diff --git a/internal/cluster/dao/cluster.go b/internal/cluster/dao/cluster.go index 414f9b1..44a14ff 100644 --- a/internal/cluster/dao/cluster.go +++ b/internal/cluster/dao/cluster.go @@ -2,7 +2,6 @@ package dao import ( "context" - "database/sql" "strings" "time" @@ -14,59 +13,7 @@ import ( "github.com/uptrace/bun" ) -// ClusterDao is the interface for cluster operations -type ClusterDao interface { - // create cluster - CreateCluster(ctx context.Context, c *models.Cluster) error - // create or update cluster - UpdateCluster(ctx context.Context, c *models.Cluster) error - //list clusters - ListClusters(ctx context.Context, qo commonv3.QueryOptions) ([]models.Cluster, error) - // delete cluster - DeleteCluster(ctx context.Context, c *models.Cluster) error - // get cluster - GetCluster(ctx context.Context, c *models.Cluster) (*models.Cluster, error) - //get cluster for token - GetClusterForToken(ctx context.Context, token string) (cluster *models.Cluster, err error) - // update relay config information - UpdateClusterAnnotations(ctx context.Context, c *models.Cluster) error - // Notify channel - Notify(chanName, value string) error -} - -// clusterDao implements ClusterDao -type clusterDao struct { - cdao pg.EntityDAO - ctdao ClusterTokenDao - pcdao ProjectClusterDao -} - -// ClusterDao return new cluster dao -func NewClusterDao(edao pg.EntityDAO) ClusterDao { - return &clusterDao{ - cdao: edao, - ctdao: NewClusterTokenDao(edao), - pcdao: NewProjectClusterDao(edao), - } -} - -func (s *clusterDao) CreateCluster(ctx context.Context, cluster *models.Cluster) error { - - err := s.cdao.GetInstance().RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { - err := s.createCluster(ctx, cluster, tx) - if err != nil { - return err - } - return nil - }) - if err != nil { - return err - } - - return nil -} - -func (s *clusterDao) createCluster(ctx context.Context, cluster *models.Cluster, tx bun.Tx) error { +func CreateCluster(ctx context.Context, tx bun.Tx, cluster *models.Cluster) error { clstrToken := &models.ClusterToken{ OrganizationId: cluster.OrganizationId, @@ -74,7 +21,7 @@ func (s *clusterDao) createCluster(ctx context.Context, cluster *models.Cluster, ProjectId: cluster.ProjectId, CreatedAt: time.Now(), } - err := s.ctdao.CreateToken(ctx, clstrToken) + err := CreateToken(ctx, tx, clstrToken) if err != nil { return err } @@ -102,9 +49,9 @@ func (s *clusterDao) createCluster(ctx context.Context, cluster *models.Cluster, return nil } -func (s *clusterDao) UpdateCluster(ctx context.Context, c *models.Cluster) error { +func UpdateCluster(ctx context.Context, db bun.IDB, c *models.Cluster) error { - _, err := s.cdao.Update(ctx, c.ID, c) + _, err := pg.Update(ctx, db, c.ID, c) if err != nil { return err } @@ -112,9 +59,9 @@ func (s *clusterDao) UpdateCluster(ctx context.Context, c *models.Cluster) error return nil } -func (s *clusterDao) UpdateClusterAnnotations(ctx context.Context, c *models.Cluster) error { +func UpdateClusterAnnotations(ctx context.Context, db bun.IDB, c *models.Cluster) error { - _, err := s.cdao.GetInstance().NewUpdate().Model((*models.Cluster)(nil)). + _, err := db.NewUpdate().Model((*models.Cluster)(nil)). Set("annotations = ?", c.Annotations). Where("id = ?", c.ID).Exec(ctx) if err != nil { @@ -124,15 +71,15 @@ func (s *clusterDao) UpdateClusterAnnotations(ctx context.Context, c *models.Clu return nil } -func (s *clusterDao) GetCluster(ctx context.Context, cluster *models.Cluster) (*models.Cluster, error) { +func GetCluster(ctx context.Context, db bun.IDB, cluster *models.Cluster) (*models.Cluster, error) { if cluster.ID != uuid.Nil { - _, err := s.cdao.GetByID(ctx, cluster.ID, cluster) + _, err := pg.GetByID(ctx, db, cluster.ID, cluster) if err != nil { return nil, err } } else { - _, err := s.cdao.GetByName(ctx, cluster.Name, cluster) + _, err := pg.GetByName(ctx, db, cluster.Name, cluster) if err != nil { return nil, err } @@ -141,8 +88,8 @@ func (s *clusterDao) GetCluster(ctx context.Context, cluster *models.Cluster) (* return cluster, nil } -func (s *clusterDao) DeleteCluster(ctx context.Context, c *models.Cluster) error { - _, err := s.cdao.GetInstance(). +func DeleteCluster(ctx context.Context, db bun.IDB, c *models.Cluster) error { + _, err := db. NewUpdate().Model(c). Set("trash = ?", true). Set("deleted_at = ?", time.Now()). @@ -150,28 +97,28 @@ func (s *clusterDao) DeleteCluster(ctx context.Context, c *models.Cluster) error return err } -func (s *clusterDao) ListClusters(ctx context.Context, qo commonv3.QueryOptions) (clusters []models.Cluster, err error) { +func ListClusters(ctx context.Context, db bun.IDB, qo commonv3.QueryOptions) (clusters []models.Cluster, err error) { pid := uuid.NullUUID{UUID: uuid.MustParse(qo.Partner), Valid: true} oid := uuid.NullUUID{UUID: uuid.MustParse(qo.Organization), Valid: true} prid := uuid.NullUUID{UUID: uuid.MustParse(qo.Project), Valid: true} - err = s.cdao.ListByProject(ctx, pid, oid, prid, &clusters) + err = pg.ListByProject(ctx, db, pid, oid, prid, &clusters) if err != nil { return nil, err } return clusters, err } -func (s *clusterDao) GetClusterForToken(ctx context.Context, token string) (cluster *models.Cluster, err error) { - entity, err := s.cdao.GetX(ctx, "token", token, &models.Cluster{}) +func GetClusterForToken(ctx context.Context, db bun.IDB, token string) (cluster *models.Cluster, err error) { + entity, err := pg.GetX(ctx, db, "token", token, &models.Cluster{}) if err != nil { return nil, err } return entity.(*models.Cluster), err } -func (s *clusterDao) Notify(chanName, value string) error { - _, err := s.cdao.GetInstance().Exec("NOTIFY ?, ?", bun.Ident(chanName), value) +func Notify(db *bun.DB, chanName string, value string) error { + _, err := db.Exec("NOTIFY ?, ?", bun.Ident(chanName), value) return err } diff --git a/internal/cluster/dao/clusteroperatorbootstrap.go b/internal/cluster/dao/clusteroperatorbootstrap.go index 8a18c12..7df7641 100644 --- a/internal/cluster/dao/clusteroperatorbootstrap.go +++ b/internal/cluster/dao/clusteroperatorbootstrap.go @@ -2,7 +2,6 @@ package dao import ( "context" - "database/sql" "github.com/RafaySystems/rcloud-base/internal/models" "github.com/RafaySystems/rcloud-base/internal/persistence/provider/pg" @@ -12,68 +11,41 @@ import ( var _log = log.GetLogger() -// ClusterOperatorBootstrapDao is the interface for cluster operator bootstrap -type ClusterOperatorBootstrapDao interface { - // create edge operator bootstrap - CreateOperatorBootstrap(ctx context.Context, bootstrap *models.ClusterOperatorBootstrap) error - // GetOperatorBootstrap - GetOperatorBootstrap(ctx context.Context, edgeID string) (*models.ClusterOperatorBootstrap, error) -} - -// clusterOperatorBootstrapDao implements ClusterOperatorBootstrapDao -type clusterOperatorBootstrapDao struct { - dao pg.EntityDAO -} - -// ClusterOperatorBootstrapDao return new cluster credentials dao -func NewClusterOperatorBootstrapDao(dao pg.EntityDAO) ClusterOperatorBootstrapDao { - return &clusterOperatorBootstrapDao{ - dao: dao, - } -} - -func (es *clusterOperatorBootstrapDao) CreateOperatorBootstrap(ctx context.Context, bootstrap *models.ClusterOperatorBootstrap) error { +func CreateOperatorBootstrap(ctx context.Context, db bun.Tx, bootstrap *models.ClusterOperatorBootstrap) error { _log.Infow("CreateOperatorBootstrap: Creating operator bootstrap data", "cluster", bootstrap.ClusterId) - err := es.dao.GetInstance().RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { - var bstrap *models.ClusterOperatorBootstrap + var bstrap *models.ClusterOperatorBootstrap - entity, err := es.dao.GetX(ctx, "edge_id", bootstrap.ClusterId, &bstrap) + entity, err := pg.GetX(ctx, db, "edge_id", bootstrap.ClusterId, &bstrap) + if err != nil { + _log.Infow("CreateOperatorBootstrap: No existing bootstrap data detected", "edge", bootstrap.ClusterId) + } else { + _log.Infow("CreateOperatorBootstrap: Removing existing bootstrap data", "edge", bootstrap.ClusterId) + + bstrap = entity.(*models.ClusterOperatorBootstrap) + err = pg.DeleteX(ctx, db, "edge_id", bstrap.ClusterId, bstrap) if err != nil { - _log.Infow("CreateOperatorBootstrap: No existing boostrap data detected", "edge", bootstrap.ClusterId) - } else { - _log.Infow("CreateOperatorBootstrap: Removing existing boostrap data", "edge", bootstrap.ClusterId) - - bstrap = entity.(*models.ClusterOperatorBootstrap) - err = es.dao.DeleteX(ctx, "edge_id", bstrap.ClusterId, bstrap) - if err != nil { - _log.Errorw("Error while deleting bootstrap data", "Error", err) - return err - } - _log.Infow("CreateOperatorBootstrap: Deleted existing boostrap data", "cluster", bootstrap.ClusterId) - } - - _, err = tx.NewInsert().Model(bootstrap).Exec(ctx) - if err != nil { - _log.Errorw("Error inserting bootstrap data", "Error", err) + _log.Errorw("Error while deleting bootstrap data", "Error", err) return err } - - _log.Infow("Inserted bootstrap data", "cluster", bootstrap.ClusterId) - - return nil - }) - - if err != nil { - _log.Errorw("Exception while adding bootstrap data", "Error:", err) + _log.Infow("CreateOperatorBootstrap: Deleted existing bootstrap data", "cluster", bootstrap.ClusterId) } + + _, err = db.NewInsert().Model(bootstrap).Exec(ctx) + if err != nil { + _log.Errorw("Error inserting bootstrap data", "Error", err) + return err + } + + _log.Infow("Inserted bootstrap data", "cluster", bootstrap.ClusterId) + return nil } -func (es *clusterOperatorBootstrapDao) GetOperatorBootstrap(ctx context.Context, clusterid string) (*models.ClusterOperatorBootstrap, error) { +func GetOperatorBootstrap(ctx context.Context, db bun.IDB, clusterid string) (*models.ClusterOperatorBootstrap, error) { var bootstrap models.ClusterOperatorBootstrap - entity, err := es.dao.GetX(ctx, "clusterid", clusterid, bootstrap) + entity, err := pg.GetX(ctx, db, "clusterid", clusterid, bootstrap) if err != nil { _log.Errorw("Error while fetching bootstrap data using tx ", "Error", err) return nil, err diff --git a/internal/cluster/dao/clustertoken.go b/internal/cluster/dao/clustertoken.go index e8ba7ae..1f5a88e 100644 --- a/internal/cluster/dao/clustertoken.go +++ b/internal/cluster/dao/clustertoken.go @@ -8,6 +8,7 @@ import ( "github.com/RafaySystems/rcloud-base/internal/persistence/provider/pg" infrav3 "github.com/RafaySystems/rcloud-base/proto/types/infrapb/v3" "github.com/rs/xid" + "github.com/uptrace/bun" ) var ( @@ -17,44 +18,24 @@ var ( ErrUsedToken = errors.New("used token") ) -// ClusterTokenDao is the interface for cluster token operations -type ClusterTokenDao interface { - // create cluster token - CreateToken(ctx context.Context, c *models.ClusterToken) error - //register the token - RegisterToken(ctx context.Context, token string) (*models.ClusterToken, error) -} - -// clusterTokenDao implements ClusterTokenDao -type clusterTokenDao struct { - dao pg.EntityDAO -} - -// ClusterDao return new cluster dao -func NewClusterTokenDao(dao pg.EntityDAO) ClusterTokenDao { - return &clusterTokenDao{ - dao: dao, - } -} - // CreateToken creates a token for given cluster name -func (s *clusterTokenDao) CreateToken(ctx context.Context, token *models.ClusterToken) error { +func CreateToken(ctx context.Context, db bun.IDB, token *models.ClusterToken) error { token.Name = xid.New().String() - _, err := s.dao.Create(ctx, token) + _, err := pg.Create(ctx, db, token) return err } // registerToken registers the cluster token -func (s *clusterTokenDao) RegisterToken(ctx context.Context, token string) (*models.ClusterToken, error) { +func RegisterToken(ctx context.Context, db bun.IDB, token string) (*models.ClusterToken, error) { - entity, err := s.dao.GetX(ctx, "name", token, &models.ClusterToken{}) + entity, err := pg.GetX(ctx, db, "name", token, &models.ClusterToken{}) if err != nil { return nil, ErrInvalidToken } ct := entity.(*models.ClusterToken) ct.State = infrav3.ClusterTokenState_TokenUsed.String() - s.dao.Update(ctx, ct.ID, ct) + pg.Update(ctx, db, ct.ID, ct) if err != nil { return nil, ErrInvalidToken } diff --git a/internal/cluster/dao/namespaces.go b/internal/cluster/dao/namespaces.go index 3eb4607..77e987a 100644 --- a/internal/cluster/dao/namespaces.go +++ b/internal/cluster/dao/namespaces.go @@ -13,37 +13,11 @@ import ( "github.com/uptrace/bun" ) -// ClusterNamespacesDao is the interface for cluster namespaces operations -type ClusterNamespacesDao interface { - // Get Namespace - GetNamespace(ctx context.Context, clusterID uuid.UUID, name string) (models.ClusterNamespace, error) - // GetNamespaces - GetNamespaces(ctx context.Context, clusterID uuid.UUID) ([]models.ClusterNamespace, error) - // GetNamespacesForConditions - GetNamespacesForConditions(ctx context.Context, clusterID uuid.UUID, conditions []scheduler.ClusterNamespaceCondition) ([]models.ClusterNamespace, int, error) - // UpdateNamespaceStatus - UpdateNamespaceStatus(ctx context.Context, updated *models.ClusterNamespace) error - // GetNamespaceHashes - GetNamespaceHashes(ctx context.Context, clusterID uuid.UUID) ([]infrav3.NameHash, error) -} - -// clusterNamespacesDao implements ClusterNamespacesDao -type clusterNamespacesDao struct { - dao pg.EntityDAO -} - -// ClusterNamespacesDao return new cluster namespaces dao -func NewClusterNamespacesDao(dao pg.EntityDAO) ClusterNamespacesDao { - return &clusterNamespacesDao{ - dao: dao, - } -} - -func (s clusterNamespacesDao) GetNamespace(ctx context.Context, clusterID uuid.UUID, name string) (models.ClusterNamespace, error) { +func GetNamespace(ctx context.Context, db bun.IDB, clusterID uuid.UUID, name string) (models.ClusterNamespace, error) { var cn models.ClusterNamespace - err := s.dao.GetInstance().NewSelect().Model(&cn). + err := db.NewSelect().Model(&cn). Where("cluster_id = ?", clusterID). Where("name = ?", name). Scan(ctx) @@ -55,17 +29,17 @@ func (s clusterNamespacesDao) GetNamespace(ctx context.Context, clusterID uuid.U return cn, nil } -func (s clusterNamespacesDao) GetNamespaces(ctx context.Context, clusterID uuid.UUID) ([]models.ClusterNamespace, error) { +func GetNamespaces(ctx context.Context, db bun.IDB, clusterID uuid.UUID) ([]models.ClusterNamespace, error) { var cns []models.ClusterNamespace - _, err := s.dao.GetX(ctx, "cluster_id", clusterID, &cns) + _, err := pg.GetX(ctx, db, "cluster_id", clusterID, &cns) return cns, err } -func (s clusterNamespacesDao) GetNamespacesForConditions(ctx context.Context, clusterID uuid.UUID, conditions []scheduler.ClusterNamespaceCondition) ([]models.ClusterNamespace, int, error) { +func GetNamespacesForConditions(ctx context.Context, db bun.IDB, clusterID uuid.UUID, conditions []scheduler.ClusterNamespaceCondition) ([]models.ClusterNamespace, int, error) { var cns []models.ClusterNamespace - q := s.dao.GetInstance().NewSelect().Model(&cns).Where("cluster_id = ?", clusterID) + q := db.NewSelect().Model(&cns).Where("cluster_id = ?", clusterID) for _, condition := range conditions { q.WhereGroup("", func(sq *bun.SelectQuery) *bun.SelectQuery { @@ -87,9 +61,9 @@ func (s clusterNamespacesDao) GetNamespacesForConditions(ctx context.Context, cl return cns, count, err } -func (s clusterNamespacesDao) UpdateNamespaceStatus(ctx context.Context, updated *models.ClusterNamespace) error { +func UpdateNamespaceStatus(ctx context.Context, db bun.IDB, updated *models.ClusterNamespace) error { - _, err := s.dao.GetInstance().NewUpdate().Model(updated). + _, err := db.NewUpdate().Model(updated). Set("conditions = ?", updated.Conditions). Set("status = ?", updated.Status). Where("cluster_id = ?", updated.ClusterId). @@ -99,11 +73,11 @@ func (s clusterNamespacesDao) UpdateNamespaceStatus(ctx context.Context, updated return err } -func (s clusterNamespacesDao) GetNamespaceHashes(ctx context.Context, clusterID uuid.UUID) ([]infrav3.NameHash, error) { +func GetNamespaceHashes(ctx context.Context, db bun.IDB, clusterID uuid.UUID) ([]infrav3.NameHash, error) { var nameHashes []infrav3.NameHash - err := s.dao.GetInstance().NewSelect(). + err := db.NewSelect(). Model((*models.ClusterNamespace)(nil)). Column("name", "hash"). //TODO: to be changed to ClusterTaskDeleted later once task is supported diff --git a/internal/cluster/dao/projectcluster.go b/internal/cluster/dao/projectcluster.go index be22bc6..8869093 100644 --- a/internal/cluster/dao/projectcluster.go +++ b/internal/cluster/dao/projectcluster.go @@ -8,57 +8,34 @@ import ( "github.com/RafaySystems/rcloud-base/pkg/query" commonv3 "github.com/RafaySystems/rcloud-base/proto/types/commonpb/v3" "github.com/google/uuid" + "github.com/uptrace/bun" ) -// ProjectClusterDao is the interface for project cluster operations -type ProjectClusterDao interface { - // create project cluster - CreateProjectCluster(ctx context.Context, pc *models.ProjectCluster) error - // get projects for cluster - GetProjectsForCluster(ctx context.Context, clusterID uuid.UUID) ([]models.ProjectCluster, error) - // delete projects for cluster - DeleteProjectsForCluster(ctx context.Context, clusterID uuid.UUID) error - // Validate if the project in scope is owner of the cluster - ValidateClusterAccess(ctx context.Context, opts commonv3.QueryOptions) (bool, error) -} - -// projectClusterDao implements ProjectClusterDao -type projectClusterDao struct { - dao pg.EntityDAO -} - -// ProjectClusterDao return new project cluster dao -func NewProjectClusterDao(dao pg.EntityDAO) ProjectClusterDao { - return &projectClusterDao{ - dao: dao, - } -} - -func (s *projectClusterDao) CreateProjectCluster(ctx context.Context, pc *models.ProjectCluster) error { - _, err := s.dao.Create(ctx, pc) +func CreateProjectCluster(ctx context.Context, db bun.IDB, pc *models.ProjectCluster) error { + _, err := pg.Create(ctx, db, pc) if err != nil { return err } return nil } -func (s *projectClusterDao) GetProjectsForCluster(ctx context.Context, clusterID uuid.UUID) ([]models.ProjectCluster, error) { +func GetProjectsForCluster(ctx context.Context, db bun.IDB, clusterID uuid.UUID) ([]models.ProjectCluster, error) { var projectClusters []models.ProjectCluster - err := s.dao.GetInstance().NewSelect().Model(&projectClusters).Where("cluster_id = ?", clusterID).Scan(ctx) + err := db.NewSelect().Model(&projectClusters).Where("cluster_id = ?", clusterID).Scan(ctx) if err != nil { return nil, err } return projectClusters, nil } -func (s *projectClusterDao) DeleteProjectsForCluster(ctx context.Context, clusterID uuid.UUID) error { - return s.dao.DeleteX(ctx, "cluster_id", clusterID, &models.ProjectCluster{}) +func DeleteProjectsForCluster(ctx context.Context, db bun.IDB, clusterID uuid.UUID) error { + return pg.DeleteX(ctx, db, "cluster_id", clusterID, &models.ProjectCluster{}) } // Check if the project in scope is owner of the cluster -func (s *projectClusterDao) ValidateClusterAccess(ctx context.Context, opts commonv3.QueryOptions) (bool, error) { +func ValidateClusterAccess(ctx context.Context, db bun.IDB, opts commonv3.QueryOptions) (bool, error) { var _c models.Cluster - q, err := query.Select(s.dao.GetInstance().NewSelect().Model(&_c), &opts) + q, err := query.Select(db.NewSelect().Model(&_c), &opts) if err != nil { return false, err } diff --git a/internal/dao/bootstrap.go b/internal/dao/bootstrap.go index 2fae1e4..9e3f630 100644 --- a/internal/dao/bootstrap.go +++ b/internal/dao/bootstrap.go @@ -2,7 +2,6 @@ package dao import ( "context" - "database/sql" "errors" "fmt" "time" @@ -17,44 +16,9 @@ import ( "github.com/uptrace/bun" ) -// BootstrapDao is the interface for bootstrap operations -type BootstrapDao interface { - CreateOrUpdateBootstrapInfra(ctx context.Context, infra *models.BootstrapInfra) error - CreateOrUpdateBootstrapAgentTemplate(context.Context, *models.BootstrapAgentTemplate) error - GetBootstrapAgentTemplateForToken(ctx context.Context, token string) (*models.BootstrapAgentTemplate, error) - SelectBootstrapAgentTemplates(ctx context.Context, opts *commonv3.QueryOptions) (ret []models.BootstrapAgentTemplate, count int, err error) - DeleteBootstrapAgentTempate(ctx context.Context, opts *commonv3.QueryOptions, infraRef string) error - GetBootstrapAgents(ctx context.Context, opts *commonv3.QueryOptions, templateRef string) (ret []models.BootstrapAgent, count int, err error) - CreateBootstrapAgent(ctx context.Context, ba *models.BootstrapAgent) error - GetBootstrapAgent(ctx context.Context, templateRef string, opts *commonv3.QueryOptions) (*models.BootstrapAgent, error) - SelectBootstrapAgents(ctx context.Context, templateRef string, opts *commonv3.QueryOptions) (ret []models.BootstrapAgent, count int, err error) - RegisterBootstrapAgent(ctx context.Context, token string) error - DeleteBootstrapAgent(ctx context.Context, templateRef string, opts *commonv3.QueryOptions) error - UpdateBootstrapAgent(ctx context.Context, ba *models.BootstrapAgent, opts *commonv3.QueryOptions) error - GetBootstrapAgentForToken(ctx context.Context, token string) (*models.BootstrapAgent, error) - GetBootstrapAgentTemplateForHost(ctx context.Context, host string) (*models.BootstrapAgentTemplate, error) - GetBootstrapAgentCountForClusterID(ctx context.Context, clusterID string, orgID uuid.UUID) (int, error) - GetBootstrapAgentForClusterID(ctx context.Context, clusterID string, orgID uuid.UUID) (*models.BootstrapAgent, error) - UpdateBootstrapAgentDeleteAt(ctx context.Context, templateRef string) error - UpdateBootstrapAgentTempateDeleteAt(ctx context.Context, opts *commonv3.QueryOptions) error - UpdateBootstrapInfraDeleteAt(ctx context.Context, opts *commonv3.QueryOptions) error -} +func CreateOrUpdateBootstrapInfra(ctx context.Context, db bun.IDB, infra *models.BootstrapInfra) error { -// bootstrapDao implements BootstrapDao -type bootstrapDao struct { - bdao pg.EntityDAO -} - -// BootstrapDao return new bootstrap dao -func NewBootstrapDao(edao pg.EntityDAO) BootstrapDao { - return &bootstrapDao{ - bdao: edao, - } -} - -func (s *bootstrapDao) CreateOrUpdateBootstrapInfra(ctx context.Context, infra *models.BootstrapInfra) error { - - _, err := s.bdao.GetInstance().NewInsert().On("CONFLICT (name) DO UPDATE"). + _, err := db.NewInsert().On("CONFLICT (name) DO UPDATE"). Set("ca_cert = ?", infra.CaCert). Set("ca_key = ?", infra.CaKey). Set("modified_at = ?", time.Now()). @@ -63,9 +27,9 @@ func (s *bootstrapDao) CreateOrUpdateBootstrapInfra(ctx context.Context, infra * return err } -func (s *bootstrapDao) CreateOrUpdateBootstrapAgentTemplate(ctx context.Context, template *models.BootstrapAgentTemplate) error { +func CreateOrUpdateBootstrapAgentTemplate(ctx context.Context, db bun.IDB, template *models.BootstrapAgentTemplate) error { - _, err := s.bdao.GetInstance().NewInsert().On("CONFLICT (name) DO UPDATE"). + _, err := db.NewInsert().On("CONFLICT (name) DO UPDATE"). Set("infra_ref = ?", template.InfraRef). Set("ignore_multiple_register = ?", template.IgnoreMultipleRegister). Set("auto_register = ?", template.AutoRegister). @@ -80,14 +44,14 @@ func (s *bootstrapDao) CreateOrUpdateBootstrapAgentTemplate(ctx context.Context, return err } -func (s *bootstrapDao) GetBootstrapAgentTemplateForToken(ctx context.Context, token string) (*models.BootstrapAgentTemplate, error) { +func GetBootstrapAgentTemplateForToken(ctx context.Context, db bun.IDB, token string) (*models.BootstrapAgentTemplate, error) { var template models.BootstrapAgentTemplate - err := s.bdao.GetInstance().NewSelect().Model(&template).Where("token = ?", token).Scan(ctx) + err := db.NewSelect().Model(&template).Where("token = ?", token).Scan(ctx) return &template, err } -func (s *bootstrapDao) SelectBootstrapAgentTemplates(ctx context.Context, opts *commonv3.QueryOptions) (ret []models.BootstrapAgentTemplate, count int, err error) { - q, err := query.Select(s.bdao.GetInstance().NewSelect().Model(&ret), opts) +func SelectBootstrapAgentTemplates(ctx context.Context, db bun.IDB, opts *commonv3.QueryOptions) (ret []models.BootstrapAgentTemplate, count int, err error) { + q, err := query.Select(db.NewSelect().Model(&ret), opts) if err != nil { return } @@ -99,9 +63,9 @@ func (s *bootstrapDao) SelectBootstrapAgentTemplates(ctx context.Context, opts * return } -func (s *bootstrapDao) DeleteBootstrapAgentTempate(ctx context.Context, opts *commonv3.QueryOptions, infraRef string) error { +func DeleteBootstrapAgentTempate(ctx context.Context, db bun.IDB, opts *commonv3.QueryOptions, infraRef string) error { - q, err := query.Delete(s.bdao.GetInstance().NewSelect().Model((*models.BootstrapAgentTemplate)(nil)), opts) + q, err := query.Delete(db.NewSelect().Model((*models.BootstrapAgentTemplate)(nil)), opts) if err != nil { return err } @@ -109,9 +73,9 @@ func (s *bootstrapDao) DeleteBootstrapAgentTempate(ctx context.Context, opts *co return err } -func (s *bootstrapDao) GetBootstrapAgent(ctx context.Context, templateRef string, opts *commonv3.QueryOptions) (*models.BootstrapAgent, error) { +func GetBootstrapAgent(ctx context.Context, db bun.IDB, templateRef string, opts *commonv3.QueryOptions) (*models.BootstrapAgent, error) { var ba models.BootstrapAgent - q, err := query.Get(s.bdao.GetInstance().NewSelect().Model(&ba), opts) + q, err := query.Get(db.NewSelect().Model(&ba), opts) if err != nil { return nil, err } @@ -127,8 +91,8 @@ func (s *bootstrapDao) GetBootstrapAgent(ctx context.Context, templateRef string return &ba, err } -func (s *bootstrapDao) GetBootstrapAgents(ctx context.Context, opts *commonv3.QueryOptions, templateRef string) (ret []models.BootstrapAgent, count int, err error) { - q, err := query.Get(s.bdao.GetInstance().NewSelect().Model(&ret), opts) +func GetBootstrapAgents(ctx context.Context, db bun.IDB, opts *commonv3.QueryOptions, templateRef string) (ret []models.BootstrapAgent, count int, err error) { + q, err := query.Get(db.NewSelect().Model(&ret), opts) if err != nil { return nil, 0, err } @@ -142,9 +106,9 @@ func (s *bootstrapDao) GetBootstrapAgents(ctx context.Context, opts *commonv3.Qu return } -func (s *bootstrapDao) SelectBootstrapAgents(ctx context.Context, templateRef string, opts *commonv3.QueryOptions) (ret []models.BootstrapAgent, count int, err error) { +func SelectBootstrapAgents(ctx context.Context, db bun.IDB, templateRef string, opts *commonv3.QueryOptions) (ret []models.BootstrapAgent, count int, err error) { - q, err := query.Select(s.bdao.GetInstance().NewSelect().Model(&ret), opts) + q, err := query.Select(db.NewSelect().Model(&ret), opts) if err != nil { return } @@ -158,67 +122,65 @@ func (s *bootstrapDao) SelectBootstrapAgents(ctx context.Context, templateRef st return } -func (s *bootstrapDao) CreateBootstrapAgent(ctx context.Context, ba *models.BootstrapAgent) error { +func CreateBootstrapAgent(ctx context.Context, db bun.IDB, ba *models.BootstrapAgent) error { ba.TokenState = sentry.BootstrapAgentState_NotRegistered.String() - _, err := s.bdao.Create(ctx, ba) + _, err := pg.Create(ctx, db, ba) return err } -func (s *bootstrapDao) RegisterBootstrapAgent(ctx context.Context, token string) error { - - err := s.bdao.GetInstance().RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { - ba, err := s.getBootstrapAgentForToken(ctx, token) - if err != nil { - return err - } - - bat, err := s.getBootstrapAgentTemplate(ctx, ba.TemplateRef) - if err != nil { - return err - } - - state := sentry.BootstrapAgentState_NotApproved - if bat.AutoApprove { - state = sentry.BootstrapAgentState_Approved - } - - switch ba.TokenState { - case sentry.BootstrapAgentState_NotRegistered.String(): - ba.TokenState = sentry.BootstrapAgentState_Approved.String() - case sentry.BootstrapAgentState_NotApproved.String(), sentry.BootstrapAgentState_Approved.String(): - if !bat.IgnoreMultipleRegister { - return fmt.Errorf("cannot register token %s state is %s", token, ba.TokenState) - } - default: - return fmt.Errorf("invalid token state %s", ba.TokenState) - } - - _, err = s.bdao.GetInstance().NewUpdate().Model(ba). - Set("token_state = ?", state). - Where("token = ?", token). - Exec(ctx) +// We are explicitly taking in Tx here as it was previously doing RunInTx +// TODO: should we take Tx here or just assume we will be passed in a tx? Or should we create one? +func RegisterBootstrapAgent(ctx context.Context, db bun.Tx, token string) error { + ba, err := getBootstrapAgentForToken(ctx, db, token) + if err != nil { return err - }) + } + + bat, err := getBootstrapAgentTemplate(ctx, db, ba.TemplateRef) + if err != nil { + return err + } + + state := sentry.BootstrapAgentState_NotApproved + if bat.AutoApprove { + state = sentry.BootstrapAgentState_Approved + } + + switch ba.TokenState { + case sentry.BootstrapAgentState_NotRegistered.String(): + ba.TokenState = sentry.BootstrapAgentState_Approved.String() + case sentry.BootstrapAgentState_NotApproved.String(), sentry.BootstrapAgentState_Approved.String(): + if !bat.IgnoreMultipleRegister { + return fmt.Errorf("cannot register token %s state is %s", token, ba.TokenState) + } + default: + return fmt.Errorf("invalid token state %s", ba.TokenState) + } + + _, err = db.NewUpdate().Model(ba). + Set("token_state = ?", state). + Where("token = ?", token). + Exec(ctx) return err } -func (s *bootstrapDao) getBootstrapAgentForToken(ctx context.Context, token string) (*models.BootstrapAgent, error) { +func getBootstrapAgentForToken(ctx context.Context, db bun.IDB, token string) (*models.BootstrapAgent, error) { var ba models.BootstrapAgent - err := s.bdao.GetInstance().NewSelect().Model(&ba).Where("token = ?", token).Scan(ctx) + err := db.NewSelect().Model(&ba).Where("token = ?", token).Scan(ctx) return &ba, err } -func (s *bootstrapDao) getBootstrapAgentTemplate(ctx context.Context, name string) (*models.BootstrapAgentTemplate, error) { +func getBootstrapAgentTemplate(ctx context.Context, db bun.IDB, name string) (*models.BootstrapAgentTemplate, error) { var template models.BootstrapAgentTemplate - err := s.bdao.GetInstance().NewSelect().Model(&template).Where("name = ?", name).Scan(ctx) + err := db.NewSelect().Model(&template).Where("name = ?", name).Scan(ctx) return &template, err } -func (s *bootstrapDao) DeleteBootstrapAgent(ctx context.Context, templateRef string, opts *commonv3.QueryOptions) error { +func DeleteBootstrapAgent(ctx context.Context, db bun.IDB, templateRef string, opts *commonv3.QueryOptions) error { - dq := s.bdao.GetInstance().NewDelete().Model((*models.BootstrapAgent)(nil)).Where("name = ?", opts.ID) + dq := db.NewDelete().Model((*models.BootstrapAgent)(nil)).Where("name = ?", opts.ID) if templateRef != "" { dq = dq.Where("template_ref = ?", templateRef) } @@ -226,22 +188,22 @@ func (s *bootstrapDao) DeleteBootstrapAgent(ctx context.Context, templateRef str return err } -func (s *bootstrapDao) UpdateBootstrapAgent(ctx context.Context, ba *models.BootstrapAgent, opts *commonv3.QueryOptions) error { - _, err := s.bdao.GetInstance().NewUpdate().Model(ba).Where("id = ?", ba.ID).Returning("*").Exec(ctx) +func UpdateBootstrapAgent(ctx context.Context, db bun.IDB, ba *models.BootstrapAgent, opts *commonv3.QueryOptions) error { + _, err := db.NewUpdate().Model(ba).Where("id = ?", ba.ID).Returning("*").Exec(ctx) return err } -func (s *bootstrapDao) GetBootstrapAgentForToken(ctx context.Context, token string) (*models.BootstrapAgent, error) { +func GetBootstrapAgentForToken(ctx context.Context, db bun.IDB, token string) (*models.BootstrapAgent, error) { var ba models.BootstrapAgent - err := s.bdao.GetInstance().NewSelect().Model(&ba).Where("token = ?", token).Scan(ctx) + err := db.NewSelect().Model(&ba).Where("token = ?", token).Scan(ctx) return &ba, err } -func (s *bootstrapDao) GetBootstrapAgentTemplateForHost(ctx context.Context, host string) (*models.BootstrapAgentTemplate, error) { +func GetBootstrapAgentTemplateForHost(ctx context.Context, db bun.IDB, host string) (*models.BootstrapAgentTemplate, error) { bat := models.BootstrapAgentTemplate{} - err := s.bdao.GetInstance().NewSelect().Model(&bat). + err := db.NewSelect().Model(&bat). ColumnExpr("bat.*"). Join("JOIN sentry_bootstrap_template_host as bth"). JoinOn("bat.name = bth.name"). @@ -251,9 +213,9 @@ func (s *bootstrapDao) GetBootstrapAgentTemplateForHost(ctx context.Context, hos return &bat, err } -func (s *bootstrapDao) GetBootstrapAgentCountForClusterID(ctx context.Context, clusterID string, orgID uuid.UUID) (int, error) { +func GetBootstrapAgentCountForClusterID(ctx context.Context, db bun.IDB, clusterID string, orgID uuid.UUID) (int, error) { var ba []models.BootstrapAgent - err := s.bdao.GetInstance().NewSelect().Model(&ba). + err := db.NewSelect().Model(&ba). Where("name = ?", clusterID). Where("organization_id = ?", orgID). Scan(ctx) @@ -263,9 +225,9 @@ func (s *bootstrapDao) GetBootstrapAgentCountForClusterID(ctx context.Context, c return len(ba), nil } -func (s *bootstrapDao) GetBootstrapAgentForClusterID(ctx context.Context, clusterID string, orgID uuid.UUID) (*models.BootstrapAgent, error) { +func GetBootstrapAgentForClusterID(ctx context.Context, db bun.IDB, clusterID string, orgID uuid.UUID) (*models.BootstrapAgent, error) { var ba models.BootstrapAgent - err := s.bdao.GetInstance().NewSelect().Model(&ba). + err := db.NewSelect().Model(&ba). Where("name = ?", clusterID). Where("organization_id = ?", orgID). Scan(ctx) @@ -276,9 +238,9 @@ func (s *bootstrapDao) GetBootstrapAgentForClusterID(ctx context.Context, cluste } // updateBootstrapAgentDeleteAt builds query for deleting resource -func (s *bootstrapDao) UpdateBootstrapAgentDeleteAt(ctx context.Context, templateRef string) error { +func UpdateBootstrapAgentDeleteAt(ctx context.Context, db bun.IDB, templateRef string) error { var toBeDeletedAgent *models.BootstrapAgent - _, err := s.bdao.GetX(ctx, "template_ref", templateRef, &toBeDeletedAgent) + _, err := pg.GetX(ctx, db, "template_ref", templateRef, &toBeDeletedAgent) if err != nil { return err } @@ -289,7 +251,7 @@ func (s *bootstrapDao) UpdateBootstrapAgentDeleteAt(ctx context.Context, templat opts := &commonv3.QueryOptions{} query.WithName(toBeDeletedAgent.Name)(opts) - q, err := query.Update(s.bdao.GetInstance().NewUpdate().Model((*models.BootstrapAgent)(nil)), opts) + q, err := query.Update(db.NewUpdate().Model((*models.BootstrapAgent)(nil)), opts) if err != nil { return err } @@ -300,8 +262,8 @@ func (s *bootstrapDao) UpdateBootstrapAgentDeleteAt(ctx context.Context, templat return err } -func (s *bootstrapDao) UpdateBootstrapAgentTempateDeleteAt(ctx context.Context, opts *commonv3.QueryOptions) error { - q, err := query.Update(s.bdao.GetInstance().NewUpdate().Model((*models.BootstrapAgentTemplate)(nil)), opts) +func UpdateBootstrapAgentTempateDeleteAt(ctx context.Context, db bun.IDB, opts *commonv3.QueryOptions) error { + q, err := query.Update(db.NewUpdate().Model((*models.BootstrapAgentTemplate)(nil)), opts) if err != nil { return err } @@ -313,8 +275,8 @@ func (s *bootstrapDao) UpdateBootstrapAgentTempateDeleteAt(ctx context.Context, } // updateBootstrapInfraDeleteAt builds query for deleting resource -func (s *bootstrapDao) UpdateBootstrapInfraDeleteAt(ctx context.Context, opts *commonv3.QueryOptions) error { - q, err := query.Update(s.bdao.GetInstance().NewUpdate().Model((*models.BootstrapInfra)(nil)), opts) +func UpdateBootstrapInfraDeleteAt(ctx context.Context, db bun.IDB, opts *commonv3.QueryOptions) error { + q, err := query.Update(db.NewUpdate().Model((*models.BootstrapInfra)(nil)), opts) if err != nil { return err } diff --git a/internal/dao/group.go b/internal/dao/group.go index a60b230..70fe4de 100644 --- a/internal/dao/group.go +++ b/internal/dao/group.go @@ -9,42 +9,20 @@ import ( "github.com/uptrace/bun" ) -type groupDAO struct { - db *bun.DB -} - -// Group specific db access -type GroupDAO interface { - Close() error - // get users for group - GetUsers(context.Context, uuid.UUID) ([]models.KratosIdentities, error) - // get roles for group - GetRoles(context.Context, uuid.UUID) ([]*userv3.ProjectNamespaceRole, error) -} - -// NewGroupDao return new group dao -func NewGroupDAO(db *bun.DB) *groupDAO { - return &groupDAO{db} -} - -func (dao *groupDAO) Close() error { - return dao.db.Close() -} - // GetUsers gets the list of users in a given group -func (dao *groupDAO) GetUsers(ctx context.Context, id uuid.UUID) ([]models.KratosIdentities, error) { +func GetUsers(ctx context.Context, db bun.IDB, id uuid.UUID) ([]models.KratosIdentities, error) { var entities = []models.KratosIdentities{} - err := dao.db.NewSelect().Model(&entities). + err := db.NewSelect().Model(&entities). Join(`JOIN authsrv_groupaccount ON identities.id=authsrv_groupaccount.account_id`). Where(`authsrv_groupaccount.group_id = ?`, id). Scan(ctx) return entities, err } -func (dao *groupDAO) GetRoles(ctx context.Context, id uuid.UUID) ([]*userv3.ProjectNamespaceRole, error) { +func GetGroupRoles(ctx context.Context, db bun.IDB, id uuid.UUID) ([]*userv3.ProjectNamespaceRole, error) { // Could possibily union them later for some speedup var r = []*userv3.ProjectNamespaceRole{} - err := dao.db.NewSelect().Table("authsrv_grouprole"). + err := db.NewSelect().Table("authsrv_grouprole"). ColumnExpr("authsrv_resourcerole.name as role"). Join(`JOIN authsrv_resourcerole ON authsrv_resourcerole.id=authsrv_grouprole.role_id`). Where("authsrv_grouprole.group_id = ?", id). @@ -54,7 +32,7 @@ func (dao *groupDAO) GetRoles(ctx context.Context, id uuid.UUID) ([]*userv3.Proj } var pr = []*userv3.ProjectNamespaceRole{} - err = dao.db.NewSelect().Table("authsrv_projectgrouprole"). + err = db.NewSelect().Table("authsrv_projectgrouprole"). ColumnExpr("authsrv_resourcerole.name as role, authsrv_project.name as project"). Join(`JOIN authsrv_resourcerole ON authsrv_resourcerole.id=authsrv_projectgrouprole.role_id`). Join(`JOIN authsrv_project ON authsrv_project.id=authsrv_projectgrouprole.project_id`). @@ -65,7 +43,7 @@ func (dao *groupDAO) GetRoles(ctx context.Context, id uuid.UUID) ([]*userv3.Proj } var pnr = []*userv3.ProjectNamespaceRole{} - err = dao.db.NewSelect().Table("authsrv_projectgroupnamespacerole"). + err = db.NewSelect().Table("authsrv_projectgroupnamespacerole"). ColumnExpr("authsrv_resourcerole.name as role, authsrv_project.name as project, namespace_id as namespace"). Join(`JOIN authsrv_resourcerole ON authsrv_resourcerole.id=authsrv_projectgroupnamespacerole.role_id`). Join(`JOIN authsrv_project ON authsrv_project.id=authsrv_projectgroupnamespacerole.project_id`). // also need a namespace join diff --git a/internal/dao/kubeconfig.go b/internal/dao/kubeconfig.go index b6f8100..d5b5da2 100644 --- a/internal/dao/kubeconfig.go +++ b/internal/dao/kubeconfig.go @@ -8,36 +8,12 @@ import ( "github.com/RafaySystems/rcloud-base/internal/persistence/provider/pg" "github.com/RafaySystems/rcloud-base/proto/types/sentry" "github.com/google/uuid" + "github.com/uptrace/bun" ) -// KubeconfigDao is the interface for kubeconfig operations -type KubeconfigDao interface { - GetKubeconfigRevocation(ctx context.Context, orgID, accountID uuid.UUID, isSSOUser bool) (*models.KubeconfigRevocation, error) - CreateKubeconfigRevocation(ctx context.Context, kr *models.KubeconfigRevocation) error - UpdateKubeconfigRevocation(ctx context.Context, kr *models.KubeconfigRevocation) error - GetKubeconfigSetting(ctx context.Context, orgID, accountID uuid.UUID, issSSO bool) (*models.KubeconfigSetting, error) - CreateKubeconfigSetting(ctx context.Context, ks *models.KubeconfigSetting) error - UpdateKubeconfigSetting(ctx context.Context, ks *models.KubeconfigSetting) error - GetkubectlClusterSettings(ctx context.Context, orgID uuid.UUID, name string) (*models.KubectlClusterSetting, error) - CreatekubectlClusterSettings(ctx context.Context, kc *models.KubectlClusterSetting) error - UpdatekubectlClusterSettings(ctx context.Context, kc *models.KubectlClusterSetting) error -} - -// kubeconfigDao implements BootstrapDao -type kubeconfigDao struct { - dao pg.EntityDAO -} - -// KubeconfigDao return new kube config dao -func NewKubeconfigDao(edao pg.EntityDAO) KubeconfigDao { - return &kubeconfigDao{ - dao: edao, - } -} - -func (s *kubeconfigDao) GetKubeconfigRevocation(ctx context.Context, orgID, accountID uuid.UUID, isSSOUser bool) (*models.KubeconfigRevocation, error) { +func GetKubeconfigRevocation(ctx context.Context, db bun.IDB, orgID, accountID uuid.UUID, isSSOUser bool) (*models.KubeconfigRevocation, error) { var kr models.KubeconfigRevocation - err := s.dao.GetInstance().NewSelect().Model(&kr). + err := db.NewSelect().Model(&kr). Where("organization_id = ?", orgID). Where("account_id = ?", accountID). Where("is_sso_user = ?", isSSOUser). @@ -45,13 +21,13 @@ func (s *kubeconfigDao) GetKubeconfigRevocation(ctx context.Context, orgID, acco return &kr, err } -func (s *kubeconfigDao) CreateKubeconfigRevocation(ctx context.Context, kr *models.KubeconfigRevocation) error { - _, err := s.dao.Create(ctx, kr) +func CreateKubeconfigRevocation(ctx context.Context, db bun.IDB, kr *models.KubeconfigRevocation) error { + _, err := pg.Create(ctx, db, kr) return err } -func (s *kubeconfigDao) UpdateKubeconfigRevocation(ctx context.Context, kr *models.KubeconfigRevocation) error { - q := s.dao.GetInstance().NewUpdate().Model(kr) +func UpdateKubeconfigRevocation(ctx context.Context, db bun.IDB, kr *models.KubeconfigRevocation) error { + q := db.NewUpdate().Model(kr) q = q.Where("organization_id = ?", kr.OrganizationId). Where("account_id = ?", kr.AccountId). @@ -63,9 +39,9 @@ func (s *kubeconfigDao) UpdateKubeconfigRevocation(ctx context.Context, kr *mode return err } -func (s *kubeconfigDao) GetKubeconfigSetting(ctx context.Context, orgID, accountID uuid.UUID, issSSO bool) (*models.KubeconfigSetting, error) { +func GetKubeconfigSetting(ctx context.Context, db bun.IDB, orgID, accountID uuid.UUID, issSSO bool) (*models.KubeconfigSetting, error) { var ks models.KubeconfigSetting - err := s.dao.GetInstance().NewSelect().Model(&ks). + err := db.NewSelect().Model(&ks). Where("organization_id = ?", orgID). Where("account_id = ?", accountID). Where("is_sso_user= ?", issSSO). @@ -73,18 +49,18 @@ func (s *kubeconfigDao) GetKubeconfigSetting(ctx context.Context, orgID, account return &ks, err } -func (s *kubeconfigDao) CreateKubeconfigSetting(ctx context.Context, ks *models.KubeconfigSetting) error { +func CreateKubeconfigSetting(ctx context.Context, db bun.IDB, ks *models.KubeconfigSetting) error { if ks.AccountId == uuid.Nil { ks.Scope = sentry.KubeconfigSettingOrganizationScope } else { ks.Scope = sentry.KubeconfigSettingUserScope } - _, err := s.dao.Create(ctx, ks) + _, err := pg.Create(ctx, db, ks) return err } -func (s *kubeconfigDao) UpdateKubeconfigSetting(ctx context.Context, ks *models.KubeconfigSetting) error { - q := s.dao.GetInstance().NewUpdate().Model(ks) +func UpdateKubeconfigSetting(ctx context.Context, db bun.IDB, ks *models.KubeconfigSetting) error { + q := db.NewUpdate().Model(ks) q = q.Where("organization_id = ?", ks.OrganizationId). Where("account_id = ?", ks.AccountId). @@ -103,21 +79,21 @@ func (s *kubeconfigDao) UpdateKubeconfigSetting(ctx context.Context, ks *models. return err } -func (s *kubeconfigDao) GetkubectlClusterSettings(ctx context.Context, orgID uuid.UUID, name string) (*models.KubectlClusterSetting, error) { +func GetkubectlClusterSettings(ctx context.Context, db bun.IDB, orgID uuid.UUID, name string) (*models.KubectlClusterSetting, error) { var kc models.KubectlClusterSetting - err := s.dao.GetInstance().NewSelect().Model(&kc). + err := db.NewSelect().Model(&kc). Where("organization_id = ?", orgID). Where("name = ?", name).Scan(ctx) return &kc, err } -func (s *kubeconfigDao) CreatekubectlClusterSettings(ctx context.Context, kc *models.KubectlClusterSetting) error { - _, err := s.dao.Create(ctx, kc) +func CreatekubectlClusterSettings(ctx context.Context, db bun.IDB, kc *models.KubectlClusterSetting) error { + _, err := pg.Create(ctx, db, kc) return err } -func (s *kubeconfigDao) UpdatekubectlClusterSettings(ctx context.Context, kc *models.KubectlClusterSetting) error { - q := s.dao.GetInstance().NewUpdate().Model(kc) +func UpdatekubectlClusterSettings(ctx context.Context, db bun.IDB, kc *models.KubectlClusterSetting) error { + q := db.NewUpdate().Model(kc) q = q.Where("organization_id = ?", kc.OrganizationId). Where("name = ?", kc.Name) diff --git a/internal/dao/permission.go b/internal/dao/permission.go index 4678b5f..bf5bb8d 100644 --- a/internal/dao/permission.go +++ b/internal/dao/permission.go @@ -4,47 +4,13 @@ import ( "context" "github.com/RafaySystems/rcloud-base/internal/models" - "github.com/RafaySystems/rcloud-base/internal/persistence/provider/pg" "github.com/google/uuid" "github.com/uptrace/bun" ) -// PermissionDao is the interface for permission operations -type PermissionDao interface { - GetGroupPermissions(ctx context.Context, groupNames []string, orgID, partnerID uuid.UUID) ([]models.GroupPermission, error) - GetGroupProjectsByPermission(ctx context.Context, groupNames []string, orgID, partnerID uuid.UUID, permission string) ([]models.GroupPermission, error) - GetGroupPermissionsByProjectIDPermissions(ctx context.Context, groupNames []string, orgID, partnerID uuid.UUID, projects []string, permissions []string) ([]models.GroupPermission, error) - GetProjectByGroup(ctx context.Context, groupNames []string, orgID, partnerID uuid.UUID) ([]models.GroupPermission, error) - GetAccountPermissions(ctx context.Context, accountID, orgID, partnerID uuid.UUID) ([]models.AccountPermission, error) - IsPartnerSuperAdmin(ctx context.Context, accountID, partnerID uuid.UUID) (isPartnerAdmin, isSuperAdmin bool, err error) - GetAccountProjectsByPermission(ctx context.Context, accountID, orgID, partnerID uuid.UUID, permission string) ([]models.AccountPermission, error) - GetAccountPermissionsByProjectIDPermissions(ctx context.Context, accountID, orgID, partnerID uuid.UUID, projects []uuid.UUID, permissions []string) ([]models.AccountPermission, error) - GetSSOUsersGroupProjectRole(ctx context.Context, orgID uuid.UUID) ([]models.SSOAccountGroupProjectRole, error) - GetAcccountsWithApprovalPermission(ctx context.Context, orgID, partnerID uuid.UUID) ([]string, error) - GetSSOAcccountsWithApprovalPermission(ctx context.Context, orgID, partnerID uuid.UUID) ([]string, error) - IsOrgAdmin(ctx context.Context, accountID, partnerID uuid.UUID) (isOrgAdmin bool, err error) - GetAccountBasics(ctx context.Context, accountID uuid.UUID) (*models.Account, error) - GetAccountGroups(ctx context.Context, accountID uuid.UUID) ([]models.GroupAccount, error) - GetDefaultUserGroup(ctx context.Context, orgID uuid.UUID) (*models.Group, error) - GetDefaultUserGroupAccount(ctx context.Context, accountID, groupID uuid.UUID) (*models.GroupAccount, error) - GetDefaultAccountProject(ctx context.Context, accountID uuid.UUID) (models.AccountPermission, error) -} - -// permissionDao implements PermissionDao -type permissionDao struct { - dao pg.EntityDAO -} - -// PermissionDao return new permission dao -func NewPermissionDao(edao pg.EntityDAO) PermissionDao { - return &permissionDao{ - dao: edao, - } -} - -func (s *permissionDao) GetGroupPermissions(ctx context.Context, groupNames []string, orgID, partnerID uuid.UUID) ([]models.GroupPermission, error) { +func GetGroupPermissions(ctx context.Context, db bun.IDB, groupNames []string, orgID, partnerID uuid.UUID) ([]models.GroupPermission, error) { var gps []models.GroupPermission - err := s.dao.GetInstance().NewSelect().Model(&gps). + err := db.NewSelect().Model(&gps). Where("organization_id = ?", orgID). Where("partner_id = ?", partnerID). Where("group_name IN (?)", bun.In(groupNames)). @@ -52,10 +18,10 @@ func (s *permissionDao) GetGroupPermissions(ctx context.Context, groupNames []st return gps, err } -func (s *permissionDao) GetGroupProjectsByPermission(ctx context.Context, groupNames []string, orgID, partnerID uuid.UUID, permission string) ([]models.GroupPermission, error) { +func GetGroupProjectsByPermission(ctx context.Context, db bun.IDB, groupNames []string, orgID, partnerID uuid.UUID, permission string) ([]models.GroupPermission, error) { var gps []models.GroupPermission - err := s.dao.GetInstance().NewSelect().Model(&gps). + err := db.NewSelect().Model(&gps). Where("organization_id = ?", orgID). Where("partner_id = ?", partnerID). Where("group_name IN (?)", bun.In(groupNames)). @@ -65,10 +31,10 @@ func (s *permissionDao) GetGroupProjectsByPermission(ctx context.Context, groupN return gps, err } -func (s *permissionDao) GetGroupPermissionsByProjectIDPermissions(ctx context.Context, groupNames []string, orgID, partnerID uuid.UUID, projects []string, permissions []string) ([]models.GroupPermission, error) { +func GetGroupPermissionsByProjectIDPermissions(ctx context.Context, db bun.IDB, groupNames []string, orgID, partnerID uuid.UUID, projects []string, permissions []string) ([]models.GroupPermission, error) { var gps []models.GroupPermission - err := s.dao.GetInstance().NewSelect().Model(&gps). + err := db.NewSelect().Model(&gps). Where("organization_id = ?", orgID). Where("partner_id = ?", partnerID). Where("group_name IN (?)", bun.In(groupNames)). @@ -79,10 +45,10 @@ func (s *permissionDao) GetGroupPermissionsByProjectIDPermissions(ctx context.Co return gps, err } -func (s *permissionDao) GetProjectByGroup(ctx context.Context, groupNames []string, orgID, partnerID uuid.UUID) ([]models.GroupPermission, error) { +func GetProjectByGroup(ctx context.Context, db bun.IDB, groupNames []string, orgID, partnerID uuid.UUID) ([]models.GroupPermission, error) { var gps []models.GroupPermission - err := s.dao.GetInstance().NewSelect().Model(&gps). + err := db.NewSelect().Model(&gps). Where("organization_id = ?", orgID). Where("partner_id = ?", partnerID). Where("group_name IN (?)", bun.In(groupNames)). @@ -92,10 +58,10 @@ func (s *permissionDao) GetProjectByGroup(ctx context.Context, groupNames []stri return gps, err } -func (a *permissionDao) GetAccountPermissions(ctx context.Context, accountID, orgID, partnerID uuid.UUID) ([]models.AccountPermission, error) { +func GetAccountPermissions(ctx context.Context, db bun.IDB, accountID, orgID, partnerID uuid.UUID) ([]models.AccountPermission, error) { var aps []models.AccountPermission - err := a.dao.GetInstance().NewSelect().Model(&aps). + err := db.NewSelect().Model(&aps). Where("account_id = ?", accountID). Where("organization_id = ?", orgID). Where("partner_id = ?", partnerID). @@ -104,13 +70,13 @@ func (a *permissionDao) GetAccountPermissions(ctx context.Context, accountID, or return aps, err } -func (a *permissionDao) IsPartnerSuperAdmin(ctx context.Context, accountID, partnerID uuid.UUID) (isPartnerAdmin, isSuperAdmin bool, err error) { +func IsPartnerSuperAdmin(ctx context.Context, db bun.IDB, accountID, partnerID uuid.UUID) (isPartnerAdmin, isSuperAdmin bool, err error) { var aps []models.AccountPermission isSuperAdmin = false isPartnerAdmin = false - err = a.dao.GetInstance().NewSelect().Model(&aps). + err = db.NewSelect().Model(&aps). Where("account_id = ?", accountID). Where("partner_id = ?", partnerID). WhereGroup(" AND ", func(sq *bun.SelectQuery) *bun.SelectQuery { @@ -133,10 +99,10 @@ func (a *permissionDao) IsPartnerSuperAdmin(ctx context.Context, accountID, part return isPartnerAdmin, isSuperAdmin, nil } -func (a *permissionDao) GetAccountProjectsByPermission(ctx context.Context, accountID, orgID, partnerID uuid.UUID, permission string) ([]models.AccountPermission, error) { +func GetAccountProjectsByPermission(ctx context.Context, db bun.IDB, accountID, orgID, partnerID uuid.UUID, permission string) ([]models.AccountPermission, error) { var aps []models.AccountPermission - err := a.dao.GetInstance().NewSelect().Model(&aps). + err := db.NewSelect().Model(&aps). Where("account_id = ?", accountID). Where("organization_id = ?", orgID). Where("partner_id = ?", partnerID). @@ -146,10 +112,10 @@ func (a *permissionDao) GetAccountProjectsByPermission(ctx context.Context, acco return aps, err } -func (a *permissionDao) GetDefaultAccountProject(ctx context.Context, accountID uuid.UUID) (models.AccountPermission, error) { +func GetDefaultAccountProject(ctx context.Context, db bun.IDB, accountID uuid.UUID) (models.AccountPermission, error) { var aps models.AccountPermission - err := a.dao.GetInstance().NewSelect().Model(&aps). + err := db.NewSelect().Model(&aps). ColumnExpr("sap.*"). Join("JOIN authsrv_project as proj").JoinOn("proj.id = sap.project_id").JoinOn("proj.default = ?", true). Where("account_id = ?", accountID).Limit(1). @@ -158,10 +124,10 @@ func (a *permissionDao) GetDefaultAccountProject(ctx context.Context, accountID return aps, err } -func (a *permissionDao) GetAccountPermissionsByProjectIDPermissions(ctx context.Context, accountID, orgID, partnerID uuid.UUID, projects []uuid.UUID, permissions []string) ([]models.AccountPermission, error) { +func GetAccountPermissionsByProjectIDPermissions(ctx context.Context, db bun.IDB, accountID, orgID, partnerID uuid.UUID, projects []uuid.UUID, permissions []string) ([]models.AccountPermission, error) { var aps []models.AccountPermission - err := a.dao.GetInstance().NewSelect().Model(&aps). + err := db.NewSelect().Model(&aps). Where("account_id = ?", accountID). Where("organization_id = ?", orgID). Where("partner_id = ?", partnerID). @@ -172,17 +138,17 @@ func (a *permissionDao) GetAccountPermissionsByProjectIDPermissions(ctx context. return aps, err } -func (a *permissionDao) GetSSOUsersGroupProjectRole(ctx context.Context, orgID uuid.UUID) ([]models.SSOAccountGroupProjectRole, error) { +func GetSSOUsersGroupProjectRole(ctx context.Context, db bun.IDB, orgID uuid.UUID) ([]models.SSOAccountGroupProjectRole, error) { var ssos []models.SSOAccountGroupProjectRole - err := a.dao.GetInstance().NewSelect().Model(&ssos). + err := db.NewSelect().Model(&ssos). Where("organization_id = ?", orgID). Scan(ctx) return ssos, err } -func (a *permissionDao) GetAcccountsWithApprovalPermission(ctx context.Context, orgID, partnerID uuid.UUID) ([]string, error) { +func GetAcccountsWithApprovalPermission(ctx context.Context, db bun.IDB, orgID, partnerID uuid.UUID) ([]string, error) { // TODO: remove this from here once Account is structured in types.proto type accountPermission struct { bun.BaseModel `bun:"table:sentry_account_permission,alias:sap"` @@ -190,7 +156,7 @@ func (a *permissionDao) GetAcccountsWithApprovalPermission(ctx context.Context, *models.AccountPermission } var aps []accountPermission - err := a.dao.GetInstance().NewSelect().Model(&aps). + err := db.NewSelect().Model(&aps). ColumnExpr("ki.traits -> 'email'"). DistinctOn("ki.traits -> 'email'"). Join("INNER JOIN identities as ki ON ?TableAlias.account_id = ki.id"). @@ -213,9 +179,9 @@ func (a *permissionDao) GetAcccountsWithApprovalPermission(ctx context.Context, return usernames, nil } -func (a *permissionDao) GetSSOAcccountsWithApprovalPermission(ctx context.Context, orgID, partnerID uuid.UUID) ([]string, error) { +func GetSSOAcccountsWithApprovalPermission(ctx context.Context, db bun.IDB, orgID, partnerID uuid.UUID) ([]string, error) { var ssoaps []models.SSOAccountGroupProjectRole - err := a.dao.GetInstance().NewSelect().Model(&ssoaps). + err := db.NewSelect().Model(&ssoaps). Where("?TableAlias.organization_id = ?", orgID). Where("?TableAlias.partner_id = ?", partnerID). WhereGroup("grp", func(sq *bun.SelectQuery) *bun.SelectQuery { @@ -237,12 +203,12 @@ func (a *permissionDao) GetSSOAcccountsWithApprovalPermission(ctx context.Contex return usernames, nil } -func (a *permissionDao) IsOrgAdmin(ctx context.Context, accountID, partnerID uuid.UUID) (isOrgAdmin bool, err error) { +func IsOrgAdmin(ctx context.Context, db bun.IDB, accountID, partnerID uuid.UUID) (isOrgAdmin bool, err error) { var aps []models.AccountPermission isOrgAdmin = false - err = a.dao.GetInstance().NewSelect().Model(&aps). + err = db.NewSelect().Model(&aps). Where("account_id = ?", accountID). Where("partner_id = ?", partnerID). Where("role_name = ?", "ADMIN"). @@ -262,10 +228,10 @@ func (a *permissionDao) IsOrgAdmin(ctx context.Context, accountID, partnerID uui return isOrgAdmin, nil } -func (a *permissionDao) GetAccountBasics(ctx context.Context, accountID uuid.UUID) (*models.Account, error) { +func GetAccountBasics(ctx context.Context, db bun.IDB, accountID uuid.UUID) (*models.Account, error) { var acc models.Account - err := a.dao.GetInstance().NewSelect().Model(&acc). + err := db.NewSelect().Model(&acc). Column("identities.id", "traits", "state"). ColumnExpr("max(ks.authenticated_at) as lastlogin"). ColumnExpr("identities.traits -> 'email' as username"). @@ -280,10 +246,10 @@ func (a *permissionDao) GetAccountBasics(ctx context.Context, accountID uuid.UUI return &acc, nil } -func (a *permissionDao) GetAccountGroups(ctx context.Context, accountID uuid.UUID) ([]models.GroupAccount, error) { +func GetAccountGroups(ctx context.Context, db bun.IDB, accountID uuid.UUID) ([]models.GroupAccount, error) { var ga []models.GroupAccount - err := a.dao.GetInstance().NewSelect().Model(&ga). + err := db.NewSelect().Model(&ga). Where("account_id = ?", accountID). Where("trash = ?", false). Where("active = ?", true). @@ -294,9 +260,9 @@ func (a *permissionDao) GetAccountGroups(ctx context.Context, accountID uuid.UUI return ga, nil } -func (a *permissionDao) GetDefaultUserGroup(ctx context.Context, orgID uuid.UUID) (*models.Group, error) { +func GetDefaultUserGroup(ctx context.Context, db bun.IDB, orgID uuid.UUID) (*models.Group, error) { var g models.Group - err := a.dao.GetInstance().NewSelect().Model(&g). + err := db.NewSelect().Model(&g). Where("organization_id = ?", orgID). Where("type = ?", "DEFAULT_USERS"). Where("trash = ?", false). @@ -304,9 +270,9 @@ func (a *permissionDao) GetDefaultUserGroup(ctx context.Context, orgID uuid.UUID return &g, err } -func (a *permissionDao) GetDefaultUserGroupAccount(ctx context.Context, accountID, groupID uuid.UUID) (*models.GroupAccount, error) { +func GetDefaultUserGroupAccount(ctx context.Context, db bun.IDB, accountID, groupID uuid.UUID) (*models.GroupAccount, error) { var ga models.GroupAccount - err := a.dao.GetInstance().NewSelect().Model(&ga). + err := db.NewSelect().Model(&ga). Where("account_id = ?", accountID). Where("group_id = ?", groupID). Where("trash = ?", false). diff --git a/internal/dao/role.go b/internal/dao/role.go index ee17b58..7b4939c 100644 --- a/internal/dao/role.go +++ b/internal/dao/role.go @@ -8,30 +8,10 @@ import ( "github.com/uptrace/bun" ) -type roleDAO struct { - db *bun.DB -} - -// Role specific db access -type RoleDAO interface { - Close() error - // get permissions for role - GetRolePermissions(context.Context, uuid.UUID) ([]models.ResourcePermission, error) -} - -// NewRoleDao return new group dao -func NewRoleDAO(db *bun.DB) *roleDAO { - return &roleDAO{db} -} - -func (dao *roleDAO) Close() error { - return dao.db.Close() -} - -func (dao *roleDAO) GetRolePermissions(ctx context.Context, id uuid.UUID) ([]models.ResourcePermission, error) { +func GetRolePermissions(ctx context.Context, db bun.IDB, id uuid.UUID) ([]models.ResourcePermission, error) { // Could possibly union them later for some speedup var r = []models.ResourcePermission{} - err := dao.db.NewSelect().Table("authsrv_resourcepermission"). + err := db.NewSelect().Table("authsrv_resourcepermission"). ColumnExpr("authsrv_resourcepermission.name as name"). Join(`JOIN authsrv_resourcerolepermission ON authsrv_resourcerolepermission.resource_permission_id=authsrv_resourcepermission.id`). Where("authsrv_resourcerolepermission.resource_role_id = ?", id). diff --git a/internal/dao/user.go b/internal/dao/user.go index 8d3f186..69fdc31 100644 --- a/internal/dao/user.go +++ b/internal/dao/user.go @@ -9,42 +9,19 @@ import ( "github.com/uptrace/bun" ) -type userDAO struct { - db *bun.DB -} - -// User specific db access -type UserDAO interface { - Close() error - // get groups for user - GetGroups(context.Context, uuid.UUID) ([]models.Group, error) - // get roles for user - GetRoles(context.Context, uuid.UUID) ([]*userv3.ProjectNamespaceRole, error) -} - -// NewUserDao return new user dao -func NewUserDAO(db *bun.DB) *userDAO { - return &userDAO{db} -} - -func (dao *userDAO) Close() error { - // XXX: if one dao closes the db connections, won't other have issues? - return dao.db.Close() -} - -func (dao *userDAO) GetGroups(ctx context.Context, id uuid.UUID) ([]models.Group, error) { +func GetGroups(ctx context.Context, db bun.IDB, id uuid.UUID) ([]models.Group, error) { var entities = []models.Group{} - err := dao.db.NewSelect().Model(&entities). + err := db.NewSelect().Model(&entities). Join(`JOIN authsrv_groupaccount ON authsrv_groupaccount.group_id="group".id`). Where("authsrv_groupaccount.account_id = ?", id). Scan(ctx) return entities, err } -func (dao *userDAO) GetRoles(ctx context.Context, id uuid.UUID) ([]*userv3.ProjectNamespaceRole, error) { - // Could possibily union them later for some speedup +func GetUserRoles(ctx context.Context, db bun.IDB, id uuid.UUID) ([]*userv3.ProjectNamespaceRole, error) { + // Could possibly union them later for some speedup var r = []*userv3.ProjectNamespaceRole{} - err := dao.db.NewSelect().Table("authsrv_accountresourcerole"). + err := db.NewSelect().Table("authsrv_accountresourcerole"). ColumnExpr("authsrv_resourcerole.name as role"). Join(`JOIN authsrv_resourcerole ON authsrv_resourcerole.id=authsrv_accountresourcerole.role_id`). Where("authsrv_accountresourcerole.account_id = ?", id). @@ -54,7 +31,7 @@ func (dao *userDAO) GetRoles(ctx context.Context, id uuid.UUID) ([]*userv3.Proje } var pr = []*userv3.ProjectNamespaceRole{} - err = dao.db.NewSelect().Table("authsrv_projectaccountresourcerole"). + err = db.NewSelect().Table("authsrv_projectaccountresourcerole"). ColumnExpr("authsrv_resourcerole.name as role, authsrv_project.name as project"). Join(`JOIN authsrv_resourcerole ON authsrv_resourcerole.id=authsrv_projectaccountresourcerole.role_id`). Join(`JOIN authsrv_project ON authsrv_project.id=authsrv_projectaccountresourcerole.project_id`). @@ -65,7 +42,7 @@ func (dao *userDAO) GetRoles(ctx context.Context, id uuid.UUID) ([]*userv3.Proje } var pnr = []*userv3.ProjectNamespaceRole{} - err = dao.db.NewSelect().Table("authsrv_projectaccountnamespacerole"). + err = db.NewSelect().Table("authsrv_projectaccountnamespacerole"). ColumnExpr("authsrv_resourcerole.name as role, authsrv_project.name as project, namespace_id as namespace"). Join(`JOIN authsrv_resourcerole ON authsrv_resourcerole.id=authsrv_projectaccountnamespacerole.role_id`). Join(`JOIN authsrv_project ON authsrv_project.id=authsrv_projectaccountnamespacerole.project_id`). // also need a namespace join diff --git a/internal/persistence/provider/pg/entity_dao.go b/internal/persistence/provider/pg/entity_dao.go index b226a59..34a4434 100644 --- a/internal/persistence/provider/pg/entity_dao.go +++ b/internal/persistence/provider/pg/entity_dao.go @@ -8,77 +8,16 @@ import ( bun "github.com/uptrace/bun" ) -// DAO is the interface for database operations -type EntityDAO interface { - Close() error - // create entity - Create(context.Context, interface{}) (interface{}, error) - // get entity by field - GetX(context.Context, string, interface{}, interface{}) (interface{}, error) - // get entity by multiple fields - GetM(context.Context, map[string]interface{}, interface{}) (interface{}, error) - // get entity by id - GetByID(context.Context, uuid.UUID, interface{}) (interface{}, error) - - // get entity by name - GetByName(context.Context, string, interface{}) (interface{}, error) - // get entity by name partner and org - GetByNamePartnerOrg(context.Context, string, uuid.NullUUID, uuid.NullUUID, interface{}) (interface{}, error) - // get entity id by name - GetIdByName(context.Context, string, interface{}) (interface{}, error) - // get entity id by name partner and org - GetIdByNamePartnerOrg(context.Context, string, uuid.NullUUID, uuid.NullUUID, interface{}) (interface{}, error) - // get entity name by id - GetNameById(context.Context, uuid.UUID, interface{}) (interface{}, error) - //Update entity - Update(context.Context, uuid.UUID, interface{}) (interface{}, error) - // get entity by field - UpdateX(context.Context, string, interface{}, interface{}) (interface{}, error) - // delete entity by field - DeleteX(context.Context, string, interface{}, interface{}) error - // delete entity - Delete(context.Context, uuid.UUID, interface{}) error - // delete all items in table (for script) - HardDeleteAll(context.Context, interface{}) error - // get list of entities - List(context.Context, uuid.NullUUID, uuid.NullUUID, interface{}) (interface{}, error) - // get list of entities - ListByProject(context.Context, uuid.NullUUID, uuid.NullUUID, uuid.NullUUID, interface{}) error - // get list of entities without filtering - ListAll(context.Context, interface{}) (interface{}, error) - - // lookup user by traits - GetByTraits(ctx context.Context, name string, entity interface{}) (interface{}, error) - // lookup user id by traits - GetIdByTraits(ctx context.Context, name string, entity interface{}) (interface{}, error) - - //returns db object - GetInstance() *bun.DB -} - -type entityDAO struct { - db *bun.DB -} - -func (dao *entityDAO) Close() error { - return dao.db.Close() -} - -// NewEntityDao return new entity dao -func NewEntityDAO(db *bun.DB) EntityDAO { - return &entityDAO{db} -} - -func (dao *entityDAO) Create(ctx context.Context, entity interface{}) (interface{}, error) { - if _, err := dao.db.NewInsert().Model(entity).Exec(ctx); err != nil { +func Create(ctx context.Context, db bun.IDB, entity interface{}) (interface{}, error) { + if _, err := db.NewInsert().Model(entity).Exec(ctx); err != nil { return nil, err } return entity, nil } -func (dao *entityDAO) GetX(ctx context.Context, field string, value interface{}, entity interface{}) (interface{}, error) { - err := dao.db.NewSelect().Model(entity). +func GetX(ctx context.Context, db bun.IDB, field string, value interface{}, entity interface{}) (interface{}, error) { + err := db.NewSelect().Model(entity). Where(fmt.Sprintf("%s = ?", field), value). Where("trash = ?", false). Scan(ctx) @@ -90,9 +29,9 @@ func (dao *entityDAO) GetX(ctx context.Context, field string, value interface{}, } // M for multi ;) -func (dao *entityDAO) GetM(ctx context.Context, checks map[string]interface{}, entity interface{}) (interface{}, error) { +func GetM(ctx context.Context, db bun.IDB, checks map[string]interface{}, entity interface{}) (interface{}, error) { // Can we get the checks directly from entity and create an upsert sort of func? - q := dao.db.NewSelect().Model(entity) + q := db.NewSelect().Model(entity) for field := range checks { q.Where(fmt.Sprintf("%s = ?", field), checks[field]) } @@ -104,8 +43,8 @@ func (dao *entityDAO) GetM(ctx context.Context, checks map[string]interface{}, e return entity, nil } -func (dao *entityDAO) GetByID(ctx context.Context, id uuid.UUID, entity interface{}) (interface{}, error) { - err := dao.db.NewSelect().Model(entity). +func GetByID(ctx context.Context, db bun.IDB, id uuid.UUID, entity interface{}) (interface{}, error) { + err := db.NewSelect().Model(entity). Where("id = ?", id). Where("trash = ?", false). Scan(ctx) @@ -116,8 +55,8 @@ func (dao *entityDAO) GetByID(ctx context.Context, id uuid.UUID, entity interfac return entity, nil } -func (dao *entityDAO) GetByName(ctx context.Context, name string, entity interface{}) (interface{}, error) { - err := dao.db.NewSelect().Model(entity). +func GetByName(ctx context.Context, db bun.IDB, name string, entity interface{}) (interface{}, error) { + err := db.NewSelect().Model(entity). Where("name = ?", name). Where("trash = ?", false). Scan(ctx) @@ -127,8 +66,8 @@ func (dao *entityDAO) GetByName(ctx context.Context, name string, entity interfa return entity, nil } -func (dao *entityDAO) GetByNamePartnerOrg(ctx context.Context, name string, pid uuid.NullUUID, oid uuid.NullUUID, entity interface{}) (interface{}, error) { - sq := dao.db.NewSelect().Model(entity) +func GetByNamePartnerOrg(ctx context.Context, db bun.IDB, name string, pid uuid.NullUUID, oid uuid.NullUUID, entity interface{}) (interface{}, error) { + sq := db.NewSelect().Model(entity) if oid.Valid { sq = sq.Where("organization_id = ?", oid) } @@ -146,8 +85,8 @@ func (dao *entityDAO) GetByNamePartnerOrg(ctx context.Context, name string, pid return entity, nil } -func (dao *entityDAO) GetIdByName(ctx context.Context, name string, entity interface{}) (interface{}, error) { - err := dao.db.NewSelect().Column("id").Model(entity). +func GetIdByName(ctx context.Context, db bun.IDB, name string, entity interface{}) (interface{}, error) { + err := db.NewSelect().Column("id").Model(entity). Where("name = ?", name). Where("trash = ?", false). Scan(ctx) @@ -158,8 +97,8 @@ func (dao *entityDAO) GetIdByName(ctx context.Context, name string, entity inter return entity, nil } -func (dao *entityDAO) GetIdByNamePartnerOrg(ctx context.Context, name string, pid uuid.NullUUID, oid uuid.NullUUID, entity interface{}) (interface{}, error) { - sq := dao.db.NewSelect().Column("id").Model(entity) +func GetIdByNamePartnerOrg(ctx context.Context, db bun.IDB, name string, pid uuid.NullUUID, oid uuid.NullUUID, entity interface{}) (interface{}, error) { + sq := db.NewSelect().Column("id").Model(entity) if oid.Valid { sq = sq.Where("organization_id = ?", oid) } @@ -177,8 +116,8 @@ func (dao *entityDAO) GetIdByNamePartnerOrg(ctx context.Context, name string, pi return entity, nil } -func (dao *entityDAO) GetNameById(ctx context.Context, id uuid.UUID, entity interface{}) (interface{}, error) { - err := dao.db.NewSelect().Column("name").Model(entity). +func GetNameById(ctx context.Context, db bun.IDB, id uuid.UUID, entity interface{}) (interface{}, error) { + err := db.NewSelect().Column("name").Model(entity). Where("id = ?", id). Where("trash = ?", false). Scan(ctx) @@ -189,22 +128,22 @@ func (dao *entityDAO) GetNameById(ctx context.Context, id uuid.UUID, entity inte return entity, nil } -func (dao *entityDAO) Update(ctx context.Context, id uuid.UUID, entity interface{}) (interface{}, error) { - if _, err := dao.db.NewUpdate().Model(entity).Where("id = ?", id).Exec(ctx); err != nil { +func Update(ctx context.Context, db bun.IDB, id uuid.UUID, entity interface{}) (interface{}, error) { + if _, err := db.NewUpdate().Model(entity).Where("id = ?", id).Exec(ctx); err != nil { return nil, err } return entity, nil } -func (dao *entityDAO) UpdateX(ctx context.Context, field string, value interface{}, entity interface{}) (interface{}, error) { - if _, err := dao.db.NewUpdate().Model(entity).Where("? = ?", bun.Ident(field), value).Exec(ctx); err != nil { +func UpdateX(ctx context.Context, db bun.IDB, field string, value interface{}, entity interface{}) (interface{}, error) { + if _, err := db.NewUpdate().Model(entity).Where("? = ?", bun.Ident(field), value).Exec(ctx); err != nil { return nil, err } return entity, nil } -func (dao *entityDAO) Delete(ctx context.Context, id uuid.UUID, entity interface{}) error { - _, err := dao.db.NewUpdate(). +func Delete(ctx context.Context, db bun.IDB, id uuid.UUID, entity interface{}) error { + _, err := db.NewUpdate(). Model(entity). Column("trash"). Where("id = ?", id). @@ -213,8 +152,8 @@ func (dao *entityDAO) Delete(ctx context.Context, id uuid.UUID, entity interface return err } -func (dao *entityDAO) DeleteX(ctx context.Context, field string, value interface{}, entity interface{}) error { - _, err := dao.db.NewUpdate(). +func DeleteX(ctx context.Context, db bun.IDB, field string, value interface{}, entity interface{}) error { + _, err := db.NewUpdate(). Model(entity). Column("trash"). Where("? = ?", bun.Ident(field), value). @@ -224,16 +163,16 @@ func (dao *entityDAO) DeleteX(ctx context.Context, field string, value interface } // HardDeleteAll deletes all records in a table (primarily for use in scripts) -func (dao *entityDAO) HardDeleteAll(ctx context.Context, entity interface{}) error { - _, err := dao.db.NewDelete(). +func HardDeleteAll(ctx context.Context, db bun.IDB, entity interface{}) error { + _, err := db.NewDelete(). Model(entity). Where("1 = 1"). // TODO: see how to remove this Exec(ctx) return err } -func (dao *entityDAO) List(ctx context.Context, partnerId uuid.NullUUID, organizationId uuid.NullUUID, entities interface{}) (interface{}, error) { - sq := dao.db.NewSelect().Model(entities) +func List(ctx context.Context, db bun.IDB, partnerId uuid.NullUUID, organizationId uuid.NullUUID, entities interface{}) (interface{}, error) { + sq := db.NewSelect().Model(entities) if partnerId.Valid { sq = sq.Where("partner_id = ?", partnerId) } @@ -245,8 +184,8 @@ func (dao *entityDAO) List(ctx context.Context, partnerId uuid.NullUUID, organiz return entities, err } -func (dao *entityDAO) ListByProject(ctx context.Context, partnerId uuid.NullUUID, organizationId uuid.NullUUID, projectId uuid.NullUUID, entities interface{}) error { - sq := dao.db.NewSelect().Model(entities) +func ListByProject(ctx context.Context, db bun.IDB, partnerId uuid.NullUUID, organizationId uuid.NullUUID, projectId uuid.NullUUID, entities interface{}) error { + sq := db.NewSelect().Model(entities) if partnerId.Valid { sq = sq.Where("partner_id = ?", partnerId) } @@ -261,16 +200,15 @@ func (dao *entityDAO) ListByProject(ctx context.Context, partnerId uuid.NullUUID return err } -func (dao *entityDAO) ListAll(ctx context.Context, entities interface{}) (interface{}, error) { - err := dao.db.NewSelect().Model(entities).Scan(ctx) +func ListAll(ctx context.Context, db bun.IDB, entities interface{}) (interface{}, error) { + err := db.NewSelect().Model(entities).Scan(ctx) return entities, err } -func (dao *entityDAO) GetByTraits(ctx context.Context, name string, entity interface{}) (interface{}, error) { +func GetByTraits(ctx context.Context, db bun.IDB, name string, entity interface{}) (interface{}, error) { // TODO: better name and possibly pass in trait name - err := dao.db.NewSelect().Model(entity). + err := db.NewSelect().Model(entity). Where("traits ->> 'email' = ?", name). - Where("trash = ?", false). Scan(ctx) if err != nil { return nil, err @@ -279,11 +217,10 @@ func (dao *entityDAO) GetByTraits(ctx context.Context, name string, entity inter return entity, nil } -func (dao *entityDAO) GetIdByTraits(ctx context.Context, name string, entity interface{}) (interface{}, error) { +func GetIdByTraits(ctx context.Context, db bun.IDB, name string, entity interface{}) (interface{}, error) { // TODO: better name and possibly pass in trait name - err := dao.db.NewSelect().Column("id").Model(entity). + err := db.NewSelect().Column("id").Model(entity). Where("traits ->> 'email' = ?", name). - Where("trash = ?", false). Scan(ctx) if err != nil { return nil, err @@ -291,7 +228,3 @@ func (dao *entityDAO) GetIdByTraits(ctx context.Context, name string, entity int return entity, nil } - -func (dao *entityDAO) GetInstance() *bun.DB { - return dao.db -} diff --git a/internal/persistence/provider/pg/lookup.go b/internal/persistence/provider/pg/lookup.go new file mode 100644 index 0000000..d8a36a8 --- /dev/null +++ b/internal/persistence/provider/pg/lookup.go @@ -0,0 +1,76 @@ +package pg + +import ( + "context" + "fmt" + + "github.com/RafaySystems/rcloud-base/internal/models" + "github.com/google/uuid" + "github.com/uptrace/bun" +) + +func GetPartnerId(ctx context.Context, db bun.IDB, name string) (uuid.UUID, error) { + entity, err := GetIdByName(ctx, db, name, &models.Partner{}) + if err != nil { + return uuid.Nil, err + } + if prt, ok := entity.(*models.Partner); ok { + return prt.ID, nil + } + return uuid.Nil, fmt.Errorf("no partner found with name %v", name) +} + +func GetOrganizationId(ctx context.Context, db bun.IDB, name string) (uuid.UUID, error) { + entity, err := GetIdByName(ctx, db, name, &models.Organization{}) + if err != nil { + return uuid.Nil, err + } + if org, ok := entity.(*models.Organization); ok { + return org.ID, nil + } + return uuid.Nil, fmt.Errorf("no organization found with name %v", name) +} + +func GetProjectId(ctx context.Context, db bun.IDB, name string) (uuid.UUID, error) { + entity, err := GetIdByName(ctx, db, name, &models.Project{}) + if err != nil { + return uuid.Nil, err + } + if proj, ok := entity.(*models.Project); ok { + return proj.ID, nil + } + return uuid.Nil, fmt.Errorf("no project found with name %v", name) +} + +func GetPartnerName(ctx context.Context, db bun.IDB, id uuid.UUID) (string, error) { + entity, err := GetNameById(ctx, db, id, &models.Partner{}) + if err != nil { + return "", err + } + if prt, ok := entity.(*models.Partner); ok { + return prt.Name, nil + } + return "", fmt.Errorf("no partner found with id %v", id) +} + +func GetOrganizationName(ctx context.Context, db bun.IDB, id uuid.UUID) (string, error) { + entity, err := GetNameById(ctx, db, id, &models.Organization{}) + if err != nil { + return "", err + } + if org, ok := entity.(*models.Organization); ok { + return org.Name, nil + } + return "", fmt.Errorf("no organization found with id %v", id) +} + +func GetProjectName(ctx context.Context, db bun.IDB, id uuid.UUID) (string, error) { + entity, err := GetNameById(ctx, db, id, &models.Project{}) + if err != nil { + return "", err + } + if proj, ok := entity.(*models.Project); ok { + return proj.Name, nil + } + return "", fmt.Errorf("no project found with id %v", id) +} diff --git a/internal/utils/lookup.go b/internal/utils/lookup.go deleted file mode 100644 index 58d23f2..0000000 --- a/internal/utils/lookup.go +++ /dev/null @@ -1,98 +0,0 @@ -package utils - -import ( - "context" - "fmt" - - "github.com/RafaySystems/rcloud-base/internal/models" - "github.com/RafaySystems/rcloud-base/internal/persistence/provider/pg" - "github.com/google/uuid" - "github.com/uptrace/bun" -) - -// TODO: could use a better name -type lookup struct { - dao pg.EntityDAO -} - -type Lookup interface { - GetPartnerId(context.Context, string) (uuid.UUID, error) - GetOrganizationId(context.Context, string) (uuid.UUID, error) - GetProjectId(context.Context, string) (uuid.UUID, error) - - GetPartnerName(context.Context, uuid.UUID) (string, error) - GetOrganizationName(context.Context, uuid.UUID) (string, error) - GetProjectName(context.Context, uuid.UUID) (string, error) -} - -func NewLookup(db *bun.DB) Lookup { - return &lookup{ - dao: pg.NewEntityDAO(db), - } -} - -func (l *lookup) GetPartnerId(ctx context.Context, name string) (uuid.UUID, error) { - entity, err := l.dao.GetIdByName(ctx, name, &models.Partner{}) - if err != nil { - return uuid.Nil, err - } - if prt, ok := entity.(*models.Partner); ok { - return prt.ID, nil - } - return uuid.Nil, fmt.Errorf("no partner found with name %v", name) -} - -func (l *lookup) GetOrganizationId(ctx context.Context, name string) (uuid.UUID, error) { - entity, err := l.dao.GetIdByName(ctx, name, &models.Organization{}) - if err != nil { - return uuid.Nil, err - } - if org, ok := entity.(*models.Organization); ok { - return org.ID, nil - } - return uuid.Nil, fmt.Errorf("no organization found with name %v", name) -} - -func (l *lookup) GetProjectId(ctx context.Context, name string) (uuid.UUID, error) { - entity, err := l.dao.GetIdByName(ctx, name, &models.Project{}) - if err != nil { - return uuid.Nil, err - } - if proj, ok := entity.(*models.Project); ok { - return proj.ID, nil - } - return uuid.Nil, fmt.Errorf("no project found with name %v", name) -} - -func (l *lookup) GetPartnerName(ctx context.Context, id uuid.UUID) (string, error) { - entity, err := l.dao.GetNameById(ctx, id, &models.Partner{}) - if err != nil { - return "", err - } - if prt, ok := entity.(*models.Partner); ok { - return prt.Name, nil - } - return "", fmt.Errorf("no partner found with id %v", id) -} - -func (l *lookup) GetOrganizationName(ctx context.Context, id uuid.UUID) (string, error) { - entity, err := l.dao.GetNameById(ctx, id, &models.Organization{}) - if err != nil { - return "", err - } - if org, ok := entity.(*models.Organization); ok { - return org.Name, nil - } - return "", fmt.Errorf("no organization found with id %v", id) -} - -func (l *lookup) GetProjectName(ctx context.Context, id uuid.UUID) (string, error) { - entity, err := l.dao.GetNameById(ctx, id, &models.Project{}) - if err != nil { - return "", err - } - if proj, ok := entity.(*models.Project); ok { - return proj.Name, nil - } - return "", fmt.Errorf("no project found with id %v", id) -} diff --git a/main.go b/main.go index ef485ef..a871e4b 100644 --- a/main.go +++ b/main.go @@ -512,11 +512,8 @@ func runRelayPeerRPC(wg *sync.WaitGroup, ctx context.Context) { func runRPC(wg *sync.WaitGroup, ctx context.Context) { defer wg.Done() - defer ps.Close() defer schedulerPool.Close() - defer gs.Close() - defer rs.Close() - defer rrs.Close() + defer db.Close() partnerServer := server.NewPartnerServer(ps) organizationServer := server.NewOrganizationServer(os) diff --git a/pkg/service/account_permission.go b/pkg/service/account_permission.go index 6228e78..db6ab3e 100644 --- a/pkg/service/account_permission.go +++ b/pkg/service/account_permission.go @@ -6,7 +6,6 @@ import ( "github.com/RafaySystems/rcloud-base/internal/dao" "github.com/RafaySystems/rcloud-base/internal/models" - "github.com/RafaySystems/rcloud-base/internal/persistence/provider/pg" "github.com/RafaySystems/rcloud-base/proto/types/sentry" "github.com/google/uuid" "github.com/uptrace/bun" @@ -15,7 +14,6 @@ import ( // AccountPermissionService is the interface for account permission operations type AccountPermissionService interface { - Close() error GetAccountPermissions(ctx context.Context, accountID string, orgID, partnerID string) ([]sentry.AccountPermission, error) IsPartnerSuperAdmin(ctx context.Context, accountID, partnerID string) (isPartnerAdmin, isSuperAdmin bool, err error) GetAccountProjectsByPermission(ctx context.Context, accountID, orgID, partnerID string, permission string) ([]sentry.AccountPermission, error) @@ -30,25 +28,16 @@ type AccountPermissionService interface { // accountPermissionService implements AccountPermissionService type accountPermissionService struct { - dao pg.EntityDAO - pdao dao.PermissionDao + db *bun.DB } // NewKubeconfigRevocation return new kubeconfig revocation service func NewAccountPermissionService(db *bun.DB) AccountPermissionService { - edao := pg.NewEntityDAO(db) - return &accountPermissionService{ - dao: edao, - pdao: dao.NewPermissionDao(edao), - } -} - -func (s *accountPermissionService) Close() error { - return s.dao.Close() + return &accountPermissionService{db} } func (a *accountPermissionService) GetAccountPermissions(ctx context.Context, accountID string, orgID, partnerID string) ([]sentry.AccountPermission, error) { - aps, err := a.pdao.GetAccountPermissions(ctx, uuid.MustParse(accountID), uuid.MustParse(orgID), uuid.MustParse(partnerID)) + aps, err := dao.GetAccountPermissions(ctx, a.db, uuid.MustParse(accountID), uuid.MustParse(orgID), uuid.MustParse(partnerID)) if err != nil { return nil, err } @@ -61,11 +50,11 @@ func (a *accountPermissionService) GetAccountPermissions(ctx context.Context, ac } func (a *accountPermissionService) IsPartnerSuperAdmin(ctx context.Context, accountID, partnerID string) (isPartnerAdmin, isSuperAdmin bool, err error) { - return a.pdao.IsPartnerSuperAdmin(ctx, uuid.MustParse(accountID), uuid.MustParse(partnerID)) + return dao.IsPartnerSuperAdmin(ctx, a.db, uuid.MustParse(accountID), uuid.MustParse(partnerID)) } func (a *accountPermissionService) GetAccountProjectsByPermission(ctx context.Context, accountID, orgID, partnerID string, permission string) ([]sentry.AccountPermission, error) { - aps, err := a.pdao.GetAccountProjectsByPermission(ctx, uuid.MustParse(accountID), uuid.MustParse(orgID), uuid.MustParse(partnerID), permission) + aps, err := dao.GetAccountProjectsByPermission(ctx, a.db, uuid.MustParse(accountID), uuid.MustParse(orgID), uuid.MustParse(partnerID), permission) if err != nil { return nil, err } @@ -85,7 +74,7 @@ func (a *accountPermissionService) GetAccountPermissionsByProjectIDPermissions(c projids = append(projids, id) } } - aps, err := a.pdao.GetAccountPermissionsByProjectIDPermissions(ctx, uuid.MustParse(accountID), uuid.MustParse(orgID), uuid.MustParse(partnerID), projids, permissions) + aps, err := dao.GetAccountPermissionsByProjectIDPermissions(ctx, a.db, uuid.MustParse(accountID), uuid.MustParse(orgID), uuid.MustParse(partnerID), projids, permissions) if err != nil { return nil, err } @@ -98,11 +87,11 @@ func (a *accountPermissionService) GetAccountPermissionsByProjectIDPermissions(c } func (a *accountPermissionService) GetAccount(ctx context.Context, accountID string) (*models.Account, error) { - return a.pdao.GetAccountBasics(ctx, uuid.MustParse(accountID)) + return dao.GetAccountBasics(ctx, a.db, uuid.MustParse(accountID)) } func (a *accountPermissionService) GetAccountGroups(ctx context.Context, accountID string) ([]string, error) { - ag, err := a.pdao.GetAccountGroups(ctx, uuid.MustParse(accountID)) + ag, err := dao.GetAccountGroups(ctx, a.db, uuid.MustParse(accountID)) if err != nil { return nil, err } @@ -131,7 +120,7 @@ func (a *accountPermissionService) GetSSOAccount(ctx context.Context, accountID, //TODO: this needs to be revisited as sso users for oidc are stored in identities by kratos func (a *accountPermissionService) GetSSOUsersGroupProjectRole(ctx context.Context, orgID string) ([]sentry.SSOAccountGroupProjectRoleData, error) { - ssos, err := a.pdao.GetSSOUsersGroupProjectRole(ctx, uuid.MustParse(orgID)) + ssos, err := dao.GetSSOUsersGroupProjectRole(ctx, a.db, uuid.MustParse(orgID)) if err != nil { return nil, err } @@ -144,7 +133,7 @@ func (a *accountPermissionService) GetSSOUsersGroupProjectRole(ctx context.Conte } func (a *accountPermissionService) GetAcccountsWithApprovalPermission(ctx context.Context, orgID, partnerID string) ([]string, error) { - usernames, err := a.pdao.GetAcccountsWithApprovalPermission(ctx, uuid.MustParse(orgID), uuid.MustParse(partnerID)) + usernames, err := dao.GetAcccountsWithApprovalPermission(ctx, a.db, uuid.MustParse(orgID), uuid.MustParse(partnerID)) if err != nil { return nil, err } @@ -152,7 +141,7 @@ func (a *accountPermissionService) GetAcccountsWithApprovalPermission(ctx contex } func (a *accountPermissionService) GetSSOAcccountsWithApprovalPermission(ctx context.Context, orgID, partnerID string) ([]string, error) { - usernames, err := a.pdao.GetSSOAcccountsWithApprovalPermission(ctx, uuid.MustParse(orgID), uuid.MustParse(partnerID)) + usernames, err := dao.GetSSOAcccountsWithApprovalPermission(ctx, a.db, uuid.MustParse(orgID), uuid.MustParse(partnerID)) if err != nil { return nil, err } @@ -160,17 +149,17 @@ func (a *accountPermissionService) GetSSOAcccountsWithApprovalPermission(ctx con } func (a *accountPermissionService) IsOrgAdmin(ctx context.Context, accountID, partnerID string) (isOrgAdmin bool, err error) { - return a.pdao.IsOrgAdmin(ctx, uuid.MustParse(accountID), uuid.MustParse(partnerID)) + return dao.IsOrgAdmin(ctx, a.db, uuid.MustParse(accountID), uuid.MustParse(partnerID)) } func (a *accountPermissionService) IsAccountActive(ctx context.Context, accountID, orgID string) (bool, error) { active := false - group, err := a.pdao.GetDefaultUserGroup(ctx, uuid.MustParse(orgID)) + group, err := dao.GetDefaultUserGroup(ctx, a.db, uuid.MustParse(orgID)) if err != nil { return false, err } - ga, err := a.pdao.GetDefaultUserGroupAccount(ctx, uuid.MustParse(accountID), group.ID) + ga, err := dao.GetDefaultUserGroupAccount(ctx, a.db, uuid.MustParse(accountID), group.ID) if err != nil { return active, err } diff --git a/pkg/service/account_permissions_test.go b/pkg/service/account_permissions_test.go index 2122c11..b44e98b 100644 --- a/pkg/service/account_permissions_test.go +++ b/pkg/service/account_permissions_test.go @@ -13,7 +13,6 @@ func TestGetAccountPermissions(t *testing.T) { defer db.Close() ps := NewAccountPermissionService(db) - defer ps.Close() aid := uuid.New().String() oid := uuid.New().String() @@ -33,7 +32,6 @@ func TestIsPartnerSuperAdmin(t *testing.T) { defer db.Close() ps := NewAccountPermissionService(db) - defer ps.Close() aid := uuid.New().String() pid := uuid.New().String() @@ -52,7 +50,6 @@ func TestGetAccountProjectsByPermission(t *testing.T) { defer db.Close() ps := NewAccountPermissionService(db) - defer ps.Close() aid := uuid.New().String() oid := uuid.New().String() @@ -72,7 +69,6 @@ func TestGetAccountPermissionsByProjectIDPermissions(t *testing.T) { defer db.Close() ps := NewAccountPermissionService(db) - defer ps.Close() projects := []string{"myproject"} permissions := []string{"read"} @@ -94,7 +90,6 @@ func TestIsOrgAdmin(t *testing.T) { defer db.Close() ps := NewAccountPermissionService(db) - defer ps.Close() aid := uuid.New().String() pid := uuid.New().String() @@ -113,7 +108,6 @@ func TestIsAccountActive(t *testing.T) { defer db.Close() ps := NewAccountPermissionService(db) - defer ps.Close() aid := uuid.New().String() oid := uuid.New().String() @@ -135,7 +129,6 @@ func TestGetAccount(t *testing.T) { defer db.Close() ps := NewAccountPermissionService(db) - defer ps.Close() aid := uuid.New().String() @@ -153,7 +146,6 @@ func TestGetAccountGroups(t *testing.T) { defer db.Close() ps := NewAccountPermissionService(db) - defer ps.Close() aid := uuid.New().String() diff --git a/pkg/service/apikey.go b/pkg/service/apikey.go index 77241d9..69ca4b7 100644 --- a/pkg/service/apikey.go +++ b/pkg/service/apikey.go @@ -16,7 +16,6 @@ import ( // ApiKeyService is the interface for api key operations type ApiKeyService interface { - Close() error // create api key Create(ctx context.Context, req *rpcv3.ApiKeyRequest) (*models.ApiKey, error) // get by user @@ -31,14 +30,12 @@ type ApiKeyService interface { // apiKeyService implements ApiKeyService type apiKeyService struct { - dao pg.EntityDAO + db *bun.DB } // NewApiKeyService return new api key service func NewApiKeyService(db *bun.DB) ApiKeyService { - return &apiKeyService{ - dao: pg.NewEntityDAO(db), - } + return &apiKeyService{db} } func (s *apiKeyService) Create(ctx context.Context, req *rpcv3.ApiKeyRequest) (*models.ApiKey, error) { @@ -52,7 +49,7 @@ func (s *apiKeyService) Create(ctx context.Context, req *rpcv3.ApiKeyRequest) (* Secret: crypto.GenerateSha256Secret(), } - _, err := s.dao.Create(ctx, apikey) + _, err := pg.Create(ctx, s.db, apikey) if err != nil { return nil, err } @@ -60,7 +57,7 @@ func (s *apiKeyService) Create(ctx context.Context, req *rpcv3.ApiKeyRequest) (* } func (s *apiKeyService) Delete(ctx context.Context, req *rpcv3.ApiKeyRequest) (*rpcv3.DeleteUserResponse, error) { - _, err := s.dao.GetInstance().NewUpdate().Model(&models.ApiKey{}). + _, err := s.db.NewUpdate().Model(&models.ApiKey{}). Set("trash = ?", true). Where("name = ?", req.Username). Where("key = ?", req.Id).Exec(ctx) @@ -69,7 +66,7 @@ func (s *apiKeyService) Delete(ctx context.Context, req *rpcv3.ApiKeyRequest) (* func (s *apiKeyService) List(ctx context.Context, req *rpcv3.ApiKeyRequest) (*rpcv3.ApiKeyResponseList, error) { var apikeys []models.ApiKey - resp, err := s.dao.GetByName(ctx, req.Username, &apikeys) + resp, err := pg.GetByName(ctx, s.db, req.Username, &apikeys) if err == sql.ErrNoRows { return nil, nil } @@ -93,7 +90,7 @@ func (s *apiKeyService) List(ctx context.Context, req *rpcv3.ApiKeyRequest) (*rp func (s *apiKeyService) Get(ctx context.Context, req *rpcv3.ApiKeyRequest) (*models.ApiKey, error) { var apikey models.ApiKey - _, err := s.dao.GetByName(ctx, req.Username, &apikey) + _, err := pg.GetByName(ctx, s.db, req.Username, &apikey) if err == sql.ErrNoRows { return nil, nil } @@ -102,13 +99,9 @@ func (s *apiKeyService) Get(ctx context.Context, req *rpcv3.ApiKeyRequest) (*mod func (s *apiKeyService) GetByKey(ctx context.Context, req *rpcv3.ApiKeyRequest) (*models.ApiKey, error) { var apikey models.ApiKey - _, err := s.dao.GetX(ctx, "key", req.Id, &apikey) + _, err := pg.GetX(ctx, s.db, "key", req.Id, &apikey) if err != nil { return nil, err } return &apikey, err } - -func (s *apiKeyService) Close() error { - return s.dao.Close() -} diff --git a/pkg/service/bootstrap.go b/pkg/service/bootstrap.go index f47b87d..c5f8602 100644 --- a/pkg/service/bootstrap.go +++ b/pkg/service/bootstrap.go @@ -24,8 +24,6 @@ var KEKFunc cryptoutil.PasswordFunc // BootstrapService is the interface for bootstrap operations type BootstrapService interface { - Close() error - // bootstrap infra methods PatchBootstrapInfra(ctx context.Context, infra *sentry.BootstrapInfra) error GetBootstrapInfra(ctx context.Context, name string) (*sentry.BootstrapInfra, error) @@ -44,33 +42,28 @@ type BootstrapService interface { GetBootstrapAgentForClusterID(ctx context.Context, clusterID string, orgID string) (*sentry.BootstrapAgent, error) SelectBootstrapAgents(ctx context.Context, templateRef string, opts ...query.Option) (*sentry.BootstrapAgentList, error) RegisterBootstrapAgent(ctx context.Context, token string) error - DeleteBoostrapAgent(ctx context.Context, templateRef string, opts ...query.Option) error + DeleteBootstrapAgent(ctx context.Context, templateRef string, opts ...query.Option) error PatchBootstrapAgent(ctx context.Context, ba *sentry.BootstrapAgent, templateRef string, opts ...query.Option) error } // bootstrapService implements BootstrapService type bootstrapService struct { - dao pg.EntityDAO - bdao dao.BootstrapDao + db *bun.DB } // NewBootstrapService return new bootstrap service func NewBootstrapService(db *bun.DB) BootstrapService { - edao := pg.NewEntityDAO(db) - return &bootstrapService{ - dao: edao, - bdao: dao.NewBootstrapDao(edao), - } + return &bootstrapService{db} } func (s *bootstrapService) PatchBootstrapInfra(ctx context.Context, infra *sentry.BootstrapInfra) error { - return s.bdao.CreateOrUpdateBootstrapInfra(ctx, convertToInfraModel(infra)) + return dao.CreateOrUpdateBootstrapInfra(ctx, s.db, convertToInfraModel(infra)) } func (s *bootstrapService) GetBootstrapInfra(ctx context.Context, name string) (*sentry.BootstrapInfra, error) { var bi models.BootstrapInfra - _, err := s.dao.GetByName(ctx, name, &bi) + _, err := pg.GetByName(ctx, s.db, name, &bi) if err != nil { return nil, err } @@ -96,12 +89,12 @@ func (s *bootstrapService) PatchBootstrapAgentTemplate(ctx context.Context, temp CreatedAt: time.Now(), } - return s.bdao.CreateOrUpdateBootstrapAgentTemplate(ctx, &templ) + return dao.CreateOrUpdateBootstrapAgentTemplate(ctx, s.db, &templ) } func (s *bootstrapService) GetBootstrapAgentTemplate(ctx context.Context, agentType string) (*sentry.BootstrapAgentTemplate, error) { var template models.BootstrapAgentTemplate - _, err := s.dao.GetByName(ctx, agentType, &template) + _, err := pg.GetByName(ctx, s.db, agentType, &template) if err != nil { return nil, err } @@ -110,7 +103,7 @@ func (s *bootstrapService) GetBootstrapAgentTemplate(ctx context.Context, agentT } func (s *bootstrapService) GetBootstrapAgentTemplateForToken(ctx context.Context, token string) (*sentry.BootstrapAgentTemplate, error) { - bat, err := s.bdao.GetBootstrapAgentTemplateForToken(ctx, token) + bat, err := dao.GetBootstrapAgentTemplateForToken(ctx, s.db, token) if err != nil { return nil, err } @@ -123,7 +116,7 @@ func (s *bootstrapService) SelectBootstrapAgentTemplates(ctx context.Context, op opt(queryOptions) } - batl, count, err := s.bdao.SelectBootstrapAgentTemplates(ctx, queryOptions) + batl, count, err := dao.SelectBootstrapAgentTemplates(ctx, s.db, queryOptions) if err != nil { return nil, err } @@ -141,7 +134,7 @@ func (s *bootstrapService) SelectBootstrapAgentTemplates(ctx context.Context, op func (s *bootstrapService) CreateBootstrapAgent(ctx context.Context, agent *sentry.BootstrapAgent) error { ba := convertToAgentModel(agent) ba.CreatedAt = time.Now() - return s.bdao.CreateBootstrapAgent(ctx, ba) + return dao.CreateBootstrapAgent(ctx, s.db, ba) } func convertToAgentModel(agent *sentry.BootstrapAgent) *models.BootstrapAgent { @@ -252,7 +245,7 @@ func prepareTemplateResponse(template *models.BootstrapAgentTemplate) *sentry.Bo json.Unmarshal(template.Hosts, &hosts) } templResp := sentry.BootstrapAgentTemplate{ - Kind: "BootstapAgentTemplate", + Kind: "BootstrapAgentTemplate", Metadata: &commonv3.Metadata{ Name: template.Name, DisplayName: template.DisplayName, @@ -281,7 +274,7 @@ func (s *bootstrapService) GetBootstrapAgents(ctx context.Context, templateRef s opt(queryOptions) } - agl, count, err := s.bdao.GetBootstrapAgents(ctx, queryOptions, templateRef) + agl, count, err := dao.GetBootstrapAgents(ctx, s.db, queryOptions, templateRef) if err != nil { return nil, err } @@ -301,7 +294,7 @@ func (s *bootstrapService) GetBootstrapAgent(ctx context.Context, templateRef st for _, opt := range opts { opt(queryOptions) } - ba, err := s.bdao.GetBootstrapAgent(ctx, templateRef, queryOptions) + ba, err := dao.GetBootstrapAgent(ctx, s.db, templateRef, queryOptions) if err != nil { return nil, err } @@ -314,7 +307,7 @@ func (s *bootstrapService) SelectBootstrapAgents(ctx context.Context, templateRe opt(queryOptions) } - agl, count, err := s.bdao.SelectBootstrapAgents(ctx, templateRef, queryOptions) + agl, count, err := dao.SelectBootstrapAgents(ctx, s.db, templateRef, queryOptions) if err != nil { return nil, err } @@ -331,16 +324,19 @@ func (s *bootstrapService) SelectBootstrapAgents(ctx context.Context, templateRe } func (s *bootstrapService) RegisterBootstrapAgent(ctx context.Context, token string) error { - return s.bdao.RegisterBootstrapAgent(ctx, token) + err := s.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { + return dao.RegisterBootstrapAgent(ctx, tx, token) + }) + return err } -func (s *bootstrapService) DeleteBoostrapAgent(ctx context.Context, templateRef string, opts ...query.Option) error { +func (s *bootstrapService) DeleteBootstrapAgent(ctx context.Context, templateRef string, opts ...query.Option) error { queryOptions := &commonv3.QueryOptions{} for _, opt := range opts { opt(queryOptions) } - err := s.bdao.DeleteBootstrapAgent(ctx, templateRef, queryOptions) + err := dao.DeleteBootstrapAgent(ctx, s.db, templateRef, queryOptions) return err } @@ -351,8 +347,8 @@ func (s *bootstrapService) PatchBootstrapAgent(ctx context.Context, ba *sentry.B opt(queryOptions) } - err := s.dao.GetInstance().RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { - bdb, err := s.bdao.GetBootstrapAgent(ctx, templateRef, queryOptions) + err := s.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { + bdb, err := dao.GetBootstrapAgent(ctx, s.db, templateRef, queryOptions) if err != nil { return err } @@ -376,13 +372,13 @@ func (s *bootstrapService) PatchBootstrapAgent(ctx context.Context, ba *sentry.B } bdb.ModifiedAt = time.Now() bdb.DisplayName = ba.Metadata.DisplayName - return s.bdao.UpdateBootstrapAgent(ctx, bdb, queryOptions) + return dao.UpdateBootstrapAgent(ctx, s.db, bdb, queryOptions) }) return err } func (s *bootstrapService) GetBootstrapAgentForToken(ctx context.Context, token string) (*sentry.BootstrapAgent, error) { - ba, err := s.bdao.GetBootstrapAgentForToken(ctx, token) + ba, err := dao.GetBootstrapAgentForToken(ctx, s.db, token) if err != nil { return nil, err } @@ -390,7 +386,7 @@ func (s *bootstrapService) GetBootstrapAgentForToken(ctx context.Context, token } func (s *bootstrapService) GetBootstrapAgentTemplateForHost(ctx context.Context, host string) (*sentry.BootstrapAgentTemplate, error) { - bat, err := s.bdao.GetBootstrapAgentTemplateForHost(ctx, host) + bat, err := dao.GetBootstrapAgentTemplateForHost(ctx, s.db, host) if err != nil { return nil, err } @@ -400,7 +396,7 @@ func (s *bootstrapService) GetBootstrapAgentTemplateForHost(ctx context.Context, } func (s *bootstrapService) GetBootstrapAgentCountForClusterID(ctx context.Context, clusterID string, orgID string) (int, error) { - count, err := s.bdao.GetBootstrapAgentCountForClusterID(ctx, clusterID, uuid.MustParse(orgID)) + count, err := dao.GetBootstrapAgentCountForClusterID(ctx, s.db, clusterID, uuid.MustParse(orgID)) if err != nil { return 0, err } @@ -411,7 +407,7 @@ func (s *bootstrapService) GetBootstrapAgentCountForClusterID(ctx context.Contex } func (s *bootstrapService) GetBootstrapAgentForClusterID(ctx context.Context, clusterID string, orgID string) (*sentry.BootstrapAgent, error) { - ba, err := s.bdao.GetBootstrapAgentForClusterID(ctx, clusterID, uuid.MustParse(orgID)) + ba, err := dao.GetBootstrapAgentForClusterID(ctx, s.db, clusterID, uuid.MustParse(orgID)) if err != nil || ba == nil { return nil, err } @@ -455,7 +451,3 @@ func (s *bootstrapService) GetRelayAgent(ctx context.Context, ClusterScope strin _log.Infow("did not find relay bootstrap agent for", "cluster", ClusterScope, "template", queryOptions.Name) return nil, fmt.Errorf("failed to get relay agent") } - -func (s *bootstrapService) Close() error { - return s.dao.Close() -} diff --git a/pkg/service/cluster.go b/pkg/service/cluster.go index 802ec26..c1e76ea 100644 --- a/pkg/service/cluster.go +++ b/pkg/service/cluster.go @@ -45,7 +45,6 @@ const ( ) type ClusterService interface { - Close() error // create Cluster Create(ctx context.Context, cluster *infrav3.Cluster) (*infrav3.Cluster, error) // get cluster @@ -90,12 +89,7 @@ type ClusterService interface { // clusterService implements ClusterService type clusterService struct { - dao pg.EntityDAO - cdao dao.ClusterDao - eobdao dao.ClusterOperatorBootstrapDao - pcdao dao.ProjectClusterDao - ctdao dao.ClusterTokenDao - cndao dao.ClusterNamespacesDao + db *bun.DB downloadData common.DownloadData clusterHandlers []event.Handler bs BootstrapService @@ -103,17 +97,7 @@ type clusterService struct { // NewClusterService return new cluster service func NewClusterService(db *bun.DB, data *common.DownloadData, bs BootstrapService) ClusterService { - entityDao := pg.NewEntityDAO(db) - return &clusterService{ - dao: entityDao, - cdao: dao.NewClusterDao(entityDao), - eobdao: dao.NewClusterOperatorBootstrapDao(entityDao), - pcdao: dao.NewProjectClusterDao(entityDao), - ctdao: dao.NewClusterTokenDao(entityDao), - cndao: dao.NewClusterNamespacesDao(entityDao), - downloadData: *data, - bs: bs, - } + return &clusterService{db: db, downloadData: *data, bs: bs} } func (es *clusterService) Create(ctx context.Context, cluster *infrav3.Cluster) (*infrav3.Cluster, error) { @@ -128,7 +112,7 @@ func (es *clusterService) Create(ctx context.Context, cluster *infrav3.Cluster) } var proj models.Project - _, err := es.dao.GetByName(ctx, cluster.Metadata.Project, &proj) + _, err := pg.GetByName(ctx, es.db, cluster.Metadata.Project, &proj) if err != nil { cluster.Status = &commonv3.Status{ ConditionType: "Create", @@ -191,7 +175,7 @@ func (es *clusterService) Create(ctx context.Context, cluster *infrav3.Cluster) return cluster, fmt.Errorf(errormsg) } - clusterPresent, err := es.dao.GetByNamePartnerOrg(ctx, cluster.Metadata.Name, uuid.NullUUID{UUID: proj.PartnerId, Valid: true}, + clusterPresent, err := pg.GetByNamePartnerOrg(ctx, es.db, cluster.Metadata.Name, uuid.NullUUID{UUID: proj.PartnerId, Valid: true}, uuid.NullUUID{UUID: proj.OrganizationId, Valid: true}, &models.Cluster{}) if err != nil && err.Error() == "sql: no rows in result set" { _log.Infof("Skipping as first time cluster create ") @@ -207,7 +191,7 @@ func (es *clusterService) Create(ctx context.Context, cluster *infrav3.Cluster) metro := &models.Metro{} if cluster.Spec.Metro != nil && cluster.Spec.Metro.Name != "" { - if mdb, err := es.dao.GetByNamePartnerOrg(ctx, cluster.Spec.Metro.Name, uuid.NullUUID{UUID: proj.PartnerId, Valid: true}, uuid.NullUUID{UUID: uuid.Nil, Valid: false}, metro); err != nil { + if mdb, err := pg.GetByNamePartnerOrg(ctx, es.db, cluster.Spec.Metro.Name, uuid.NullUUID{UUID: proj.PartnerId, Valid: true}, uuid.NullUUID{UUID: uuid.Nil, Valid: false}, metro); err != nil { errormsg = "Invalid cluster location, provide a valid metro name" cluster.Status = &commonv3.Status{ ConditionType: "Create", @@ -276,7 +260,9 @@ func (es *clusterService) Create(ctx context.Context, cluster *infrav3.Cluster) cluster.Spec.ClusterData.Health = infrav3.Health_EDGE_IGNORE - err = es.cdao.CreateCluster(ctx, edb) + err = es.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { + return dao.CreateCluster(ctx, tx, edb) + }) if err != nil { cluster.Status = &commonv3.Status{ ConditionType: "Create", @@ -293,7 +279,7 @@ func (es *clusterService) Create(ctx context.Context, cluster *infrav3.Cluster) ProjectID: edb.ProjectId, ClusterID: edb.ID, } - err = es.pcdao.CreateProjectCluster(ctx, pc) + err = dao.CreateProjectCluster(ctx, es.db, pc) if err != nil { cluster.Status = &commonv3.Status{ ConditionType: "Create", @@ -321,11 +307,14 @@ func (es *clusterService) Create(ctx context.Context, cluster *infrav3.Cluster) } _log.Infow("Creating cluster operator yaml", "clusterid", edb.ID) operatorSpecEncoded := base64.StdEncoding.EncodeToString([]byte(operatorSpecStr)) - boostrapData := models.ClusterOperatorBootstrap{ + bootstrapData := models.ClusterOperatorBootstrap{ ClusterId: edb.ID, YamlContent: operatorSpecEncoded, } - es.eobdao.CreateOperatorBootstrap(ctx, &boostrapData) + + es.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { + return dao.CreateOperatorBootstrap(ctx, tx, &bootstrapData) + }) } ev := event.Resource{ @@ -355,7 +344,7 @@ func (s *clusterService) Select(ctx context.Context, cluster *infrav3.Cluster, i if err != nil { id = uuid.Nil } - c, err := s.cdao.GetCluster(ctx, &models.Cluster{ID: id, Name: cluster.Metadata.Name}) + c, err := dao.GetCluster(ctx, s.db, &models.Cluster{ID: id, Name: cluster.Metadata.Name}) if err != nil { clstr.Status = &commonv3.Status{ ConditionStatus: commonv3.ConditionStatus_StatusFailed, @@ -366,7 +355,7 @@ func (s *clusterService) Select(ctx context.Context, cluster *infrav3.Cluster, i } var projects []models.ProjectCluster if isExtended { - projects, err = s.pcdao.GetProjectsForCluster(ctx, c.ID) + projects, err = dao.GetProjectsForCluster(ctx, s.db, c.ID) if err != nil { clstr.Status = &commonv3.Status{ ConditionStatus: commonv3.ConditionStatus_StatusFailed, @@ -379,7 +368,7 @@ func (s *clusterService) Select(ctx context.Context, cluster *infrav3.Cluster, i var metro *models.Metro if c.MetroId != uuid.Nil { - entity, err := s.dao.GetByID(ctx, c.MetroId, &models.Metro{}) + entity, err := pg.GetByID(ctx, s.db, c.MetroId, &models.Metro{}) if err != nil { _log.Errorf("failed to fetch metro details", err) } @@ -407,7 +396,7 @@ func (s *clusterService) Get(ctx context.Context, opts ...query.Option) (*infrav if err != nil { id = uuid.Nil } - c, err := s.cdao.GetCluster(ctx, &models.Cluster{ID: id, Name: queryOptions.Name}) + c, err := dao.GetCluster(ctx, s.db, &models.Cluster{ID: id, Name: queryOptions.Name}) if err != nil { clstr.Status = &commonv3.Status{ ConditionStatus: commonv3.ConditionStatus_StatusFailed, @@ -418,7 +407,7 @@ func (s *clusterService) Get(ctx context.Context, opts ...query.Option) (*infrav } var projects []models.ProjectCluster if queryOptions.Extended { - projects, err = s.pcdao.GetProjectsForCluster(ctx, c.ID) + projects, err = dao.GetProjectsForCluster(ctx, s.db, c.ID) if err != nil { clstr.Status = &commonv3.Status{ ConditionStatus: commonv3.ConditionStatus_StatusFailed, @@ -441,19 +430,19 @@ func (s *clusterService) Get(ctx context.Context, opts ...query.Option) (*infrav func (s *clusterService) prepareClusterResponse(ctx context.Context, clstr *infrav3.Cluster, c *models.Cluster, metro *models.Metro, projects []models.ProjectCluster, isExtended bool) *infrav3.Cluster { var part models.Partner - _, err := s.dao.GetNameById(ctx, c.PartnerId, &part) + _, err := pg.GetNameById(ctx, s.db, c.PartnerId, &part) if err != nil { _log.Infow("unable to fetch partner information, ", err.Error()) } var org models.Organization - _, err = s.dao.GetNameById(ctx, c.OrganizationId, &org) + _, err = pg.GetNameById(ctx, s.db, c.OrganizationId, &org) if err != nil { _log.Infow("unable to fetch organization information, ", err.Error()) } var proj models.Project - _, err = s.dao.GetNameById(ctx, c.ProjectId, &proj) + _, err = pg.GetNameById(ctx, s.db, c.ProjectId, &proj) if err != nil { _log.Infow("unable to fetch project information, ", err.Error()) } @@ -539,7 +528,7 @@ func (cs *clusterService) Update(ctx context.Context, cluster *infrav3.Cluster) return cluster, fmt.Errorf("invalid cluster data, name is missing") } - edb, err := cs.dao.GetByName(ctx, cluster.Metadata.Name, &models.Cluster{}) + edb, err := pg.GetByName(ctx, cs.db, cluster.Metadata.Name, &models.Cluster{}) if err != nil { cluster.Status = &commonv3.Status{ ConditionType: "Update", @@ -596,7 +585,7 @@ func (cs *clusterService) Update(ctx context.Context, cluster *infrav3.Cluster) if cluster.Spec.Metro != nil && cdb.MetroId.String() != cluster.Spec.Metro.Id { metro := &models.Metro{} if cluster.Spec.Metro.Name != "" { - if mdb, err := cs.dao.GetByNamePartnerOrg(ctx, cluster.Spec.Metro.Name, uuid.NullUUID{UUID: pid, Valid: true}, uuid.NullUUID{UUID: uuid.Nil, Valid: false}, metro); err != nil { + if mdb, err := pg.GetByNamePartnerOrg(ctx, cs.db, cluster.Spec.Metro.Name, uuid.NullUUID{UUID: pid, Valid: true}, uuid.NullUUID{UUID: uuid.Nil, Valid: false}, metro); err != nil { errormsg = "Invalid cluster location, provide a valid metro name" cluster.Status = &commonv3.Status{ ConditionType: "Update", @@ -631,7 +620,7 @@ func (cs *clusterService) Update(ctx context.Context, cluster *infrav3.Cluster) } } - err = cs.cdao.UpdateCluster(ctx, cdb) + err = dao.UpdateCluster(ctx, cs.db, cdb) if err != nil { cluster.Status = &commonv3.Status{ ConditionStatus: commonv3.ConditionStatus_StatusFailed, @@ -719,11 +708,11 @@ func (cs *clusterService) deleteCluster(ctx context.Context, clusterId, projectI ID: uuid.MustParse(clusterId), ProjectId: uuid.MustParse(projectId), } - err := cs.pcdao.DeleteProjectsForCluster(ctx, uuid.MustParse(clusterId)) + err := dao.DeleteProjectsForCluster(ctx, cs.db, uuid.MustParse(clusterId)) if err != nil { return errors.Wrapf(err, "could not delete projects for cluster %s", clusterId) } - return cs.cdao.DeleteCluster(ctx, &c) + return dao.DeleteCluster(ctx, cs.db, &c) } func (cs *clusterService) List(ctx context.Context, opts ...query.Option) (*infrav3.ClusterList, error) { @@ -734,12 +723,12 @@ func (cs *clusterService) List(ctx context.Context, opts ...query.Option) (*infr } var proj models.Project - _, err := cs.dao.GetByName(ctx, queryOptions.Project, &proj) + _, err := pg.GetByName(ctx, cs.db, queryOptions.Project, &proj) if err != nil { return nil, err } - cdb, err := cs.cdao.ListClusters(ctx, commonv3.QueryOptions{ + cdb, err := dao.ListClusters(ctx, cs.db, commonv3.QueryOptions{ Project: proj.ID.String(), Organization: proj.OrganizationId.String(), Partner: proj.PartnerId.String(), @@ -760,13 +749,13 @@ func (cs *clusterService) List(ctx context.Context, opts ...query.Option) (*infr var items []*infrav3.Cluster for _, clstr := range cdb { - projects, err := cs.pcdao.GetProjectsForCluster(ctx, clstr.ID) + projects, err := dao.GetProjectsForCluster(ctx, cs.db, clstr.ID) if err != nil { return nil, err } metro := &models.Metro{} if clstr.MetroId != uuid.Nil { - entity, err := cs.dao.GetByID(ctx, clstr.MetroId, &models.Metro{}) + entity, err := pg.GetByID(ctx, cs.db, clstr.MetroId, &models.Metro{}) if err != nil { return nil, err } @@ -831,7 +820,7 @@ func (s *clusterService) UpdateClusterConditionStatus(ctx context.Context, curre func (s *clusterService) UpdateClusterAnnotations(ctx context.Context, cluster *infrav3.Cluster) error { if len(cluster.Metadata.Annotations) > 0 { annBytes, _ := json.Marshal(cluster.Metadata.Annotations) - return s.cdao.UpdateClusterAnnotations(ctx, &models.Cluster{ + return dao.UpdateClusterAnnotations(ctx, s.db, &models.Cluster{ ID: uuid.MustParse(cluster.Metadata.Id), Annotations: json.RawMessage(annBytes), }) @@ -840,7 +829,7 @@ func (s *clusterService) UpdateClusterAnnotations(ctx context.Context, cluster * } func (s *clusterService) ListenClusters(ctx context.Context, mChan chan<- commonv3.Metadata) { - listener := pgdriver.NewListener(s.dao.GetInstance()) + listener := pgdriver.NewListener(s.db) listener.Listen(ctx, clusterNotifyChan) notifyChan := listener.Channel() listenerLoop: @@ -877,11 +866,11 @@ func (s *clusterService) GetClusterProjects(ctx context.Context, cluster *infrav if err != nil { id = uuid.Nil } - c, err := s.cdao.GetCluster(ctx, &models.Cluster{ID: id, Name: cluster.Metadata.Name}) + c, err := dao.GetCluster(ctx, s.db, &models.Cluster{ID: id, Name: cluster.Metadata.Name}) if err != nil { return nil, err } - projects, err := s.pcdao.GetProjectsForCluster(ctx, c.ID) + projects, err := dao.GetProjectsForCluster(ctx, s.db, c.ID) if err != nil { return nil, err } @@ -895,7 +884,7 @@ func (s *clusterService) UpdateStatus(ctx context.Context, current *infrav3.Clus opt(&queryOptions) } - isAllowed, err := s.pcdao.ValidateClusterAccess(ctx, queryOptions) + isAllowed, err := dao.ValidateClusterAccess(ctx, s.db, queryOptions) if err != nil { return err } @@ -958,7 +947,7 @@ func (s *clusterService) deleteBootstrapAgentForCluster(ctx context.Context, clu return err } - err = s.bs.DeleteBoostrapAgent(ctx, templateRef, query.WithMeta(agent.Metadata)) + err = s.bs.DeleteBootstrapAgent(ctx, templateRef, query.WithMeta(agent.Metadata)) if err != nil { return err } @@ -1163,7 +1152,7 @@ func (s *clusterService) notifyCluster(ctx context.Context, c *infrav3.Cluster) return } - err = s.cdao.Notify(clusterNotifyChan, string(b)) + err = dao.Notify(s.db, clusterNotifyChan, string(b)) if err != nil { _log.Infow("unable to send cluster notification", "error", err) return @@ -1173,7 +1162,3 @@ func (s *clusterService) notifyCluster(ctx context.Context, c *infrav3.Cluster) func (s *clusterService) AddEventHandler(evh event.Handler) { s.clusterHandlers = append(s.clusterHandlers, evh) } - -func (s *clusterService) Close() error { - return s.dao.Close() -} diff --git a/pkg/service/cluster_test.go b/pkg/service/cluster_test.go index a636110..07ae858 100644 --- a/pkg/service/cluster_test.go +++ b/pkg/service/cluster_test.go @@ -31,7 +31,6 @@ func TestCreateCluster(t *testing.T) { } ps := NewClusterService(db, downloadData, NewBootstrapService(db)) - defer ps.Close() puuid := uuid.New().String() cuuid := uuid.New().String() @@ -77,7 +76,6 @@ func TestUpdateCluster(t *testing.T) { } ps := NewClusterService(db, downloadData, NewBootstrapService(db)) - defer ps.Close() puuid := uuid.New().String() cuuid := uuid.New().String() @@ -112,7 +110,6 @@ func TestSelectCluster(t *testing.T) { } ps := NewClusterService(db, downloadData, NewBootstrapService(db)) - defer ps.Close() puuid := uuid.New().String() cuuid := uuid.New().String() @@ -147,7 +144,6 @@ func TestGetCluster(t *testing.T) { } ps := NewClusterService(db, downloadData, NewBootstrapService(db)) - defer ps.Close() puuid := uuid.New().String() cuuid := uuid.New().String() @@ -181,7 +177,6 @@ func TestListCluster(t *testing.T) { } ps := NewClusterService(db, downloadData, NewBootstrapService(db)) - defer ps.Close() puuid := uuid.New().String() ouuid := uuid.New().String() diff --git a/pkg/service/group.go b/pkg/service/group.go index 9f3300c..f6cb3c3 100644 --- a/pkg/service/group.go +++ b/pkg/service/group.go @@ -9,7 +9,6 @@ import ( "github.com/RafaySystems/rcloud-base/internal/dao" "github.com/RafaySystems/rcloud-base/internal/models" "github.com/RafaySystems/rcloud-base/internal/persistence/provider/pg" - "github.com/RafaySystems/rcloud-base/internal/utils" authzv1 "github.com/RafaySystems/rcloud-base/proto/types/authz" v3 "github.com/RafaySystems/rcloud-base/proto/types/commonpb/v3" userv3 "github.com/RafaySystems/rcloud-base/proto/types/userpb/v3" @@ -25,7 +24,6 @@ const ( // GroupService is the interface for group operations type GroupService interface { - Close() error // create group Create(context.Context, *userv3.Group) (*userv3.Group, error) // get group by id @@ -42,41 +40,34 @@ type GroupService interface { // groupService implements GroupService type groupService struct { - dao pg.EntityDAO - gdao dao.GroupDAO - l utils.Lookup - azc AuthzService + db *bun.DB + azc AuthzService } // NewGroupService return new group service func NewGroupService(db *bun.DB, azc AuthzService) GroupService { - return &groupService{ - dao: pg.NewEntityDAO(db), - gdao: dao.NewGroupDAO(db), - l: utils.NewLookup(db), - azc: azc, - } + return &groupService{db: db, azc: azc} } func (s *groupService) deleteGroupRoleRelaitons(ctx context.Context, groupId uuid.UUID, group *userv3.Group) (*userv3.Group, error) { // delete previous entries // TODO: single delete command - err := s.dao.DeleteX(ctx, "group_id", groupId, &models.GroupRole{}) + err := pg.DeleteX(ctx, s.db, "group_id", groupId, &models.GroupRole{}) if err != nil { return &userv3.Group{}, err } - err = s.dao.DeleteX(ctx, "group_id", groupId, &models.ProjectGroupRole{}) + err = pg.DeleteX(ctx, s.db, "group_id", groupId, &models.ProjectGroupRole{}) if err != nil { return &userv3.Group{}, err } - err = s.dao.DeleteX(ctx, "group_id", groupId, &models.ProjectGroupNamespaceRole{}) + err = pg.DeleteX(ctx, s.db, "group_id", groupId, &models.ProjectGroupNamespaceRole{}) if err != nil { return &userv3.Group{}, err } _, err = s.azc.DeletePolicies(ctx, &authzv1.Policy{Sub: "g:" + group.GetMetadata().GetName()}) if err != nil { - return &userv3.Group{}, fmt.Errorf("unable to delete gorup-role relations from authz; %v", err) + return &userv3.Group{}, fmt.Errorf("unable to delete group-role relations from authz; %v", err) } return group, nil } @@ -92,7 +83,7 @@ func (s *groupService) createGroupRoleRelations(ctx context.Context, group *user var ps []*authzv1.Policy for _, pnr := range projectNamespaceRoles { role := pnr.GetRole() - entity, err := s.dao.GetIdByName(ctx, role, &models.Role{}) + entity, err := pg.GetIdByName(ctx, s.db, role, &models.Role{}) if err != nil { return &userv3.Group{}, fmt.Errorf("unable to find role '%v'", role) } @@ -108,7 +99,7 @@ func (s *groupService) createGroupRoleRelations(ctx context.Context, group *user namespaceId := pnr.GetNamespace() // TODO: lookup id from name switch { case namespaceId != 0: - projectId, err := s.l.GetProjectId(ctx, project) + projectId, err := pg.GetProjectId(ctx, s.db, project) if err != nil { return &userv3.Group{}, fmt.Errorf("unable to find project '%v'", project) } @@ -133,7 +124,7 @@ func (s *groupService) createGroupRoleRelations(ctx context.Context, group *user Act: "*", }) case project != "": - projectId, err := s.l.GetProjectId(ctx, project) + projectId, err := pg.GetProjectId(ctx, s.db, project) if err != nil { return &userv3.Group{}, fmt.Errorf("unable to find project '%v'", project) } @@ -177,19 +168,19 @@ func (s *groupService) createGroupRoleRelations(ctx context.Context, group *user } } if len(pgnrs) > 0 { - _, err := s.dao.Create(ctx, &pgnrs) + _, err := pg.Create(ctx, s.db, &pgnrs) if err != nil { return &userv3.Group{}, err } } if len(pgrs) > 0 { - _, err := s.dao.Create(ctx, &pgrs) + _, err := pg.Create(ctx, s.db, &pgrs) if err != nil { return &userv3.Group{}, err } } if len(grs) > 0 { - _, err := s.dao.Create(ctx, &grs) + _, err := pg.Create(ctx, s.db, &grs) if err != nil { return &userv3.Group{}, err } @@ -206,14 +197,14 @@ func (s *groupService) createGroupRoleRelations(ctx context.Context, group *user } func (s *groupService) deleteGroupAccountRelations(ctx context.Context, groupId uuid.UUID, group *userv3.Group) (*userv3.Group, error) { - err := s.dao.DeleteX(ctx, "group_id", groupId, &models.GroupAccount{}) + err := pg.DeleteX(ctx, s.db, "group_id", groupId, &models.GroupAccount{}) if err != nil { return &userv3.Group{}, fmt.Errorf("unable to delete user; %v", err) } _, err = s.azc.DeleteUserGroups(ctx, &authzv1.UserGroup{Grp: "g:" + group.GetMetadata().GetName()}) if err != nil { - return &userv3.Group{}, fmt.Errorf("unable to delete gorup-user relations from authz; %v", err) + return &userv3.Group{}, fmt.Errorf("unable to delete group-user relations from authz; %v", err) } return group, nil } @@ -225,7 +216,7 @@ func (s *groupService) createGroupAccountRelations(ctx context.Context, groupId var ugs []*authzv1.UserGroup for _, account := range unique(group.GetSpec().GetUsers()) { // FIXME: do combined lookup - entity, err := s.dao.GetIdByTraits(ctx, account, &models.KratosIdentities{}) + entity, err := pg.GetIdByTraits(ctx, s.db, account, &models.KratosIdentities{}) if err != nil { return &userv3.Group{}, fmt.Errorf("unable to find user '%v'", account) } @@ -248,13 +239,13 @@ func (s *groupService) createGroupAccountRelations(ctx context.Context, groupId if len(grpaccs) == 0 { return group, nil } - _, err := s.dao.Create(ctx, &grpaccs) + _, err := pg.Create(ctx, s.db, &grpaccs) if err != nil { return &userv3.Group{}, err } // TODO: revert our db inserts if this fails - // Just FYI, the succcess can be false if we delete the db directly but casbin has it available internally + // Just FYI, the success can be false if we delete the db directly but casbin has it available internally _, err = s.azc.CreateUserGroups(ctx, &authzv1.UserGroups{UserGroups: ugs}) if err != nil { return &userv3.Group{}, fmt.Errorf("unable to create mapping in authz; %v", err) @@ -266,11 +257,11 @@ func (s *groupService) createGroupAccountRelations(ctx context.Context, groupId func (s *groupService) getPartnerOrganization(ctx context.Context, group *userv3.Group) (uuid.UUID, uuid.UUID, error) { partner := group.GetMetadata().GetPartner() org := group.GetMetadata().GetOrganization() - partnerId, err := s.l.GetPartnerId(ctx, partner) + partnerId, err := pg.GetPartnerId(ctx, s.db, partner) if err != nil { return uuid.Nil, uuid.Nil, err } - organizationId, err := s.l.GetOrganizationId(ctx, org) + organizationId, err := pg.GetOrganizationId(ctx, s.db, org) if err != nil { return partnerId, uuid.Nil, err } @@ -283,7 +274,7 @@ func (s *groupService) Create(ctx context.Context, group *userv3.Group) (*userv3 if err != nil { return nil, fmt.Errorf("unable to get partner and org id") } - g, _ := s.dao.GetIdByNamePartnerOrg(ctx, group.GetMetadata().GetName(), uuid.NullUUID{UUID: partnerId, Valid: true}, uuid.NullUUID{UUID: organizationId, Valid: true}, &models.Group{}) + g, _ := pg.GetIdByNamePartnerOrg(ctx, s.db, group.GetMetadata().GetName(), uuid.NullUUID{UUID: partnerId, Valid: true}, uuid.NullUUID{UUID: organizationId, Valid: true}, &models.Group{}) if g != nil { return nil, fmt.Errorf("group '%v' already exists", group.GetMetadata().GetName()) } @@ -298,7 +289,7 @@ func (s *groupService) Create(ctx context.Context, group *userv3.Group) (*userv3 PartnerId: partnerId, Type: group.GetSpec().GetType(), } - entity, err := s.dao.Create(ctx, &grp) + entity, err := pg.Create(ctx, s.db, &grp) if err != nil { return &userv3.Group{}, err } @@ -335,7 +326,7 @@ func (s *groupService) toV3Group(ctx context.Context, group *userv3.Group, grp * Labels: labels, ModifiedAt: timestamppb.New(grp.ModifiedAt), } - users, err := s.gdao.GetUsers(ctx, grp.ID) + users, err := dao.GetUsers(ctx, s.db, grp.ID) if err != nil { return &userv3.Group{}, err } @@ -344,7 +335,7 @@ func (s *groupService) toV3Group(ctx context.Context, group *userv3.Group, grp * userNames = append(userNames, u.Traits["email"].(string)) } - roles, err := s.gdao.GetRoles(ctx, grp.ID) + roles, err := dao.GetGroupRoles(ctx, s.db, grp.ID) if err != nil { return &userv3.Group{}, err } @@ -362,7 +353,7 @@ func (s *groupService) GetByID(ctx context.Context, group *userv3.Group) (*userv if err != nil { return &userv3.Group{}, err } - entity, err := s.dao.GetByID(ctx, uid, &models.Group{}) + entity, err := pg.GetByID(ctx, s.db, uid, &models.Group{}) if err != nil { return &userv3.Group{}, err } @@ -380,7 +371,7 @@ func (s *groupService) GetByName(ctx context.Context, group *userv3.Group) (*use if err != nil { return nil, fmt.Errorf("unable to get partner and org id") } - entity, err := s.dao.GetByNamePartnerOrg(ctx, name, uuid.NullUUID{UUID: partnerId, Valid: true}, uuid.NullUUID{UUID: organizationId, Valid: true}, &models.Group{}) + entity, err := pg.GetByNamePartnerOrg(ctx, s.db, name, uuid.NullUUID{UUID: partnerId, Valid: true}, uuid.NullUUID{UUID: organizationId, Valid: true}, &models.Group{}) if err != nil { return &userv3.Group{}, err } @@ -399,7 +390,7 @@ func (s *groupService) Update(ctx context.Context, group *userv3.Group) (*userv3 if err != nil { return nil, fmt.Errorf("unable to get partner and org id") } - entity, err := s.dao.GetByNamePartnerOrg(ctx, name, uuid.NullUUID{UUID: partnerId, Valid: true}, uuid.NullUUID{UUID: organizationId, Valid: true}, &models.Group{}) + entity, err := pg.GetByNamePartnerOrg(ctx, s.db, name, uuid.NullUUID{UUID: partnerId, Valid: true}, uuid.NullUUID{UUID: organizationId, Valid: true}, &models.Group{}) if err != nil { return &userv3.Group{}, fmt.Errorf("no group found with name '%v'", name) } @@ -429,7 +420,7 @@ func (s *groupService) Update(ctx context.Context, group *userv3.Group) (*userv3 return &userv3.Group{}, err } - _, err = s.dao.Update(ctx, grp.ID, grp) + _, err = pg.Update(ctx, s.db, grp.ID, grp) if err != nil { return &userv3.Group{}, err } @@ -451,7 +442,7 @@ func (s *groupService) Delete(ctx context.Context, group *userv3.Group) (*userv3 if err != nil { return &userv3.Group{}, fmt.Errorf("unable to get partner and org id") } - entity, err := s.dao.GetByNamePartnerOrg(ctx, name, uuid.NullUUID{UUID: partnerId, Valid: true}, uuid.NullUUID{UUID: organizationId, Valid: true}, &models.Group{}) + entity, err := pg.GetByNamePartnerOrg(ctx, s.db, name, uuid.NullUUID{UUID: partnerId, Valid: true}, uuid.NullUUID{UUID: organizationId, Valid: true}, &models.Group{}) if err != nil { return &userv3.Group{}, err } @@ -464,7 +455,7 @@ func (s *groupService) Delete(ctx context.Context, group *userv3.Group) (*userv3 if err != nil { return &userv3.Group{}, err } - err = s.dao.Delete(ctx, grp.ID, grp) + err = pg.Delete(ctx, s.db, grp.ID, grp) if err != nil { return &userv3.Group{}, err } @@ -483,16 +474,16 @@ func (s *groupService) List(ctx context.Context, group *userv3.Group) (*userv3.G }, } if len(group.Metadata.Organization) > 0 { - orgId, err := s.l.GetOrganizationId(ctx, group.Metadata.Organization) + orgId, err := pg.GetOrganizationId(ctx, s.db, group.Metadata.Organization) if err != nil { return groupList, err } - partId, err := s.l.GetPartnerId(ctx, group.Metadata.Partner) + partId, err := pg.GetPartnerId(ctx, s.db, group.Metadata.Partner) if err != nil { return groupList, err } var grps []models.Group - entities, err := s.dao.List(ctx, uuid.NullUUID{UUID: partId, Valid: true}, uuid.NullUUID{UUID: orgId, Valid: true}, &grps) + entities, err := pg.List(ctx, s.db, uuid.NullUUID{UUID: partId, Valid: true}, uuid.NullUUID{UUID: orgId, Valid: true}, &grps) if err != nil { return groupList, err } @@ -518,7 +509,3 @@ func (s *groupService) List(ctx context.Context, group *userv3.Group) (*userv3.G } return groupList, nil } - -func (s *groupService) Close() error { - return s.dao.Close() -} diff --git a/pkg/service/group_permission.go b/pkg/service/group_permission.go index 4708d54..ee4bd3e 100644 --- a/pkg/service/group_permission.go +++ b/pkg/service/group_permission.go @@ -6,7 +6,6 @@ import ( "github.com/RafaySystems/rcloud-base/internal/dao" "github.com/RafaySystems/rcloud-base/internal/models" - "github.com/RafaySystems/rcloud-base/internal/persistence/provider/pg" "github.com/RafaySystems/rcloud-base/proto/types/sentry" "github.com/google/uuid" "github.com/uptrace/bun" @@ -14,7 +13,6 @@ import ( // GroupPermissionService is the interface for group permission operations type GroupPermissionService interface { - Close() error GetGroupPermissions(ctx context.Context, groupNames []string, orgID, partnerID string) ([]sentry.GroupPermission, error) GetGroupProjectsByPermission(ctx context.Context, groupNames []string, orgID, partnerID string, permission string) ([]sentry.GroupPermission, error) GetGroupPermissionsByProjectIDPermissions(ctx context.Context, groupNames []string, orgID, partnerID string, projects []string, permissions []string) ([]sentry.GroupPermission, error) @@ -23,25 +21,16 @@ type GroupPermissionService interface { // groupPermissionService implements GroupPermissionService type groupPermissionService struct { - dao pg.EntityDAO - pdao dao.PermissionDao + db *bun.DB } // NewKubeconfigRevocation return new kubeconfig revocation service func NewGroupPermissionService(db *bun.DB) GroupPermissionService { - edao := pg.NewEntityDAO(db) - return &groupPermissionService{ - dao: edao, - pdao: dao.NewPermissionDao(edao), - } -} - -func (s *groupPermissionService) Close() error { - return s.dao.Close() + return &groupPermissionService{db} } func (s *groupPermissionService) GetGroupPermissions(ctx context.Context, groupNames []string, orgID, partnerID string) ([]sentry.GroupPermission, error) { - gps, err := s.pdao.GetGroupPermissions(ctx, groupNames, uuid.MustParse(orgID), uuid.MustParse(partnerID)) + gps, err := dao.GetGroupPermissions(ctx, s.db, groupNames, uuid.MustParse(orgID), uuid.MustParse(partnerID)) if err != nil { return nil, err } @@ -53,8 +42,8 @@ func (s *groupPermissionService) GetGroupPermissions(ctx context.Context, groupN return groupPermissions, nil } -func (a *groupPermissionService) GetGroupProjectsByPermission(ctx context.Context, groupNames []string, orgID, partnerID string, permission string) ([]sentry.GroupPermission, error) { - aps, err := a.pdao.GetGroupProjectsByPermission(ctx, groupNames, uuid.MustParse(orgID), uuid.MustParse(partnerID), permission) +func (s *groupPermissionService) GetGroupProjectsByPermission(ctx context.Context, groupNames []string, orgID, partnerID string, permission string) ([]sentry.GroupPermission, error) { + aps, err := dao.GetGroupProjectsByPermission(ctx, s.db, groupNames, uuid.MustParse(orgID), uuid.MustParse(partnerID), permission) if err != nil { return nil, err } @@ -67,7 +56,7 @@ func (a *groupPermissionService) GetGroupProjectsByPermission(ctx context.Contex } func (s *groupPermissionService) GetGroupPermissionsByProjectIDPermissions(ctx context.Context, groupNames []string, orgID, partnerID string, projects []string, permissions []string) ([]sentry.GroupPermission, error) { - gps, err := s.pdao.GetGroupPermissionsByProjectIDPermissions(ctx, groupNames, uuid.MustParse(orgID), uuid.MustParse(partnerID), projects, permissions) + gps, err := dao.GetGroupPermissionsByProjectIDPermissions(ctx, s.db, groupNames, uuid.MustParse(orgID), uuid.MustParse(partnerID), projects, permissions) if err != nil { return nil, err } @@ -80,7 +69,7 @@ func (s *groupPermissionService) GetGroupPermissionsByProjectIDPermissions(ctx c } func (s *groupPermissionService) GetProjectByGroup(ctx context.Context, groupNames []string, orgID, partnerID string) ([]sentry.GroupPermission, error) { - gps, err := s.pdao.GetProjectByGroup(ctx, groupNames, uuid.MustParse(orgID), uuid.MustParse(partnerID)) + gps, err := dao.GetProjectByGroup(ctx, s.db, groupNames, uuid.MustParse(orgID), uuid.MustParse(partnerID)) if err != nil { return nil, err } diff --git a/pkg/service/group_permission_test.go b/pkg/service/group_permission_test.go index b4d0566..1edd111 100644 --- a/pkg/service/group_permission_test.go +++ b/pkg/service/group_permission_test.go @@ -13,7 +13,6 @@ func TestGetGroupPermissions(t *testing.T) { defer db.Close() ps := NewGroupPermissionService(db) - defer ps.Close() groupNames := []string{"mygroup", "admin"} gid := uuid.New() @@ -34,7 +33,6 @@ func TestGetGroupProjectsByPermission(t *testing.T) { defer db.Close() ps := NewGroupPermissionService(db) - defer ps.Close() groupNames := []string{"mygroup", "admin"} gid := uuid.New() @@ -55,7 +53,6 @@ func TestGetGroupPermissionsByProjectIDPermissions(t *testing.T) { defer db.Close() ps := NewGroupPermissionService(db) - defer ps.Close() groupNames := []string{"mygroup", "admin"} projectNames := []string{"myproject"} diff --git a/pkg/service/group_test.go b/pkg/service/group_test.go index 09d5c05..c65c8c1 100644 --- a/pkg/service/group_test.go +++ b/pkg/service/group_test.go @@ -83,7 +83,6 @@ func TestCreateGroupNoUsersNoRoles(t *testing.T) { mazc := mockAuthzClient{} gs := NewGroupService(db, &mazc) - defer gs.Close() guuid := uuid.New().String() puuid := uuid.New().String() @@ -117,7 +116,6 @@ func TestCreateGroupDuplicate(t *testing.T) { mazc := mockAuthzClient{} gs := NewGroupService(db, &mazc) - defer gs.Close() guuid := uuid.New().String() puuid := uuid.New().String() @@ -161,7 +159,6 @@ func TestCreateGroupWithUsersNoRoles(t *testing.T) { mazc := mockAuthzClient{} gs := NewGroupService(db, &mazc) - defer gs.Close() guuid := uuid.New().String() puuid := uuid.New().String() @@ -226,7 +223,6 @@ func TestCreateGroupNoUsersWithRoles(t *testing.T) { mazc := mockAuthzClient{} gs := NewGroupService(db, &mazc) - defer gs.Close() guuid := uuid.New().String() puuid := uuid.New().String() @@ -303,7 +299,6 @@ func TestCreateGroupWithUsersWithRoles(t *testing.T) { mazc := mockAuthzClient{} gs := NewGroupService(db, &mazc) - defer gs.Close() guuid := uuid.New().String() puuid := uuid.New().String() @@ -384,7 +379,6 @@ func TestUpdateGroupWithUsersWithRoles(t *testing.T) { mazc := mockAuthzClient{} gs := NewGroupService(db, &mazc) - defer gs.Close() guuid := uuid.New().String() puuid := uuid.New().String() @@ -457,7 +451,6 @@ func TestGroupDelete(t *testing.T) { mazc := mockAuthzClient{} gs := NewGroupService(db, &mazc) - defer gs.Close() guuid := uuid.New().String() puuid := uuid.New().String() @@ -496,7 +489,6 @@ func TestGroupDeleteNonExist(t *testing.T) { mazc := mockAuthzClient{} gs := NewGroupService(db, &mazc) - defer gs.Close() guuid := uuid.New().String() puuid := uuid.New().String() @@ -520,7 +512,6 @@ func TestGroupGetByName(t *testing.T) { mazc := mockAuthzClient{} gs := NewGroupService(db, &mazc) - defer gs.Close() guuid := uuid.New().String() puuid := uuid.New().String() @@ -571,7 +562,6 @@ func TestGroupGetById(t *testing.T) { mazc := mockAuthzClient{} gs := NewGroupService(db, &mazc) - defer gs.Close() guuid := uuid.New().String() puuid := uuid.New().String() @@ -618,7 +608,6 @@ func TestGroupList(t *testing.T) { mazc := mockAuthzClient{} gs := NewGroupService(db, &mazc) - defer gs.Close() guuid1 := uuid.New().String() guuid2 := uuid.New().String() diff --git a/pkg/service/idp.go b/pkg/service/idp.go index 6809df9..9905db5 100644 --- a/pkg/service/idp.go +++ b/pkg/service/idp.go @@ -16,7 +16,6 @@ import ( "github.com/RafaySystems/rcloud-base/internal/models" "github.com/RafaySystems/rcloud-base/internal/persistence/provider/pg" - "github.com/RafaySystems/rcloud-base/internal/utils" commonv3 "github.com/RafaySystems/rcloud-base/proto/types/commonpb/v3" systemv3 "github.com/RafaySystems/rcloud-base/proto/types/systempb/v3" "github.com/google/uuid" @@ -35,17 +34,12 @@ type IdpService interface { } type idpService struct { - dao pg.EntityDAO + db *bun.DB appHost string - l utils.Lookup } func NewIdpService(db *bun.DB, hostUrl string) IdpService { - return &idpService{ - dao: pg.NewEntityDAO(db), - appHost: hostUrl, - l: utils.NewLookup(db), - } + return &idpService{db: db, appHost: hostUrl} } func generateAcsURL(id string, hostUrl string) string { @@ -110,11 +104,11 @@ func generateSpCert(host string) (string, string, error) { func (s *idpService) getPartnerOrganization(ctx context.Context, provider *systemv3.Idp) (uuid.UUID, uuid.UUID, error) { partner := provider.GetMetadata().GetPartner() org := provider.GetMetadata().GetOrganization() - partnerId, err := s.l.GetPartnerId(ctx, partner) + partnerId, err := pg.GetPartnerId(ctx, s.db, partner) if err != nil { return uuid.Nil, uuid.Nil, err } - organizationId, err := s.l.GetOrganizationId(ctx, org) + organizationId, err := pg.GetOrganizationId(ctx, s.db, org) if err != nil { return partnerId, uuid.Nil, err } @@ -137,8 +131,9 @@ func (s *idpService) Create(ctx context.Context, idp *systemv3.Idp) (*systemv3.I if err != nil { return nil, fmt.Errorf("unable to get partner and org id") } - i, _ := s.dao.GetIdByNamePartnerOrg( + i, _ := pg.GetIdByNamePartnerOrg( ctx, + s.db, idp.GetMetadata().GetName(), uuid.NullUUID{UUID: partnerId, Valid: true}, uuid.NullUUID{UUID: organizationId, Valid: true}, @@ -149,7 +144,7 @@ func (s *idpService) Create(ctx context.Context, idp *systemv3.Idp) (*systemv3.I } e := &models.Idp{} - s.dao.GetX(ctx, "domain", domain, e) + pg.GetX(ctx, s.db, "domain", domain, e) if e.Domain == domain { return &systemv3.Idp{}, fmt.Errorf("DUPLICATE DOMAIN") } @@ -181,7 +176,7 @@ func (s *idpService) Create(ctx context.Context, idp *systemv3.Idp) (*systemv3.I entity.SpCert = spcert entity.SpKey = spkey } - _, err = s.dao.Create(ctx, entity) + _, err = pg.Create(ctx, s.db, entity) if err != nil { return &systemv3.Idp{}, err } @@ -219,8 +214,8 @@ func (s *idpService) GetByID(ctx context.Context, idp *systemv3.Idp) (*systemv3. return &systemv3.Idp{}, err } entity := &models.Idp{} - // TODO: Check for existance of id before GetByID - _, err = s.dao.GetByID(ctx, id, entity) + // TODO: Check for existence of id before GetByID + _, err = pg.GetByID(ctx, s.db, id, entity) if err != nil { return &systemv3.Idp{}, err } @@ -261,7 +256,7 @@ func (s *idpService) GetByName(ctx context.Context, idp *systemv3.Idp) (*systemv return &systemv3.Idp{}, status.Error(codes.InvalidArgument, "EMPTY NAME") } entity := &models.Idp{} - _, err := s.dao.GetByName(ctx, name, entity) + _, err := pg.GetByName(ctx, s.db, name, entity) if err != nil { return &systemv3.Idp{}, err } @@ -307,14 +302,14 @@ func (s *idpService) Update(ctx context.Context, idp *systemv3.Idp) (*systemv3.I return &systemv3.Idp{}, status.Error(codes.InvalidArgument, "EMPTY DOMAIN") } - _, err := s.dao.GetByName(ctx, name, existingIdp) + _, err := pg.GetByName(ctx, s.db, name, existingIdp) if err != nil { // TODO: Handle both db and idp not exist errors // separately. return &systemv3.Idp{}, status.Errorf(codes.InvalidArgument, "IDP %q NOT EXIST", name) } - s.dao.GetX(ctx, "domain", domain, existingIdp) + pg.GetX(ctx, s.db, "domain", domain, existingIdp) if existingIdp.Domain == domain { return &systemv3.Idp{}, status.Error(codes.InvalidArgument, "DUPLICATE DOMAIN") } @@ -357,7 +352,7 @@ func (s *idpService) Update(ctx context.Context, idp *systemv3.Idp) (*systemv3.I entity.SpKey = spkey } - _, err = s.dao.Update(ctx, existingIdp.Id, entity) + _, err = pg.Update(ctx, s.db, existingIdp.Id, entity) if err != nil { return &systemv3.Idp{}, err } @@ -395,7 +390,7 @@ func (s *idpService) List(ctx context.Context) (*systemv3.IdpList, error) { orgID uuid.NullUUID parID uuid.NullUUID ) - _, err := s.dao.List(ctx, parID, orgID, &entities) + _, err := pg.List(ctx, s.db, parID, orgID, &entities) if err != nil { return &systemv3.IdpList{}, err } @@ -448,12 +443,12 @@ func (s *idpService) Delete(ctx context.Context, idp *systemv3.Idp) error { return status.Error(codes.InvalidArgument, "EMPTY NAME") } - _, err := s.dao.GetByName(ctx, name, entity) + _, err := pg.GetByName(ctx, s.db, name, entity) if err != nil { return status.Errorf(codes.InvalidArgument, "IDP %q NOT EXISTS", name) } - err = s.dao.Delete(ctx, entity.Id, &models.Idp{}) + err = pg.Delete(ctx, s.db, entity.Id, &models.Idp{}) if err != nil { return err } diff --git a/pkg/service/kubeconfig_revocation.go b/pkg/service/kubeconfig_revocation.go index 40063b5..7f2c95d 100644 --- a/pkg/service/kubeconfig_revocation.go +++ b/pkg/service/kubeconfig_revocation.go @@ -8,7 +8,6 @@ import ( "github.com/RafaySystems/rcloud-base/internal/constants" "github.com/RafaySystems/rcloud-base/internal/dao" "github.com/RafaySystems/rcloud-base/internal/models" - "github.com/RafaySystems/rcloud-base/internal/persistence/provider/pg" "github.com/RafaySystems/rcloud-base/proto/types/sentry" "github.com/google/uuid" "github.com/uptrace/bun" @@ -17,32 +16,22 @@ import ( // KubeconfigRevocation is the interface for bootstrap operations type KubeconfigRevocationService interface { - Close() error Get(ctx context.Context, orgID string, accountID string, isSSOUser bool) (*sentry.KubeconfigRevocation, error) Patch(ctx context.Context, kr *sentry.KubeconfigRevocation) error } // bootstrapService implements BootstrapService type kubeconfigRevocationService struct { - dao pg.EntityDAO - kdao dao.KubeconfigDao + db *bun.DB } // NewKubeconfigRevocation return new kubeconfig revocation service func NewKubeconfigRevocationService(db *bun.DB) KubeconfigRevocationService { - edao := pg.NewEntityDAO(db) - return &kubeconfigRevocationService{ - dao: edao, - kdao: dao.NewKubeconfigDao(edao), - } -} - -func (krs *kubeconfigRevocationService) Close() error { - return krs.dao.Close() + return &kubeconfigRevocationService{db} } func (krs *kubeconfigRevocationService) Get(ctx context.Context, orgID string, accountID string, isSSOUser bool) (*sentry.KubeconfigRevocation, error) { - kr, err := krs.kdao.GetKubeconfigRevocation(ctx, uuid.MustParse(orgID), uuid.MustParse(accountID), isSSOUser) + kr, err := dao.GetKubeconfigRevocation(ctx, krs.db, uuid.MustParse(orgID), uuid.MustParse(accountID), isSSOUser) if err == sql.ErrNoRows { return nil, constants.ErrNotFound } else if err != nil { @@ -63,14 +52,14 @@ func prepareKubeCfgRevocationResponse(kr *models.KubeconfigRevocation) *sentry.K } func (krs *kubeconfigRevocationService) Patch(ctx context.Context, kr *sentry.KubeconfigRevocation) error { - err := krs.dao.GetInstance().RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { - _, err := krs.kdao.GetKubeconfigRevocation(ctx, uuid.MustParse(kr.OrganizationID), uuid.MustParse(kr.AccountID), kr.IsSSOUser) + err := krs.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { + _, err := dao.GetKubeconfigRevocation(ctx, krs.db, uuid.MustParse(kr.OrganizationID), uuid.MustParse(kr.AccountID), kr.IsSSOUser) if err != nil && err == sql.ErrNoRows { kcr := convertToModel(kr) kcr.CreatedAt = time.Now() - return krs.kdao.CreateKubeconfigRevocation(ctx, kcr) + return dao.CreateKubeconfigRevocation(ctx, krs.db, kcr) } - return krs.kdao.UpdateKubeconfigRevocation(ctx, convertToModel(kr)) + return dao.UpdateKubeconfigRevocation(ctx, krs.db, convertToModel(kr)) }) return err } diff --git a/pkg/service/kubeconfig_revocation_test.go b/pkg/service/kubeconfig_revocation_test.go index 4635d58..0b92e21 100644 --- a/pkg/service/kubeconfig_revocation_test.go +++ b/pkg/service/kubeconfig_revocation_test.go @@ -13,7 +13,6 @@ func TestGetKubeconfigRevocation(t *testing.T) { defer db.Close() ps := NewKubeconfigRevocationService(db) - defer ps.Close() ouuid := uuid.New().String() cuuid := uuid.New().String() diff --git a/pkg/service/kubeconfig_settings.go b/pkg/service/kubeconfig_settings.go index 3038f96..95a28ac 100644 --- a/pkg/service/kubeconfig_settings.go +++ b/pkg/service/kubeconfig_settings.go @@ -8,7 +8,6 @@ import ( "github.com/RafaySystems/rcloud-base/internal/constants" "github.com/RafaySystems/rcloud-base/internal/dao" "github.com/RafaySystems/rcloud-base/internal/models" - "github.com/RafaySystems/rcloud-base/internal/persistence/provider/pg" "github.com/RafaySystems/rcloud-base/proto/types/sentry" "github.com/google/uuid" "github.com/uptrace/bun" @@ -17,28 +16,18 @@ import ( // KubeconfigSettingService is the interface for kube config setting operations type KubeconfigSettingService interface { - Close() error Get(ctx context.Context, orgID string, accountID string, isSSO bool) (*sentry.KubeconfigSetting, error) Patch(ctx context.Context, ks *sentry.KubeconfigSetting) error } // kubeconfigSettingService implements KubeconfigSettingService type kubeconfigSettingService struct { - dao pg.EntityDAO - kdao dao.KubeconfigDao + db *bun.DB } // NewKubeconfigSettingService return new kubeconfig setting service func NewKubeconfigSettingService(db *bun.DB) KubeconfigSettingService { - edao := pg.NewEntityDAO(db) - return &kubeconfigSettingService{ - dao: edao, - kdao: dao.NewKubeconfigDao(edao), - } -} - -func (krs *kubeconfigSettingService) Close() error { - return krs.dao.Close() + return &kubeconfigSettingService{db} } func (kss *kubeconfigSettingService) Get(ctx context.Context, orgID string, accountID string, isSSO bool) (*sentry.KubeconfigSetting, error) { @@ -51,7 +40,7 @@ func (kss *kubeconfigSettingService) Get(ctx context.Context, orgID string, acco _log.Info("account identifier is empty") } - kr, err := kss.kdao.GetKubeconfigSetting(ctx, oid, aid, isSSO) + kr, err := dao.GetKubeconfigSetting(ctx, kss.db, oid, aid, isSSO) if err == sql.ErrNoRows { return nil, constants.ErrNotFound } else if err != nil { @@ -65,15 +54,15 @@ func (kss *kubeconfigSettingService) Patch(ctx context.Context, ks *sentry.Kubec if err != nil { accId = uuid.Nil } - err = kss.dao.GetInstance().RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { - _, err := kss.kdao.GetKubeconfigSetting(ctx, uuid.MustParse(ks.OrganizationID), accId, ks.IsSSOUser) + err = kss.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { + _, err := dao.GetKubeconfigSetting(ctx, kss.db, uuid.MustParse(ks.OrganizationID), accId, ks.IsSSOUser) db := convertToKubeCfgSettingModel(ks) if err != nil && err == sql.ErrNoRows { db.CreatedAt = time.Now() - return kss.kdao.CreateKubeconfigSetting(ctx, convertToKubeCfgSettingModel(ks)) + return dao.CreateKubeconfigSetting(ctx, kss.db, convertToKubeCfgSettingModel(ks)) } db.ModifiedAt = time.Now() - return kss.kdao.UpdateKubeconfigSetting(ctx, convertToKubeCfgSettingModel(ks)) + return dao.UpdateKubeconfigSetting(ctx, kss.db, convertToKubeCfgSettingModel(ks)) }) return err } diff --git a/pkg/service/kubectl_cluster_setting.go b/pkg/service/kubectl_cluster_setting.go index f4a024b..86aaf75 100644 --- a/pkg/service/kubectl_cluster_setting.go +++ b/pkg/service/kubectl_cluster_setting.go @@ -8,7 +8,6 @@ import ( "github.com/RafaySystems/rcloud-base/internal/constants" "github.com/RafaySystems/rcloud-base/internal/dao" "github.com/RafaySystems/rcloud-base/internal/models" - "github.com/RafaySystems/rcloud-base/internal/persistence/provider/pg" "github.com/RafaySystems/rcloud-base/proto/types/sentry" "github.com/google/uuid" "github.com/uptrace/bun" @@ -17,32 +16,22 @@ import ( // KubectlClusterSettingsService is the interface for kubectl cluster setting operations type KubectlClusterSettingsService interface { - Close() error Get(ctx context.Context, orgID string, clusterID string) (*sentry.KubectlClusterSettings, error) Patch(ctx context.Context, kc *sentry.KubectlClusterSettings) error } // kubectlClusterSettingsService implements KubectlClusterSettingsService type kubectlClusterSettingsService struct { - dao pg.EntityDAO - kdao dao.KubeconfigDao + db *bun.DB } // NewKubectlClusterSettingsService return new kubectl cluster setting service func NewkubectlClusterSettingsService(db *bun.DB) KubectlClusterSettingsService { - edao := pg.NewEntityDAO(db) - return &kubectlClusterSettingsService{ - dao: edao, - kdao: dao.NewKubeconfigDao(edao), - } -} - -func (kcs *kubectlClusterSettingsService) Close() error { - return kcs.dao.Close() + return &kubectlClusterSettingsService{db} } func (kcs *kubectlClusterSettingsService) Get(ctx context.Context, orgID string, clusterID string) (*sentry.KubectlClusterSettings, error) { - kc, err := kcs.kdao.GetkubectlClusterSettings(ctx, uuid.MustParse(orgID), clusterID) + kc, err := dao.GetkubectlClusterSettings(ctx, kcs.db, uuid.MustParse(orgID), clusterID) if err == sql.ErrNoRows { return nil, constants.ErrNotFound } else if err != nil { @@ -52,19 +41,19 @@ func (kcs *kubectlClusterSettingsService) Get(ctx context.Context, orgID string, } func (kcs *kubectlClusterSettingsService) Patch(ctx context.Context, kc *sentry.KubectlClusterSettings) error { - err := kcs.dao.GetInstance().RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { - _, err := kcs.kdao.GetkubectlClusterSettings(ctx, uuid.MustParse(kc.OrganizationID), kc.Name) + err := kcs.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { + _, err := dao.GetkubectlClusterSettings(ctx, kcs.db, uuid.MustParse(kc.OrganizationID), kc.Name) if err != nil { if err == sql.ErrNoRows { kcsdb := convertToKubeCtlSettingModel(kc) kcsdb.CreatedAt = time.Now() - kcs.kdao.CreatekubectlClusterSettings(ctx, kcsdb) + dao.CreatekubectlClusterSettings(ctx, kcs.db, kcsdb) } return err } kcsdb := convertToKubeCtlSettingModel(kc) kcsdb.ModifiedAt = time.Now() - return kcs.kdao.UpdatekubectlClusterSettings(ctx, kcsdb) + return dao.UpdatekubectlClusterSettings(ctx, kcs.db, kcsdb) }) return err } diff --git a/pkg/service/kubectl_cluster_setting_test.go b/pkg/service/kubectl_cluster_setting_test.go index 7f6afb2..0294159 100644 --- a/pkg/service/kubectl_cluster_setting_test.go +++ b/pkg/service/kubectl_cluster_setting_test.go @@ -13,7 +13,6 @@ func TestGetKubectlSetting(t *testing.T) { defer db.Close() ps := NewkubectlClusterSettingsService(db) - defer ps.Close() ouuid := uuid.New().String() cuuid := uuid.New().String() diff --git a/pkg/service/metro.go b/pkg/service/metro.go index 2dd63dc..87df315 100644 --- a/pkg/service/metro.go +++ b/pkg/service/metro.go @@ -15,7 +15,6 @@ import ( // MetroService is the interface for metro operations type MetroService interface { - Close() error // create metro Create(ctx context.Context, metro *infrav3.Location) (*infrav3.Location, error) // get metro by id @@ -34,20 +33,18 @@ type MetroService interface { // metroService implements MetroService type metroService struct { - dao pg.EntityDAO + db *bun.DB } // NewProjectService return new project service func NewMetroService(db *bun.DB) MetroService { - return &metroService{ - dao: pg.NewEntityDAO(db), - } + return &metroService{db} } func (s *metroService) Create(ctx context.Context, metro *infrav3.Location) (*infrav3.Location, error) { var part models.Partner - _, err := s.dao.GetByName(ctx, metro.Metadata.Partner, &part) + _, err := pg.GetByName(ctx, s.db, metro.Metadata.Partner, &part) if err != nil { return nil, err } @@ -68,7 +65,7 @@ func (s *metroService) Create(ctx context.Context, metro *infrav3.Location) (*in OrganizationId: uuid.Nil, PartnerId: part.ID, } - _, err = s.dao.Create(ctx, &metrodb) + _, err = pg.Create(ctx, s.db, &metrodb) if err != nil { return nil, err } @@ -81,7 +78,7 @@ func (s *metroService) GetByName(ctx context.Context, name string) (*infrav3.Loc var metro infrav3.Location - entity, err := s.dao.GetByName(ctx, name, &models.Metro{}) + entity, err := pg.GetByName(ctx, s.db, name, &models.Metro{}) if err != nil { return nil, err } @@ -112,7 +109,7 @@ func (s *metroService) GetByName(ctx context.Context, name string) (*infrav3.Loc func (s *metroService) GetById(ctx context.Context, id uuid.UUID) (*infrav3.Location, error) { var location infrav3.Location - entity, err := s.dao.GetByID(ctx, id, &models.Metro{}) + entity, err := pg.GetByID(ctx, s.db, id, &models.Metro{}) if err != nil { return nil, err } @@ -144,7 +141,7 @@ func (s *metroService) GetById(ctx context.Context, id uuid.UUID) (*infrav3.Loca func (s *metroService) Update(ctx context.Context, metro *infrav3.Location) (*infrav3.Location, error) { - entity, err := s.dao.GetByName(ctx, metro.Metadata.Name, &models.Metro{}) + entity, err := pg.GetByName(ctx, s.db, metro.Metadata.Name, &models.Metro{}) if err != nil { return metro, err } @@ -160,7 +157,7 @@ func (s *metroService) Update(ctx context.Context, metro *infrav3.Location) (*in metrodb.Longitude = metro.Spec.Longitude metrodb.ModifiedAt = time.Now() - _, err = s.dao.Update(ctx, metrodb.ID, metrodb) + _, err = pg.Update(ctx, s.db, metrodb.ID, metrodb) if err != nil { return metro, err } @@ -171,12 +168,12 @@ func (s *metroService) Update(ctx context.Context, metro *infrav3.Location) (*in func (s *metroService) Delete(ctx context.Context, metro *infrav3.Location) (*infrav3.Location, error) { - entity, err := s.dao.GetByName(ctx, metro.Metadata.Name, &models.Metro{}) + entity, err := pg.GetByName(ctx, s.db, metro.Metadata.Name, &models.Metro{}) if err != nil { return metro, err } if metrodb, ok := entity.(*models.Metro); ok { - err = s.dao.Delete(ctx, metrodb.ID, metrodb) + err = pg.Delete(ctx, s.db, metrodb.ID, metrodb) if err != nil { return metro, err } @@ -191,12 +188,12 @@ func (s *metroService) List(ctx context.Context, partner string) (*infrav3.Locat var metrodbs []models.Metro var part models.Partner - _, err := s.dao.GetByName(ctx, partner, &part) + _, err := pg.GetByName(ctx, s.db, partner, &part) if err != nil { return nil, err } - entities, err := s.dao.List(ctx, uuid.NullUUID{UUID: part.ID, Valid: true}, uuid.NullUUID{UUID: uuid.Nil, Valid: false}, &metrodbs) + entities, err := pg.List(ctx, s.db, uuid.NullUUID{UUID: part.ID, Valid: true}, uuid.NullUUID{UUID: uuid.Nil, Valid: false}, &metrodbs) if err != nil { return nil, err } @@ -229,7 +226,7 @@ func (s *metroService) List(ctx context.Context, partner string) (*infrav3.Locat } func (s *metroService) GetIDByName(ctx context.Context, name string) (uuid.UUID, error) { - entity, err := s.dao.GetByName(ctx, name, &models.Metro{}) + entity, err := pg.GetByName(ctx, s.db, name, &models.Metro{}) if err != nil { return uuid.Nil, err } @@ -239,7 +236,3 @@ func (s *metroService) GetIDByName(ctx context.Context, name string) (uuid.UUID, } return uuid.Nil, nil } - -func (s *metroService) Close() error { - return s.dao.Close() -} diff --git a/pkg/service/metro_test.go b/pkg/service/metro_test.go index 9a05474..b006edd 100644 --- a/pkg/service/metro_test.go +++ b/pkg/service/metro_test.go @@ -22,7 +22,6 @@ func TestCreateMetro(t *testing.T) { defer db.Close() ps := NewMetroService(db) - defer ps.Close() puuid := uuid.New().String() muuid := uuid.New().String() @@ -49,7 +48,6 @@ func TestCreateMetroDuplicate(t *testing.T) { defer db.Close() gs := NewMetroService(db) - defer gs.Close() muuid := uuid.New().String() @@ -72,7 +70,6 @@ func TestMetroDelete(t *testing.T) { defer db.Close() ps := NewMetroService(db) - defer ps.Close() puuid := uuid.New().String() @@ -97,7 +94,6 @@ func TestMetroDeleteNonExist(t *testing.T) { defer db.Close() ps := NewMetroService(db) - defer ps.Close() puuid := uuid.New().String() @@ -119,7 +115,6 @@ func TestMetroGetByName(t *testing.T) { defer db.Close() ps := NewMetroService(db) - defer ps.Close() muuid := uuid.New().String() @@ -142,7 +137,6 @@ func TestMetroUpdate(t *testing.T) { defer db.Close() ps := NewMetroService(db) - defer ps.Close() puuid := uuid.New().String() diff --git a/pkg/service/namespace.go b/pkg/service/namespace.go index 4b150d2..1fa280e 100644 --- a/pkg/service/namespace.go +++ b/pkg/service/namespace.go @@ -5,6 +5,7 @@ import ( "encoding/json" "strconv" + "github.com/RafaySystems/rcloud-base/internal/cluster/dao" "github.com/RafaySystems/rcloud-base/internal/models" "github.com/RafaySystems/rcloud-base/pkg/converter" "github.com/RafaySystems/rcloud-base/pkg/patch" @@ -17,7 +18,7 @@ import ( func (s *clusterService) GetNamespacesForConditions(ctx context.Context, conditions []scheduler.ClusterNamespaceCondition, clusterID string) (*scheduler.ClusterNamespaceList, error) { - cns, count, err := s.cndao.GetNamespacesForConditions(ctx, uuid.MustParse(clusterID), conditions) + cns, count, err := dao.GetNamespacesForConditions(ctx, s.db, uuid.MustParse(clusterID), conditions) if err != nil { return nil, err } @@ -65,7 +66,7 @@ func (s *clusterService) GetNamespacesForConditions(ctx context.Context, conditi func (s *clusterService) GetNamespaces(ctx context.Context, clusterID string) (*scheduler.ClusterNamespaceList, error) { - cns, err := s.cndao.GetNamespaces(ctx, uuid.MustParse(clusterID)) + cns, err := dao.GetNamespaces(ctx, s.db, uuid.MustParse(clusterID)) if err != nil { return nil, err } @@ -113,7 +114,7 @@ func (s *clusterService) GetNamespaces(ctx context.Context, clusterID string) (* func (s *clusterService) GetNamespace(ctx context.Context, namespace string, clusterID string) (*scheduler.ClusterNamespace, error) { - cn, err := s.cndao.GetNamespace(ctx, uuid.MustParse(clusterID), namespace) + cn, err := dao.GetNamespace(ctx, s.db, uuid.MustParse(clusterID), namespace) if err != nil { return nil, err } @@ -173,7 +174,7 @@ func (s *clusterService) UpdateNamespaceStatus(ctx context.Context, current *sch Status: converter.ConvertToJsonRawMessage(existing.Status), } - err = s.cndao.UpdateNamespaceStatus(ctx, &cn) + err = dao.UpdateNamespaceStatus(ctx, s.db, &cn) if err != nil { return err } @@ -192,6 +193,6 @@ func (s *clusterService) UpdateNamespaceStatus(ctx context.Context, current *sch } func (s *clusterService) GetNamespaceHashes(ctx context.Context, clusterID string) ([]infrav3.NameHash, error) { - nameHashes, err := s.cndao.GetNamespaceHashes(ctx, uuid.MustParse(clusterID)) + nameHashes, err := dao.GetNamespaceHashes(ctx, s.db, uuid.MustParse(clusterID)) return nameHashes, err } diff --git a/pkg/service/oidc_provider.go b/pkg/service/oidc_provider.go index e50d568..3b184e9 100644 --- a/pkg/service/oidc_provider.go +++ b/pkg/service/oidc_provider.go @@ -10,7 +10,6 @@ import ( "github.com/RafaySystems/rcloud-base/internal/models" "github.com/RafaySystems/rcloud-base/internal/persistence/provider/pg" - "github.com/RafaySystems/rcloud-base/internal/utils" commonv3 "github.com/RafaySystems/rcloud-base/proto/types/commonpb/v3" systemv3 "github.com/RafaySystems/rcloud-base/proto/types/systempb/v3" "github.com/google/uuid" @@ -30,17 +29,12 @@ type OIDCProviderService interface { } type oidcProvider struct { - dao pg.EntityDAO + db *bun.DB kratosUrl string - l utils.Lookup } func NewOIDCProviderService(db *bun.DB, kratosUrl string) OIDCProviderService { - return &oidcProvider{ - dao: pg.NewEntityDAO(db), - kratosUrl: kratosUrl, - l: utils.NewLookup(db), - } + return &oidcProvider{db: db, kratosUrl: kratosUrl} } func generateCallbackUrl(id string, kUrl string) string { @@ -56,11 +50,11 @@ func validateURL(rawURL string) error { func (s *oidcProvider) getPartnerOrganization(ctx context.Context, provider *systemv3.OIDCProvider) (uuid.UUID, uuid.UUID, error) { partner := provider.GetMetadata().GetPartner() org := provider.GetMetadata().GetOrganization() - partnerId, err := s.l.GetPartnerId(ctx, partner) + partnerId, err := pg.GetPartnerId(ctx, s.db, partner) if err != nil { return uuid.Nil, uuid.Nil, err } - organizationId, err := s.l.GetOrganizationId(ctx, org) + organizationId, err := pg.GetOrganizationId(ctx, s.db, org) if err != nil { return partnerId, uuid.Nil, err } @@ -81,8 +75,9 @@ func (s *oidcProvider) Create(ctx context.Context, provider *systemv3.OIDCProvid if err != nil { return nil, fmt.Errorf("unable to get partner and org id") } - p, _ := s.dao.GetIdByNamePartnerOrg( + p, _ := pg.GetIdByNamePartnerOrg( ctx, + s.db, provider.GetMetadata().GetName(), uuid.NullUUID{UUID: partnerId, Valid: true}, uuid.NullUUID{UUID: organizationId, Valid: true}, @@ -129,7 +124,7 @@ func (s *oidcProvider) Create(ctx context.Context, provider *systemv3.OIDCProvid RequestedClaims: provider.Spec.GetRequestedClaims().AsMap(), Predefined: provider.Spec.GetPredefined(), } - _, err = s.dao.Create(ctx, entity) + _, err = pg.Create(ctx, s.db, entity) if err != nil { return &systemv3.OIDCProvider{}, err } @@ -167,7 +162,7 @@ func (s *oidcProvider) GetByID(ctx context.Context, provider *systemv3.OIDCProvi } entity := &models.OIDCProvider{} - _, err = s.dao.GetByID(ctx, id, entity) + _, err = pg.GetByID(ctx, s.db, id, entity) // TODO: Return proper error for Id not exist if err != nil { return &systemv3.OIDCProvider{}, err @@ -206,7 +201,7 @@ func (s *oidcProvider) GetByName(ctx context.Context, provider *systemv3.OIDCPro } entity := &models.OIDCProvider{} - _, err := s.dao.GetByName(ctx, name, entity) + _, err := pg.GetByName(ctx, s.db, name, entity) if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -251,7 +246,7 @@ func (s *oidcProvider) List(ctx context.Context) (*systemv3.OIDCProviderList, er orgID uuid.NullUUID parID uuid.NullUUID ) - _, err := s.dao.List(ctx, parID, orgID, &entities) + _, err := pg.List(ctx, s.db, parID, orgID, &entities) if err != nil { return &systemv3.OIDCProviderList{}, nil } @@ -307,7 +302,7 @@ func (s *oidcProvider) Update(ctx context.Context, provider *systemv3.OIDCProvid } existingP := &models.OIDCProvider{} - _, err = s.dao.GetByName(ctx, name, existingP) + _, err = pg.GetByName(ctx, s.db, name, existingP) if err != nil { if errors.Is(err, sql.ErrNoRows) { return &systemv3.OIDCProvider{}, status.Errorf(codes.InvalidArgument, "OIDC PROVIDER %q NOT EXIST", name) @@ -352,7 +347,7 @@ func (s *oidcProvider) Update(ctx context.Context, provider *systemv3.OIDCProvid RequestedClaims: provider.Spec.GetRequestedClaims().AsMap(), Predefined: provider.Spec.GetPredefined(), } - _, err = s.dao.Update(ctx, existingP.Id, entity) + _, err = pg.Update(ctx, s.db, existingP.Id, entity) if err != nil { return &systemv3.OIDCProvider{}, err } @@ -389,12 +384,12 @@ func (s *oidcProvider) Delete(ctx context.Context, provider *systemv3.OIDCProvid if len(name) == 0 { return status.Error(codes.InvalidArgument, "EMPTY NAME") } - _, err := s.dao.GetByName(ctx, name, entity) + _, err := pg.GetByName(ctx, s.db, name, entity) if err != nil { return status.Errorf(codes.InvalidArgument, "OIDC PROVIDER %q NOT EXIST", name) } - err = s.dao.Delete(ctx, entity.Id, &models.OIDCProvider{}) + err = pg.Delete(ctx, s.db, entity.Id, &models.OIDCProvider{}) if err != nil { return err } diff --git a/pkg/service/organization.go b/pkg/service/organization.go index 314f3a2..9fee198 100644 --- a/pkg/service/organization.go +++ b/pkg/service/organization.go @@ -22,7 +22,6 @@ const ( // OrganizationService is the interface for organization operations type OrganizationService interface { - Close() error // create organization Create(ctx context.Context, organization *systemv3.Organization) (*systemv3.Organization, error) // get organization by id @@ -39,20 +38,18 @@ type OrganizationService interface { // organizationService implements OrganizationService type organizationService struct { - dao pg.EntityDAO + db *bun.DB } // NewOrganizationService return new organization service func NewOrganizationService(db *bun.DB) OrganizationService { - return &organizationService{ - dao: pg.NewEntityDAO(db), - } + return &organizationService{db} } func (s *organizationService) Create(ctx context.Context, org *systemv3.Organization) (*systemv3.Organization, error) { var partner models.Partner - _, err := s.dao.GetByName(ctx, org.Metadata.Partner, &partner) + _, err := pg.GetByName(ctx, s.db, org.Metadata.Partner, &partner) if err != nil { return nil, err } @@ -99,7 +96,7 @@ func (s *organizationService) Create(ctx context.Context, org *systemv3.Organiza CreatedAt: time.Now(), ModifiedAt: time.Now(), } - entity, err := s.dao.Create(ctx, &organization) + entity, err := pg.Create(ctx, s.db, &organization) if err != nil { org.Status = &v3.Status{ ConditionType: "Create", @@ -143,7 +140,7 @@ func (s *organizationService) GetByID(ctx context.Context, id string) (*systemv3 } return organization, err } - entity, err := s.dao.GetByID(ctx, uid, &models.Organization{}) + entity, err := pg.GetByID(ctx, s.db, uid, &models.Organization{}) if err != nil { organization.Status = &v3.Status{ ConditionType: "Describe", @@ -156,7 +153,7 @@ func (s *organizationService) GetByID(ctx context.Context, id string) (*systemv3 if org, ok := entity.(*models.Organization); ok { var partner models.Partner - _, err := s.dao.GetByID(ctx, org.PartnerId, &partner) + _, err := pg.GetByID(ctx, s.db, org.PartnerId, &partner) if err != nil { organization.Status = &v3.Status{ ConditionType: "Describe", @@ -206,7 +203,7 @@ func (s *organizationService) GetByName(ctx context.Context, name string) (*syst Name: name, }, } - entity, err := s.dao.GetByName(ctx, name, &models.Organization{}) + entity, err := pg.GetByName(ctx, s.db, name, &models.Organization{}) if err != nil { organization.Metadata = &v3.Metadata{ Name: name, @@ -223,7 +220,7 @@ func (s *organizationService) GetByName(ctx context.Context, name string) (*syst if org, ok := entity.(*models.Organization); ok { var partner models.Partner - _, err := s.dao.GetByID(ctx, org.PartnerId, &partner) + _, err := pg.GetByID(ctx, s.db, org.PartnerId, &partner) if err != nil { organization.Metadata = &v3.Metadata{ Name: name, @@ -254,7 +251,7 @@ func (s *organizationService) GetByName(ctx context.Context, name string) (*syst func (s *organizationService) Update(ctx context.Context, organization *systemv3.Organization) (*systemv3.Organization, error) { - entity, err := s.dao.GetByName(ctx, organization.Metadata.Name, &models.Organization{}) + entity, err := pg.GetByName(ctx, s.db, organization.Metadata.Name, &models.Organization{}) if err != nil { organization.Status = &v3.Status{ ConditionType: "Update", @@ -298,7 +295,7 @@ func (s *organizationService) Update(ctx context.Context, organization *systemv3 org.IsTOTPEnabled = organization.GetSpec().GetIsTotpEnabled() org.AreClustersShared = organization.GetSpec().GetAreClustersShared() - _, err = s.dao.Update(ctx, org.ID, org) + _, err = pg.Update(ctx, s.db, org.ID, org) if err != nil { organization.Status = &v3.Status{ ConditionType: "Update", @@ -324,7 +321,7 @@ func (s *organizationService) Update(ctx context.Context, organization *systemv3 func (s *organizationService) Delete(ctx context.Context, organization *systemv3.Organization) (*systemv3.Organization, error) { - entity, err := s.dao.GetByName(ctx, organization.Metadata.Name, &models.Organization{}) + entity, err := pg.GetByName(ctx, s.db, organization.Metadata.Name, &models.Organization{}) if err != nil { organization.Status = &v3.Status{ ConditionType: "Delete", @@ -337,7 +334,7 @@ func (s *organizationService) Delete(ctx context.Context, organization *systemv3 if org, ok := entity.(*models.Organization); ok { org.Trash = true - _, err := s.dao.Update(ctx, org.ID, org) + _, err := pg.Update(ctx, s.db, org.ID, org) if err != nil { organization.Status = &v3.Status{ ConditionType: "Delete", @@ -372,13 +369,13 @@ func (s *organizationService) List(ctx context.Context, organization *systemv3.O } if len(organization.Metadata.Partner) > 0 { var partner models.Partner - _, err := s.dao.GetByName(ctx, organization.Metadata.Partner, &partner) + _, err := pg.GetByName(ctx, s.db, organization.Metadata.Partner, &partner) if err != nil { return organinzationList, err } var orgs []models.Organization - entities, err := s.dao.List(ctx, uuid.NullUUID{UUID: partner.ID, Valid: true}, uuid.NullUUID{UUID: uuid.Nil}, &orgs) + entities, err := pg.List(ctx, s.db, uuid.NullUUID{UUID: partner.ID, Valid: true}, uuid.NullUUID{UUID: uuid.Nil}, &orgs) if err != nil { return organinzationList, err } @@ -470,7 +467,3 @@ func prepareOrganizationResponse(organization *systemv3.Organization, org *model return organization, nil } - -func (s *organizationService) Close() error { - return s.dao.Close() -} diff --git a/pkg/service/organization_test.go b/pkg/service/organization_test.go index 6f74024..a24274a 100644 --- a/pkg/service/organization_test.go +++ b/pkg/service/organization_test.go @@ -22,7 +22,6 @@ func TestCreateOrganization(t *testing.T) { defer db.Close() ps := NewOrganizationService(db) - defer ps.Close() puuid := uuid.New().String() ouuid := uuid.New().String() @@ -49,7 +48,6 @@ func TestCreateOrganizationDuplicate(t *testing.T) { defer db.Close() gs := NewOrganizationService(db) - defer gs.Close() ouuid := uuid.New().String() @@ -72,7 +70,6 @@ func TestOrganizationDelete(t *testing.T) { defer db.Close() ps := NewOrganizationService(db) - defer ps.Close() ouuid := uuid.New().String() @@ -96,7 +93,6 @@ func TestOrganizationDeleteNonExist(t *testing.T) { defer db.Close() ps := NewOrganizationService(db) - defer ps.Close() ouuid := uuid.New().String() @@ -117,7 +113,6 @@ func TestOrganizationGetByName(t *testing.T) { defer db.Close() ps := NewOrganizationService(db) - defer ps.Close() partuuid := uuid.New().String() ouuid := uuid.New().String() @@ -143,7 +138,6 @@ func TestOrganizationGetById(t *testing.T) { defer db.Close() ps := NewOrganizationService(db) - defer ps.Close() partuuid := uuid.New().String() puuid := uuid.New().String() @@ -169,7 +163,6 @@ func TestOrganizationUpdate(t *testing.T) { defer db.Close() ps := NewOrganizationService(db) - defer ps.Close() puuid := uuid.New().String() diff --git a/pkg/service/partner.go b/pkg/service/partner.go index 93f2211..fc7e1fc 100644 --- a/pkg/service/partner.go +++ b/pkg/service/partner.go @@ -16,7 +16,6 @@ import ( // PartnerService is the interface for partner operations type PartnerService interface { - Close() error // create partner Create(ctx context.Context, partner *systemv3.Partner) (*systemv3.Partner, error) // get partner by id @@ -31,14 +30,12 @@ type PartnerService interface { // partnerService implements PartnerService type partnerService struct { - dao pg.EntityDAO + db *bun.DB } // NewPartnerService return new partner service func NewPartnerService(db *bun.DB) PartnerService { - return &partnerService{ - dao: pg.NewEntityDAO(db), - } + return &partnerService{db} } func (s *partnerService) Create(ctx context.Context, partner *systemv3.Partner) (*systemv3.Partner, error) { @@ -68,7 +65,7 @@ func (s *partnerService) Create(ctx context.Context, partner *systemv3.Partner) CreatedAt: time.Now(), ModifiedAt: time.Now(), } - entity, err := s.dao.Create(ctx, &part) + entity, err := pg.Create(ctx, s.db, &part) if err != nil { partner.Status = &v3.Status{ ConditionStatus: v3.ConditionStatus_StatusFailed, @@ -113,7 +110,7 @@ func (s *partnerService) GetByID(ctx context.Context, id string) (*systemv3.Part } return partner, err } - entity, err := s.dao.GetByID(ctx, uid, &models.Partner{}) + entity, err := pg.GetByID(ctx, s.db, uid, &models.Partner{}) if err != nil { partner.Status = &v3.Status{ ConditionType: "Describe", @@ -182,7 +179,7 @@ func (s *partnerService) GetByName(ctx context.Context, name string) (*systemv3. }, } - entity, err := s.dao.GetByName(ctx, name, &models.Partner{}) + entity, err := pg.GetByName(ctx, s.db, name, &models.Partner{}) if err != nil { partner.Status = &v3.Status{ ConditionType: "Describe", @@ -244,7 +241,7 @@ func (s *partnerService) GetByName(ctx context.Context, name string) (*systemv3. func (s *partnerService) Update(ctx context.Context, partner *systemv3.Partner) (*systemv3.Partner, error) { - entity, err := s.dao.GetByName(ctx, partner.Metadata.Name, &models.Partner{}) + entity, err := pg.GetByName(ctx, s.db, partner.Metadata.Name, &models.Partner{}) if err != nil { partner.Status = &v3.Status{ ConditionStatus: v3.ConditionStatus_StatusFailed, @@ -278,7 +275,7 @@ func (s *partnerService) Update(ctx context.Context, partner *systemv3.Partner) part.ModifiedAt = time.Now() //Update the partner details - _, err = s.dao.Update(ctx, part.ID, part) + _, err = pg.Update(ctx, s.db, part.ID, part) if err != nil { partner.Status = &v3.Status{ ConditionStatus: v3.ConditionStatus_StatusFailed, @@ -301,7 +298,7 @@ func (s *partnerService) Update(ctx context.Context, partner *systemv3.Partner) } func (s *partnerService) Delete(ctx context.Context, partner *systemv3.Partner) (*systemv3.Partner, error) { - entity, err := s.dao.GetByName(ctx, partner.Metadata.Name, &models.Partner{}) + entity, err := pg.GetByName(ctx, s.db, partner.Metadata.Name, &models.Partner{}) if err != nil { partner.Status = &v3.Status{ ConditionType: "Delete", @@ -314,7 +311,7 @@ func (s *partnerService) Delete(ctx context.Context, partner *systemv3.Partner) if part, ok := entity.(*models.Partner); ok { part.Trash = true - _, err := s.dao.Update(ctx, part.ID, part) + _, err := pg.Update(ctx, s.db, part.ID, part) if err != nil { partner.Status = &v3.Status{ ConditionType: "Delete", @@ -338,7 +335,3 @@ func (s *partnerService) Delete(ctx context.Context, partner *systemv3.Partner) return partner, nil } - -func (s *partnerService) Close() error { - return s.dao.Close() -} diff --git a/pkg/service/partner_test.go b/pkg/service/partner_test.go index a67a5e4..ebcf9bd 100644 --- a/pkg/service/partner_test.go +++ b/pkg/service/partner_test.go @@ -38,7 +38,6 @@ func TestCreatePartner(t *testing.T) { defer db.Close() ps := NewPartnerService(db) - defer ps.Close() puuid := uuid.New().String() @@ -61,7 +60,6 @@ func TestCreatePartnerDuplicate(t *testing.T) { defer db.Close() gs := NewPartnerService(db) - defer gs.Close() puuid := uuid.New().String() @@ -84,7 +82,6 @@ func TestPartnerDelete(t *testing.T) { defer db.Close() ps := NewPartnerService(db) - defer ps.Close() puuid := uuid.New().String() @@ -108,7 +105,6 @@ func TestPartnerDeleteNonExist(t *testing.T) { defer db.Close() gs := NewPartnerService(db) - defer gs.Close() puuid := uuid.New().String() @@ -129,7 +125,6 @@ func TestPartnerGetByName(t *testing.T) { defer db.Close() ps := NewPartnerService(db) - defer ps.Close() puuid := uuid.New().String() @@ -150,7 +145,6 @@ func TestPartnerGetById(t *testing.T) { defer db.Close() ps := NewPartnerService(db) - defer ps.Close() puuid := uuid.New().String() @@ -172,7 +166,6 @@ func TestPartnerUpdate(t *testing.T) { defer db.Close() ps := NewPartnerService(db) - defer ps.Close() puuid := uuid.New().String() diff --git a/pkg/service/project.go b/pkg/service/project.go index 602e059..998d556 100644 --- a/pkg/service/project.go +++ b/pkg/service/project.go @@ -21,7 +21,6 @@ const ( // ProjectService is the interface for project operations type ProjectService interface { - Close() error // create project Create(ctx context.Context, project *systemv3.Project) (*systemv3.Project, error) // get project by id @@ -39,14 +38,12 @@ type ProjectService interface { // projectService implements ProjectService type projectService struct { - dao pg.EntityDAO + db *bun.DB } // NewProjectService return new project service func NewProjectService(db *bun.DB) ProjectService { - return &projectService{ - dao: pg.NewEntityDAO(db), - } + return &projectService{db} } func (s *projectService) Create(ctx context.Context, project *systemv3.Project) (*systemv3.Project, error) { @@ -56,7 +53,7 @@ func (s *projectService) Create(ctx context.Context, project *systemv3.Project) } var org models.Organization - _, err := s.dao.GetByName(ctx, project.Metadata.Organization, &org) + _, err := pg.GetByName(ctx, s.db, project.Metadata.Organization, &org) if err != nil { return nil, err } @@ -72,7 +69,7 @@ func (s *projectService) Create(ctx context.Context, project *systemv3.Project) PartnerId: org.PartnerId, Default: project.GetSpec().GetDefault(), } - entity, err := s.dao.Create(ctx, &proj) + entity, err := pg.Create(ctx, s.db, &proj) if err != nil { project.Status = &v3.Status{ ConditionType: "Create", @@ -121,7 +118,7 @@ func (s *projectService) GetByID(ctx context.Context, id string) (*systemv3.Proj } return project, err } - entity, err := s.dao.GetByID(ctx, uid, &models.Project{}) + entity, err := pg.GetByID(ctx, s.db, uid, &models.Project{}) if err != nil { project.Status = &v3.Status{ ConditionType: "Describe", @@ -167,7 +164,7 @@ func (s *projectService) GetByName(ctx context.Context, name string) (*systemv3. }, } - entity, err := s.dao.GetByName(ctx, name, &models.Project{}) + entity, err := pg.GetByName(ctx, s.db, name, &models.Project{}) if err != nil { project.Status = &v3.Status{ ConditionType: "Describe", @@ -181,13 +178,13 @@ func (s *projectService) GetByName(ctx context.Context, name string) (*systemv3. if proj, ok := entity.(*models.Project); ok { var org models.Organization - _, err := s.dao.GetByID(ctx, proj.OrganizationId, &org) + _, err := pg.GetByID(ctx, s.db, proj.OrganizationId, &org) if err != nil { return nil, err } var partner models.Partner - _, err = s.dao.GetByID(ctx, proj.PartnerId, &partner) + _, err = pg.GetByID(ctx, s.db, proj.PartnerId, &partner) if err != nil { return nil, err } @@ -216,7 +213,7 @@ func (s *projectService) GetByName(ctx context.Context, name string) (*systemv3. func (s *projectService) Update(ctx context.Context, project *systemv3.Project) (*systemv3.Project, error) { - entity, err := s.dao.GetByName(ctx, project.Metadata.Name, &models.Project{}) + entity, err := pg.GetByName(ctx, s.db, project.Metadata.Name, &models.Project{}) if err != nil { project.Status = &v3.Status{ ConditionType: "Update", @@ -233,7 +230,7 @@ func (s *projectService) Update(ctx context.Context, project *systemv3.Project) proj.Default = project.Spec.Default proj.ModifiedAt = time.Now() - _, err = s.dao.Update(ctx, proj.ID, proj) + _, err = pg.Update(ctx, s.db, proj.ID, proj) if err != nil { project.Status = &v3.Status{ ConditionType: "Update", @@ -259,7 +256,7 @@ func (s *projectService) Update(ctx context.Context, project *systemv3.Project) } func (s *projectService) Delete(ctx context.Context, project *systemv3.Project) (*systemv3.Project, error) { - entity, err := s.dao.GetByName(ctx, project.Metadata.Name, &models.Project{}) + entity, err := pg.GetByName(ctx, s.db, project.Metadata.Name, &models.Project{}) if err != nil { project.Status = &v3.Status{ ConditionType: "Delete", @@ -271,7 +268,7 @@ func (s *projectService) Delete(ctx context.Context, project *systemv3.Project) } if proj, ok := entity.(*models.Project); ok { proj.Trash = true - _, err := s.dao.Update(ctx, proj.ID, proj) + _, err := pg.Update(ctx, s.db, proj.ID, proj) if err != nil { project.Status = &v3.Status{ ConditionType: "Delete", @@ -306,17 +303,17 @@ func (s *projectService) List(ctx context.Context, project *systemv3.Project) (* } if len(project.Metadata.Organization) > 0 { var org models.Organization - _, err := s.dao.GetByName(ctx, project.Metadata.Organization, &org) + _, err := pg.GetByName(ctx, s.db, project.Metadata.Organization, &org) if err != nil { return projectList, err } var part models.Partner - _, err = s.dao.GetByName(ctx, project.Metadata.Partner, &part) + _, err = pg.GetByName(ctx, s.db, project.Metadata.Partner, &part) if err != nil { return projectList, err } var projs []models.Project - entities, err := s.dao.List(ctx, uuid.NullUUID{UUID: part.ID, Valid: true}, uuid.NullUUID{UUID: org.ID, Valid: true}, &projs) + entities, err := pg.List(ctx, s.db, uuid.NullUUID{UUID: part.ID, Valid: true}, uuid.NullUUID{UUID: org.ID, Valid: true}, &projs) if err != nil { return projectList, err } @@ -353,7 +350,3 @@ func (s *projectService) List(ctx context.Context, project *systemv3.Project) (* } return projectList, nil } - -func (s *projectService) Close() error { - return s.dao.Close() -} diff --git a/pkg/service/project_test.go b/pkg/service/project_test.go index 01d2470..59a5579 100644 --- a/pkg/service/project_test.go +++ b/pkg/service/project_test.go @@ -22,7 +22,6 @@ func TestCreateProject(t *testing.T) { defer db.Close() ps := NewProjectService(db) - defer ps.Close() puuid := uuid.New().String() ouuid := uuid.New().String() @@ -49,7 +48,6 @@ func TestCreateProjectDuplicate(t *testing.T) { defer db.Close() gs := NewProjectService(db) - defer gs.Close() puuid := uuid.New().String() @@ -72,7 +70,6 @@ func TestProjectDelete(t *testing.T) { defer db.Close() ps := NewProjectService(db) - defer ps.Close() puuid := uuid.New().String() @@ -96,7 +93,6 @@ func TestProjectDeleteNonExist(t *testing.T) { defer db.Close() ps := NewProjectService(db) - defer ps.Close() puuid := uuid.New().String() @@ -117,7 +113,6 @@ func TestProjectGetByName(t *testing.T) { defer db.Close() ps := NewProjectService(db) - defer ps.Close() partuuid := uuid.New().String() ouuid := uuid.New().String() @@ -146,7 +141,6 @@ func TestProjectGetById(t *testing.T) { defer db.Close() ps := NewProjectService(db) - defer ps.Close() puuid := uuid.New().String() @@ -168,7 +162,6 @@ func TestProjectUpdate(t *testing.T) { defer db.Close() ps := NewProjectService(db) - defer ps.Close() puuid := uuid.New().String() diff --git a/pkg/service/role.go b/pkg/service/role.go index 9ae4023..01104e8 100644 --- a/pkg/service/role.go +++ b/pkg/service/role.go @@ -9,7 +9,6 @@ import ( "github.com/RafaySystems/rcloud-base/internal/dao" "github.com/RafaySystems/rcloud-base/internal/models" "github.com/RafaySystems/rcloud-base/internal/persistence/provider/pg" - "github.com/RafaySystems/rcloud-base/internal/utils" authzv1 "github.com/RafaySystems/rcloud-base/proto/types/authz" v3 "github.com/RafaySystems/rcloud-base/proto/types/commonpb/v3" rolev3 "github.com/RafaySystems/rcloud-base/proto/types/rolepb/v3" @@ -25,7 +24,6 @@ const ( // RoleService is the interface for role operations type RoleService interface { - Close() error // create role Create(context.Context, *rolev3.Role) (*rolev3.Role, error) // get role by id @@ -42,30 +40,23 @@ type RoleService interface { // roleService implements RoleService type roleService struct { - dao pg.EntityDAO - rdao dao.RoleDAO - l utils.Lookup - azc AuthzService + db *bun.DB + azc AuthzService } // NewRoleService return new role service func NewRoleService(db *bun.DB, azc AuthzService) RoleService { - return &roleService{ - dao: pg.NewEntityDAO(db), - rdao: dao.NewRoleDAO(db), - l: utils.NewLookup(db), - azc: azc, - } + return &roleService{db: db, azc: azc} } func (s *roleService) getPartnerOrganization(ctx context.Context, role *rolev3.Role) (uuid.UUID, uuid.UUID, error) { partner := role.GetMetadata().GetPartner() org := role.GetMetadata().GetOrganization() - partnerId, err := s.l.GetPartnerId(ctx, partner) + partnerId, err := pg.GetPartnerId(ctx, s.db, partner) if err != nil { return uuid.Nil, uuid.Nil, err } - organizationId, err := s.l.GetOrganizationId(ctx, org) + organizationId, err := pg.GetOrganizationId(ctx, s.db, org) if err != nil { return partnerId, uuid.Nil, err } @@ -74,7 +65,7 @@ func (s *roleService) getPartnerOrganization(ctx context.Context, role *rolev3.R } func (s *roleService) deleteRolePermissionMapping(ctx context.Context, rleId uuid.UUID, role *rolev3.Role) (*rolev3.Role, error) { - err := s.dao.DeleteX(ctx, "resource_role_id", rleId, &models.ResourceRolePermission{}) + err := pg.DeleteX(ctx, s.db, "resource_role_id", rleId, &models.ResourceRolePermission{}) if err != nil { return &rolev3.Role{}, err } @@ -96,7 +87,7 @@ func (s *roleService) createRolePermissionMapping(ctx context.Context, role *rol var items []models.ResourceRolePermission for _, p := range perms { - entity, err := s.dao.GetIdByName(ctx, p, &models.ResourcePermission{}) + entity, err := pg.GetIdByName(ctx, s.db, p, &models.ResourcePermission{}) if err != nil { return role, fmt.Errorf("unable to find role permission '%v'", p) } @@ -110,7 +101,7 @@ func (s *roleService) createRolePermissionMapping(ctx context.Context, role *rol } } if len(items) > 0 { - _, err := s.dao.Create(ctx, &items) + _, err := pg.Create(ctx, s.db, &items) if err != nil { return role, err } @@ -134,7 +125,7 @@ func (s *roleService) Create(ctx context.Context, role *rolev3.Role) (*rolev3.Ro if err != nil { return nil, fmt.Errorf("unable to get partner and org id") } - r, _ := s.dao.GetIdByNamePartnerOrg(ctx, role.GetMetadata().GetName(), uuid.NullUUID{UUID: partnerId, Valid: true}, uuid.NullUUID{UUID: organizationId, Valid: true}, &models.Role{}) + r, _ := pg.GetIdByNamePartnerOrg(ctx, s.db, role.GetMetadata().GetName(), uuid.NullUUID{UUID: partnerId, Valid: true}, uuid.NullUUID{UUID: organizationId, Valid: true}, &models.Role{}) if r != nil { return nil, fmt.Errorf("role '%v' already exists", role.GetMetadata().GetName()) } @@ -159,7 +150,7 @@ func (s *roleService) Create(ctx context.Context, role *rolev3.Role) (*rolev3.Ro IsGlobal: role.GetSpec().GetIsGlobal(), Scope: strings.ToLower(scope), } - entity, err := s.dao.Create(ctx, &rle) + entity, err := pg.Create(ctx, s.db, &rle) if err != nil { return &rolev3.Role{}, err } @@ -184,7 +175,7 @@ func (s *roleService) GetByID(ctx context.Context, role *rolev3.Role) (*rolev3.R if err != nil { return &rolev3.Role{}, err } - entity, err := s.dao.GetByID(ctx, uid, &models.Role{}) + entity, err := pg.GetByID(ctx, s.db, uid, &models.Role{}) if err != nil { return &rolev3.Role{}, err } @@ -206,7 +197,7 @@ func (s *roleService) GetByName(ctx context.Context, role *rolev3.Role) (*rolev3 if err != nil { return nil, fmt.Errorf("unable to get partner and org id") } - entity, err := s.dao.GetByNamePartnerOrg(ctx, name, uuid.NullUUID{UUID: partnerId, Valid: true}, uuid.NullUUID{UUID: organizationId, Valid: true}, &models.Role{}) + entity, err := pg.GetByNamePartnerOrg(ctx, s.db, name, uuid.NullUUID{UUID: partnerId, Valid: true}, uuid.NullUUID{UUID: organizationId, Valid: true}, &models.Role{}) if err != nil { return &rolev3.Role{}, err } @@ -230,7 +221,7 @@ func (s *roleService) Update(ctx context.Context, role *rolev3.Role) (*rolev3.Ro } name := role.GetMetadata().GetName() - entity, err := s.dao.GetByNamePartnerOrg(ctx, name, uuid.NullUUID{UUID: partnerId, Valid: true}, uuid.NullUUID{UUID: organizationId, Valid: true}, &models.Role{}) + entity, err := pg.GetByNamePartnerOrg(ctx, s.db, name, uuid.NullUUID{UUID: partnerId, Valid: true}, uuid.NullUUID{UUID: organizationId, Valid: true}, &models.Role{}) if err != nil { return role, fmt.Errorf("unable to find role '%v'", name) } @@ -243,7 +234,7 @@ func (s *roleService) Update(ctx context.Context, role *rolev3.Role) (*rolev3.Ro rle.Scope = role.Spec.Scope rle.ModifiedAt = time.Now() - _, err = s.dao.Update(ctx, rle.ID, rle) + _, err = pg.Update(ctx, s.db, rle.ID, rle) if err != nil { return &rolev3.Role{}, err } @@ -277,7 +268,7 @@ func (s *roleService) Delete(ctx context.Context, role *rolev3.Role) (*rolev3.Ro return &rolev3.Role{}, fmt.Errorf("unable to get partner and org id; %v", err) } - entity, err := s.dao.GetByNamePartnerOrg(ctx, name, uuid.NullUUID{UUID: partnerId, Valid: true}, uuid.NullUUID{UUID: organizationId, Valid: true}, &models.Role{}) + entity, err := pg.GetByNamePartnerOrg(ctx, s.db, name, uuid.NullUUID{UUID: partnerId, Valid: true}, uuid.NullUUID{UUID: organizationId, Valid: true}, &models.Role{}) if err != nil { return &rolev3.Role{}, err } @@ -288,7 +279,7 @@ func (s *roleService) Delete(ctx context.Context, role *rolev3.Role) (*rolev3.Ro return &rolev3.Role{}, err } - err = s.dao.Delete(ctx, rle.ID, rle) + err = pg.Delete(ctx, s.db, rle.ID, rle) if err != nil { return &rolev3.Role{}, err } @@ -312,7 +303,7 @@ func (s *roleService) toV3Role(ctx context.Context, role *rolev3.Role, rle *mode Labels: labels, ModifiedAt: timestamppb.New(rle.ModifiedAt), } - entities, err := s.rdao.GetRolePermissions(ctx, rle.ID) + entities, err := dao.GetRolePermissions(ctx, s.db, rle.ID) if err != nil { return role, err } @@ -339,16 +330,16 @@ func (s *roleService) List(ctx context.Context, role *rolev3.Role) (*rolev3.Role }, } if len(role.Metadata.Organization) > 0 { - orgId, err := s.l.GetOrganizationId(ctx, role.Metadata.Organization) + orgId, err := pg.GetOrganizationId(ctx, s.db, role.Metadata.Organization) if err != nil { return roleList, err } - partId, err := s.l.GetPartnerId(ctx, role.Metadata.Partner) + partId, err := pg.GetPartnerId(ctx, s.db, role.Metadata.Partner) if err != nil { return roleList, err } var rles []models.Role - entities, err := s.dao.List(ctx, uuid.NullUUID{UUID: partId, Valid: true}, uuid.NullUUID{UUID: orgId, Valid: true}, &rles) + entities, err := pg.List(ctx, s.db, uuid.NullUUID{UUID: partId, Valid: true}, uuid.NullUUID{UUID: orgId, Valid: true}, &rles) if err != nil { return roleList, err } @@ -374,7 +365,3 @@ func (s *roleService) List(ctx context.Context, role *rolev3.Role) (*rolev3.Role } return roleList, nil } - -func (s *roleService) Close() error { - return s.dao.Close() -} diff --git a/pkg/service/role_test.go b/pkg/service/role_test.go index 8f19d32..00de9f4 100644 --- a/pkg/service/role_test.go +++ b/pkg/service/role_test.go @@ -53,7 +53,6 @@ func TestCreateRole(t *testing.T) { mazc := mockAuthzClient{} rs := NewRoleService(db, &mazc) - defer rs.Close() ruuid := uuid.New().String() puuid := uuid.New().String() @@ -86,7 +85,6 @@ func TestCreateRoleWithPermissions(t *testing.T) { mazc := mockAuthzClient{} rs := NewRoleService(db, &mazc) - defer rs.Close() ruuid := uuid.New().String() puuid := uuid.New().String() @@ -123,7 +121,6 @@ func TestCreateRoleDuplicate(t *testing.T) { mazc := mockAuthzClient{} rs := NewRoleService(db, &mazc) - defer rs.Close() ruuid := uuid.New().String() puuid := uuid.New().String() @@ -155,7 +152,6 @@ func TestUpdateRole(t *testing.T) { mazc := mockAuthzClient{} rs := NewRoleService(db, &mazc) - defer rs.Close() ruuid := uuid.New().String() puuid := uuid.New().String() @@ -195,7 +191,6 @@ func TestRoleDelete(t *testing.T) { mazc := mockAuthzClient{} rs := NewRoleService(db, &mazc) - defer rs.Close() ruuid := uuid.New().String() puuid := uuid.New().String() @@ -227,7 +222,6 @@ func TestRoleDeleteNonExist(t *testing.T) { mazc := mockAuthzClient{} rs := NewRoleService(db, &mazc) - defer rs.Close() ruuid := uuid.New().String() puuid := uuid.New().String() @@ -252,7 +246,6 @@ func TestRoleGetByName(t *testing.T) { mazc := mockAuthzClient{} rs := NewRoleService(db, &mazc) - defer rs.Close() ruuid := uuid.New().String() rruuid := uuid.New().String() @@ -284,7 +277,6 @@ func TestRoleGetById(t *testing.T) { mazc := mockAuthzClient{} rs := NewRoleService(db, &mazc) - defer rs.Close() ruuid := uuid.New().String() rruuid := uuid.New().String() @@ -312,7 +304,6 @@ func TestRoleList(t *testing.T) { mazc := mockAuthzClient{} rs := NewRoleService(db, &mazc) - defer rs.Close() ruuid1 := uuid.New().String() ruuid2 := uuid.New().String() diff --git a/pkg/service/rolepermission.go b/pkg/service/rolepermission.go index 6a4c3dd..58c91b8 100644 --- a/pkg/service/rolepermission.go +++ b/pkg/service/rolepermission.go @@ -5,7 +5,6 @@ import ( "github.com/RafaySystems/rcloud-base/internal/models" "github.com/RafaySystems/rcloud-base/internal/persistence/provider/pg" - "github.com/RafaySystems/rcloud-base/internal/utils" v3 "github.com/RafaySystems/rcloud-base/proto/types/commonpb/v3" rolev3 "github.com/RafaySystems/rcloud-base/proto/types/rolepb/v3" "github.com/google/uuid" @@ -19,7 +18,6 @@ const ( // RolepermissionService is the interface for rolepermission operations type RolepermissionService interface { - Close() error // get rolepermission by name GetByName(context.Context, *rolev3.RolePermission) (*rolev3.RolePermission, error) // list rolepermissions @@ -28,16 +26,12 @@ type RolepermissionService interface { // rolepermissionService implements RolepermissionService type rolepermissionService struct { - dao pg.EntityDAO - l utils.Lookup + db *bun.DB } // NewRolepermissionService return new rolepermission service func NewRolepermissionService(db *bun.DB) RolepermissionService { - return &rolepermissionService{ - dao: pg.NewEntityDAO(db), - l: utils.NewLookup(db), - } + return &rolepermissionService{db: db} } func (s *rolepermissionService) toV3Rolepermission(rolepermission *rolev3.RolePermission, rlp *models.ResourcePermission) *rolev3.RolePermission { @@ -53,11 +47,11 @@ func (s *rolepermissionService) toV3Rolepermission(rolepermission *rolev3.RolePe func (s *rolepermissionService) getPartnerOrganization(ctx context.Context, rolepermission *rolev3.RolePermission) (uuid.UUID, uuid.UUID, error) { partner := rolepermission.GetMetadata().GetPartner() org := rolepermission.GetMetadata().GetOrganization() - partnerId, err := s.l.GetPartnerId(ctx, partner) + partnerId, err := pg.GetPartnerId(ctx, s.db, partner) if err != nil { return uuid.Nil, uuid.Nil, err } - organizationId, err := s.l.GetOrganizationId(ctx, org) + organizationId, err := pg.GetOrganizationId(ctx, s.db, org) if err != nil { return partnerId, uuid.Nil, err } @@ -67,7 +61,7 @@ func (s *rolepermissionService) getPartnerOrganization(ctx context.Context, role func (s *rolepermissionService) GetByName(ctx context.Context, rolepermission *rolev3.RolePermission) (*rolev3.RolePermission, error) { name := rolepermission.GetMetadata().GetName() - entity, err := s.dao.GetByName(ctx, name, &models.ResourcePermission{}) + entity, err := pg.GetByName(ctx, s.db, name, &models.ResourcePermission{}) if err != nil { return rolepermission, err } @@ -91,7 +85,7 @@ func (s *rolepermissionService) List(ctx context.Context, rolepermission *rolev3 }, } var rles []models.ResourcePermission - entities, err := s.dao.List(ctx, uuid.NullUUID{UUID: uuid.Nil, Valid: false}, uuid.NullUUID{UUID: uuid.Nil, Valid: false}, &rles) + entities, err := pg.List(ctx, s.db, uuid.NullUUID{UUID: uuid.Nil, Valid: false}, uuid.NullUUID{UUID: uuid.Nil, Valid: false}, &rles) if err != nil { return rolepermissionList, err } @@ -111,7 +105,3 @@ func (s *rolepermissionService) List(ctx context.Context, rolepermission *rolev3 return rolepermissionList, nil } - -func (s *rolepermissionService) Close() error { - return s.dao.Close() -} diff --git a/pkg/service/rolepermission_test.go b/pkg/service/rolepermission_test.go index df3beb6..4713be8 100644 --- a/pkg/service/rolepermission_test.go +++ b/pkg/service/rolepermission_test.go @@ -29,7 +29,6 @@ func TestRolePermissionList(t *testing.T) { defer db.Close() rs := NewRolepermissionService(db) - defer rs.Close() ruuid1 := uuid.New().String() ruuid2 := uuid.New().String() diff --git a/pkg/service/user.go b/pkg/service/user.go index 73f0980..f7e67ba 100644 --- a/pkg/service/user.go +++ b/pkg/service/user.go @@ -14,7 +14,6 @@ import ( "github.com/RafaySystems/rcloud-base/internal/models" providers "github.com/RafaySystems/rcloud-base/internal/persistence/provider/kratos" "github.com/RafaySystems/rcloud-base/internal/persistence/provider/pg" - "github.com/RafaySystems/rcloud-base/internal/utils" "github.com/RafaySystems/rcloud-base/pkg/common" userrpcv3 "github.com/RafaySystems/rcloud-base/proto/rpc/user" authzv1 "github.com/RafaySystems/rcloud-base/proto/types/authz" @@ -29,7 +28,6 @@ const ( // GroupService is the interface for group operations type UserService interface { - Close() error // create user Create(context.Context, *userv3.User) (*userv3.User, error) // get user by id @@ -47,14 +45,11 @@ type UserService interface { } type userService struct { - ap providers.AuthProvider - dao pg.EntityDAO - udao dao.UserDAO - l utils.Lookup - azc AuthzService - pdao dao.PermissionDao - ks ApiKeyService - cc common.CliConfigDownloadData + ap providers.AuthProvider + db *bun.DB + azc AuthzService + ks ApiKeyService + cc common.CliConfigDownloadData } type userTraits struct { @@ -72,8 +67,7 @@ type parsedIds struct { } func NewUserService(ap providers.AuthProvider, db *bun.DB, azc AuthzService, kss ApiKeyService, cfg common.CliConfigDownloadData) UserService { - edao := pg.NewEntityDAO(db) - return &userService{ap: ap, dao: edao, udao: dao.NewUserDAO(db), l: utils.NewLookup(db), azc: azc, pdao: dao.NewPermissionDao(edao), ks: kss, cc: cfg} + return &userService{ap: ap, db: db, azc: azc, ks: kss, cc: cfg} } func getUserTraits(traits map[string]interface{}) userTraits { @@ -114,7 +108,7 @@ func (s *userService) createUserRoleRelations(ctx context.Context, user *userv3. var ps []*authzv1.Policy for _, pnr := range projectNamespaceRoles { role := pnr.GetRole() - entity, err := s.dao.GetIdByName(ctx, role, &models.Role{}) + entity, err := pg.GetIdByName(ctx, s.db, role, &models.Role{}) if err != nil { return user, fmt.Errorf("unable to find role '%v'", role) } @@ -131,7 +125,7 @@ func (s *userService) createUserRoleRelations(ctx context.Context, user *userv3. switch { case pnr.Namespace != nil: - projectId, err := s.l.GetProjectId(ctx, project) + projectId, err := pg.GetProjectId(ctx, s.db, project) if err != nil { return user, fmt.Errorf("unable to find project '%v'", project) } @@ -158,7 +152,7 @@ func (s *userService) createUserRoleRelations(ctx context.Context, user *userv3. Act: "*", }) case project != "": - projectId, err := s.l.GetProjectId(ctx, project) + projectId, err := pg.GetProjectId(ctx, s.db, project) if err != nil { return user, fmt.Errorf("unable to find project '%v'", project) } @@ -209,19 +203,19 @@ func (s *userService) createUserRoleRelations(ctx context.Context, user *userv3. } } if len(panrs) > 0 { - _, err := s.dao.Create(ctx, &panrs) + _, err := pg.Create(ctx, s.db, &panrs) if err != nil { return &userv3.User{}, err } } if len(pars) > 0 { - _, err := s.dao.Create(ctx, &pars) + _, err := pg.Create(ctx, s.db, &pars) if err != nil { return &userv3.User{}, err } } if len(ars) > 0 { - _, err := s.dao.Create(ctx, &ars) + _, err := pg.Create(ctx, s.db, &ars) if err != nil { return &userv3.User{}, err } @@ -241,11 +235,11 @@ func (s *userService) createUserRoleRelations(ctx context.Context, user *userv3. func (s *userService) getPartnerOrganization(ctx context.Context, user *userv3.User) (uuid.UUID, uuid.UUID, error) { partner := user.GetMetadata().GetPartner() org := user.GetMetadata().GetOrganization() - partnerId, err := s.l.GetPartnerId(ctx, partner) + partnerId, err := pg.GetPartnerId(ctx, s.db, partner) if err != nil { return uuid.Nil, uuid.Nil, err } - organizationId, err := s.l.GetOrganizationId(ctx, org) + organizationId, err := pg.GetOrganizationId(ctx, s.db, org) if err != nil { return partnerId, uuid.Nil, err } @@ -288,7 +282,7 @@ func (s *userService) Create(ctx context.Context, user *userv3.User) (*userv3.Us func (s *userService) identitiesModelToUser(ctx context.Context, user *userv3.User, usr *models.KratosIdentities) (*userv3.User, error) { traits := getUserTraits(usr.Traits) - groups, err := s.udao.GetGroups(ctx, usr.ID) + groups, err := dao.GetGroups(ctx, s.db, usr.ID) if err != nil { return &userv3.User{}, err } @@ -299,7 +293,7 @@ func (s *userService) identitiesModelToUser(ctx context.Context, user *userv3.Us labels := make(map[string]string) - roles, err := s.udao.GetRoles(ctx, usr.ID) + roles, err := dao.GetUserRoles(ctx, s.db, usr.ID) if err != nil { return &userv3.User{}, err } @@ -328,7 +322,7 @@ func (s *userService) GetByID(ctx context.Context, user *userv3.User) (*userv3.U if err != nil { return &userv3.User{}, err } - entity, err := s.dao.GetByID(ctx, uid, &models.KratosIdentities{}) + entity, err := pg.GetByID(ctx, s.db, uid, &models.KratosIdentities{}) if err != nil { return &userv3.User{}, err } @@ -347,7 +341,7 @@ func (s *userService) GetByID(ctx context.Context, user *userv3.User) (*userv3.U func (s *userService) GetByName(ctx context.Context, user *userv3.User) (*userv3.User, error) { name := user.GetMetadata().GetName() - entity, err := s.dao.GetByTraits(ctx, name, &models.KratosIdentities{}) + entity, err := pg.GetByTraits(ctx, s.db, name, &models.KratosIdentities{}) if err != nil { return &userv3.User{}, err } @@ -365,15 +359,15 @@ func (s *userService) GetByName(ctx context.Context, user *userv3.User) (*userv3 } func (s *userService) deleteUserRoleRelations(ctx context.Context, userId uuid.UUID, user *userv3.User) error { - err := s.dao.DeleteX(ctx, "account_id", userId, &models.AccountResourcerole{}) + err := pg.DeleteX(ctx, s.db, "account_id", userId, &models.AccountResourcerole{}) if err != nil { return err } - err = s.dao.DeleteX(ctx, "account_id", userId, &models.ProjectAccountResourcerole{}) + err = pg.DeleteX(ctx, s.db, "account_id", userId, &models.ProjectAccountResourcerole{}) if err != nil { return err } - err = s.dao.DeleteX(ctx, "account_id", userId, &models.ProjectAccountNamespaceRole{}) + err = pg.DeleteX(ctx, s.db, "account_id", userId, &models.ProjectAccountNamespaceRole{}) if err != nil { return err } @@ -388,7 +382,7 @@ func (s *userService) deleteUserRoleRelations(ctx context.Context, userId uuid.U func (s *userService) Update(ctx context.Context, user *userv3.User) (*userv3.User, error) { name := user.GetMetadata().GetName() - entity, err := s.dao.GetIdByTraits(ctx, name, &models.KratosIdentities{}) + entity, err := pg.GetIdByTraits(ctx, s.db, name, &models.KratosIdentities{}) if err != nil { return &userv3.User{}, fmt.Errorf("no user found with name '%v'", name) } @@ -426,7 +420,7 @@ func (s *userService) Update(ctx context.Context, user *userv3.User) (*userv3.Us func (s *userService) Delete(ctx context.Context, user *userv3.User) (*userrpcv3.DeleteUserResponse, error) { name := user.GetMetadata().GetName() - entity, err := s.dao.GetIdByTraits(ctx, name, &models.KratosIdentities{}) + entity, err := pg.GetIdByTraits(ctx, s.db, name, &models.KratosIdentities{}) if err != nil { return &userrpcv3.DeleteUserResponse{}, fmt.Errorf("no user founnd with username '%v'", name) } @@ -442,7 +436,7 @@ func (s *userService) Delete(ctx context.Context, user *userv3.User) (*userrpcv3 return &userrpcv3.DeleteUserResponse{}, err } - err = s.dao.DeleteX(ctx, "account_id", usr.ID, &models.GroupAccount{}) + err = pg.DeleteX(ctx, s.db, "account_id", usr.ID, &models.GroupAccount{}) if err != nil { return &userrpcv3.DeleteUserResponse{}, fmt.Errorf("unable to delete user; %v", err) } @@ -463,7 +457,7 @@ func (s *userService) List(ctx context.Context, _ *userv3.User) (*userv3.UserLis }, } var accs []models.KratosIdentities - entities, err := s.dao.ListAll(ctx, &accs) + entities, err := pg.ListAll(ctx, s.db, &accs) if err != nil { return userList, err } @@ -489,25 +483,25 @@ func (s *userService) List(ctx context.Context, _ *userv3.User) (*userv3.UserLis func (s *userService) RetrieveCliConfig(ctx context.Context, req *userrpcv3.ApiKeyRequest) (*common.CliConfigDownloadData, error) { // get the default project associated to this account - ap, err := s.pdao.GetDefaultAccountProject(ctx, uuid.MustParse(req.Id)) + ap, err := dao.GetDefaultAccountProject(ctx, s.db, uuid.MustParse(req.Id)) if err != nil { return nil, err } // fetch the metadata information required to populate cli config var proj models.Project - _, err = s.dao.GetByID(ctx, ap.ProjecttId, &proj) + _, err = pg.GetByID(ctx, s.db, ap.ProjecttId, &proj) if err != nil { return nil, err } var org models.Organization - _, err = s.dao.GetByID(ctx, ap.OrganizationId, &org) + _, err = pg.GetByID(ctx, s.db, ap.OrganizationId, &org) if err != nil { return nil, err } var part models.Partner - _, err = s.dao.GetByID(ctx, ap.PartnerId, &part) + _, err = pg.GetByID(ctx, s.db, ap.PartnerId, &part) if err != nil { return nil, err } @@ -539,7 +533,3 @@ func (s *userService) RetrieveCliConfig(ctx context.Context, req *userrpcv3.ApiK return cliConfig, nil } - -func (s *userService) Close() error { - return s.dao.Close() -} diff --git a/pkg/service/user_test.go b/pkg/service/user_test.go index 6f8e57c..e8e289f 100644 --- a/pkg/service/user_test.go +++ b/pkg/service/user_test.go @@ -61,7 +61,6 @@ func TestCreateUser(t *testing.T) { ap := &mockAuthProvider{} mazc := mockAuthzClient{} us := NewUserService(ap, db, &mazc, nil, common.CliConfigDownloadData{}) - defer us.Close() uuuid := uuid.New().String() puuid := uuid.New().String() @@ -115,7 +114,6 @@ func TestCreateUserWithRole(t *testing.T) { ap := &mockAuthProvider{} mazc := mockAuthzClient{} us := NewUserService(ap, db, &mazc, nil, common.CliConfigDownloadData{}) - defer us.Close() uuuid := uuid.New().String() puuid := uuid.New().String() @@ -168,7 +166,6 @@ func TestUpdateUser(t *testing.T) { ap := &mockAuthProvider{} mazc := mockAuthzClient{} us := NewUserService(ap, db, &mazc, nil, common.CliConfigDownloadData{}) - defer us.Close() uuuid := uuid.New().String() puuid := uuid.New().String() @@ -222,7 +219,6 @@ func TestUserGetByName(t *testing.T) { ap := &mockAuthProvider{} mazc := mockAuthzClient{} us := NewUserService(ap, db, &mazc, nil, common.CliConfigDownloadData{}) - defer us.Close() uuuid := uuid.New().String() puuid := uuid.New().String() @@ -273,7 +269,6 @@ func TestUserGetById(t *testing.T) { ap := &mockAuthProvider{} mazc := mockAuthzClient{} us := NewUserService(ap, db, &mazc, nil, common.CliConfigDownloadData{}) - defer us.Close() uuuid := uuid.New().String() puuid := uuid.New().String() @@ -322,7 +317,6 @@ func TestUserList(t *testing.T) { ap := &mockAuthProvider{} mazc := mockAuthzClient{} us := NewUserService(ap, db, &mazc, nil, common.CliConfigDownloadData{}) - defer us.Close() uuuid1 := uuid.New().String() uuuid2 := uuid.New().String() @@ -391,7 +385,6 @@ func TestUserDelete(t *testing.T) { ap := &mockAuthProvider{} mazc := mockAuthzClient{} us := NewUserService(ap, db, &mazc, nil, common.CliConfigDownloadData{}) - defer us.Close() uuuid := uuid.New().String() puuid := uuid.New().String() diff --git a/server/bootstrap.go b/server/bootstrap.go index 874923c..a15d81a 100644 --- a/server/bootstrap.go +++ b/server/bootstrap.go @@ -115,7 +115,7 @@ func (s *bootstrapServer) DeleteBootstrapAgent(ctx context.Context, in *sentry.B return } - err = s.bs.DeleteBoostrapAgent(ctx, templateRef, query.WithMeta(in.Metadata)) + err = s.bs.DeleteBootstrapAgent(ctx, templateRef, query.WithMeta(in.Metadata)) if err == sql.ErrNoRows { err = status.Error(codes.NotFound, err.Error()) }