diff --git a/main.go b/main.go index 38bf593..6819fc7 100644 --- a/main.go +++ b/main.go @@ -581,7 +581,7 @@ func runRPC(wg *sync.WaitGroup, ctx context.Context) { var asv authv3.AuthService if !dev { _log.Infow("adding auth interceptor") - ac := authv3.NewAuthContext(kc, ks, as) + ac := authv3.NewAuthContext(db, kc, ks, as) asv = authv3.NewAuthService(ac) o := authv3.Option{ ExcludeRPCMethods: []string{ diff --git a/pkg/auth/v3/auth.go b/pkg/auth/v3/auth.go index 241d52d..be0ec6e 100644 --- a/pkg/auth/v3/auth.go +++ b/pkg/auth/v3/auth.go @@ -37,6 +37,7 @@ type Option struct { } type authContext struct { + db *bun.DB kc *kclient.APIClient ks service.ApiKeyService as service.AuthzService @@ -78,7 +79,7 @@ func SetupAuthContext(auditLogger *zap.Logger) authContext { } as := service.NewAuthzService(db, enforcer) - return authContext{kc: kc, as: as, ks: service.NewApiKeyService(db, auditLogger)} + return authContext{db: db, kc: kc, as: as, ks: service.NewApiKeyService(db, auditLogger)} } func getDSN() string { @@ -106,11 +107,13 @@ func getEnvWithDefault(env, def string) string { // instead of creating new instances. To create authContext along with // its dependencies, use SetupAuthContext. func NewAuthContext( + db *bun.DB, kc *kclient.APIClient, apiKeySvc service.ApiKeyService, authzSvc service.AuthzService, ) authContext { return authContext{ + db: db, kc: kc, ks: apiKeySvc, as: authzSvc, diff --git a/pkg/auth/v3/core.go b/pkg/auth/v3/core.go index ed0148b..8d5f194 100644 --- a/pkg/auth/v3/core.go +++ b/pkg/auth/v3/core.go @@ -7,9 +7,11 @@ import ( "errors" "strings" + "github.com/RafayLabs/rcloud-base/internal/dao" rpcv3 "github.com/RafayLabs/rcloud-base/proto/rpc/user" authzv1 "github.com/RafayLabs/rcloud-base/proto/types/authz" commonv3 "github.com/RafayLabs/rcloud-base/proto/types/commonpb/v3" + "github.com/google/uuid" ) var ( @@ -93,6 +95,23 @@ func (ac *authContext) authenticate(ctx context.Context, req *commonv3.IsRequest t := session.Identity.Traits.(map[string]interface{}) res.SessionData.Username = t["email"].(string) + uid, err := uuid.Parse(session.Identity.Id) + if err != nil { + res.Status = commonv3.RequestStatus_RequestNotAuthenticated + res.Reason = "unable to find identity" + return false, err + } + groups, err := dao.GetGroups(ctx, ac.db, uid) + if err != nil { + res.Status = commonv3.RequestStatus_RequestNotAuthenticated + res.Reason = "unable to find identity" + return false, err + } + groupNames := []string{} + for _, g := range groups { + groupNames = append(groupNames, g.Name) + } + res.SessionData.Groups = groupNames } else { res.Status = commonv3.RequestStatus_RequestNotAuthenticated res.Reason = "no active session"