diff --git a/components/usermgmt/main.go b/components/usermgmt/main.go index 22754e3..4b12dc0 100644 --- a/components/usermgmt/main.go +++ b/components/usermgmt/main.go @@ -62,6 +62,7 @@ var ( gs service.GroupService rs service.RoleService is service.IdpService + ps service.OIDCProviderService dev bool _log = logv2.GetLogger() authPool authv3.AuthPool @@ -124,6 +125,7 @@ func setup() { gs = service.NewGroupService(db) rs = service.NewRoleService(db) is = service.NewIdpService(db) + ps = service.NewOIDCProviderService(db) _log.Infow("usermgmt setup complete") } @@ -158,6 +160,7 @@ func runAPI(wg *sync.WaitGroup, ctx context.Context) { pbrpcv3.RegisterGroupHandlerFromEndpoint, pbrpcv3.RegisterRoleHandlerFromEndpoint, pbrpcv3.RegisterIdpHandlerFromEndpoint, + pbrpcv3.RegisterOIDCProviderHandlerFromEndpoint, ) if err != nil { _log.Fatalw("unable to create gateway", "error", err) @@ -192,6 +195,7 @@ func runRPC(wg *sync.WaitGroup, ctx context.Context) { groupServer := server.NewGroupServer(gs) roleServer := server.NewRoleServer(rs) idpServer := server.NewIdpServer(is) + oidcProviderServer := server.NewOIDCServer(ps) l, err := net.Listen("tcp", fmt.Sprintf(":%d", rpcPort)) if err != nil { @@ -229,6 +233,7 @@ func runRPC(wg *sync.WaitGroup, ctx context.Context) { rpcv3.RegisterGroupServer(s, groupServer) rpcv3.RegisterRoleServer(s, roleServer) rpcv3.RegisterIdpServer(s, idpServer) + rpcv3.RegisterOIDCProviderServer(s, oidcProviderServer) _log.Infow("starting rpc server", "port", rpcPort) err = s.Serve(l) diff --git a/components/usermgmt/pkg/service/oidc_provider.go b/components/usermgmt/pkg/service/oidc_provider.go index 415d4f8..c400f86 100644 --- a/components/usermgmt/pkg/service/oidc_provider.go +++ b/components/usermgmt/pkg/service/oidc_provider.go @@ -2,10 +2,18 @@ package service import ( "context" + "fmt" + "net/url" + "os" + "time" "github.com/RafaySystems/rcloud-base/components/common/pkg/persistence/provider/pg" + commonv3 "github.com/RafaySystems/rcloud-base/components/common/proto/types/commonpb/v3" + "github.com/RafaySystems/rcloud-base/components/usermgmt/pkg/internal/models" userv3 "github.com/RafaySystems/rcloud-base/components/usermgmt/proto/types/userpb/v3" + "github.com/google/uuid" bun "github.com/uptrace/bun" + "google.golang.org/protobuf/types/known/structpb" ) type OIDCProviderService interface { @@ -26,16 +34,163 @@ func NewOIDCProviderService(db *bun.DB) OIDCProviderService { } } +func generateCallbackUrl() (string, error) { + base, err := url.Parse(os.Getenv("APP_HOST_HTTP")) + if err != nil { + return "", err + } + uuid := uuid.New() + return fmt.Sprintf("%s/auth/v3/sso/callback/%s", base, uuid), nil +} + func (s *oidcProvider) Create(ctx context.Context, provider *userv3.OIDCProvider) (*userv3.OIDCProvider, error) { - return &userv3.OIDCProvider{}, nil + // validate name + name := provider.Metadata.GetName() + if len(name) == 0 { + return &userv3.OIDCProvider{}, fmt.Errorf("EMPTY NAME") + } + e := &models.OIDCProvider{} + s.dao.GetByName(ctx, name, e) + if e.Name == name { + return &userv3.OIDCProvider{}, fmt.Errorf("DUPLICATE NAME") + } + + callback, err := generateCallbackUrl() + if err != nil { + return &userv3.OIDCProvider{}, err + } + entity := &models.OIDCProvider{ + Name: name, + CreatedAt: time.Time{}, + ModifiedAt: time.Time{}, + ProviderName: provider.Spec.GetProviderName(), + MapperURL: provider.Spec.GetMapperUrl(), + MapperFilename: provider.Spec.GetMapperFilename(), + ClientId: provider.Spec.GetClientId(), + ClientSecret: provider.Spec.GetClientSecret(), + Scopes: provider.Spec.GetScopes(), + IssuerURL: provider.Spec.GetIssuerUrl(), + AuthURL: provider.Spec.GetAuthUrl(), + TokenURL: provider.Spec.GetTokenUrl(), + RequestedClaims: provider.Spec.GetRequestedClaims().AsMap(), + Predefined: provider.Spec.GetPredefined(), + CallbackURL: callback, + } + _, err = s.dao.Create(ctx, entity) + if err != nil { + return &userv3.OIDCProvider{}, err + } + + rclaims, _ := structpb.NewStruct(entity.RequestedClaims) + rv := &userv3.OIDCProvider{ + ApiVersion: "usermgmt.k8smgmt.io/v3", + Kind: "OIDCProvider", + Metadata: &commonv3.Metadata{ + Name: entity.Name, + Description: entity.Description, + Id: entity.Id.String(), + }, + Spec: &userv3.OIDCProviderSpec{ + ProviderName: entity.ProviderName, + MapperUrl: entity.MapperURL, + MapperFilename: entity.MapperFilename, + ClientId: entity.ClientId, + ClientSecret: entity.ClientSecret, + Scopes: entity.Scopes, + IssuerUrl: entity.IssuerURL, + AuthUrl: entity.AuthURL, + TokenUrl: entity.TokenURL, + RequestedClaims: rclaims, + Predefined: entity.Predefined, + CallbackUrl: entity.CallbackURL, + }, + } + return rv, nil } func (s *oidcProvider) GetByID(ctx context.Context, provider *userv3.OIDCProvider) (*userv3.OIDCProvider, error) { - return &userv3.OIDCProvider{}, nil + id, err := uuid.Parse(provider.Metadata.GetId()) + if err != nil { + return &userv3.OIDCProvider{}, err + } + + entity := &models.OIDCProvider{} + _, err = s.dao.GetByID(ctx, id, entity) + // TODO: Return proper error for Id not exist + if err != nil { + return &userv3.OIDCProvider{}, err + } + + rclaims, _ := structpb.NewStruct(entity.RequestedClaims) + rv := &userv3.OIDCProvider{ + ApiVersion: "usermgmt.k8smgmt.io/v3", + Kind: "OIDCProvider", + Metadata: &commonv3.Metadata{ + Name: entity.Name, + Description: entity.Description, + Id: entity.Id.String(), + }, + Spec: &userv3.OIDCProviderSpec{ + ProviderName: entity.ProviderName, + MapperUrl: entity.MapperURL, + MapperFilename: entity.MapperFilename, + ClientId: entity.ClientId, + Scopes: entity.Scopes, + IssuerUrl: entity.IssuerURL, + AuthUrl: entity.AuthURL, + TokenUrl: entity.TokenURL, + RequestedClaims: rclaims, + Predefined: entity.Predefined, + CallbackUrl: entity.CallbackURL, + }, + } + return rv, nil } func (s *oidcProvider) List(ctx context.Context) (*userv3.OIDCProviderList, error) { - return &userv3.OIDCProviderList{}, nil + var ( + entities []models.OIDCProvider + orgID uuid.NullUUID + parID uuid.NullUUID + ) + _, err := s.dao.List(ctx, parID, orgID, &entities) + if err != nil { + return &userv3.OIDCProviderList{}, nil + } + var result []*userv3.OIDCProvider + for _, entity := range entities { + rclaims, _ := structpb.NewStruct(entity.RequestedClaims) + e := &userv3.OIDCProvider{ + ApiVersion: "usermgmt.k8smgmt.io/v3", + Kind: "OIDCProvider", + Metadata: &commonv3.Metadata{ + Name: entity.Name, + Description: entity.Description, + Id: entity.Id.String(), + }, + Spec: &userv3.OIDCProviderSpec{ + ProviderName: entity.ProviderName, + MapperUrl: entity.MapperURL, + MapperFilename: entity.MapperFilename, + ClientId: entity.ClientId, + Scopes: entity.Scopes, + IssuerUrl: entity.IssuerURL, + AuthUrl: entity.AuthURL, + TokenUrl: entity.TokenURL, + RequestedClaims: rclaims, + Predefined: entity.Predefined, + CallbackUrl: entity.CallbackURL, + }, + } + result = append(result, e) + } + + rv := &userv3.OIDCProviderList{ + ApiVersion: "usermgmt.k8smgmt.io/v3", + Kind: "OIDCProviderList", + Items: result, + } + return rv, nil } func (s *oidcProvider) Update(ctx context.Context, provider *userv3.OIDCProvider) (*userv3.OIDCProvider, error) {