diff --git a/internal/cluster/dao/cluster.go b/internal/cluster/dao/cluster.go index 44a14ff..f23c2ac 100644 --- a/internal/cluster/dao/cluster.go +++ b/internal/cluster/dao/cluster.go @@ -13,7 +13,7 @@ import ( "github.com/uptrace/bun" ) -func CreateCluster(ctx context.Context, tx bun.Tx, cluster *models.Cluster) error { +func CreateCluster(ctx context.Context, tx bun.IDB, cluster *models.Cluster) error { clstrToken := &models.ClusterToken{ OrganizationId: cluster.OrganizationId, diff --git a/internal/dao/bootstrap.go b/internal/dao/bootstrap.go index 9e3f630..3bcf48d 100644 --- a/internal/dao/bootstrap.go +++ b/internal/dao/bootstrap.go @@ -128,10 +128,7 @@ func CreateBootstrapAgent(ctx context.Context, db bun.IDB, ba *models.BootstrapA return err } -// We are explicitly taking in Tx here as it was previously doing RunInTx -// TODO: should we take Tx here or just assume we will be passed in a tx? Or should we create one? func RegisterBootstrapAgent(ctx context.Context, db bun.Tx, token string) error { - ba, err := getBootstrapAgentForToken(ctx, db, token) if err != nil { return err diff --git a/pkg/service/cluster.go b/pkg/service/cluster.go index e9a5384..6cecaa8 100644 --- a/pkg/service/cluster.go +++ b/pkg/service/cluster.go @@ -235,13 +235,17 @@ func (es *clusterService) Create(ctx context.Context, cluster *infrav3.Cluster) cluster.Spec.ClusterData.Health = infrav3.Health_EDGE_IGNORE - err = es.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { - return dao.CreateCluster(ctx, tx, edb) - }) + tx, err := es.db.BeginTx(ctx, &sql.TxOptions{}) if err != nil { return &infrav3.Cluster{}, err } + err = dao.CreateCluster(ctx, tx, edb) + if err != nil { + tx.Rollback() + return &infrav3.Cluster{}, err + } + // if project is set create project cluster var pc *models.ProjectCluster pcList := make([]models.ProjectCluster, 0) @@ -250,8 +254,9 @@ func (es *clusterService) Create(ctx context.Context, cluster *infrav3.Cluster) ProjectID: edb.ProjectId, ClusterID: edb.ID, } - err = dao.CreateProjectCluster(ctx, es.db, pc) + err = dao.CreateProjectCluster(ctx, tx, pc) if err != nil { + tx.Rollback() return &infrav3.Cluster{}, err } pcList = append(pcList, *pc) @@ -273,9 +278,22 @@ func (es *clusterService) Create(ctx context.Context, cluster *infrav3.Cluster) YamlContent: operatorSpecEncoded, } - es.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { - return dao.CreateOperatorBootstrap(ctx, tx, &bootstrapData) - }) + err = dao.CreateOperatorBootstrap(ctx, tx, &bootstrapData) + if err != nil { + tx.Rollback() + cluster.Status = &commonv3.Status{ + ConditionType: "Create", + ConditionStatus: commonv3.ConditionStatus_StatusFailed, + Reason: err.Error(), + } + return cluster, err + } + } + + err = tx.Commit() + if err != nil { + tx.Rollback() + _log.Warn("unable to commit changes", err) } ev := event.Resource{ diff --git a/pkg/service/cluster_test.go b/pkg/service/cluster_test.go index 07ae858..b11ed08 100644 --- a/pkg/service/cluster_test.go +++ b/pkg/service/cluster_test.go @@ -47,10 +47,9 @@ func TestCreateCluster(t *testing.T) { WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(cuuid)) mock.ExpectExec(`UPDATE "cluster_clusters"`). WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - mock.ExpectExec(`INSERT INTO "cluster_project_cluster"`). WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() cluster := &infrav3.Cluster{ Metadata: &v3.Metadata{Id: cuuid, Name: "cluster-" + cuuid, Organization: "orgname", Project: "project-" + puuid}, diff --git a/pkg/service/group.go b/pkg/service/group.go index 13d8d48..c1a1da9 100644 --- a/pkg/service/group.go +++ b/pkg/service/group.go @@ -2,6 +2,7 @@ package service import ( "context" + "database/sql" "fmt" "strconv" "time" @@ -49,18 +50,18 @@ func NewGroupService(db *bun.DB, azc AuthzService) GroupService { return &groupService{db: db, azc: azc} } -func (s *groupService) deleteGroupRoleRelaitons(ctx context.Context, groupId uuid.UUID, group *userv3.Group) (*userv3.Group, error) { +func (s *groupService) deleteGroupRoleRelaitons(ctx context.Context, db bun.IDB, groupId uuid.UUID, group *userv3.Group) (*userv3.Group, error) { // delete previous entries // TODO: single delete command - err := pg.DeleteX(ctx, s.db, "group_id", groupId, &models.GroupRole{}) + err := pg.DeleteX(ctx, db, "group_id", groupId, &models.GroupRole{}) if err != nil { return &userv3.Group{}, err } - err = pg.DeleteX(ctx, s.db, "group_id", groupId, &models.ProjectGroupRole{}) + err = pg.DeleteX(ctx, db, "group_id", groupId, &models.ProjectGroupRole{}) if err != nil { return &userv3.Group{}, err } - err = pg.DeleteX(ctx, s.db, "group_id", groupId, &models.ProjectGroupNamespaceRole{}) + err = pg.DeleteX(ctx, db, "group_id", groupId, &models.ProjectGroupNamespaceRole{}) if err != nil { return &userv3.Group{}, err } @@ -73,7 +74,7 @@ func (s *groupService) deleteGroupRoleRelaitons(ctx context.Context, groupId uui } // Map roles to groups -func (s *groupService) createGroupRoleRelations(ctx context.Context, group *userv3.Group, ids parsedIds) (*userv3.Group, error) { +func (s *groupService) createGroupRoleRelations(ctx context.Context, db bun.IDB, group *userv3.Group, ids parsedIds) (*userv3.Group, error) { // TODO: add transactions projectNamespaceRoles := group.GetSpec().GetProjectNamespaceRoles() @@ -83,7 +84,7 @@ func (s *groupService) createGroupRoleRelations(ctx context.Context, group *user var ps []*authzv1.Policy for _, pnr := range projectNamespaceRoles { role := pnr.GetRole() - entity, err := pg.GetIdByName(ctx, s.db, role, &models.Role{}) + entity, err := pg.GetIdByName(ctx, db, role, &models.Role{}) if err != nil { return &userv3.Group{}, fmt.Errorf("unable to find role '%v'", role) } @@ -99,7 +100,7 @@ func (s *groupService) createGroupRoleRelations(ctx context.Context, group *user namespaceId := pnr.GetNamespace() // TODO: lookup id from name switch { case namespaceId != 0: - projectId, err := pg.GetProjectId(ctx, s.db, project) + projectId, err := pg.GetProjectId(ctx, db, project) if err != nil { return &userv3.Group{}, fmt.Errorf("unable to find project '%v'", project) } @@ -123,7 +124,7 @@ func (s *groupService) createGroupRoleRelations(ctx context.Context, group *user Obj: role, }) case project != "": - projectId, err := pg.GetProjectId(ctx, s.db, project) + projectId, err := pg.GetProjectId(ctx, db, project) if err != nil { return &userv3.Group{}, fmt.Errorf("unable to find project '%v'", project) } @@ -165,19 +166,19 @@ func (s *groupService) createGroupRoleRelations(ctx context.Context, group *user } } if len(pgnrs) > 0 { - _, err := pg.Create(ctx, s.db, &pgnrs) + _, err := pg.Create(ctx, db, &pgnrs) if err != nil { return &userv3.Group{}, err } } if len(pgrs) > 0 { - _, err := pg.Create(ctx, s.db, &pgrs) + _, err := pg.Create(ctx, db, &pgrs) if err != nil { return &userv3.Group{}, err } } if len(grs) > 0 { - _, err := pg.Create(ctx, s.db, &grs) + _, err := pg.Create(ctx, db, &grs) if err != nil { return &userv3.Group{}, err } @@ -193,8 +194,8 @@ func (s *groupService) createGroupRoleRelations(ctx context.Context, group *user return group, nil } -func (s *groupService) deleteGroupAccountRelations(ctx context.Context, groupId uuid.UUID, group *userv3.Group) (*userv3.Group, error) { - err := pg.DeleteX(ctx, s.db, "group_id", groupId, &models.GroupAccount{}) +func (s *groupService) deleteGroupAccountRelations(ctx context.Context, db bun.IDB, groupId uuid.UUID, group *userv3.Group) (*userv3.Group, error) { + err := pg.DeleteX(ctx, db, "group_id", groupId, &models.GroupAccount{}) if err != nil { return &userv3.Group{}, fmt.Errorf("unable to delete user; %v", err) } @@ -207,13 +208,13 @@ func (s *groupService) deleteGroupAccountRelations(ctx context.Context, groupId } // Update the users(account) mapped to each group -func (s *groupService) createGroupAccountRelations(ctx context.Context, groupId uuid.UUID, group *userv3.Group) (*userv3.Group, error) { +func (s *groupService) createGroupAccountRelations(ctx context.Context, db bun.IDB, groupId uuid.UUID, group *userv3.Group) (*userv3.Group, error) { // TODO: add transactions var grpaccs []models.GroupAccount var ugs []*authzv1.UserGroup for _, account := range unique(group.GetSpec().GetUsers()) { // FIXME: do combined lookup - entity, err := pg.GetIdByTraits(ctx, s.db, account, &models.KratosIdentities{}) + entity, err := pg.GetIdByTraits(ctx, db, account, &models.KratosIdentities{}) if err != nil { return &userv3.Group{}, fmt.Errorf("unable to find user '%v'", account) } @@ -236,7 +237,7 @@ func (s *groupService) createGroupAccountRelations(ctx context.Context, groupId if len(grpaccs) == 0 { return group, nil } - _, err := pg.Create(ctx, s.db, &grpaccs) + _, err := pg.Create(ctx, db, &grpaccs) if err != nil { return &userv3.Group{}, err } @@ -251,14 +252,14 @@ func (s *groupService) createGroupAccountRelations(ctx context.Context, groupId return group, nil } -func (s *groupService) getPartnerOrganization(ctx context.Context, group *userv3.Group) (uuid.UUID, uuid.UUID, error) { +func (s *groupService) getPartnerOrganization(ctx context.Context, db bun.IDB, group *userv3.Group) (uuid.UUID, uuid.UUID, error) { partner := group.GetMetadata().GetPartner() org := group.GetMetadata().GetOrganization() - partnerId, err := pg.GetPartnerId(ctx, s.db, partner) + partnerId, err := pg.GetPartnerId(ctx, db, partner) if err != nil { return uuid.Nil, uuid.Nil, err } - organizationId, err := pg.GetOrganizationId(ctx, s.db, org) + organizationId, err := pg.GetOrganizationId(ctx, db, org) if err != nil { return partnerId, uuid.Nil, err } @@ -267,7 +268,7 @@ func (s *groupService) getPartnerOrganization(ctx context.Context, group *userv3 } func (s *groupService) Create(ctx context.Context, group *userv3.Group) (*userv3.Group, error) { - partnerId, organizationId, err := s.getPartnerOrganization(ctx, group) + partnerId, organizationId, err := s.getPartnerOrganization(ctx, s.db, group) if err != nil { return nil, fmt.Errorf("unable to get partner and org id") } @@ -286,29 +287,45 @@ func (s *groupService) Create(ctx context.Context, group *userv3.Group) (*userv3 PartnerId: partnerId, Type: group.GetSpec().GetType(), } - entity, err := pg.Create(ctx, s.db, &grp) + + tx, err := s.db.BeginTx(ctx, &sql.TxOptions{}) if err != nil { return &userv3.Group{}, err } + entity, err := pg.Create(ctx, tx, &grp) + if err != nil { + tx.Rollback() // TODO: check errors for rollback (and do what?) + return &userv3.Group{}, err + } + //update v3 spec if grp, ok := entity.(*models.Group); ok { // we can get previous group using the id, find users/roles from that and delete those - group, err = s.createGroupAccountRelations(ctx, grp.ID, group) + group, err = s.createGroupAccountRelations(ctx, tx, grp.ID, group) if err != nil { + tx.Rollback() return &userv3.Group{}, err } - group, err = s.createGroupRoleRelations(ctx, group, parsedIds{Id: grp.ID, Partner: partnerId, Organization: organizationId}) + group, err = s.createGroupRoleRelations(ctx, tx, group, parsedIds{Id: grp.ID, Partner: partnerId, Organization: organizationId}) if err != nil { + tx.Rollback() return &userv3.Group{}, err } + + err = tx.Commit() + if err != nil { + tx.Rollback() + _log.Warn("unable to commit changes", err) + } return group, nil } + tx.Rollback() return &userv3.Group{}, fmt.Errorf("unable to create group") } -func (s *groupService) toV3Group(ctx context.Context, group *userv3.Group, grp *models.Group) (*userv3.Group, error) { +func (s *groupService) toV3Group(ctx context.Context, db bun.IDB, group *userv3.Group, grp *models.Group) (*userv3.Group, error) { labels := make(map[string]string) labels["organization"] = group.GetMetadata().GetOrganization() labels["partner"] = group.GetMetadata().GetPartner() @@ -323,7 +340,7 @@ func (s *groupService) toV3Group(ctx context.Context, group *userv3.Group, grp * Labels: labels, ModifiedAt: timestamppb.New(grp.ModifiedAt), } - users, err := dao.GetUsers(ctx, s.db, grp.ID) + users, err := dao.GetUsers(ctx, db, grp.ID) if err != nil { return &userv3.Group{}, err } @@ -332,7 +349,7 @@ func (s *groupService) toV3Group(ctx context.Context, group *userv3.Group, grp * userNames = append(userNames, u.Traits["email"].(string)) } - roles, err := dao.GetGroupRoles(ctx, s.db, grp.ID) + roles, err := dao.GetGroupRoles(ctx, db, grp.ID) if err != nil { return &userv3.Group{}, err } @@ -356,7 +373,7 @@ func (s *groupService) GetByID(ctx context.Context, group *userv3.Group) (*userv } if grp, ok := entity.(*models.Group); ok { - return s.toV3Group(ctx, group, grp) + return s.toV3Group(ctx, s.db, group, grp) } return group, nil @@ -364,7 +381,7 @@ func (s *groupService) GetByID(ctx context.Context, group *userv3.Group) (*userv func (s *groupService) GetByName(ctx context.Context, group *userv3.Group) (*userv3.Group, error) { name := group.GetMetadata().GetName() - partnerId, organizationId, err := s.getPartnerOrganization(ctx, group) + partnerId, organizationId, err := s.getPartnerOrganization(ctx, s.db, group) if err != nil { return nil, fmt.Errorf("unable to get partner and org id") } @@ -374,7 +391,7 @@ func (s *groupService) GetByName(ctx context.Context, group *userv3.Group) (*use } if grp, ok := entity.(*models.Group); ok { - return s.toV3Group(ctx, group, grp) + return s.toV3Group(ctx, s.db, group, grp) } return group, nil @@ -383,7 +400,7 @@ func (s *groupService) GetByName(ctx context.Context, group *userv3.Group) (*use func (s *groupService) Update(ctx context.Context, group *userv3.Group) (*userv3.Group, error) { // TODO: inform when unchanged name := group.GetMetadata().GetName() - partnerId, organizationId, err := s.getPartnerOrganization(ctx, group) + partnerId, organizationId, err := s.getPartnerOrganization(ctx, s.db, group) if err != nil { return nil, fmt.Errorf("unable to get partner and org id") } @@ -399,28 +416,44 @@ func (s *groupService) Update(ctx context.Context, group *userv3.Group) (*userv3 grp.Type = group.Spec.Type grp.ModifiedAt = time.Now() - // update account/role links - group, err = s.deleteGroupAccountRelations(ctx, grp.ID, group) - if err != nil { - return &userv3.Group{}, err - } - group, err = s.createGroupAccountRelations(ctx, grp.ID, group) - if err != nil { - return &userv3.Group{}, err - } - group, err = s.deleteGroupRoleRelaitons(ctx, grp.ID, group) - if err != nil { - return &userv3.Group{}, err - } - group, err = s.createGroupRoleRelations(ctx, group, parsedIds{Id: grp.ID, Partner: partnerId, Organization: organizationId}) + tx, err := s.db.BeginTx(ctx, &sql.TxOptions{}) if err != nil { return &userv3.Group{}, err } - _, err = pg.Update(ctx, s.db, grp.ID, grp) + // update account/role links + group, err = s.deleteGroupAccountRelations(ctx, tx, grp.ID, group) if err != nil { + tx.Rollback() return &userv3.Group{}, err } + group, err = s.createGroupAccountRelations(ctx, tx, grp.ID, group) + if err != nil { + tx.Rollback() + return &userv3.Group{}, err + } + group, err = s.deleteGroupRoleRelaitons(ctx, tx, grp.ID, group) + if err != nil { + tx.Rollback() + return &userv3.Group{}, err + } + group, err = s.createGroupRoleRelations(ctx, tx, group, parsedIds{Id: grp.ID, Partner: partnerId, Organization: organizationId}) + if err != nil { + tx.Rollback() + return &userv3.Group{}, err + } + + _, err = pg.Update(ctx, tx, grp.ID, grp) + if err != nil { + tx.Rollback() + return &userv3.Group{}, err + } + + err = tx.Commit() + if err != nil { + tx.Rollback() + _log.Warn("unable to commit changes", err) + } // update spec and status group.Spec = &userv3.GroupSpec{ @@ -435,7 +468,7 @@ func (s *groupService) Update(ctx context.Context, group *userv3.Group) (*userv3 func (s *groupService) Delete(ctx context.Context, group *userv3.Group) (*userv3.Group, error) { name := group.GetMetadata().GetName() - partnerId, organizationId, err := s.getPartnerOrganization(ctx, group) + partnerId, organizationId, err := s.getPartnerOrganization(ctx, s.db, group) if err != nil { return &userv3.Group{}, fmt.Errorf("unable to get partner and org id") } @@ -444,21 +477,37 @@ func (s *groupService) Delete(ctx context.Context, group *userv3.Group) (*userv3 return &userv3.Group{}, err } if grp, ok := entity.(*models.Group); ok { - group, err = s.deleteGroupRoleRelaitons(ctx, grp.ID, group) + + tx, err := s.db.BeginTx(ctx, &sql.TxOptions{}) if err != nil { return &userv3.Group{}, err } - group, err = s.deleteGroupAccountRelations(ctx, grp.ID, group) + + group, err = s.deleteGroupRoleRelaitons(ctx, s.db, grp.ID, group) if err != nil { + tx.Rollback() + return &userv3.Group{}, err + } + group, err = s.deleteGroupAccountRelations(ctx, s.db, grp.ID, group) + if err != nil { + tx.Rollback() return &userv3.Group{}, err } err = pg.Delete(ctx, s.db, grp.ID, grp) if err != nil { + tx.Rollback() return &userv3.Group{}, err } + + err = tx.Commit() + if err != nil { + tx.Rollback() + _log.Warn("unable to commit changes", err) + } + return group, nil } - return group, nil + return &userv3.Group{}, fmt.Errorf("unable to delete group") } func (s *groupService) List(ctx context.Context, group *userv3.Group) (*userv3.GroupList, error) { @@ -487,7 +536,7 @@ func (s *groupService) List(ctx context.Context, group *userv3.Group) (*userv3.G if grps, ok := entities.(*[]models.Group); ok { for _, grp := range *grps { entry := &userv3.Group{Metadata: group.GetMetadata()} - entry, err = s.toV3Group(ctx, entry, &grp) + entry, err = s.toV3Group(ctx, s.db, entry, &grp) if err != nil { return groupList, err } diff --git a/pkg/service/group_test.go b/pkg/service/group_test.go index c65c8c1..c90b556 100644 --- a/pkg/service/group_test.go +++ b/pkg/service/group_test.go @@ -94,8 +94,11 @@ func TestCreateGroupNoUsersNoRoles(t *testing.T) { WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(ouuid)) mock.ExpectQuery(`SELECT "group"."id" FROM "authsrv_group" AS "group" WHERE .organization_id = '` + ouuid + `'. AND .partner_id = '` + puuid + `'. AND .name = 'group-` + guuid + `'.`). WillReturnError(fmt.Errorf("no data available")) + + mock.ExpectBegin() mock.ExpectQuery(`INSERT INTO "authsrv_group"`). WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(guuid)) + mock.ExpectCommit() group := &userv3.Group{ Metadata: &v3.Metadata{Partner: "partner-" + puuid, Organization: "org-" + ouuid, Name: "group-" + guuid}, @@ -133,9 +136,12 @@ func TestCreateGroupDuplicate(t *testing.T) { WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(puuid)) mock.ExpectQuery(`SELECT "organization"."id" FROM "authsrv_organization" AS "organization"`). WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(ouuid)) + + mock.ExpectBegin() // TODO: more precise checks mock.ExpectQuery(`INSERT INTO "authsrv_group"`). WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(guuid)) + mock.ExpectCommit() _, err := gs.Create(context.Background(), group) if err == nil { t.Fatal("should not be able to recreate group with same name") @@ -171,6 +177,8 @@ func TestCreateGroupWithUsersNoRoles(t *testing.T) { WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(ouuid)) mock.ExpectQuery(`SELECT "group"."id" FROM "authsrv_group" AS "group" WHERE .organization_id = '` + ouuid + `'. AND .partner_id = '` + puuid + `'. AND .name = 'group-` + guuid + `'.`). WillReturnError(fmt.Errorf("no data available")) + + mock.ExpectBegin() mock.ExpectQuery(`INSERT INTO "authsrv_group"`). WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(guuid)) for _, u := range tc.users { @@ -179,6 +187,7 @@ func TestCreateGroupWithUsersNoRoles(t *testing.T) { } mock.ExpectQuery(`INSERT INTO "authsrv_groupaccount"`). WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(uuid.New().String())) + mock.ExpectCommit() group := &userv3.Group{ Metadata: &v3.Metadata{Partner: "partner-" + puuid, Organization: "org-" + ouuid, Name: "group-" + guuid}, @@ -236,6 +245,8 @@ func TestCreateGroupNoUsersWithRoles(t *testing.T) { WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(ouuid)) mock.ExpectQuery(`SELECT "group"."id" FROM "authsrv_group" AS "group" WHERE .organization_id = '` + ouuid + `'. AND .partner_id = '` + puuid + `'. AND .name = 'group-` + guuid + `'.`). WillReturnError(fmt.Errorf("no data available")) + + mock.ExpectBegin() mock.ExpectQuery(`INSERT INTO "authsrv_group"`). WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(guuid)) mock.ExpectQuery(`SELECT "resourcerole"."id" FROM "authsrv_resourcerole" AS "resourcerole"`). @@ -246,6 +257,7 @@ func TestCreateGroupNoUsersWithRoles(t *testing.T) { } mock.ExpectQuery(fmt.Sprintf(`INSERT INTO "%v"`, tc.dbname)). WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(uuid.New().String())) + mock.ExpectCommit() group := &userv3.Group{ Metadata: &v3.Metadata{Partner: "partner-" + puuid, Organization: "org-" + ouuid, Name: "group-" + guuid}, @@ -311,6 +323,7 @@ func TestCreateGroupWithUsersWithRoles(t *testing.T) { WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(ouuid)) mock.ExpectQuery(`SELECT "group"."id" FROM "authsrv_group" AS "group" WHERE .organization_id = '` + ouuid + `'. AND .partner_id = '` + puuid + `'. AND .name = 'group-` + guuid + `'.`).WithArgs() + mock.ExpectBegin() // TODO: more precise checks mock.ExpectQuery(`INSERT INTO "authsrv_group"`). WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(guuid)) @@ -329,6 +342,7 @@ func TestCreateGroupWithUsersWithRoles(t *testing.T) { } mock.ExpectQuery(fmt.Sprintf(`INSERT INTO "%v"`, tc.dbname)). WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(uuid.New().String())) + mock.ExpectCommit() group := &userv3.Group{ Metadata: &v3.Metadata{Partner: "partner-" + puuid, Organization: "org-" + ouuid, Name: "group-" + guuid}, @@ -393,7 +407,7 @@ func TestUpdateGroupWithUsersWithRoles(t *testing.T) { mock.ExpectQuery(`SELECT "group"."id", "group"."name",.* FROM "authsrv_group" AS "group" WHERE .*name = 'group-` + guuid + `'`). WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(guuid, "group-"+guuid)) - // TODO: more precise checks + mock.ExpectBegin() mock.ExpectExec(`UPDATE "authsrv_groupaccount" AS "groupaccount" SET trash = TRUE WHERE ."group_id" = '` + guuid). WillReturnResult(sqlmock.NewResult(1, 1)) for _, u := range tc.users { @@ -418,6 +432,7 @@ func TestUpdateGroupWithUsersWithRoles(t *testing.T) { WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(uuid.New().String())) mock.ExpectExec(`UPDATE "authsrv_group"`). WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() group := &userv3.Group{ Metadata: &v3.Metadata{Partner: "partner-" + puuid, Organization: "org-" + ouuid, Name: "group-" + guuid}, @@ -463,6 +478,7 @@ func TestGroupDelete(t *testing.T) { mock.ExpectQuery(`SELECT "group"."id", "group"."name", .* FROM "authsrv_group" AS "group" WHERE`). WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(guuid, "group-"+guuid)) + mock.ExpectBegin() mock.ExpectExec(`UPDATE "authsrv_grouprole" AS "grouprole" SET trash = TRUE WHERE ."group_id" = '` + guuid). WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectExec(`UPDATE "authsrv_projectgrouprole" AS "projectgrouprole" SET trash = TRUE WHERE ."group_id" = '` + guuid). @@ -473,6 +489,7 @@ func TestGroupDelete(t *testing.T) { WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectExec(`UPDATE "authsrv_group" AS "group" SET trash = TRUE WHERE .id = '` + guuid). WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() group := &userv3.Group{ Metadata: &v3.Metadata{Partner: "partner-" + puuid, Organization: "org-" + ouuid, Name: "group-" + guuid}, diff --git a/pkg/service/kubeconfig_revocation.go b/pkg/service/kubeconfig_revocation.go index 7f2c95d..7817bf7 100644 --- a/pkg/service/kubeconfig_revocation.go +++ b/pkg/service/kubeconfig_revocation.go @@ -52,16 +52,15 @@ func prepareKubeCfgRevocationResponse(kr *models.KubeconfigRevocation) *sentry.K } func (krs *kubeconfigRevocationService) Patch(ctx context.Context, kr *sentry.KubeconfigRevocation) error { - err := krs.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { - _, err := dao.GetKubeconfigRevocation(ctx, krs.db, uuid.MustParse(kr.OrganizationID), uuid.MustParse(kr.AccountID), kr.IsSSOUser) + return krs.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { + _, err := dao.GetKubeconfigRevocation(ctx, tx, uuid.MustParse(kr.OrganizationID), uuid.MustParse(kr.AccountID), kr.IsSSOUser) if err != nil && err == sql.ErrNoRows { kcr := convertToModel(kr) kcr.CreatedAt = time.Now() - return dao.CreateKubeconfigRevocation(ctx, krs.db, kcr) + return dao.CreateKubeconfigRevocation(ctx, tx, kcr) } - return dao.UpdateKubeconfigRevocation(ctx, krs.db, convertToModel(kr)) + return dao.UpdateKubeconfigRevocation(ctx, tx, convertToModel(kr)) }) - return err } func convertToModel(kr *sentry.KubeconfigRevocation) *models.KubeconfigRevocation { diff --git a/pkg/service/kubeconfig_settings.go b/pkg/service/kubeconfig_settings.go index 95a28ac..7319729 100644 --- a/pkg/service/kubeconfig_settings.go +++ b/pkg/service/kubeconfig_settings.go @@ -54,17 +54,16 @@ func (kss *kubeconfigSettingService) Patch(ctx context.Context, ks *sentry.Kubec if err != nil { accId = uuid.Nil } - err = kss.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { - _, err := dao.GetKubeconfigSetting(ctx, kss.db, uuid.MustParse(ks.OrganizationID), accId, ks.IsSSOUser) + return kss.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { + _, err := dao.GetKubeconfigSetting(ctx, tx, uuid.MustParse(ks.OrganizationID), accId, ks.IsSSOUser) db := convertToKubeCfgSettingModel(ks) if err != nil && err == sql.ErrNoRows { db.CreatedAt = time.Now() - return dao.CreateKubeconfigSetting(ctx, kss.db, convertToKubeCfgSettingModel(ks)) + return dao.CreateKubeconfigSetting(ctx, tx, convertToKubeCfgSettingModel(ks)) } db.ModifiedAt = time.Now() - return dao.UpdateKubeconfigSetting(ctx, kss.db, convertToKubeCfgSettingModel(ks)) + return dao.UpdateKubeconfigSetting(ctx, tx, convertToKubeCfgSettingModel(ks)) }) - return err } func prepareKubeCfgSettingResponse(ks *models.KubeconfigSetting) *sentry.KubeconfigSetting { diff --git a/pkg/service/kubectl_cluster_setting.go b/pkg/service/kubectl_cluster_setting.go index 86aaf75..50e4375 100644 --- a/pkg/service/kubectl_cluster_setting.go +++ b/pkg/service/kubectl_cluster_setting.go @@ -41,21 +41,20 @@ func (kcs *kubectlClusterSettingsService) Get(ctx context.Context, orgID string, } func (kcs *kubectlClusterSettingsService) Patch(ctx context.Context, kc *sentry.KubectlClusterSettings) error { - err := kcs.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { - _, err := dao.GetkubectlClusterSettings(ctx, kcs.db, uuid.MustParse(kc.OrganizationID), kc.Name) + return kcs.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { + _, err := dao.GetkubectlClusterSettings(ctx, tx, uuid.MustParse(kc.OrganizationID), kc.Name) if err != nil { if err == sql.ErrNoRows { kcsdb := convertToKubeCtlSettingModel(kc) kcsdb.CreatedAt = time.Now() - dao.CreatekubectlClusterSettings(ctx, kcs.db, kcsdb) + dao.CreatekubectlClusterSettings(ctx, tx, kcsdb) } return err } kcsdb := convertToKubeCtlSettingModel(kc) kcsdb.ModifiedAt = time.Now() - return dao.UpdatekubectlClusterSettings(ctx, kcs.db, kcsdb) + return dao.UpdatekubectlClusterSettings(ctx, tx, kcsdb) }) - return err } func convertToKubeCtlSettingModel(kcs *sentry.KubectlClusterSettings) *models.KubectlClusterSetting { diff --git a/pkg/service/organization.go b/pkg/service/organization.go index 4864962..795baa2 100644 --- a/pkg/service/organization.go +++ b/pkg/service/organization.go @@ -239,8 +239,7 @@ func (s *organizationService) Delete(ctx context.Context, organization *systemv3 } if org, ok := entity.(*models.Organization); ok { - org.Trash = true - _, err := pg.Update(ctx, s.db, org.ID, org) + err := pg.Delete(ctx, s.db, org.ID, org) if err != nil { return &systemv3.Organization{}, err } diff --git a/pkg/service/partner.go b/pkg/service/partner.go index 92a2891..828daf7 100644 --- a/pkg/service/partner.go +++ b/pkg/service/partner.go @@ -252,8 +252,7 @@ func (s *partnerService) Delete(ctx context.Context, partner *systemv3.Partner) } if part, ok := entity.(*models.Partner); ok { - part.Trash = true - _, err := pg.Update(ctx, s.db, part.ID, part) + err := pg.Delete(ctx, s.db, part.ID, part) if err != nil { return &systemv3.Partner{}, err } diff --git a/pkg/service/project.go b/pkg/service/project.go index e7f40ba..be852ea 100644 --- a/pkg/service/project.go +++ b/pkg/service/project.go @@ -204,8 +204,7 @@ func (s *projectService) Delete(ctx context.Context, project *systemv3.Project) return &systemv3.Project{}, err } if proj, ok := entity.(*models.Project); ok { - proj.Trash = true - _, err := pg.Update(ctx, s.db, proj.ID, proj) + err := pg.Delete(ctx, s.db, proj.ID, proj) if err != nil { return &systemv3.Project{}, err } diff --git a/pkg/service/role.go b/pkg/service/role.go index 01104e8..56572e6 100644 --- a/pkg/service/role.go +++ b/pkg/service/role.go @@ -2,6 +2,7 @@ package service import ( "context" + "database/sql" "fmt" "strings" "time" @@ -49,7 +50,7 @@ func NewRoleService(db *bun.DB, azc AuthzService) RoleService { return &roleService{db: db, azc: azc} } -func (s *roleService) getPartnerOrganization(ctx context.Context, role *rolev3.Role) (uuid.UUID, uuid.UUID, error) { +func (s *roleService) getPartnerOrganization(ctx context.Context, db bun.IDB, role *rolev3.Role) (uuid.UUID, uuid.UUID, error) { partner := role.GetMetadata().GetPartner() org := role.GetMetadata().GetOrganization() partnerId, err := pg.GetPartnerId(ctx, s.db, partner) @@ -64,7 +65,7 @@ func (s *roleService) getPartnerOrganization(ctx context.Context, role *rolev3.R } -func (s *roleService) deleteRolePermissionMapping(ctx context.Context, rleId uuid.UUID, role *rolev3.Role) (*rolev3.Role, error) { +func (s *roleService) deleteRolePermissionMapping(ctx context.Context, db bun.IDB, rleId uuid.UUID, role *rolev3.Role) (*rolev3.Role, error) { err := pg.DeleteX(ctx, s.db, "resource_role_id", rleId, &models.ResourceRolePermission{}) if err != nil { return &rolev3.Role{}, err @@ -82,7 +83,7 @@ func (s *roleService) deleteRolePermissionMapping(ctx context.Context, rleId uui return role, nil } -func (s *roleService) createRolePermissionMapping(ctx context.Context, role *rolev3.Role, ids parsedIds) (*rolev3.Role, error) { +func (s *roleService) createRolePermissionMapping(ctx context.Context, db bun.IDB, role *rolev3.Role, ids parsedIds) (*rolev3.Role, error) { perms := role.GetSpec().GetRolepermissions() var items []models.ResourceRolePermission @@ -121,7 +122,7 @@ func (s *roleService) createRolePermissionMapping(ctx context.Context, role *rol } func (s *roleService) Create(ctx context.Context, role *rolev3.Role) (*rolev3.Role, error) { - partnerId, organizationId, err := s.getPartnerOrganization(ctx, role) + partnerId, organizationId, err := s.getPartnerOrganization(ctx, s.db, role) if err != nil { return nil, fmt.Errorf("unable to get partner and org id") } @@ -150,21 +151,35 @@ func (s *roleService) Create(ctx context.Context, role *rolev3.Role) (*rolev3.Ro IsGlobal: role.GetSpec().GetIsGlobal(), Scope: strings.ToLower(scope), } - entity, err := pg.Create(ctx, s.db, &rle) + + tx, err := s.db.BeginTx(ctx, &sql.TxOptions{}) if err != nil { return &rolev3.Role{}, err } + entity, err := pg.Create(ctx, tx, &rle) + if err != nil { + tx.Rollback() + return &rolev3.Role{}, err + } + //update v3 spec if createdRole, ok := entity.(*models.Role); ok { - role, err = s.createRolePermissionMapping(ctx, role, parsedIds{Id: createdRole.ID, Partner: partnerId, Organization: organizationId}) + role, err = s.createRolePermissionMapping(ctx, tx, role, parsedIds{Id: createdRole.ID, Partner: partnerId, Organization: organizationId}) if err != nil { + tx.Rollback() return &rolev3.Role{}, err } } else { + tx.Rollback() return &rolev3.Role{}, fmt.Errorf("unable to create role '%v'", role.GetMetadata().GetName()) } + err = tx.Commit() + if err != nil { + tx.Rollback() + _log.Warn("unable to commit changes", err) + } return role, nil } @@ -181,7 +196,7 @@ func (s *roleService) GetByID(ctx context.Context, role *rolev3.Role) (*rolev3.R } if rle, ok := entity.(*models.Role); ok { - role, err = s.toV3Role(ctx, role, rle) + role, err = s.toV3Role(ctx, s.db, role, rle) if err != nil { return &rolev3.Role{}, err } @@ -193,7 +208,7 @@ func (s *roleService) GetByID(ctx context.Context, role *rolev3.Role) (*rolev3.R func (s *roleService) GetByName(ctx context.Context, role *rolev3.Role) (*rolev3.Role, error) { name := role.GetMetadata().GetName() - partnerId, organizationId, err := s.getPartnerOrganization(ctx, role) + partnerId, organizationId, err := s.getPartnerOrganization(ctx, s.db, role) if err != nil { return nil, fmt.Errorf("unable to get partner and org id") } @@ -203,7 +218,7 @@ func (s *roleService) GetByName(ctx context.Context, role *rolev3.Role) (*rolev3 } if rle, ok := entity.(*models.Role); ok { - role, err = s.toV3Role(ctx, role, rle) + role, err = s.toV3Role(ctx, s.db, role, rle) if err != nil { return &rolev3.Role{}, err } @@ -215,7 +230,7 @@ func (s *roleService) GetByName(ctx context.Context, role *rolev3.Role) (*rolev3 } func (s *roleService) Update(ctx context.Context, role *rolev3.Role) (*rolev3.Role, error) { - partnerId, organizationId, err := s.getPartnerOrganization(ctx, role) + partnerId, organizationId, err := s.getPartnerOrganization(ctx, s.db, role) if err != nil { return nil, fmt.Errorf("unable to get partner and org id") } @@ -234,18 +249,26 @@ func (s *roleService) Update(ctx context.Context, role *rolev3.Role) (*rolev3.Ro rle.Scope = role.Spec.Scope rle.ModifiedAt = time.Now() - _, err = pg.Update(ctx, s.db, rle.ID, rle) + tx, err := s.db.BeginTx(ctx, &sql.TxOptions{}) if err != nil { return &rolev3.Role{}, err } - role, err = s.deleteRolePermissionMapping(ctx, rle.ID, role) + _, err = pg.Update(ctx, tx, rle.ID, rle) if err != nil { + tx.Rollback() return &rolev3.Role{}, err } - role, err = s.createRolePermissionMapping(ctx, role, parsedIds{Id: rle.ID, Partner: partnerId, Organization: organizationId}) + role, err = s.deleteRolePermissionMapping(ctx, tx, rle.ID, role) if err != nil { + tx.Rollback() + return &rolev3.Role{}, err + } + + role, err = s.createRolePermissionMapping(ctx, tx, role, parsedIds{Id: rle.ID, Partner: partnerId, Organization: organizationId}) + if err != nil { + tx.Rollback() return &rolev3.Role{}, err } @@ -254,16 +277,21 @@ func (s *roleService) Update(ctx context.Context, role *rolev3.Role) (*rolev3.Ro IsGlobal: rle.IsGlobal, Scope: rle.Scope, } - } else { - return &rolev3.Role{}, fmt.Errorf("unable to update role '%v'", role.GetMetadata().GetName()) - } - return role, nil + err = tx.Commit() + if err != nil { + tx.Rollback() + _log.Warn("unable to commit changes", err) + } + return role, nil + } + return &rolev3.Role{}, fmt.Errorf("unable to update role '%v'", role.GetMetadata().GetName()) + } func (s *roleService) Delete(ctx context.Context, role *rolev3.Role) (*rolev3.Role, error) { name := role.GetMetadata().GetName() - partnerId, organizationId, err := s.getPartnerOrganization(ctx, role) + partnerId, organizationId, err := s.getPartnerOrganization(ctx, s.db, role) if err != nil { return &rolev3.Role{}, fmt.Errorf("unable to get partner and org id; %v", err) } @@ -274,21 +302,36 @@ func (s *roleService) Delete(ctx context.Context, role *rolev3.Role) (*rolev3.Ro } if rle, ok := entity.(*models.Role); ok { - role, err = s.deleteRolePermissionMapping(ctx, rle.ID, role) + + tx, err := s.db.BeginTx(ctx, &sql.TxOptions{}) if err != nil { return &rolev3.Role{}, err } + role, err = s.deleteRolePermissionMapping(ctx, tx, rle.ID, role) + if err != nil { + tx.Rollback() + return &rolev3.Role{}, err + } + err = pg.Delete(ctx, s.db, rle.ID, rle) if err != nil { + tx.Rollback() return &rolev3.Role{}, err } + + err = tx.Commit() + if err != nil { + tx.Rollback() + _log.Warn("unable to commit changes", err) + } + return role, nil } - return role, nil + return &rolev3.Role{}, fmt.Errorf("unable to delete role '%v'", role.GetMetadata().GetName()) } -func (s *roleService) toV3Role(ctx context.Context, role *rolev3.Role, rle *models.Role) (*rolev3.Role, error) { +func (s *roleService) toV3Role(ctx context.Context, db bun.IDB, role *rolev3.Role, rle *models.Role) (*rolev3.Role, error) { labels := make(map[string]string) labels["organization"] = role.GetMetadata().GetOrganization() labels["partner"] = role.GetMetadata().GetPartner() @@ -346,7 +389,7 @@ func (s *roleService) List(ctx context.Context, role *rolev3.Role) (*rolev3.Role if rles, ok := entities.(*[]models.Role); ok { for _, rle := range *rles { entry := &rolev3.Role{Metadata: role.GetMetadata()} - entry, err = s.toV3Role(ctx, entry, &rle) + entry, err = s.toV3Role(ctx, s.db, entry, &rle) if err != nil { return roleList, err } diff --git a/pkg/service/role_test.go b/pkg/service/role_test.go index 00de9f4..57548b2 100644 --- a/pkg/service/role_test.go +++ b/pkg/service/role_test.go @@ -64,9 +64,12 @@ func TestCreateRole(t *testing.T) { WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(ouuid)) mock.ExpectQuery(`SELECT "resourcerole"."id" FROM "authsrv_resourcerole" AS "resourcerole" WHERE .organization_id = '` + ouuid + `'. AND .partner_id = '` + puuid + `'. AND .name = 'role-` + ruuid + `'.`). WillReturnError(fmt.Errorf("no data available")) + + mock.ExpectBegin() // TODO: more precise checks mock.ExpectQuery(`INSERT INTO "authsrv_resourcerole"`). WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(ruuid)) + mock.ExpectCommit() role := &rolev3.Role{ Metadata: &v3.Metadata{Partner: "partner-" + puuid, Organization: "org-" + ouuid, Name: "role-" + ruuid}, @@ -96,12 +99,15 @@ func TestCreateRoleWithPermissions(t *testing.T) { WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(ouuid)) mock.ExpectQuery(`SELECT "resourcerole"."id" FROM "authsrv_resourcerole" AS "resourcerole" WHERE .organization_id = '` + ouuid + `'. AND .partner_id = '` + puuid + `'. AND .name = 'role-` + ruuid + `'.`). WillReturnError(fmt.Errorf("no data available")) + + mock.ExpectBegin() mock.ExpectQuery(`INSERT INTO "authsrv_resourcerole"`). WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(ruuid)) mock.ExpectQuery(`SELECT "resourcepermission"."id" FROM "authsrv_resourcepermission" AS "resourcepermission" WHERE .name = 'ops_star.all'.`). WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(uuid.New().String())) mock.ExpectQuery(`INSERT INTO "authsrv_resourcerolepermission"`). WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(uuid.New().String())) + mock.ExpectCommit() role := &rolev3.Role{ Metadata: &v3.Metadata{Partner: "partner-" + puuid, Organization: "org-" + ouuid, Name: "role-" + ruuid}, @@ -132,9 +138,12 @@ func TestCreateRoleDuplicate(t *testing.T) { WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(ouuid)) mock.ExpectQuery(` SELECT "resourcerole"."id" FROM "authsrv_resourcerole" AS "resourcerole" WHERE .organization_id = '` + ouuid + `'. AND .partner_id = '` + puuid + `'. AND .name = 'role-` + ruuid + `'.`). WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(ruuid)) + + mock.ExpectBegin() // TODO: more precise checks mock.ExpectQuery(`INSERT INTO "authsrv_resourcerole"`). WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(ruuid)) + mock.ExpectCommit() role := &rolev3.Role{ Metadata: &v3.Metadata{Partner: "partner-" + puuid, Organization: "org-" + ouuid, Name: "role-" + ruuid}, @@ -164,6 +173,7 @@ func TestUpdateRole(t *testing.T) { mock.ExpectQuery(`SELECT "resourcerole"."id", "resourcerole"."name", .*FROM "authsrv_resourcerole" AS "resourcerole" WHERE .organization_id = '` + ouuid + `'. AND .partner_id = '` + puuid + `'. AND .name = 'role-` + ruuid + `'.`). WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id", "name", "organization_id", "partner_id"}).AddRow(ruuid, "role-"+ruuid, ouuid, puuid)) + mock.ExpectBegin() mock.ExpectExec(`UPDATE "authsrv_resourcerole" AS "resourcerole" SET "name" = 'role-` + ruuid + `', .*"organization_id" = '` + ouuid + `', "partner_id" = '` + puuid + `', "is_global" = TRUE, "scope" = 'system' WHERE .id = '` + ruuid + `'.`). WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectExec(`UPDATE "authsrv_resourcerolepermission" AS "resourcerolepermission" SET trash = TRUE WHERE ."resource_role_id" = '` + ruuid + `'.`). @@ -173,6 +183,7 @@ func TestUpdateRole(t *testing.T) { WithArgs().WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("ops_star.all")) mock.ExpectQuery(`INSERT INTO "authsrv_resourcerolepermission"`). WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(ruuid)) + mock.ExpectCommit() role := &rolev3.Role{ Metadata: &v3.Metadata{Partner: "partner-" + puuid, Organization: "org-" + ouuid, Name: "role-" + ruuid}, @@ -202,10 +213,12 @@ func TestRoleDelete(t *testing.T) { WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(ouuid)) mock.ExpectQuery(`SELECT "resourcerole"."id", "resourcerole"."name", .* FROM "authsrv_resourcerole" AS "resourcerole" WHERE`). WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(ruuid, "role-"+ruuid)) + mock.ExpectBegin() mock.ExpectExec(`UPDATE "authsrv_resourcerolepermission" AS "resourcerolepermission" SET trash = TRUE WHERE ."resource_role_id" = '` + ruuid + `'.`). WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectExec(`UPDATE "authsrv_resourcerole" AS "resourcerole" SET trash = TRUE WHERE .id = '` + ruuid + `'.`). WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() role := &rolev3.Role{ Metadata: &v3.Metadata{Partner: "partner-" + puuid, Organization: "org-" + ouuid, Name: "role-" + ruuid}, diff --git a/pkg/service/user.go b/pkg/service/user.go index ad9bae3..1b1e1e0 100644 --- a/pkg/service/user.go +++ b/pkg/service/user.go @@ -2,6 +2,7 @@ package service import ( "context" + "database/sql" "fmt" "strconv" "time" @@ -98,7 +99,7 @@ func getUserTraits(traits map[string]interface{}) userTraits { } // Map roles to accounts -func (s *userService) createUserRoleRelations(ctx context.Context, user *userv3.User, ids parsedIds) (*userv3.User, error) { +func (s *userService) createUserRoleRelations(ctx context.Context, db bun.IDB, user *userv3.User, ids parsedIds) (*userv3.User, error) { projectNamespaceRoles := user.GetSpec().GetProjectNamespaceRoles() // TODO: add transactions @@ -108,7 +109,7 @@ func (s *userService) createUserRoleRelations(ctx context.Context, user *userv3. var ps []*authzv1.Policy for _, pnr := range projectNamespaceRoles { role := pnr.GetRole() - entity, err := pg.GetIdByName(ctx, s.db, role, &models.Role{}) + entity, err := pg.GetIdByName(ctx, db, role, &models.Role{}) if err != nil { return user, fmt.Errorf("unable to find role '%v'", role) } @@ -125,7 +126,7 @@ func (s *userService) createUserRoleRelations(ctx context.Context, user *userv3. switch { case pnr.Namespace != nil: - projectId, err := pg.GetProjectId(ctx, s.db, project) + projectId, err := pg.GetProjectId(ctx, db, project) if err != nil { return user, fmt.Errorf("unable to find project '%v'", project) } @@ -151,7 +152,7 @@ func (s *userService) createUserRoleRelations(ctx context.Context, user *userv3. Obj: role, }) case project != "": - projectId, err := pg.GetProjectId(ctx, s.db, project) + projectId, err := pg.GetProjectId(ctx, db, project) if err != nil { return user, fmt.Errorf("unable to find project '%v'", project) } @@ -200,19 +201,19 @@ func (s *userService) createUserRoleRelations(ctx context.Context, user *userv3. } } if len(panrs) > 0 { - _, err := pg.Create(ctx, s.db, &panrs) + _, err := pg.Create(ctx, db, &panrs) if err != nil { return &userv3.User{}, err } } if len(pars) > 0 { - _, err := pg.Create(ctx, s.db, &pars) + _, err := pg.Create(ctx, db, &pars) if err != nil { return &userv3.User{}, err } } if len(ars) > 0 { - _, err := pg.Create(ctx, s.db, &ars) + _, err := pg.Create(ctx, db, &ars) if err != nil { return &userv3.User{}, err } @@ -229,14 +230,14 @@ func (s *userService) createUserRoleRelations(ctx context.Context, user *userv3. } // FIXME: make this generic -func (s *userService) getPartnerOrganization(ctx context.Context, user *userv3.User) (uuid.UUID, uuid.UUID, error) { +func (s *userService) getPartnerOrganization(ctx context.Context, db bun.IDB, user *userv3.User) (uuid.UUID, uuid.UUID, error) { partner := user.GetMetadata().GetPartner() org := user.GetMetadata().GetOrganization() - partnerId, err := pg.GetPartnerId(ctx, s.db, partner) + partnerId, err := pg.GetPartnerId(ctx, db, partner) if err != nil { return uuid.Nil, uuid.Nil, err } - organizationId, err := pg.GetOrganizationId(ctx, s.db, org) + organizationId, err := pg.GetOrganizationId(ctx, db, org) if err != nil { return partnerId, uuid.Nil, err } @@ -246,7 +247,7 @@ func (s *userService) getPartnerOrganization(ctx context.Context, user *userv3.U func (s *userService) Create(ctx context.Context, user *userv3.User) (*userv3.User, error) { // TODO: restrict endpoint to admin - partnerId, organizationId, err := s.getPartnerOrganization(ctx, user) + partnerId, organizationId, err := s.getPartnerOrganization(ctx, s.db, user) if err != nil { return nil, fmt.Errorf("unable to get partner and org id") } @@ -263,11 +264,23 @@ func (s *userService) Create(ctx context.Context, user *userv3.User) (*userv3.Us } uid, _ := uuid.Parse(id) - user, err = s.createUserRoleRelations(ctx, user, parsedIds{Id: uid, Partner: partnerId, Organization: organizationId}) + + tx, err := s.db.BeginTx(ctx, &sql.TxOptions{}) if err != nil { return &userv3.User{}, err } + user, err = s.createUserRoleRelations(ctx, tx, user, parsedIds{Id: uid, Partner: partnerId, Organization: organizationId}) + if err != nil { + tx.Rollback() + return &userv3.User{}, err + } + err = tx.Commit() + if err != nil { + tx.Rollback() + _log.Warn("unable to commit changes", err) + } + rl, err := s.ap.GetRecoveryLink(ctx, id) fmt.Println("Recovery link:", rl) // TODO: email the recovery link to the user if err != nil { @@ -277,9 +290,9 @@ func (s *userService) Create(ctx context.Context, user *userv3.User) (*userv3.Us return user, nil } -func (s *userService) identitiesModelToUser(ctx context.Context, user *userv3.User, usr *models.KratosIdentities) (*userv3.User, error) { +func (s *userService) identitiesModelToUser(ctx context.Context, db bun.IDB, user *userv3.User, usr *models.KratosIdentities) (*userv3.User, error) { traits := getUserTraits(usr.Traits) - groups, err := dao.GetGroups(ctx, s.db, usr.ID) + groups, err := dao.GetGroups(ctx, db, usr.ID) if err != nil { return &userv3.User{}, err } @@ -290,7 +303,7 @@ func (s *userService) identitiesModelToUser(ctx context.Context, user *userv3.Us labels := make(map[string]string) - roles, err := dao.GetUserRoles(ctx, s.db, usr.ID) + roles, err := dao.GetUserRoles(ctx, db, usr.ID) if err != nil { return &userv3.User{}, err } @@ -325,7 +338,7 @@ func (s *userService) GetByID(ctx context.Context, user *userv3.User) (*userv3.U } if usr, ok := entity.(*models.KratosIdentities); ok { - user, err := s.identitiesModelToUser(ctx, user, usr) + user, err := s.identitiesModelToUser(ctx, s.db, user, usr) if err != nil { return &userv3.User{}, err } @@ -344,7 +357,7 @@ func (s *userService) GetByName(ctx context.Context, user *userv3.User) (*userv3 } if usr, ok := entity.(*models.KratosIdentities); ok { - user, err := s.identitiesModelToUser(ctx, user, usr) + user, err := s.identitiesModelToUser(ctx, s.db, user, usr) if err != nil { return &userv3.User{}, err } @@ -355,16 +368,16 @@ func (s *userService) GetByName(ctx context.Context, user *userv3.User) (*userv3 return user, nil } -func (s *userService) deleteUserRoleRelations(ctx context.Context, userId uuid.UUID, user *userv3.User) error { - err := pg.DeleteX(ctx, s.db, "account_id", userId, &models.AccountResourcerole{}) +func (s *userService) deleteUserRoleRelations(ctx context.Context, db bun.IDB, userId uuid.UUID, user *userv3.User) error { + err := pg.DeleteX(ctx, db, "account_id", userId, &models.AccountResourcerole{}) if err != nil { return err } - err = pg.DeleteX(ctx, s.db, "account_id", userId, &models.ProjectAccountResourcerole{}) + err = pg.DeleteX(ctx, db, "account_id", userId, &models.ProjectAccountResourcerole{}) if err != nil { return err } - err = pg.DeleteX(ctx, s.db, "account_id", userId, &models.ProjectAccountNamespaceRole{}) + err = pg.DeleteX(ctx, db, "account_id", userId, &models.ProjectAccountNamespaceRole{}) if err != nil { return err } @@ -385,7 +398,7 @@ func (s *userService) Update(ctx context.Context, user *userv3.User) (*userv3.Us } if usr, ok := entity.(*models.KratosIdentities); ok { - partnerId, organizationId, err := s.getPartnerOrganization(ctx, user) + partnerId, organizationId, err := s.getPartnerOrganization(ctx, s.db, user) if err != nil { return nil, fmt.Errorf("unable to get partner and org id") } @@ -399,20 +412,33 @@ func (s *userService) Update(ctx context.Context, user *userv3.User) (*userv3.Us return &userv3.User{}, err } - err = s.deleteUserRoleRelations(ctx, usr.ID, user) + tx, err := s.db.BeginTx(ctx, &sql.TxOptions{}) if err != nil { return &userv3.User{}, err } - user, err = s.createUserRoleRelations(ctx, user, parsedIds{Id: usr.ID, Partner: partnerId, Organization: organizationId}) + err = s.deleteUserRoleRelations(ctx, tx, usr.ID, user) if err != nil { + tx.Rollback() return &userv3.User{}, err } - } else { - return &userv3.User{}, fmt.Errorf("unable to update user '%v'", name) + + user, err = s.createUserRoleRelations(ctx, tx, user, parsedIds{Id: usr.ID, Partner: partnerId, Organization: organizationId}) + if err != nil { + tx.Rollback() + return &userv3.User{}, err + } + + err = tx.Commit() + if err != nil { + tx.Rollback() + _log.Warn("unable to commit changes", err) + } + return user, nil } - return user, nil + return &userv3.User{}, fmt.Errorf("unable to update user '%v'", name) + } func (s *userService) Delete(ctx context.Context, user *userv3.User) (*userrpcv3.DeleteUserResponse, error) { @@ -423,21 +449,35 @@ func (s *userService) Delete(ctx context.Context, user *userv3.User) (*userrpcv3 } if usr, ok := entity.(*models.KratosIdentities); ok { - err = s.deleteUserRoleRelations(ctx, usr.ID, user) + + tx, err := s.db.BeginTx(ctx, &sql.TxOptions{}) if err != nil { return &userrpcv3.DeleteUserResponse{}, err } - err := s.ap.Delete(ctx, usr.ID.String()) + err = s.deleteUserRoleRelations(ctx, s.db, usr.ID, user) if err != nil { + tx.Rollback() + return &userrpcv3.DeleteUserResponse{}, err + } + + err = s.ap.Delete(ctx, usr.ID.String()) + if err != nil { + tx.Rollback() return &userrpcv3.DeleteUserResponse{}, err } err = pg.DeleteX(ctx, s.db, "account_id", usr.ID, &models.GroupAccount{}) if err != nil { + tx.Rollback() return &userrpcv3.DeleteUserResponse{}, fmt.Errorf("unable to delete user; %v", err) } + err = tx.Commit() + if err != nil { + tx.Rollback() + _log.Warn("unable to commit changes", err) + } return &userrpcv3.DeleteUserResponse{}, nil } return &userrpcv3.DeleteUserResponse{}, fmt.Errorf("unable to delete user '%v'", user.Metadata.Name) @@ -461,7 +501,7 @@ func (s *userService) List(ctx context.Context, _ *userv3.User) (*userv3.UserLis if usrs, ok := entities.(*[]models.KratosIdentities); ok { for _, usr := range *usrs { user := &userv3.User{} - user, err := s.identitiesModelToUser(ctx, user, &usr) + user, err := s.identitiesModelToUser(ctx, s.db, user, &usr) if err != nil { return userList, err } diff --git a/pkg/service/user_test.go b/pkg/service/user_test.go index e8e289f..ad9d7fd 100644 --- a/pkg/service/user_test.go +++ b/pkg/service/user_test.go @@ -71,6 +71,9 @@ func TestCreateUser(t *testing.T) { mock.ExpectQuery(`SELECT "organization"."id" FROM "authsrv_organization" AS "organization"`). WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(ouuid)) + mock.ExpectBegin() + mock.ExpectCommit() + user := &userv3.User{ Metadata: &v3.Metadata{Partner: "partner-" + puuid, Organization: "org-" + ouuid, Name: "user-" + uuuid}, Spec: &userv3.UserSpec{}, @@ -124,6 +127,8 @@ func TestCreateUserWithRole(t *testing.T) { WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(puuid)) mock.ExpectQuery(`SELECT "organization"."id" FROM "authsrv_organization" AS "organization"`). WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(ouuid)) + + mock.ExpectBegin() mock.ExpectQuery(`SELECT "resourcerole"."id" FROM "authsrv_resourcerole" AS "resourcerole"`). WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(pruuid)) if tc.roles[0].Project != nil { @@ -132,6 +137,7 @@ func TestCreateUserWithRole(t *testing.T) { } mock.ExpectQuery(fmt.Sprintf(`INSERT INTO "%v"`, tc.dbname)). WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(uuid.New().String())) + mock.ExpectCommit() user := &userv3.User{ Metadata: &v3.Metadata{Partner: "partner-" + puuid, Organization: "org-" + ouuid, Name: "user-" + uuuid}, @@ -184,6 +190,8 @@ func TestUpdateUser(t *testing.T) { WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(puuid)) mock.ExpectQuery(`SELECT "organization"."id" FROM "authsrv_organization" AS "organization"`). WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(ouuid)) + + mock.ExpectBegin() mock.ExpectExec(`UPDATE "authsrv_accountresourcerole" AS "accountresourcerole" SET trash = TRUE WHERE`). WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectExec(`UPDATE "authsrv_projectaccountresourcerole" AS "projectaccountresourcerole" SET trash = TRUE WHERE`). @@ -196,6 +204,7 @@ func TestUpdateUser(t *testing.T) { WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(pruuid)) mock.ExpectQuery(`INSERT INTO "authsrv_projectaccountnamespacerole"`). WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(uuid.New().String())) + mock.ExpectCommit() user := &userv3.User{ Metadata: &v3.Metadata{Partner: "partner-" + puuid, Organization: "org-" + ouuid, Name: "user-" + uuuid}, @@ -392,6 +401,7 @@ func TestUserDelete(t *testing.T) { mock.ExpectQuery(`SELECT "identities"."id" FROM "identities" WHERE .*traits ->> 'email' = 'user-` + uuuid + `'`). WithArgs().WillReturnRows(sqlmock.NewRows([]string{"id", "traits"}).AddRow(uuuid, []byte(`{"email":"johndoe@provider.com"}`))) + mock.ExpectBegin() mock.ExpectExec(`UPDATE "authsrv_accountresourcerole" AS "accountresourcerole" SET trash = TRUE WHERE`). WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectExec(`UPDATE "authsrv_projectaccountresourcerole" AS "projectaccountresourcerole" SET trash = TRUE WHERE`). @@ -401,6 +411,7 @@ func TestUserDelete(t *testing.T) { // User delete is via kratos mock.ExpectExec(`UPDATE "authsrv_groupaccount" AS "groupaccount" SET trash = TRUE WHERE`). WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() user := &userv3.User{ Metadata: &v3.Metadata{Partner: "partner-" + puuid, Organization: "org-" + ouuid, Name: "user-" + uuuid},