Compare commits

...

1 Commits

Author SHA1 Message Date
Elias Schneider
718c3e74eb fix: use structured URL component matching for callback URLs 2026-02-22 22:55:59 +01:00
3 changed files with 409 additions and 123 deletions

View File

@@ -1,9 +1,7 @@
package dto
import (
"net/url"
"regexp"
"strings"
"time"
"github.com/pocket-id/pocket-id/backend/internal/utils"
@@ -67,19 +65,5 @@ func ValidateClientID(clientID string) bool {
// ValidateCallbackURL validates callback URLs with support for wildcards
func ValidateCallbackURL(raw string) bool {
// Don't validate if it contains a wildcard
if strings.Contains(raw, "*") {
return true
}
u, err := url.Parse(raw)
if err != nil {
return false
}
if !u.IsAbs() {
return false
}
return true
return utils.ValidateCallbackURLPattern(raw) == nil
}

View File

@@ -1,6 +1,7 @@
package utils
import (
"errors"
"net"
"net/url"
"path"
@@ -8,7 +9,38 @@ import (
"strings"
)
// GetCallbackURLFromList returns the first callback URL that matches the input callback URL
const (
patternParseSchemePlaceholder = "https"
patternParsePortPlaceholder = "65535"
)
var errInvalidCallbackURLPattern = errors.New("invalid callback URL pattern")
type callbackURLPattern struct {
SchemePattern string
HasUserInfo bool
UsernamePattern string
HasPassword bool
PasswordPattern string
HostnamePattern string
HasPort bool
PortPattern string
PathPattern string
}
type callbackURLValue struct {
Scheme string
HasUserInfo bool
Username string
HasPassword bool
Password string
Hostname string
HasPort bool
Port string
Path string
}
// GetCallbackURLFromList returns the first callback URL that matches the input callback URL.
func GetCallbackURLFromList(urls []string, inputCallbackURL string) (callbackURL string, err error) {
// Special case for Loopback Interface Redirection. Quoting from RFC 8252 section 7.3:
// https://datatracker.ietf.org/doc/html/rfc8252#section-7.3
@@ -24,7 +56,12 @@ func GetCallbackURLFromList(urls []string, inputCallbackURL string) (callbackURL
host := u.Hostname()
ip := net.ParseIP(host)
if host == "localhost" || (ip != nil && ip.IsLoopback()) {
u.Host = host
// For IPv6 loopback hosts, brackets are required when serializing without a port.
if strings.Contains(host, ":") {
u.Host = "[" + host + "]"
} else {
u.Host = host
}
loopbackCallbackURLWithoutPort = u.String()
}
}
@@ -54,6 +91,61 @@ func GetCallbackURLFromList(urls []string, inputCallbackURL string) (callbackURL
return "", nil
}
// ValidateCallbackURLPattern validates callback URL patterns, including wildcard patterns.
func ValidateCallbackURLPattern(raw string) error {
if raw == "*" {
return nil
}
raw, _, _ = strings.Cut(raw, "#")
base, rawQuery, hasQuery := strings.Cut(raw, "?")
if hasQuery {
query, err := url.ParseQuery(rawQuery)
if err != nil {
return err
}
for _, values := range query {
for _, value := range values {
if err := validateGlobPattern(value); err != nil {
return err
}
}
}
}
pattern, err := parseCallbackURLPattern(base)
if err != nil {
return err
}
if err := validateGlobPattern(pattern.SchemePattern); err != nil {
return err
}
if pattern.HasUserInfo {
if err := validateGlobPattern(pattern.UsernamePattern); err != nil {
return err
}
}
if pattern.HasPassword {
if err := validateGlobPattern(pattern.PasswordPattern); err != nil {
return err
}
}
for _, segment := range splitHostLabels(pattern.HostnamePattern) {
if err := validateGlobPattern(segment); err != nil {
return err
}
}
if pattern.HasPort {
if err := validateGlobPattern(pattern.PortPattern); err != nil {
return err
}
}
return nil
}
// matchCallbackURL checks if the input callback URL matches the given pattern.
// It supports wildcard matching for paths and query parameters.
//
@@ -64,53 +156,92 @@ func matchCallbackURL(pattern string, inputCallbackURL string) (matches bool, er
return true, nil
}
// Strip fragment part
// Strip fragment part.
// The endpoint URI MUST NOT include a fragment component.
// https://datatracker.ietf.org/doc/html/rfc6749#section-3.1.2
pattern, _, _ = strings.Cut(pattern, "#")
inputCallbackURL, _, _ = strings.Cut(inputCallbackURL, "#")
// Store and strip query part
patternBase, rawPatternQuery, patternHasQuery := strings.Cut(pattern, "?")
inputBase, rawInputQuery, inputHasQuery := strings.Cut(inputCallbackURL, "?")
// Store and parse query parts.
var patternQuery url.Values
if i := strings.Index(pattern, "?"); i >= 0 {
patternQuery, err = url.ParseQuery(pattern[i+1:])
if patternHasQuery {
patternQuery, err = url.ParseQuery(rawPatternQuery)
if err != nil {
return false, err
}
pattern = pattern[:i]
}
var inputQuery url.Values
if i := strings.Index(inputCallbackURL, "?"); i >= 0 {
inputQuery, err = url.ParseQuery(inputCallbackURL[i+1:])
if inputHasQuery {
inputQuery, err = url.ParseQuery(rawInputQuery)
if err != nil {
return false, err
}
inputCallbackURL = inputCallbackURL[:i]
}
// Split both pattern and input parts
patternParts, patternPath := splitParts(pattern)
inputParts, inputPath := splitParts(inputCallbackURL)
// Verify everything except the path and query parameters
if len(patternParts) != len(inputParts) {
patternURL, err := parseCallbackURLPattern(patternBase)
if err != nil {
return false, nil
}
for i, patternPart := range patternParts {
matched, err := path.Match(patternPart, inputParts[i])
inputURL, err := parseCallbackURLValue(inputBase)
if err != nil {
return false, nil
}
// Verify scheme.
matched, err := path.Match(patternURL.SchemePattern, inputURL.Scheme)
if err != nil || !matched {
return false, err
}
// Verify userinfo.
if patternURL.HasUserInfo != inputURL.HasUserInfo {
return false, nil
}
if patternURL.HasUserInfo {
matched, err = path.Match(patternURL.UsernamePattern, inputURL.Username)
if err != nil || !matched {
return false, err
}
if patternURL.HasPassword != inputURL.HasPassword {
return false, nil
}
if patternURL.HasPassword {
matched, err = path.Match(patternURL.PasswordPattern, inputURL.Password)
if err != nil || !matched {
return false, err
}
}
}
// Verify host.
matched, err = matchHostPattern(patternURL.HostnamePattern, inputURL.Hostname)
if err != nil || !matched {
return false, err
}
// Verify port.
if patternURL.HasPort != inputURL.HasPort {
return false, nil
}
if patternURL.HasPort {
matched, err = path.Match(patternURL.PortPattern, inputURL.Port)
if err != nil || !matched {
return false, err
}
}
// Verify path with wildcard support
matched, err := matchPath(patternPath, inputPath)
// Verify path with wildcard support.
matched, err = matchPath(patternURL.PathPattern, inputURL.Path)
if err != nil || !matched {
return false, err
}
// Verify query parameters
// Verify query parameters.
if len(patternQuery) != len(inputQuery) {
return false, nil
}
@@ -126,7 +257,7 @@ func matchCallbackURL(pattern string, inputCallbackURL string) (matches bool, er
}
for i := range patternValues {
matched, err := path.Match(patternValues[i], inputValues[i])
matched, err = path.Match(patternValues[i], inputValues[i])
if err != nil || !matched {
return false, err
}
@@ -136,6 +267,205 @@ func matchCallbackURL(pattern string, inputCallbackURL string) (matches bool, er
return true, nil
}
func parseCallbackURLPattern(raw string) (callbackURLPattern, error) {
schemePattern, rest, hasScheme := strings.Cut(raw, "://")
if !hasScheme || schemePattern == "" {
return callbackURLPattern{}, errInvalidCallbackURLPattern
}
authority := rest
pathPattern := ""
if i := strings.IndexRune(rest, '/'); i >= 0 {
authority = rest[:i]
pathPattern = rest[i:]
}
if authority == "" {
return callbackURLPattern{}, errInvalidCallbackURLPattern
}
userinfoPattern, hostPort, hasUserinfo := splitAuthority(authority)
if hostPort == "" {
return callbackURLPattern{}, errInvalidCallbackURLPattern
}
hasPassword := false
usernamePattern := ""
passwordPattern := ""
if hasUserinfo {
usernamePattern, passwordPattern, hasPassword = strings.Cut(userinfoPattern, ":")
}
sanitizedAuthority, wildcardPortPattern, err := sanitizePatternAuthority(authority)
if err != nil {
return callbackURLPattern{}, err
}
sanitizedScheme := schemePattern
if strings.ContainsRune(schemePattern, '*') {
sanitizedScheme = patternParseSchemePlaceholder
}
sanitizedURL := sanitizedScheme + "://" + sanitizedAuthority + pathPattern
u, err := url.Parse(sanitizedURL)
if err != nil || !u.IsAbs() || u.Hostname() == "" {
return callbackURLPattern{}, errInvalidCallbackURLPattern
}
portPattern := u.Port()
hasPort := portPattern != ""
if wildcardPortPattern != "" {
portPattern = wildcardPortPattern
hasPort = true
}
return callbackURLPattern{
SchemePattern: schemePattern,
HasUserInfo: hasUserinfo,
UsernamePattern: usernamePattern,
HasPassword: hasPassword,
PasswordPattern: passwordPattern,
HostnamePattern: u.Hostname(),
HasPort: hasPort,
PortPattern: portPattern,
PathPattern: pathPattern,
}, nil
}
func parseCallbackURLValue(raw string) (callbackURLValue, error) {
u, err := url.Parse(raw)
if err != nil || !u.IsAbs() || u.Hostname() == "" {
return callbackURLValue{}, errInvalidCallbackURLPattern
}
hasUserinfo := u.User != nil
username := ""
password := ""
hasPassword := false
if hasUserinfo {
username = u.User.Username()
password, hasPassword = u.User.Password()
}
resolvedPath := u.EscapedPath()
if resolvedPath == "" {
resolvedPath = u.Path
}
return callbackURLValue{
Scheme: u.Scheme,
HasUserInfo: hasUserinfo,
Username: username,
HasPassword: hasPassword,
Password: password,
Hostname: u.Hostname(),
HasPort: u.Port() != "",
Port: u.Port(),
Path: resolvedPath,
}, nil
}
func sanitizePatternAuthority(authority string) (sanitizedAuthority string, wildcardPortPattern string, err error) {
userinfo, hostPort, hasUserinfo := splitAuthority(authority)
if hostPort == "" {
return "", "", errInvalidCallbackURLPattern
}
sanitizedHostPort := hostPort
if strings.HasPrefix(hostPort, "[") {
end := strings.Index(hostPort, "]")
if end < 0 {
return "", "", errInvalidCallbackURLPattern
}
rest := hostPort[end+1:]
if rest != "" {
if !strings.HasPrefix(rest, ":") {
return "", "", errInvalidCallbackURLPattern
}
port := rest[1:]
if port == "" {
return "", "", errInvalidCallbackURLPattern
}
if strings.Contains(port, "*") {
sanitizedHostPort = hostPort[:end+1] + ":" + patternParsePortPlaceholder
wildcardPortPattern = port
}
}
} else {
lastColon := strings.LastIndex(hostPort, ":")
if lastColon >= 0 {
hostCandidate := hostPort[:lastColon]
portCandidate := hostPort[lastColon+1:]
isBareIPv6WithoutPort := strings.Count(hostPort, ":") > 1 && net.ParseIP(hostPort) != nil
if !isBareIPv6WithoutPort {
if hostCandidate == "" || portCandidate == "" {
return "", "", errInvalidCallbackURLPattern
}
if strings.Contains(portCandidate, "*") {
sanitizedHostPort = hostCandidate + ":" + patternParsePortPlaceholder
wildcardPortPattern = portCandidate
}
}
}
}
if hasUserinfo {
sanitizedAuthority = userinfo + "@" + sanitizedHostPort
} else {
sanitizedAuthority = sanitizedHostPort
}
return sanitizedAuthority, wildcardPortPattern, nil
}
func splitAuthority(authority string) (userinfo string, hostPort string, hasUserinfo bool) {
lastAt := strings.LastIndex(authority, "@")
if lastAt < 0 {
return "", authority, false
}
return authority[:lastAt], authority[lastAt+1:], true
}
func matchHostPattern(patternHost, inputHost string) (bool, error) {
patternSegments := splitHostLabels(patternHost)
inputSegments := splitHostLabels(inputHost)
if len(patternSegments) != len(inputSegments) {
return false, nil
}
for i := range patternSegments {
matched, err := path.Match(patternSegments[i], inputSegments[i])
if err != nil || !matched {
return false, err
}
}
return true, nil
}
func splitHostLabels(host string) []string {
if strings.Contains(host, ":") {
return []string{host}
}
return strings.Split(host, ".")
}
func validateGlobPattern(pattern string) error {
if _, err := path.Match(pattern, ""); err != nil {
return err
}
return nil
}
// matchPath matches the input path against the pattern with wildcard support
// Supported wildcards:
//
@@ -172,35 +502,3 @@ func matchPath(pattern string, input string) (matches bool, err error) {
matched, err := regexp.MatchString(regexPattern.String(), input)
return matched, err
}
// splitParts splits the URL into parts by special characters and returns the path separately
func splitParts(s string) (parts []string, path string) {
split := func(r rune) bool {
return r == ':' || r == '/' || r == '[' || r == ']' || r == '@' || r == '.'
}
pathStart := -1
// Look for scheme:// first
if i := strings.Index(s, "://"); i >= 0 {
// Look for the next slash after scheme://
rest := s[i+3:]
if j := strings.IndexRune(rest, '/'); j >= 0 {
pathStart = i + 3 + j
}
} else {
// Otherwise, first slash is path start
pathStart = strings.IndexRune(s, '/')
}
if pathStart >= 0 {
path = s[pathStart:]
base := s[:pathStart]
parts = strings.FieldsFunc(base, split)
} else {
parts = strings.FieldsFunc(s, split)
path = ""
}
return parts, path
}

View File

@@ -91,6 +91,18 @@ func TestMatchCallbackURL(t *testing.T) {
"https://еxample.com/callback",
false,
},
{
"userinfo prefix doesn't bypass exact hostname",
"https://auth.company.com/callback",
"https://auth@company.com/callback",
false,
},
{
"userinfo prefix doesn't bypass wildcard subdomain",
"https://*.victim.com/callback",
"https://wildcard@victim.com/callback",
false,
},
// Port
{
@@ -706,86 +718,78 @@ func TestMatchPath(t *testing.T) {
}
}
func TestSplitParts(t *testing.T) {
func TestValidateCallbackURLPattern(t *testing.T) {
tests := []struct {
name string
input string
expectedParts []string
expectedPath string
name string
pattern string
shouldError bool
}{
{
name: "simple https URL",
input: "https://example.com/callback",
expectedParts: []string{"https", "example", "com"},
expectedPath: "/callback",
name: "exact URL",
pattern: "https://example.com/callback",
shouldError: false,
},
{
name: "URL with port",
input: "https://example.com:8080/callback",
expectedParts: []string{"https", "example", "com", "8080"},
expectedPath: "/callback",
name: "wildcard scheme",
pattern: "*://example.com/callback",
shouldError: false,
},
{
name: "URL with subdomain",
input: "https://api.example.com/callback",
expectedParts: []string{"https", "api", "example", "com"},
expectedPath: "/callback",
name: "wildcard port",
pattern: "https://example.com:*/callback",
shouldError: false,
},
{
name: "URL with credentials",
input: "https://user:pass@example.com/callback",
expectedParts: []string{"https", "user", "pass", "example", "com"},
expectedPath: "/callback",
name: "partial wildcard port",
pattern: "https://example.com:80*/callback",
shouldError: false,
},
{
name: "URL without path",
input: "https://example.com",
expectedParts: []string{"https", "example", "com"},
expectedPath: "",
name: "wildcard userinfo",
pattern: "https://user:*@example.com/callback",
shouldError: false,
},
{
name: "URL with deep path",
input: "https://example.com/api/v1/callback",
expectedParts: []string{"https", "example", "com"},
expectedPath: "/api/v1/callback",
name: "glob wildcard",
pattern: "*",
shouldError: false,
},
{
name: "URL with path and query",
input: "https://example.com/callback?code=123",
expectedParts: []string{"https", "example", "com"},
expectedPath: "/callback?code=123",
name: "relative URL",
pattern: "/callback",
shouldError: true,
},
{
name: "URL with trailing slash",
input: "https://example.com/",
expectedParts: []string{"https", "example", "com"},
expectedPath: "/",
name: "missing scheme separator",
pattern: "https//example.com/callback",
shouldError: true,
},
{
name: "URL with multiple subdomains",
input: "https://api.v1.staging.example.com/callback",
expectedParts: []string{"https", "api", "v1", "staging", "example", "com"},
expectedPath: "/callback",
name: "malformed wildcard host glob",
pattern: "https://exa[mple.com/callback",
shouldError: true,
},
{
name: "URL with port and credentials",
input: "https://user:pass@example.com:8080/callback",
expectedParts: []string{"https", "user", "pass", "example", "com", "8080"},
expectedPath: "/callback",
name: "malformed wildcard query glob",
pattern: "https://example.com/callback?code=[abc",
shouldError: true,
},
{
name: "scheme with authority separator but no slash",
input: "http://example.com",
expectedParts: []string{"http", "example", "com"},
expectedPath: "",
name: "malformed authority",
pattern: "https://[::1/callback",
shouldError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parts, path := splitParts(tt.input)
assert.Equal(t, tt.expectedParts, parts, "parts mismatch")
assert.Equal(t, tt.expectedPath, path, "path mismatch")
err := ValidateCallbackURLPattern(tt.pattern)
if tt.shouldError {
require.Error(t, err)
return
}
require.NoError(t, err)
})
}
}