From 4bf3b1bdd4b343f83ea9828c8c14c72c6735957c Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Mon, 16 Jun 2025 09:55:35 +0200 Subject: [PATCH] refactor: move string generator to crypto package --- internal/crypto/text.go | 16 ++++++++++++++++ pkg/openid/client/login.go | 5 ++--- pkg/openid/client/logout.go | 3 +-- pkg/session/id.go | 4 ++-- pkg/strings/generator.go | 29 ----------------------------- 5 files changed, 21 insertions(+), 36 deletions(-) create mode 100644 internal/crypto/text.go delete mode 100644 pkg/strings/generator.go diff --git a/internal/crypto/text.go b/internal/crypto/text.go new file mode 100644 index 0000000..893edd2 --- /dev/null +++ b/internal/crypto/text.go @@ -0,0 +1,16 @@ +package crypto + +import ( + "crypto/rand" + "encoding/base64" +) + +// Text generates a cryptographically secure random string of a given length, and base64 URL-encodes it. +func Text(length int) (string, error) { + data := make([]byte, length) + if _, err := rand.Read(data); err != nil { + return "", err + } + + return base64.RawURLEncoding.EncodeToString(data), nil +} diff --git a/pkg/openid/client/login.go b/pkg/openid/client/login.go index 75db5cb..86f3b67 100644 --- a/pkg/openid/client/login.go +++ b/pkg/openid/client/login.go @@ -16,7 +16,6 @@ import ( mw "github.com/nais/wonderwall/pkg/middleware" "github.com/nais/wonderwall/pkg/openid" "github.com/nais/wonderwall/pkg/openid/acr" - "github.com/nais/wonderwall/pkg/strings" "github.com/nais/wonderwall/pkg/url" log "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/attribute" @@ -76,12 +75,12 @@ func (c *Client) newAuthorizationCodeParams(r *http.Request) (openid.Authorizati return req, fmt.Errorf("generating callback url: %w", err) } - nonce, err := strings.GenerateBase64(32) + nonce, err := crypto.Text(32) if err != nil { return req, fmt.Errorf("creating nonce: %w", err) } - state, err := strings.GenerateBase64(32) + state, err := crypto.Text(32) if err != nil { return req, fmt.Errorf("creating state: %w", err) } diff --git a/pkg/openid/client/logout.go b/pkg/openid/client/logout.go index 4b7c868..7465d80 100644 --- a/pkg/openid/client/logout.go +++ b/pkg/openid/client/logout.go @@ -8,7 +8,6 @@ import ( "github.com/nais/wonderwall/internal/crypto" "github.com/nais/wonderwall/pkg/cookie" "github.com/nais/wonderwall/pkg/openid" - "github.com/nais/wonderwall/pkg/strings" urlpkg "github.com/nais/wonderwall/pkg/url" ) @@ -24,7 +23,7 @@ func NewLogout(c *Client, r *http.Request) (*Logout, error) { return nil, fmt.Errorf("generating logout callback url: %w", err) } - state, err := strings.GenerateBase64(32) + state, err := crypto.Text(32) if err != nil { return nil, fmt.Errorf("generating state: %w", err) } diff --git a/pkg/session/id.go b/pkg/session/id.go index e525809..7ab14c5 100644 --- a/pkg/session/id.go +++ b/pkg/session/id.go @@ -4,9 +4,9 @@ import ( "fmt" "net/http" + "github.com/nais/wonderwall/internal/crypto" "github.com/nais/wonderwall/pkg/openid" openidconfig "github.com/nais/wonderwall/pkg/openid/config" - "github.com/nais/wonderwall/pkg/strings" ) // ExternalID returns the external session ID, derived from the given request or id_token; e.g. `sid` or `session_state`. @@ -33,7 +33,7 @@ func ExternalID(r *http.Request, cfg openidconfig.Provider, idToken *openid.IDTo } // 3. generate ID if all else fails - sessionID, err = strings.GenerateBase64(64) + sessionID, err = crypto.Text(64) if err != nil { return "", fmt.Errorf("generating session ID: %w", err) } diff --git a/pkg/strings/generator.go b/pkg/strings/generator.go deleted file mode 100644 index 1f4c32f..0000000 --- a/pkg/strings/generator.go +++ /dev/null @@ -1,29 +0,0 @@ -package strings - -import ( - "crypto/rand" - "encoding/base64" - "fmt" - "io" -) - -// GenerateBase64 generates a random string of a given length, and base64 URL-encodes it. -func GenerateBase64(length int) (string, error) { - bytes, err := Generate(length) - if err != nil { - return "", err - } - - return base64.RawURLEncoding.EncodeToString(bytes), nil -} - -// Generate generates a random byte array of a given length. -func Generate(length int) ([]byte, error) { - bytes := make([]byte, length) - _, err := io.ReadFull(rand.Reader, bytes) - if err != nil { - return nil, fmt.Errorf("reading rand.Reader: %w", err) - } - - return bytes, nil -}