Use PKCE verifier of oauth2 package (#1340)

This commit is contained in:
Hidetake Iwata
2025-05-19 18:51:43 +09:00
committed by GitHub
parent 7af43af614
commit db764cd328
3 changed files with 26 additions and 106 deletions

View File

@@ -121,11 +121,8 @@ func authorizationRequestOptions(nonce string, pkceParams pkce.Params, extraPara
oauth2.AccessTypeOffline,
gooidc.Nonce(nonce),
}
if pkceParams.CodeChallenge != "" {
opts = append(opts, oauth2.SetAuthURLParam("code_challenge", pkceParams.CodeChallenge))
}
if pkceParams.CodeChallengeMethod != "" {
opts = append(opts, oauth2.SetAuthURLParam("code_challenge_method", pkceParams.CodeChallengeMethod))
if pkceOpt := pkceParams.AuthCodeOption(); pkceOpt != nil {
opts = append(opts, pkceOpt)
}
for key, value := range extraParams {
opts = append(opts, oauth2.SetAuthURLParam(key, value))
@@ -134,11 +131,10 @@ func authorizationRequestOptions(nonce string, pkceParams pkce.Params, extraPara
}
func tokenRequestOptions(pkceParams pkce.Params) []oauth2.AuthCodeOption {
var opts []oauth2.AuthCodeOption
if pkceParams.CodeVerifier != "" {
opts = append(opts, oauth2.SetAuthURLParam("code_verifier", pkceParams.CodeVerifier))
if pkceOpt := pkceParams.TokenRequestOption(); pkceOpt != nil {
return []oauth2.AuthCodeOption{pkceOpt}
}
return opts
return nil
}
func (c *client) NegotiatedPKCEMethod() pkce.Method {

View File

@@ -2,13 +2,7 @@
// See also https://tools.ietf.org/html/rfc7636.
package pkce
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/binary"
"fmt"
)
import "golang.org/x/oauth2"
type Method int
@@ -20,9 +14,22 @@ const (
// Params represents a set of the PKCE parameters.
type Params struct {
CodeChallenge string
CodeChallengeMethod string
CodeVerifier string
Method Method
Verifier string
}
func (params Params) AuthCodeOption() oauth2.AuthCodeOption {
if params.Method == MethodS256 {
return oauth2.S256ChallengeOption(params.Verifier)
}
return nil
}
func (params Params) TokenRequestOption() oauth2.AuthCodeOption {
if params.Method == MethodS256 {
return oauth2.VerifierOption(params.Verifier)
}
return nil
}
// New returns a parameters supported by the provider.
@@ -30,39 +37,10 @@ type Params struct {
// It returns a zero value if no method is available.
func New(method Method) (Params, error) {
if method == MethodS256 {
return NewS256()
return Params{
Method: MethodS256,
Verifier: oauth2.GenerateVerifier(),
}, nil
}
return Params{}, nil
}
// NewS256 generates a parameters for S256.
func NewS256() (Params, error) {
b, err := random32()
if err != nil {
return Params{}, fmt.Errorf("could not generate a random: %w", err)
}
return computeS256(b), nil
}
func random32() ([]byte, error) {
b := make([]byte, 32)
if err := binary.Read(rand.Reader, binary.LittleEndian, b); err != nil {
return nil, fmt.Errorf("read error: %w", err)
}
return b, nil
}
func computeS256(b []byte) Params {
v := base64URLEncode(b)
s := sha256.New()
_, _ = s.Write([]byte(v))
return Params{
CodeChallenge: base64URLEncode(s.Sum(nil)),
CodeChallengeMethod: "S256",
CodeVerifier: v,
}
}
func base64URLEncode(b []byte) string {
return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(b)
}

View File

@@ -1,54 +0,0 @@
package pkce
import (
"testing"
"github.com/google/go-cmp/cmp"
)
func TestNew(t *testing.T) {
t.Run("S256", func(t *testing.T) {
params, err := New(MethodS256)
if err != nil {
t.Fatalf("New error: %s", err)
}
if params.CodeChallengeMethod != "S256" {
t.Errorf("CodeChallengeMethod wants S256 but was %s", params.CodeChallengeMethod)
}
if params.CodeChallenge == "" {
t.Errorf("CodeChallenge wants non-empty but was empty")
}
if params.CodeVerifier == "" {
t.Errorf("CodeVerifier wants non-empty but was empty")
}
})
t.Run("NoMethod", func(t *testing.T) {
params, err := New(NoMethod)
if err != nil {
t.Fatalf("New error: %s", err)
}
if diff := cmp.Diff(Params{}, params); diff != "" {
t.Errorf("mismatch (-want +got):\n%s", diff)
}
})
}
func Test_computeS256(t *testing.T) {
// Testdata described at:
// https://tools.ietf.org/html/rfc7636#appendix-B
b := []byte{
116, 24, 223, 180, 151, 153, 224, 37, 79, 250, 96, 125, 216, 173,
187, 186, 22, 212, 37, 77, 105, 214, 191, 240, 91, 88, 5, 88, 83,
132, 141, 121,
}
p := computeS256(b)
if want := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; want != p.CodeVerifier {
t.Errorf("CodeVerifier wants %s but was %s", want, p.CodeVerifier)
}
if want := "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"; want != p.CodeChallenge {
t.Errorf("CodeChallenge wants %s but was %s", want, p.CodeChallenge)
}
if p.CodeChallengeMethod != "S256" {
t.Errorf("CodeChallengeMethod wants S256 but was %s", p.CodeChallengeMethod)
}
}