mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-03-03 01:10:20 +00:00
Compare commits
1 Commits
main
...
fix-callba
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b088fa94e0 |
@@ -34,22 +34,7 @@ func GetCallbackURLFromList(urls []string, inputCallbackURL string) (callbackURL
|
|||||||
// time of the request for loopback IP redirect URIs, to accommodate
|
// time of the request for loopback IP redirect URIs, to accommodate
|
||||||
// clients that obtain an available ephemeral port from the operating
|
// clients that obtain an available ephemeral port from the operating
|
||||||
// system at the time of the request.
|
// system at the time of the request.
|
||||||
loopbackCallbackURLWithoutPort := ""
|
loopbackCallbackURLWithoutPort := loopbackURLWithWildcardPort(inputCallbackURL)
|
||||||
u, _ := url.Parse(inputCallbackURL)
|
|
||||||
|
|
||||||
if u != nil && u.Scheme == "http" {
|
|
||||||
host := u.Hostname()
|
|
||||||
ip := net.ParseIP(host)
|
|
||||||
if host == "localhost" || (ip != nil && ip.IsLoopback()) {
|
|
||||||
// For IPv6 loopback hosts, brackets are required when serializing without a port.
|
|
||||||
if strings.Contains(host, ":") {
|
|
||||||
u.Host = "[" + host + "]"
|
|
||||||
} else {
|
|
||||||
u.Host = host
|
|
||||||
}
|
|
||||||
loopbackCallbackURLWithoutPort = u.String()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, pattern := range urls {
|
for _, pattern := range urls {
|
||||||
// Try the original callback first
|
// Try the original callback first
|
||||||
@@ -76,6 +61,28 @@ func GetCallbackURLFromList(urls []string, inputCallbackURL string) (callbackURL
|
|||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func loopbackURLWithWildcardPort(input string) string {
|
||||||
|
u, _ := url.Parse(input)
|
||||||
|
|
||||||
|
if u == nil || u.Scheme != "http" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
host := u.Hostname()
|
||||||
|
ip := net.ParseIP(host)
|
||||||
|
if host != "localhost" && (ip == nil || !ip.IsLoopback()) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// For IPv6 loopback hosts, brackets are required when serializing without a port.
|
||||||
|
if strings.Contains(host, ":") {
|
||||||
|
u.Host = "[" + host + "]"
|
||||||
|
} else {
|
||||||
|
u.Host = host
|
||||||
|
}
|
||||||
|
return u.String()
|
||||||
|
}
|
||||||
|
|
||||||
// matchCallbackURL checks if the input callback URL matches the given pattern.
|
// matchCallbackURL checks if the input callback URL matches the given pattern.
|
||||||
// It supports wildcard matching for paths and query parameters.
|
// It supports wildcard matching for paths and query parameters.
|
||||||
//
|
//
|
||||||
@@ -125,10 +132,57 @@ func matchCallbackURL(pattern string, inputCallbackURL string) (matches bool, er
|
|||||||
// normalizeToURLPatternStandard converts patterns with single asterisk wildcards and globstar wildcards
|
// normalizeToURLPatternStandard converts patterns with single asterisk wildcards and globstar wildcards
|
||||||
// into a format that can be parsed by the urlpattern package, which uses :param for single segment wildcards
|
// into a format that can be parsed by the urlpattern package, which uses :param for single segment wildcards
|
||||||
// and ** for multi-segment wildcards.
|
// and ** for multi-segment wildcards.
|
||||||
|
// Additionally, it escapes ":" with a backslash inside IPv6 addresses
|
||||||
func normalizeToURLPatternStandard(pattern string) string {
|
func normalizeToURLPatternStandard(pattern string) string {
|
||||||
patternBase, patternPath := extractPath(pattern)
|
patternBase, patternPath := extractPath(pattern)
|
||||||
|
|
||||||
var result strings.Builder
|
var result strings.Builder
|
||||||
|
result.Grow(len(pattern) + 5) // Add 5 for some extra capacity, hoping to avoid many re-allocations
|
||||||
|
|
||||||
|
// First, process the base
|
||||||
|
|
||||||
|
// 0 = scheme
|
||||||
|
// 1 = hostname (optionally with username/password) - before IPv6 start (no `[` found)
|
||||||
|
// 2 = is matching IPv6 (until `]`)
|
||||||
|
// 3 = after hostname
|
||||||
|
var step int
|
||||||
|
for i := 0; i < len(patternBase); i++ {
|
||||||
|
switch step {
|
||||||
|
case 0:
|
||||||
|
if i > 3 && patternBase[i] == '/' && patternBase[i-1] == '/' && patternBase[i-2] == ':' {
|
||||||
|
// We just passed the scheme
|
||||||
|
step = 1
|
||||||
|
}
|
||||||
|
case 1:
|
||||||
|
switch patternBase[i] {
|
||||||
|
case '/', ']':
|
||||||
|
// No IPv6, skip to end of this logic
|
||||||
|
step = 3
|
||||||
|
case '[':
|
||||||
|
// Start of IPv6 match
|
||||||
|
step = 2
|
||||||
|
}
|
||||||
|
case 2:
|
||||||
|
if patternBase[i] == '/' || patternBase[i] == ']' || patternBase[i] == '[' {
|
||||||
|
// End of IPv6 match
|
||||||
|
step = 3
|
||||||
|
}
|
||||||
|
|
||||||
|
switch patternBase[i] {
|
||||||
|
case ':':
|
||||||
|
// We are matching an IPv6 block and there's a colon, so escape that
|
||||||
|
result.WriteByte('\\')
|
||||||
|
case '/', ']', '[':
|
||||||
|
// End of IPv6 match
|
||||||
|
step = 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write the byte
|
||||||
|
result.WriteByte(patternBase[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next, process the path
|
||||||
for i := 0; i < len(patternPath); i++ {
|
for i := 0; i < len(patternPath); i++ {
|
||||||
if patternPath[i] == '*' {
|
if patternPath[i] == '*' {
|
||||||
// Replace globstar with a single asterisk
|
// Replace globstar with a single asterisk
|
||||||
@@ -141,19 +195,19 @@ func normalizeToURLPatternStandard(pattern string) string {
|
|||||||
result.WriteString(strconv.Itoa(i))
|
result.WriteString(strconv.Itoa(i))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
// Add the byte
|
||||||
result.WriteByte(patternPath[i])
|
result.WriteByte(patternPath[i])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
patternPath = result.String()
|
return result.String()
|
||||||
|
|
||||||
return patternBase + patternPath
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func extractPath(url string) (base string, path string) {
|
func extractPath(url string) (base string, path string) {
|
||||||
pathStart := -1
|
pathStart := -1
|
||||||
|
|
||||||
// Look for scheme:// first
|
// Look for scheme:// first
|
||||||
if i := strings.Index(url, "://"); i >= 0 {
|
i := strings.Index(url, "://")
|
||||||
|
if i >= 0 {
|
||||||
// Look for the next slash after scheme://
|
// Look for the next slash after scheme://
|
||||||
rest := url[i+3:]
|
rest := url[i+3:]
|
||||||
if j := strings.IndexByte(rest, '/'); j >= 0 {
|
if j := strings.IndexByte(rest, '/'); j >= 0 {
|
||||||
|
|||||||
@@ -58,11 +58,6 @@ func TestValidateCallbackURLPattern(t *testing.T) {
|
|||||||
pattern: "https://exa[mple.com/callback",
|
pattern: "https://exa[mple.com/callback",
|
||||||
shouldError: true,
|
shouldError: true,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "malformed authority",
|
|
||||||
pattern: "https://[::1/callback",
|
|
||||||
shouldError: true,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -78,6 +73,76 @@ func TestValidateCallbackURLPattern(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNormalizeToURLPatternStandard(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "exact URL unchanged",
|
||||||
|
input: "https://example.com/callback",
|
||||||
|
expected: "https://example.com/callback",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single wildcard path segment converted to named parameter",
|
||||||
|
input: "https://example.com/api/*/callback",
|
||||||
|
expected: "https://example.com/api/:p5/callback",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single wildcard in path suffix converted to named parameter",
|
||||||
|
input: "https://example.com/test*",
|
||||||
|
expected: "https://example.com/test:p5",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "globstar converted to single asterisk",
|
||||||
|
input: "https://example.com/**/callback",
|
||||||
|
expected: "https://example.com/*/callback",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed globstar and single wildcard conversion",
|
||||||
|
input: "https://example.com/**/v1/**/callback/*",
|
||||||
|
expected: "https://example.com/*/v1/*/callback/:p19",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "URL without path unchanged",
|
||||||
|
input: "https://example.com",
|
||||||
|
expected: "https://example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "relative path conversion",
|
||||||
|
input: "/foo/*/bar",
|
||||||
|
expected: "/foo/:p5/bar",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard in hostname is not normalized by this function",
|
||||||
|
input: "https://*.example.com/callback",
|
||||||
|
expected: "https://*.example.com/callback",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 hostname escapes all colons inside address",
|
||||||
|
input: "https://[2001:db8:1:1::a:1]/callback",
|
||||||
|
expected: "https://[2001\\:db8\\:1\\:1\\:\\:a\\:1]/callback",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 hostname with port escapes only address colons",
|
||||||
|
input: "https://[::1]:8080/callback",
|
||||||
|
expected: "https://[\\:\\:1]:8080/callback",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard in query is converted when query is part of input",
|
||||||
|
input: "https://example.com/callback?code=*",
|
||||||
|
expected: "https://example.com/callback?code=:p15",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tt.expected, normalizeToURLPatternStandard(tt.input))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestMatchCallbackURL(t *testing.T) {
|
func TestMatchCallbackURL(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -98,6 +163,18 @@ func TestMatchCallbackURL(t *testing.T) {
|
|||||||
"https://example.com/callback",
|
"https://example.com/callback",
|
||||||
false,
|
false,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"exact match - IPv4",
|
||||||
|
"https://10.1.0.1/callback",
|
||||||
|
"https://10.1.0.1/callback",
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"exact match - IPv6",
|
||||||
|
"https://[2001:db8:1:1::a:1]/callback",
|
||||||
|
"https://[2001:db8:1:1::a:1]/callback",
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
|
||||||
// Scheme
|
// Scheme
|
||||||
{
|
{
|
||||||
@@ -182,6 +259,30 @@ func TestMatchCallbackURL(t *testing.T) {
|
|||||||
"https://example.com:8080/callback",
|
"https://example.com:8080/callback",
|
||||||
true,
|
true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"wildcard port - IPv4",
|
||||||
|
"https://10.1.0.1:*/callback",
|
||||||
|
"https://10.1.0.1:8080/callback",
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"partial wildcard in port prefix - IPv4",
|
||||||
|
"https://10.1.0.1:80*/callback",
|
||||||
|
"https://10.1.0.1:8080/callback",
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"wildcard port - IPv6",
|
||||||
|
"https://[2001:db8:1:1::a:1]:*/callback",
|
||||||
|
"https://[2001:db8:1:1::a:1]:8080/callback",
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"partial wildcard in port prefix - IPv6",
|
||||||
|
"https://[2001:db8:1:1::a:1]:80*/callback",
|
||||||
|
"https://[2001:db8:1:1::a:1]:8080/callback",
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
|
||||||
// Path
|
// Path
|
||||||
{
|
{
|
||||||
@@ -202,6 +303,18 @@ func TestMatchCallbackURL(t *testing.T) {
|
|||||||
"https://example.com/callback",
|
"https://example.com/callback",
|
||||||
true,
|
true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"wildcard entire path - IPv4",
|
||||||
|
"https://10.1.0.1/*",
|
||||||
|
"https://10.1.0.1/callback",
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"wildcard entire path - IPv6",
|
||||||
|
"https://[2001:db8:1:1::a:1]/*",
|
||||||
|
"https://[2001:db8:1:1::a:1]/callback",
|
||||||
|
true,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"partial wildcard in path prefix",
|
"partial wildcard in path prefix",
|
||||||
"https://example.com/test*",
|
"https://example.com/test*",
|
||||||
@@ -435,10 +548,11 @@ func TestMatchCallbackURL(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
matches, err := matchCallbackURL(tt.pattern, tt.input)
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
require.NoError(t, err, tt.name)
|
matches, err := matchCallbackURL(tt.pattern, tt.input)
|
||||||
assert.Equal(t, tt.shouldMatch, matches, tt.name)
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.shouldMatch, matches)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -472,14 +586,21 @@ func TestGetCallbackURLFromList_LoopbackSpecialHandling(t *testing.T) {
|
|||||||
expectMatch: true,
|
expectMatch: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "IPv6 loopback with dynamic port",
|
name: "IPv6 loopback with dynamic port - exact match",
|
||||||
urls: []string{"http://[::1]/callback"},
|
urls: []string{"http://[::1]/callback"},
|
||||||
inputCallbackURL: "http://[::1]:8080/callback",
|
inputCallbackURL: "http://[::1]:8080/callback",
|
||||||
expectedURL: "http://[::1]:8080/callback",
|
expectedURL: "http://[::1]:8080/callback",
|
||||||
expectMatch: true,
|
expectMatch: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "IPv6 loopback with wildcard path",
|
name: "IPv6 loopback with same port - exact match",
|
||||||
|
urls: []string{"http://[::1]:8080/callback"},
|
||||||
|
inputCallbackURL: "http://[::1]:8080/callback",
|
||||||
|
expectedURL: "http://[::1]:8080/callback",
|
||||||
|
expectMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 loopback with path match",
|
||||||
urls: []string{"http://[::1]/auth/*"},
|
urls: []string{"http://[::1]/auth/*"},
|
||||||
inputCallbackURL: "http://[::1]:8080/auth/callback",
|
inputCallbackURL: "http://[::1]:8080/auth/callback",
|
||||||
expectedURL: "http://[::1]:8080/auth/callback",
|
expectedURL: "http://[::1]:8080/auth/callback",
|
||||||
@@ -506,6 +627,20 @@ func TestGetCallbackURLFromList_LoopbackSpecialHandling(t *testing.T) {
|
|||||||
expectedURL: "http://127.0.0.1:3000/auth/callback",
|
expectedURL: "http://127.0.0.1:3000/auth/callback",
|
||||||
expectMatch: true,
|
expectMatch: true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "loopback with path port",
|
||||||
|
urls: []string{"http://127.0.0.1:*/auth/callback"},
|
||||||
|
inputCallbackURL: "http://127.0.0.1:3000/auth/callback",
|
||||||
|
expectedURL: "http://127.0.0.1:3000/auth/callback",
|
||||||
|
expectMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 loopback with path port",
|
||||||
|
urls: []string{"http://[::1]:*/auth/callback"},
|
||||||
|
inputCallbackURL: "http://[::1]:3000/auth/callback",
|
||||||
|
expectedURL: "http://[::1]:3000/auth/callback",
|
||||||
|
expectMatch: true,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "loopback with path mismatch",
|
name: "loopback with path mismatch",
|
||||||
urls: []string{"http://127.0.0.1/callback"},
|
urls: []string{"http://127.0.0.1/callback"},
|
||||||
@@ -549,6 +684,76 @@ func TestGetCallbackURLFromList_LoopbackSpecialHandling(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLoopbackURLWithWildcardPort(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
output string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "localhost http with port strips port",
|
||||||
|
input: "http://localhost:3000/callback",
|
||||||
|
output: "http://localhost/callback",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "localhost http without port stays same",
|
||||||
|
input: "http://localhost/callback",
|
||||||
|
output: "http://localhost/callback",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv4 loopback with port strips port",
|
||||||
|
input: "http://127.0.0.1:8080/callback",
|
||||||
|
output: "http://127.0.0.1/callback",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv4 loopback without port stays same",
|
||||||
|
input: "http://127.0.0.1/callback",
|
||||||
|
output: "http://127.0.0.1/callback",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 loopback with port strips port and keeps brackets",
|
||||||
|
input: "http://[::1]:8080/callback",
|
||||||
|
output: "http://[::1]/callback",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 loopback preserves path query and fragment",
|
||||||
|
input: "http://[::1]:8080/auth/callback?code=123#state",
|
||||||
|
output: "http://[::1]/auth/callback?code=123#state",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "https loopback returns empty",
|
||||||
|
input: "https://127.0.0.1:8080/callback",
|
||||||
|
output: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non loopback host returns empty",
|
||||||
|
input: "http://example.com:8080/callback",
|
||||||
|
output: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non loopback IP returns empty",
|
||||||
|
input: "http://192.168.1.10:8080/callback",
|
||||||
|
output: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "malformed URL returns empty",
|
||||||
|
input: "http://[::1:8080/callback",
|
||||||
|
output: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "relative URL returns empty",
|
||||||
|
input: "/callback",
|
||||||
|
output: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tt.output, loopbackURLWithWildcardPort(tt.input))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGetCallbackURLFromList_MultiplePatterns(t *testing.T) {
|
func TestGetCallbackURLFromList_MultiplePatterns(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
Reference in New Issue
Block a user