mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-03-05 18:30:34 +00:00
228 lines
5.9 KiB
Go
228 lines
5.9 KiB
Go
package controller
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
|
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
|
"github.com/pocket-id/pocket-id/backend/internal/service"
|
|
)
|
|
|
|
func TestCreateTokensHandler(t *testing.T) {
|
|
createTestContext := func(t *testing.T, rawURL string, form url.Values, authHeader string, noCT bool) (*gin.Context, *httptest.ResponseRecorder) {
|
|
t.Helper()
|
|
|
|
mode := gin.Mode()
|
|
gin.SetMode(gin.TestMode)
|
|
t.Cleanup(func() { gin.SetMode(mode) })
|
|
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
|
|
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, rawURL, strings.NewReader(form.Encode()))
|
|
require.NoError(t, err)
|
|
|
|
if !noCT {
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
}
|
|
if authHeader != "" {
|
|
req.Header.Set("Authorization", authHeader)
|
|
}
|
|
|
|
c.Request = req
|
|
return c, recorder
|
|
}
|
|
|
|
t.Run("Ignores Query String Parameters For Binding", func(t *testing.T) {
|
|
oc := &OidcController{}
|
|
|
|
c, _ := createTestContext(
|
|
t,
|
|
"http://example.com/oidc/token?grant_type=refresh_token&refresh_token=query-value",
|
|
url.Values{},
|
|
"",
|
|
false,
|
|
)
|
|
|
|
oc.createTokensHandler(c)
|
|
|
|
require.Len(t, c.Errors, 1)
|
|
assert.Contains(t, c.Errors[0].Err.Error(), "GrantType")
|
|
})
|
|
|
|
t.Run("Missing Authorization Code", func(t *testing.T) {
|
|
oc := &OidcController{}
|
|
|
|
c, _ := createTestContext(
|
|
t,
|
|
"http://example.com/oidc/token",
|
|
url.Values{
|
|
"grant_type": {service.GrantTypeAuthorizationCode},
|
|
},
|
|
"",
|
|
false,
|
|
)
|
|
|
|
oc.createTokensHandler(c)
|
|
|
|
require.Len(t, c.Errors, 1)
|
|
var missingCodeErr *common.OidcMissingAuthorizationCodeError
|
|
require.ErrorAs(t, c.Errors[0].Err, &missingCodeErr)
|
|
})
|
|
|
|
t.Run("Missing Refresh Token", func(t *testing.T) {
|
|
oc := &OidcController{}
|
|
|
|
c, _ := createTestContext(
|
|
t,
|
|
"http://example.com/oidc/token",
|
|
url.Values{
|
|
"grant_type": {service.GrantTypeRefreshToken},
|
|
},
|
|
"",
|
|
false,
|
|
)
|
|
|
|
oc.createTokensHandler(c)
|
|
|
|
require.Len(t, c.Errors, 1)
|
|
var missingRefreshErr *common.OidcMissingRefreshTokenError
|
|
require.ErrorAs(t, c.Errors[0].Err, &missingRefreshErr)
|
|
})
|
|
|
|
t.Run("Uses Basic Auth Credentials When Body Credentials Missing", func(t *testing.T) {
|
|
var capturedInput dto.OidcCreateTokensDto
|
|
oc := &OidcController{
|
|
createTokens: func(_ context.Context, input dto.OidcCreateTokensDto) (service.CreatedTokens, error) {
|
|
capturedInput = input
|
|
return service.CreatedTokens{
|
|
AccessToken: "access-token",
|
|
IdToken: "id-token",
|
|
RefreshToken: "refresh-token",
|
|
ExpiresIn: 2 * time.Minute,
|
|
}, nil
|
|
},
|
|
}
|
|
|
|
basicAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("client-id:client-secret"))
|
|
c, recorder := createTestContext(
|
|
t,
|
|
"http://example.com/oidc/token",
|
|
url.Values{
|
|
"grant_type": {service.GrantTypeRefreshToken},
|
|
"refresh_token": {"input-refresh-token"},
|
|
},
|
|
basicAuth,
|
|
false,
|
|
)
|
|
|
|
oc.createTokensHandler(c)
|
|
|
|
require.Empty(t, c.Errors)
|
|
assert.Equal(t, "client-id", capturedInput.ClientID)
|
|
assert.Equal(t, "client-secret", capturedInput.ClientSecret)
|
|
assert.Equal(t, "input-refresh-token", capturedInput.RefreshToken)
|
|
|
|
require.Equal(t, http.StatusOK, recorder.Code)
|
|
var response dto.OidcTokenResponseDto
|
|
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &response))
|
|
assert.Equal(t, "access-token", response.AccessToken)
|
|
assert.Equal(t, "Bearer", response.TokenType)
|
|
assert.Equal(t, "id-token", response.IdToken)
|
|
assert.Equal(t, "refresh-token", response.RefreshToken)
|
|
assert.Equal(t, 120, response.ExpiresIn)
|
|
})
|
|
|
|
t.Run("Maps Authorization Pending Error", func(t *testing.T) {
|
|
oc := &OidcController{
|
|
createTokens: func(context.Context, dto.OidcCreateTokensDto) (service.CreatedTokens, error) {
|
|
return service.CreatedTokens{}, &common.OidcAuthorizationPendingError{}
|
|
},
|
|
}
|
|
|
|
c, recorder := createTestContext(
|
|
t,
|
|
"http://example.com/oidc/token",
|
|
url.Values{
|
|
"grant_type": {service.GrantTypeRefreshToken},
|
|
"refresh_token": {"input-refresh-token"},
|
|
},
|
|
"",
|
|
false,
|
|
)
|
|
|
|
oc.createTokensHandler(c)
|
|
|
|
require.Empty(t, c.Errors)
|
|
require.Equal(t, http.StatusBadRequest, recorder.Code)
|
|
var response map[string]string
|
|
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &response))
|
|
assert.Equal(t, "authorization_pending", response["error"])
|
|
})
|
|
|
|
t.Run("Maps Slow Down Error", func(t *testing.T) {
|
|
oc := &OidcController{
|
|
createTokens: func(context.Context, dto.OidcCreateTokensDto) (service.CreatedTokens, error) {
|
|
return service.CreatedTokens{}, &common.OidcSlowDownError{}
|
|
},
|
|
}
|
|
|
|
c, recorder := createTestContext(
|
|
t,
|
|
"http://example.com/oidc/token",
|
|
url.Values{
|
|
"grant_type": {service.GrantTypeRefreshToken},
|
|
"refresh_token": {"input-refresh-token"},
|
|
},
|
|
"",
|
|
false,
|
|
)
|
|
|
|
oc.createTokensHandler(c)
|
|
|
|
require.Empty(t, c.Errors)
|
|
require.Equal(t, http.StatusBadRequest, recorder.Code)
|
|
var response map[string]string
|
|
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &response))
|
|
assert.Equal(t, "slow_down", response["error"])
|
|
})
|
|
|
|
t.Run("Returns Generic Service Error In Context", func(t *testing.T) {
|
|
expectedErr := errors.New("boom")
|
|
oc := &OidcController{
|
|
createTokens: func(context.Context, dto.OidcCreateTokensDto) (service.CreatedTokens, error) {
|
|
return service.CreatedTokens{}, expectedErr
|
|
},
|
|
}
|
|
|
|
c, _ := createTestContext(
|
|
t,
|
|
"http://example.com/oidc/token",
|
|
url.Values{
|
|
"grant_type": {service.GrantTypeRefreshToken},
|
|
"refresh_token": {"input-refresh-token"},
|
|
},
|
|
"",
|
|
false,
|
|
)
|
|
|
|
oc.createTokensHandler(c)
|
|
|
|
require.Len(t, c.Errors, 1)
|
|
assert.ErrorIs(t, c.Errors[0].Err, expectedErr)
|
|
})
|
|
}
|