diff --git a/pkg/openid/client/login.go b/pkg/openid/client/login.go index 9a7f493..9faed8e 100644 --- a/pkg/openid/client/login.go +++ b/pkg/openid/client/login.go @@ -19,6 +19,10 @@ import ( const ( LocaleURLParameter = "locale" SecurityLevelURLParameter = "level" + + ResponseModeQuery = "query" + + CodeChallengeMethodS256 = "S256" ) var ( @@ -28,8 +32,8 @@ var ( // LoginParameterMapping maps incoming login parameters to OpenID Connect parameters LoginParameterMapping = map[string]string{ - LocaleURLParameter: "ui_locales", - SecurityLevelURLParameter: "acr_values", + LocaleURLParameter: openid.UILocales, + SecurityLevelURLParameter: openid.ACRValues, } ) @@ -141,15 +145,15 @@ func newLoginParameters(c Client) (*loginParameters, error) { func (in *loginParameters) authCodeURL(r *http.Request, callbackURL string, loginstatus loginstatus.Loginstatus) (string, error) { opts := []oauth2.AuthCodeOption{ - oauth2.SetAuthURLParam("nonce", in.Nonce), - oauth2.SetAuthURLParam("response_mode", "query"), - oauth2.SetAuthURLParam("code_challenge", in.CodeChallenge), - oauth2.SetAuthURLParam("code_challenge_method", "S256"), - oauth2.SetAuthURLParam("redirect_uri", callbackURL), + oauth2.SetAuthURLParam(openid.Nonce, in.Nonce), + oauth2.SetAuthURLParam(openid.ResponseMode, ResponseModeQuery), + oauth2.SetAuthURLParam(openid.CodeChallenge, in.CodeChallenge), + oauth2.SetAuthURLParam(openid.CodeChallengeMethod, CodeChallengeMethodS256), + oauth2.SetAuthURLParam(openid.RedirectURI, callbackURL), } if loginstatus.NeedsResourceIndicator() { - opts = append(opts, oauth2.SetAuthURLParam("resource", loginstatus.ResourceIndicator())) + opts = append(opts, oauth2.SetAuthURLParam(openid.Resource, loginstatus.ResourceIndicator())) } opts, err := in.withSecurityLevel(r, opts) diff --git a/pkg/openid/client/login_callback.go b/pkg/openid/client/login_callback.go index d78fa10..084999a 100644 --- a/pkg/openid/client/login_callback.go +++ b/pkg/openid/client/login_callback.go @@ -13,6 +13,10 @@ import ( "github.com/nais/wonderwall/pkg/openid/provider" ) +const ( + ClientAssertionTypeJwtBearer = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" +) + type LoginCallback interface { IdentityProviderError() error StateMismatchError() error @@ -42,9 +46,9 @@ func NewLoginCallback(c Client, r *http.Request, p provider.Provider, cookie *op } func (in *loginCallback) IdentityProviderError() error { - if in.requestParams.Get("error") != "" { - oauthError := in.requestParams.Get("error") - oauthErrorDescription := in.requestParams.Get("error_description") + if in.requestParams.Get(openid.Error) != "" { + oauthError := in.requestParams.Get(openid.Error) + oauthErrorDescription := in.requestParams.Get(openid.ErrorDescription) return fmt.Errorf("error from identity provider: %s: %s", oauthError, oauthErrorDescription) } @@ -53,7 +57,7 @@ func (in *loginCallback) IdentityProviderError() error { func (in *loginCallback) StateMismatchError() error { expectedState := in.cookie.State - actualState := in.requestParams.Get("state") + actualState := in.requestParams.Get(openid.State) if len(actualState) <= 0 { return fmt.Errorf("missing state parameter in request (possible csrf)") @@ -73,12 +77,12 @@ func (in *loginCallback) RedeemTokens(ctx context.Context) (*openid.Tokens, erro } opts := []oauth2.AuthCodeOption{ - oauth2.SetAuthURLParam("code_verifier", in.cookie.CodeVerifier), - oauth2.SetAuthURLParam("client_assertion", clientAssertion), - oauth2.SetAuthURLParam("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"), + oauth2.SetAuthURLParam(openid.CodeVerifier, in.cookie.CodeVerifier), + oauth2.SetAuthURLParam(openid.ClientAssertion, clientAssertion), + oauth2.SetAuthURLParam(openid.ClientAssertionType, ClientAssertionTypeJwtBearer), } - code := in.requestParams.Get("code") + code := in.requestParams.Get(openid.Code) rawTokens, err := in.client.AuthCodeGrant(ctx, code, opts) if err != nil { return nil, fmt.Errorf("exchanging authorization code for token: %w", err) diff --git a/pkg/openid/client/logout.go b/pkg/openid/client/logout.go index 3051025..b89a9a0 100644 --- a/pkg/openid/client/logout.go +++ b/pkg/openid/client/logout.go @@ -5,6 +5,7 @@ import ( "net/http" urlpkg "github.com/nais/wonderwall/pkg/handler/url" + "github.com/nais/wonderwall/pkg/openid" ) type Logout interface { @@ -33,10 +34,10 @@ func NewLogout(c Client, r *http.Request) (Logout, error) { func (in *logout) SingleLogoutURL(idToken string) string { endSessionEndpoint := in.config().Provider().EndSessionEndpointURL() v := endSessionEndpoint.Query() - v.Add("post_logout_redirect_uri", in.logoutCallbackURL) + v.Add(openid.PostLogoutRedirectURI, in.logoutCallbackURL) if len(idToken) > 0 { - v.Add("id_token_hint", idToken) + v.Add(openid.IDTokenHint, idToken) } endSessionEndpoint.RawQuery = v.Encode() diff --git a/pkg/openid/client/logout_frontchannel.go b/pkg/openid/client/logout_frontchannel.go index 89bbb29..179c965 100644 --- a/pkg/openid/client/logout_frontchannel.go +++ b/pkg/openid/client/logout_frontchannel.go @@ -1,6 +1,10 @@ package client -import "net/http" +import ( + "net/http" + + "github.com/nais/wonderwall/pkg/openid" +) type LogoutFrontchannel interface { // Sid is the session identifier which SHOULD be included as a parameter in the front-channel logout request. @@ -14,7 +18,7 @@ type logoutFrontchannel struct { func NewLogoutFrontchannel(r *http.Request) LogoutFrontchannel { params := r.URL.Query() - sid := params.Get("sid") + sid := params.Get(openid.Sid) return &logoutFrontchannel{ sid: sid, diff --git a/pkg/openid/params.go b/pkg/openid/params.go new file mode 100644 index 0000000..24d853d --- /dev/null +++ b/pkg/openid/params.go @@ -0,0 +1,23 @@ +package openid + +const ( + ACRValues = "acr_values" + ClientAssertion = "client_assertion" + ClientAssertionType = "client_assertion_type" + CodeChallenge = "code_challenge" + CodeChallengeMethod = "code_challenge_method" + Code = "code" + CodeVerifier = "code_verifier" + Error = "error" + ErrorDescription = "error_description" + IDTokenHint = "id_token_hint" + Nonce = "nonce" + PostLogoutRedirectURI = "post_logout_redirect_uri" + SessionState = "session_state" + Sid = "sid" + State = "state" + RedirectURI = "redirect_uri" + Resource = "resource" + ResponseMode = "response_mode" + UILocales = "ui_locales" +) diff --git a/pkg/session/handler.go b/pkg/session/handler.go index 75b8888..b9e4ea3 100644 --- a/pkg/session/handler.go +++ b/pkg/session/handler.go @@ -172,11 +172,6 @@ func (h *Handler) Key(sessionID string) string { return fmt.Sprintf("%s:%s:%s", provider.Name(), client.ClientID(), sessionID) } -const ( - // TODO - move to url_params consts for openid pkg - SessionStateParamKey = "session_state" -) - func NewSessionID(cfg openidconfig.Provider, idToken *openid.IDToken, params url.Values) (string, error) { // 1. check for 'sid' claim in id_token sessionID, err := idToken.GetSidClaim() @@ -207,9 +202,9 @@ func NewSessionID(cfg openidconfig.Provider, idToken *openid.IDToken, params url } func getSessionStateFrom(params url.Values) (string, error) { - sessionState := params.Get(SessionStateParamKey) + sessionState := params.Get(openid.SessionState) if len(sessionState) == 0 { - return "", fmt.Errorf("missing required '%s' in params", SessionStateParamKey) + return "", fmt.Errorf("missing required '%s' in params", openid.SessionState) } return sessionState, nil }