mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-03-08 03:40:24 +00:00
Compare commits
2 Commits
main
...
ldap-sync-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3650d797db | ||
|
|
6b0ce57081 |
@@ -35,6 +35,7 @@ type LdapService struct {
|
||||
userService *UserService
|
||||
groupService *UserGroupService
|
||||
fileStorage storage.FileStorage
|
||||
clientFactory func() (ldapClient, error)
|
||||
}
|
||||
|
||||
type savePicture struct {
|
||||
@@ -43,8 +44,33 @@ type savePicture struct {
|
||||
picture string
|
||||
}
|
||||
|
||||
type ldapDesiredUser struct {
|
||||
ldapID string
|
||||
input dto.UserCreateDto
|
||||
picture string
|
||||
}
|
||||
|
||||
type ldapDesiredGroup struct {
|
||||
ldapID string
|
||||
input dto.UserGroupCreateDto
|
||||
memberUsernames []string
|
||||
}
|
||||
|
||||
type ldapDesiredState struct {
|
||||
users []ldapDesiredUser
|
||||
userIDs map[string]struct{}
|
||||
groups []ldapDesiredGroup
|
||||
groupIDs map[string]struct{}
|
||||
}
|
||||
|
||||
type ldapClient interface {
|
||||
Search(searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error)
|
||||
Bind(username, password string) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
func NewLdapService(db *gorm.DB, httpClient *http.Client, appConfigService *AppConfigService, userService *UserService, groupService *UserGroupService, fileStorage storage.FileStorage) *LdapService {
|
||||
return &LdapService{
|
||||
service := &LdapService{
|
||||
db: db,
|
||||
httpClient: httpClient,
|
||||
appConfigService: appConfigService,
|
||||
@@ -52,9 +78,12 @@ func NewLdapService(db *gorm.DB, httpClient *http.Client, appConfigService *AppC
|
||||
groupService: groupService,
|
||||
fileStorage: fileStorage,
|
||||
}
|
||||
|
||||
service.clientFactory = service.createClient
|
||||
return service
|
||||
}
|
||||
|
||||
func (s *LdapService) createClient() (*ldap.Conn, error) {
|
||||
func (s *LdapService) createClient() (ldapClient, error) {
|
||||
dbConfig := s.appConfigService.GetDbConfig()
|
||||
|
||||
if !dbConfig.LdapEnabled.IsTrue() {
|
||||
@@ -79,24 +108,33 @@ func (s *LdapService) createClient() (*ldap.Conn, error) {
|
||||
|
||||
func (s *LdapService) SyncAll(ctx context.Context) error {
|
||||
// Setup LDAP connection
|
||||
client, err := s.createClient()
|
||||
client, err := s.clientFactory()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create LDAP client: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Start a transaction
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
// First, we fetch all users and group from LDAP, which is our "desired state"
|
||||
desiredState, err := s.fetchDesiredState(ctx, client)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch LDAP state: %w", err)
|
||||
}
|
||||
|
||||
savePictures, deleteFiles, err := s.SyncUsers(ctx, tx, client)
|
||||
// Start a transaction
|
||||
tx := s.db.WithContext(ctx).Begin()
|
||||
if tx.Error != nil {
|
||||
return fmt.Errorf("failed to begin database transaction: %w", tx.Error)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// Reconcile users
|
||||
savePictures, deleteFiles, err := s.reconcileUsers(ctx, tx, desiredState.users, desiredState.userIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sync users: %w", err)
|
||||
}
|
||||
|
||||
err = s.SyncGroups(ctx, tx, client)
|
||||
// Reconcile groups
|
||||
err = s.reconcileGroups(ctx, tx, desiredState.groups, desiredState.groupIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sync groups: %w", err)
|
||||
}
|
||||
@@ -129,10 +167,31 @@ func (s *LdapService) SyncAll(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
//nolint:gocognit
|
||||
func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.Conn) error {
|
||||
func (s *LdapService) fetchDesiredState(ctx context.Context, client ldapClient) (ldapDesiredState, error) {
|
||||
// Fetch users first so we can use their DNs when resolving group members
|
||||
users, userIDs, usernamesByDN, err := s.fetchUsersFromLDAP(ctx, client)
|
||||
if err != nil {
|
||||
return ldapDesiredState{}, err
|
||||
}
|
||||
|
||||
// Then fetch groups to complete the desired LDAP state snapshot
|
||||
groups, groupIDs, err := s.fetchGroupsFromLDAP(ctx, client, usernamesByDN)
|
||||
if err != nil {
|
||||
return ldapDesiredState{}, err
|
||||
}
|
||||
|
||||
return ldapDesiredState{
|
||||
users: users,
|
||||
userIDs: userIDs,
|
||||
groups: groups,
|
||||
groupIDs: groupIDs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *LdapService) fetchGroupsFromLDAP(ctx context.Context, client ldapClient, usernamesByDN map[string]string) (desiredGroups []ldapDesiredGroup, ldapGroupIDs map[string]struct{}, err error) {
|
||||
dbConfig := s.appConfigService.GetDbConfig()
|
||||
|
||||
// Query LDAP for all groups we want to manage
|
||||
searchAttrs := []string{
|
||||
dbConfig.LdapAttributeGroupName.Value,
|
||||
dbConfig.LdapAttributeGroupUniqueIdentifier.Value,
|
||||
@@ -149,90 +208,42 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.
|
||||
)
|
||||
result, err := client.Search(searchReq)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query LDAP: %w", err)
|
||||
return nil, nil, fmt.Errorf("failed to query LDAP groups: %w", err)
|
||||
}
|
||||
|
||||
// Create a mapping for groups that exist
|
||||
ldapGroupIDs := make(map[string]struct{}, len(result.Entries))
|
||||
// Build the in-memory desired state for groups
|
||||
ldapGroupIDs = make(map[string]struct{}, len(result.Entries))
|
||||
desiredGroups = make([]ldapDesiredGroup, 0, len(result.Entries))
|
||||
|
||||
for _, value := range result.Entries {
|
||||
ldapId := convertLdapIdToString(value.GetAttributeValue(dbConfig.LdapAttributeGroupUniqueIdentifier.Value))
|
||||
ldapID := convertLdapIdToString(value.GetAttributeValue(dbConfig.LdapAttributeGroupUniqueIdentifier.Value))
|
||||
|
||||
// Skip groups without a valid LDAP ID
|
||||
if ldapId == "" {
|
||||
if ldapID == "" {
|
||||
slog.Warn("Skipping LDAP group without a valid unique identifier", slog.String("attribute", dbConfig.LdapAttributeGroupUniqueIdentifier.Value))
|
||||
continue
|
||||
}
|
||||
|
||||
ldapGroupIDs[ldapId] = struct{}{}
|
||||
|
||||
// Try to find the group in the database
|
||||
var databaseGroup model.UserGroup
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Where("ldap_id = ?", ldapId).
|
||||
First(&databaseGroup).
|
||||
Error
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// This could error with ErrRecordNotFound and we want to ignore that here
|
||||
return fmt.Errorf("failed to query for LDAP group ID '%s': %w", ldapId, err)
|
||||
}
|
||||
ldapGroupIDs[ldapID] = struct{}{}
|
||||
|
||||
// Get group members and add to the correct Group
|
||||
groupMembers := value.GetAttributeValues(dbConfig.LdapAttributeGroupMember.Value)
|
||||
membersUserId := make([]string, 0, len(groupMembers))
|
||||
memberUsernames := make([]string, 0, len(groupMembers))
|
||||
for _, member := range groupMembers {
|
||||
username := getDNProperty(dbConfig.LdapAttributeUserUsername.Value, member)
|
||||
|
||||
// If username extraction fails, try to query LDAP directly for the user
|
||||
username := s.resolveGroupMemberUsername(ctx, client, member, usernamesByDN)
|
||||
if username == "" {
|
||||
// Query LDAP to get the user by their DN
|
||||
userSearchReq := ldap.NewSearchRequest(
|
||||
member,
|
||||
ldap.ScopeBaseObject,
|
||||
0, 0, 0, false,
|
||||
"(objectClass=*)",
|
||||
[]string{dbConfig.LdapAttributeUserUsername.Value, dbConfig.LdapAttributeUserUniqueIdentifier.Value},
|
||||
[]ldap.Control{},
|
||||
)
|
||||
|
||||
userResult, err := client.Search(userSearchReq)
|
||||
if err != nil || len(userResult.Entries) == 0 {
|
||||
slog.WarnContext(ctx, "Could not resolve group member DN", slog.String("member", member), slog.Any("error", err))
|
||||
continue
|
||||
}
|
||||
|
||||
username = userResult.Entries[0].GetAttributeValue(dbConfig.LdapAttributeUserUsername.Value)
|
||||
if username == "" {
|
||||
slog.WarnContext(ctx, "Could not extract username from group member DN", slog.String("member", member))
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
username = norm.NFC.String(username)
|
||||
|
||||
var databaseUser model.User
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Where("username = ? AND ldap_id IS NOT NULL", username).
|
||||
First(&databaseUser).
|
||||
Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// The user collides with a non-LDAP user, so we skip it
|
||||
continue
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to query for existing user '%s': %w", username, err)
|
||||
}
|
||||
|
||||
membersUserId = append(membersUserId, databaseUser.ID)
|
||||
memberUsernames = append(memberUsernames, username)
|
||||
}
|
||||
|
||||
syncGroup := dto.UserGroupCreateDto{
|
||||
Name: value.GetAttributeValue(dbConfig.LdapAttributeGroupName.Value),
|
||||
FriendlyName: value.GetAttributeValue(dbConfig.LdapAttributeGroupName.Value),
|
||||
LdapID: ldapId,
|
||||
LdapID: ldapID,
|
||||
}
|
||||
dto.Normalize(syncGroup)
|
||||
dto.Normalize(&syncGroup)
|
||||
|
||||
err = syncGroup.Validate()
|
||||
if err != nil {
|
||||
@@ -240,64 +251,20 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.
|
||||
continue
|
||||
}
|
||||
|
||||
if databaseGroup.ID == "" {
|
||||
newGroup, err := s.groupService.createInternal(ctx, syncGroup, tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create group '%s': %w", syncGroup.Name, err)
|
||||
}
|
||||
|
||||
_, err = s.groupService.updateUsersInternal(ctx, newGroup.ID, membersUserId, tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sync users for group '%s': %w", syncGroup.Name, err)
|
||||
}
|
||||
} else {
|
||||
_, err = s.groupService.updateInternal(ctx, databaseGroup.ID, syncGroup, true, tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update group '%s': %w", syncGroup.Name, err)
|
||||
}
|
||||
|
||||
_, err = s.groupService.updateUsersInternal(ctx, databaseGroup.ID, membersUserId, tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sync users for group '%s': %w", syncGroup.Name, err)
|
||||
}
|
||||
}
|
||||
desiredGroups = append(desiredGroups, ldapDesiredGroup{
|
||||
ldapID: ldapID,
|
||||
input: syncGroup,
|
||||
memberUsernames: memberUsernames,
|
||||
})
|
||||
}
|
||||
|
||||
// Get all LDAP groups from the database
|
||||
var ldapGroupsInDb []model.UserGroup
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Find(&ldapGroupsInDb, "ldap_id IS NOT NULL").
|
||||
Select("ldap_id").
|
||||
Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch groups from database: %w", err)
|
||||
}
|
||||
|
||||
// Delete groups that no longer exist in LDAP
|
||||
for _, group := range ldapGroupsInDb {
|
||||
if _, exists := ldapGroupIDs[*group.LdapID]; exists {
|
||||
continue
|
||||
}
|
||||
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Delete(&model.UserGroup{}, "ldap_id = ?", group.LdapID).
|
||||
Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete group '%s': %w", group.Name, err)
|
||||
}
|
||||
|
||||
slog.Info("Deleted group", slog.String("group", group.Name))
|
||||
}
|
||||
|
||||
return nil
|
||||
return desiredGroups, ldapGroupIDs, nil
|
||||
}
|
||||
|
||||
//nolint:gocognit
|
||||
func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.Conn) (savePictures []savePicture, deleteFiles []string, err error) {
|
||||
func (s *LdapService) fetchUsersFromLDAP(ctx context.Context, client ldapClient) (desiredUsers []ldapDesiredUser, ldapUserIDs map[string]struct{}, usernamesByDN map[string]string, err error) {
|
||||
dbConfig := s.appConfigService.GetDbConfig()
|
||||
|
||||
// Query LDAP for all users we want to manage
|
||||
searchAttrs := []string{
|
||||
"memberOf",
|
||||
"sn",
|
||||
@@ -323,50 +290,29 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
|
||||
|
||||
result, err := client.Search(searchReq)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to query LDAP: %w", err)
|
||||
return nil, nil, nil, fmt.Errorf("failed to query LDAP users: %w", err)
|
||||
}
|
||||
|
||||
// Create a mapping for users that exist
|
||||
ldapUserIDs := make(map[string]struct{}, len(result.Entries))
|
||||
savePictures = make([]savePicture, 0, len(result.Entries))
|
||||
// Build the in-memory desired state for users and a DN lookup for group membership resolution
|
||||
ldapUserIDs = make(map[string]struct{}, len(result.Entries))
|
||||
usernamesByDN = make(map[string]string, len(result.Entries))
|
||||
desiredUsers = make([]ldapDesiredUser, 0, len(result.Entries))
|
||||
|
||||
for _, value := range result.Entries {
|
||||
ldapId := convertLdapIdToString(value.GetAttributeValue(dbConfig.LdapAttributeUserUniqueIdentifier.Value))
|
||||
username := norm.NFC.String(value.GetAttributeValue(dbConfig.LdapAttributeUserUsername.Value))
|
||||
if normalizedDN := normalizeLDAPDN(value.DN); normalizedDN != "" && username != "" {
|
||||
usernamesByDN[normalizedDN] = username
|
||||
}
|
||||
|
||||
ldapID := convertLdapIdToString(value.GetAttributeValue(dbConfig.LdapAttributeUserUniqueIdentifier.Value))
|
||||
|
||||
// Skip users without a valid LDAP ID
|
||||
if ldapId == "" {
|
||||
if ldapID == "" {
|
||||
slog.Warn("Skipping LDAP user without a valid unique identifier", slog.String("attribute", dbConfig.LdapAttributeUserUniqueIdentifier.Value))
|
||||
continue
|
||||
}
|
||||
|
||||
ldapUserIDs[ldapId] = struct{}{}
|
||||
|
||||
// Get the user from the database
|
||||
var databaseUser model.User
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Where("ldap_id = ?", ldapId).
|
||||
First(&databaseUser).
|
||||
Error
|
||||
|
||||
// If a user is found (even if disabled), enable them since they're now back in LDAP
|
||||
if databaseUser.ID != "" && databaseUser.Disabled {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Model(&model.User{}).
|
||||
Where("id = ?", databaseUser.ID).
|
||||
Update("disabled", false).
|
||||
Error
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to enable user %s: %w", databaseUser.Username, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// This could error with ErrRecordNotFound and we want to ignore that here
|
||||
return nil, nil, fmt.Errorf("failed to query for LDAP user ID '%s': %w", ldapId, err)
|
||||
}
|
||||
ldapUserIDs[ldapID] = struct{}{}
|
||||
|
||||
// Check if user is admin by checking if they are in the admin group
|
||||
isAdmin := false
|
||||
@@ -385,14 +331,14 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
|
||||
LastName: value.GetAttributeValue(dbConfig.LdapAttributeUserLastName.Value),
|
||||
DisplayName: value.GetAttributeValue(dbConfig.LdapAttributeUserDisplayName.Value),
|
||||
IsAdmin: isAdmin,
|
||||
LdapID: ldapId,
|
||||
LdapID: ldapID,
|
||||
}
|
||||
|
||||
if newUser.DisplayName == "" {
|
||||
newUser.DisplayName = strings.TrimSpace(newUser.FirstName + " " + newUser.LastName)
|
||||
}
|
||||
|
||||
dto.Normalize(newUser)
|
||||
dto.Normalize(&newUser)
|
||||
|
||||
err = newUser.Validate()
|
||||
if err != nil {
|
||||
@@ -400,53 +346,201 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
|
||||
continue
|
||||
}
|
||||
|
||||
userID := databaseUser.ID
|
||||
if databaseUser.ID == "" {
|
||||
createdUser, err := s.userService.createUserInternal(ctx, newUser, true, tx)
|
||||
if errors.Is(err, &common.AlreadyInUseError{}) {
|
||||
slog.Warn("Skipping creating LDAP user", slog.String("username", newUser.Username), slog.Any("error", err))
|
||||
desiredUsers = append(desiredUsers, ldapDesiredUser{
|
||||
ldapID: ldapID,
|
||||
input: newUser,
|
||||
picture: value.GetAttributeValue(dbConfig.LdapAttributeUserProfilePicture.Value),
|
||||
})
|
||||
}
|
||||
|
||||
return desiredUsers, ldapUserIDs, usernamesByDN, nil
|
||||
}
|
||||
|
||||
func (s *LdapService) resolveGroupMemberUsername(ctx context.Context, client ldapClient, member string, usernamesByDN map[string]string) string {
|
||||
dbConfig := s.appConfigService.GetDbConfig()
|
||||
|
||||
// First try the DN cache we built while loading users
|
||||
username, exists := usernamesByDN[normalizeLDAPDN(member)]
|
||||
if exists && username != "" {
|
||||
return username
|
||||
}
|
||||
|
||||
// Then try to extract the username directly from the DN
|
||||
username = getDNProperty(dbConfig.LdapAttributeUserUsername.Value, member)
|
||||
if username != "" {
|
||||
return norm.NFC.String(username)
|
||||
}
|
||||
|
||||
// As a fallback, query LDAP for the referenced entry
|
||||
userSearchReq := ldap.NewSearchRequest(
|
||||
member,
|
||||
ldap.ScopeBaseObject,
|
||||
0, 0, 0, false,
|
||||
"(objectClass=*)",
|
||||
[]string{dbConfig.LdapAttributeUserUsername.Value},
|
||||
[]ldap.Control{},
|
||||
)
|
||||
|
||||
userResult, err := client.Search(userSearchReq)
|
||||
if err != nil || len(userResult.Entries) == 0 {
|
||||
slog.WarnContext(ctx, "Could not resolve group member DN", slog.String("member", member), slog.Any("error", err))
|
||||
return ""
|
||||
}
|
||||
|
||||
username = userResult.Entries[0].GetAttributeValue(dbConfig.LdapAttributeUserUsername.Value)
|
||||
if username == "" {
|
||||
slog.WarnContext(ctx, "Could not extract username from group member DN", slog.String("member", member))
|
||||
return ""
|
||||
}
|
||||
|
||||
return norm.NFC.String(username)
|
||||
}
|
||||
|
||||
func (s *LdapService) reconcileGroups(ctx context.Context, tx *gorm.DB, desiredGroups []ldapDesiredGroup, ldapGroupIDs map[string]struct{}) error {
|
||||
// Load the current LDAP-managed state from the database
|
||||
ldapGroupsInDB, ldapGroupsByID, err := s.loadLDAPGroupsInDB(ctx, tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch groups from database: %w", err)
|
||||
}
|
||||
|
||||
_, _, ldapUsersByUsername, err := s.loadLDAPUsersInDB(ctx, tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch users from database: %w", err)
|
||||
}
|
||||
|
||||
// Apply creates and updates to match the desired LDAP group state
|
||||
for _, desiredGroup := range desiredGroups {
|
||||
memberUserIDs := make([]string, 0, len(desiredGroup.memberUsernames))
|
||||
for _, username := range desiredGroup.memberUsernames {
|
||||
databaseUser, exists := ldapUsersByUsername[username]
|
||||
if !exists {
|
||||
// The user collides with a non-LDAP user or was skipped during user sync, so we ignore it
|
||||
continue
|
||||
} else if err != nil {
|
||||
return nil, nil, fmt.Errorf("error creating user '%s': %w", newUser.Username, err)
|
||||
}
|
||||
userID = createdUser.ID
|
||||
} else {
|
||||
_, err = s.userService.updateUserInternal(ctx, databaseUser.ID, newUser, false, true, tx)
|
||||
if errors.Is(err, &common.AlreadyInUseError{}) {
|
||||
slog.Warn("Skipping updating LDAP user", slog.String("username", newUser.Username), slog.Any("error", err))
|
||||
continue
|
||||
} else if err != nil {
|
||||
return nil, nil, fmt.Errorf("error updating user '%s': %w", newUser.Username, err)
|
||||
}
|
||||
|
||||
memberUserIDs = append(memberUserIDs, databaseUser.ID)
|
||||
}
|
||||
|
||||
// Save profile picture
|
||||
pictureString := value.GetAttributeValue(dbConfig.LdapAttributeUserProfilePicture.Value)
|
||||
if pictureString != "" {
|
||||
// Storage operations must be executed outside of a transaction
|
||||
savePictures = append(savePictures, savePicture{
|
||||
userID: databaseUser.ID,
|
||||
username: userID,
|
||||
picture: pictureString,
|
||||
})
|
||||
databaseGroup := ldapGroupsByID[desiredGroup.ldapID]
|
||||
if databaseGroup.ID == "" {
|
||||
newGroup, err := s.groupService.createInternal(ctx, desiredGroup.input, tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create group '%s': %w", desiredGroup.input.Name, err)
|
||||
}
|
||||
ldapGroupsByID[desiredGroup.ldapID] = newGroup
|
||||
|
||||
_, err = s.groupService.updateUsersInternal(ctx, newGroup.ID, memberUserIDs, tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sync users for group '%s': %w", desiredGroup.input.Name, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = s.groupService.updateInternal(ctx, databaseGroup.ID, desiredGroup.input, true, tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update group '%s': %w", desiredGroup.input.Name, err)
|
||||
}
|
||||
|
||||
_, err = s.groupService.updateUsersInternal(ctx, databaseGroup.ID, memberUserIDs, tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sync users for group '%s': %w", desiredGroup.input.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Get all LDAP users from the database
|
||||
var ldapUsersInDb []model.User
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Find(&ldapUsersInDb, "ldap_id IS NOT NULL").
|
||||
Select("id, username, ldap_id, disabled").
|
||||
Error
|
||||
// Delete groups that are no longer present in LDAP
|
||||
for _, group := range ldapGroupsInDB {
|
||||
if group.LdapID == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, exists := ldapGroupIDs[*group.LdapID]; exists {
|
||||
continue
|
||||
}
|
||||
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Delete(&model.UserGroup{}, "ldap_id = ?", *group.LdapID).
|
||||
Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete group '%s': %w", group.Name, err)
|
||||
}
|
||||
|
||||
slog.Info("Deleted group", slog.String("group", group.Name))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
//nolint:gocognit
|
||||
func (s *LdapService) reconcileUsers(ctx context.Context, tx *gorm.DB, desiredUsers []ldapDesiredUser, ldapUserIDs map[string]struct{}) (savePictures []savePicture, deleteFiles []string, err error) {
|
||||
dbConfig := s.appConfigService.GetDbConfig()
|
||||
|
||||
// Load the current LDAP-managed state from the database
|
||||
ldapUsersInDB, ldapUsersByID, _, err := s.loadLDAPUsersInDB(ctx, tx)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to fetch users from database: %w", err)
|
||||
}
|
||||
|
||||
// Mark users as disabled or delete users that no longer exist in LDAP
|
||||
deleteFiles = make([]string, 0, len(ldapUserIDs))
|
||||
for _, user := range ldapUsersInDb {
|
||||
// Skip if the user ID exists in the fetched LDAP results
|
||||
// Apply creates and updates to match the desired LDAP user state
|
||||
savePictures = make([]savePicture, 0, len(desiredUsers))
|
||||
|
||||
for _, desiredUser := range desiredUsers {
|
||||
databaseUser := ldapUsersByID[desiredUser.ldapID]
|
||||
|
||||
// If a user is found (even if disabled), enable them since they're now back in LDAP.
|
||||
if databaseUser.ID != "" && databaseUser.Disabled {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Model(&model.User{}).
|
||||
Where("id = ?", databaseUser.ID).
|
||||
Update("disabled", false).
|
||||
Error
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to enable user %s: %w", databaseUser.Username, err)
|
||||
}
|
||||
|
||||
databaseUser.Disabled = false
|
||||
ldapUsersByID[desiredUser.ldapID] = databaseUser
|
||||
}
|
||||
|
||||
userID := databaseUser.ID
|
||||
if databaseUser.ID == "" {
|
||||
createdUser, err := s.userService.createUserInternal(ctx, desiredUser.input, true, tx)
|
||||
if errors.Is(err, &common.AlreadyInUseError{}) {
|
||||
slog.Warn("Skipping creating LDAP user", slog.String("username", desiredUser.input.Username), slog.Any("error", err))
|
||||
continue
|
||||
} else if err != nil {
|
||||
return nil, nil, fmt.Errorf("error creating user '%s': %w", desiredUser.input.Username, err)
|
||||
}
|
||||
|
||||
userID = createdUser.ID
|
||||
ldapUsersByID[desiredUser.ldapID] = createdUser
|
||||
} else {
|
||||
_, err = s.userService.updateUserInternal(ctx, databaseUser.ID, desiredUser.input, false, true, tx)
|
||||
if errors.Is(err, &common.AlreadyInUseError{}) {
|
||||
slog.Warn("Skipping updating LDAP user", slog.String("username", desiredUser.input.Username), slog.Any("error", err))
|
||||
continue
|
||||
} else if err != nil {
|
||||
return nil, nil, fmt.Errorf("error updating user '%s': %w", desiredUser.input.Username, err)
|
||||
}
|
||||
}
|
||||
|
||||
if desiredUser.picture != "" {
|
||||
savePictures = append(savePictures, savePicture{
|
||||
userID: userID,
|
||||
username: desiredUser.input.Username,
|
||||
picture: desiredUser.picture,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Disable or delete users that are no longer present in LDAP
|
||||
deleteFiles = make([]string, 0, len(ldapUsersInDB))
|
||||
for _, user := range ldapUsersInDB {
|
||||
if user.LdapID == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, exists := ldapUserIDs[*user.LdapID]; exists {
|
||||
continue
|
||||
}
|
||||
@@ -458,29 +552,73 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
|
||||
}
|
||||
|
||||
slog.Info("Disabled user", slog.String("username", user.Username))
|
||||
} else {
|
||||
err = s.userService.deleteUserInternal(ctx, tx, user.ID, true)
|
||||
if err != nil {
|
||||
target := &common.LdapUserUpdateError{}
|
||||
if errors.As(err, &target) {
|
||||
return nil, nil, fmt.Errorf("failed to delete user %s: LDAP user must be disabled before deletion", user.Username)
|
||||
}
|
||||
return nil, nil, fmt.Errorf("failed to delete user %s: %w", user.Username, err)
|
||||
}
|
||||
|
||||
slog.Info("Deleted user", slog.String("username", user.Username))
|
||||
|
||||
// Storage operations must be executed outside of a transaction
|
||||
deleteFiles = append(deleteFiles, path.Join("profile-pictures", user.ID+".png"))
|
||||
continue
|
||||
}
|
||||
|
||||
err = s.userService.deleteUserInternal(ctx, tx, user.ID, true)
|
||||
if err != nil {
|
||||
target := &common.LdapUserUpdateError{}
|
||||
if errors.As(err, &target) {
|
||||
return nil, nil, fmt.Errorf("failed to delete user %s: LDAP user must be disabled before deletion", user.Username)
|
||||
}
|
||||
return nil, nil, fmt.Errorf("failed to delete user %s: %w", user.Username, err)
|
||||
}
|
||||
|
||||
slog.Info("Deleted user", slog.String("username", user.Username))
|
||||
deleteFiles = append(deleteFiles, path.Join("profile-pictures", user.ID+".png"))
|
||||
}
|
||||
|
||||
return savePictures, deleteFiles, nil
|
||||
}
|
||||
|
||||
func (s *LdapService) loadLDAPUsersInDB(ctx context.Context, tx *gorm.DB) (users []model.User, byLdapID map[string]model.User, byUsername map[string]model.User, err error) {
|
||||
// Load all LDAP-managed users and index them by LDAP ID and by username
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Select("id, username, ldap_id, disabled").
|
||||
Where("ldap_id IS NOT NULL").
|
||||
Find(&users).
|
||||
Error
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
byLdapID = make(map[string]model.User, len(users))
|
||||
byUsername = make(map[string]model.User, len(users))
|
||||
for _, user := range users {
|
||||
byLdapID[*user.LdapID] = user
|
||||
byUsername[user.Username] = user
|
||||
}
|
||||
|
||||
return users, byLdapID, byUsername, nil
|
||||
}
|
||||
|
||||
func (s *LdapService) loadLDAPGroupsInDB(ctx context.Context, tx *gorm.DB) ([]model.UserGroup, map[string]model.UserGroup, error) {
|
||||
var groups []model.UserGroup
|
||||
|
||||
// Load all LDAP-managed groups and index them by LDAP ID
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Select("id, name, ldap_id").
|
||||
Where("ldap_id IS NOT NULL").
|
||||
Find(&groups).
|
||||
Error
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
groupsByID := make(map[string]model.UserGroup, len(groups))
|
||||
for _, group := range groups {
|
||||
groupsByID[*group.LdapID] = group
|
||||
}
|
||||
|
||||
return groups, groupsByID, nil
|
||||
}
|
||||
|
||||
func (s *LdapService) saveProfilePicture(parentCtx context.Context, userId string, pictureString string) error {
|
||||
var reader io.ReadSeeker
|
||||
|
||||
// Accept either a URL, a base64-encoded payload, or raw binary data
|
||||
_, err := url.ParseRequestURI(pictureString)
|
||||
if err == nil {
|
||||
ctx, cancel := context.WithTimeout(parentCtx, 15*time.Second)
|
||||
@@ -522,6 +660,31 @@ func (s *LdapService) saveProfilePicture(parentCtx context.Context, userId strin
|
||||
return nil
|
||||
}
|
||||
|
||||
// normalizeLDAPDN returns a canonical lowercase form of a DN for use as a map key.
|
||||
// Different LDAP servers may format the same DN with varying attribute type casing (e.g. "CN=" vs "cn=") or extra whitespace (e.g. "dc=example, dc=com").
|
||||
// Without normalization, cache lookups in usernamesByDN would miss when a member attribute value uses a different format than the DN returned in the search entry
|
||||
//
|
||||
// ldap.ParseDN is used instead of simple lowercasing because it correctly handles multi-valued RDNs (joined with "+") and strips inter-component whitespace.
|
||||
// If parsing fails for any reason, we fall back to a simple lowercase+trim.
|
||||
func normalizeLDAPDN(dn string) string {
|
||||
parsed, err := ldap.ParseDN(dn)
|
||||
if err != nil {
|
||||
return strings.ToLower(strings.TrimSpace(dn))
|
||||
}
|
||||
|
||||
// Reconstruct the DN in a canonical form: lowercase type=lowercase value, with RDN components separated by "," and multi-value attributes by "+"
|
||||
parts := make([]string, 0, len(parsed.RDNs))
|
||||
for _, rdn := range parsed.RDNs {
|
||||
attrs := make([]string, 0, len(rdn.Attributes))
|
||||
for _, attr := range rdn.Attributes {
|
||||
attrs = append(attrs, strings.ToLower(attr.Type)+"="+strings.ToLower(attr.Value))
|
||||
}
|
||||
parts = append(parts, strings.Join(attrs, "+"))
|
||||
}
|
||||
|
||||
return strings.Join(parts, ",")
|
||||
}
|
||||
|
||||
// getDNProperty returns the value of a property from a LDAP identifier
|
||||
// See: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/ldap/distinguished-names
|
||||
func getDNProperty(property string, str string) string {
|
||||
|
||||
@@ -1,9 +1,286 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/go-ldap/ldap/v3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/storage"
|
||||
testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
|
||||
)
|
||||
|
||||
type fakeLDAPClient struct {
|
||||
searchFn func(searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error)
|
||||
}
|
||||
|
||||
func (c *fakeLDAPClient) Search(searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error) {
|
||||
if c.searchFn == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return c.searchFn(searchRequest)
|
||||
}
|
||||
|
||||
func (c *fakeLDAPClient) Bind(_, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeLDAPClient) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestLdapServiceSyncAllReconcilesUsersAndGroups(t *testing.T) {
|
||||
service, db := newTestLdapService(t, newFakeLDAPClient(
|
||||
ldapSearchResult(
|
||||
ldapEntry("uid=alice,ou=people,dc=example,dc=com", map[string][]string{
|
||||
"entryUUID": {"u-alice"},
|
||||
"uid": {"alice"},
|
||||
"mail": {"alice@example.com"},
|
||||
"givenName": {"Alice"},
|
||||
"sn": {"Jones"},
|
||||
"displayName": {""},
|
||||
"memberOf": {"cn=admins,ou=groups,dc=example,dc=com"},
|
||||
}),
|
||||
ldapEntry("uid=bob,ou=people,dc=example,dc=com", map[string][]string{
|
||||
"entryUUID": {"u-bob"},
|
||||
"uid": {"bob"},
|
||||
"mail": {"bob@example.com"},
|
||||
"givenName": {"Bob"},
|
||||
"sn": {"Brown"},
|
||||
"displayName": {""},
|
||||
}),
|
||||
),
|
||||
ldapSearchResult(
|
||||
ldapEntry("cn=team,ou=groups,dc=example,dc=com", map[string][]string{
|
||||
"entryUUID": {"g-team"},
|
||||
"cn": {"team"},
|
||||
"member": {
|
||||
"UID=Alice, OU=People, DC=example, DC=com",
|
||||
"uid=bob, ou=people, dc=example, dc=com",
|
||||
},
|
||||
}),
|
||||
),
|
||||
))
|
||||
|
||||
aliceLdapID := "u-alice"
|
||||
missingLdapID := "u-missing"
|
||||
teamLdapID := "g-team"
|
||||
oldGroupLdapID := "g-old"
|
||||
|
||||
require.NoError(t, db.Create(&model.User{
|
||||
Username: "alice-old",
|
||||
Email: new("alice-old@example.com"),
|
||||
EmailVerified: true,
|
||||
FirstName: "Old",
|
||||
LastName: "Name",
|
||||
DisplayName: "Old Name",
|
||||
LdapID: &aliceLdapID,
|
||||
Disabled: true,
|
||||
}).Error)
|
||||
|
||||
require.NoError(t, db.Create(&model.User{
|
||||
Username: "missing",
|
||||
Email: new("missing@example.com"),
|
||||
EmailVerified: true,
|
||||
FirstName: "Missing",
|
||||
LastName: "User",
|
||||
DisplayName: "Missing User",
|
||||
LdapID: &missingLdapID,
|
||||
}).Error)
|
||||
|
||||
require.NoError(t, db.Create(&model.UserGroup{
|
||||
Name: "team-old",
|
||||
FriendlyName: "team-old",
|
||||
LdapID: &teamLdapID,
|
||||
}).Error)
|
||||
|
||||
require.NoError(t, db.Create(&model.UserGroup{
|
||||
Name: "old-group",
|
||||
FriendlyName: "old-group",
|
||||
LdapID: &oldGroupLdapID,
|
||||
}).Error)
|
||||
|
||||
require.NoError(t, service.SyncAll(t.Context()))
|
||||
|
||||
var alice model.User
|
||||
require.NoError(t, db.First(&alice, "ldap_id = ?", aliceLdapID).Error)
|
||||
assert.Equal(t, "alice", alice.Username)
|
||||
assert.Equal(t, new("alice@example.com"), alice.Email)
|
||||
assert.Equal(t, "Alice", alice.FirstName)
|
||||
assert.Equal(t, "Jones", alice.LastName)
|
||||
assert.Equal(t, "Alice Jones", alice.DisplayName)
|
||||
assert.True(t, alice.IsAdmin)
|
||||
assert.False(t, alice.Disabled)
|
||||
|
||||
var bob model.User
|
||||
require.NoError(t, db.First(&bob, "ldap_id = ?", "u-bob").Error)
|
||||
assert.Equal(t, "bob", bob.Username)
|
||||
assert.Equal(t, "Bob Brown", bob.DisplayName)
|
||||
|
||||
var missing model.User
|
||||
require.NoError(t, db.First(&missing, "ldap_id = ?", missingLdapID).Error)
|
||||
assert.True(t, missing.Disabled)
|
||||
|
||||
var oldGroupCount int64
|
||||
require.NoError(t, db.Model(&model.UserGroup{}).Where("ldap_id = ?", oldGroupLdapID).Count(&oldGroupCount).Error)
|
||||
assert.Zero(t, oldGroupCount)
|
||||
|
||||
var team model.UserGroup
|
||||
require.NoError(t, db.Preload("Users").First(&team, "ldap_id = ?", teamLdapID).Error)
|
||||
assert.Equal(t, "team", team.Name)
|
||||
assert.Equal(t, "team", team.FriendlyName)
|
||||
assert.ElementsMatch(t, []string{"alice", "bob"}, usernames(team.Users))
|
||||
}
|
||||
|
||||
func TestLdapServiceSyncAllHandlesDuplicateLDAPIDsInSingleRun(t *testing.T) {
|
||||
service, db := newTestLdapService(t, newFakeLDAPClient(
|
||||
ldapSearchResult(
|
||||
ldapEntry("uid=alice,ou=people,dc=example,dc=com", map[string][]string{
|
||||
"entryUUID": {"u-dup"},
|
||||
"uid": {"alice"},
|
||||
"mail": {"alice@example.com"},
|
||||
"givenName": {"Alice"},
|
||||
"sn": {"Doe"},
|
||||
"displayName": {"Alice Doe"},
|
||||
}),
|
||||
ldapEntry("uid=alice,ou=people,dc=example,dc=com", map[string][]string{
|
||||
"entryUUID": {"u-dup"},
|
||||
"uid": {"alice"},
|
||||
"mail": {"alice@example.com"},
|
||||
"givenName": {"Alicia"},
|
||||
"sn": {"Doe"},
|
||||
"displayName": {"Alicia Doe"},
|
||||
}),
|
||||
),
|
||||
ldapSearchResult(
|
||||
ldapEntry("cn=team,ou=groups,dc=example,dc=com", map[string][]string{
|
||||
"entryUUID": {"g-dup"},
|
||||
"cn": {"team"},
|
||||
"member": {"uid=alice,ou=people,dc=example,dc=com"},
|
||||
}),
|
||||
ldapEntry("cn=team,ou=groups,dc=example,dc=com", map[string][]string{
|
||||
"entryUUID": {"g-dup"},
|
||||
"cn": {"team-renamed"},
|
||||
"member": {"uid=alice,ou=people,dc=example,dc=com"},
|
||||
}),
|
||||
),
|
||||
))
|
||||
|
||||
require.NoError(t, service.SyncAll(t.Context()))
|
||||
|
||||
var users []model.User
|
||||
require.NoError(t, db.Find(&users, "ldap_id = ?", "u-dup").Error)
|
||||
require.Len(t, users, 1)
|
||||
assert.Equal(t, "alice", users[0].Username)
|
||||
assert.Equal(t, "Alicia", users[0].FirstName)
|
||||
assert.Equal(t, "Alicia Doe", users[0].DisplayName)
|
||||
|
||||
var groups []model.UserGroup
|
||||
require.NoError(t, db.Preload("Users").Find(&groups, "ldap_id = ?", "g-dup").Error)
|
||||
require.Len(t, groups, 1)
|
||||
assert.Equal(t, "team-renamed", groups[0].Name)
|
||||
assert.Equal(t, "team-renamed", groups[0].FriendlyName)
|
||||
assert.ElementsMatch(t, []string{"alice"}, usernames(groups[0].Users))
|
||||
}
|
||||
|
||||
func newTestLdapService(t *testing.T, client ldapClient) (*LdapService, *gorm.DB) {
|
||||
t.Helper()
|
||||
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
|
||||
fileStorage, err := storage.NewDatabaseStorage(db)
|
||||
require.NoError(t, err)
|
||||
|
||||
appConfig := NewTestAppConfigService(&model.AppConfig{
|
||||
RequireUserEmail: model.AppConfigVariable{Value: "false"},
|
||||
LdapEnabled: model.AppConfigVariable{Value: "true"},
|
||||
LdapBase: model.AppConfigVariable{Value: "dc=example,dc=com"},
|
||||
LdapUserSearchFilter: model.AppConfigVariable{Value: "(objectClass=person)"},
|
||||
LdapUserGroupSearchFilter: model.AppConfigVariable{Value: "(objectClass=groupOfNames)"},
|
||||
LdapAttributeUserUniqueIdentifier: model.AppConfigVariable{Value: "entryUUID"},
|
||||
LdapAttributeUserUsername: model.AppConfigVariable{Value: "uid"},
|
||||
LdapAttributeUserEmail: model.AppConfigVariable{Value: "mail"},
|
||||
LdapAttributeUserFirstName: model.AppConfigVariable{Value: "givenName"},
|
||||
LdapAttributeUserLastName: model.AppConfigVariable{Value: "sn"},
|
||||
LdapAttributeUserDisplayName: model.AppConfigVariable{Value: "displayName"},
|
||||
LdapAttributeUserProfilePicture: model.AppConfigVariable{Value: "jpegPhoto"},
|
||||
LdapAttributeGroupMember: model.AppConfigVariable{Value: "member"},
|
||||
LdapAttributeGroupUniqueIdentifier: model.AppConfigVariable{Value: "entryUUID"},
|
||||
LdapAttributeGroupName: model.AppConfigVariable{Value: "cn"},
|
||||
LdapAdminGroupName: model.AppConfigVariable{Value: "admins"},
|
||||
LdapSoftDeleteUsers: model.AppConfigVariable{Value: "true"},
|
||||
})
|
||||
|
||||
groupService := NewUserGroupService(db, appConfig, nil)
|
||||
userService := NewUserService(
|
||||
db,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
appConfig,
|
||||
NewCustomClaimService(db),
|
||||
NewAppImagesService(map[string]string{}, fileStorage),
|
||||
nil,
|
||||
fileStorage,
|
||||
)
|
||||
|
||||
service := NewLdapService(db, &http.Client{}, appConfig, userService, groupService, fileStorage)
|
||||
service.clientFactory = func() (ldapClient, error) {
|
||||
return client, nil
|
||||
}
|
||||
|
||||
return service, db
|
||||
}
|
||||
|
||||
func newFakeLDAPClient(userResult, groupResult *ldap.SearchResult) ldapClient {
|
||||
return &fakeLDAPClient{
|
||||
searchFn: func(searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error) {
|
||||
switch searchRequest.Filter {
|
||||
case "(objectClass=person)":
|
||||
return userResult, nil
|
||||
case "(objectClass=groupOfNames)":
|
||||
return groupResult, nil
|
||||
default:
|
||||
return &ldap.SearchResult{}, nil
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func ldapSearchResult(entries ...*ldap.Entry) *ldap.SearchResult {
|
||||
return &ldap.SearchResult{Entries: entries}
|
||||
}
|
||||
|
||||
func ldapEntry(dn string, attrs map[string][]string) *ldap.Entry {
|
||||
entry := &ldap.Entry{
|
||||
DN: dn,
|
||||
Attributes: make([]*ldap.EntryAttribute, 0, len(attrs)),
|
||||
}
|
||||
|
||||
for name, values := range attrs {
|
||||
entry.Attributes = append(entry.Attributes, &ldap.EntryAttribute{
|
||||
Name: name,
|
||||
Values: values,
|
||||
})
|
||||
}
|
||||
|
||||
return entry
|
||||
}
|
||||
|
||||
func usernames(users []model.User) []string {
|
||||
result := make([]string, 0, len(users))
|
||||
for _, user := range users {
|
||||
result = append(result, user.Username)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func TestGetDNProperty(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -64,10 +341,58 @@ func TestGetDNProperty(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := getDNProperty(tt.property, tt.dn)
|
||||
if result != tt.expectedResult {
|
||||
t.Errorf("getDNProperty(%q, %q) = %q, want %q",
|
||||
tt.property, tt.dn, result, tt.expectedResult)
|
||||
}
|
||||
assert.Equalf(t, tt.expectedResult, result, "getDNProperty(%q, %q)", tt.property, tt.dn)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeLDAPDN(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "already normalized",
|
||||
input: "cn=alice,dc=example,dc=com",
|
||||
expected: "cn=alice,dc=example,dc=com",
|
||||
},
|
||||
{
|
||||
name: "uppercase attribute types",
|
||||
input: "CN=Alice,DC=example,DC=com",
|
||||
expected: "cn=alice,dc=example,dc=com",
|
||||
},
|
||||
{
|
||||
name: "spaces after commas",
|
||||
input: "cn=alice, dc=example, dc=com",
|
||||
expected: "cn=alice,dc=example,dc=com",
|
||||
},
|
||||
{
|
||||
name: "uppercase types and spaces",
|
||||
input: "CN=Alice, DC=example, DC=com",
|
||||
expected: "cn=alice,dc=example,dc=com",
|
||||
},
|
||||
{
|
||||
name: "multi-valued RDN",
|
||||
input: "cn=alice+uid=a123,dc=example,dc=com",
|
||||
expected: "cn=alice+uid=a123,dc=example,dc=com",
|
||||
},
|
||||
{
|
||||
name: "invalid DN falls back to lowercase+trim",
|
||||
input: " NOT A VALID DN ",
|
||||
expected: "not a valid dn",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := normalizeLDAPDN(tt.input)
|
||||
assert.Equalf(t, tt.expected, result, "normalizeLDAPDN(%q)", tt.input)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -98,9 +423,7 @@ func TestConvertLdapIdToString(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := convertLdapIdToString(tt.input)
|
||||
if got != tt.expected {
|
||||
t.Errorf("Expected %q, got %q", tt.expected, got)
|
||||
}
|
||||
assert.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,7 +96,10 @@ func (s *UserGroupService) Delete(ctx context.Context, id string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
s.scimService.ScheduleSync()
|
||||
if s.scimService != nil {
|
||||
s.scimService.ScheduleSync()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -126,7 +129,10 @@ func (s *UserGroupService) createInternal(ctx context.Context, input dto.UserGro
|
||||
return model.UserGroup{}, err
|
||||
}
|
||||
|
||||
s.scimService.ScheduleSync()
|
||||
if s.scimService != nil {
|
||||
s.scimService.ScheduleSync()
|
||||
}
|
||||
|
||||
return group, nil
|
||||
}
|
||||
|
||||
@@ -175,7 +181,10 @@ func (s *UserGroupService) updateInternal(ctx context.Context, id string, input
|
||||
return model.UserGroup{}, err
|
||||
}
|
||||
|
||||
s.scimService.ScheduleSync()
|
||||
if s.scimService != nil {
|
||||
s.scimService.ScheduleSync()
|
||||
}
|
||||
|
||||
return group, nil
|
||||
}
|
||||
|
||||
@@ -238,7 +247,10 @@ func (s *UserGroupService) updateUsersInternal(ctx context.Context, id string, u
|
||||
return model.UserGroup{}, err
|
||||
}
|
||||
|
||||
s.scimService.ScheduleSync()
|
||||
if s.scimService != nil {
|
||||
s.scimService.ScheduleSync()
|
||||
}
|
||||
|
||||
return group, nil
|
||||
}
|
||||
|
||||
@@ -315,6 +327,9 @@ func (s *UserGroupService) UpdateAllowedOidcClient(ctx context.Context, id strin
|
||||
return model.UserGroup{}, err
|
||||
}
|
||||
|
||||
s.scimService.ScheduleSync()
|
||||
if s.scimService != nil {
|
||||
s.scimService.ScheduleSync()
|
||||
}
|
||||
|
||||
return group, nil
|
||||
}
|
||||
|
||||
@@ -225,7 +225,10 @@ func (s *UserService) deleteUserInternal(ctx context.Context, tx *gorm.DB, userI
|
||||
return fmt.Errorf("failed to delete user: %w", err)
|
||||
}
|
||||
|
||||
s.scimService.ScheduleSync()
|
||||
if s.scimService != nil {
|
||||
s.scimService.ScheduleSync()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -310,7 +313,10 @@ func (s *UserService) createUserInternal(ctx context.Context, input dto.UserCrea
|
||||
}
|
||||
}
|
||||
|
||||
s.scimService.ScheduleSync()
|
||||
if s.scimService != nil {
|
||||
s.scimService.ScheduleSync()
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
@@ -456,7 +462,10 @@ func (s *UserService) updateUserInternal(ctx context.Context, userID string, upd
|
||||
return user, err
|
||||
}
|
||||
|
||||
s.scimService.ScheduleSync()
|
||||
if s.scimService != nil {
|
||||
s.scimService.ScheduleSync()
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
@@ -515,7 +524,10 @@ func (s *UserService) UpdateUserGroups(ctx context.Context, id string, userGroup
|
||||
return model.User{}, err
|
||||
}
|
||||
|
||||
s.scimService.ScheduleSync()
|
||||
if s.scimService != nil {
|
||||
s.scimService.ScheduleSync()
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
@@ -576,7 +588,10 @@ func (s *UserService) disableUserInternal(ctx context.Context, tx *gorm.DB, user
|
||||
return err
|
||||
}
|
||||
|
||||
s.scimService.ScheduleSync()
|
||||
if s.scimService != nil {
|
||||
s.scimService.ScheduleSync()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user