mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-03-27 13:06:56 +00:00
Compare commits
1 Commits
main
...
callback-u
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
76da41f126 |
@@ -89,12 +89,19 @@ type OidcController struct {
|
||||
// @Router /api/oidc/authorize [post]
|
||||
func (oc *OidcController) authorizeHandler(c *gin.Context) {
|
||||
var input dto.AuthorizeOidcClientRequestDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
err := c.ShouldBindJSON(&input)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
code, callbackURL, err := oc.oidcService.Authorize(c.Request.Context(), input, c.GetString("userID"), c.ClientIP(), c.Request.UserAgent())
|
||||
code, callbackURL, err := oc.oidcService.Authorize(
|
||||
c.Request.Context(),
|
||||
input,
|
||||
c.GetString("userID"),
|
||||
c.ClientIP(),
|
||||
c.Request.UserAgent(),
|
||||
)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
|
||||
@@ -33,8 +33,8 @@ type OidcClientWithAllowedGroupsCountDto struct {
|
||||
|
||||
type OidcClientUpdateDto struct {
|
||||
Name string `json:"name" binding:"required,max=50" unorm:"nfc"`
|
||||
CallbackURLs []string `json:"callbackURLs" binding:"omitempty,dive,callback_url"`
|
||||
LogoutCallbackURLs []string `json:"logoutCallbackURLs" binding:"omitempty,dive,callback_url"`
|
||||
CallbackURLs []string `json:"callbackURLs" binding:"omitempty,dive,callback_url_pattern"`
|
||||
LogoutCallbackURLs []string `json:"logoutCallbackURLs" binding:"omitempty,dive,callback_url_pattern"`
|
||||
IsPublic bool `json:"isPublic"`
|
||||
PkceEnabled bool `json:"pkceEnabled"`
|
||||
RequiresReauthentication bool `json:"requiresReauthentication"`
|
||||
@@ -66,7 +66,7 @@ type OidcClientFederatedIdentityDto struct {
|
||||
type AuthorizeOidcClientRequestDto struct {
|
||||
ClientID string `json:"clientID" binding:"required"`
|
||||
Scope string `json:"scope" binding:"required"`
|
||||
CallbackURL string `json:"callbackURL"`
|
||||
CallbackURL string `json:"callbackURL" binding:"omitempty,callback_url"`
|
||||
Nonce string `json:"nonce"`
|
||||
CodeChallenge string `json:"codeChallenge"`
|
||||
CodeChallengeMethod string `json:"codeChallengeMethod"`
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
@@ -19,38 +21,38 @@ var validateUsernameRegex = regexp.MustCompile("^[a-zA-Z0-9]([a-zA-Z0-9_.@-]*[a-
|
||||
var validateClientIDRegex = regexp.MustCompile("^[a-zA-Z0-9._-]+$")
|
||||
|
||||
func init() {
|
||||
v := binding.Validator.Engine().(*validator.Validate)
|
||||
engine := binding.Validator.Engine().(*validator.Validate)
|
||||
|
||||
// Maximum allowed value for TTLs
|
||||
const maxTTL = 31 * 24 * time.Hour
|
||||
|
||||
if err := v.RegisterValidation("username", func(fl validator.FieldLevel) bool {
|
||||
return ValidateUsername(fl.Field().String())
|
||||
}); err != nil {
|
||||
panic("Failed to register custom validation for username: " + err.Error())
|
||||
validators := map[string]validator.Func{
|
||||
"username": func(fl validator.FieldLevel) bool {
|
||||
return ValidateUsername(fl.Field().String())
|
||||
},
|
||||
"client_id": func(fl validator.FieldLevel) bool {
|
||||
return ValidateClientID(fl.Field().String())
|
||||
},
|
||||
"ttl": func(fl validator.FieldLevel) bool {
|
||||
ttl, ok := fl.Field().Interface().(utils.JSONDuration)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
// Allow zero, which means the field wasn't set
|
||||
return ttl.Duration == 0 || (ttl.Duration > time.Second && ttl.Duration <= maxTTL)
|
||||
},
|
||||
"callback_url": func(fl validator.FieldLevel) bool {
|
||||
return ValidateCallbackURL(fl.Field().String())
|
||||
},
|
||||
"callback_url_pattern": func(fl validator.FieldLevel) bool {
|
||||
return ValidateCallbackURLPattern(fl.Field().String())
|
||||
},
|
||||
}
|
||||
|
||||
if err := v.RegisterValidation("client_id", func(fl validator.FieldLevel) bool {
|
||||
return ValidateClientID(fl.Field().String())
|
||||
}); err != nil {
|
||||
panic("Failed to register custom validation for client_id: " + err.Error())
|
||||
}
|
||||
|
||||
if err := v.RegisterValidation("ttl", func(fl validator.FieldLevel) bool {
|
||||
ttl, ok := fl.Field().Interface().(utils.JSONDuration)
|
||||
if !ok {
|
||||
return false
|
||||
for k, v := range validators {
|
||||
err := engine.RegisterValidation(k, v)
|
||||
if err != nil {
|
||||
panic("Failed to register custom validation for " + k + ": " + err.Error())
|
||||
}
|
||||
// Allow zero, which means the field wasn't set
|
||||
return ttl.Duration == 0 || (ttl.Duration > time.Second && ttl.Duration <= maxTTL)
|
||||
}); err != nil {
|
||||
panic("Failed to register custom validation for ttl: " + err.Error())
|
||||
}
|
||||
|
||||
if err := v.RegisterValidation("callback_url", func(fl validator.FieldLevel) bool {
|
||||
return ValidateCallbackURL(fl.Field().String())
|
||||
}); err != nil {
|
||||
panic("Failed to register custom validation for callback_url: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,8 +66,24 @@ func ValidateClientID(clientID string) bool {
|
||||
return validateClientIDRegex.MatchString(clientID)
|
||||
}
|
||||
|
||||
// ValidateCallbackURL validates callback URLs with support for wildcards
|
||||
func ValidateCallbackURL(raw string) bool {
|
||||
// ValidateCallbackURL validates the input callback URL
|
||||
func ValidateCallbackURL(str string) bool {
|
||||
// Ensure the URL is a valid one and that the protocol is not "javascript:" or "data:"
|
||||
u, err := url.Parse(str)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
switch strings.ToLower(u.Scheme) {
|
||||
case "javascript", "data":
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateCallbackURLPattern validates callback URL patterns, with support for wildcards
|
||||
func ValidateCallbackURLPattern(raw string) bool {
|
||||
err := utils.ValidateCallbackURLPattern(raw)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
@@ -57,3 +57,28 @@ func TestValidateClientID(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateCallbackURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected bool
|
||||
}{
|
||||
{"valid https URL", "https://example.com/callback", true},
|
||||
{"valid loopback URL", "http://127.0.0.1:49813/callback", true},
|
||||
{"empty scheme", "//127.0.0.1:49813/callback", true},
|
||||
{"valid custom scheme", "pocketid://callback", true},
|
||||
{"invalid malformed URL", "http://[::1", false},
|
||||
{"invalid missing scheme separator", "://example.com/callback", false},
|
||||
{"rejects javascript scheme", "javascript:alert(1)", false},
|
||||
{"rejects mixed case javascript scheme", "JavaScript:alert(1)", false},
|
||||
{"rejects data scheme", "data:text/html;base64,PGgxPkhlbGxvPC9oMT4=", false},
|
||||
{"rejects mixed case data scheme", "DaTa:text/html;base64,PGgxPkhlbGxvPC9oMT4=", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, ValidateCallbackURL(tt.input))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -125,9 +125,7 @@ func (s *OidcService) getJWKCache(ctx context.Context) (*jwk.Cache, error) {
|
||||
|
||||
func (s *OidcService) Authorize(ctx context.Context, input dto.AuthorizeOidcClientRequestDto, userID, ipAddress, userAgent string) (string, string, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
defer tx.Rollback()
|
||||
|
||||
var client model.OidcClient
|
||||
err := tx.
|
||||
|
||||
@@ -71,19 +71,16 @@
|
||||
reauthToken = await webauthnService.reauthenticate(authResponse);
|
||||
}
|
||||
|
||||
await oidService
|
||||
.authorize(
|
||||
client!.id,
|
||||
scope,
|
||||
callbackURL,
|
||||
nonce,
|
||||
codeChallenge,
|
||||
codeChallengeMethod,
|
||||
reauthToken
|
||||
)
|
||||
.then(async ({ code, callbackURL, issuer }) => {
|
||||
onSuccess(code, callbackURL, issuer);
|
||||
});
|
||||
const authResult = await oidService.authorize(
|
||||
client!.id,
|
||||
scope,
|
||||
callbackURL,
|
||||
nonce,
|
||||
codeChallenge,
|
||||
codeChallengeMethod,
|
||||
reauthToken
|
||||
);
|
||||
onSuccess(authResult.code, authResult.callbackURL, authResult.issuer);
|
||||
} catch (e) {
|
||||
errorMessage = getWebauthnErrorMessage(e);
|
||||
isLoading = false;
|
||||
@@ -91,13 +88,17 @@
|
||||
}
|
||||
|
||||
function onSuccess(code: string, callbackURL: string, issuer: string) {
|
||||
const redirectURL = new URL(callbackURL);
|
||||
if (redirectURL.protocol == 'javascript:' || redirectURL.protocol == 'data:') {
|
||||
throw new Error('Invalid redirect URL protocol');
|
||||
}
|
||||
|
||||
redirectURL.searchParams.append('code', code);
|
||||
redirectURL.searchParams.append('state', authorizeState);
|
||||
redirectURL.searchParams.append('iss', issuer);
|
||||
|
||||
success = true;
|
||||
setTimeout(() => {
|
||||
const redirectURL = new URL(callbackURL);
|
||||
redirectURL.searchParams.append('code', code);
|
||||
redirectURL.searchParams.append('state', authorizeState);
|
||||
redirectURL.searchParams.append('iss', issuer);
|
||||
|
||||
window.location.href = redirectURL.toString();
|
||||
}, 1000);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user