diff --git a/components/usermgmt/pkg/service/idp.go b/components/usermgmt/pkg/service/idp.go index 06ff6f5..f5fc417 100644 --- a/components/usermgmt/pkg/service/idp.go +++ b/components/usermgmt/pkg/service/idp.go @@ -23,6 +23,20 @@ import ( "github.com/uptrace/bun" ) +var baseUrl *url.URL + +func init() { + base, ok := os.LookupEnv("APP_HOST_HTTP") + if !ok || len(base) == 0 { + panic("APP_HOST_HTTP env not set") + } + var err error + baseUrl, err = url.Parse(base) + if err != nil { + panic("Failed to get application url") + } +} + type IdpService interface { Create(context.Context, *userv3.Idp) (*userv3.Idp, error) GetByID(context.Context, *userv3.Idp) (*userv3.Idp, error) @@ -41,10 +55,9 @@ func NewIdpService(db *bun.DB) IdpService { } } -func generateAcsURL(baseURL string) string { +func generateAcsURL() (string, error) { uuid := uuid.New() - acsURL := fmt.Sprintf("%s/%s/", baseURL, uuid.String()) - return acsURL + return fmt.Sprintf("%s/%s/", baseUrl.String(), uuid.String()), nil } // generateSpCert generates self signed certificate. Returns cert and @@ -105,6 +118,13 @@ func (s *idpService) Create(ctx context.Context, idp *userv3.Idp) (*userv3.Idp, name := idp.Metadata.GetName() domain := idp.Spec.GetDomain() + // validate name and domain + if len(name) == 0 { + return &userv3.Idp{}, fmt.Errorf("EMPTY NAME") + } + if len(domain) == 0 { + return &userv3.Idp{}, fmt.Errorf("EMPTY DOMAIN") + } e := &models.Idp{} s.dao.GetByName(ctx, name, e) if e.Name == name { @@ -115,11 +135,10 @@ func (s *idpService) Create(ctx context.Context, idp *userv3.Idp) (*userv3.Idp, return &userv3.Idp{}, fmt.Errorf("DUPLICATE DOMAIN") } - base, err := url.Parse(os.Getenv("APP_HOST_HTTP")) + acsURL, err := generateAcsURL() if err != nil { return &userv3.Idp{}, err } - acsURL := generateAcsURL(base.String()) entity := &models.Idp{ Name: name, Description: idp.Metadata.GetDescription(), @@ -135,7 +154,7 @@ func (s *idpService) Create(ctx context.Context, idp *userv3.Idp) (*userv3.Idp, SaeEnabled: idp.Spec.GetSaeEnabled(), } if entity.SaeEnabled { - spcert, spkey, err := generateSpCert(base.Host) + spcert, spkey, err := generateSpCert(baseUrl.Host) if err != nil { return &userv3.Idp{}, err } @@ -181,13 +200,11 @@ func (s *idpService) GetByID(ctx context.Context, idp *userv3.Idp) (*userv3.Idp, return &userv3.Idp{}, err } entity := &models.Idp{} + // TODO: Check for existance of id before GetByID _, err = s.dao.GetByID(ctx, id, entity) if err != nil { return &userv3.Idp{}, err } - if entity.Id != id { - return &userv3.Idp{}, fmt.Errorf("IDP ID DOES NOT EXISTS") - } rv := &userv3.Idp{ ApiVersion: "usermgmt.k8smgmt.io/v3", Kind: "Idp", @@ -217,17 +234,40 @@ func (s *idpService) GetByID(ctx context.Context, idp *userv3.Idp) (*userv3.Idp, } func (s *idpService) Update(ctx context.Context, idp *userv3.Idp) (*userv3.Idp, error) { + var id, orgId, partId uuid.UUID id, err := uuid.Parse(idp.Metadata.GetId()) + // TODO: 400 Bad Request if err != nil { return &userv3.Idp{}, err } + if len(idp.Metadata.GetOrganization()) != 0 { + orgId, err = uuid.Parse(idp.Metadata.GetOrganization()) + if err != nil { + return &userv3.Idp{}, err + } + } + if len(idp.Metadata.GetPartner()) != 0 { + partId, err = uuid.Parse(idp.Metadata.GetPartner()) + if err != nil { + return &userv3.Idp{}, err + } + } + _, err = s.dao.GetByID(ctx, id, &models.Idp{}) + // TODO: Return proper error for Id not exist + if err != nil { + return &userv3.Idp{}, err + } + entity := &models.Idp{ + Id: id, Name: idp.Metadata.GetName(), Description: idp.Metadata.GetDescription(), ModifiedAt: time.Now(), IdpName: idp.Spec.GetIdpName(), Domain: idp.Spec.GetDomain(), AcsURL: idp.Spec.GetAcsUrl(), + OrganizationId: orgId, + PartnerId: partId, SsoURL: idp.Spec.GetSsoUrl(), IdpCert: idp.Spec.GetIdpCert(), MetadataURL: idp.Spec.GetMetadataUrl(), @@ -236,17 +276,14 @@ func (s *idpService) Update(ctx context.Context, idp *userv3.Idp) (*userv3.Idp, SaeEnabled: idp.Spec.GetSaeEnabled(), } if entity.SaeEnabled { - base, err := url.Parse(os.Getenv("APP_HOST_HTTP")) - if err != nil { - return &userv3.Idp{}, err - } - spcert, spkey, err := generateSpCert(base.Host) + spcert, spkey, err := generateSpCert(baseUrl.Host) if err != nil { return &userv3.Idp{}, err } entity.SpCert = spcert entity.SpKey = spkey } + _, err = s.dao.Update(ctx, id, entity) if err != nil { return &userv3.Idp{}, err @@ -280,10 +317,15 @@ func (s *idpService) Update(ctx context.Context, idp *userv3.Idp) (*userv3.Idp, } func (s *idpService) List(ctx context.Context) (*userv3.IdpList, error) { - var entities []models.Idp - var orgID uuid.NullUUID - var parID uuid.NullUUID - s.dao.List(ctx, parID, orgID, &entities) + var ( + entities []models.Idp + orgID uuid.NullUUID + parID uuid.NullUUID + ) + _, err := s.dao.List(ctx, parID, orgID, &entities) + if err != nil { + return &userv3.IdpList{}, err + } // Get idps only till limit var result []*userv3.Idp @@ -318,7 +360,7 @@ func (s *idpService) List(ctx context.Context) (*userv3.IdpList, error) { rv := &userv3.IdpList{ ApiVersion: "usermgmt.k8smgmt.io/v3", - Kind: "Idp", + Kind: "IdpList", Items: result, } return rv, nil @@ -329,9 +371,13 @@ func (s *idpService) Delete(ctx context.Context, idp *userv3.Idp) error { if err != nil { return err } - entity := &models.Idp{} - err = s.dao.Delete(ctx, id, entity) + _, err = s.dao.GetByID(ctx, id, entity) + if entity.Id != id { + return fmt.Errorf("ID DOES NOT EXISTS") + } + + err = s.dao.Delete(ctx, id, &models.Idp{}) if err != nil { return err }