refactor(openid): extract magic strings

This commit is contained in:
Trong Huu Nguyen
2022-08-19 10:34:32 +02:00
parent 5990e4bb71
commit 08f570363a
6 changed files with 58 additions and 27 deletions

View File

@@ -19,6 +19,10 @@ import (
const (
LocaleURLParameter = "locale"
SecurityLevelURLParameter = "level"
ResponseModeQuery = "query"
CodeChallengeMethodS256 = "S256"
)
var (
@@ -28,8 +32,8 @@ var (
// LoginParameterMapping maps incoming login parameters to OpenID Connect parameters
LoginParameterMapping = map[string]string{
LocaleURLParameter: "ui_locales",
SecurityLevelURLParameter: "acr_values",
LocaleURLParameter: openid.UILocales,
SecurityLevelURLParameter: openid.ACRValues,
}
)
@@ -141,15 +145,15 @@ func newLoginParameters(c Client) (*loginParameters, error) {
func (in *loginParameters) authCodeURL(r *http.Request, callbackURL string, loginstatus loginstatus.Loginstatus) (string, error) {
opts := []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("nonce", in.Nonce),
oauth2.SetAuthURLParam("response_mode", "query"),
oauth2.SetAuthURLParam("code_challenge", in.CodeChallenge),
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
oauth2.SetAuthURLParam("redirect_uri", callbackURL),
oauth2.SetAuthURLParam(openid.Nonce, in.Nonce),
oauth2.SetAuthURLParam(openid.ResponseMode, ResponseModeQuery),
oauth2.SetAuthURLParam(openid.CodeChallenge, in.CodeChallenge),
oauth2.SetAuthURLParam(openid.CodeChallengeMethod, CodeChallengeMethodS256),
oauth2.SetAuthURLParam(openid.RedirectURI, callbackURL),
}
if loginstatus.NeedsResourceIndicator() {
opts = append(opts, oauth2.SetAuthURLParam("resource", loginstatus.ResourceIndicator()))
opts = append(opts, oauth2.SetAuthURLParam(openid.Resource, loginstatus.ResourceIndicator()))
}
opts, err := in.withSecurityLevel(r, opts)

View File

@@ -13,6 +13,10 @@ import (
"github.com/nais/wonderwall/pkg/openid/provider"
)
const (
ClientAssertionTypeJwtBearer = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
)
type LoginCallback interface {
IdentityProviderError() error
StateMismatchError() error
@@ -42,9 +46,9 @@ func NewLoginCallback(c Client, r *http.Request, p provider.Provider, cookie *op
}
func (in *loginCallback) IdentityProviderError() error {
if in.requestParams.Get("error") != "" {
oauthError := in.requestParams.Get("error")
oauthErrorDescription := in.requestParams.Get("error_description")
if in.requestParams.Get(openid.Error) != "" {
oauthError := in.requestParams.Get(openid.Error)
oauthErrorDescription := in.requestParams.Get(openid.ErrorDescription)
return fmt.Errorf("error from identity provider: %s: %s", oauthError, oauthErrorDescription)
}
@@ -53,7 +57,7 @@ func (in *loginCallback) IdentityProviderError() error {
func (in *loginCallback) StateMismatchError() error {
expectedState := in.cookie.State
actualState := in.requestParams.Get("state")
actualState := in.requestParams.Get(openid.State)
if len(actualState) <= 0 {
return fmt.Errorf("missing state parameter in request (possible csrf)")
@@ -73,12 +77,12 @@ func (in *loginCallback) RedeemTokens(ctx context.Context) (*openid.Tokens, erro
}
opts := []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_verifier", in.cookie.CodeVerifier),
oauth2.SetAuthURLParam("client_assertion", clientAssertion),
oauth2.SetAuthURLParam("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"),
oauth2.SetAuthURLParam(openid.CodeVerifier, in.cookie.CodeVerifier),
oauth2.SetAuthURLParam(openid.ClientAssertion, clientAssertion),
oauth2.SetAuthURLParam(openid.ClientAssertionType, ClientAssertionTypeJwtBearer),
}
code := in.requestParams.Get("code")
code := in.requestParams.Get(openid.Code)
rawTokens, err := in.client.AuthCodeGrant(ctx, code, opts)
if err != nil {
return nil, fmt.Errorf("exchanging authorization code for token: %w", err)

View File

@@ -5,6 +5,7 @@ import (
"net/http"
urlpkg "github.com/nais/wonderwall/pkg/handler/url"
"github.com/nais/wonderwall/pkg/openid"
)
type Logout interface {
@@ -33,10 +34,10 @@ func NewLogout(c Client, r *http.Request) (Logout, error) {
func (in *logout) SingleLogoutURL(idToken string) string {
endSessionEndpoint := in.config().Provider().EndSessionEndpointURL()
v := endSessionEndpoint.Query()
v.Add("post_logout_redirect_uri", in.logoutCallbackURL)
v.Add(openid.PostLogoutRedirectURI, in.logoutCallbackURL)
if len(idToken) > 0 {
v.Add("id_token_hint", idToken)
v.Add(openid.IDTokenHint, idToken)
}
endSessionEndpoint.RawQuery = v.Encode()

View File

@@ -1,6 +1,10 @@
package client
import "net/http"
import (
"net/http"
"github.com/nais/wonderwall/pkg/openid"
)
type LogoutFrontchannel interface {
// Sid is the session identifier which SHOULD be included as a parameter in the front-channel logout request.
@@ -14,7 +18,7 @@ type logoutFrontchannel struct {
func NewLogoutFrontchannel(r *http.Request) LogoutFrontchannel {
params := r.URL.Query()
sid := params.Get("sid")
sid := params.Get(openid.Sid)
return &logoutFrontchannel{
sid: sid,

23
pkg/openid/params.go Normal file
View File

@@ -0,0 +1,23 @@
package openid
const (
ACRValues = "acr_values"
ClientAssertion = "client_assertion"
ClientAssertionType = "client_assertion_type"
CodeChallenge = "code_challenge"
CodeChallengeMethod = "code_challenge_method"
Code = "code"
CodeVerifier = "code_verifier"
Error = "error"
ErrorDescription = "error_description"
IDTokenHint = "id_token_hint"
Nonce = "nonce"
PostLogoutRedirectURI = "post_logout_redirect_uri"
SessionState = "session_state"
Sid = "sid"
State = "state"
RedirectURI = "redirect_uri"
Resource = "resource"
ResponseMode = "response_mode"
UILocales = "ui_locales"
)

View File

@@ -172,11 +172,6 @@ func (h *Handler) Key(sessionID string) string {
return fmt.Sprintf("%s:%s:%s", provider.Name(), client.ClientID(), sessionID)
}
const (
// TODO - move to url_params consts for openid pkg
SessionStateParamKey = "session_state"
)
func NewSessionID(cfg openidconfig.Provider, idToken *openid.IDToken, params url.Values) (string, error) {
// 1. check for 'sid' claim in id_token
sessionID, err := idToken.GetSidClaim()
@@ -207,9 +202,9 @@ func NewSessionID(cfg openidconfig.Provider, idToken *openid.IDToken, params url
}
func getSessionStateFrom(params url.Values) (string, error) {
sessionState := params.Get(SessionStateParamKey)
sessionState := params.Get(openid.SessionState)
if len(sessionState) == 0 {
return "", fmt.Errorf("missing required '%s' in params", SessionStateParamKey)
return "", fmt.Errorf("missing required '%s' in params", openid.SessionState)
}
return sessionState, nil
}