diff --git a/pkg/auth/v3/interceptor.go b/pkg/auth/v3/interceptor.go index 35cfc30..d4ffffb 100644 --- a/pkg/auth/v3/interceptor.go +++ b/pkg/auth/v3/interceptor.go @@ -4,6 +4,7 @@ import ( context "context" "strings" + "github.com/RafayLabs/rcloud-base/pkg/common" "github.com/RafayLabs/rcloud-base/pkg/gateway" commonv3 "github.com/RafayLabs/rcloud-base/proto/types/commonpb/v3" grpc "google.golang.org/grpc" @@ -88,7 +89,7 @@ func (ac authContext) NewAuthUnaryInterceptor(opt Option) grpc.UnaryServerInterc s := res.GetStatus() switch s { case commonv3.RequestStatus_RequestAllowed: - ctx := NewSessionContext(ctx, res.SessionData) + ctx := context.WithValue(ctx, common.SessionDataKey, res.SessionData) return handler(ctx, req) case commonv3.RequestStatus_RequestMethodOrURLNotAllowed: return nil, status.Error(codes.PermissionDenied, res.GetReason()) diff --git a/pkg/auth/v3/middleware.go b/pkg/auth/v3/middleware.go index 3eae9b8..b89be21 100644 --- a/pkg/auth/v3/middleware.go +++ b/pkg/auth/v3/middleware.go @@ -1,11 +1,13 @@ package authv3 import ( + context "context" "net/http" "regexp" "strings" "github.com/RafayLabs/rcloud-base/internal/dao" + "github.com/RafayLabs/rcloud-base/pkg/common" commonpbv3 "github.com/RafayLabs/rcloud-base/proto/types/commonpb/v3" "github.com/google/uuid" "github.com/uptrace/bun" @@ -92,7 +94,7 @@ func (am *authMiddleware) ServeHTTP(rw http.ResponseWriter, r *http.Request, nex s := res.GetStatus() switch s { case commonpbv3.RequestStatus_RequestAllowed: - ctx := NewSessionContext(r.Context(), res.SessionData) + ctx := context.WithValue(r.Context(), common.SessionDataKey, res.SessionData) next(rw, r.WithContext(ctx)) return case commonpbv3.RequestStatus_RequestMethodOrURLNotAllowed: diff --git a/pkg/auth/v3/session.go b/pkg/auth/v3/session.go deleted file mode 100644 index 7c1ee90..0000000 --- a/pkg/auth/v3/session.go +++ /dev/null @@ -1,17 +0,0 @@ -package authv3 - -import ( - "context" - - "github.com/RafayLabs/rcloud-base/pkg/common" - commonv3 "github.com/RafayLabs/rcloud-base/proto/types/commonpb/v3" -) - -func NewSessionContext(ctx context.Context, s *commonv3.SessionData) context.Context { - return context.WithValue(ctx, common.SessionDataKey, s) -} - -func GetSession(ctx context.Context) (*commonv3.SessionData, bool) { - s, ok := ctx.Value(common.SessionDataKey).(*commonv3.SessionData) - return s, ok -} diff --git a/pkg/common/constants.go b/pkg/common/constants.go index 65f4542..5c51554 100644 --- a/pkg/common/constants.go +++ b/pkg/common/constants.go @@ -47,6 +47,4 @@ const ( RelayCommandsAuditType = "RelayCommands" ) -type contextKey struct{} - var SessionDataKey contextKey diff --git a/pkg/common/types.go b/pkg/common/types.go index 4f5a61b..d933a30 100644 --- a/pkg/common/types.go +++ b/pkg/common/types.go @@ -24,3 +24,5 @@ type CliConfigDownloadData struct { Organization string `json:"organization"` Partner string `json:"partner"` } + +type contextKey struct{} diff --git a/pkg/service/project.go b/pkg/service/project.go index 0a2ef4a..37189bc 100644 --- a/pkg/service/project.go +++ b/pkg/service/project.go @@ -219,18 +219,12 @@ func (s *projectService) Delete(ctx context.Context, project *systemv3.Project) } func (s *projectService) List(ctx context.Context, project *systemv3.Project) (*systemv3.ProjectList, error) { - sessionData := ctx.Value(common.SessionDataKey) + sd, ok := ctx.Value(common.SessionDataKey).(*commonv3.SessionData) username := "" - if sessionData == nil { + if !ok { return &systemv3.ProjectList{}, fmt.Errorf("cannot perform project listing without auth") - } else { - sd, ok := sessionData.(*commonv3.SessionData) - if !ok { - return &systemv3.ProjectList{}, fmt.Errorf("cannot perform project listing without auth") - } else { - username = sd.Username - } } + username = sd.Username var projects []*systemv3.Project projectList := &systemv3.ProjectList{ diff --git a/pkg/service/user.go b/pkg/service/user.go index 9db0abe..5d4a719 100644 --- a/pkg/service/user.go +++ b/pkg/service/user.go @@ -376,18 +376,12 @@ func (s *userService) GetByName(ctx context.Context, user *userv3.User) (*userv3 } func (s *userService) GetUserInfo(ctx context.Context, user *userv3.User) (*userv3.UserInfo, error) { - sessionData := ctx.Value(common.SessionDataKey) + sd, ok := ctx.Value(common.SessionDataKey).(*commonv3.SessionData) username := "" - if sessionData == nil { + if !ok { return &userv3.UserInfo{}, fmt.Errorf("cannot perform project listing without auth") - } else { - sd, ok := sessionData.(*commonv3.SessionData) - if !ok { - return &userv3.UserInfo{}, fmt.Errorf("cannot perform project listing without auth") - } else { - username = sd.Username - } } + username = sd.Username entity, err := dao.GetByTraits(ctx, s.db, username, &models.KratosIdentities{}) if err != nil { diff --git a/server/user.go b/server/user.go index a2a45a1..6fb4164 100644 --- a/server/user.go +++ b/server/user.go @@ -5,7 +5,7 @@ import ( "encoding/json" "fmt" - authv3 "github.com/RafayLabs/rcloud-base/pkg/auth/v3" + "github.com/RafayLabs/rcloud-base/pkg/common" "github.com/RafayLabs/rcloud-base/pkg/query" "github.com/RafayLabs/rcloud-base/pkg/service" rpcv3 "github.com/RafayLabs/rcloud-base/proto/rpc/user" @@ -75,7 +75,7 @@ func (s *userServer) UpdateUser(ctx context.Context, req *userpbv3.User) (*userp } func (s *userServer) DownloadCliConfig(ctx context.Context, req *rpcv3.CliConfigRequest) (*commonv3.HttpBody, error) { - sessData, ok := authv3.GetSession(ctx) + sessData, ok := ctx.Value(common.SessionDataKey).(*commonv3.SessionData) if !ok { return nil, fmt.Errorf("unable to retrieve session data") }