diff --git a/pkg/oidc/client/client.go b/pkg/oidc/client/client.go index 73e9102..5fa0610 100644 --- a/pkg/oidc/client/client.go +++ b/pkg/oidc/client/client.go @@ -121,11 +121,8 @@ func authorizationRequestOptions(nonce string, pkceParams pkce.Params, extraPara oauth2.AccessTypeOffline, gooidc.Nonce(nonce), } - if pkceParams.CodeChallenge != "" { - opts = append(opts, oauth2.SetAuthURLParam("code_challenge", pkceParams.CodeChallenge)) - } - if pkceParams.CodeChallengeMethod != "" { - opts = append(opts, oauth2.SetAuthURLParam("code_challenge_method", pkceParams.CodeChallengeMethod)) + if pkceOpt := pkceParams.AuthCodeOption(); pkceOpt != nil { + opts = append(opts, pkceOpt) } for key, value := range extraParams { opts = append(opts, oauth2.SetAuthURLParam(key, value)) @@ -134,11 +131,10 @@ func authorizationRequestOptions(nonce string, pkceParams pkce.Params, extraPara } func tokenRequestOptions(pkceParams pkce.Params) []oauth2.AuthCodeOption { - var opts []oauth2.AuthCodeOption - if pkceParams.CodeVerifier != "" { - opts = append(opts, oauth2.SetAuthURLParam("code_verifier", pkceParams.CodeVerifier)) + if pkceOpt := pkceParams.TokenRequestOption(); pkceOpt != nil { + return []oauth2.AuthCodeOption{pkceOpt} } - return opts + return nil } func (c *client) NegotiatedPKCEMethod() pkce.Method { diff --git a/pkg/pkce/pkce.go b/pkg/pkce/pkce.go index c69c953..0b0d9f2 100644 --- a/pkg/pkce/pkce.go +++ b/pkg/pkce/pkce.go @@ -2,13 +2,7 @@ // See also https://tools.ietf.org/html/rfc7636. package pkce -import ( - "crypto/rand" - "crypto/sha256" - "encoding/base64" - "encoding/binary" - "fmt" -) +import "golang.org/x/oauth2" type Method int @@ -20,9 +14,22 @@ const ( // Params represents a set of the PKCE parameters. type Params struct { - CodeChallenge string - CodeChallengeMethod string - CodeVerifier string + Method Method + Verifier string +} + +func (params Params) AuthCodeOption() oauth2.AuthCodeOption { + if params.Method == MethodS256 { + return oauth2.S256ChallengeOption(params.Verifier) + } + return nil +} + +func (params Params) TokenRequestOption() oauth2.AuthCodeOption { + if params.Method == MethodS256 { + return oauth2.VerifierOption(params.Verifier) + } + return nil } // New returns a parameters supported by the provider. @@ -30,39 +37,10 @@ type Params struct { // It returns a zero value if no method is available. func New(method Method) (Params, error) { if method == MethodS256 { - return NewS256() + return Params{ + Method: MethodS256, + Verifier: oauth2.GenerateVerifier(), + }, nil } return Params{}, nil } - -// NewS256 generates a parameters for S256. -func NewS256() (Params, error) { - b, err := random32() - if err != nil { - return Params{}, fmt.Errorf("could not generate a random: %w", err) - } - return computeS256(b), nil -} - -func random32() ([]byte, error) { - b := make([]byte, 32) - if err := binary.Read(rand.Reader, binary.LittleEndian, b); err != nil { - return nil, fmt.Errorf("read error: %w", err) - } - return b, nil -} - -func computeS256(b []byte) Params { - v := base64URLEncode(b) - s := sha256.New() - _, _ = s.Write([]byte(v)) - return Params{ - CodeChallenge: base64URLEncode(s.Sum(nil)), - CodeChallengeMethod: "S256", - CodeVerifier: v, - } -} - -func base64URLEncode(b []byte) string { - return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(b) -} diff --git a/pkg/pkce/pkce_test.go b/pkg/pkce/pkce_test.go deleted file mode 100644 index cc38349..0000000 --- a/pkg/pkce/pkce_test.go +++ /dev/null @@ -1,54 +0,0 @@ -package pkce - -import ( - "testing" - - "github.com/google/go-cmp/cmp" -) - -func TestNew(t *testing.T) { - t.Run("S256", func(t *testing.T) { - params, err := New(MethodS256) - if err != nil { - t.Fatalf("New error: %s", err) - } - if params.CodeChallengeMethod != "S256" { - t.Errorf("CodeChallengeMethod wants S256 but was %s", params.CodeChallengeMethod) - } - if params.CodeChallenge == "" { - t.Errorf("CodeChallenge wants non-empty but was empty") - } - if params.CodeVerifier == "" { - t.Errorf("CodeVerifier wants non-empty but was empty") - } - }) - t.Run("NoMethod", func(t *testing.T) { - params, err := New(NoMethod) - if err != nil { - t.Fatalf("New error: %s", err) - } - if diff := cmp.Diff(Params{}, params); diff != "" { - t.Errorf("mismatch (-want +got):\n%s", diff) - } - }) -} - -func Test_computeS256(t *testing.T) { - // Testdata described at: - // https://tools.ietf.org/html/rfc7636#appendix-B - b := []byte{ - 116, 24, 223, 180, 151, 153, 224, 37, 79, 250, 96, 125, 216, 173, - 187, 186, 22, 212, 37, 77, 105, 214, 191, 240, 91, 88, 5, 88, 83, - 132, 141, 121, - } - p := computeS256(b) - if want := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; want != p.CodeVerifier { - t.Errorf("CodeVerifier wants %s but was %s", want, p.CodeVerifier) - } - if want := "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"; want != p.CodeChallenge { - t.Errorf("CodeChallenge wants %s but was %s", want, p.CodeChallenge) - } - if p.CodeChallengeMethod != "S256" { - t.Errorf("CodeChallengeMethod wants S256 but was %s", p.CodeChallengeMethod) - } -}