diff --git a/pkg/openid/login.go b/pkg/openid/login.go index 0f3bb47..7c7e07a 100644 --- a/pkg/openid/login.go +++ b/pkg/openid/login.go @@ -1,11 +1,11 @@ package openid import ( - "crypto/rand" "crypto/sha256" "encoding/base64" "fmt" - "io" + + "github.com/nais/wonderwall/pkg/strings" ) type LoginParameters struct { @@ -16,36 +16,33 @@ type LoginParameters struct { } func GenerateLoginParameters() (*LoginParameters, error) { - codeVerifier := make([]byte, 64) - nonce := make([]byte, 32) - state := make([]byte, 32) - - var err error - - _, err = io.ReadFull(rand.Reader, state) + codeVerifier, err := strings.GenerateBase64(64) if err != nil { - return nil, fmt.Errorf("failed to create state: %w", err) + return nil, fmt.Errorf("creating code verifier: %w", err) } - _, err = io.ReadFull(rand.Reader, nonce) + nonce, err := strings.GenerateBase64(32) if err != nil { - return nil, fmt.Errorf("failed to create nonce: %w", err) + return nil, fmt.Errorf("creating nonce: %w", err) } - _, err = io.ReadFull(rand.Reader, codeVerifier) + state, err := strings.GenerateBase64(32) if err != nil { - return nil, fmt.Errorf("failed to create code verifier: %w", err) + return nil, fmt.Errorf("creating state: %w", err) } - codeVerifier = []byte(base64.RawURLEncoding.EncodeToString(codeVerifier)) - hasher := sha256.New() - hasher.Write(codeVerifier) - codeVerifierHash := hasher.Sum(nil) - return &LoginParameters{ - CodeVerifier: string(codeVerifier), - CodeChallenge: base64.RawURLEncoding.EncodeToString(codeVerifierHash), - Nonce: base64.RawURLEncoding.EncodeToString(nonce), - State: base64.RawURLEncoding.EncodeToString(state), + CodeVerifier: codeVerifier, + CodeChallenge: CodeChallenge(codeVerifier), + Nonce: nonce, + State: state, }, nil } + +func CodeChallenge(codeVerifier string) string { + hasher := sha256.New() + hasher.Write([]byte(codeVerifier)) + codeVerifierHash := hasher.Sum(nil) + + return base64.RawURLEncoding.EncodeToString(codeVerifierHash) +} diff --git a/pkg/strings/generator.go b/pkg/strings/generator.go new file mode 100644 index 0000000..1f4c32f --- /dev/null +++ b/pkg/strings/generator.go @@ -0,0 +1,29 @@ +package strings + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + "io" +) + +// GenerateBase64 generates a random string of a given length, and base64 URL-encodes it. +func GenerateBase64(length int) (string, error) { + bytes, err := Generate(length) + if err != nil { + return "", err + } + + return base64.RawURLEncoding.EncodeToString(bytes), nil +} + +// Generate generates a random byte array of a given length. +func Generate(length int) ([]byte, error) { + bytes := make([]byte, length) + _, err := io.ReadFull(rand.Reader, bytes) + if err != nil { + return nil, fmt.Errorf("reading rand.Reader: %w", err) + } + + return bytes, nil +}