From 19b2401831fb49945604065c758aa0adfac9a201 Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Tue, 18 Apr 2023 12:20:23 +0200 Subject: [PATCH] feat(metrics): add authentication method reference label for successful logins --- pkg/handler/handler.go | 2 +- pkg/jwt/jwt.go | 57 ++++++++++++++++++++++++++++++++++++------ pkg/metrics/metrics.go | 15 +++++++---- pkg/openid/tokens.go | 10 ++++++++ 4 files changed, 71 insertions(+), 13 deletions(-) diff --git a/pkg/handler/handler.go b/pkg/handler/handler.go index f073706..6813b2c 100644 --- a/pkg/handler/handler.go +++ b/pkg/handler/handler.go @@ -214,7 +214,7 @@ func (s *Standalone) LoginCallback(w http.ResponseWriter, r *http.Request) { } mw.LogEntryFrom(r).WithFields(fields).Info("callback: successful login") - metrics.ObserveLogin() + metrics.ObserveLogin(tokens.IDToken.GetAmrClaim()) cookie.Clear(w, cookie.Retry, s.GetCookieOptions(r)) http.Redirect(w, r, redirect, http.StatusTemporaryRedirect) } diff --git a/pkg/jwt/jwt.go b/pkg/jwt/jwt.go index 5db6e27..f92d72a 100644 --- a/pkg/jwt/jwt.go +++ b/pkg/jwt/jwt.go @@ -12,6 +12,7 @@ import ( const ( AcceptableClockSkew = 5 * time.Second + AmrClaim = "amr" JtiClaim = "jti" SidClaim = "sid" UtiClaim = "uti" @@ -22,6 +23,19 @@ type Token struct { token jwt.Token } +func (in *Token) GetClaim(claim string) (any, error) { + if in.token == nil { + return nil, fmt.Errorf("token is nil") + } + + gotClaim, ok := in.token.Get(claim) + if !ok { + return nil, fmt.Errorf("missing required '%s' claim in id_token", claim) + } + + return gotClaim, nil +} + func (in *Token) GetExpiration() time.Time { return in.token.Expiration() } @@ -44,13 +58,9 @@ func (in *Token) GetSerialized() string { } func (in *Token) GetStringClaim(claim string) (string, error) { - if in.token == nil { - return "", fmt.Errorf("token is nil") - } - - gotClaim, ok := in.token.Get(claim) - if !ok { - return "", fmt.Errorf("missing required '%s' claim in id_token", claim) + gotClaim, err := in.GetClaim(claim) + if err != nil { + return "", err } claimString, ok := gotClaim.(string) @@ -61,6 +71,30 @@ func (in *Token) GetStringClaim(claim string) (string, error) { return claimString, nil } +func (in *Token) GetStringSliceClaim(claim string) ([]string, error) { + gotClaim, err := in.GetClaim(claim) + if err != nil { + return nil, err + } + + // the claim is a slice of interfaces... + claimValues, ok := gotClaim.([]interface{}) + if !ok { + return nil, fmt.Errorf("'%s' claim is not a slice", claim) + } + + // ...so we need to assert the actual type for each interface + strings := make([]string, 0) + + for _, v := range claimValues { + if str, ok := v.(string); ok { + strings = append(strings, str) + } + } + + return strings, nil +} + func (in *Token) GetStringClaimOrEmpty(claim string) string { str, err := in.GetStringClaim(claim) if err != nil { @@ -70,6 +104,15 @@ func (in *Token) GetStringClaimOrEmpty(claim string) string { return str } +func (in *Token) GetStringSliceClaimOrEmpty(claim string) []string { + s, err := in.GetStringSliceClaim(claim) + if err != nil { + return make([]string, 0) + } + + return s +} + func (in *Token) GetToken() jwt.Token { return in.token } diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index a120369..25cd06b 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -13,6 +13,7 @@ import ( const ( Namespace = "wonderwall" + LabelAmr = "amr" LabelHpa = "hpa" LabelOperation = "operation" LabelProvider = "provider" @@ -62,7 +63,7 @@ func redisLatency(constLabels ...prometheus.Labels) *prometheus.HistogramVec { return prometheus.NewHistogramVec(opts, []string{LabelOperation}) } -func logins(constLabels ...prometheus.Labels) prometheus.Counter { +func logins(constLabels ...prometheus.Labels) *prometheus.CounterVec { opts := prometheus.CounterOpts{ Name: "logins", Namespace: Namespace, @@ -76,7 +77,7 @@ func logins(constLabels ...prometheus.Labels) prometheus.Counter { opts.ConstLabels = constLabels[0] } - return prometheus.NewCounter(opts) + return prometheus.NewCounterVec(opts, []string{LabelAmr}) } func logouts(constLabels ...prometheus.Labels) *prometheus.CounterVec { @@ -114,11 +115,13 @@ func WithProvider(provider string) { // InitLabels zeroes out all possible label combinations func InitLabels() { - logoutOperations := []LogoutOperation{LogoutOperationSelfInitiated, LogoutOperationFrontChannel} + logoutOperations := []LogoutOperation{LogoutOperationSelfInitiated, LogoutOperationFrontChannel, LogoutOperationLocal} for _, operation := range logoutOperations { Logouts.With(prometheus.Labels{LabelOperation: operation}) } + + Logins.With(prometheus.Labels{LabelAmr: ""}) } func Handle(address string, provider config.Provider) error { @@ -148,8 +151,10 @@ func ObserveRedisLatency(operation string, fun func() error) error { return err } -func ObserveLogin() { - Logins.Inc() +func ObserveLogin(amrValue string) { + Logins.With(prometheus.Labels{ + LabelAmr: amrValue, + }).Inc() } func ObserveLogout(operation LogoutOperation) { diff --git a/pkg/openid/tokens.go b/pkg/openid/tokens.go index e1b3b67..dbdd823 100644 --- a/pkg/openid/tokens.go +++ b/pkg/openid/tokens.go @@ -2,6 +2,7 @@ package openid import ( "fmt" + "strings" "time" "github.com/lestrrat-go/jwx/v2/jwk" @@ -39,6 +40,15 @@ type IDToken struct { jwt.Token } +func (in *IDToken) GetAmrClaim() string { + s := in.GetStringClaimOrEmpty(jwt.AmrClaim) + if len(s) == 0 { + s = strings.Join(in.GetStringSliceClaimOrEmpty(jwt.AmrClaim), ",") + } + + return s +} + func (in *IDToken) GetSidClaim() (string, error) { return in.GetStringClaim(jwt.SidClaim) }