mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-02-23 13:33:52 +00:00
Compare commits
1 Commits
main
...
fix/callba
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
718c3e74eb |
@@ -85,7 +85,7 @@ func initRouter(db *gorm.DB, svc *services) (utils.Service, error) {
|
||||
controller.NewAuditLogController(apiGroup, svc.auditLogService, authMiddleware)
|
||||
controller.NewUserGroupController(apiGroup, authMiddleware, svc.userGroupService)
|
||||
controller.NewCustomClaimController(apiGroup, authMiddleware, svc.customClaimService)
|
||||
controller.NewVersionController(apiGroup, authMiddleware, svc.versionService)
|
||||
controller.NewVersionController(apiGroup, svc.versionService)
|
||||
controller.NewScimController(apiGroup, authMiddleware, svc.scimService)
|
||||
controller.NewUserSignupController(apiGroup, authMiddleware, middleware.NewRateLimitMiddleware(), svc.userSignUpService, svc.appConfigService)
|
||||
|
||||
|
||||
@@ -5,17 +5,14 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/middleware"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/service"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
)
|
||||
|
||||
// NewVersionController registers version-related routes.
|
||||
func NewVersionController(group *gin.RouterGroup, authMiddleware *middleware.AuthMiddleware, versionService *service.VersionService) {
|
||||
func NewVersionController(group *gin.RouterGroup, versionService *service.VersionService) {
|
||||
vc := &VersionController{versionService: versionService}
|
||||
group.GET("/version/latest", vc.getLatestVersionHandler)
|
||||
group.GET("/version/current", authMiddleware.WithAdminNotRequired().Add(), vc.getCurrentVersionHandler)
|
||||
}
|
||||
|
||||
type VersionController struct {
|
||||
@@ -41,16 +38,3 @@ func (vc *VersionController) getLatestVersionHandler(c *gin.Context) {
|
||||
"latestVersion": tag,
|
||||
})
|
||||
}
|
||||
|
||||
// getCurrentVersionHandler godoc
|
||||
// @Summary Get current deployed version of Pocket ID
|
||||
// @Tags Version
|
||||
// @Produce json
|
||||
// @Success 200 {object} map[string]string "Current version information"
|
||||
// @Router /api/version/current [get]
|
||||
func (vc *VersionController) getCurrentVersionHandler(c *gin.Context) {
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"currentVersion": common.Version,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
@@ -67,19 +65,5 @@ func ValidateClientID(clientID string) bool {
|
||||
|
||||
// ValidateCallbackURL validates callback URLs with support for wildcards
|
||||
func ValidateCallbackURL(raw string) bool {
|
||||
// Don't validate if it contains a wildcard
|
||||
if strings.Contains(raw, "*") {
|
||||
return true
|
||||
}
|
||||
|
||||
u, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !u.IsAbs() {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
return utils.ValidateCallbackURLPattern(raw) == nil
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||
@@ -194,7 +193,6 @@ func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
|
||||
Expiration(now.Add(s.appConfigService.GetDbConfig().SessionDuration.AsDurationMinutes())).
|
||||
IssuedAt(now).
|
||||
Issuer(s.envConfig.AppURL).
|
||||
JwtID(uuid.New().String()).
|
||||
Build()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to build token: %w", err)
|
||||
@@ -249,7 +247,6 @@ func (s *JwtService) BuildIDToken(userClaims map[string]any, clientID string, no
|
||||
Expiration(now.Add(1 * time.Hour)).
|
||||
IssuedAt(now).
|
||||
Issuer(s.envConfig.AppURL).
|
||||
JwtID(uuid.New().String()).
|
||||
Build()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build token: %w", err)
|
||||
@@ -339,7 +336,6 @@ func (s *JwtService) BuildOAuthAccessToken(user model.User, clientID string) (jw
|
||||
Expiration(now.Add(1 * time.Hour)).
|
||||
IssuedAt(now).
|
||||
Issuer(s.envConfig.AppURL).
|
||||
JwtID(uuid.New().String()).
|
||||
Build()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build token: %w", err)
|
||||
|
||||
@@ -27,8 +27,6 @@ import (
|
||||
|
||||
const testEncryptionKey = "0123456789abcdef0123456789abcdef"
|
||||
|
||||
const uuidRegexPattern = "^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$"
|
||||
|
||||
func newTestEnvConfig() *common.EnvConfigSchema {
|
||||
return &common.EnvConfigSchema{
|
||||
AppURL: "https://test.example.com",
|
||||
@@ -325,9 +323,6 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
||||
audience, ok := claims.Audience()
|
||||
_ = assert.True(t, ok, "Audience not found in token") &&
|
||||
assert.Equal(t, []string{service.envConfig.AppURL}, audience, "Audience should contain the app URL")
|
||||
jwtID, ok := claims.JwtID()
|
||||
_ = assert.True(t, ok, "JWT ID not found in token") &&
|
||||
assert.Regexp(t, uuidRegexPattern, jwtID, "JWT ID is not a UUID")
|
||||
|
||||
expectedExp := time.Now().Add(1 * time.Hour)
|
||||
expiration, ok := claims.Expiration()
|
||||
@@ -525,9 +520,6 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
issuer, ok := claims.Issuer()
|
||||
_ = assert.True(t, ok, "Issuer not found in token") &&
|
||||
assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL")
|
||||
jwtID, ok := claims.JwtID()
|
||||
_ = assert.True(t, ok, "JWT ID not found in token") &&
|
||||
assert.Regexp(t, uuidRegexPattern, jwtID, "JWT ID is not a UUID")
|
||||
|
||||
expectedExp := time.Now().Add(1 * time.Hour)
|
||||
expiration, ok := claims.Expiration()
|
||||
@@ -762,9 +754,6 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
|
||||
issuer, ok := claims.Issuer()
|
||||
_ = assert.True(t, ok, "Issuer not found in token") &&
|
||||
assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL")
|
||||
jwtID, ok := claims.JwtID()
|
||||
_ = assert.True(t, ok, "JWT ID not found in token") &&
|
||||
assert.Regexp(t, uuidRegexPattern, jwtID, "JWT ID is not a UUID")
|
||||
|
||||
expectedExp := time.Now().Add(1 * time.Hour)
|
||||
expiration, ok := claims.Expiration()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/url"
|
||||
"path"
|
||||
@@ -8,7 +9,38 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// GetCallbackURLFromList returns the first callback URL that matches the input callback URL
|
||||
const (
|
||||
patternParseSchemePlaceholder = "https"
|
||||
patternParsePortPlaceholder = "65535"
|
||||
)
|
||||
|
||||
var errInvalidCallbackURLPattern = errors.New("invalid callback URL pattern")
|
||||
|
||||
type callbackURLPattern struct {
|
||||
SchemePattern string
|
||||
HasUserInfo bool
|
||||
UsernamePattern string
|
||||
HasPassword bool
|
||||
PasswordPattern string
|
||||
HostnamePattern string
|
||||
HasPort bool
|
||||
PortPattern string
|
||||
PathPattern string
|
||||
}
|
||||
|
||||
type callbackURLValue struct {
|
||||
Scheme string
|
||||
HasUserInfo bool
|
||||
Username string
|
||||
HasPassword bool
|
||||
Password string
|
||||
Hostname string
|
||||
HasPort bool
|
||||
Port string
|
||||
Path string
|
||||
}
|
||||
|
||||
// GetCallbackURLFromList returns the first callback URL that matches the input callback URL.
|
||||
func GetCallbackURLFromList(urls []string, inputCallbackURL string) (callbackURL string, err error) {
|
||||
// Special case for Loopback Interface Redirection. Quoting from RFC 8252 section 7.3:
|
||||
// https://datatracker.ietf.org/doc/html/rfc8252#section-7.3
|
||||
@@ -24,7 +56,12 @@ func GetCallbackURLFromList(urls []string, inputCallbackURL string) (callbackURL
|
||||
host := u.Hostname()
|
||||
ip := net.ParseIP(host)
|
||||
if host == "localhost" || (ip != nil && ip.IsLoopback()) {
|
||||
u.Host = host
|
||||
// For IPv6 loopback hosts, brackets are required when serializing without a port.
|
||||
if strings.Contains(host, ":") {
|
||||
u.Host = "[" + host + "]"
|
||||
} else {
|
||||
u.Host = host
|
||||
}
|
||||
loopbackCallbackURLWithoutPort = u.String()
|
||||
}
|
||||
}
|
||||
@@ -54,6 +91,61 @@ func GetCallbackURLFromList(urls []string, inputCallbackURL string) (callbackURL
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// ValidateCallbackURLPattern validates callback URL patterns, including wildcard patterns.
|
||||
func ValidateCallbackURLPattern(raw string) error {
|
||||
if raw == "*" {
|
||||
return nil
|
||||
}
|
||||
|
||||
raw, _, _ = strings.Cut(raw, "#")
|
||||
base, rawQuery, hasQuery := strings.Cut(raw, "?")
|
||||
|
||||
if hasQuery {
|
||||
query, err := url.ParseQuery(rawQuery)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, values := range query {
|
||||
for _, value := range values {
|
||||
if err := validateGlobPattern(value); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pattern, err := parseCallbackURLPattern(base)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateGlobPattern(pattern.SchemePattern); err != nil {
|
||||
return err
|
||||
}
|
||||
if pattern.HasUserInfo {
|
||||
if err := validateGlobPattern(pattern.UsernamePattern); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if pattern.HasPassword {
|
||||
if err := validateGlobPattern(pattern.PasswordPattern); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for _, segment := range splitHostLabels(pattern.HostnamePattern) {
|
||||
if err := validateGlobPattern(segment); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if pattern.HasPort {
|
||||
if err := validateGlobPattern(pattern.PortPattern); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// matchCallbackURL checks if the input callback URL matches the given pattern.
|
||||
// It supports wildcard matching for paths and query parameters.
|
||||
//
|
||||
@@ -64,53 +156,92 @@ func matchCallbackURL(pattern string, inputCallbackURL string) (matches bool, er
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Strip fragment part
|
||||
// Strip fragment part.
|
||||
// The endpoint URI MUST NOT include a fragment component.
|
||||
// https://datatracker.ietf.org/doc/html/rfc6749#section-3.1.2
|
||||
pattern, _, _ = strings.Cut(pattern, "#")
|
||||
inputCallbackURL, _, _ = strings.Cut(inputCallbackURL, "#")
|
||||
|
||||
// Store and strip query part
|
||||
patternBase, rawPatternQuery, patternHasQuery := strings.Cut(pattern, "?")
|
||||
inputBase, rawInputQuery, inputHasQuery := strings.Cut(inputCallbackURL, "?")
|
||||
|
||||
// Store and parse query parts.
|
||||
var patternQuery url.Values
|
||||
if i := strings.Index(pattern, "?"); i >= 0 {
|
||||
patternQuery, err = url.ParseQuery(pattern[i+1:])
|
||||
if patternHasQuery {
|
||||
patternQuery, err = url.ParseQuery(rawPatternQuery)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
pattern = pattern[:i]
|
||||
}
|
||||
var inputQuery url.Values
|
||||
if i := strings.Index(inputCallbackURL, "?"); i >= 0 {
|
||||
inputQuery, err = url.ParseQuery(inputCallbackURL[i+1:])
|
||||
if inputHasQuery {
|
||||
inputQuery, err = url.ParseQuery(rawInputQuery)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
inputCallbackURL = inputCallbackURL[:i]
|
||||
}
|
||||
|
||||
// Split both pattern and input parts
|
||||
patternParts, patternPath := splitParts(pattern)
|
||||
inputParts, inputPath := splitParts(inputCallbackURL)
|
||||
|
||||
// Verify everything except the path and query parameters
|
||||
if len(patternParts) != len(inputParts) {
|
||||
patternURL, err := parseCallbackURLPattern(patternBase)
|
||||
if err != nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for i, patternPart := range patternParts {
|
||||
matched, err := path.Match(patternPart, inputParts[i])
|
||||
inputURL, err := parseCallbackURLValue(inputBase)
|
||||
if err != nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Verify scheme.
|
||||
matched, err := path.Match(patternURL.SchemePattern, inputURL.Scheme)
|
||||
if err != nil || !matched {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Verify userinfo.
|
||||
if patternURL.HasUserInfo != inputURL.HasUserInfo {
|
||||
return false, nil
|
||||
}
|
||||
if patternURL.HasUserInfo {
|
||||
matched, err = path.Match(patternURL.UsernamePattern, inputURL.Username)
|
||||
if err != nil || !matched {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if patternURL.HasPassword != inputURL.HasPassword {
|
||||
return false, nil
|
||||
}
|
||||
if patternURL.HasPassword {
|
||||
matched, err = path.Match(patternURL.PasswordPattern, inputURL.Password)
|
||||
if err != nil || !matched {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify host.
|
||||
matched, err = matchHostPattern(patternURL.HostnamePattern, inputURL.Hostname)
|
||||
if err != nil || !matched {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Verify port.
|
||||
if patternURL.HasPort != inputURL.HasPort {
|
||||
return false, nil
|
||||
}
|
||||
if patternURL.HasPort {
|
||||
matched, err = path.Match(patternURL.PortPattern, inputURL.Port)
|
||||
if err != nil || !matched {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
// Verify path with wildcard support
|
||||
matched, err := matchPath(patternPath, inputPath)
|
||||
// Verify path with wildcard support.
|
||||
matched, err = matchPath(patternURL.PathPattern, inputURL.Path)
|
||||
if err != nil || !matched {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Verify query parameters
|
||||
// Verify query parameters.
|
||||
if len(patternQuery) != len(inputQuery) {
|
||||
return false, nil
|
||||
}
|
||||
@@ -126,7 +257,7 @@ func matchCallbackURL(pattern string, inputCallbackURL string) (matches bool, er
|
||||
}
|
||||
|
||||
for i := range patternValues {
|
||||
matched, err := path.Match(patternValues[i], inputValues[i])
|
||||
matched, err = path.Match(patternValues[i], inputValues[i])
|
||||
if err != nil || !matched {
|
||||
return false, err
|
||||
}
|
||||
@@ -136,6 +267,205 @@ func matchCallbackURL(pattern string, inputCallbackURL string) (matches bool, er
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func parseCallbackURLPattern(raw string) (callbackURLPattern, error) {
|
||||
schemePattern, rest, hasScheme := strings.Cut(raw, "://")
|
||||
if !hasScheme || schemePattern == "" {
|
||||
return callbackURLPattern{}, errInvalidCallbackURLPattern
|
||||
}
|
||||
|
||||
authority := rest
|
||||
pathPattern := ""
|
||||
if i := strings.IndexRune(rest, '/'); i >= 0 {
|
||||
authority = rest[:i]
|
||||
pathPattern = rest[i:]
|
||||
}
|
||||
if authority == "" {
|
||||
return callbackURLPattern{}, errInvalidCallbackURLPattern
|
||||
}
|
||||
|
||||
userinfoPattern, hostPort, hasUserinfo := splitAuthority(authority)
|
||||
if hostPort == "" {
|
||||
return callbackURLPattern{}, errInvalidCallbackURLPattern
|
||||
}
|
||||
|
||||
hasPassword := false
|
||||
usernamePattern := ""
|
||||
passwordPattern := ""
|
||||
if hasUserinfo {
|
||||
usernamePattern, passwordPattern, hasPassword = strings.Cut(userinfoPattern, ":")
|
||||
}
|
||||
|
||||
sanitizedAuthority, wildcardPortPattern, err := sanitizePatternAuthority(authority)
|
||||
if err != nil {
|
||||
return callbackURLPattern{}, err
|
||||
}
|
||||
|
||||
sanitizedScheme := schemePattern
|
||||
if strings.ContainsRune(schemePattern, '*') {
|
||||
sanitizedScheme = patternParseSchemePlaceholder
|
||||
}
|
||||
|
||||
sanitizedURL := sanitizedScheme + "://" + sanitizedAuthority + pathPattern
|
||||
u, err := url.Parse(sanitizedURL)
|
||||
if err != nil || !u.IsAbs() || u.Hostname() == "" {
|
||||
return callbackURLPattern{}, errInvalidCallbackURLPattern
|
||||
}
|
||||
|
||||
portPattern := u.Port()
|
||||
hasPort := portPattern != ""
|
||||
if wildcardPortPattern != "" {
|
||||
portPattern = wildcardPortPattern
|
||||
hasPort = true
|
||||
}
|
||||
|
||||
return callbackURLPattern{
|
||||
SchemePattern: schemePattern,
|
||||
HasUserInfo: hasUserinfo,
|
||||
UsernamePattern: usernamePattern,
|
||||
HasPassword: hasPassword,
|
||||
PasswordPattern: passwordPattern,
|
||||
HostnamePattern: u.Hostname(),
|
||||
HasPort: hasPort,
|
||||
PortPattern: portPattern,
|
||||
PathPattern: pathPattern,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseCallbackURLValue(raw string) (callbackURLValue, error) {
|
||||
u, err := url.Parse(raw)
|
||||
if err != nil || !u.IsAbs() || u.Hostname() == "" {
|
||||
return callbackURLValue{}, errInvalidCallbackURLPattern
|
||||
}
|
||||
|
||||
hasUserinfo := u.User != nil
|
||||
username := ""
|
||||
password := ""
|
||||
hasPassword := false
|
||||
if hasUserinfo {
|
||||
username = u.User.Username()
|
||||
password, hasPassword = u.User.Password()
|
||||
}
|
||||
|
||||
resolvedPath := u.EscapedPath()
|
||||
if resolvedPath == "" {
|
||||
resolvedPath = u.Path
|
||||
}
|
||||
|
||||
return callbackURLValue{
|
||||
Scheme: u.Scheme,
|
||||
HasUserInfo: hasUserinfo,
|
||||
Username: username,
|
||||
HasPassword: hasPassword,
|
||||
Password: password,
|
||||
Hostname: u.Hostname(),
|
||||
HasPort: u.Port() != "",
|
||||
Port: u.Port(),
|
||||
Path: resolvedPath,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func sanitizePatternAuthority(authority string) (sanitizedAuthority string, wildcardPortPattern string, err error) {
|
||||
userinfo, hostPort, hasUserinfo := splitAuthority(authority)
|
||||
if hostPort == "" {
|
||||
return "", "", errInvalidCallbackURLPattern
|
||||
}
|
||||
|
||||
sanitizedHostPort := hostPort
|
||||
|
||||
if strings.HasPrefix(hostPort, "[") {
|
||||
end := strings.Index(hostPort, "]")
|
||||
if end < 0 {
|
||||
return "", "", errInvalidCallbackURLPattern
|
||||
}
|
||||
|
||||
rest := hostPort[end+1:]
|
||||
if rest != "" {
|
||||
if !strings.HasPrefix(rest, ":") {
|
||||
return "", "", errInvalidCallbackURLPattern
|
||||
}
|
||||
|
||||
port := rest[1:]
|
||||
if port == "" {
|
||||
return "", "", errInvalidCallbackURLPattern
|
||||
}
|
||||
|
||||
if strings.Contains(port, "*") {
|
||||
sanitizedHostPort = hostPort[:end+1] + ":" + patternParsePortPlaceholder
|
||||
wildcardPortPattern = port
|
||||
}
|
||||
}
|
||||
} else {
|
||||
lastColon := strings.LastIndex(hostPort, ":")
|
||||
if lastColon >= 0 {
|
||||
hostCandidate := hostPort[:lastColon]
|
||||
portCandidate := hostPort[lastColon+1:]
|
||||
|
||||
isBareIPv6WithoutPort := strings.Count(hostPort, ":") > 1 && net.ParseIP(hostPort) != nil
|
||||
if !isBareIPv6WithoutPort {
|
||||
if hostCandidate == "" || portCandidate == "" {
|
||||
return "", "", errInvalidCallbackURLPattern
|
||||
}
|
||||
|
||||
if strings.Contains(portCandidate, "*") {
|
||||
sanitizedHostPort = hostCandidate + ":" + patternParsePortPlaceholder
|
||||
wildcardPortPattern = portCandidate
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasUserinfo {
|
||||
sanitizedAuthority = userinfo + "@" + sanitizedHostPort
|
||||
} else {
|
||||
sanitizedAuthority = sanitizedHostPort
|
||||
}
|
||||
|
||||
return sanitizedAuthority, wildcardPortPattern, nil
|
||||
}
|
||||
|
||||
func splitAuthority(authority string) (userinfo string, hostPort string, hasUserinfo bool) {
|
||||
lastAt := strings.LastIndex(authority, "@")
|
||||
if lastAt < 0 {
|
||||
return "", authority, false
|
||||
}
|
||||
|
||||
return authority[:lastAt], authority[lastAt+1:], true
|
||||
}
|
||||
|
||||
func matchHostPattern(patternHost, inputHost string) (bool, error) {
|
||||
patternSegments := splitHostLabels(patternHost)
|
||||
inputSegments := splitHostLabels(inputHost)
|
||||
|
||||
if len(patternSegments) != len(inputSegments) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for i := range patternSegments {
|
||||
matched, err := path.Match(patternSegments[i], inputSegments[i])
|
||||
if err != nil || !matched {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func splitHostLabels(host string) []string {
|
||||
if strings.Contains(host, ":") {
|
||||
return []string{host}
|
||||
}
|
||||
|
||||
return strings.Split(host, ".")
|
||||
}
|
||||
|
||||
func validateGlobPattern(pattern string) error {
|
||||
if _, err := path.Match(pattern, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// matchPath matches the input path against the pattern with wildcard support
|
||||
// Supported wildcards:
|
||||
//
|
||||
@@ -172,35 +502,3 @@ func matchPath(pattern string, input string) (matches bool, err error) {
|
||||
matched, err := regexp.MatchString(regexPattern.String(), input)
|
||||
return matched, err
|
||||
}
|
||||
|
||||
// splitParts splits the URL into parts by special characters and returns the path separately
|
||||
func splitParts(s string) (parts []string, path string) {
|
||||
split := func(r rune) bool {
|
||||
return r == ':' || r == '/' || r == '[' || r == ']' || r == '@' || r == '.'
|
||||
}
|
||||
|
||||
pathStart := -1
|
||||
|
||||
// Look for scheme:// first
|
||||
if i := strings.Index(s, "://"); i >= 0 {
|
||||
// Look for the next slash after scheme://
|
||||
rest := s[i+3:]
|
||||
if j := strings.IndexRune(rest, '/'); j >= 0 {
|
||||
pathStart = i + 3 + j
|
||||
}
|
||||
} else {
|
||||
// Otherwise, first slash is path start
|
||||
pathStart = strings.IndexRune(s, '/')
|
||||
}
|
||||
|
||||
if pathStart >= 0 {
|
||||
path = s[pathStart:]
|
||||
base := s[:pathStart]
|
||||
parts = strings.FieldsFunc(base, split)
|
||||
} else {
|
||||
parts = strings.FieldsFunc(s, split)
|
||||
path = ""
|
||||
}
|
||||
|
||||
return parts, path
|
||||
}
|
||||
|
||||
@@ -91,6 +91,18 @@ func TestMatchCallbackURL(t *testing.T) {
|
||||
"https://еxample.com/callback",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"userinfo prefix doesn't bypass exact hostname",
|
||||
"https://auth.company.com/callback",
|
||||
"https://auth@company.com/callback",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"userinfo prefix doesn't bypass wildcard subdomain",
|
||||
"https://*.victim.com/callback",
|
||||
"https://wildcard@victim.com/callback",
|
||||
false,
|
||||
},
|
||||
|
||||
// Port
|
||||
{
|
||||
@@ -706,86 +718,78 @@ func TestMatchPath(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitParts(t *testing.T) {
|
||||
func TestValidateCallbackURLPattern(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedParts []string
|
||||
expectedPath string
|
||||
name string
|
||||
pattern string
|
||||
shouldError bool
|
||||
}{
|
||||
{
|
||||
name: "simple https URL",
|
||||
input: "https://example.com/callback",
|
||||
expectedParts: []string{"https", "example", "com"},
|
||||
expectedPath: "/callback",
|
||||
name: "exact URL",
|
||||
pattern: "https://example.com/callback",
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "URL with port",
|
||||
input: "https://example.com:8080/callback",
|
||||
expectedParts: []string{"https", "example", "com", "8080"},
|
||||
expectedPath: "/callback",
|
||||
name: "wildcard scheme",
|
||||
pattern: "*://example.com/callback",
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "URL with subdomain",
|
||||
input: "https://api.example.com/callback",
|
||||
expectedParts: []string{"https", "api", "example", "com"},
|
||||
expectedPath: "/callback",
|
||||
name: "wildcard port",
|
||||
pattern: "https://example.com:*/callback",
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "URL with credentials",
|
||||
input: "https://user:pass@example.com/callback",
|
||||
expectedParts: []string{"https", "user", "pass", "example", "com"},
|
||||
expectedPath: "/callback",
|
||||
name: "partial wildcard port",
|
||||
pattern: "https://example.com:80*/callback",
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "URL without path",
|
||||
input: "https://example.com",
|
||||
expectedParts: []string{"https", "example", "com"},
|
||||
expectedPath: "",
|
||||
name: "wildcard userinfo",
|
||||
pattern: "https://user:*@example.com/callback",
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "URL with deep path",
|
||||
input: "https://example.com/api/v1/callback",
|
||||
expectedParts: []string{"https", "example", "com"},
|
||||
expectedPath: "/api/v1/callback",
|
||||
name: "glob wildcard",
|
||||
pattern: "*",
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "URL with path and query",
|
||||
input: "https://example.com/callback?code=123",
|
||||
expectedParts: []string{"https", "example", "com"},
|
||||
expectedPath: "/callback?code=123",
|
||||
name: "relative URL",
|
||||
pattern: "/callback",
|
||||
shouldError: true,
|
||||
},
|
||||
{
|
||||
name: "URL with trailing slash",
|
||||
input: "https://example.com/",
|
||||
expectedParts: []string{"https", "example", "com"},
|
||||
expectedPath: "/",
|
||||
name: "missing scheme separator",
|
||||
pattern: "https//example.com/callback",
|
||||
shouldError: true,
|
||||
},
|
||||
{
|
||||
name: "URL with multiple subdomains",
|
||||
input: "https://api.v1.staging.example.com/callback",
|
||||
expectedParts: []string{"https", "api", "v1", "staging", "example", "com"},
|
||||
expectedPath: "/callback",
|
||||
name: "malformed wildcard host glob",
|
||||
pattern: "https://exa[mple.com/callback",
|
||||
shouldError: true,
|
||||
},
|
||||
{
|
||||
name: "URL with port and credentials",
|
||||
input: "https://user:pass@example.com:8080/callback",
|
||||
expectedParts: []string{"https", "user", "pass", "example", "com", "8080"},
|
||||
expectedPath: "/callback",
|
||||
name: "malformed wildcard query glob",
|
||||
pattern: "https://example.com/callback?code=[abc",
|
||||
shouldError: true,
|
||||
},
|
||||
{
|
||||
name: "scheme with authority separator but no slash",
|
||||
input: "http://example.com",
|
||||
expectedParts: []string{"http", "example", "com"},
|
||||
expectedPath: "",
|
||||
name: "malformed authority",
|
||||
pattern: "https://[::1/callback",
|
||||
shouldError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parts, path := splitParts(tt.input)
|
||||
assert.Equal(t, tt.expectedParts, parts, "parts mismatch")
|
||||
assert.Equal(t, tt.expectedPath, path, "path mismatch")
|
||||
err := ValidateCallbackURLPattern(tt.pattern)
|
||||
if tt.shouldError {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user