diff --git a/pkg/service/user.go b/pkg/service/user.go index ad9bae3..bf62087 100644 --- a/pkg/service/user.go +++ b/pkg/service/user.go @@ -2,6 +2,7 @@ package service import ( "context" + "database/sql" "fmt" "strconv" "time" @@ -98,7 +99,7 @@ func getUserTraits(traits map[string]interface{}) userTraits { } // Map roles to accounts -func (s *userService) createUserRoleRelations(ctx context.Context, user *userv3.User, ids parsedIds) (*userv3.User, error) { +func (s *userService) createUserRoleRelations(ctx context.Context, db bun.IDB, user *userv3.User, ids parsedIds) (*userv3.User, error) { projectNamespaceRoles := user.GetSpec().GetProjectNamespaceRoles() // TODO: add transactions @@ -108,7 +109,7 @@ func (s *userService) createUserRoleRelations(ctx context.Context, user *userv3. var ps []*authzv1.Policy for _, pnr := range projectNamespaceRoles { role := pnr.GetRole() - entity, err := pg.GetIdByName(ctx, s.db, role, &models.Role{}) + entity, err := pg.GetIdByName(ctx, db, role, &models.Role{}) if err != nil { return user, fmt.Errorf("unable to find role '%v'", role) } @@ -125,7 +126,7 @@ func (s *userService) createUserRoleRelations(ctx context.Context, user *userv3. switch { case pnr.Namespace != nil: - projectId, err := pg.GetProjectId(ctx, s.db, project) + projectId, err := pg.GetProjectId(ctx, db, project) if err != nil { return user, fmt.Errorf("unable to find project '%v'", project) } @@ -151,7 +152,7 @@ func (s *userService) createUserRoleRelations(ctx context.Context, user *userv3. Obj: role, }) case project != "": - projectId, err := pg.GetProjectId(ctx, s.db, project) + projectId, err := pg.GetProjectId(ctx, db, project) if err != nil { return user, fmt.Errorf("unable to find project '%v'", project) } @@ -200,19 +201,19 @@ func (s *userService) createUserRoleRelations(ctx context.Context, user *userv3. } } if len(panrs) > 0 { - _, err := pg.Create(ctx, s.db, &panrs) + _, err := pg.Create(ctx, db, &panrs) if err != nil { return &userv3.User{}, err } } if len(pars) > 0 { - _, err := pg.Create(ctx, s.db, &pars) + _, err := pg.Create(ctx, db, &pars) if err != nil { return &userv3.User{}, err } } if len(ars) > 0 { - _, err := pg.Create(ctx, s.db, &ars) + _, err := pg.Create(ctx, db, &ars) if err != nil { return &userv3.User{}, err } @@ -229,14 +230,14 @@ func (s *userService) createUserRoleRelations(ctx context.Context, user *userv3. } // FIXME: make this generic -func (s *userService) getPartnerOrganization(ctx context.Context, user *userv3.User) (uuid.UUID, uuid.UUID, error) { +func (s *userService) getPartnerOrganization(ctx context.Context, db bun.IDB, user *userv3.User) (uuid.UUID, uuid.UUID, error) { partner := user.GetMetadata().GetPartner() org := user.GetMetadata().GetOrganization() - partnerId, err := pg.GetPartnerId(ctx, s.db, partner) + partnerId, err := pg.GetPartnerId(ctx, db, partner) if err != nil { return uuid.Nil, uuid.Nil, err } - organizationId, err := pg.GetOrganizationId(ctx, s.db, org) + organizationId, err := pg.GetOrganizationId(ctx, db, org) if err != nil { return partnerId, uuid.Nil, err } @@ -246,7 +247,7 @@ func (s *userService) getPartnerOrganization(ctx context.Context, user *userv3.U func (s *userService) Create(ctx context.Context, user *userv3.User) (*userv3.User, error) { // TODO: restrict endpoint to admin - partnerId, organizationId, err := s.getPartnerOrganization(ctx, user) + partnerId, organizationId, err := s.getPartnerOrganization(ctx, s.db, user) if err != nil { return nil, fmt.Errorf("unable to get partner and org id") } @@ -263,11 +264,22 @@ func (s *userService) Create(ctx context.Context, user *userv3.User) (*userv3.Us } uid, _ := uuid.Parse(id) - user, err = s.createUserRoleRelations(ctx, user, parsedIds{Id: uid, Partner: partnerId, Organization: organizationId}) + + tx, err := s.db.BeginTx(ctx, &sql.TxOptions{}) if err != nil { return &userv3.User{}, err } + user, err = s.createUserRoleRelations(ctx, tx, user, parsedIds{Id: uid, Partner: partnerId, Organization: organizationId}) + if err != nil { + tx.Rollback() + return &userv3.User{}, err + } + err = tx.Commit() + if err != nil { + fmt.Println("unable to commit changes", err) + } + rl, err := s.ap.GetRecoveryLink(ctx, id) fmt.Println("Recovery link:", rl) // TODO: email the recovery link to the user if err != nil { @@ -277,9 +289,9 @@ func (s *userService) Create(ctx context.Context, user *userv3.User) (*userv3.Us return user, nil } -func (s *userService) identitiesModelToUser(ctx context.Context, user *userv3.User, usr *models.KratosIdentities) (*userv3.User, error) { +func (s *userService) identitiesModelToUser(ctx context.Context, db bun.IDB, user *userv3.User, usr *models.KratosIdentities) (*userv3.User, error) { traits := getUserTraits(usr.Traits) - groups, err := dao.GetGroups(ctx, s.db, usr.ID) + groups, err := dao.GetGroups(ctx, db, usr.ID) if err != nil { return &userv3.User{}, err } @@ -290,7 +302,7 @@ func (s *userService) identitiesModelToUser(ctx context.Context, user *userv3.Us labels := make(map[string]string) - roles, err := dao.GetUserRoles(ctx, s.db, usr.ID) + roles, err := dao.GetUserRoles(ctx, db, usr.ID) if err != nil { return &userv3.User{}, err } @@ -325,7 +337,7 @@ func (s *userService) GetByID(ctx context.Context, user *userv3.User) (*userv3.U } if usr, ok := entity.(*models.KratosIdentities); ok { - user, err := s.identitiesModelToUser(ctx, user, usr) + user, err := s.identitiesModelToUser(ctx, s.db, user, usr) if err != nil { return &userv3.User{}, err } @@ -344,7 +356,7 @@ func (s *userService) GetByName(ctx context.Context, user *userv3.User) (*userv3 } if usr, ok := entity.(*models.KratosIdentities); ok { - user, err := s.identitiesModelToUser(ctx, user, usr) + user, err := s.identitiesModelToUser(ctx, s.db, user, usr) if err != nil { return &userv3.User{}, err } @@ -355,16 +367,16 @@ func (s *userService) GetByName(ctx context.Context, user *userv3.User) (*userv3 return user, nil } -func (s *userService) deleteUserRoleRelations(ctx context.Context, userId uuid.UUID, user *userv3.User) error { - err := pg.DeleteX(ctx, s.db, "account_id", userId, &models.AccountResourcerole{}) +func (s *userService) deleteUserRoleRelations(ctx context.Context, db bun.IDB, userId uuid.UUID, user *userv3.User) error { + err := pg.DeleteX(ctx, db, "account_id", userId, &models.AccountResourcerole{}) if err != nil { return err } - err = pg.DeleteX(ctx, s.db, "account_id", userId, &models.ProjectAccountResourcerole{}) + err = pg.DeleteX(ctx, db, "account_id", userId, &models.ProjectAccountResourcerole{}) if err != nil { return err } - err = pg.DeleteX(ctx, s.db, "account_id", userId, &models.ProjectAccountNamespaceRole{}) + err = pg.DeleteX(ctx, db, "account_id", userId, &models.ProjectAccountNamespaceRole{}) if err != nil { return err } @@ -385,7 +397,7 @@ func (s *userService) Update(ctx context.Context, user *userv3.User) (*userv3.Us } if usr, ok := entity.(*models.KratosIdentities); ok { - partnerId, organizationId, err := s.getPartnerOrganization(ctx, user) + partnerId, organizationId, err := s.getPartnerOrganization(ctx, s.db, user) if err != nil { return nil, fmt.Errorf("unable to get partner and org id") } @@ -399,20 +411,32 @@ func (s *userService) Update(ctx context.Context, user *userv3.User) (*userv3.Us return &userv3.User{}, err } - err = s.deleteUserRoleRelations(ctx, usr.ID, user) + tx, err := s.db.BeginTx(ctx, &sql.TxOptions{}) if err != nil { return &userv3.User{}, err } - user, err = s.createUserRoleRelations(ctx, user, parsedIds{Id: usr.ID, Partner: partnerId, Organization: organizationId}) + err = s.deleteUserRoleRelations(ctx, tx, usr.ID, user) if err != nil { + tx.Rollback() return &userv3.User{}, err } - } else { - return &userv3.User{}, fmt.Errorf("unable to update user '%v'", name) + + user, err = s.createUserRoleRelations(ctx, tx, user, parsedIds{Id: usr.ID, Partner: partnerId, Organization: organizationId}) + if err != nil { + tx.Rollback() + return &userv3.User{}, err + } + + err = tx.Commit() + if err != nil { + fmt.Println("unable to commit changes", err) + } + return user, nil } - return user, nil + return &userv3.User{}, fmt.Errorf("unable to update user '%v'", name) + } func (s *userService) Delete(ctx context.Context, user *userv3.User) (*userrpcv3.DeleteUserResponse, error) { @@ -423,21 +447,34 @@ func (s *userService) Delete(ctx context.Context, user *userv3.User) (*userrpcv3 } if usr, ok := entity.(*models.KratosIdentities); ok { - err = s.deleteUserRoleRelations(ctx, usr.ID, user) + + tx, err := s.db.BeginTx(ctx, &sql.TxOptions{}) if err != nil { return &userrpcv3.DeleteUserResponse{}, err } - err := s.ap.Delete(ctx, usr.ID.String()) + err = s.deleteUserRoleRelations(ctx, s.db, usr.ID, user) if err != nil { + tx.Rollback() + return &userrpcv3.DeleteUserResponse{}, err + } + + err = s.ap.Delete(ctx, usr.ID.String()) + if err != nil { + tx.Rollback() return &userrpcv3.DeleteUserResponse{}, err } err = pg.DeleteX(ctx, s.db, "account_id", usr.ID, &models.GroupAccount{}) if err != nil { + tx.Rollback() return &userrpcv3.DeleteUserResponse{}, fmt.Errorf("unable to delete user; %v", err) } + err = tx.Commit() + if err != nil { + fmt.Println("unable to commit changes", err) + } return &userrpcv3.DeleteUserResponse{}, nil } return &userrpcv3.DeleteUserResponse{}, fmt.Errorf("unable to delete user '%v'", user.Metadata.Name) @@ -461,7 +498,7 @@ func (s *userService) List(ctx context.Context, _ *userv3.User) (*userv3.UserLis if usrs, ok := entities.(*[]models.KratosIdentities); ok { for _, usr := range *usrs { user := &userv3.User{} - user, err := s.identitiesModelToUser(ctx, user, &usr) + user, err := s.identitiesModelToUser(ctx, s.db, user, &usr) if err != nil { return userList, err } diff --git a/pkg/service/user_test.go b/pkg/service/user_test.go index e8e289f..ad9d7fd 100644 --- a/pkg/service/user_test.go +++ b/pkg/service/user_test.go @@ -71,6 +71,9 @@ func TestCreateUser(t *testing.T) { mock.ExpectQuery(`SELECT "organization"."id" FROM "authsrv_organization" AS "organization"`). WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(ouuid)) + mock.ExpectBegin() + mock.ExpectCommit() + user := &userv3.User{ Metadata: &v3.Metadata{Partner: "partner-" + puuid, Organization: "org-" + ouuid, Name: "user-" + uuuid}, Spec: &userv3.UserSpec{}, @@ -124,6 +127,8 @@ func TestCreateUserWithRole(t *testing.T) { WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(puuid)) mock.ExpectQuery(`SELECT "organization"."id" FROM "authsrv_organization" AS "organization"`). WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(ouuid)) + + mock.ExpectBegin() mock.ExpectQuery(`SELECT "resourcerole"."id" FROM "authsrv_resourcerole" AS "resourcerole"`). WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(pruuid)) if tc.roles[0].Project != nil { @@ -132,6 +137,7 @@ func TestCreateUserWithRole(t *testing.T) { } mock.ExpectQuery(fmt.Sprintf(`INSERT INTO "%v"`, tc.dbname)). WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(uuid.New().String())) + mock.ExpectCommit() user := &userv3.User{ Metadata: &v3.Metadata{Partner: "partner-" + puuid, Organization: "org-" + ouuid, Name: "user-" + uuuid}, @@ -184,6 +190,8 @@ func TestUpdateUser(t *testing.T) { WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(puuid)) mock.ExpectQuery(`SELECT "organization"."id" FROM "authsrv_organization" AS "organization"`). WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(ouuid)) + + mock.ExpectBegin() mock.ExpectExec(`UPDATE "authsrv_accountresourcerole" AS "accountresourcerole" SET trash = TRUE WHERE`). WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectExec(`UPDATE "authsrv_projectaccountresourcerole" AS "projectaccountresourcerole" SET trash = TRUE WHERE`). @@ -196,6 +204,7 @@ func TestUpdateUser(t *testing.T) { WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(pruuid)) mock.ExpectQuery(`INSERT INTO "authsrv_projectaccountnamespacerole"`). WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(uuid.New().String())) + mock.ExpectCommit() user := &userv3.User{ Metadata: &v3.Metadata{Partner: "partner-" + puuid, Organization: "org-" + ouuid, Name: "user-" + uuuid}, @@ -392,6 +401,7 @@ func TestUserDelete(t *testing.T) { mock.ExpectQuery(`SELECT "identities"."id" FROM "identities" WHERE .*traits ->> 'email' = 'user-` + uuuid + `'`). WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id", "traits"}).AddRow(uuuid, []byte(`{"email":"johndoe@provider.com"}`))) + mock.ExpectBegin() mock.ExpectExec(`UPDATE "authsrv_accountresourcerole" AS "accountresourcerole" SET trash = TRUE WHERE`). WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectExec(`UPDATE "authsrv_projectaccountresourcerole" AS "projectaccountresourcerole" SET trash = TRUE WHERE`). @@ -401,6 +411,7 @@ func TestUserDelete(t *testing.T) { // User delete is via kratos mock.ExpectExec(`UPDATE "authsrv_groupaccount" AS "groupaccount" SET trash = TRUE WHERE`). WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() user := &userv3.User{ Metadata: &v3.Metadata{Partner: "partner-" + puuid, Organization: "org-" + ouuid, Name: "user-" + uuuid},