mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-02-19 19:29:53 +00:00
Compare commits
127 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
68fc9c0659 | ||
|
|
2952b15755 | ||
|
|
ef1d599662 | ||
|
|
4e49d3932a | ||
|
|
86d3c08494 | ||
|
|
7b4ccd1f30 | ||
|
|
f145903eb0 | ||
|
|
d3bc1797b6 | ||
|
|
db94f81937 | ||
|
|
b03e91b653 | ||
|
|
505bdcb8ba | ||
|
|
f103a54790 | ||
|
|
e1de593dcd | ||
|
|
45f42772b1 | ||
|
|
98152640b1 | ||
|
|
04e235e805 | ||
|
|
ae737dddaa | ||
|
|
f565c702e5 | ||
|
|
f945b44bc9 | ||
|
|
857b9cc864 | ||
|
|
bf042563e9 | ||
|
|
49f1ab2f75 | ||
|
|
e46f60ac8d | ||
|
|
5c9e504291 | ||
|
|
7fe83f8087 | ||
|
|
43f0114c57 | ||
|
|
1a41b05f60 | ||
|
|
81315790a8 | ||
|
|
8c8fc2304d | ||
|
|
15ece0ab30 | ||
|
|
5550729120 | ||
|
|
9872608d61 | ||
|
|
be52660227 | ||
|
|
237342e876 | ||
|
|
cfbfbc9753 | ||
|
|
aefb308536 | ||
|
|
031181ad2a | ||
|
|
dbf3da41f3 | ||
|
|
3a2902789e | ||
|
|
459a4fd727 | ||
|
|
2ecc1abbad | ||
|
|
92c57ada1a | ||
|
|
fceb6fa7b4 | ||
|
|
c290c027fb | ||
|
|
ca205a8c73 | ||
|
|
968cf0b307 | ||
|
|
fd8bee94a4 | ||
|
|
41ac1be082 | ||
|
|
dd9b1d26ea | ||
|
|
4b829757b2 | ||
|
|
b5b01cb6dd | ||
|
|
287314f016 | ||
|
|
73e7e0b1c5 | ||
|
|
d070b9a778 | ||
|
|
d976bf5965 | ||
|
|
052ac008c3 | ||
|
|
57a2b2bc83 | ||
|
|
043f82ad79 | ||
|
|
ba61cdba4e | ||
|
|
dcd1ae96e0 | ||
|
|
1fdb058386 | ||
|
|
29cb5513a0 | ||
|
|
6db57d9f27 | ||
|
|
1a77bd9914 | ||
|
|
350335711b | ||
|
|
988c425150 | ||
|
|
23827ba1d1 | ||
|
|
7d36bda769 | ||
|
|
8c559ea067 | ||
|
|
88832d4bc9 | ||
|
|
f5cece3b0e | ||
|
|
d5485238b8 | ||
|
|
ac5a121f66 | ||
|
|
481df3bcb9 | ||
|
|
7677a3de2c | ||
|
|
1f65c01b04 | ||
|
|
d5928f6fea | ||
|
|
bef77ac8dc | ||
|
|
c8eb034c49 | ||
|
|
c77167df46 | ||
|
|
3717a663d9 | ||
|
|
5814549cbe | ||
|
|
2e5d268798 | ||
|
|
4ed312251e | ||
|
|
946c534b08 | ||
|
|
883877adec | ||
|
|
215531d65c | ||
|
|
c0f055c3c0 | ||
|
|
d77044882d | ||
|
|
d6795300b1 | ||
|
|
fd3c76ffa3 | ||
|
|
698bc3a35a | ||
|
|
1bcb50edc3 | ||
|
|
9700afb9cb | ||
|
|
9ce82fb205 | ||
|
|
2935236ace | ||
|
|
c821b675b8 | ||
|
|
a09d529027 | ||
|
|
b62b61fb01 | ||
|
|
df5c1ed1f8 | ||
|
|
f4af35f86b | ||
|
|
657a51f7ed | ||
|
|
575b2f71e9 | ||
|
|
97f7326da4 | ||
|
|
242d87a54b | ||
|
|
c111b79147 | ||
|
|
61bf14225b | ||
|
|
c1e98411b6 | ||
|
|
b25e95fc4a | ||
|
|
3cc82d8522 | ||
|
|
ea4e48680c | ||
|
|
f403eed12c | ||
|
|
388a874922 | ||
|
|
9a4aab465a | ||
|
|
a052cd6619 | ||
|
|
31a803b243 | ||
|
|
1d2e41c04e | ||
|
|
b650d6d423 | ||
|
|
156aad3057 | ||
|
|
05bfe00924 | ||
|
|
035b2c022b | ||
|
|
61b62d4612 | ||
|
|
dc5d7bb2f3 | ||
|
|
5e9096e328 | ||
|
|
34b4ba514f | ||
|
|
d217083059 | ||
|
|
bdcef60cab |
1
.github/CODEOWNERS
vendored
Normal file
1
.github/CODEOWNERS
vendored
Normal file
@@ -0,0 +1 @@
|
||||
* @pocket-id/maintainers
|
||||
1
.github/ISSUE_TEMPLATE/bug.yml
vendored
1
.github/ISSUE_TEMPLATE/bug.yml
vendored
@@ -1,6 +1,7 @@
|
||||
name: "🐛 Bug Report"
|
||||
description: "Report something that is not working as expected"
|
||||
title: "🐛 Bug Report: "
|
||||
type: 'Bug'
|
||||
labels: [bug]
|
||||
body:
|
||||
- type: markdown
|
||||
|
||||
1
.github/ISSUE_TEMPLATE/feature.yml
vendored
1
.github/ISSUE_TEMPLATE/feature.yml
vendored
@@ -2,6 +2,7 @@ name: 🚀 Feature
|
||||
description: "Submit a proposal for a new feature"
|
||||
title: "🚀 Feature: "
|
||||
labels: [feature]
|
||||
type: 'Feature'
|
||||
body:
|
||||
- type: textarea
|
||||
id: feature-description
|
||||
|
||||
1
.github/ISSUE_TEMPLATE/language-request.yml
vendored
1
.github/ISSUE_TEMPLATE/language-request.yml
vendored
@@ -2,6 +2,7 @@ name: "🌐 Language request"
|
||||
description: "You want to contribute to a language that isn't on Crowdin yet?"
|
||||
title: "🌐 Language Request: <language name in english>"
|
||||
labels: [language-request]
|
||||
type: 'Language Request'
|
||||
body:
|
||||
- type: input
|
||||
id: language-name-native
|
||||
|
||||
96
.github/workflows/build-next.yml
vendored
Normal file
96
.github/workflows/build-next.yml
vendored
Normal file
@@ -0,0 +1,96 @@
|
||||
name: Build Next Image
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
concurrency:
|
||||
group: build-next-image
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build-next:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
id-token: write
|
||||
attestations: write
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm"
|
||||
cache-dependency-path: frontend/package-lock.json
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: "backend/go.mod"
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Set DOCKER_IMAGE_NAME
|
||||
run: |
|
||||
# Lowercase REPO_OWNER which is required for containers
|
||||
REPO_OWNER=${{ github.repository_owner }}
|
||||
DOCKER_IMAGE_NAME="ghcr.io/${REPO_OWNER,,}/pocket-id"
|
||||
echo "DOCKER_IMAGE_NAME=${DOCKER_IMAGE_NAME}" >>${GITHUB_ENV}
|
||||
|
||||
- name: Login to GitHub Container Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.repository_owner }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Install frontend dependencies
|
||||
working-directory: frontend
|
||||
run: npm ci
|
||||
|
||||
- name: Build frontend
|
||||
working-directory: frontend
|
||||
run: npm run build
|
||||
|
||||
- name: Build binaries
|
||||
run: sh scripts/development/build-binaries.sh --docker-only
|
||||
|
||||
- name: Build and push container image
|
||||
id: build-push-image
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ env.DOCKER_IMAGE_NAME }}:next
|
||||
file: Dockerfile-prebuilt
|
||||
- name: Build and push container image (distroless)
|
||||
uses: docker/build-push-action@v6
|
||||
id: container-build-push-distroless
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ env.DOCKER_IMAGE_NAME }}:next-distroless
|
||||
file: Dockerfile-distroless
|
||||
- name: Container image attestation
|
||||
uses: actions/attest-build-provenance@v2
|
||||
with:
|
||||
subject-name: "${{ env.DOCKER_IMAGE_NAME }}"
|
||||
subject-digest: ${{ steps.build-push-image.outputs.digest }}
|
||||
push-to-registry: true
|
||||
- name: Container image attestation (distroless)
|
||||
uses: actions/attest-build-provenance@v2
|
||||
with:
|
||||
subject-name: "${{ env.DOCKER_IMAGE_NAME }}"
|
||||
subject-digest: ${{ steps.container-build-push-distroless.outputs.digest }}
|
||||
push-to-registry: true
|
||||
36
.github/workflows/release.yml
vendored
36
.github/workflows/release.yml
vendored
@@ -29,14 +29,12 @@ jobs:
|
||||
uses: docker/setup-qemu-action@v3
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Set DOCKER_IMAGE_NAME
|
||||
run: |
|
||||
# Lowercase REPO_OWNER which is required for containers
|
||||
REPO_OWNER=${{ github.repository_owner }}
|
||||
DOCKER_IMAGE_NAME="ghcr.io/${REPO_OWNER,,}/pocket-id"
|
||||
echo "DOCKER_IMAGE_NAME=${DOCKER_IMAGE_NAME}" >>${GITHUB_ENV}
|
||||
|
||||
- name: Login to GitHub Container Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
@@ -53,17 +51,26 @@ jobs:
|
||||
type=semver,pattern={{version}},prefix=v
|
||||
type=semver,pattern={{major}}.{{minor}},prefix=v
|
||||
type=semver,pattern={{major}},prefix=v
|
||||
|
||||
- name: Docker metadata (distroless)
|
||||
id: meta-distroless
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
${{ env.DOCKER_IMAGE_NAME }}
|
||||
flavor: |
|
||||
suffix=-distroless,onlatest=true
|
||||
tags: |
|
||||
type=semver,pattern={{version}},prefix=v
|
||||
type=semver,pattern={{major}}.{{minor}},prefix=v
|
||||
type=semver,pattern={{major}},prefix=v
|
||||
- name: Install frontend dependencies
|
||||
working-directory: frontend
|
||||
run: npm ci
|
||||
- name: Build frontend
|
||||
working-directory: frontend
|
||||
run: npm run build
|
||||
|
||||
- name: Build binaries
|
||||
run: sh scripts/development/build-binaries.sh
|
||||
|
||||
- name: Build and push container image
|
||||
uses: docker/build-push-action@v6
|
||||
id: container-build-push
|
||||
@@ -74,19 +81,32 @@ jobs:
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
file: Dockerfile-prebuilt
|
||||
|
||||
- name: Build and push container image (distroless)
|
||||
uses: docker/build-push-action@v6
|
||||
id: container-build-push-distroless
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ steps.meta-distroless.outputs.tags }}
|
||||
labels: ${{ steps.meta-distroless.outputs.labels }}
|
||||
file: Dockerfile-distroless
|
||||
- name: Binary attestation
|
||||
uses: actions/attest-build-provenance@v2
|
||||
with:
|
||||
subject-path: "backend/.bin/pocket-id-**"
|
||||
|
||||
- name: Container image attestation
|
||||
uses: actions/attest-build-provenance@v2
|
||||
with:
|
||||
subject-name: "${{ env.DOCKER_IMAGE_NAME }}"
|
||||
subject-digest: ${{ steps.container-build-push.outputs.digest }}
|
||||
push-to-registry: true
|
||||
|
||||
- name: Container image attestation (distroless)
|
||||
uses: actions/attest-build-provenance@v2
|
||||
with:
|
||||
subject-name: "${{ env.DOCKER_IMAGE_NAME }}"
|
||||
subject-digest: ${{ steps.container-build-push-distroless.outputs.digest }}
|
||||
push-to-registry: true
|
||||
- name: Upload binaries to release
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -10,6 +10,7 @@ node_modules
|
||||
/frontend/build
|
||||
/backend/bin
|
||||
pocket-id
|
||||
/tests/test-results/*.json
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
|
||||
114
CHANGELOG.md
114
CHANGELOG.md
@@ -1,3 +1,117 @@
|
||||
## [](https://github.com/pocket-id/pocket-id/compare/v1.6.2...v) (2025-07-21)
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* allow passkey names up to 50 characters ([b03e91b](https://github.com/pocket-id/pocket-id/commit/b03e91b6530c2393ad20ac49aa2cb2b4962651b2))
|
||||
* ensure user inputs are normalized ([#724](https://github.com/pocket-id/pocket-id/issues/724)) ([7b4ccd1](https://github.com/pocket-id/pocket-id/commit/7b4ccd1f306f4882c52fe30133fcda114ef0d18b))
|
||||
* show rename and delete buttons for passkeys without hovering over the row ([2952b15](https://github.com/pocket-id/pocket-id/commit/2952b1575542ecd0062fe740e2d6a3caad05190d))
|
||||
* use object-contain for images on oidc-client list ([d3bc179](https://github.com/pocket-id/pocket-id/commit/d3bc1797b65ec8bc9201c55d06f3612093f3a873))
|
||||
* use user-agent for identifying known device signins ([ef1d599](https://github.com/pocket-id/pocket-id/commit/ef1d5996624fc534190f80a26f2c48bbad206f49))
|
||||
## [](https://github.com/pocket-id/pocket-id/compare/v1.6.1...v) (2025-07-09)
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* ensure confirmation dialog shows on top of other components ([f103a54](https://github.com/pocket-id/pocket-id/commit/f103a547904070c5b192e519c8b5a8fed9d80e96))
|
||||
* login failures on Postgres when IP is null ([#737](https://github.com/pocket-id/pocket-id/issues/737)) ([e1de593](https://github.com/pocket-id/pocket-id/commit/e1de593dcd30b7b04da3b003455134992b702595))
|
||||
## [](https://github.com/pocket-id/pocket-id/compare/v1.5.0...v) (2025-07-06)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* add "key-rotate" command ([#709](https://github.com/pocket-id/pocket-id/issues/709)) ([8c8fc23](https://github.com/pocket-id/pocket-id/commit/8c8fc2304d8f33c1fea54b1138b109f282e78b8b))
|
||||
* add support for OAuth 2.0 Authorization Server Issuer Identification ([bf04256](https://github.com/pocket-id/pocket-id/commit/bf042563e997d57bb087705a5789fd72ffbed467))
|
||||
* distroless container additional variant + healthcheck command ([#716](https://github.com/pocket-id/pocket-id/issues/716)) ([1a41b05](https://github.com/pocket-id/pocket-id/commit/1a41b05f60d487fff78703bec1d4e832f96fd071))
|
||||
* encrypt private keys saved on disk and in database ([#682](https://github.com/pocket-id/pocket-id/issues/682)) ([5550729](https://github.com/pocket-id/pocket-id/commit/5550729120ac9f5e9361c7f9cf25b9075a33a94a))
|
||||
* enhance language selection message and add translation contribution link ([be52660](https://github.com/pocket-id/pocket-id/commit/be526602273c1689cb4057ca96d4214e7f817d1d))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* actually fix linter issues ([#720](https://github.com/pocket-id/pocket-id/issues/720)) ([7fe83f8](https://github.com/pocket-id/pocket-id/commit/7fe83f8087f033f957bb6e0eee5e0c159417e1cd))
|
||||
* add missing error check in initial user setup ([fceb6fa](https://github.com/pocket-id/pocket-id/commit/fceb6fa7b4701a3645c4c2353bcd108b15d69ded))
|
||||
* allow profile picture update even if "allow own account edit" enabled ([9872608](https://github.com/pocket-id/pocket-id/commit/9872608d61a486f7b775f314d9392e0620bcd891))
|
||||
* app config forms not updating with latest values ([#696](https://github.com/pocket-id/pocket-id/issues/696)) ([92c57ad](https://github.com/pocket-id/pocket-id/commit/92c57ada1a11f76963e36ca0a81bca8f52dbc84e))
|
||||
* auth fails when client IP is empty on Postgres ([#695](https://github.com/pocket-id/pocket-id/issues/695)) ([031181a](https://github.com/pocket-id/pocket-id/commit/031181ad2ae8fae94cc5793dd1c614e79476a766))
|
||||
* custom claims input suggestions flickering ([49f1ab2](https://github.com/pocket-id/pocket-id/commit/49f1ab2f75df97d551fff5acbadcd55df74af617))
|
||||
* keep sidebar in settings sticky ([e46f60a](https://github.com/pocket-id/pocket-id/commit/e46f60ac8d6944bcea54d0708af1950d98f66c3c))
|
||||
* linter issues ([#719](https://github.com/pocket-id/pocket-id/issues/719)) ([43f0114](https://github.com/pocket-id/pocket-id/commit/43f0114c579f7b5b32b372e09f46bcb2a9d7796e))
|
||||
* show friendly name in user group selection ([5c9e504](https://github.com/pocket-id/pocket-id/commit/5c9e504291b3bffe947bcbe907701806e301d1fe))
|
||||
* support non UTF-8 LDAP IDs ([#714](https://github.com/pocket-id/pocket-id/issues/714)) ([8131579](https://github.com/pocket-id/pocket-id/commit/81315790a8aa601a2565a1b54807df1e68f06dc5))
|
||||
* token introspection authentication not handled correctly ([#704](https://github.com/pocket-id/pocket-id/issues/704)) ([aefb308](https://github.com/pocket-id/pocket-id/commit/aefb30853677baf7ed29ac8b539e1aadf56e14a4))
|
||||
|
||||
## [](https://github.com/pocket-id/pocket-id/compare/v1.4.1...v) (2025-06-27)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* improve initial admin creation workflow ([287314f](https://github.com/pocket-id/pocket-id/commit/287314f01644e42ddb2ce1b1115bd14f2f0c1768))
|
||||
* redact sensitive app config variables if set with env variable ([ba61cdb](https://github.com/pocket-id/pocket-id/commit/ba61cdba4eb3d5659f3ae6b6c21249985c0aa630))
|
||||
* self-service user signup ([#672](https://github.com/pocket-id/pocket-id/issues/672)) ([dcd1ae9](https://github.com/pocket-id/pocket-id/commit/dcd1ae96e048115be34b0cce275054e990462ebf))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* double double full stops for certain error messages ([d070b9a](https://github.com/pocket-id/pocket-id/commit/d070b9a778d7d1a51f2fa62d003f2331a96d6c91))
|
||||
* error page flickering after sign out ([1a77bd9](https://github.com/pocket-id/pocket-id/commit/1a77bd9914ea01e445ff3d6e116c9ed3bcfbf153))
|
||||
* improve accent color picker disabled state ([d976bf5](https://github.com/pocket-id/pocket-id/commit/d976bf5965eda10e3ecb71821c23e93e5d712a02))
|
||||
* less noisy logging for certain GET requests ([#681](https://github.com/pocket-id/pocket-id/issues/681)) ([043f82a](https://github.com/pocket-id/pocket-id/commit/043f82ad794eb64a5550d8b80703114a055701d9))
|
||||
* margin of user sign up description ([052ac00](https://github.com/pocket-id/pocket-id/commit/052ac008c3a8c910d1ce79ee99b2b2f75e4090f4))
|
||||
* remove duplicate request logging ([#678](https://github.com/pocket-id/pocket-id/issues/678)) ([988c425](https://github.com/pocket-id/pocket-id/commit/988c425150556b32cff1d341a21fcc9c69d9aaf8))
|
||||
* users can't be updated by admin if self account editing is disabled ([29cb551](https://github.com/pocket-id/pocket-id/commit/29cb5513a03d1a9571969c8a42deec9b2bdee037))
|
||||
|
||||
## [](https://github.com/pocket-id/pocket-id/compare/v1.4.0...v) (2025-06-22)
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* app not starting if UI config is disabled and Postgres is used ([7d36bda](https://github.com/pocket-id/pocket-id/commit/7d36bda769e25497dec6b76206a4f7e151b0bd72))
|
||||
|
||||
## [](https://github.com/pocket-id/pocket-id/compare/v1.3.1...v) (2025-06-19)
|
||||
|
||||
### Features
|
||||
|
||||
* allow setting unix socket mode ([#661](https://github.com/pocket-id/pocket-id/issues/661)) ([7677a3d](https://github.com/pocket-id/pocket-id/commit/7677a3de2c923c11a58bc8c4d1b2121d403a1504))
|
||||
* auto-focus on the login buttons ([#647](https://github.com/pocket-id/pocket-id/issues/647)) ([d679530](https://github.com/pocket-id/pocket-id/commit/d6795300b158b85dd9feadd561b6ecd891f5db0d))
|
||||
* configurable local ipv6 ranges for audit log ([#657](https://github.com/pocket-id/pocket-id/issues/657)) ([d548523](https://github.com/pocket-id/pocket-id/commit/d5485238b8fd4cc566af00eae2b17d69a119f991))
|
||||
* location filter for global audit log ([#662](https://github.com/pocket-id/pocket-id/issues/662)) ([ac5a121](https://github.com/pocket-id/pocket-id/commit/ac5a121f664b8127d0faf30c0f93432f30e7f33a))
|
||||
* ui accent colors ([#643](https://github.com/pocket-id/pocket-id/issues/643)) ([883877a](https://github.com/pocket-id/pocket-id/commit/883877adec6fc3e65bd5a705499449959b894fb5))
|
||||
* use icon instead of text on application image update hover state ([215531d](https://github.com/pocket-id/pocket-id/commit/215531d65c6683609b0b4a5505fdb72696fdb93e))
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* allow images with uppercase file extension ([1bcb50e](https://github.com/pocket-id/pocket-id/commit/1bcb50edc335886dd722a4c69960c48cc3cd1687))
|
||||
* center oidc client images if they are smaller than the box ([946c534](https://github.com/pocket-id/pocket-id/commit/946c534b0877a074a6b658060f9af27e4061397c))
|
||||
* explicitly cache images to prevent unexpected behavior ([2e5d268](https://github.com/pocket-id/pocket-id/commit/2e5d2687982186c12e530492292d49895cb6043a))
|
||||
* reduce duration of animations on login and signin page ([#648](https://github.com/pocket-id/pocket-id/issues/648)) ([d770448](https://github.com/pocket-id/pocket-id/commit/d77044882d5a41da22df1c0099c1eb1f20bcbc5b))
|
||||
* use inline style for dynamic background image URL instead of Tailwind class ([bef77ac](https://github.com/pocket-id/pocket-id/commit/bef77ac8dca2b98b6732677aaafbc28f79d00487))
|
||||
## [](https://github.com/pocket-id/pocket-id/compare/v1.3.0...v) (2025-06-09)
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* change timestamp of `client_credentials.sql` migration ([2935236](https://github.com/pocket-id/pocket-id/commit/2935236acee9c78c2fe6787ec8b5f53ae0eca047))
|
||||
|
||||
## [](https://github.com/pocket-id/pocket-id/compare/v1.2.0...v) (2025-06-09)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* add API endpoint for user authorized clients ([d217083](https://github.com/pocket-id/pocket-id/commit/d217083059120171d5c555b09eefe6ba3c8a8d42))
|
||||
* add unix socket support ([#615](https://github.com/pocket-id/pocket-id/issues/615)) ([035b2c0](https://github.com/pocket-id/pocket-id/commit/035b2c022bfd2b98f13355ec7a126e0f1ab3ebd8))
|
||||
* allow introspection and device code endpoints to use Federated Client Credentials ([#640](https://github.com/pocket-id/pocket-id/issues/640)) ([b62b61f](https://github.com/pocket-id/pocket-id/commit/b62b61fb017dba31a6fc612c138bebf370d3956c))
|
||||
* JWT bearer assertions for client authentication ([#566](https://github.com/pocket-id/pocket-id/issues/566)) ([05bfe00](https://github.com/pocket-id/pocket-id/commit/05bfe0092450c9bc26d03c6a54c21050eef8f63a))
|
||||
* new color theme for the UI ([97f7326](https://github.com/pocket-id/pocket-id/commit/97f7326da40265a954340d519661969530f097a0))
|
||||
* oidc client data preview ([#624](https://github.com/pocket-id/pocket-id/issues/624)) ([c111b79](https://github.com/pocket-id/pocket-id/commit/c111b7914731a3cafeaa55102b515f84a1ad74dc))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* don't load app config and user on every route change ([bdcef60](https://github.com/pocket-id/pocket-id/commit/bdcef60cab6a61e1717661e918c42e3650d23fee))
|
||||
* misleading text for disable animations option ([657a51f](https://github.com/pocket-id/pocket-id/commit/657a51f7ed8a77e8a937971032091058aacfded6))
|
||||
* OIDC client image can't be deleted ([61b62d4](https://github.com/pocket-id/pocket-id/commit/61b62d461200c1359a16c92c9c62530362a4785c))
|
||||
* UI config overridden by env variables don't apply on first start ([5e9096e](https://github.com/pocket-id/pocket-id/commit/5e9096e328741ba2a0e03835927fe62e6aea2a89))
|
||||
* use full width for audit log filters ([575b2f7](https://github.com/pocket-id/pocket-id/commit/575b2f71e9f1ff9c4f6fd411b136676c213b7201))
|
||||
|
||||
## [](https://github.com/pocket-id/pocket-id/compare/v1.1.0...v) (2025-06-03)
|
||||
|
||||
|
||||
|
||||
@@ -48,5 +48,7 @@ RUN chmod +x /app/pocket-id && \
|
||||
EXPOSE 1411
|
||||
ENV APP_ENV=production
|
||||
|
||||
HEALTHCHECK --interval=90s --timeout=5s --start-period=10s --retries=3 CMD [ "/app/pocket-id", "healthcheck" ]
|
||||
|
||||
ENTRYPOINT ["sh", "/app/docker/entrypoint.sh"]
|
||||
CMD ["/app/pocket-id"]
|
||||
|
||||
18
Dockerfile-distroless
Normal file
18
Dockerfile-distroless
Normal file
@@ -0,0 +1,18 @@
|
||||
# This Dockerfile embeds a pre-built binary for the given Linux architecture
|
||||
# Binaries must be built using "./scripts/development/build-binaries.sh --docker-only"
|
||||
|
||||
FROM gcr.io/distroless/static-debian12:nonroot
|
||||
|
||||
# TARGETARCH can be "amd64" or "arm64"
|
||||
ARG TARGETARCH
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY ./backend/.bin/pocket-id-linux-${TARGETARCH} /app/pocket-id
|
||||
|
||||
EXPOSE 1411
|
||||
ENV APP_ENV=production
|
||||
|
||||
HEALTHCHECK --interval=90s --timeout=5s --start-period=10s --retries=3 CMD [ "/app/pocket-id", "healthcheck" ]
|
||||
|
||||
CMD ["/app/pocket-id"]
|
||||
@@ -1,5 +1,5 @@
|
||||
# This Dockerfile embeds a pre-built binary for the given Linux architecture
|
||||
# Binaries must be built using ./scripts/development/build-binaries.sh first
|
||||
# Binaries must be built using "./scripts/development/build-binaries.sh --docker-only"
|
||||
|
||||
FROM alpine
|
||||
|
||||
@@ -16,5 +16,7 @@ COPY ./scripts/docker /app/docker
|
||||
EXPOSE 1411
|
||||
ENV APP_ENV=production
|
||||
|
||||
HEALTHCHECK --interval=90s --timeout=5s --start-period=10s --retries=3 CMD [ "/app/pocket-id", "healthcheck" ]
|
||||
|
||||
ENTRYPOINT ["/app/docker/entrypoint.sh"]
|
||||
CMD ["/app/pocket-id"]
|
||||
|
||||
12
backend/.air.toml
Normal file
12
backend/.air.toml
Normal file
@@ -0,0 +1,12 @@
|
||||
root = "."
|
||||
tmp_dir = ".bin"
|
||||
|
||||
[build]
|
||||
bin = "./.bin/pocket-id"
|
||||
cmd = "CGO_ENABLED=0 go build -o ./.bin/pocket-id ./cmd"
|
||||
exclude_dir = ["resources", ".bin", "data"]
|
||||
exclude_regex = [".*_test\\.go"]
|
||||
stop_on_error = true
|
||||
|
||||
[misc]
|
||||
clean_on_exit = true
|
||||
@@ -1,15 +1,9 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
_ "time/tzdata"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/bootstrap"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/cmds"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
)
|
||||
|
||||
// @title Pocket ID API
|
||||
@@ -17,27 +11,5 @@ import (
|
||||
// @description.markdown
|
||||
|
||||
func main() {
|
||||
// Get the command
|
||||
// By default, this starts the server
|
||||
var cmd string
|
||||
flag.Parse()
|
||||
args := flag.Args()
|
||||
if len(args) > 0 {
|
||||
cmd = args[0]
|
||||
}
|
||||
|
||||
var err error
|
||||
switch cmd {
|
||||
case "version":
|
||||
fmt.Println("pocket-id " + common.Version)
|
||||
case "one-time-access-token":
|
||||
err = cmds.OneTimeAccessToken(args)
|
||||
default:
|
||||
// Start the server
|
||||
err = bootstrap.Bootstrap()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err.Error())
|
||||
}
|
||||
cmds.Execute()
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ require (
|
||||
github.com/emersion/go-smtp v0.21.3
|
||||
github.com/fxamacker/cbor/v2 v2.7.0
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
github.com/glebarez/go-sqlite v1.21.2
|
||||
github.com/glebarez/sqlite v1.11.0
|
||||
github.com/go-co-op/gocron/v2 v2.15.0
|
||||
github.com/go-ldap/ldap/v3 v3.4.10
|
||||
@@ -19,10 +20,13 @@ require (
|
||||
github.com/golang-migrate/migrate/v4 v4.18.2
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/hashicorp/go-uuid v1.0.3
|
||||
github.com/jinzhu/copier v0.4.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/lestrrat-go/jwx/v3 v3.0.0-beta1
|
||||
github.com/lestrrat-go/httprc/v3 v3.0.0-beta2
|
||||
github.com/lestrrat-go/jwx/v3 v3.0.1
|
||||
github.com/mileusna/useragent v1.3.5
|
||||
github.com/oschwald/maxminddb-golang/v2 v2.0.0-beta.2
|
||||
github.com/spf13/cobra v1.9.1
|
||||
github.com/stretchr/testify v1.10.0
|
||||
go.opentelemetry.io/contrib/exporters/autoexport v0.59.0
|
||||
go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.60.0
|
||||
@@ -32,8 +36,9 @@ require (
|
||||
go.opentelemetry.io/otel/sdk v1.35.0
|
||||
go.opentelemetry.io/otel/sdk/metric v1.35.0
|
||||
go.opentelemetry.io/otel/trace v1.35.0
|
||||
golang.org/x/crypto v0.36.0
|
||||
golang.org/x/crypto v0.39.0
|
||||
golang.org/x/image v0.24.0
|
||||
golang.org/x/text v0.26.0
|
||||
golang.org/x/time v0.9.0
|
||||
gorm.io/driver/postgres v1.5.11
|
||||
gorm.io/gorm v1.25.12
|
||||
@@ -54,7 +59,6 @@ require (
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.8 // indirect
|
||||
github.com/gin-contrib/sse v1.0.0 // indirect
|
||||
github.com/glebarez/go-sqlite v1.21.2 // indirect
|
||||
github.com/go-asn1-ber/asn1-ber v1.5.7 // indirect
|
||||
github.com/go-logr/logr v1.4.2 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
@@ -67,6 +71,7 @@ require (
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.1 // indirect
|
||||
github.com/hashicorp/errwrap v1.1.0 // indirect
|
||||
github.com/hashicorp/go-multierror v1.1.1 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/pgx/v5 v5.7.2 // indirect
|
||||
@@ -77,9 +82,8 @@ require (
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/lestrrat-go/blackmagic v1.0.2 // indirect
|
||||
github.com/lestrrat-go/blackmagic v1.0.3 // indirect
|
||||
github.com/lestrrat-go/httpcc v1.0.1 // indirect
|
||||
github.com/lestrrat-go/httprc/v3 v3.0.0-beta1 // indirect
|
||||
github.com/lestrrat-go/option v1.0.1 // indirect
|
||||
github.com/lib/pq v1.10.9 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
@@ -98,6 +102,7 @@ require (
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/robfig/cron/v3 v3.0.1 // indirect
|
||||
github.com/segmentio/asm v1.2.0 // indirect
|
||||
github.com/spf13/pflag v1.0.6 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
github.com/x448/float16 v0.8.4 // indirect
|
||||
@@ -121,16 +126,15 @@ require (
|
||||
golang.org/x/arch v0.14.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6 // indirect
|
||||
golang.org/x/net v0.38.0 // indirect
|
||||
golang.org/x/sync v0.14.0 // indirect
|
||||
golang.org/x/sync v0.15.0 // indirect
|
||||
golang.org/x/sys v0.33.0 // indirect
|
||||
golang.org/x/text v0.23.0 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250218202821-56aae31c358a // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250218202821-56aae31c358a // indirect
|
||||
google.golang.org/grpc v1.71.0 // indirect
|
||||
google.golang.org/protobuf v1.36.5 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
modernc.org/libc v1.65.6 // indirect
|
||||
modernc.org/libc v1.65.10 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.10.0 // indirect
|
||||
modernc.org/sqlite v1.37.0 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
modernc.org/sqlite v1.38.0 // indirect
|
||||
)
|
||||
|
||||
@@ -24,6 +24,7 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL
|
||||
github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCyP4=
|
||||
github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
|
||||
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
@@ -120,6 +121,8 @@ github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9
|
||||
github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
|
||||
github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8=
|
||||
github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
@@ -140,6 +143,8 @@ github.com/jcmturner/gokrb5/v8 v8.4.4 h1:x1Sv4HaTpepFkXbt2IkL29DXRf8sOfZXo8eRKh6
|
||||
github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs=
|
||||
github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY=
|
||||
github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc=
|
||||
github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8=
|
||||
github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
@@ -164,14 +169,14 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k=
|
||||
github.com/lestrrat-go/blackmagic v1.0.2/go.mod h1:UrEqBzIR2U6CnzVyUtfM6oZNMt/7O7Vohk2J0OGSAtU=
|
||||
github.com/lestrrat-go/blackmagic v1.0.3 h1:94HXkVLxkZO9vJI/w2u1T0DAoprShFd13xtnSINtDWs=
|
||||
github.com/lestrrat-go/blackmagic v1.0.3/go.mod h1:6AWFyKNNj0zEXQYfTMPfZrAXUWUfTIZ5ECEUEJaijtw=
|
||||
github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE=
|
||||
github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E=
|
||||
github.com/lestrrat-go/httprc/v3 v3.0.0-beta1 h1:pzDjP9dSONCFQC/AE3mWUnHILGiYPiMKzQIS+weKJXA=
|
||||
github.com/lestrrat-go/httprc/v3 v3.0.0-beta1/go.mod h1:wdsgouffPvWPEYh8t7PRH/PidR5sfVqt0na4Nhj60Ms=
|
||||
github.com/lestrrat-go/jwx/v3 v3.0.0-beta1 h1:Iqjb8JvWjh34Jv8DeM2wQ1aG5fzFBzwQu7rlqwuJB0I=
|
||||
github.com/lestrrat-go/jwx/v3 v3.0.0-beta1/go.mod h1:ak32WoNtHE0aLowVWBcCvXngcAnW4tuC0YhFwOr/kwc=
|
||||
github.com/lestrrat-go/httprc/v3 v3.0.0-beta2 h1:SDxjGoH7qj0nBXVrcrxX8eD94wEnjR+EEuqqmeqQYlY=
|
||||
github.com/lestrrat-go/httprc/v3 v3.0.0-beta2/go.mod h1:Nwo81sMxE0DcvTB+rJyynNhv/DUu2yZErV7sscw9pHE=
|
||||
github.com/lestrrat-go/jwx/v3 v3.0.1 h1:fH3T748FCMbXoF9UXXNS9i0q6PpYyJZK/rKSbkt2guY=
|
||||
github.com/lestrrat-go/jwx/v3 v3.0.1/go.mod h1:XP2WqxMOSzHSyf3pfibCcfsLqbomxakAnNqiuaH8nwo=
|
||||
github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU=
|
||||
github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
@@ -225,8 +230,13 @@ github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
|
||||
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
|
||||
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
||||
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys=
|
||||
github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs=
|
||||
github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo=
|
||||
github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0=
|
||||
github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o=
|
||||
github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
@@ -309,8 +319,8 @@ golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliY
|
||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
|
||||
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
|
||||
golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6 h1:y5zboxd6LQAqYIhHnB48p0ByQ/GnQx2BE33L8BOHQkI=
|
||||
golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6/go.mod h1:U6Lno4MTRCDY+Ba7aCcauB9T60gsv5s4ralQzP72ZoQ=
|
||||
golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
||||
@@ -321,8 +331,8 @@ golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU=
|
||||
golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
|
||||
golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w=
|
||||
golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
@@ -343,8 +353,8 @@ golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
||||
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ=
|
||||
golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8=
|
||||
golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
@@ -377,8 +387,8 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
||||
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
||||
golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M=
|
||||
golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA=
|
||||
golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY=
|
||||
golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
@@ -413,22 +423,22 @@ modernc.org/cc/v4 v4.26.1 h1:+X5NtzVBn0KgsBCBe+xkDC7twLb/jNVj9FPgiwSQO3s=
|
||||
modernc.org/cc/v4 v4.26.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
||||
modernc.org/ccgo/v4 v4.28.0 h1:rjznn6WWehKq7dG4JtLRKxb52Ecv8OUGah8+Z/SfpNU=
|
||||
modernc.org/ccgo/v4 v4.28.0/go.mod h1:JygV3+9AV6SmPhDasu4JgquwU81XAKLd3OKTUDNOiKE=
|
||||
modernc.org/fileutil v1.3.1 h1:8vq5fe7jdtEvoCf3Zf9Nm0Q05sH6kGx0Op2CPx1wTC8=
|
||||
modernc.org/fileutil v1.3.1/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
|
||||
modernc.org/fileutil v1.3.3 h1:3qaU+7f7xxTUmvU1pJTZiDLAIoJVdUSSauJNHg9yXoA=
|
||||
modernc.org/fileutil v1.3.3/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
|
||||
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
|
||||
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
|
||||
modernc.org/libc v1.65.6 h1:OhJUhmuJ6MVZdqL5qmnd0/my46DKGFhSX4WOR7ijfyE=
|
||||
modernc.org/libc v1.65.6/go.mod h1:MOiGAM9lrMBT9L8xT1nO41qYl5eg9gCp9/kWhz5L7WA=
|
||||
modernc.org/libc v1.65.10 h1:ZwEk8+jhW7qBjHIT+wd0d9VjitRyQef9BnzlzGwMODc=
|
||||
modernc.org/libc v1.65.10/go.mod h1:StFvYpx7i/mXtBAfVOjaU0PWZOvIRoZSgXhrwXzr8Po=
|
||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||
modernc.org/memory v1.10.0 h1:fzumd51yQ1DxcOxSO+S6X7+QTuVU+n8/Aj7swYjFfC4=
|
||||
modernc.org/memory v1.10.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
||||
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
||||
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||
modernc.org/sqlite v1.37.0 h1:s1TMe7T3Q3ovQiK2Ouz4Jwh7dw4ZDqbebSDTlSJdfjI=
|
||||
modernc.org/sqlite v1.37.0/go.mod h1:5YiWv+YviqGMuGw4V+PNplcyaJ5v+vQd7TQOgkACoJM=
|
||||
modernc.org/sqlite v1.38.0 h1:+4OrfPQ8pxHKuWG4md1JpR/EYAh3Md7TdejuuzE7EUI=
|
||||
modernc.org/sqlite v1.38.0/go.mod h1:1Bj+yES4SVvBZ4cBOpVZ6QgesMCKpJZDq0nxYzOpmNE=
|
||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||
|
||||
@@ -11,13 +11,9 @@ import (
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/job"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils/signals"
|
||||
)
|
||||
|
||||
func Bootstrap() error {
|
||||
// Get a context that is canceled when the application is stopping
|
||||
ctx := signals.SignalContext(context.Background())
|
||||
|
||||
func Bootstrap(ctx context.Context) error {
|
||||
initApplicationImages()
|
||||
|
||||
// Initialize the tracer and metrics exporter
|
||||
@@ -59,11 +55,12 @@ func Bootstrap() error {
|
||||
|
||||
// Invoke all shutdown functions
|
||||
// We give these a timeout of 5s
|
||||
// Note: we use a background context because the run context has been canceled already
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer shutdownCancel()
|
||||
err = utils.
|
||||
NewServiceRunner(shutdownFns...).
|
||||
Run(shutdownCtx)
|
||||
Run(shutdownCtx) //nolint:contextcheck
|
||||
if err != nil {
|
||||
log.Printf("Error shutting down services: %v", err)
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
sqliteutil "github.com/pocket-id/pocket-id/backend/internal/utils/sqlite"
|
||||
"github.com/pocket-id/pocket-id/backend/resources"
|
||||
)
|
||||
|
||||
@@ -88,6 +89,7 @@ func connectDatabase() (db *gorm.DB, err error) {
|
||||
if !strings.HasPrefix(common.EnvConfig.DbConnectionString, "file:") {
|
||||
return nil, errors.New("invalid value for env var 'DB_CONNECTION_STRING': does not begin with 'file:'")
|
||||
}
|
||||
sqliteutil.RegisterSqliteFunctions()
|
||||
connString, err := parseSqliteConnectionString(common.EnvConfig.DbConnectionString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
|
||||
@@ -14,7 +16,12 @@ import (
|
||||
func init() {
|
||||
registerTestControllers = []func(apiGroup *gin.RouterGroup, db *gorm.DB, svc *services){
|
||||
func(apiGroup *gin.RouterGroup, db *gorm.DB, svc *services) {
|
||||
testService := service.NewTestService(db, svc.appConfigService, svc.jwtService, svc.ldapService)
|
||||
testService, err := service.NewTestService(db, svc.appConfigService, svc.jwtService, svc.ldapService)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to initialize test service: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
controller.NewTestController(apiGroup, testService)
|
||||
},
|
||||
}
|
||||
|
||||
@@ -7,6 +7,9 @@ import (
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/frontend"
|
||||
@@ -45,8 +48,26 @@ func initRouterInternal(db *gorm.DB, svc *services) (utils.Service, error) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
r := gin.Default()
|
||||
r.Use(gin.Logger())
|
||||
// do not log these URLs
|
||||
loggerSkipPathsPrefix := []string{
|
||||
"GET /application-configuration/logo",
|
||||
"GET /application-configuration/background-image",
|
||||
"GET /application-configuration/favicon",
|
||||
"GET /_app",
|
||||
"GET /fonts",
|
||||
"GET /healthz",
|
||||
"HEAD /healthz",
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
r.Use(gin.LoggerWithConfig(gin.LoggerConfig{Skip: func(c *gin.Context) bool {
|
||||
for _, prefix := range loggerSkipPathsPrefix {
|
||||
if strings.HasPrefix(c.Request.Method+" "+c.Request.URL.String(), prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}}))
|
||||
|
||||
if !common.EnvConfig.TrustProxy {
|
||||
_ = r.SetTrustedProxies(nil)
|
||||
@@ -101,21 +122,39 @@ func initRouterInternal(db *gorm.DB, svc *services) (utils.Service, error) {
|
||||
|
||||
// Set up the server
|
||||
srv := &http.Server{
|
||||
Addr: net.JoinHostPort(common.EnvConfig.Host, common.EnvConfig.Port),
|
||||
MaxHeaderBytes: 1 << 20,
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
Handler: r,
|
||||
}
|
||||
|
||||
// Set up the listener
|
||||
listener, err := net.Listen("tcp", srv.Addr)
|
||||
network := "tcp"
|
||||
addr := net.JoinHostPort(common.EnvConfig.Host, common.EnvConfig.Port)
|
||||
if common.EnvConfig.UnixSocket != "" {
|
||||
network = "unix"
|
||||
addr = common.EnvConfig.UnixSocket
|
||||
}
|
||||
|
||||
listener, err := net.Listen(network, addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create TCP listener: %w", err)
|
||||
return nil, fmt.Errorf("failed to create %s listener: %w", network, err)
|
||||
}
|
||||
|
||||
// Set the socket mode if using a Unix socket
|
||||
if network == "unix" && common.EnvConfig.UnixSocketMode != "" {
|
||||
mode, err := strconv.ParseUint(common.EnvConfig.UnixSocketMode, 8, 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse UNIX socket mode '%s': %w", common.EnvConfig.UnixSocketMode, err)
|
||||
}
|
||||
|
||||
if err := os.Chmod(addr, os.FileMode(mode)); err != nil {
|
||||
return nil, fmt.Errorf("failed to set UNIX socket mode '%s': %w", common.EnvConfig.UnixSocketMode, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Service runner function
|
||||
runFn := func(ctx context.Context) error {
|
||||
log.Printf("Server listening on %s", srv.Addr)
|
||||
log.Printf("Server listening on %s", addr)
|
||||
|
||||
// Start the server in a background goroutine
|
||||
go func() {
|
||||
|
||||
@@ -26,23 +26,27 @@ type services struct {
|
||||
}
|
||||
|
||||
// Initializes all services
|
||||
// The context should be used by services only for initialization, and not for running
|
||||
func initServices(initCtx context.Context, db *gorm.DB, httpClient *http.Client) (svc *services, err error) {
|
||||
func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client) (svc *services, err error) {
|
||||
svc = &services{}
|
||||
|
||||
svc.appConfigService = service.NewAppConfigService(initCtx, db)
|
||||
svc.appConfigService = service.NewAppConfigService(ctx, db)
|
||||
|
||||
svc.emailService, err = service.NewEmailService(db, svc.appConfigService)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create email service: %w", err)
|
||||
return nil, fmt.Errorf("failed to create email service: %w", err)
|
||||
}
|
||||
|
||||
svc.geoLiteService = service.NewGeoLiteService(httpClient)
|
||||
svc.auditLogService = service.NewAuditLogService(db, svc.appConfigService, svc.emailService, svc.geoLiteService)
|
||||
svc.jwtService = service.NewJwtService(svc.appConfigService)
|
||||
svc.jwtService = service.NewJwtService(db, svc.appConfigService)
|
||||
svc.userService = service.NewUserService(db, svc.jwtService, svc.auditLogService, svc.emailService, svc.appConfigService)
|
||||
svc.customClaimService = service.NewCustomClaimService(db)
|
||||
svc.oidcService = service.NewOidcService(db, svc.jwtService, svc.appConfigService, svc.auditLogService, svc.customClaimService)
|
||||
|
||||
svc.oidcService, err = service.NewOidcService(ctx, db, svc.jwtService, svc.appConfigService, svc.auditLogService, svc.customClaimService)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OIDC service: %w", err)
|
||||
}
|
||||
|
||||
svc.userGroupService = service.NewUserGroupService(db, svc.appConfigService)
|
||||
svc.ldapService = service.NewLdapService(db, httpClient, svc.appConfigService, svc.userService, svc.userGroupService)
|
||||
svc.apiKeyService = service.NewApiKeyService(db, svc.emailService)
|
||||
|
||||
83
backend/internal/cmds/healthcheck.go
Normal file
83
backend/internal/cmds/healthcheck.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package cmds
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
)
|
||||
|
||||
type healthcheckFlags struct {
|
||||
Endpoint string
|
||||
Verbose bool
|
||||
}
|
||||
|
||||
func init() {
|
||||
var flags healthcheckFlags
|
||||
|
||||
healthcheckCmd := &cobra.Command{
|
||||
Use: "healthcheck",
|
||||
Short: "Performs a healthcheck of a running Pocket ID instance",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
start := time.Now()
|
||||
|
||||
ctx, cancel := context.WithTimeout(cmd.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
url := flags.Endpoint + "/healthz"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
slog.ErrorContext(ctx,
|
||||
"Failed to create request object",
|
||||
"error", err,
|
||||
"url", url,
|
||||
"ms", time.Since(start).Milliseconds(),
|
||||
)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
slog.ErrorContext(ctx,
|
||||
"Failed to perform request",
|
||||
"error", err,
|
||||
"url", url,
|
||||
"ms", time.Since(start).Milliseconds(),
|
||||
)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode < 200 || res.StatusCode >= 300 {
|
||||
if err != nil {
|
||||
slog.ErrorContext(ctx,
|
||||
"Healthcheck failed",
|
||||
"status", res.StatusCode,
|
||||
"url", url,
|
||||
"ms", time.Since(start).Milliseconds(),
|
||||
)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
if flags.Verbose {
|
||||
slog.InfoContext(ctx,
|
||||
"Healthcheck succeeded",
|
||||
"status", res.StatusCode,
|
||||
"url", url,
|
||||
"ms", time.Since(start).Milliseconds(),
|
||||
)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
healthcheckCmd.Flags().StringVarP(&flags.Endpoint, "endpoint", "e", "http://localhost:"+common.EnvConfig.Port, "Endpoint for Pocket ID")
|
||||
healthcheckCmd.Flags().BoolVarP(&flags.Verbose, "verbose", "v", false, "Enable verbose mode")
|
||||
|
||||
rootCmd.AddCommand(healthcheckCmd)
|
||||
}
|
||||
107
backend/internal/cmds/key_rotate.go
Normal file
107
backend/internal/cmds/key_rotate.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package cmds
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
"github.com/spf13/cobra"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/bootstrap"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/service"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk"
|
||||
)
|
||||
|
||||
type keyRotateFlags struct {
|
||||
Alg string
|
||||
Crv string
|
||||
Yes bool
|
||||
}
|
||||
|
||||
func init() {
|
||||
var flags keyRotateFlags
|
||||
|
||||
keyRotateCmd := &cobra.Command{
|
||||
Use: "key-rotate",
|
||||
Short: "Generates a new token signing key and replaces the current one",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
db := bootstrap.NewDatabase()
|
||||
|
||||
return keyRotate(cmd.Context(), flags, db, &common.EnvConfig)
|
||||
},
|
||||
}
|
||||
|
||||
keyRotateCmd.Flags().StringVarP(&flags.Alg, "alg", "a", "RS256", "Key algorithm. Supported values: RS256, RS384, RS512, ES256, ES384, ES512, EdDSA")
|
||||
keyRotateCmd.Flags().StringVarP(&flags.Crv, "crv", "c", "", "Curve name when using EdDSA keys. Supported values: Ed25519")
|
||||
keyRotateCmd.Flags().BoolVarP(&flags.Yes, "yes", "y", false, "Do not prompt for confirmation")
|
||||
|
||||
rootCmd.AddCommand(keyRotateCmd)
|
||||
}
|
||||
|
||||
func keyRotate(ctx context.Context, flags keyRotateFlags, db *gorm.DB, envConfig *common.EnvConfigSchema) error {
|
||||
// Validate the flags
|
||||
switch strings.ToUpper(flags.Alg) {
|
||||
case jwa.RS256().String(), jwa.RS384().String(), jwa.RS512().String(),
|
||||
jwa.ES256().String(), jwa.ES384().String(), jwa.ES512().String():
|
||||
// All good, but uppercase it for consistency
|
||||
flags.Alg = strings.ToUpper(flags.Alg)
|
||||
case strings.ToUpper(jwa.EdDSA().String()):
|
||||
// Ensure Crv is set and valid
|
||||
switch strings.ToUpper(flags.Crv) {
|
||||
case strings.ToUpper(jwa.Ed25519().String()):
|
||||
// All good, but ensure consistency in casing
|
||||
flags.Crv = jwa.Ed25519().String()
|
||||
case "":
|
||||
return errors.New("a curve name is required when algorithm is EdDSA")
|
||||
default:
|
||||
return errors.New("unsupported EdDSA curve; supported values: Ed25519")
|
||||
}
|
||||
case "":
|
||||
return errors.New("key algorithm is required")
|
||||
default:
|
||||
return errors.New("unsupported key algorithm; supported values: RS256, RS384, RS512, ES256, ES384, ES512, EdDSA")
|
||||
}
|
||||
|
||||
if !flags.Yes {
|
||||
fmt.Println("WARNING: Rotating the private key will invalidate all existing tokens. Both pocket-id and all client applications will likely need to be restarted.")
|
||||
ok, err := utils.PromptForConfirmation("Confirm")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
fmt.Println("Aborted")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Init the services we need
|
||||
appConfigService := service.NewAppConfigService(ctx, db)
|
||||
|
||||
// Get the key provider
|
||||
keyProvider, err := jwkutils.GetKeyProvider(db, envConfig, appConfigService.GetDbConfig().InstanceID.Value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get key provider: %w", err)
|
||||
}
|
||||
|
||||
// Generate a new key
|
||||
key, err := jwkutils.GenerateKey(flags.Alg, flags.Crv)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate key: %w", err)
|
||||
}
|
||||
|
||||
// Save the key
|
||||
err = keyProvider.SaveKey(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to store new key: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("Key rotated successfully")
|
||||
fmt.Println("Note: if pocket-id is running, you will need to restart it for the new key to be loaded")
|
||||
|
||||
return nil
|
||||
}
|
||||
214
backend/internal/cmds/key_rotate_test.go
Normal file
214
backend/internal/cmds/key_rotate_test.go
Normal file
@@ -0,0 +1,214 @@
|
||||
package cmds
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/service"
|
||||
jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk"
|
||||
testingutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestKeyRotate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
flags keyRotateFlags
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid RS256",
|
||||
flags: keyRotateFlags{
|
||||
Alg: "RS256",
|
||||
Yes: true,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid EdDSA with Ed25519",
|
||||
flags: keyRotateFlags{
|
||||
Alg: "EdDSA",
|
||||
Crv: "Ed25519",
|
||||
Yes: true,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid algorithm",
|
||||
flags: keyRotateFlags{
|
||||
Alg: "INVALID",
|
||||
Yes: true,
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "unsupported key algorithm",
|
||||
},
|
||||
{
|
||||
name: "EdDSA without curve",
|
||||
flags: keyRotateFlags{
|
||||
Alg: "EdDSA",
|
||||
Yes: true,
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "a curve name is required when algorithm is EdDSA",
|
||||
},
|
||||
{
|
||||
name: "empty algorithm",
|
||||
flags: keyRotateFlags{
|
||||
Alg: "",
|
||||
Yes: true,
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "key algorithm is required",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Run("file storage", func(t *testing.T) {
|
||||
testKeyRotateWithFileStorage(t, tt.flags, tt.wantErr, tt.errMsg)
|
||||
})
|
||||
|
||||
t.Run("database storage", func(t *testing.T) {
|
||||
testKeyRotateWithDatabaseStorage(t, tt.flags, tt.wantErr, tt.errMsg)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testKeyRotateWithFileStorage(t *testing.T, flags keyRotateFlags, wantErr bool, errMsg string) {
|
||||
// Create temporary directory for keys
|
||||
tempDir := t.TempDir()
|
||||
keysPath := filepath.Join(tempDir, "keys")
|
||||
err := os.MkdirAll(keysPath, 0755)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set up file storage config
|
||||
envConfig := &common.EnvConfigSchema{
|
||||
KeysStorage: "file",
|
||||
KeysPath: keysPath,
|
||||
}
|
||||
|
||||
// Create test database
|
||||
db := testingutils.NewDatabaseForTest(t)
|
||||
|
||||
// Initialize app config service and create instance
|
||||
appConfigService := service.NewAppConfigService(t.Context(), db)
|
||||
instanceID := appConfigService.GetDbConfig().InstanceID.Value
|
||||
|
||||
// Check if key exists before rotation
|
||||
keyProvider, err := jwkutils.GetKeyProvider(db, envConfig, instanceID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Run the key rotation
|
||||
err = keyRotate(t.Context(), flags, db, envConfig)
|
||||
|
||||
if wantErr {
|
||||
require.Error(t, err)
|
||||
if errMsg != "" {
|
||||
require.ErrorContains(t, err, errMsg)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify key was created
|
||||
key, err := keyProvider.LoadKey()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, key)
|
||||
|
||||
// Verify the algorithm matches what we requested
|
||||
alg, _ := key.Algorithm()
|
||||
assert.NotEmpty(t, alg)
|
||||
if flags.Alg != "" {
|
||||
expectedAlg := flags.Alg
|
||||
if expectedAlg == "EdDSA" {
|
||||
// EdDSA keys should have the EdDSA algorithm
|
||||
assert.Equal(t, "EdDSA", alg.String())
|
||||
} else {
|
||||
assert.Equal(t, expectedAlg, alg.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testKeyRotateWithDatabaseStorage(t *testing.T, flags keyRotateFlags, wantErr bool, errMsg string) {
|
||||
// Set up database storage config
|
||||
envConfig := &common.EnvConfigSchema{
|
||||
KeysStorage: "database",
|
||||
EncryptionKey: "test-encryption-key-characters-long",
|
||||
}
|
||||
|
||||
// Create test database
|
||||
db := testingutils.NewDatabaseForTest(t)
|
||||
|
||||
// Initialize app config service and create instance
|
||||
appConfigService := service.NewAppConfigService(t.Context(), db)
|
||||
instanceID := appConfigService.GetDbConfig().InstanceID.Value
|
||||
|
||||
// Get key provider
|
||||
keyProvider, err := jwkutils.GetKeyProvider(db, envConfig, instanceID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Run the key rotation
|
||||
err = keyRotate(t.Context(), flags, db, envConfig)
|
||||
|
||||
if wantErr {
|
||||
require.Error(t, err)
|
||||
if errMsg != "" {
|
||||
require.ErrorContains(t, err, errMsg)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify key was created
|
||||
key, err := keyProvider.LoadKey()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, key)
|
||||
|
||||
// Verify the algorithm matches what we requested
|
||||
alg, _ := key.Algorithm()
|
||||
assert.NotEmpty(t, alg)
|
||||
if flags.Alg != "" {
|
||||
expectedAlg := flags.Alg
|
||||
if expectedAlg == "EdDSA" {
|
||||
// EdDSA keys should have the EdDSA algorithm
|
||||
assert.Equal(t, "EdDSA", alg.String())
|
||||
} else {
|
||||
assert.Equal(t, expectedAlg, alg.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyRotateMultipleAlgorithms(t *testing.T) {
|
||||
algorithms := []struct {
|
||||
alg string
|
||||
crv string
|
||||
}{
|
||||
{"RS256", ""},
|
||||
{"RS384", ""},
|
||||
// Skip RSA-4096 key generation test as it can take a long time
|
||||
// {"RS512", ""},
|
||||
{"ES256", ""},
|
||||
{"ES384", ""},
|
||||
{"ES512", ""},
|
||||
{"EdDSA", "Ed25519"},
|
||||
}
|
||||
|
||||
for _, algo := range algorithms {
|
||||
t.Run(algo.alg, func(t *testing.T) {
|
||||
// Test with database storage for all algorithms
|
||||
testKeyRotateWithDatabaseStorage(t, keyRotateFlags{
|
||||
Alg: algo.alg,
|
||||
Crv: algo.crv,
|
||||
Yes: true,
|
||||
}, false, "")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -6,77 +6,77 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/bootstrap"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/service"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils/signals"
|
||||
)
|
||||
|
||||
// OneTimeAccessToken creates a one-time access token for the given user
|
||||
// Args must contain the username or email of the user
|
||||
func OneTimeAccessToken(args []string) error {
|
||||
// Get a context that is canceled when the application is stopping
|
||||
ctx := signals.SignalContext(context.Background())
|
||||
var oneTimeAccessTokenCmd = &cobra.Command{
|
||||
Use: "one-time-access-token [username or email]",
|
||||
Short: "Generates a one-time access token for the given user",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
// Get the username or email of the user
|
||||
userArg := args[0]
|
||||
|
||||
// Get the username or email of the user
|
||||
// Note length is 2 because the first argument is always the command (one-time-access-token)
|
||||
if len(args) != 2 {
|
||||
return errors.New("missing username or email of user; usage: one-time-access-token <username or email>")
|
||||
}
|
||||
userArg := args[1]
|
||||
// Connect to the database
|
||||
db := bootstrap.NewDatabase()
|
||||
|
||||
// Connect to the database
|
||||
db := bootstrap.NewDatabase()
|
||||
// Create the access token
|
||||
var oneTimeAccessToken *model.OneTimeAccessToken
|
||||
err := db.Transaction(func(tx *gorm.DB) error {
|
||||
// Load the user to retrieve the user ID
|
||||
var user model.User
|
||||
queryCtx, queryCancel := context.WithTimeout(cmd.Context(), 10*time.Second)
|
||||
defer queryCancel()
|
||||
txErr := tx.
|
||||
WithContext(queryCtx).
|
||||
Where("username = ? OR email = ?", userArg, userArg).
|
||||
First(&user).
|
||||
Error
|
||||
switch {
|
||||
case errors.Is(txErr, gorm.ErrRecordNotFound):
|
||||
return errors.New("user not found")
|
||||
case txErr != nil:
|
||||
return fmt.Errorf("failed to query for user: %w", txErr)
|
||||
case user.ID == "":
|
||||
return errors.New("invalid user loaded: ID is empty")
|
||||
}
|
||||
|
||||
// Create the access token
|
||||
var oneTimeAccessToken *model.OneTimeAccessToken
|
||||
err := db.Transaction(func(tx *gorm.DB) error {
|
||||
// Load the user to retrieve the user ID
|
||||
var user model.User
|
||||
queryCtx, queryCancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer queryCancel()
|
||||
txErr := tx.
|
||||
WithContext(queryCtx).
|
||||
Where("username = ? OR email = ?", userArg, userArg).
|
||||
First(&user).
|
||||
Error
|
||||
switch {
|
||||
case errors.Is(txErr, gorm.ErrRecordNotFound):
|
||||
return errors.New("user not found")
|
||||
case txErr != nil:
|
||||
return fmt.Errorf("failed to query for user: %w", txErr)
|
||||
case user.ID == "":
|
||||
return errors.New("invalid user loaded: ID is empty")
|
||||
// Create a new access token that expires in 1 hour
|
||||
oneTimeAccessToken, txErr = service.NewOneTimeAccessToken(user.ID, time.Now().Add(time.Hour))
|
||||
if txErr != nil {
|
||||
return fmt.Errorf("failed to generate access token: %w", txErr)
|
||||
}
|
||||
|
||||
queryCtx, queryCancel = context.WithTimeout(cmd.Context(), 10*time.Second)
|
||||
defer queryCancel()
|
||||
txErr = tx.
|
||||
WithContext(queryCtx).
|
||||
Create(oneTimeAccessToken).
|
||||
Error
|
||||
if txErr != nil {
|
||||
return fmt.Errorf("failed to save access token: %w", txErr)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create a new access token that expires in 1 hour
|
||||
oneTimeAccessToken, txErr = service.NewOneTimeAccessToken(user.ID, time.Now().Add(time.Hour))
|
||||
if txErr != nil {
|
||||
return fmt.Errorf("failed to generate access token: %w", txErr)
|
||||
}
|
||||
|
||||
queryCtx, queryCancel = context.WithTimeout(ctx, 10*time.Second)
|
||||
defer queryCancel()
|
||||
txErr = tx.
|
||||
WithContext(queryCtx).
|
||||
Create(oneTimeAccessToken).
|
||||
Error
|
||||
if txErr != nil {
|
||||
return fmt.Errorf("failed to save access token: %w", txErr)
|
||||
}
|
||||
// Print the result
|
||||
fmt.Printf(`A one-time access token valid for 1 hour has been created for "%s".`+"\n", userArg)
|
||||
fmt.Printf("Use the following URL to sign in once: %s/lc/%s\n", common.EnvConfig.AppURL, oneTimeAccessToken.Token)
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Print the result
|
||||
fmt.Printf(`A one-time access token valid for 1 hour has been created for "%s".`+"\n", userArg)
|
||||
fmt.Printf("Use the following URL to sign in once: %s/lc/%s\n", common.EnvConfig.AppURL, oneTimeAccessToken.Token)
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(oneTimeAccessTokenCmd)
|
||||
}
|
||||
|
||||
36
backend/internal/cmds/root.go
Normal file
36
backend/internal/cmds/root.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package cmds
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/bootstrap"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils/signals"
|
||||
)
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
Use: "pocket-id",
|
||||
Short: "A simple and easy-to-use OIDC provider that allows users to authenticate with their passkeys to your services.",
|
||||
Long: "By default, this command starts the pocket-id server.",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
// Start the server
|
||||
err := bootstrap.Bootstrap(cmd.Context())
|
||||
if err != nil {
|
||||
slog.Error("Failed to run pocket-id", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
func Execute() {
|
||||
// Get a context that is canceled when the application is stopping
|
||||
ctx := signals.SignalContext(context.Background())
|
||||
|
||||
err := rootCmd.ExecuteContext(ctx)
|
||||
if err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
19
backend/internal/cmds/version.go
Normal file
19
backend/internal/cmds/version.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package cmds
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
)
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(&cobra.Command{
|
||||
Use: "version",
|
||||
Short: "Print the version number",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
fmt.Println("pocket-id " + common.Version)
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/url"
|
||||
|
||||
@@ -18,9 +20,10 @@ const (
|
||||
)
|
||||
|
||||
const (
|
||||
DbProviderSqlite DbProvider = "sqlite"
|
||||
DbProviderPostgres DbProvider = "postgres"
|
||||
MaxMindGeoLiteCityUrl string = "https://download.maxmind.com/app/geoip_download?edition_id=GeoLite2-City&license_key=%s&suffix=tar.gz"
|
||||
DbProviderSqlite DbProvider = "sqlite"
|
||||
DbProviderPostgres DbProvider = "postgres"
|
||||
MaxMindGeoLiteCityUrl string = "https://download.maxmind.com/app/geoip_download?edition_id=GeoLite2-City&license_key=%s&suffix=tar.gz"
|
||||
defaultSqliteConnString string = "file:data/pocket-id.db?_pragma=journal_mode(WAL)&_pragma=busy_timeout(2500)&_txlock=immediate"
|
||||
)
|
||||
|
||||
type EnvConfigSchema struct {
|
||||
@@ -30,11 +33,17 @@ type EnvConfigSchema struct {
|
||||
DbConnectionString string `env:"DB_CONNECTION_STRING"`
|
||||
UploadPath string `env:"UPLOAD_PATH"`
|
||||
KeysPath string `env:"KEYS_PATH"`
|
||||
KeysStorage string `env:"KEYS_STORAGE"`
|
||||
EncryptionKey string `env:"ENCRYPTION_KEY"`
|
||||
EncryptionKeyFile string `env:"ENCRYPTION_KEY_FILE"`
|
||||
Port string `env:"PORT"`
|
||||
Host string `env:"HOST"`
|
||||
UnixSocket string `env:"UNIX_SOCKET"`
|
||||
UnixSocketMode string `env:"UNIX_SOCKET_MODE"`
|
||||
MaxMindLicenseKey string `env:"MAXMIND_LICENSE_KEY"`
|
||||
GeoLiteDBPath string `env:"GEOLITE_DB_PATH"`
|
||||
GeoLiteDBUrl string `env:"GEOLITE_DB_URL"`
|
||||
LocalIPv6Ranges string `env:"LOCAL_IPV6_RANGES"`
|
||||
UiConfigDisabled bool `env:"UI_CONFIG_DISABLED"`
|
||||
MetricsEnabled bool `env:"METRICS_ENABLED"`
|
||||
TracingEnabled bool `env:"TRACING_ENABLED"`
|
||||
@@ -42,49 +51,83 @@ type EnvConfigSchema struct {
|
||||
AnalyticsDisabled bool `env:"ANALYTICS_DISABLED"`
|
||||
}
|
||||
|
||||
var EnvConfig = &EnvConfigSchema{
|
||||
AppEnv: "production",
|
||||
DbProvider: "sqlite",
|
||||
DbConnectionString: "file:data/pocket-id.db?_pragma=journal_mode(WAL)&_pragma=busy_timeout(2500)&_txlock=immediate",
|
||||
UploadPath: "data/uploads",
|
||||
KeysPath: "data/keys",
|
||||
AppURL: "http://localhost:1411",
|
||||
Port: "1411",
|
||||
Host: "0.0.0.0",
|
||||
MaxMindLicenseKey: "",
|
||||
GeoLiteDBPath: "data/GeoLite2-City.mmdb",
|
||||
GeoLiteDBUrl: MaxMindGeoLiteCityUrl,
|
||||
UiConfigDisabled: false,
|
||||
MetricsEnabled: false,
|
||||
TracingEnabled: false,
|
||||
TrustProxy: false,
|
||||
AnalyticsDisabled: false,
|
||||
}
|
||||
var EnvConfig = defaultConfig()
|
||||
|
||||
func init() {
|
||||
if err := env.ParseWithOptions(EnvConfig, env.Options{}); err != nil {
|
||||
log.Fatal(err)
|
||||
err := parseEnvConfig()
|
||||
if err != nil {
|
||||
log.Fatalf("Configuration error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func defaultConfig() EnvConfigSchema {
|
||||
return EnvConfigSchema{
|
||||
AppEnv: "production",
|
||||
DbProvider: "sqlite",
|
||||
DbConnectionString: "",
|
||||
UploadPath: "data/uploads",
|
||||
KeysPath: "data/keys",
|
||||
KeysStorage: "", // "database" or "file"
|
||||
EncryptionKey: "",
|
||||
AppURL: "http://localhost:1411",
|
||||
Port: "1411",
|
||||
Host: "0.0.0.0",
|
||||
UnixSocket: "",
|
||||
UnixSocketMode: "",
|
||||
MaxMindLicenseKey: "",
|
||||
GeoLiteDBPath: "data/GeoLite2-City.mmdb",
|
||||
GeoLiteDBUrl: MaxMindGeoLiteCityUrl,
|
||||
LocalIPv6Ranges: "",
|
||||
UiConfigDisabled: false,
|
||||
MetricsEnabled: false,
|
||||
TracingEnabled: false,
|
||||
TrustProxy: false,
|
||||
AnalyticsDisabled: false,
|
||||
}
|
||||
}
|
||||
|
||||
func parseEnvConfig() error {
|
||||
err := env.ParseWithOptions(&EnvConfig, env.Options{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing env config: %w", err)
|
||||
}
|
||||
|
||||
// Validate the environment variables
|
||||
switch EnvConfig.DbProvider {
|
||||
case DbProviderSqlite:
|
||||
if EnvConfig.DbConnectionString == "" {
|
||||
log.Fatal("Missing required env var 'DB_CONNECTION_STRING' for SQLite database")
|
||||
EnvConfig.DbConnectionString = defaultSqliteConnString
|
||||
}
|
||||
case DbProviderPostgres:
|
||||
if EnvConfig.DbConnectionString == "" {
|
||||
log.Fatal("Missing required env var 'DB_CONNECTION_STRING' for Postgres database")
|
||||
return errors.New("missing required env var 'DB_CONNECTION_STRING' for Postgres database")
|
||||
}
|
||||
default:
|
||||
log.Fatal("Invalid DB_PROVIDER value. Must be 'sqlite' or 'postgres'")
|
||||
return errors.New("invalid DB_PROVIDER value. Must be 'sqlite' or 'postgres'")
|
||||
}
|
||||
|
||||
parsedAppUrl, err := url.Parse(EnvConfig.AppURL)
|
||||
if err != nil {
|
||||
log.Fatal("APP_URL is not a valid URL")
|
||||
return errors.New("APP_URL is not a valid URL")
|
||||
}
|
||||
if parsedAppUrl.Path != "" {
|
||||
log.Fatal("APP_URL must not contain a path")
|
||||
return errors.New("APP_URL must not contain a path")
|
||||
}
|
||||
|
||||
switch EnvConfig.KeysStorage {
|
||||
// KeysStorage defaults to "file" if empty
|
||||
case "":
|
||||
EnvConfig.KeysStorage = "file"
|
||||
case "database":
|
||||
// If KeysStorage is "database", a key must be specified
|
||||
if EnvConfig.EncryptionKey == "" && EnvConfig.EncryptionKeyFile == "" {
|
||||
return errors.New("ENCRYPTION_KEY or ENCRYPTION_KEY_FILE must be non-empty when KEYS_STORAGE is database")
|
||||
}
|
||||
case "file":
|
||||
// All good, these are valid values
|
||||
default:
|
||||
return fmt.Errorf("invalid value for KEYS_STORAGE: %s", EnvConfig.KeysStorage)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
188
backend/internal/common/env_config_test.go
Normal file
188
backend/internal/common/env_config_test.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseEnvConfig(t *testing.T) {
|
||||
// Store original config to restore later
|
||||
originalConfig := EnvConfig
|
||||
t.Cleanup(func() {
|
||||
EnvConfig = originalConfig
|
||||
})
|
||||
|
||||
t.Run("should parse valid SQLite config correctly", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, DbProviderSqlite, EnvConfig.DbProvider)
|
||||
})
|
||||
|
||||
t.Run("should parse valid Postgres config correctly", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "postgres")
|
||||
t.Setenv("DB_CONNECTION_STRING", "postgres://user:pass@localhost/db")
|
||||
t.Setenv("APP_URL", "https://example.com")
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, DbProviderPostgres, EnvConfig.DbProvider)
|
||||
})
|
||||
|
||||
t.Run("should fail with invalid DB_PROVIDER", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "invalid")
|
||||
t.Setenv("DB_CONNECTION_STRING", "test")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "invalid DB_PROVIDER value")
|
||||
})
|
||||
|
||||
t.Run("should set default SQLite connection string when DB_CONNECTION_STRING is empty", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "") // Explicitly empty
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, defaultSqliteConnString, EnvConfig.DbConnectionString)
|
||||
})
|
||||
|
||||
t.Run("should fail when Postgres DB_CONNECTION_STRING is missing", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "postgres")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "missing required env var 'DB_CONNECTION_STRING' for Postgres")
|
||||
})
|
||||
|
||||
t.Run("should fail with invalid APP_URL", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "€://not-a-valid-url")
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "APP_URL is not a valid URL")
|
||||
})
|
||||
|
||||
t.Run("should fail when APP_URL contains path", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "http://localhost:3000/path")
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "APP_URL must not contain a path")
|
||||
})
|
||||
|
||||
t.Run("should default KEYS_STORAGE to 'file' when empty", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "file", EnvConfig.KeysStorage)
|
||||
})
|
||||
|
||||
t.Run("should fail when KEYS_STORAGE is 'database' but no encryption key", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
t.Setenv("KEYS_STORAGE", "database")
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "ENCRYPTION_KEY or ENCRYPTION_KEY_FILE must be non-empty")
|
||||
})
|
||||
|
||||
t.Run("should accept valid KEYS_STORAGE values", func(t *testing.T) {
|
||||
validStorageTypes := []string{"file", "database"}
|
||||
|
||||
for _, storage := range validStorageTypes {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
t.Setenv("KEYS_STORAGE", storage)
|
||||
if storage == "database" {
|
||||
t.Setenv("ENCRYPTION_KEY", "test-key")
|
||||
}
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, storage, EnvConfig.KeysStorage)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("should fail with invalid KEYS_STORAGE value", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
t.Setenv("KEYS_STORAGE", "invalid")
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "invalid value for KEYS_STORAGE")
|
||||
})
|
||||
|
||||
t.Run("should parse boolean environment variables correctly", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
t.Setenv("UI_CONFIG_DISABLED", "true")
|
||||
t.Setenv("METRICS_ENABLED", "true")
|
||||
t.Setenv("TRACING_ENABLED", "false")
|
||||
t.Setenv("TRUST_PROXY", "true")
|
||||
t.Setenv("ANALYTICS_DISABLED", "false")
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.NoError(t, err)
|
||||
assert.True(t, EnvConfig.UiConfigDisabled)
|
||||
assert.True(t, EnvConfig.MetricsEnabled)
|
||||
assert.False(t, EnvConfig.TracingEnabled)
|
||||
assert.True(t, EnvConfig.TrustProxy)
|
||||
assert.False(t, EnvConfig.AnalyticsDisabled)
|
||||
})
|
||||
|
||||
t.Run("should parse string environment variables correctly", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "postgres")
|
||||
t.Setenv("DB_CONNECTION_STRING", "postgres://test")
|
||||
t.Setenv("APP_URL", "https://prod.example.com")
|
||||
t.Setenv("APP_ENV", "staging")
|
||||
t.Setenv("UPLOAD_PATH", "/custom/uploads")
|
||||
t.Setenv("KEYS_PATH", "/custom/keys")
|
||||
t.Setenv("PORT", "8080")
|
||||
t.Setenv("HOST", "127.0.0.1")
|
||||
t.Setenv("UNIX_SOCKET", "/tmp/app.sock")
|
||||
t.Setenv("MAXMIND_LICENSE_KEY", "test-license")
|
||||
t.Setenv("GEOLITE_DB_PATH", "/custom/geolite.mmdb")
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "staging", EnvConfig.AppEnv)
|
||||
assert.Equal(t, "/custom/uploads", EnvConfig.UploadPath)
|
||||
assert.Equal(t, "8080", EnvConfig.Port)
|
||||
assert.Equal(t, "127.0.0.1", EnvConfig.Host)
|
||||
})
|
||||
}
|
||||
@@ -65,6 +65,11 @@ type OidcClientSecretInvalidError struct{}
|
||||
func (e *OidcClientSecretInvalidError) Error() string { return "invalid client secret" }
|
||||
func (e *OidcClientSecretInvalidError) HttpStatusCode() int { return 400 }
|
||||
|
||||
type OidcClientAssertionInvalidError struct{}
|
||||
|
||||
func (e *OidcClientAssertionInvalidError) Error() string { return "invalid client assertion" }
|
||||
func (e *OidcClientAssertionInvalidError) HttpStatusCode() int { return 400 }
|
||||
|
||||
type OidcInvalidAuthorizationCodeError struct{}
|
||||
|
||||
func (e *OidcInvalidAuthorizationCodeError) Error() string { return "invalid authorization code" }
|
||||
@@ -344,3 +349,13 @@ func (e *OidcAuthorizationPendingError) Error() string {
|
||||
func (e *OidcAuthorizationPendingError) HttpStatusCode() int {
|
||||
return http.StatusBadRequest
|
||||
}
|
||||
|
||||
type OpenSignupDisabledError struct{}
|
||||
|
||||
func (e *OpenSignupDisabledError) Error() string {
|
||||
return "Open user signup is not enabled"
|
||||
}
|
||||
|
||||
func (e *OpenSignupDisabledError) HttpStatusCode() int {
|
||||
return http.StatusForbidden
|
||||
}
|
||||
|
||||
@@ -38,10 +38,10 @@ func NewApiKeyController(group *gin.RouterGroup, authMiddleware *middleware.Auth
|
||||
// @Summary List API keys
|
||||
// @Description Get a paginated list of API keys belonging to the current user
|
||||
// @Tags API Keys
|
||||
// @Param page query int false "Page number, starting from 1" default(1)
|
||||
// @Param limit query int false "Number of items per page" default(10)
|
||||
// @Param sort_column query string false "Column to sort by" default("created_at")
|
||||
// @Param sort_direction query string false "Sort direction (asc or desc)" default("desc")
|
||||
// @Param pagination[page] query int false "Page number for pagination" default(1)
|
||||
// @Param pagination[limit] query int false "Number of items per page" default(20)
|
||||
// @Param sort[column] query string false "Column to sort by"
|
||||
// @Param sort[direction] query string false "Sort direction (asc or desc)" default("asc")
|
||||
// @Success 200 {object} dto.Paginated[dto.ApiKeyDto]
|
||||
// @Router /api/api-keys [get]
|
||||
func (c *ApiKeyController) listApiKeysHandler(ctx *gin.Context) {
|
||||
@@ -82,7 +82,7 @@ func (c *ApiKeyController) createApiKeyHandler(ctx *gin.Context) {
|
||||
userID := ctx.GetString("userID")
|
||||
|
||||
var input dto.ApiKeyCreateDto
|
||||
if err := ctx.ShouldBindJSON(&input); err != nil {
|
||||
if err := dto.ShouldBindWithNormalizedJSON(ctx, &input); err != nil {
|
||||
_ = ctx.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package controller
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
@@ -57,7 +58,6 @@ type AppConfigController struct {
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Success 200 {array} dto.PublicAppConfigVariableDto
|
||||
// @Failure 500 {object} object "{"error": "error message"}"
|
||||
// @Router /application-configuration [get]
|
||||
func (acc *AppConfigController) listAppConfigHandler(c *gin.Context) {
|
||||
configuration := acc.appConfigService.ListAppConfig(false)
|
||||
@@ -85,7 +85,6 @@ func (acc *AppConfigController) listAppConfigHandler(c *gin.Context) {
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Success 200 {array} dto.AppConfigVariableDto
|
||||
// @Security BearerAuth
|
||||
// @Router /application-configuration/all [get]
|
||||
func (acc *AppConfigController) listAllAppConfigHandler(c *gin.Context) {
|
||||
configuration := acc.appConfigService.ListAppConfig(true)
|
||||
@@ -107,11 +106,10 @@ func (acc *AppConfigController) listAllAppConfigHandler(c *gin.Context) {
|
||||
// @Produce json
|
||||
// @Param body body dto.AppConfigUpdateDto true "Application Configuration"
|
||||
// @Success 200 {array} dto.AppConfigVariableDto
|
||||
// @Security BearerAuth
|
||||
// @Router /api/application-configuration [put]
|
||||
func (acc *AppConfigController) updateAppConfigHandler(c *gin.Context) {
|
||||
var input dto.AppConfigUpdateDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
if err := dto.ShouldBindWithNormalizedJSON(c, &input); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
@@ -164,7 +162,6 @@ func (acc *AppConfigController) getLogoHandler(c *gin.Context) {
|
||||
// @Tags Application Configuration
|
||||
// @Produce image/x-icon
|
||||
// @Success 200 {file} binary "Favicon image"
|
||||
// @Failure 404 {object} object "{"error": "File not found"}"
|
||||
// @Router /api/application-configuration/favicon [get]
|
||||
func (acc *AppConfigController) getFaviconHandler(c *gin.Context) {
|
||||
acc.getImage(c, "favicon", "ico")
|
||||
@@ -177,7 +174,6 @@ func (acc *AppConfigController) getFaviconHandler(c *gin.Context) {
|
||||
// @Produce image/png
|
||||
// @Produce image/jpeg
|
||||
// @Success 200 {file} binary "Background image"
|
||||
// @Failure 404 {object} object "{"error": "File not found"}"
|
||||
// @Router /api/application-configuration/background-image [get]
|
||||
func (acc *AppConfigController) getBackgroundImageHandler(c *gin.Context) {
|
||||
imageType := acc.appConfigService.GetDbConfig().BackgroundImageType.Value
|
||||
@@ -192,7 +188,6 @@ func (acc *AppConfigController) getBackgroundImageHandler(c *gin.Context) {
|
||||
// @Param light query boolean false "Light mode logo (true) or dark mode logo (false)"
|
||||
// @Param file formData file true "Logo image file"
|
||||
// @Success 204 "No Content"
|
||||
// @Security BearerAuth
|
||||
// @Router /api/application-configuration/logo [put]
|
||||
func (acc *AppConfigController) updateLogoHandler(c *gin.Context) {
|
||||
dbConfig := acc.appConfigService.GetDbConfig()
|
||||
@@ -218,7 +213,6 @@ func (acc *AppConfigController) updateLogoHandler(c *gin.Context) {
|
||||
// @Accept multipart/form-data
|
||||
// @Param file formData file true "Favicon file (.ico)"
|
||||
// @Success 204 "No Content"
|
||||
// @Security BearerAuth
|
||||
// @Router /api/application-configuration/favicon [put]
|
||||
func (acc *AppConfigController) updateFaviconHandler(c *gin.Context) {
|
||||
file, err := c.FormFile("file")
|
||||
@@ -242,7 +236,6 @@ func (acc *AppConfigController) updateFaviconHandler(c *gin.Context) {
|
||||
// @Accept multipart/form-data
|
||||
// @Param file formData file true "Background image file"
|
||||
// @Success 204 "No Content"
|
||||
// @Security BearerAuth
|
||||
// @Router /api/application-configuration/background-image [put]
|
||||
func (acc *AppConfigController) updateBackgroundImageHandler(c *gin.Context) {
|
||||
imageType := acc.appConfigService.GetDbConfig().BackgroundImageType.Value
|
||||
@@ -255,6 +248,8 @@ func (acc *AppConfigController) getImage(c *gin.Context, name string, imageType
|
||||
mimeType := utils.GetImageMimeType(imageType)
|
||||
|
||||
c.Header("Content-Type", mimeType)
|
||||
|
||||
utils.SetCacheControlHeader(c, 15*time.Minute, 24*time.Hour)
|
||||
c.File(imagePath)
|
||||
}
|
||||
|
||||
@@ -280,7 +275,6 @@ func (acc *AppConfigController) updateImage(c *gin.Context, imageName string, ol
|
||||
// @Description Manually trigger LDAP synchronization
|
||||
// @Tags Application Configuration
|
||||
// @Success 204 "No Content"
|
||||
// @Security BearerAuth
|
||||
// @Router /api/application-configuration/sync-ldap [post]
|
||||
func (acc *AppConfigController) syncLdapHandler(c *gin.Context) {
|
||||
err := acc.ldapService.SyncAll(c.Request.Context())
|
||||
@@ -297,7 +291,6 @@ func (acc *AppConfigController) syncLdapHandler(c *gin.Context) {
|
||||
// @Description Send a test email to verify email configuration
|
||||
// @Tags Application Configuration
|
||||
// @Success 204 "No Content"
|
||||
// @Security BearerAuth
|
||||
// @Router /api/application-configuration/test-email [post]
|
||||
func (acc *AppConfigController) testEmailHandler(c *gin.Context) {
|
||||
userID := c.GetString("userID")
|
||||
|
||||
@@ -34,10 +34,10 @@ type AuditLogController struct {
|
||||
// @Summary List audit logs
|
||||
// @Description Get a paginated list of audit logs for the current user
|
||||
// @Tags Audit Logs
|
||||
// @Param page query int false "Page number, starting from 1" default(1)
|
||||
// @Param limit query int false "Number of items per page" default(10)
|
||||
// @Param sort_column query string false "Column to sort by" default("created_at")
|
||||
// @Param sort_direction query string false "Sort direction (asc or desc)" default("desc")
|
||||
// @Param pagination[page] query int false "Page number for pagination" default(1)
|
||||
// @Param pagination[limit] query int false "Number of items per page" default(20)
|
||||
// @Param sort[column] query string false "Column to sort by"
|
||||
// @Param sort[direction] query string false "Sort direction (asc or desc)" default("asc")
|
||||
// @Success 200 {object} dto.Paginated[dto.AuditLogDto]
|
||||
// @Router /api/audit-logs [get]
|
||||
func (alc *AuditLogController) listAuditLogsForUserHandler(c *gin.Context) {
|
||||
@@ -82,13 +82,14 @@ func (alc *AuditLogController) listAuditLogsForUserHandler(c *gin.Context) {
|
||||
// @Summary List all audit logs
|
||||
// @Description Get a paginated list of all audit logs (admin only)
|
||||
// @Tags Audit Logs
|
||||
// @Param page query int false "Page number, starting from 1" default(1)
|
||||
// @Param limit query int false "Number of items per page" default(10)
|
||||
// @Param sort_column query string false "Column to sort by" default("created_at")
|
||||
// @Param sort_direction query string false "Sort direction (asc or desc)" default("desc")
|
||||
// @Param user_id query string false "Filter by user ID"
|
||||
// @Param event query string false "Filter by event type"
|
||||
// @Param client_name query string false "Filter by client name"
|
||||
// @Param pagination[page] query int false "Page number for pagination" default(1)
|
||||
// @Param pagination[limit] query int false "Number of items per page" default(20)
|
||||
// @Param sort[column] query string false "Column to sort by"
|
||||
// @Param sort[direction] query string false "Sort direction (asc or desc)" default("asc")
|
||||
// @Param filters[userId] query string false "Filter by user ID"
|
||||
// @Param filters[event] query string false "Filter by event type"
|
||||
// @Param filters[clientName] query string false "Filter by client name"
|
||||
// @Param filters[location] query string false "Filter by location type (external or internal)"
|
||||
// @Success 200 {object} dto.Paginated[dto.AuditLogDto]
|
||||
// @Router /api/audit-logs/all [get]
|
||||
func (alc *AuditLogController) listAllAuditLogsHandler(c *gin.Context) {
|
||||
|
||||
@@ -35,10 +35,6 @@ type CustomClaimController struct {
|
||||
// @Tags Custom Claims
|
||||
// @Produce json
|
||||
// @Success 200 {array} string "List of suggested custom claim names"
|
||||
// @Failure 401 {object} object "Unauthorized"
|
||||
// @Failure 403 {object} object "Forbidden"
|
||||
// @Failure 500 {object} object "Internal server error"
|
||||
// @Security BearerAuth
|
||||
// @Router /api/custom-claims/suggestions [get]
|
||||
func (ccc *CustomClaimController) getSuggestionsHandler(c *gin.Context) {
|
||||
claims, err := ccc.customClaimService.GetSuggestions(c.Request.Context())
|
||||
@@ -63,7 +59,7 @@ func (ccc *CustomClaimController) getSuggestionsHandler(c *gin.Context) {
|
||||
func (ccc *CustomClaimController) UpdateCustomClaimsForUserHandler(c *gin.Context) {
|
||||
var input []dto.CustomClaimCreateDto
|
||||
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
if err := dto.ShouldBindWithNormalizedJSON(c, &input); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
@@ -93,12 +89,11 @@ func (ccc *CustomClaimController) UpdateCustomClaimsForUserHandler(c *gin.Contex
|
||||
// @Param userGroupId path string true "User Group ID"
|
||||
// @Param claims body []dto.CustomClaimCreateDto true "List of custom claims to set for the user group"
|
||||
// @Success 200 {array} dto.CustomClaimDto "Updated custom claims"
|
||||
// @Security BearerAuth
|
||||
// @Router /api/custom-claims/user-group/{userGroupId} [put]
|
||||
func (ccc *CustomClaimController) UpdateCustomClaimsForUserGroupHandler(c *gin.Context) {
|
||||
var input []dto.CustomClaimCreateDto
|
||||
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
if err := dto.ShouldBindWithNormalizedJSON(c, &input); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -14,6 +14,10 @@ func NewTestController(group *gin.RouterGroup, testService *service.TestService)
|
||||
testController := &TestController{TestService: testService}
|
||||
|
||||
group.POST("/test/reset", testController.resetAndSeedHandler)
|
||||
group.POST("/test/refreshtoken", testController.signRefreshToken)
|
||||
|
||||
group.GET("/externalidp/jwks.json", testController.externalIdPJWKS)
|
||||
group.POST("/externalidp/sign", testController.externalIdPSignToken)
|
||||
}
|
||||
|
||||
type TestController struct {
|
||||
@@ -21,6 +25,16 @@ type TestController struct {
|
||||
}
|
||||
|
||||
func (tc *TestController) resetAndSeedHandler(c *gin.Context) {
|
||||
var baseURL string
|
||||
if c.Request.TLS != nil {
|
||||
baseURL = "https://" + c.Request.Host
|
||||
} else {
|
||||
baseURL = "http://" + c.Request.Host
|
||||
}
|
||||
|
||||
skipLdap := c.Query("skip-ldap") == "true"
|
||||
skipSeed := c.Query("skip-seed") == "true"
|
||||
|
||||
if err := tc.TestService.ResetDatabase(); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -31,9 +45,11 @@ func (tc *TestController) resetAndSeedHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := tc.TestService.SeedDatabase(); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
if !skipSeed {
|
||||
if err := tc.TestService.SeedDatabase(baseURL); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := tc.TestService.ResetAppConfig(c.Request.Context()); err != nil {
|
||||
@@ -41,17 +57,71 @@ func (tc *TestController) resetAndSeedHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := tc.TestService.SetLdapTestConfig(c.Request.Context()); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
if !skipLdap {
|
||||
if err := tc.TestService.SetLdapTestConfig(c.Request.Context()); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := tc.TestService.SyncLdap(c.Request.Context()); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
if err := tc.TestService.SyncLdap(c.Request.Context()); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
tc.TestService.SetJWTKeys()
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (tc *TestController) externalIdPJWKS(c *gin.Context) {
|
||||
jwks, err := tc.TestService.GetExternalIdPJWKS()
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, jwks)
|
||||
}
|
||||
|
||||
func (tc *TestController) externalIdPSignToken(c *gin.Context) {
|
||||
var input struct {
|
||||
Aud string `json:"aud"`
|
||||
Iss string `json:"iss"`
|
||||
Sub string `json:"sub"`
|
||||
}
|
||||
err := c.ShouldBindJSON(&input)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
token, err := tc.TestService.SignExternalIdPToken(input.Iss, input.Sub, input.Aud)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
c.Writer.WriteString(token)
|
||||
}
|
||||
|
||||
func (tc *TestController) signRefreshToken(c *gin.Context) {
|
||||
var input struct {
|
||||
UserID string `json:"user"`
|
||||
ClientID string `json:"client"`
|
||||
RefreshToken string `json:"rt"`
|
||||
}
|
||||
err := c.ShouldBindJSON(&input)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
token, err := tc.TestService.SignRefreshToken(input.UserID, input.ClientID, input.RefreshToken)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
c.Writer.WriteString(token)
|
||||
}
|
||||
|
||||
@@ -6,15 +6,16 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils/cookie"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/middleware"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/service"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils/cookie"
|
||||
)
|
||||
|
||||
// NewOidcController creates a new controller for OIDC related endpoints
|
||||
@@ -48,9 +49,14 @@ func NewOidcController(group *gin.RouterGroup, authMiddleware *middleware.AuthMi
|
||||
group.DELETE("/oidc/clients/:id/logo", oc.deleteClientLogoHandler)
|
||||
group.POST("/oidc/clients/:id/logo", authMiddleware.Add(), fileSizeLimitMiddleware.Add(2<<20), oc.updateClientLogoHandler)
|
||||
|
||||
group.GET("/oidc/clients/:id/preview/:userId", authMiddleware.Add(), oc.getClientPreviewHandler)
|
||||
|
||||
group.POST("/oidc/device/authorize", oc.deviceAuthorizationHandler)
|
||||
group.POST("/oidc/device/verify", authMiddleware.WithAdminNotRequired().Add(), oc.verifyDeviceCodeHandler)
|
||||
group.GET("/oidc/device/info", authMiddleware.WithAdminNotRequired().Add(), oc.getDeviceCodeInfoHandler)
|
||||
|
||||
group.GET("/oidc/users/me/clients", authMiddleware.WithAdminNotRequired().Add(), oc.listOwnAuthorizedClientsHandler)
|
||||
group.GET("/oidc/users/:id/clients", authMiddleware.Add(), oc.listAuthorizedClientsHandler)
|
||||
}
|
||||
|
||||
type OidcController struct {
|
||||
@@ -66,7 +72,6 @@ type OidcController struct {
|
||||
// @Produce json
|
||||
// @Param request body dto.AuthorizeOidcClientRequestDto true "Authorization request parameters"
|
||||
// @Success 200 {object} dto.AuthorizeOidcClientResponseDto "Authorization code and callback URL"
|
||||
// @Security BearerAuth
|
||||
// @Router /api/oidc/authorize [post]
|
||||
func (oc *OidcController) authorizeHandler(c *gin.Context) {
|
||||
var input dto.AuthorizeOidcClientRequestDto
|
||||
@@ -84,6 +89,7 @@ func (oc *OidcController) authorizeHandler(c *gin.Context) {
|
||||
response := dto.AuthorizeOidcClientResponseDto{
|
||||
Code: code,
|
||||
CallbackURL: callbackURL,
|
||||
Issuer: common.EnvConfig.AppURL,
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
@@ -97,7 +103,6 @@ func (oc *OidcController) authorizeHandler(c *gin.Context) {
|
||||
// @Produce json
|
||||
// @Param request body dto.AuthorizationRequiredDto true "Authorization check parameters"
|
||||
// @Success 200 {object} object "{ \"authorizationRequired\": true/false }"
|
||||
// @Security BearerAuth
|
||||
// @Router /api/oidc/authorization-required [post]
|
||||
func (oc *OidcController) authorizationConfirmationRequiredHandler(c *gin.Context) {
|
||||
var input dto.AuthorizationRequiredDto
|
||||
@@ -121,11 +126,13 @@ func (oc *OidcController) authorizationConfirmationRequiredHandler(c *gin.Contex
|
||||
// @Tags OIDC
|
||||
// @Produce json
|
||||
// @Param client_id formData string false "Client ID (if not using Basic Auth)"
|
||||
// @Param client_secret formData string false "Client secret (if not using Basic Auth)"
|
||||
// @Param client_secret formData string false "Client secret (if not using Basic Auth or client assertions)"
|
||||
// @Param code formData string false "Authorization code (required for 'authorization_code' grant)"
|
||||
// @Param grant_type formData string true "Grant type ('authorization_code' or 'refresh_token')"
|
||||
// @Param code_verifier formData string false "PKCE code verifier (for authorization_code with PKCE)"
|
||||
// @Param refresh_token formData string false "Refresh token (required for 'refresh_token' grant)"
|
||||
// @Param client_assertion formData string false "Client assertion type (for 'authorization_code' grant when using client assertions)"
|
||||
// @Param client_assertion_type formData string false "Client assertion type (for 'authorization_code' grant when using client assertions)"
|
||||
// @Success 200 {object} dto.OidcTokenResponseDto "Token response with access_token and optional id_token and refresh_token"
|
||||
// @Router /api/oidc/token [post]
|
||||
func (oc *OidcController) createTokensHandler(c *gin.Context) {
|
||||
@@ -195,7 +202,7 @@ func (oc *OidcController) userInfoHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
token, err := oc.jwtService.VerifyOauthAccessToken(authToken)
|
||||
token, err := oc.jwtService.VerifyOAuthAccessToken(authToken)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -224,7 +231,6 @@ func (oc *OidcController) userInfoHandler(c *gin.Context) {
|
||||
// @Description End user session and handle OIDC logout
|
||||
// @Tags OIDC
|
||||
// @Accept application/x-www-form-urlencoded
|
||||
// @Produce html
|
||||
// @Param id_token_hint query string false "ID token"
|
||||
// @Param post_logout_redirect_uri query string false "URL to redirect to after logout"
|
||||
// @Param state query string false "State parameter to include in the redirect"
|
||||
@@ -304,9 +310,21 @@ func (oc *OidcController) introspectTokenHandler(c *gin.Context) {
|
||||
// find valid tokens) while still allowing it to be used by an application that is
|
||||
// supposed to interact with our IdP (since that needs to have a client_id
|
||||
// and client_secret anyway).
|
||||
clientID, clientSecret, _ := c.Request.BasicAuth()
|
||||
var (
|
||||
creds service.ClientAuthCredentials
|
||||
ok bool
|
||||
)
|
||||
creds.ClientID, creds.ClientSecret, ok = c.Request.BasicAuth()
|
||||
if !ok {
|
||||
// If there's no basic auth, check if we have a bearer token
|
||||
bearer, ok := utils.BearerAuth(c.Request)
|
||||
if ok {
|
||||
creds.ClientAssertionType = service.ClientAssertionTypeJWTBearer
|
||||
creds.ClientAssertion = bearer
|
||||
}
|
||||
}
|
||||
|
||||
response, err := oc.oidcService.IntrospectToken(c.Request.Context(), clientID, clientSecret, input.Token)
|
||||
response, err := oc.oidcService.IntrospectToken(c.Request.Context(), creds, input.Token)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -348,7 +366,6 @@ func (oc *OidcController) getClientMetaDataHandler(c *gin.Context) {
|
||||
// @Produce json
|
||||
// @Param id path string true "Client ID"
|
||||
// @Success 200 {object} dto.OidcClientWithAllowedUserGroupsDto "Client information"
|
||||
// @Security BearerAuth
|
||||
// @Router /api/oidc/clients/{id} [get]
|
||||
func (oc *OidcController) getClientHandler(c *gin.Context) {
|
||||
clientId := c.Param("id")
|
||||
@@ -360,12 +377,12 @@ func (oc *OidcController) getClientHandler(c *gin.Context) {
|
||||
|
||||
clientDto := dto.OidcClientWithAllowedUserGroupsDto{}
|
||||
err = dto.MapStruct(client, &clientDto)
|
||||
if err == nil {
|
||||
c.JSON(http.StatusOK, clientDto)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
_ = c.Error(err)
|
||||
c.JSON(http.StatusOK, clientDto)
|
||||
}
|
||||
|
||||
// listClientsHandler godoc
|
||||
@@ -373,12 +390,11 @@ func (oc *OidcController) getClientHandler(c *gin.Context) {
|
||||
// @Description Get a paginated list of OIDC clients with optional search and sorting
|
||||
// @Tags OIDC
|
||||
// @Param search query string false "Search term to filter clients by name"
|
||||
// @Param page query int false "Page number, starting from 1" default(1)
|
||||
// @Param limit query int false "Number of items per page" default(10)
|
||||
// @Param sort_column query string false "Column to sort by" default("name")
|
||||
// @Param sort_direction query string false "Sort direction (asc or desc)" default("asc")
|
||||
// @Param pagination[page] query int false "Page number for pagination" default(1)
|
||||
// @Param pagination[limit] query int false "Number of items per page" default(20)
|
||||
// @Param sort[column] query string false "Column to sort by"
|
||||
// @Param sort[direction] query string false "Sort direction (asc or desc)" default("asc")
|
||||
// @Success 200 {object} dto.Paginated[dto.OidcClientWithAllowedGroupsCountDto]
|
||||
// @Security BearerAuth
|
||||
// @Router /api/oidc/clients [get]
|
||||
func (oc *OidcController) listClientsHandler(c *gin.Context) {
|
||||
searchTerm := c.Query("search")
|
||||
@@ -424,7 +440,6 @@ func (oc *OidcController) listClientsHandler(c *gin.Context) {
|
||||
// @Produce json
|
||||
// @Param client body dto.OidcClientCreateDto true "Client information"
|
||||
// @Success 201 {object} dto.OidcClientWithAllowedUserGroupsDto "Created client"
|
||||
// @Security BearerAuth
|
||||
// @Router /api/oidc/clients [post]
|
||||
func (oc *OidcController) createClientHandler(c *gin.Context) {
|
||||
var input dto.OidcClientCreateDto
|
||||
@@ -454,7 +469,6 @@ func (oc *OidcController) createClientHandler(c *gin.Context) {
|
||||
// @Tags OIDC
|
||||
// @Param id path string true "Client ID"
|
||||
// @Success 204 "No Content"
|
||||
// @Security BearerAuth
|
||||
// @Router /api/oidc/clients/{id} [delete]
|
||||
func (oc *OidcController) deleteClientHandler(c *gin.Context) {
|
||||
err := oc.oidcService.DeleteClient(c.Request.Context(), c.Param("id"))
|
||||
@@ -475,7 +489,6 @@ func (oc *OidcController) deleteClientHandler(c *gin.Context) {
|
||||
// @Param id path string true "Client ID"
|
||||
// @Param client body dto.OidcClientCreateDto true "Client information"
|
||||
// @Success 200 {object} dto.OidcClientWithAllowedUserGroupsDto "Updated client"
|
||||
// @Security BearerAuth
|
||||
// @Router /api/oidc/clients/{id} [put]
|
||||
func (oc *OidcController) updateClientHandler(c *gin.Context) {
|
||||
var input dto.OidcClientCreateDto
|
||||
@@ -506,7 +519,6 @@ func (oc *OidcController) updateClientHandler(c *gin.Context) {
|
||||
// @Produce json
|
||||
// @Param id path string true "Client ID"
|
||||
// @Success 200 {object} object "{ \"secret\": \"string\" }"
|
||||
// @Security BearerAuth
|
||||
// @Router /api/oidc/clients/{id}/secret [post]
|
||||
func (oc *OidcController) createClientSecretHandler(c *gin.Context) {
|
||||
secret, err := oc.oidcService.CreateClientSecret(c.Request.Context(), c.Param("id"))
|
||||
@@ -535,6 +547,8 @@ func (oc *OidcController) getClientLogoHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
utils.SetCacheControlHeader(c, 15*time.Minute, 12*time.Hour)
|
||||
|
||||
c.Header("Content-Type", mimeType)
|
||||
c.File(imagePath)
|
||||
}
|
||||
@@ -545,9 +559,8 @@ func (oc *OidcController) getClientLogoHandler(c *gin.Context) {
|
||||
// @Tags OIDC
|
||||
// @Accept multipart/form-data
|
||||
// @Param id path string true "Client ID"
|
||||
// @Param file formData file true "Logo image file (PNG, JPG, or SVG, max 2MB)"
|
||||
// @Param file formData file true "Logo image file (PNG, JPG, or SVG)"
|
||||
// @Success 204 "No Content"
|
||||
// @Security BearerAuth
|
||||
// @Router /api/oidc/clients/{id}/logo [post]
|
||||
func (oc *OidcController) updateClientLogoHandler(c *gin.Context) {
|
||||
file, err := c.FormFile("file")
|
||||
@@ -571,7 +584,6 @@ func (oc *OidcController) updateClientLogoHandler(c *gin.Context) {
|
||||
// @Tags OIDC
|
||||
// @Param id path string true "Client ID"
|
||||
// @Success 204 "No Content"
|
||||
// @Security BearerAuth
|
||||
// @Router /api/oidc/clients/{id}/logo [delete]
|
||||
func (oc *OidcController) deleteClientLogoHandler(c *gin.Context) {
|
||||
err := oc.oidcService.DeleteClientLogo(c.Request.Context(), c.Param("id"))
|
||||
@@ -592,7 +604,6 @@ func (oc *OidcController) deleteClientLogoHandler(c *gin.Context) {
|
||||
// @Param id path string true "Client ID"
|
||||
// @Param groups body dto.OidcUpdateAllowedUserGroupsDto true "User group IDs"
|
||||
// @Success 200 {object} dto.OidcClientDto "Updated client"
|
||||
// @Security BearerAuth
|
||||
// @Router /api/oidc/clients/{id}/allowed-user-groups [put]
|
||||
func (oc *OidcController) updateAllowedUserGroupsHandler(c *gin.Context) {
|
||||
var input dto.OidcUpdateAllowedUserGroupsDto
|
||||
@@ -637,6 +648,62 @@ func (oc *OidcController) deviceAuthorizationHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// listOwnAuthorizedClientsHandler godoc
|
||||
// @Summary List authorized clients for current user
|
||||
// @Description Get a paginated list of OIDC clients that the current user has authorized
|
||||
// @Tags OIDC
|
||||
// @Param pagination[page] query int false "Page number for pagination" default(1)
|
||||
// @Param pagination[limit] query int false "Number of items per page" default(20)
|
||||
// @Param sort[column] query string false "Column to sort by"
|
||||
// @Param sort[direction] query string false "Sort direction (asc or desc)" default("asc")
|
||||
// @Success 200 {object} dto.Paginated[dto.AuthorizedOidcClientDto]
|
||||
// @Router /api/oidc/users/me/clients [get]
|
||||
func (oc *OidcController) listOwnAuthorizedClientsHandler(c *gin.Context) {
|
||||
userID := c.GetString("userID")
|
||||
oc.listAuthorizedClients(c, userID)
|
||||
}
|
||||
|
||||
// listAuthorizedClientsHandler godoc
|
||||
// @Summary List authorized clients for a user
|
||||
// @Description Get a paginated list of OIDC clients that a specific user has authorized
|
||||
// @Tags OIDC
|
||||
// @Param id path string true "User ID"
|
||||
// @Param pagination[page] query int false "Page number for pagination" default(1)
|
||||
// @Param pagination[limit] query int false "Number of items per page" default(20)
|
||||
// @Param sort[column] query string false "Column to sort by"
|
||||
// @Param sort[direction] query string false "Sort direction (asc or desc)" default("asc")
|
||||
// @Success 200 {object} dto.Paginated[dto.AuthorizedOidcClientDto]
|
||||
// @Router /api/oidc/users/{id}/clients [get]
|
||||
func (oc *OidcController) listAuthorizedClientsHandler(c *gin.Context) {
|
||||
userID := c.Param("id")
|
||||
oc.listAuthorizedClients(c, userID)
|
||||
}
|
||||
|
||||
func (oc *OidcController) listAuthorizedClients(c *gin.Context, userID string) {
|
||||
var sortedPaginationRequest utils.SortedPaginationRequest
|
||||
if err := c.ShouldBindQuery(&sortedPaginationRequest); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
authorizedClients, pagination, err := oc.oidcService.ListAuthorizedClients(c.Request.Context(), userID, sortedPaginationRequest)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
// Map the clients to DTOs
|
||||
var authorizedClientsDto []dto.AuthorizedOidcClientDto
|
||||
if err := dto.MapStructList(authorizedClients, &authorizedClientsDto); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, dto.Paginated[dto.AuthorizedOidcClientDto]{
|
||||
Data: authorizedClientsDto,
|
||||
Pagination: pagination,
|
||||
})
|
||||
}
|
||||
|
||||
func (oc *OidcController) verifyDeviceCodeHandler(c *gin.Context) {
|
||||
userCode := c.Query("code")
|
||||
if userCode == "" {
|
||||
@@ -672,3 +739,43 @@ func (oc *OidcController) getDeviceCodeInfoHandler(c *gin.Context) {
|
||||
|
||||
c.JSON(http.StatusOK, deviceCodeInfo)
|
||||
}
|
||||
|
||||
// getClientPreviewHandler godoc
|
||||
// @Summary Preview OIDC client data for user
|
||||
// @Description Get a preview of the OIDC data (ID token, access token, userinfo) that would be sent to the client for a specific user
|
||||
// @Tags OIDC
|
||||
// @Produce json
|
||||
// @Param id path string true "Client ID"
|
||||
// @Param userId path string true "User ID to preview data for"
|
||||
// @Param scopes query string false "Scopes to include in the preview (comma-separated)"
|
||||
// @Success 200 {object} dto.OidcClientPreviewDto "Preview data including ID token, access token, and userinfo payloads"
|
||||
// @Security BearerAuth
|
||||
// @Router /api/oidc/clients/{id}/preview/{userId} [get]
|
||||
func (oc *OidcController) getClientPreviewHandler(c *gin.Context) {
|
||||
clientID := c.Param("id")
|
||||
userID := c.Param("userId")
|
||||
scopes := c.Query("scopes")
|
||||
|
||||
if clientID == "" {
|
||||
_ = c.Error(&common.ValidationError{Message: "client ID is required"})
|
||||
return
|
||||
}
|
||||
|
||||
if userID == "" {
|
||||
_ = c.Error(&common.ValidationError{Message: "user ID is required"})
|
||||
return
|
||||
}
|
||||
|
||||
if scopes == "" {
|
||||
_ = c.Error(&common.ValidationError{Message: "scopes are required"})
|
||||
return
|
||||
}
|
||||
|
||||
preview, err := oc.oidcService.GetClientPreview(c.Request.Context(), clientID, userID, scopes)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, preview)
|
||||
}
|
||||
|
||||
@@ -44,11 +44,17 @@ func NewUserController(group *gin.RouterGroup, authMiddleware *middleware.AuthMi
|
||||
group.POST("/users/:id/one-time-access-token", authMiddleware.Add(), uc.createAdminOneTimeAccessTokenHandler)
|
||||
group.POST("/users/:id/one-time-access-email", authMiddleware.Add(), uc.RequestOneTimeAccessEmailAsAdminHandler)
|
||||
group.POST("/one-time-access-token/:token", rateLimitMiddleware.Add(rate.Every(10*time.Second), 5), uc.exchangeOneTimeAccessTokenHandler)
|
||||
group.POST("/one-time-access-token/setup", uc.getSetupAccessTokenHandler)
|
||||
group.POST("/one-time-access-email", rateLimitMiddleware.Add(rate.Every(10*time.Minute), 3), uc.RequestOneTimeAccessEmailAsUnauthenticatedUserHandler)
|
||||
|
||||
group.DELETE("/users/:id/profile-picture", authMiddleware.Add(), uc.resetUserProfilePictureHandler)
|
||||
group.DELETE("/users/me/profile-picture", authMiddleware.WithAdminNotRequired().Add(), uc.resetCurrentUserProfilePictureHandler)
|
||||
|
||||
group.POST("/signup-tokens", authMiddleware.Add(), uc.createSignupTokenHandler)
|
||||
group.GET("/signup-tokens", authMiddleware.Add(), uc.listSignupTokensHandler)
|
||||
group.DELETE("/signup-tokens/:id", authMiddleware.Add(), uc.deleteSignupTokenHandler)
|
||||
group.POST("/signup", rateLimitMiddleware.Add(rate.Every(1*time.Minute), 10), uc.signupHandler)
|
||||
group.POST("/signup/setup", uc.signUpInitialAdmin)
|
||||
|
||||
}
|
||||
|
||||
type UserController struct {
|
||||
@@ -85,10 +91,10 @@ func (uc *UserController) getUserGroupsHandler(c *gin.Context) {
|
||||
// @Description Get a paginated list of users with optional search and sorting
|
||||
// @Tags Users
|
||||
// @Param search query string false "Search term to filter users"
|
||||
// @Param page query int false "Page number, starting from 1" default(1)
|
||||
// @Param limit query int false "Number of items per page" default(10)
|
||||
// @Param sort_column query string false "Column to sort by" default("created_at")
|
||||
// @Param sort_direction query string false "Sort direction (asc or desc)" default("desc")
|
||||
// @Param pagination[page] query int false "Page number for pagination" default(1)
|
||||
// @Param pagination[limit] query int false "Number of items per page" default(20)
|
||||
// @Param sort[column] query string false "Column to sort by"
|
||||
// @Param sort[direction] query string false "Sort direction (asc or desc)" default("asc")
|
||||
// @Success 200 {object} dto.Paginated[dto.UserDto]
|
||||
// @Router /api/users [get]
|
||||
func (uc *UserController) listUsersHandler(c *gin.Context) {
|
||||
@@ -187,7 +193,7 @@ func (uc *UserController) deleteUserHandler(c *gin.Context) {
|
||||
// @Router /api/users [post]
|
||||
func (uc *UserController) createUserHandler(c *gin.Context) {
|
||||
var input dto.UserCreateDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
if err := dto.ShouldBindWithNormalizedJSON(c, &input); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
@@ -250,10 +256,7 @@ func (uc *UserController) getUserProfilePictureHandler(c *gin.Context) {
|
||||
defer picture.Close()
|
||||
}
|
||||
|
||||
_, ok := c.GetQuery("skipCache")
|
||||
if !ok {
|
||||
c.Header("Cache-Control", "public, max-age=900")
|
||||
}
|
||||
utils.SetCacheControlHeader(c, 15*time.Minute, 1*time.Hour)
|
||||
|
||||
c.DataFromReader(http.StatusOK, size, "image/png", picture, nil)
|
||||
}
|
||||
@@ -375,7 +378,7 @@ func (uc *UserController) createAdminOneTimeAccessTokenHandler(c *gin.Context) {
|
||||
// @Router /api/one-time-access-email [post]
|
||||
func (uc *UserController) RequestOneTimeAccessEmailAsUnauthenticatedUserHandler(c *gin.Context) {
|
||||
var input dto.OneTimeAccessEmailAsUnauthenticatedUserDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
if err := dto.ShouldBindWithNormalizedJSON(c, &input); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
@@ -443,14 +446,23 @@ func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, userDto)
|
||||
}
|
||||
|
||||
// getSetupAccessTokenHandler godoc
|
||||
// @Summary Setup initial admin
|
||||
// @Description Generate setup access token for initial admin user configuration
|
||||
// signUpInitialAdmin godoc
|
||||
// @Summary Sign up initial admin user
|
||||
// @Description Sign up and generate setup access token for initial admin user
|
||||
// @Tags Users
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param body body dto.SignUpDto true "User information"
|
||||
// @Success 200 {object} dto.UserDto
|
||||
// @Router /api/one-time-access-token/setup [post]
|
||||
func (uc *UserController) getSetupAccessTokenHandler(c *gin.Context) {
|
||||
user, token, err := uc.userService.SetupInitialAdmin(c.Request.Context())
|
||||
// @Router /api/signup/setup [post]
|
||||
func (uc *UserController) signUpInitialAdmin(c *gin.Context) {
|
||||
var input dto.SignUpDto
|
||||
if err := dto.ShouldBindWithNormalizedJSON(c, &input); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
user, token, err := uc.userService.SignUpInitialAdmin(c.Request.Context(), input)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -498,10 +510,132 @@ func (uc *UserController) updateUserGroups(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, userDto)
|
||||
}
|
||||
|
||||
// createSignupTokenHandler godoc
|
||||
// @Summary Create signup token
|
||||
// @Description Create a new signup token that allows user registration
|
||||
// @Tags Users
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param token body dto.SignupTokenCreateDto true "Signup token information"
|
||||
// @Success 201 {object} dto.SignupTokenDto
|
||||
// @Router /api/signup-tokens [post]
|
||||
func (uc *UserController) createSignupTokenHandler(c *gin.Context) {
|
||||
var input dto.SignupTokenCreateDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
signupToken, err := uc.userService.CreateSignupToken(c.Request.Context(), input.ExpiresAt, input.UsageLimit)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
var tokenDto dto.SignupTokenDto
|
||||
if err := dto.MapStruct(signupToken, &tokenDto); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, tokenDto)
|
||||
}
|
||||
|
||||
// listSignupTokensHandler godoc
|
||||
// @Summary List signup tokens
|
||||
// @Description Get a paginated list of signup tokens
|
||||
// @Tags Users
|
||||
// @Param pagination[page] query int false "Page number for pagination" default(1)
|
||||
// @Param pagination[limit] query int false "Number of items per page" default(20)
|
||||
// @Param sort[column] query string false "Column to sort by"
|
||||
// @Param sort[direction] query string false "Sort direction (asc or desc)" default("asc")
|
||||
// @Success 200 {object} dto.Paginated[dto.SignupTokenDto]
|
||||
// @Router /api/signup-tokens [get]
|
||||
func (uc *UserController) listSignupTokensHandler(c *gin.Context) {
|
||||
var sortedPaginationRequest utils.SortedPaginationRequest
|
||||
if err := c.ShouldBindQuery(&sortedPaginationRequest); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
tokens, pagination, err := uc.userService.ListSignupTokens(c.Request.Context(), sortedPaginationRequest)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
var tokensDto []dto.SignupTokenDto
|
||||
if err := dto.MapStructList(tokens, &tokensDto); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, dto.Paginated[dto.SignupTokenDto]{
|
||||
Data: tokensDto,
|
||||
Pagination: pagination,
|
||||
})
|
||||
}
|
||||
|
||||
// deleteSignupTokenHandler godoc
|
||||
// @Summary Delete signup token
|
||||
// @Description Delete a signup token by ID
|
||||
// @Tags Users
|
||||
// @Param id path string true "Token ID"
|
||||
// @Success 204 "No Content"
|
||||
// @Router /api/signup-tokens/{id} [delete]
|
||||
func (uc *UserController) deleteSignupTokenHandler(c *gin.Context) {
|
||||
tokenID := c.Param("id")
|
||||
|
||||
err := uc.userService.DeleteSignupToken(c.Request.Context(), tokenID)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// signupWithTokenHandler godoc
|
||||
// @Summary Sign up
|
||||
// @Description Create a new user account
|
||||
// @Tags Users
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param user body dto.SignUpDto true "User information"
|
||||
// @Success 201 {object} dto.SignUpDto
|
||||
// @Router /api/signup [post]
|
||||
func (uc *UserController) signupHandler(c *gin.Context) {
|
||||
var input dto.SignUpDto
|
||||
if err := dto.ShouldBindWithNormalizedJSON(c, &input); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
ipAddress := c.ClientIP()
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
|
||||
user, accessToken, err := uc.userService.SignUp(c.Request.Context(), input, ipAddress, userAgent)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
maxAge := int(uc.appConfigService.GetDbConfig().SessionDuration.AsDurationMinutes().Seconds())
|
||||
cookie.AddAccessTokenCookie(c, maxAge, accessToken)
|
||||
|
||||
var userDto dto.UserDto
|
||||
if err := dto.MapStruct(user, &userDto); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, userDto)
|
||||
}
|
||||
|
||||
// updateUser is an internal helper method, not exposed as an API endpoint
|
||||
func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) {
|
||||
var input dto.UserCreateDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
if err := dto.ShouldBindWithNormalizedJSON(c, &input); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -40,10 +40,10 @@ type UserGroupController struct {
|
||||
// @Description Get a paginated list of user groups with optional search and sorting
|
||||
// @Tags User Groups
|
||||
// @Param search query string false "Search term to filter user groups by name"
|
||||
// @Param page query int false "Page number, starting from 1" default(1)
|
||||
// @Param limit query int false "Number of items per page" default(10)
|
||||
// @Param sort_column query string false "Column to sort by" default("name")
|
||||
// @Param sort_direction query string false "Sort direction (asc or desc)" default("asc")
|
||||
// @Param pagination[page] query int false "Page number for pagination" default(1)
|
||||
// @Param pagination[limit] query int false "Number of items per page" default(20)
|
||||
// @Param sort[column] query string false "Column to sort by"
|
||||
// @Param sort[direction] query string false "Sort direction (asc or desc)" default("asc")
|
||||
// @Success 200 {object} dto.Paginated[dto.UserGroupDtoWithUserCount]
|
||||
// @Router /api/user-groups [get]
|
||||
func (ugc *UserGroupController) list(c *gin.Context) {
|
||||
@@ -92,7 +92,6 @@ func (ugc *UserGroupController) list(c *gin.Context) {
|
||||
// @Produce json
|
||||
// @Param id path string true "User Group ID"
|
||||
// @Success 200 {object} dto.UserGroupDtoWithUsers
|
||||
// @Security BearerAuth
|
||||
// @Router /api/user-groups/{id} [get]
|
||||
func (ugc *UserGroupController) get(c *gin.Context) {
|
||||
group, err := ugc.UserGroupService.Get(c.Request.Context(), c.Param("id"))
|
||||
@@ -118,11 +117,10 @@ func (ugc *UserGroupController) get(c *gin.Context) {
|
||||
// @Produce json
|
||||
// @Param userGroup body dto.UserGroupCreateDto true "User group information"
|
||||
// @Success 201 {object} dto.UserGroupDtoWithUsers "Created user group"
|
||||
// @Security BearerAuth
|
||||
// @Router /api/user-groups [post]
|
||||
func (ugc *UserGroupController) create(c *gin.Context) {
|
||||
var input dto.UserGroupCreateDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
if err := dto.ShouldBindWithNormalizedJSON(c, &input); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
@@ -151,11 +149,10 @@ func (ugc *UserGroupController) create(c *gin.Context) {
|
||||
// @Param id path string true "User Group ID"
|
||||
// @Param userGroup body dto.UserGroupCreateDto true "User group information"
|
||||
// @Success 200 {object} dto.UserGroupDtoWithUsers "Updated user group"
|
||||
// @Security BearerAuth
|
||||
// @Router /api/user-groups/{id} [put]
|
||||
func (ugc *UserGroupController) update(c *gin.Context) {
|
||||
var input dto.UserGroupCreateDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
if err := dto.ShouldBindWithNormalizedJSON(c, &input); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
@@ -183,7 +180,6 @@ func (ugc *UserGroupController) update(c *gin.Context) {
|
||||
// @Produce json
|
||||
// @Param id path string true "User Group ID"
|
||||
// @Success 204 "No Content"
|
||||
// @Security BearerAuth
|
||||
// @Router /api/user-groups/{id} [delete]
|
||||
func (ugc *UserGroupController) delete(c *gin.Context) {
|
||||
if err := ugc.UserGroupService.Delete(c.Request.Context(), c.Param("id")); err != nil {
|
||||
@@ -203,7 +199,6 @@ func (ugc *UserGroupController) delete(c *gin.Context) {
|
||||
// @Param id path string true "User Group ID"
|
||||
// @Param users body dto.UserGroupUpdateUsersDto true "List of user IDs to assign to this group"
|
||||
// @Success 200 {object} dto.UserGroupDtoWithUsers
|
||||
// @Security BearerAuth
|
||||
// @Router /api/user-groups/{id}/users [put]
|
||||
func (ugc *UserGroupController) updateUsers(c *gin.Context) {
|
||||
var input dto.UserGroupUpdateUsersDto
|
||||
|
||||
@@ -69,20 +69,21 @@ func (wkc *WellKnownController) computeOIDCConfiguration() ([]byte, error) {
|
||||
return nil, fmt.Errorf("failed to get key algorithm: %w", err)
|
||||
}
|
||||
config := map[string]any{
|
||||
"issuer": appUrl,
|
||||
"authorization_endpoint": appUrl + "/authorize",
|
||||
"token_endpoint": appUrl + "/api/oidc/token",
|
||||
"userinfo_endpoint": appUrl + "/api/oidc/userinfo",
|
||||
"end_session_endpoint": appUrl + "/api/oidc/end-session",
|
||||
"introspection_endpoint": appUrl + "/api/oidc/introspect",
|
||||
"device_authorization_endpoint": appUrl + "/api/oidc/device/authorize",
|
||||
"jwks_uri": appUrl + "/.well-known/jwks.json",
|
||||
"grant_types_supported": []string{service.GrantTypeAuthorizationCode, service.GrantTypeRefreshToken, service.GrantTypeDeviceCode},
|
||||
"scopes_supported": []string{"openid", "profile", "email", "groups"},
|
||||
"claims_supported": []string{"sub", "given_name", "family_name", "name", "email", "email_verified", "preferred_username", "picture", "groups"},
|
||||
"response_types_supported": []string{"code", "id_token"},
|
||||
"subject_types_supported": []string{"public"},
|
||||
"id_token_signing_alg_values_supported": []string{alg.String()},
|
||||
"issuer": appUrl,
|
||||
"authorization_endpoint": appUrl + "/authorize",
|
||||
"token_endpoint": appUrl + "/api/oidc/token",
|
||||
"userinfo_endpoint": appUrl + "/api/oidc/userinfo",
|
||||
"end_session_endpoint": appUrl + "/api/oidc/end-session",
|
||||
"introspection_endpoint": appUrl + "/api/oidc/introspect",
|
||||
"device_authorization_endpoint": appUrl + "/api/oidc/device/authorize",
|
||||
"jwks_uri": appUrl + "/.well-known/jwks.json",
|
||||
"grant_types_supported": []string{service.GrantTypeAuthorizationCode, service.GrantTypeRefreshToken, service.GrantTypeDeviceCode},
|
||||
"scopes_supported": []string{"openid", "profile", "email", "groups"},
|
||||
"claims_supported": []string{"sub", "given_name", "family_name", "name", "email", "email_verified", "preferred_username", "picture", "groups"},
|
||||
"response_types_supported": []string{"code", "id_token"},
|
||||
"subject_types_supported": []string{"public"},
|
||||
"id_token_signing_alg_values_supported": []string{alg.String()},
|
||||
"authorization_response_iss_parameter_supported": true,
|
||||
}
|
||||
return json.Marshal(config)
|
||||
}
|
||||
|
||||
@@ -5,8 +5,8 @@ import (
|
||||
)
|
||||
|
||||
type ApiKeyCreateDto struct {
|
||||
Name string `json:"name" binding:"required,min=3,max=50"`
|
||||
Description string `json:"description"`
|
||||
Name string `json:"name" binding:"required,min=3,max=50" unorm:"nfc"`
|
||||
Description string `json:"description" unorm:"nfc"`
|
||||
ExpiresAt datatype.DateTime `json:"expiresAt" binding:"required"`
|
||||
}
|
||||
|
||||
|
||||
@@ -12,11 +12,13 @@ type AppConfigVariableDto struct {
|
||||
}
|
||||
|
||||
type AppConfigUpdateDto struct {
|
||||
AppName string `json:"appName" binding:"required,min=1,max=30"`
|
||||
AppName string `json:"appName" binding:"required,min=1,max=30" unorm:"nfc"`
|
||||
SessionDuration string `json:"sessionDuration" binding:"required"`
|
||||
EmailsVerified string `json:"emailsVerified" binding:"required"`
|
||||
DisableAnimations string `json:"disableAnimations" binding:"required"`
|
||||
AllowOwnAccountEdit string `json:"allowOwnAccountEdit" binding:"required"`
|
||||
AllowUserSignups string `json:"allowUserSignups" binding:"required,oneof=disabled withToken open"`
|
||||
AccentColor string `json:"accentColor"`
|
||||
SmtpHost string `json:"smtpHost"`
|
||||
SmtpPort string `json:"smtpPort"`
|
||||
SmtpFrom string `json:"smtpFrom" binding:"omitempty,email"`
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
)
|
||||
|
||||
@@ -9,18 +8,19 @@ type AuditLogDto struct {
|
||||
ID string `json:"id"`
|
||||
CreatedAt datatype.DateTime `json:"createdAt"`
|
||||
|
||||
Event model.AuditLogEvent `json:"event"`
|
||||
IpAddress string `json:"ipAddress"`
|
||||
Country string `json:"country"`
|
||||
City string `json:"city"`
|
||||
Device string `json:"device"`
|
||||
UserID string `json:"userID"`
|
||||
Username string `json:"username"`
|
||||
Data model.AuditLogData `json:"data"`
|
||||
Event string `json:"event"`
|
||||
IpAddress string `json:"ipAddress"`
|
||||
Country string `json:"country"`
|
||||
City string `json:"city"`
|
||||
Device string `json:"device"`
|
||||
UserID string `json:"userID"`
|
||||
Username string `json:"username"`
|
||||
Data map[string]string `json:"data"`
|
||||
}
|
||||
|
||||
type AuditLogFilterDto struct {
|
||||
UserID string `form:"filters[userId]"`
|
||||
Event string `form:"filters[event]"`
|
||||
ClientName string `form:"filters[clientName]"`
|
||||
Location string `form:"filters[location]"`
|
||||
}
|
||||
|
||||
@@ -6,6 +6,6 @@ type CustomClaimDto struct {
|
||||
}
|
||||
|
||||
type CustomClaimCreateDto struct {
|
||||
Key string `json:"key" binding:"required"`
|
||||
Value string `json:"value" binding:"required"`
|
||||
Key string `json:"key" binding:"required" unorm:"nfc"`
|
||||
Value string `json:"value" binding:"required" unorm:"nfc"`
|
||||
}
|
||||
|
||||
@@ -1,109 +1,27 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"time"
|
||||
"fmt"
|
||||
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
"github.com/jinzhu/copier"
|
||||
)
|
||||
|
||||
// MapStructList maps a list of source structs to a list of destination structs
|
||||
func MapStructList[S any, D any](source []S, destination *[]D) error {
|
||||
*destination = make([]D, 0, len(source))
|
||||
func MapStructList[S any, D any](source []S, destination *[]D) (err error) {
|
||||
*destination = make([]D, len(source))
|
||||
|
||||
for _, item := range source {
|
||||
var destItem D
|
||||
if err := MapStruct(item, &destItem); err != nil {
|
||||
return err
|
||||
for i, item := range source {
|
||||
err = MapStruct(item, &((*destination)[i]))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to map field %d: %w", i, err)
|
||||
}
|
||||
*destination = append(*destination, destItem)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MapStruct maps a source struct to a destination struct
|
||||
func MapStruct[S any, D any](source S, destination *D) error {
|
||||
// Ensure destination is a non-nil pointer
|
||||
destValue := reflect.ValueOf(destination)
|
||||
if destValue.Kind() != reflect.Ptr || destValue.IsNil() {
|
||||
return errors.New("destination must be a non-nil pointer to a struct")
|
||||
}
|
||||
|
||||
// Ensure source is a struct
|
||||
sourceValue := reflect.ValueOf(source)
|
||||
if sourceValue.Kind() != reflect.Struct {
|
||||
return errors.New("source must be a struct")
|
||||
}
|
||||
|
||||
return mapStructInternal(sourceValue, destValue.Elem())
|
||||
}
|
||||
|
||||
func mapStructInternal(sourceVal reflect.Value, destVal reflect.Value) error {
|
||||
for i := 0; i < destVal.NumField(); i++ {
|
||||
destField := destVal.Field(i)
|
||||
destFieldType := destVal.Type().Field(i)
|
||||
|
||||
if destFieldType.Anonymous {
|
||||
if err := mapStructInternal(sourceVal, destField); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
sourceField := sourceVal.FieldByName(destFieldType.Name)
|
||||
|
||||
if sourceField.IsValid() && destField.CanSet() {
|
||||
if err := mapField(sourceField, destField); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func mapField(sourceField reflect.Value, destField reflect.Value) error {
|
||||
switch {
|
||||
case sourceField.Type() == destField.Type():
|
||||
destField.Set(sourceField)
|
||||
case sourceField.Kind() == reflect.Slice && destField.Kind() == reflect.Slice:
|
||||
return mapSlice(sourceField, destField)
|
||||
case sourceField.Kind() == reflect.Struct && destField.Kind() == reflect.Struct:
|
||||
return mapStructInternal(sourceField, destField)
|
||||
default:
|
||||
return mapSpecialTypes(sourceField, destField)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func mapSlice(sourceField reflect.Value, destField reflect.Value) error {
|
||||
if sourceField.Type().Elem() == destField.Type().Elem() {
|
||||
newSlice := reflect.MakeSlice(destField.Type(), sourceField.Len(), sourceField.Cap())
|
||||
for j := 0; j < sourceField.Len(); j++ {
|
||||
newSlice.Index(j).Set(sourceField.Index(j))
|
||||
}
|
||||
destField.Set(newSlice)
|
||||
} else if sourceField.Type().Elem().Kind() == reflect.Struct && destField.Type().Elem().Kind() == reflect.Struct {
|
||||
newSlice := reflect.MakeSlice(destField.Type(), sourceField.Len(), sourceField.Cap())
|
||||
for j := 0; j < sourceField.Len(); j++ {
|
||||
sourceElem := sourceField.Index(j)
|
||||
destElem := reflect.New(destField.Type().Elem()).Elem()
|
||||
if err := mapStructInternal(sourceElem, destElem); err != nil {
|
||||
return err
|
||||
}
|
||||
newSlice.Index(j).Set(destElem)
|
||||
}
|
||||
destField.Set(newSlice)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func mapSpecialTypes(sourceField reflect.Value, destField reflect.Value) error {
|
||||
if _, ok := sourceField.Interface().(datatype.DateTime); ok {
|
||||
if sourceField.Type() == reflect.TypeOf(datatype.DateTime{}) && destField.Type() == reflect.TypeOf(time.Time{}) {
|
||||
dateValue := sourceField.Interface().(datatype.DateTime)
|
||||
destField.Set(reflect.ValueOf(dateValue.ToTime()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
func MapStruct(source any, destination any) error {
|
||||
return copier.CopyWithOption(destination, source, copier.Option{
|
||||
DeepCopy: true,
|
||||
})
|
||||
}
|
||||
|
||||
197
backend/internal/dto/dto_mapper_test.go
Normal file
197
backend/internal/dto/dto_mapper_test.go
Normal file
@@ -0,0 +1,197 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
)
|
||||
|
||||
type sourceStruct struct {
|
||||
AString string
|
||||
AStringPtr *string
|
||||
ABool bool
|
||||
ABoolPtr *bool
|
||||
ACustomDateTime datatype.DateTime
|
||||
ACustomDateTimePtr *datatype.DateTime
|
||||
ANilStringPtr *string
|
||||
ASlice []string
|
||||
AMap map[string]int
|
||||
AStruct embeddedStruct
|
||||
AStructPtr *embeddedStruct
|
||||
|
||||
StringPtrToString *string
|
||||
EmptyStringPtrToString *string
|
||||
NilStringPtrToString *string
|
||||
IntToInt64 int
|
||||
AuditLogEventToString model.AuditLogEvent
|
||||
}
|
||||
|
||||
type destStruct struct {
|
||||
AString string
|
||||
AStringPtr *string
|
||||
ABool bool
|
||||
ABoolPtr *bool
|
||||
ACustomDateTime datatype.DateTime
|
||||
ACustomDateTimePtr *datatype.DateTime
|
||||
ANilStringPtr *string
|
||||
ASlice []string
|
||||
AMap map[string]int
|
||||
AStruct embeddedStruct
|
||||
AStructPtr *embeddedStruct
|
||||
|
||||
StringPtrToString string
|
||||
EmptyStringPtrToString string
|
||||
NilStringPtrToString string
|
||||
IntToInt64 int64
|
||||
AuditLogEventToString string
|
||||
}
|
||||
|
||||
type embeddedStruct struct {
|
||||
Foo string
|
||||
Bar int64
|
||||
}
|
||||
|
||||
func TestMapStruct(t *testing.T) {
|
||||
src := sourceStruct{
|
||||
AString: "abcd",
|
||||
AStringPtr: utils.Ptr("xyz"),
|
||||
ABool: true,
|
||||
ABoolPtr: utils.Ptr(false),
|
||||
ACustomDateTime: datatype.DateTime(time.Date(2025, 1, 2, 3, 4, 5, 0, time.UTC)),
|
||||
ACustomDateTimePtr: utils.Ptr(datatype.DateTime(time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC))),
|
||||
ANilStringPtr: nil,
|
||||
ASlice: []string{"a", "b", "c"},
|
||||
AMap: map[string]int{
|
||||
"a": 1,
|
||||
"b": 2,
|
||||
},
|
||||
AStruct: embeddedStruct{
|
||||
Foo: "bar",
|
||||
Bar: 42,
|
||||
},
|
||||
AStructPtr: &embeddedStruct{
|
||||
Foo: "quo",
|
||||
Bar: 111,
|
||||
},
|
||||
|
||||
StringPtrToString: utils.Ptr("foobar"),
|
||||
EmptyStringPtrToString: utils.Ptr(""),
|
||||
NilStringPtrToString: nil,
|
||||
IntToInt64: 99,
|
||||
AuditLogEventToString: model.AuditLogEventAccountCreated,
|
||||
}
|
||||
var dst destStruct
|
||||
err := MapStruct(src, &dst)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, src.AString, dst.AString)
|
||||
_ = assert.NotNil(t, src.AStringPtr) &&
|
||||
assert.Equal(t, *src.AStringPtr, *dst.AStringPtr)
|
||||
assert.Equal(t, src.ABool, dst.ABool)
|
||||
_ = assert.NotNil(t, src.ABoolPtr) &&
|
||||
assert.Equal(t, *src.ABoolPtr, *dst.ABoolPtr)
|
||||
assert.Equal(t, src.ACustomDateTime, dst.ACustomDateTime)
|
||||
_ = assert.NotNil(t, src.ACustomDateTimePtr) &&
|
||||
assert.Equal(t, *src.ACustomDateTimePtr, *dst.ACustomDateTimePtr)
|
||||
assert.Nil(t, dst.ANilStringPtr)
|
||||
assert.Equal(t, src.ASlice, dst.ASlice)
|
||||
assert.Equal(t, src.AMap, dst.AMap)
|
||||
assert.Equal(t, "bar", dst.AStruct.Foo)
|
||||
assert.Equal(t, int64(42), dst.AStruct.Bar)
|
||||
_ = assert.NotNil(t, src.AStructPtr) &&
|
||||
assert.Equal(t, "quo", dst.AStructPtr.Foo) &&
|
||||
assert.Equal(t, int64(111), dst.AStructPtr.Bar)
|
||||
assert.Equal(t, "foobar", dst.StringPtrToString)
|
||||
assert.Empty(t, dst.EmptyStringPtrToString)
|
||||
assert.Empty(t, dst.NilStringPtrToString)
|
||||
assert.Equal(t, int64(99), dst.IntToInt64)
|
||||
assert.Equal(t, "ACCOUNT_CREATED", dst.AuditLogEventToString)
|
||||
}
|
||||
|
||||
func TestMapStructList(t *testing.T) {
|
||||
sources := []sourceStruct{
|
||||
{
|
||||
AString: "first",
|
||||
AStringPtr: utils.Ptr("one"),
|
||||
ABool: true,
|
||||
ABoolPtr: utils.Ptr(false),
|
||||
ACustomDateTime: datatype.DateTime(time.Date(2025, 1, 2, 3, 4, 5, 0, time.UTC)),
|
||||
ACustomDateTimePtr: utils.Ptr(datatype.DateTime(time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC))),
|
||||
ASlice: []string{"a", "b"},
|
||||
AMap: map[string]int{
|
||||
"a": 1,
|
||||
"b": 2,
|
||||
},
|
||||
AStruct: embeddedStruct{
|
||||
Foo: "first_struct",
|
||||
Bar: 10,
|
||||
},
|
||||
IntToInt64: 10,
|
||||
},
|
||||
{
|
||||
AString: "second",
|
||||
AStringPtr: utils.Ptr("two"),
|
||||
ABool: false,
|
||||
ABoolPtr: utils.Ptr(true),
|
||||
ACustomDateTime: datatype.DateTime(time.Date(2026, 6, 7, 8, 9, 10, 0, time.UTC)),
|
||||
ACustomDateTimePtr: utils.Ptr(datatype.DateTime(time.Date(2023, 6, 7, 8, 9, 10, 0, time.UTC))),
|
||||
ASlice: []string{"c", "d", "e"},
|
||||
AMap: map[string]int{
|
||||
"c": 3,
|
||||
"d": 4,
|
||||
},
|
||||
AStruct: embeddedStruct{
|
||||
Foo: "second_struct",
|
||||
Bar: 20,
|
||||
},
|
||||
IntToInt64: 20,
|
||||
},
|
||||
}
|
||||
|
||||
var destinations []destStruct
|
||||
err := MapStructList(sources, &destinations)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Len(t, destinations, 2)
|
||||
|
||||
// Verify first element
|
||||
assert.Equal(t, "first", destinations[0].AString)
|
||||
assert.Equal(t, "one", *destinations[0].AStringPtr)
|
||||
assert.True(t, destinations[0].ABool)
|
||||
assert.False(t, *destinations[0].ABoolPtr)
|
||||
assert.Equal(t, datatype.DateTime(time.Date(2025, 1, 2, 3, 4, 5, 0, time.UTC)), destinations[0].ACustomDateTime)
|
||||
assert.Equal(t, datatype.DateTime(time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC)), *destinations[0].ACustomDateTimePtr)
|
||||
assert.Equal(t, []string{"a", "b"}, destinations[0].ASlice)
|
||||
assert.Equal(t, map[string]int{"a": 1, "b": 2}, destinations[0].AMap)
|
||||
assert.Equal(t, "first_struct", destinations[0].AStruct.Foo)
|
||||
assert.Equal(t, int64(10), destinations[0].AStruct.Bar)
|
||||
assert.Equal(t, int64(10), destinations[0].IntToInt64)
|
||||
|
||||
// Verify second element
|
||||
assert.Equal(t, "second", destinations[1].AString)
|
||||
assert.Equal(t, "two", *destinations[1].AStringPtr)
|
||||
assert.False(t, destinations[1].ABool)
|
||||
assert.True(t, *destinations[1].ABoolPtr)
|
||||
assert.Equal(t, datatype.DateTime(time.Date(2026, 6, 7, 8, 9, 10, 0, time.UTC)), destinations[1].ACustomDateTime)
|
||||
assert.Equal(t, datatype.DateTime(time.Date(2023, 6, 7, 8, 9, 10, 0, time.UTC)), *destinations[1].ACustomDateTimePtr)
|
||||
assert.Equal(t, []string{"c", "d", "e"}, destinations[1].ASlice)
|
||||
assert.Equal(t, map[string]int{"c": 3, "d": 4}, destinations[1].AMap)
|
||||
assert.Equal(t, "second_struct", destinations[1].AStruct.Foo)
|
||||
assert.Equal(t, int64(20), destinations[1].AStruct.Bar)
|
||||
assert.Equal(t, int64(20), destinations[1].IntToInt64)
|
||||
}
|
||||
|
||||
func TestMapStructList_EmptySource(t *testing.T) {
|
||||
var sources []sourceStruct
|
||||
var destinations []destStruct
|
||||
|
||||
err := MapStructList(sources, &destinations)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, destinations)
|
||||
}
|
||||
94
backend/internal/dto/dto_normalize.go
Normal file
94
backend/internal/dto/dto_normalize.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"reflect"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gin-gonic/gin/binding"
|
||||
"golang.org/x/text/unicode/norm"
|
||||
)
|
||||
|
||||
// Normalize iterates through an object and performs Unicode normalization on all string fields with the `unorm` tag.
|
||||
func Normalize(obj any) {
|
||||
v := reflect.ValueOf(obj)
|
||||
if v.Kind() != reflect.Ptr || v.IsNil() {
|
||||
return
|
||||
}
|
||||
v = v.Elem()
|
||||
|
||||
// Handle case where obj is a slice of models
|
||||
if v.Kind() == reflect.Slice {
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
elem := v.Index(i)
|
||||
if elem.Kind() == reflect.Ptr && !elem.IsNil() && elem.Elem().Kind() == reflect.Struct {
|
||||
Normalize(elem.Interface())
|
||||
} else if elem.Kind() == reflect.Struct && elem.CanAddr() {
|
||||
Normalize(elem.Addr().Interface())
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if v.Kind() != reflect.Struct {
|
||||
return
|
||||
}
|
||||
|
||||
// Iterate through all fields looking for those with the "unorm" tag
|
||||
t := v.Type()
|
||||
loop:
|
||||
for i := range t.NumField() {
|
||||
field := t.Field(i)
|
||||
|
||||
unormTag := field.Tag.Get("unorm")
|
||||
if unormTag == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
fv := v.Field(i)
|
||||
if !fv.CanSet() || fv.Kind() != reflect.String {
|
||||
continue
|
||||
}
|
||||
|
||||
var form norm.Form
|
||||
switch unormTag {
|
||||
case "nfc":
|
||||
form = norm.NFC
|
||||
case "nfkc":
|
||||
form = norm.NFKC
|
||||
case "nfd":
|
||||
form = norm.NFD
|
||||
case "nfkd":
|
||||
form = norm.NFKD
|
||||
default:
|
||||
continue loop
|
||||
}
|
||||
|
||||
val := fv.String()
|
||||
val = form.String(val)
|
||||
fv.SetString(val)
|
||||
}
|
||||
}
|
||||
|
||||
func ShouldBindWithNormalizedJSON(ctx *gin.Context, obj any) error {
|
||||
return ctx.ShouldBindWith(obj, binding.JSON)
|
||||
}
|
||||
|
||||
type NormalizerJSONBinding struct{}
|
||||
|
||||
func (NormalizerJSONBinding) Name() string {
|
||||
return "json"
|
||||
}
|
||||
|
||||
func (NormalizerJSONBinding) Bind(req *http.Request, obj any) error {
|
||||
// Use the default JSON binder
|
||||
err := binding.JSON.Bind(req, obj)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Perform normalization
|
||||
Normalize(obj)
|
||||
|
||||
return nil
|
||||
}
|
||||
84
backend/internal/dto/dto_normalize_test.go
Normal file
84
backend/internal/dto/dto_normalize_test.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/text/unicode/norm"
|
||||
)
|
||||
|
||||
type testDto struct {
|
||||
Name string `unorm:"nfc"`
|
||||
Description string `unorm:"nfd"`
|
||||
Other string
|
||||
BadForm string `unorm:"bad"`
|
||||
}
|
||||
|
||||
func TestNormalize(t *testing.T) {
|
||||
input := testDto{
|
||||
// Is in NFC form already
|
||||
Name: norm.NFC.String("Café"),
|
||||
// NFC form will be normalized to NFD
|
||||
Description: norm.NFC.String("vërø"),
|
||||
// Should be unchanged
|
||||
Other: "NöTag",
|
||||
// Should be unchanged
|
||||
BadForm: "BåD",
|
||||
}
|
||||
|
||||
Normalize(&input)
|
||||
|
||||
assert.Equal(t, norm.NFC.String("Café"), input.Name)
|
||||
assert.Equal(t, norm.NFD.String("vërø"), input.Description)
|
||||
assert.Equal(t, "NöTag", input.Other)
|
||||
assert.Equal(t, "BåD", input.BadForm)
|
||||
}
|
||||
|
||||
func TestNormalizeSlice(t *testing.T) {
|
||||
obj1 := testDto{
|
||||
Name: norm.NFC.String("Café1"),
|
||||
Description: norm.NFC.String("vërø1"),
|
||||
Other: "NöTag1",
|
||||
BadForm: "BåD1",
|
||||
}
|
||||
obj2 := testDto{
|
||||
Name: norm.NFD.String("Résumé2"),
|
||||
Description: norm.NFD.String("accéléré2"),
|
||||
Other: "NöTag2",
|
||||
BadForm: "BåD2",
|
||||
}
|
||||
|
||||
t.Run("slice of structs", func(t *testing.T) {
|
||||
slice := []testDto{obj1, obj2}
|
||||
Normalize(&slice)
|
||||
|
||||
// Verify first element
|
||||
assert.Equal(t, norm.NFC.String("Café1"), slice[0].Name)
|
||||
assert.Equal(t, norm.NFD.String("vërø1"), slice[0].Description)
|
||||
assert.Equal(t, "NöTag1", slice[0].Other)
|
||||
assert.Equal(t, "BåD1", slice[0].BadForm)
|
||||
|
||||
// Verify second element
|
||||
assert.Equal(t, norm.NFC.String("Résumé2"), slice[1].Name)
|
||||
assert.Equal(t, norm.NFD.String("accéléré2"), slice[1].Description)
|
||||
assert.Equal(t, "NöTag2", slice[1].Other)
|
||||
assert.Equal(t, "BåD2", slice[1].BadForm)
|
||||
})
|
||||
|
||||
t.Run("slice of pointers to structs", func(t *testing.T) {
|
||||
slice := []*testDto{&obj1, &obj2}
|
||||
Normalize(&slice)
|
||||
|
||||
// Verify first element
|
||||
assert.Equal(t, norm.NFC.String("Café1"), slice[0].Name)
|
||||
assert.Equal(t, norm.NFD.String("vërø1"), slice[0].Description)
|
||||
assert.Equal(t, "NöTag1", slice[0].Other)
|
||||
assert.Equal(t, "BåD1", slice[0].BadForm)
|
||||
|
||||
// Verify second element
|
||||
assert.Equal(t, norm.NFC.String("Résumé2"), slice[1].Name)
|
||||
assert.Equal(t, norm.NFD.String("accéléré2"), slice[1].Description)
|
||||
assert.Equal(t, "NöTag2", slice[1].Other)
|
||||
assert.Equal(t, "BåD2", slice[1].BadForm)
|
||||
})
|
||||
}
|
||||
@@ -8,10 +8,11 @@ type OidcClientMetaDataDto struct {
|
||||
|
||||
type OidcClientDto struct {
|
||||
OidcClientMetaDataDto
|
||||
CallbackURLs []string `json:"callbackURLs"`
|
||||
LogoutCallbackURLs []string `json:"logoutCallbackURLs"`
|
||||
IsPublic bool `json:"isPublic"`
|
||||
PkceEnabled bool `json:"pkceEnabled"`
|
||||
CallbackURLs []string `json:"callbackURLs"`
|
||||
LogoutCallbackURLs []string `json:"logoutCallbackURLs"`
|
||||
IsPublic bool `json:"isPublic"`
|
||||
PkceEnabled bool `json:"pkceEnabled"`
|
||||
Credentials OidcClientCredentialsDto `json:"credentials"`
|
||||
}
|
||||
|
||||
type OidcClientWithAllowedUserGroupsDto struct {
|
||||
@@ -25,11 +26,23 @@ type OidcClientWithAllowedGroupsCountDto struct {
|
||||
}
|
||||
|
||||
type OidcClientCreateDto struct {
|
||||
Name string `json:"name" binding:"required,max=50"`
|
||||
CallbackURLs []string `json:"callbackURLs"`
|
||||
LogoutCallbackURLs []string `json:"logoutCallbackURLs"`
|
||||
IsPublic bool `json:"isPublic"`
|
||||
PkceEnabled bool `json:"pkceEnabled"`
|
||||
Name string `json:"name" binding:"required,max=50" unorm:"nfc"`
|
||||
CallbackURLs []string `json:"callbackURLs"`
|
||||
LogoutCallbackURLs []string `json:"logoutCallbackURLs"`
|
||||
IsPublic bool `json:"isPublic"`
|
||||
PkceEnabled bool `json:"pkceEnabled"`
|
||||
Credentials OidcClientCredentialsDto `json:"credentials"`
|
||||
}
|
||||
|
||||
type OidcClientCredentialsDto struct {
|
||||
FederatedIdentities []OidcClientFederatedIdentityDto `json:"federatedIdentities,omitempty"`
|
||||
}
|
||||
|
||||
type OidcClientFederatedIdentityDto struct {
|
||||
Issuer string `json:"issuer"`
|
||||
Subject string `json:"subject,omitempty"`
|
||||
Audience string `json:"audience,omitempty"`
|
||||
JWKS string `json:"jwks,omitempty"`
|
||||
}
|
||||
|
||||
type AuthorizeOidcClientRequestDto struct {
|
||||
@@ -44,6 +57,7 @@ type AuthorizeOidcClientRequestDto struct {
|
||||
type AuthorizeOidcClientResponseDto struct {
|
||||
Code string `json:"code"`
|
||||
CallbackURL string `json:"callbackURL"`
|
||||
Issuer string `json:"issuer"`
|
||||
}
|
||||
|
||||
type AuthorizationRequiredDto struct {
|
||||
@@ -52,13 +66,15 @@ type AuthorizationRequiredDto struct {
|
||||
}
|
||||
|
||||
type OidcCreateTokensDto struct {
|
||||
GrantType string `form:"grant_type" binding:"required"`
|
||||
Code string `form:"code"`
|
||||
DeviceCode string `form:"device_code"`
|
||||
ClientID string `form:"client_id"`
|
||||
ClientSecret string `form:"client_secret"`
|
||||
CodeVerifier string `form:"code_verifier"`
|
||||
RefreshToken string `form:"refresh_token"`
|
||||
GrantType string `form:"grant_type" binding:"required"`
|
||||
Code string `form:"code"`
|
||||
DeviceCode string `form:"device_code"`
|
||||
ClientID string `form:"client_id"`
|
||||
ClientSecret string `form:"client_secret"`
|
||||
CodeVerifier string `form:"code_verifier"`
|
||||
RefreshToken string `form:"refresh_token"`
|
||||
ClientAssertion string `form:"client_assertion"`
|
||||
ClientAssertionType string `form:"client_assertion_type"`
|
||||
}
|
||||
|
||||
type OidcIntrospectDto struct {
|
||||
@@ -98,9 +114,11 @@ type OidcIntrospectionResponseDto struct {
|
||||
}
|
||||
|
||||
type OidcDeviceAuthorizationRequestDto struct {
|
||||
ClientID string `form:"client_id" binding:"required"`
|
||||
Scope string `form:"scope" binding:"required"`
|
||||
ClientSecret string `form:"client_secret"`
|
||||
ClientID string `form:"client_id" binding:"required"`
|
||||
Scope string `form:"scope" binding:"required"`
|
||||
ClientSecret string `form:"client_secret"`
|
||||
ClientAssertion string `form:"client_assertion"`
|
||||
ClientAssertionType string `form:"client_assertion_type"`
|
||||
}
|
||||
|
||||
type OidcDeviceAuthorizationResponseDto struct {
|
||||
@@ -125,3 +143,14 @@ type DeviceCodeInfoDto struct {
|
||||
AuthorizationRequired bool `json:"authorizationRequired"`
|
||||
Client OidcClientMetaDataDto `json:"client"`
|
||||
}
|
||||
|
||||
type AuthorizedOidcClientDto struct {
|
||||
Scope string `json:"scope"`
|
||||
Client OidcClientMetaDataDto `json:"client"`
|
||||
}
|
||||
|
||||
type OidcClientPreviewDto struct {
|
||||
IdToken map[string]any `json:"idToken"`
|
||||
AccessToken map[string]any `json:"accessToken"`
|
||||
UserInfo map[string]any `json:"userInfo"`
|
||||
}
|
||||
|
||||
21
backend/internal/dto/signup_token_dto.go
Normal file
21
backend/internal/dto/signup_token_dto.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
)
|
||||
|
||||
type SignupTokenCreateDto struct {
|
||||
ExpiresAt time.Time `json:"expiresAt" binding:"required"`
|
||||
UsageLimit int `json:"usageLimit" binding:"required,min=1,max=100"`
|
||||
}
|
||||
|
||||
type SignupTokenDto struct {
|
||||
ID string `json:"id"`
|
||||
Token string `json:"token"`
|
||||
ExpiresAt datatype.DateTime `json:"expiresAt"`
|
||||
UsageLimit int `json:"usageLimit"`
|
||||
UsageCount int `json:"usageCount"`
|
||||
CreatedAt datatype.DateTime `json:"createdAt"`
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
package dto
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type UserDto struct {
|
||||
ID string `json:"id"`
|
||||
@@ -17,10 +19,10 @@ type UserDto struct {
|
||||
}
|
||||
|
||||
type UserCreateDto struct {
|
||||
Username string `json:"username" binding:"required,username,min=2,max=50"`
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
FirstName string `json:"firstName" binding:"required,min=1,max=50"`
|
||||
LastName string `json:"lastName" binding:"max=50"`
|
||||
Username string `json:"username" binding:"required,username,min=2,max=50" unorm:"nfc"`
|
||||
Email string `json:"email" binding:"required,email" unorm:"nfc"`
|
||||
FirstName string `json:"firstName" binding:"required,min=1,max=50" unorm:"nfc"`
|
||||
LastName string `json:"lastName" binding:"max=50" unorm:"nfc"`
|
||||
IsAdmin bool `json:"isAdmin"`
|
||||
Locale *string `json:"locale"`
|
||||
Disabled bool `json:"disabled"`
|
||||
@@ -33,7 +35,7 @@ type OneTimeAccessTokenCreateDto struct {
|
||||
}
|
||||
|
||||
type OneTimeAccessEmailAsUnauthenticatedUserDto struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Email string `json:"email" binding:"required,email" unorm:"nfc"`
|
||||
RedirectPath string `json:"redirectPath"`
|
||||
}
|
||||
|
||||
@@ -44,3 +46,11 @@ type OneTimeAccessEmailAsAdminDto struct {
|
||||
type UserUpdateUserGroupDto struct {
|
||||
UserGroupIds []string `json:"userGroupIds" binding:"required"`
|
||||
}
|
||||
|
||||
type SignUpDto struct {
|
||||
Username string `json:"username" binding:"required,username,min=2,max=50" unorm:"nfc"`
|
||||
Email string `json:"email" binding:"required,email" unorm:"nfc"`
|
||||
FirstName string `json:"firstName" binding:"required,min=1,max=50" unorm:"nfc"`
|
||||
LastName string `json:"lastName" binding:"max=50" unorm:"nfc"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
@@ -34,8 +34,8 @@ type UserGroupDtoWithUserCount struct {
|
||||
}
|
||||
|
||||
type UserGroupCreateDto struct {
|
||||
FriendlyName string `json:"friendlyName" binding:"required,min=2,max=50"`
|
||||
Name string `json:"name" binding:"required,min=2,max=255"`
|
||||
FriendlyName string `json:"friendlyName" binding:"required,min=2,max=50" unorm:"nfc"`
|
||||
Name string `json:"name" binding:"required,min=2,max=255" unorm:"nfc"`
|
||||
LdapID string `json:"-"`
|
||||
}
|
||||
|
||||
|
||||
@@ -8,13 +8,13 @@ import (
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
// [a-zA-Z0-9] : The username must start with an alphanumeric character
|
||||
// [a-zA-Z0-9_.@-]* : The rest of the username can contain alphanumeric characters, dots, underscores, hyphens, and "@" symbols
|
||||
// [a-zA-Z0-9]$ : The username must end with an alphanumeric character
|
||||
var validateUsernameRegex = regexp.MustCompile("^[a-zA-Z0-9][a-zA-Z0-9_.@-]*[a-zA-Z0-9]$")
|
||||
|
||||
var validateUsername validator.Func = func(fl validator.FieldLevel) bool {
|
||||
// [a-zA-Z0-9] : The username must start with an alphanumeric character
|
||||
// [a-zA-Z0-9_.@-]* : The rest of the username can contain alphanumeric characters, dots, underscores, hyphens, and "@" symbols
|
||||
// [a-zA-Z0-9]$ : The username must end with an alphanumeric character
|
||||
regex := "^[a-zA-Z0-9][a-zA-Z0-9_.@-]*[a-zA-Z0-9]$"
|
||||
matched, _ := regexp.MatchString(regex, fl.Field().String())
|
||||
return matched
|
||||
return validateUsernameRegex.MatchString(fl.Field().String())
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -19,5 +19,5 @@ type WebauthnCredentialDto struct {
|
||||
}
|
||||
|
||||
type WebauthnCredentialUpdateDto struct {
|
||||
Name string `json:"name" binding:"required,min=1,max=30"`
|
||||
Name string `json:"name" binding:"required,min=1,max=50"`
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ func (s *Scheduler) RegisterDbCleanupJobs(ctx context.Context, db *gorm.DB) erro
|
||||
return errors.Join(
|
||||
s.registerJob(ctx, "ClearWebauthnSessions", def, jobs.clearWebauthnSessions, true),
|
||||
s.registerJob(ctx, "ClearOneTimeAccessTokens", def, jobs.clearOneTimeAccessTokens, true),
|
||||
s.registerJob(ctx, "ClearSignupTokens", def, jobs.clearSignupTokens, true),
|
||||
s.registerJob(ctx, "ClearOidcAuthorizationCodes", def, jobs.clearOidcAuthorizationCodes, true),
|
||||
s.registerJob(ctx, "ClearOidcRefreshTokens", def, jobs.clearOidcRefreshTokens, true),
|
||||
s.registerJob(ctx, "ClearAuditLogs", def, jobs.clearAuditLogs, true),
|
||||
@@ -60,6 +61,21 @@ func (j *DbCleanupJobs) clearOneTimeAccessTokens(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClearSignupTokens deletes signup tokens that have expired
|
||||
func (j *DbCleanupJobs) clearSignupTokens(ctx context.Context) error {
|
||||
// Delete tokens that are expired OR have reached their usage limit
|
||||
st := j.db.
|
||||
WithContext(ctx).
|
||||
Delete(&model.SignupToken{}, "expires_at < ?", datatype.DateTime(time.Now()))
|
||||
if st.Error != nil {
|
||||
return fmt.Errorf("failed to clean expired tokens: %w", st.Error)
|
||||
}
|
||||
|
||||
slog.InfoContext(ctx, "Cleaned expired tokens", slog.Int64("count", st.RowsAffected))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClearOidcAuthorizationCodes deletes OIDC authorization codes that have expired
|
||||
func (j *DbCleanupJobs) clearOidcAuthorizationCodes(ctx context.Context) error {
|
||||
st := j.db.
|
||||
|
||||
@@ -29,7 +29,7 @@ func (m *RateLimitMiddleware) Add(limit rate.Limit, burst int) gin.HandlerFunc {
|
||||
|
||||
// Skip rate limiting for localhost and test environment
|
||||
// If the client ip is localhost the request comes from the frontend
|
||||
if ip == "127.0.0.1" || ip == "::1" || common.EnvConfig.AppEnv == "test" {
|
||||
if ip == "" || ip == "127.0.0.1" || ip == "::1" || common.EnvConfig.AppEnv == "test" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
)
|
||||
|
||||
type AppConfigVariable struct {
|
||||
@@ -35,8 +37,10 @@ type AppConfig struct {
|
||||
AppName AppConfigVariable `key:"appName,public"` // Public
|
||||
SessionDuration AppConfigVariable `key:"sessionDuration"`
|
||||
EmailsVerified AppConfigVariable `key:"emailsVerified"`
|
||||
AccentColor AppConfigVariable `key:"accentColor,public"` // Public
|
||||
DisableAnimations AppConfigVariable `key:"disableAnimations,public"` // Public
|
||||
AllowOwnAccountEdit AppConfigVariable `key:"allowOwnAccountEdit,public"` // Public
|
||||
AllowUserSignups AppConfigVariable `key:"allowUserSignups,public"` // Public
|
||||
// Internal
|
||||
BackgroundImageType AppConfigVariable `key:"backgroundImageType,internal"` // Internal
|
||||
LogoLightImageType AppConfigVariable `key:"logoLightImageType,internal"` // Internal
|
||||
@@ -47,7 +51,7 @@ type AppConfig struct {
|
||||
SmtpPort AppConfigVariable `key:"smtpPort"`
|
||||
SmtpFrom AppConfigVariable `key:"smtpFrom"`
|
||||
SmtpUser AppConfigVariable `key:"smtpUser"`
|
||||
SmtpPassword AppConfigVariable `key:"smtpPassword"`
|
||||
SmtpPassword AppConfigVariable `key:"smtpPassword,sensitive"`
|
||||
SmtpTls AppConfigVariable `key:"smtpTls"`
|
||||
SmtpSkipCertVerify AppConfigVariable `key:"smtpSkipCertVerify"`
|
||||
EmailLoginNotificationEnabled AppConfigVariable `key:"emailLoginNotificationEnabled"`
|
||||
@@ -58,7 +62,7 @@ type AppConfig struct {
|
||||
LdapEnabled AppConfigVariable `key:"ldapEnabled,public"` // Public
|
||||
LdapUrl AppConfigVariable `key:"ldapUrl"`
|
||||
LdapBindDn AppConfigVariable `key:"ldapBindDn"`
|
||||
LdapBindPassword AppConfigVariable `key:"ldapBindPassword"`
|
||||
LdapBindPassword AppConfigVariable `key:"ldapBindPassword,sensitive"`
|
||||
LdapBase AppConfigVariable `key:"ldapBase"`
|
||||
LdapUserSearchFilter AppConfigVariable `key:"ldapUserSearchFilter"`
|
||||
LdapUserGroupSearchFilter AppConfigVariable `key:"ldapUserGroupSearchFilter"`
|
||||
@@ -76,7 +80,7 @@ type AppConfig struct {
|
||||
LdapSoftDeleteUsers AppConfigVariable `key:"ldapSoftDeleteUsers"`
|
||||
}
|
||||
|
||||
func (c *AppConfig) ToAppConfigVariableSlice(showAll bool) []AppConfigVariable {
|
||||
func (c *AppConfig) ToAppConfigVariableSlice(showAll bool, redactSensitiveValues bool) []AppConfigVariable {
|
||||
// Use reflection to iterate through all fields
|
||||
cfgValue := reflect.ValueOf(c).Elem()
|
||||
cfgType := cfgValue.Type()
|
||||
@@ -96,11 +100,16 @@ func (c *AppConfig) ToAppConfigVariableSlice(showAll bool) []AppConfigVariable {
|
||||
continue
|
||||
}
|
||||
|
||||
fieldValue := cfgValue.Field(i)
|
||||
value := cfgValue.Field(i).FieldByName("Value").String()
|
||||
|
||||
// Redact sensitive values if the value isn't empty, the UI config is disabled, and redactSensitiveValues is true
|
||||
if value != "" && common.EnvConfig.UiConfigDisabled && redactSensitiveValues && attrs == "sensitive" {
|
||||
value = "XXXXXXXXXX"
|
||||
}
|
||||
|
||||
appConfigVariable := AppConfigVariable{
|
||||
Key: key,
|
||||
Value: fieldValue.FieldByName("Value").String(),
|
||||
Value: value,
|
||||
}
|
||||
|
||||
res = append(res, appConfigVariable)
|
||||
|
||||
@@ -10,7 +10,7 @@ type AuditLog struct {
|
||||
Base
|
||||
|
||||
Event AuditLogEvent `sortable:"true"`
|
||||
IpAddress string `sortable:"true"`
|
||||
IpAddress *string `sortable:"true"`
|
||||
Country string `sortable:"true"`
|
||||
City string `sortable:"true"`
|
||||
UserAgent string `sortable:"true"`
|
||||
@@ -28,6 +28,7 @@ type AuditLogEvent string //nolint:recvcheck
|
||||
const (
|
||||
AuditLogEventSignIn AuditLogEvent = "SIGN_IN"
|
||||
AuditLogEventOneTimeAccessTokenSignIn AuditLogEvent = "TOKEN_SIGN_IN"
|
||||
AuditLogEventAccountCreated AuditLogEvent = "ACCOUNT_CREATED"
|
||||
AuditLogEventClientAuthorization AuditLogEvent = "CLIENT_AUTHORIZATION"
|
||||
AuditLogEventNewClientAuthorization AuditLogEvent = "NEW_CLIENT_AUTHORIZATION"
|
||||
AuditLogEventDeviceCodeAuthorization AuditLogEvent = "DEVICE_CODE_AUTHORIZATION"
|
||||
|
||||
11
backend/internal/model/kv.go
Normal file
11
backend/internal/model/kv.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package model
|
||||
|
||||
type KV struct {
|
||||
Key string `gorm:"primaryKey;not null"`
|
||||
Value *string
|
||||
}
|
||||
|
||||
// TableName overrides the table name used by KV to `kv`
|
||||
func (KV) TableName() string {
|
||||
return "kv"
|
||||
}
|
||||
@@ -5,8 +5,9 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
"gorm.io/gorm"
|
||||
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
)
|
||||
|
||||
type UserAuthorizedOidcClient struct {
|
||||
@@ -45,6 +46,7 @@ type OidcClient struct {
|
||||
HasLogo bool `gorm:"-"`
|
||||
IsPublic bool
|
||||
PkceEnabled bool
|
||||
Credentials OidcClientCredentials
|
||||
|
||||
AllowedUserGroups []UserGroup `gorm:"many2many:oidc_clients_allowed_user_groups;"`
|
||||
CreatedByID string
|
||||
@@ -71,9 +73,49 @@ func (c *OidcClient) AfterFind(_ *gorm.DB) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
type OidcClientCredentials struct { //nolint:recvcheck
|
||||
FederatedIdentities []OidcClientFederatedIdentity `json:"federatedIdentities,omitempty"`
|
||||
}
|
||||
|
||||
type OidcClientFederatedIdentity struct {
|
||||
Issuer string `json:"issuer"`
|
||||
Subject string `json:"subject,omitempty"`
|
||||
Audience string `json:"audience,omitempty"`
|
||||
JWKS string `json:"jwks,omitempty"` // URL of the JWKS
|
||||
}
|
||||
|
||||
func (occ OidcClientCredentials) FederatedIdentityForIssuer(issuer string) (OidcClientFederatedIdentity, bool) {
|
||||
if issuer == "" {
|
||||
return OidcClientFederatedIdentity{}, false
|
||||
}
|
||||
|
||||
for _, fi := range occ.FederatedIdentities {
|
||||
if fi.Issuer == issuer {
|
||||
return fi, true
|
||||
}
|
||||
}
|
||||
|
||||
return OidcClientFederatedIdentity{}, false
|
||||
}
|
||||
|
||||
func (occ *OidcClientCredentials) Scan(value any) error {
|
||||
switch v := value.(type) {
|
||||
case []byte:
|
||||
return json.Unmarshal(v, occ)
|
||||
case string:
|
||||
return json.Unmarshal([]byte(v), occ)
|
||||
default:
|
||||
return fmt.Errorf("unsupported type: %T", value)
|
||||
}
|
||||
}
|
||||
|
||||
func (occ OidcClientCredentials) Value() (driver.Value, error) {
|
||||
return json.Marshal(occ)
|
||||
}
|
||||
|
||||
type UrlList []string //nolint:recvcheck
|
||||
|
||||
func (cu *UrlList) Scan(value interface{}) error {
|
||||
func (cu *UrlList) Scan(value any) error {
|
||||
switch v := value.(type) {
|
||||
case []byte:
|
||||
return json.Unmarshal(v, cu)
|
||||
|
||||
28
backend/internal/model/signup_token.go
Normal file
28
backend/internal/model/signup_token.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
)
|
||||
|
||||
type SignupToken struct {
|
||||
Base
|
||||
|
||||
Token string `json:"token"`
|
||||
ExpiresAt datatype.DateTime `json:"expiresAt" sortable:"true"`
|
||||
UsageLimit int `json:"usageLimit" sortable:"true"`
|
||||
UsageCount int `json:"usageCount" sortable:"true"`
|
||||
}
|
||||
|
||||
func (st *SignupToken) IsExpired() bool {
|
||||
return time.Time(st.ExpiresAt).Before(time.Now())
|
||||
}
|
||||
|
||||
func (st *SignupToken) IsUsageLimitReached() bool {
|
||||
return st.UsageCount >= st.UsageLimit
|
||||
}
|
||||
|
||||
func (st *SignupToken) IsValid() bool {
|
||||
return !st.IsExpired() && !st.IsUsageLimitReached()
|
||||
}
|
||||
@@ -29,17 +29,17 @@ type AppConfigService struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewAppConfigService(initCtx context.Context, db *gorm.DB) *AppConfigService {
|
||||
func NewAppConfigService(ctx context.Context, db *gorm.DB) *AppConfigService {
|
||||
service := &AppConfigService{
|
||||
db: db,
|
||||
}
|
||||
|
||||
err := service.LoadDbConfig(initCtx)
|
||||
err := service.LoadDbConfig(ctx)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to initialize app config service: %v", err)
|
||||
}
|
||||
|
||||
err = service.initInstanceID(initCtx)
|
||||
err = service.initInstanceID(ctx)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to initialize instance ID: %v", err)
|
||||
}
|
||||
@@ -68,6 +68,8 @@ func (s *AppConfigService) getDefaultDbConfig() *model.AppConfig {
|
||||
EmailsVerified: model.AppConfigVariable{Value: "false"},
|
||||
DisableAnimations: model.AppConfigVariable{Value: "false"},
|
||||
AllowOwnAccountEdit: model.AppConfigVariable{Value: "true"},
|
||||
AllowUserSignups: model.AppConfigVariable{Value: "disabled"},
|
||||
AccentColor: model.AppConfigVariable{Value: "default"},
|
||||
// Internal
|
||||
BackgroundImageType: model.AppConfigVariable{Value: "jpg"},
|
||||
LogoLightImageType: model.AppConfigVariable{Value: "svg"},
|
||||
@@ -232,11 +234,11 @@ func (s *AppConfigService) UpdateAppConfig(ctx context.Context, input dto.AppCon
|
||||
s.dbConfig.Store(cfg)
|
||||
|
||||
// Return the updated config
|
||||
res := cfg.ToAppConfigVariableSlice(true)
|
||||
res := cfg.ToAppConfigVariableSlice(true, false)
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// UpdateAppConfigValues
|
||||
// UpdateAppConfigValues updates the application configuration values in the database.
|
||||
func (s *AppConfigService) UpdateAppConfigValues(ctx context.Context, keysAndValues ...string) error {
|
||||
// Count of keysAndValues must be even
|
||||
if len(keysAndValues)%2 != 0 {
|
||||
@@ -317,11 +319,11 @@ func (s *AppConfigService) UpdateAppConfigValues(ctx context.Context, keysAndVal
|
||||
}
|
||||
|
||||
func (s *AppConfigService) ListAppConfig(showAll bool) []model.AppConfigVariable {
|
||||
return s.GetDbConfig().ToAppConfigVariableSlice(showAll)
|
||||
return s.GetDbConfig().ToAppConfigVariableSlice(showAll, true)
|
||||
}
|
||||
|
||||
func (s *AppConfigService) UpdateImage(ctx context.Context, uploadedFile *multipart.FileHeader, imageName string, oldImageType string) (err error) {
|
||||
fileType := utils.GetFileExtension(uploadedFile.Filename)
|
||||
fileType := strings.ToLower(utils.GetFileExtension(uploadedFile.Filename))
|
||||
mimeType := utils.GetImageMimeType(fileType)
|
||||
if mimeType == "" {
|
||||
return &common.FileTypeNotSupportedError{}
|
||||
@@ -355,24 +357,52 @@ func (s *AppConfigService) UpdateImage(ctx context.Context, uploadedFile *multip
|
||||
|
||||
// LoadDbConfig loads the configuration values from the database into the DbConfig struct.
|
||||
func (s *AppConfigService) LoadDbConfig(ctx context.Context) (err error) {
|
||||
var dest *model.AppConfig
|
||||
|
||||
// If the UI config is disabled, only load from the env
|
||||
if common.EnvConfig.UiConfigDisabled {
|
||||
dest, err = s.loadDbConfigFromEnv(ctx, s.db)
|
||||
} else {
|
||||
dest, err = s.loadDbConfigInternal(ctx, s.db)
|
||||
}
|
||||
dest, err := s.loadDbConfigInternal(ctx, s.db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the value in the object
|
||||
s.dbConfig.Store(dest)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AppConfigService) loadDbConfigInternal(ctx context.Context, tx *gorm.DB) (*model.AppConfig, error) {
|
||||
// If the UI config is disabled, only load from the env
|
||||
if common.EnvConfig.UiConfigDisabled {
|
||||
dest, err := s.loadDbConfigFromEnv(ctx, tx)
|
||||
return dest, err
|
||||
}
|
||||
|
||||
// First, start from the default configuration
|
||||
dest := s.getDefaultDbConfig()
|
||||
|
||||
// Load all configuration values from the database
|
||||
// This loads all values in a single shot
|
||||
var loaded []model.AppConfigVariable
|
||||
queryCtx, queryCancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer queryCancel()
|
||||
err := tx.
|
||||
WithContext(queryCtx).
|
||||
Find(&loaded).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load configuration from the database: %w", err)
|
||||
}
|
||||
|
||||
// Iterate through all values loaded from the database
|
||||
for _, v := range loaded {
|
||||
// Find the field in the struct whose "key" tag matches, then update that
|
||||
err = dest.UpdateField(v.Key, v.Value, false)
|
||||
|
||||
// We ignore the case of fields that don't exist, as there may be leftover data in the database
|
||||
if err != nil && !errors.Is(err, model.AppConfigKeyNotFoundError{}) {
|
||||
return nil, fmt.Errorf("failed to process config for key '%s': %w", v.Key, err)
|
||||
}
|
||||
}
|
||||
|
||||
return dest, nil
|
||||
}
|
||||
|
||||
func (s *AppConfigService) loadDbConfigFromEnv(ctx context.Context, tx *gorm.DB) (*model.AppConfig, error) {
|
||||
// First, start from the default configuration
|
||||
dest := s.getDefaultDbConfig()
|
||||
@@ -414,36 +444,6 @@ func (s *AppConfigService) loadDbConfigFromEnv(ctx context.Context, tx *gorm.DB)
|
||||
return dest, nil
|
||||
}
|
||||
|
||||
func (s *AppConfigService) loadDbConfigInternal(ctx context.Context, tx *gorm.DB) (*model.AppConfig, error) {
|
||||
// First, start from the default configuration
|
||||
dest := s.getDefaultDbConfig()
|
||||
|
||||
// Load all configuration values from the database
|
||||
// This loads all values in a single shot
|
||||
var loaded []model.AppConfigVariable
|
||||
queryCtx, queryCancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer queryCancel()
|
||||
err := tx.
|
||||
WithContext(queryCtx).
|
||||
Find(&loaded).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load configuration from the database: %w", err)
|
||||
}
|
||||
|
||||
// Iterate through all values loaded from the database
|
||||
for _, v := range loaded {
|
||||
// Find the field in the struct whose "key" tag matches, then update that
|
||||
err = dest.UpdateField(v.Key, v.Value, false)
|
||||
|
||||
// We ignore the case of fields that don't exist, as there may be leftover data in the database
|
||||
if err != nil && !errors.Is(err, model.AppConfigKeyNotFoundError{}) {
|
||||
return nil, fmt.Errorf("failed to process config for key '%s': %w", v.Key, err)
|
||||
}
|
||||
}
|
||||
|
||||
return dest, nil
|
||||
}
|
||||
|
||||
func (s *AppConfigService) initInstanceID(ctx context.Context) error {
|
||||
// Check if the instance ID is already set
|
||||
instanceID := s.GetDbConfig().InstanceID.Value
|
||||
|
||||
@@ -3,17 +3,13 @@ package service
|
||||
import (
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
"github.com/stretchr/testify/require"
|
||||
testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
|
||||
)
|
||||
|
||||
// NewTestAppConfigService is a function used by tests to create AppConfigService objects with pre-defined configuration values
|
||||
@@ -28,7 +24,7 @@ func NewTestAppConfigService(config *model.AppConfig) *AppConfigService {
|
||||
|
||||
func TestLoadDbConfig(t *testing.T) {
|
||||
t.Run("empty config table", func(t *testing.T) {
|
||||
db := newAppConfigTestDatabaseForTest(t)
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
service := &AppConfigService{
|
||||
db: db,
|
||||
}
|
||||
@@ -42,7 +38,7 @@ func TestLoadDbConfig(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("loads value from config table", func(t *testing.T) {
|
||||
db := newAppConfigTestDatabaseForTest(t)
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
|
||||
// Populate the config table with some initial values
|
||||
err := db.
|
||||
@@ -72,7 +68,7 @@ func TestLoadDbConfig(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("ignores unknown config keys", func(t *testing.T) {
|
||||
db := newAppConfigTestDatabaseForTest(t)
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
|
||||
// Add an entry with a key that doesn't exist in the config struct
|
||||
err := db.Create([]model.AppConfigVariable{
|
||||
@@ -93,7 +89,7 @@ func TestLoadDbConfig(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("loading config multiple times", func(t *testing.T) {
|
||||
db := newAppConfigTestDatabaseForTest(t)
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
|
||||
// Initial state
|
||||
err := db.Create([]model.AppConfigVariable{
|
||||
@@ -135,7 +131,7 @@ func TestLoadDbConfig(t *testing.T) {
|
||||
common.EnvConfig.UiConfigDisabled = true
|
||||
|
||||
// Create database with config that should be ignored
|
||||
db := newAppConfigTestDatabaseForTest(t)
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
err := db.Create([]model.AppConfigVariable{
|
||||
{Key: "appName", Value: "DB App"},
|
||||
{Key: "sessionDuration", Value: "120"},
|
||||
@@ -171,7 +167,7 @@ func TestLoadDbConfig(t *testing.T) {
|
||||
common.EnvConfig.UiConfigDisabled = false
|
||||
|
||||
// Create database with config values that should take precedence
|
||||
db := newAppConfigTestDatabaseForTest(t)
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
err := db.Create([]model.AppConfigVariable{
|
||||
{Key: "appName", Value: "DB App"},
|
||||
{Key: "sessionDuration", Value: "120"},
|
||||
@@ -195,7 +191,7 @@ func TestLoadDbConfig(t *testing.T) {
|
||||
|
||||
func TestUpdateAppConfigValues(t *testing.T) {
|
||||
t.Run("update single value", func(t *testing.T) {
|
||||
db := newAppConfigTestDatabaseForTest(t)
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
|
||||
// Create a service with default config
|
||||
service := &AppConfigService{
|
||||
@@ -220,7 +216,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("update multiple values", func(t *testing.T) {
|
||||
db := newAppConfigTestDatabaseForTest(t)
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
|
||||
// Create a service with default config
|
||||
service := &AppConfigService{
|
||||
@@ -264,7 +260,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("empty value resets to default", func(t *testing.T) {
|
||||
db := newAppConfigTestDatabaseForTest(t)
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
|
||||
// Create a service with default config
|
||||
service := &AppConfigService{
|
||||
@@ -285,7 +281,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("error with odd number of arguments", func(t *testing.T) {
|
||||
db := newAppConfigTestDatabaseForTest(t)
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
|
||||
// Create a service with default config
|
||||
service := &AppConfigService{
|
||||
@@ -301,7 +297,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("error with invalid key", func(t *testing.T) {
|
||||
db := newAppConfigTestDatabaseForTest(t)
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
|
||||
// Create a service with default config
|
||||
service := &AppConfigService{
|
||||
@@ -319,7 +315,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
|
||||
|
||||
func TestUpdateAppConfig(t *testing.T) {
|
||||
t.Run("updates configuration values from DTO", func(t *testing.T) {
|
||||
db := newAppConfigTestDatabaseForTest(t)
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
|
||||
// Create a service with default config
|
||||
service := &AppConfigService{
|
||||
@@ -392,7 +388,7 @@ func TestUpdateAppConfig(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("empty values reset to defaults", func(t *testing.T) {
|
||||
db := newAppConfigTestDatabaseForTest(t)
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
|
||||
// Create a service with default config and modify some values
|
||||
service := &AppConfigService{
|
||||
@@ -457,7 +453,7 @@ func TestUpdateAppConfig(t *testing.T) {
|
||||
// Disable UI config
|
||||
common.EnvConfig.UiConfigDisabled = true
|
||||
|
||||
db := newAppConfigTestDatabaseForTest(t)
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
service := &AppConfigService{
|
||||
db: db,
|
||||
}
|
||||
@@ -475,49 +471,3 @@ func TestUpdateAppConfig(t *testing.T) {
|
||||
require.ErrorAs(t, err, &uiConfigDisabledErr)
|
||||
})
|
||||
}
|
||||
|
||||
// Implements gorm's logger.Writer interface
|
||||
type testLoggerAdapter struct {
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func (l testLoggerAdapter) Printf(format string, args ...any) {
|
||||
l.t.Logf(format, args...)
|
||||
}
|
||||
|
||||
func newAppConfigTestDatabaseForTest(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
// Get a name for this in-memory database that is specific to the test
|
||||
dbName := utils.CreateSha256Hash(t.Name())
|
||||
|
||||
// Connect to a new in-memory SQL database
|
||||
db, err := gorm.Open(
|
||||
sqlite.Open("file:"+dbName+"?mode=memory&cache=shared"),
|
||||
&gorm.Config{
|
||||
TranslateError: true,
|
||||
Logger: logger.New(
|
||||
testLoggerAdapter{t: t},
|
||||
logger.Config{
|
||||
SlowThreshold: 200 * time.Millisecond,
|
||||
LogLevel: logger.Info,
|
||||
IgnoreRecordNotFoundError: false,
|
||||
ParameterizedQueries: false,
|
||||
Colorful: false,
|
||||
},
|
||||
),
|
||||
})
|
||||
require.NoError(t, err, "Failed to connect to test database")
|
||||
|
||||
// Create the app_config_variables table
|
||||
err = db.Exec(`
|
||||
CREATE TABLE app_config_variables
|
||||
(
|
||||
key VARCHAR(100) NOT NULL PRIMARY KEY,
|
||||
value TEXT NOT NULL
|
||||
)
|
||||
`).Error
|
||||
require.NoError(t, err, "Failed to create test config table")
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
|
||||
userAgentParser "github.com/mileusna/useragent"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||
@@ -25,15 +26,15 @@ func NewAuditLogService(db *gorm.DB, appConfigService *AppConfigService, emailSe
|
||||
}
|
||||
|
||||
// Create creates a new audit log entry in the database
|
||||
func (s *AuditLogService) Create(ctx context.Context, event model.AuditLogEvent, ipAddress, userAgent, userID string, data model.AuditLogData, tx *gorm.DB) model.AuditLog {
|
||||
func (s *AuditLogService) Create(ctx context.Context, event model.AuditLogEvent, ipAddress, userAgent, userID string, data model.AuditLogData, tx *gorm.DB) (model.AuditLog, bool) {
|
||||
country, city, err := s.geoliteService.GetLocationByIP(ipAddress)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get IP location: %v", err)
|
||||
// Log the error but don't interrupt the operation
|
||||
slog.Warn("Failed to get IP location", "error", err)
|
||||
}
|
||||
|
||||
auditLog := model.AuditLog{
|
||||
Event: event,
|
||||
IpAddress: ipAddress,
|
||||
Country: country,
|
||||
City: city,
|
||||
UserAgent: userAgent,
|
||||
@@ -41,33 +42,47 @@ func (s *AuditLogService) Create(ctx context.Context, event model.AuditLogEvent,
|
||||
Data: data,
|
||||
}
|
||||
|
||||
if ipAddress != "" {
|
||||
// Only set ipAddress if not empty, because on Postgres we use INET columns that don't allow non-null empty values
|
||||
auditLog.IpAddress = &ipAddress
|
||||
}
|
||||
|
||||
// Save the audit log in the database
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Create(&auditLog).
|
||||
Error
|
||||
if err != nil {
|
||||
log.Printf("Failed to create audit log: %v", err)
|
||||
return model.AuditLog{}
|
||||
slog.Error("Failed to create audit log", "error", err)
|
||||
return model.AuditLog{}, false
|
||||
}
|
||||
|
||||
return auditLog
|
||||
return auditLog, true
|
||||
}
|
||||
|
||||
// CreateNewSignInWithEmail creates a new audit log entry in the database and sends an email if the device hasn't been used before
|
||||
func (s *AuditLogService) CreateNewSignInWithEmail(ctx context.Context, ipAddress, userAgent, userID string, tx *gorm.DB) model.AuditLog {
|
||||
createdAuditLog := s.Create(ctx, model.AuditLogEventSignIn, ipAddress, userAgent, userID, model.AuditLogData{}, tx)
|
||||
createdAuditLog, ok := s.Create(ctx, model.AuditLogEventSignIn, ipAddress, userAgent, userID, model.AuditLogData{}, tx)
|
||||
if !ok {
|
||||
// At this point the transaction has been canceled already, and error has been logged
|
||||
return createdAuditLog
|
||||
}
|
||||
|
||||
// Count the number of times the user has logged in from the same device
|
||||
var count int64
|
||||
err := tx.
|
||||
stmt := tx.
|
||||
WithContext(ctx).
|
||||
Model(&model.AuditLog{}).
|
||||
Where("user_id = ? AND ip_address = ? AND user_agent = ?", userID, ipAddress, userAgent).
|
||||
Count(&count).
|
||||
Error
|
||||
Where("user_id = ? AND user_agent = ?", userID, userAgent)
|
||||
if ipAddress == "" {
|
||||
// An empty IP address is stored as NULL in the database
|
||||
stmt = stmt.Where("ip_address IS NULL")
|
||||
} else {
|
||||
stmt = stmt.Where("ip_address = ?", ipAddress)
|
||||
}
|
||||
err := stmt.Count(&count).Error
|
||||
if err != nil {
|
||||
log.Printf("Failed to count audit logs: %v\n", err)
|
||||
log.Printf("Failed to count audit logs: %v", err)
|
||||
return createdAuditLog
|
||||
}
|
||||
|
||||
@@ -150,6 +165,14 @@ func (s *AuditLogService) ListAllAuditLogs(ctx context.Context, sortedPagination
|
||||
return nil, utils.PaginationResponse{}, fmt.Errorf("unsupported database dialect: %s", dialect)
|
||||
}
|
||||
}
|
||||
if filters.Location != "" {
|
||||
switch filters.Location {
|
||||
case "external":
|
||||
query = query.Where("country != 'Internal Network'")
|
||||
case "internal":
|
||||
query = query.Where("country = 'Internal Network'")
|
||||
}
|
||||
}
|
||||
|
||||
pagination, err := utils.PaginateAndSort(sortedPaginationRequest, query, &logs)
|
||||
if err != nil {
|
||||
|
||||
@@ -5,6 +5,8 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
@@ -15,13 +17,16 @@ import (
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/go-webauthn/webauthn/protocol"
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk"
|
||||
"github.com/pocket-id/pocket-id/backend/resources"
|
||||
)
|
||||
|
||||
@@ -30,14 +35,43 @@ type TestService struct {
|
||||
jwtService *JwtService
|
||||
appConfigService *AppConfigService
|
||||
ldapService *LdapService
|
||||
externalIdPKey jwk.Key
|
||||
}
|
||||
|
||||
func NewTestService(db *gorm.DB, appConfigService *AppConfigService, jwtService *JwtService, ldapService *LdapService) *TestService {
|
||||
return &TestService{db: db, appConfigService: appConfigService, jwtService: jwtService, ldapService: ldapService}
|
||||
func NewTestService(db *gorm.DB, appConfigService *AppConfigService, jwtService *JwtService, ldapService *LdapService) (*TestService, error) {
|
||||
s := &TestService{
|
||||
db: db,
|
||||
appConfigService: appConfigService,
|
||||
jwtService: jwtService,
|
||||
ldapService: ldapService,
|
||||
}
|
||||
err := s.initExternalIdP()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize external IdP: %w", err)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Initializes the "external IdP"
|
||||
// This creates a new "issuing authority" containing a public JWKS
|
||||
// It also stores the private key internally that will be used to issue JWTs
|
||||
func (s *TestService) initExternalIdP() error {
|
||||
// Generate a new ECDSA key
|
||||
rawKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate private key: %w", err)
|
||||
}
|
||||
|
||||
s.externalIdPKey, err = jwkutils.ImportRawKey(rawKey, jwa.ES256().String(), "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to import private key: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
//nolint:gocognit
|
||||
func (s *TestService) SeedDatabase() error {
|
||||
func (s *TestService) SeedDatabase(baseURL string) error {
|
||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||
users := []model.User{
|
||||
{
|
||||
@@ -138,6 +172,26 @@ func (s *TestService) SeedDatabase() error {
|
||||
userGroups[1],
|
||||
},
|
||||
},
|
||||
{
|
||||
Base: model.Base{
|
||||
ID: "c48232ff-ff65-45ed-ae96-7afa8a9b443b",
|
||||
},
|
||||
Name: "Federated",
|
||||
Secret: "$2a$10$Ak.FP8riD1ssy2AGGbG.gOpnp/rBpymd74j0nxNMtW0GG1Lb4gzxe", // PYjrE9u4v9GVqXKi52eur0eb2Ci4kc0x
|
||||
CallbackURLs: model.UrlList{"http://federated/auth/callback"},
|
||||
CreatedByID: users[1].ID,
|
||||
AllowedUserGroups: []model.UserGroup{},
|
||||
Credentials: model.OidcClientCredentials{
|
||||
FederatedIdentities: []model.OidcClientFederatedIdentity{
|
||||
{
|
||||
Issuer: "https://external-idp.local",
|
||||
Audience: "api://PocketID",
|
||||
Subject: "c48232ff-ff65-45ed-ae96-7afa8a9b443b",
|
||||
JWKS: baseURL + "/api/externalidp/jwks.json",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, client := range oidcClients {
|
||||
if err := tx.Create(&client).Error; err != nil {
|
||||
@@ -145,16 +199,28 @@ func (s *TestService) SeedDatabase() error {
|
||||
}
|
||||
}
|
||||
|
||||
authCode := model.OidcAuthorizationCode{
|
||||
Code: "auth-code",
|
||||
Scope: "openid profile",
|
||||
Nonce: "nonce",
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)),
|
||||
UserID: users[0].ID,
|
||||
ClientID: oidcClients[0].ID,
|
||||
authCodes := []model.OidcAuthorizationCode{
|
||||
{
|
||||
Code: "auth-code",
|
||||
Scope: "openid profile",
|
||||
Nonce: "nonce",
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)),
|
||||
UserID: users[0].ID,
|
||||
ClientID: oidcClients[0].ID,
|
||||
},
|
||||
{
|
||||
Code: "federated",
|
||||
Scope: "openid profile",
|
||||
Nonce: "nonce",
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)),
|
||||
UserID: users[1].ID,
|
||||
ClientID: oidcClients[2].ID,
|
||||
},
|
||||
}
|
||||
if err := tx.Create(&authCode).Error; err != nil {
|
||||
return err
|
||||
for _, authCode := range authCodes {
|
||||
if err := tx.Create(&authCode).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
refreshToken := model.OidcRefreshToken{
|
||||
@@ -177,13 +243,22 @@ func (s *TestService) SeedDatabase() error {
|
||||
return err
|
||||
}
|
||||
|
||||
userAuthorizedClient := model.UserAuthorizedOidcClient{
|
||||
Scope: "openid profile email",
|
||||
UserID: users[0].ID,
|
||||
ClientID: oidcClients[0].ID,
|
||||
userAuthorizedClients := []model.UserAuthorizedOidcClient{
|
||||
{
|
||||
Scope: "openid profile email",
|
||||
UserID: users[0].ID,
|
||||
ClientID: oidcClients[0].ID,
|
||||
},
|
||||
{
|
||||
Scope: "openid profile email",
|
||||
UserID: users[1].ID,
|
||||
ClientID: oidcClients[2].ID,
|
||||
},
|
||||
}
|
||||
if err := tx.Create(&userAuthorizedClient).Error; err != nil {
|
||||
return err
|
||||
for _, userAuthorizedClient := range userAuthorizedClients {
|
||||
if err := tx.Create(&userAuthorizedClient).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// To generate a new key pair, run the following command:
|
||||
@@ -237,6 +312,50 @@ func (s *TestService) SeedDatabase() error {
|
||||
return err
|
||||
}
|
||||
|
||||
signupTokens := []model.SignupToken{
|
||||
{
|
||||
Base: model.Base{
|
||||
ID: "a1b2c3d4-e5f6-7890-abcd-ef1234567890",
|
||||
},
|
||||
Token: "VALID1234567890A",
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(24 * time.Hour)),
|
||||
UsageLimit: 1,
|
||||
UsageCount: 0,
|
||||
},
|
||||
{
|
||||
Base: model.Base{
|
||||
ID: "b2c3d4e5-f6g7-8901-bcde-f12345678901",
|
||||
},
|
||||
Token: "PARTIAL567890ABC",
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(7 * 24 * time.Hour)),
|
||||
UsageLimit: 5,
|
||||
UsageCount: 2,
|
||||
},
|
||||
{
|
||||
Base: model.Base{
|
||||
ID: "c3d4e5f6-g7h8-9012-cdef-123456789012",
|
||||
},
|
||||
Token: "EXPIRED34567890B",
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(-24 * time.Hour)), // Expired
|
||||
UsageLimit: 3,
|
||||
UsageCount: 1,
|
||||
},
|
||||
{
|
||||
Base: model.Base{
|
||||
ID: "d4e5f6g7-h8i9-0123-def0-234567890123",
|
||||
},
|
||||
Token: "FULLYUSED567890C",
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(24 * time.Hour)),
|
||||
UsageLimit: 1,
|
||||
UsageCount: 1, // Usage limit reached
|
||||
},
|
||||
}
|
||||
for _, token := range signupTokens {
|
||||
if err := tx.Create(&token).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -405,3 +524,45 @@ func (s *TestService) SetLdapTestConfig(ctx context.Context) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *TestService) SignRefreshToken(userID, clientID, refreshToken string) (string, error) {
|
||||
return s.jwtService.GenerateOAuthRefreshToken(userID, clientID, refreshToken)
|
||||
}
|
||||
|
||||
// GetExternalIdPJWKS returns the JWKS for the "external IdP".
|
||||
func (s *TestService) GetExternalIdPJWKS() (jwk.Set, error) {
|
||||
pubKey, err := s.externalIdPKey.PublicKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get public key: %w", err)
|
||||
}
|
||||
|
||||
set := jwk.NewSet()
|
||||
err = set.AddKey(pubKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add public key to set: %w", err)
|
||||
}
|
||||
|
||||
return set, nil
|
||||
}
|
||||
|
||||
func (s *TestService) SignExternalIdPToken(iss, sub, aud string) (string, error) {
|
||||
now := time.Now()
|
||||
token, err := jwt.NewBuilder().
|
||||
Subject(sub).
|
||||
Expiration(now.Add(time.Hour)).
|
||||
IssuedAt(now).
|
||||
Issuer(iss).
|
||||
Audience([]string{aud}).
|
||||
Build()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to build token: %w", err)
|
||||
}
|
||||
|
||||
alg, _ := s.externalIdPKey.Algorithm()
|
||||
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.externalIdPKey))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to sign token: %w", err)
|
||||
}
|
||||
|
||||
return string(signed), nil
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -22,9 +23,10 @@ import (
|
||||
)
|
||||
|
||||
type GeoLiteService struct {
|
||||
httpClient *http.Client
|
||||
disableUpdater bool
|
||||
mutex sync.RWMutex
|
||||
httpClient *http.Client
|
||||
disableUpdater bool
|
||||
mutex sync.RWMutex
|
||||
localIPv6Ranges []*net.IPNet
|
||||
}
|
||||
|
||||
var localhostIPNets = []*net.IPNet{
|
||||
@@ -54,17 +56,84 @@ func NewGeoLiteService(httpClient *http.Client) *GeoLiteService {
|
||||
service.disableUpdater = true
|
||||
}
|
||||
|
||||
// Initialize IPv6 local ranges
|
||||
if err := service.initializeIPv6LocalRanges(); err != nil {
|
||||
log.Printf("Warning: Failed to initialize IPv6 local ranges: %v", err)
|
||||
}
|
||||
|
||||
return service
|
||||
}
|
||||
|
||||
// initializeIPv6LocalRanges parses the LOCAL_IPV6_RANGES environment variable
|
||||
func (s *GeoLiteService) initializeIPv6LocalRanges() error {
|
||||
rangesEnv := common.EnvConfig.LocalIPv6Ranges
|
||||
if rangesEnv == "" {
|
||||
return nil // No local IPv6 ranges configured
|
||||
}
|
||||
|
||||
ranges := strings.Split(rangesEnv, ",")
|
||||
localRanges := make([]*net.IPNet, 0, len(ranges))
|
||||
|
||||
for _, rangeStr := range ranges {
|
||||
rangeStr = strings.TrimSpace(rangeStr)
|
||||
if rangeStr == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
_, ipNet, err := net.ParseCIDR(rangeStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid IPv6 range '%s': %w", rangeStr, err)
|
||||
}
|
||||
|
||||
// Ensure it's an IPv6 range
|
||||
if ipNet.IP.To4() != nil {
|
||||
return fmt.Errorf("range '%s' is not a valid IPv6 range", rangeStr)
|
||||
}
|
||||
|
||||
localRanges = append(localRanges, ipNet)
|
||||
}
|
||||
|
||||
s.localIPv6Ranges = localRanges
|
||||
|
||||
if len(localRanges) > 0 {
|
||||
log.Printf("Initialized %d IPv6 local ranges", len(localRanges))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// isLocalIPv6 checks if the given IPv6 address is within any of the configured local ranges
|
||||
func (s *GeoLiteService) isLocalIPv6(ip net.IP) bool {
|
||||
if ip.To4() != nil {
|
||||
return false // Not an IPv6 address
|
||||
}
|
||||
|
||||
for _, localRange := range s.localIPv6Ranges {
|
||||
if localRange.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *GeoLiteService) DisableUpdater() bool {
|
||||
return s.disableUpdater
|
||||
}
|
||||
|
||||
// GetLocationByIP returns the country and city of the given IP address.
|
||||
func (s *GeoLiteService) GetLocationByIP(ipAddress string) (country, city string, err error) {
|
||||
if ipAddress == "" {
|
||||
return "", "", nil
|
||||
}
|
||||
|
||||
// Check the IP address against known private IP ranges
|
||||
if ip := net.ParseIP(ipAddress); ip != nil {
|
||||
// Check IPv6 local ranges first
|
||||
if s.isLocalIPv6(ip) {
|
||||
return "Internal Network", "LAN", nil
|
||||
}
|
||||
|
||||
// Check existing IPv4 ranges
|
||||
for _, ipNet := range tailscaleIPNets {
|
||||
if ipNet.Contains(ip) {
|
||||
return "Internal Network", "Tailscale", nil
|
||||
@@ -82,6 +151,11 @@ func (s *GeoLiteService) GetLocationByIP(ipAddress string) (country, city string
|
||||
}
|
||||
}
|
||||
|
||||
addr, err := netip.ParseAddr(ipAddress)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to parse IP address: %w", err)
|
||||
}
|
||||
|
||||
// Race condition between reading and writing the database.
|
||||
s.mutex.RLock()
|
||||
defer s.mutex.RUnlock()
|
||||
@@ -92,11 +166,6 @@ func (s *GeoLiteService) GetLocationByIP(ipAddress string) (country, city string
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
addr, err := netip.ParseAddr(ipAddress)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to parse IP address: %w", err)
|
||||
}
|
||||
|
||||
var record struct {
|
||||
City struct {
|
||||
Names map[string]string `maxminddb:"names"`
|
||||
|
||||
220
backend/internal/service/geolite_service_test.go
Normal file
220
backend/internal/service/geolite_service_test.go
Normal file
@@ -0,0 +1,220 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGeoLiteService_IPv6LocalRanges(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
localRanges string
|
||||
testIP string
|
||||
expectedCountry string
|
||||
expectedCity string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "IPv6 in local range",
|
||||
localRanges: "2001:0db8:abcd:000::/56,2001:0db8:abcd:001::/56",
|
||||
testIP: "2001:0db8:abcd:000::1",
|
||||
expectedCountry: "Internal Network",
|
||||
expectedCity: "LAN",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 not in local range",
|
||||
localRanges: "2001:0db8:abcd:000::/56",
|
||||
testIP: "2001:0db8:ffff:000::1",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Multiple ranges - second range match",
|
||||
localRanges: "2001:0db8:abcd:000::/56,2001:0db8:abcd:001::/56",
|
||||
testIP: "2001:0db8:abcd:001::1",
|
||||
expectedCountry: "Internal Network",
|
||||
expectedCity: "LAN",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Empty local ranges",
|
||||
localRanges: "",
|
||||
testIP: "2001:0db8:abcd:000::1",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "IPv4 private address still works",
|
||||
localRanges: "2001:0db8:abcd:000::/56",
|
||||
testIP: "192.168.1.1",
|
||||
expectedCountry: "Internal Network",
|
||||
expectedCity: "LAN",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 loopback",
|
||||
localRanges: "2001:0db8:abcd:000::/56",
|
||||
testIP: "::1",
|
||||
expectedCountry: "Internal Network",
|
||||
expectedCity: "localhost",
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
originalConfig := common.EnvConfig.LocalIPv6Ranges
|
||||
common.EnvConfig.LocalIPv6Ranges = tt.localRanges
|
||||
defer func() {
|
||||
common.EnvConfig.LocalIPv6Ranges = originalConfig
|
||||
}()
|
||||
|
||||
service := NewGeoLiteService(&http.Client{})
|
||||
|
||||
country, city, err := service.GetLocationByIP(tt.testIP)
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil && country != "Internal Network" {
|
||||
t.Errorf("Expected error or internal network classification for external IP")
|
||||
}
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedCountry, country)
|
||||
assert.Equal(t, tt.expectedCity, city)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeoLiteService_isLocalIPv6(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
localRanges string
|
||||
testIP string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Valid IPv6 in range",
|
||||
localRanges: "2001:0db8:abcd:000::/56",
|
||||
testIP: "2001:0db8:abcd:000::1",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Valid IPv6 not in range",
|
||||
localRanges: "2001:0db8:abcd:000::/56",
|
||||
testIP: "2001:0db8:ffff:000::1",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "IPv4 address should return false",
|
||||
localRanges: "2001:0db8:abcd:000::/56",
|
||||
testIP: "192.168.1.1",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "No ranges configured",
|
||||
localRanges: "",
|
||||
testIP: "2001:0db8:abcd:000::1",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Edge of range",
|
||||
localRanges: "2001:0db8:abcd:000::/56",
|
||||
testIP: "2001:0db8:abcd:00ff:ffff:ffff:ffff:ffff",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
originalConfig := common.EnvConfig.LocalIPv6Ranges
|
||||
common.EnvConfig.LocalIPv6Ranges = tt.localRanges
|
||||
defer func() {
|
||||
common.EnvConfig.LocalIPv6Ranges = originalConfig
|
||||
}()
|
||||
|
||||
service := NewGeoLiteService(&http.Client{})
|
||||
ip := net.ParseIP(tt.testIP)
|
||||
if ip == nil {
|
||||
t.Fatalf("Invalid test IP: %s", tt.testIP)
|
||||
}
|
||||
|
||||
result := service.isLocalIPv6(ip)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeoLiteService_initializeIPv6LocalRanges(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envValue string
|
||||
expectError bool
|
||||
expectCount int
|
||||
}{
|
||||
{
|
||||
name: "Valid IPv6 ranges",
|
||||
envValue: "2001:0db8:abcd:000::/56,2001:0db8:abcd:001::/56",
|
||||
expectError: false,
|
||||
expectCount: 2,
|
||||
},
|
||||
{
|
||||
name: "Empty environment variable",
|
||||
envValue: "",
|
||||
expectError: false,
|
||||
expectCount: 0,
|
||||
},
|
||||
{
|
||||
name: "Invalid CIDR notation",
|
||||
envValue: "2001:0db8:abcd:000::/999",
|
||||
expectError: true,
|
||||
expectCount: 0,
|
||||
},
|
||||
{
|
||||
name: "IPv4 range in IPv6 env var",
|
||||
envValue: "192.168.1.0/24",
|
||||
expectError: true,
|
||||
expectCount: 0,
|
||||
},
|
||||
{
|
||||
name: "Mixed valid and invalid ranges",
|
||||
envValue: "2001:0db8:abcd:000::/56,invalid-range",
|
||||
expectError: true,
|
||||
expectCount: 0,
|
||||
},
|
||||
{
|
||||
name: "Whitespace handling",
|
||||
envValue: " 2001:0db8:abcd:000::/56 , 2001:0db8:abcd:001::/56 ",
|
||||
expectError: false,
|
||||
expectCount: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
originalConfig := common.EnvConfig.LocalIPv6Ranges
|
||||
common.EnvConfig.LocalIPv6Ranges = tt.envValue
|
||||
defer func() {
|
||||
common.EnvConfig.LocalIPv6Ranges = originalConfig
|
||||
}()
|
||||
|
||||
service := &GeoLiteService{
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
|
||||
err := service.initializeIPv6LocalRanges()
|
||||
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
assert.Len(t, service.localIPv6Ranges, tt.expectCount)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2,25 +2,20 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -28,8 +23,9 @@ const (
|
||||
// This is a JSON file containing a key encoded as JWK
|
||||
PrivateKeyFile = "jwt_private_key.json"
|
||||
|
||||
// RsaKeySize is the size, in bits, of the RSA key to generate if none is found
|
||||
RsaKeySize = 2048
|
||||
// PrivateKeyFileEncrypted is the path in the data/keys folder where the encrypted key is stored
|
||||
// This is a encrypted JSON file containing a key encoded as JWK
|
||||
PrivateKeyFileEncrypted = "jwt_private_key.json.enc"
|
||||
|
||||
// KeyUsageSigning is the usage for the private keys, for the "use" property
|
||||
KeyUsageSigning = "sig"
|
||||
@@ -41,9 +37,15 @@ const (
|
||||
// TokenTypeClaim is the claim used to identify the type of token
|
||||
TokenTypeClaim = "type"
|
||||
|
||||
// RefreshTokenClaim is the claim used for the refresh token's value
|
||||
RefreshTokenClaim = "rt"
|
||||
|
||||
// OAuthAccessTokenJWTType identifies a JWT as an OAuth access token
|
||||
OAuthAccessTokenJWTType = "oauth-access-token" //nolint:gosec
|
||||
|
||||
// OAuthRefreshTokenJWTType identifies a JWT as an OAuth refresh token
|
||||
OAuthRefreshTokenJWTType = "refresh-token"
|
||||
|
||||
// AccessTokenJWTType identifies a JWT as an access token used by Pocket ID
|
||||
AccessTokenJWTType = "access-token"
|
||||
|
||||
@@ -55,58 +57,74 @@ const (
|
||||
)
|
||||
|
||||
type JwtService struct {
|
||||
envConfig *common.EnvConfigSchema
|
||||
privateKey jwk.Key
|
||||
keyId string
|
||||
appConfigService *AppConfigService
|
||||
jwksEncoded []byte
|
||||
}
|
||||
|
||||
func NewJwtService(appConfigService *AppConfigService) *JwtService {
|
||||
func NewJwtService(db *gorm.DB, appConfigService *AppConfigService) *JwtService {
|
||||
service := &JwtService{}
|
||||
|
||||
// Ensure keys are generated or loaded
|
||||
if err := service.init(appConfigService, common.EnvConfig.KeysPath); err != nil {
|
||||
err := service.init(db, appConfigService, &common.EnvConfig)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to initialize jwt service: %v", err)
|
||||
}
|
||||
|
||||
return service
|
||||
}
|
||||
|
||||
func (s *JwtService) init(appConfigService *AppConfigService, keysPath string) error {
|
||||
func (s *JwtService) init(db *gorm.DB, appConfigService *AppConfigService, envConfig *common.EnvConfigSchema) (err error) {
|
||||
s.appConfigService = appConfigService
|
||||
s.envConfig = envConfig
|
||||
|
||||
// Ensure keys are generated or loaded
|
||||
return s.loadOrGenerateKey(keysPath)
|
||||
return s.loadOrGenerateKey(db)
|
||||
}
|
||||
|
||||
// loadOrGenerateKey loads the private key from the given path or generates it if not existing.
|
||||
func (s *JwtService) loadOrGenerateKey(keysPath string) error {
|
||||
var key jwk.Key
|
||||
|
||||
// First, check if we have a JWK file
|
||||
// If we do, then we just load that
|
||||
jwkPath := filepath.Join(keysPath, PrivateKeyFile)
|
||||
ok, err := utils.FileExists(jwkPath)
|
||||
func (s *JwtService) loadOrGenerateKey(db *gorm.DB) error {
|
||||
// Get the key provider
|
||||
keyProvider, err := jwkutils.GetKeyProvider(db, s.envConfig, s.appConfigService.GetDbConfig().InstanceID.Value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if private key file (JWK) exists at path '%s': %w", jwkPath, err)
|
||||
return fmt.Errorf("failed to get key provider: %w", err)
|
||||
}
|
||||
if ok {
|
||||
key, err = s.loadKeyJWK(jwkPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load private key file (JWK) at path '%s': %w", jwkPath, err)
|
||||
}
|
||||
|
||||
// Set the key, and we are done
|
||||
// Try loading a key
|
||||
key, err := keyProvider.LoadKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load key (provider type '%s'): %w", s.envConfig.KeysStorage, err)
|
||||
}
|
||||
|
||||
// If we have a key, store it in the object and we're done
|
||||
if key != nil {
|
||||
err = s.SetKey(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set private key: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// If we are here, we need to generate a new key
|
||||
key, err = s.generateNewRSAKey()
|
||||
err = s.generateKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate key: %w", err)
|
||||
}
|
||||
|
||||
// Save the newly-generated key
|
||||
err = keyProvider.SaveKey(s.privateKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save private key (provider type '%s'): %w", s.envConfig.KeysStorage, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateKey generates a new key and stores it in the object
|
||||
func (s *JwtService) generateKey() error {
|
||||
// Default is to generate RS256 (RSA-2048) keys
|
||||
key, err := jwkutils.GenerateKey(jwa.RS256().String(), "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate new private key: %w", err)
|
||||
}
|
||||
@@ -117,12 +135,6 @@ func (s *JwtService) loadOrGenerateKey(keysPath string) error {
|
||||
return fmt.Errorf("failed to set private key: %w", err)
|
||||
}
|
||||
|
||||
// Save the key as JWK
|
||||
err = SaveKeyJWK(s.privateKey, jwkPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save private key file at path '%s': %w", jwkPath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -188,13 +200,13 @@ func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
|
||||
Subject(user.ID).
|
||||
Expiration(now.Add(s.appConfigService.GetDbConfig().SessionDuration.AsDurationMinutes())).
|
||||
IssuedAt(now).
|
||||
Issuer(common.EnvConfig.AppURL).
|
||||
Issuer(s.envConfig.AppURL).
|
||||
Build()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to build token: %w", err)
|
||||
}
|
||||
|
||||
err = SetAudienceString(token, common.EnvConfig.AppURL)
|
||||
err = SetAudienceString(token, s.envConfig.AppURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
|
||||
}
|
||||
@@ -225,8 +237,8 @@ func (s *JwtService) VerifyAccessToken(tokenString string) (jwt.Token, error) {
|
||||
jwt.WithValidate(true),
|
||||
jwt.WithKey(alg, s.privateKey),
|
||||
jwt.WithAcceptableSkew(clockSkew),
|
||||
jwt.WithAudience(common.EnvConfig.AppURL),
|
||||
jwt.WithIssuer(common.EnvConfig.AppURL),
|
||||
jwt.WithAudience(s.envConfig.AppURL),
|
||||
jwt.WithIssuer(s.envConfig.AppURL),
|
||||
jwt.WithValidator(TokenTypeValidator(AccessTokenJWTType)),
|
||||
)
|
||||
if err != nil {
|
||||
@@ -236,41 +248,52 @@ func (s *JwtService) VerifyAccessToken(tokenString string) (jwt.Token, error) {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (s *JwtService) GenerateIDToken(userClaims map[string]any, clientID string, nonce string) (string, error) {
|
||||
// BuildIDToken creates an ID token with all claims
|
||||
func (s *JwtService) BuildIDToken(userClaims map[string]any, clientID string, nonce string) (jwt.Token, error) {
|
||||
now := time.Now()
|
||||
token, err := jwt.NewBuilder().
|
||||
Expiration(now.Add(1 * time.Hour)).
|
||||
IssuedAt(now).
|
||||
Issuer(common.EnvConfig.AppURL).
|
||||
Issuer(s.envConfig.AppURL).
|
||||
Build()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to build token: %w", err)
|
||||
return nil, fmt.Errorf("failed to build token: %w", err)
|
||||
}
|
||||
|
||||
err = SetAudienceString(token, clientID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
|
||||
return nil, fmt.Errorf("failed to set 'aud' claim in token: %w", err)
|
||||
}
|
||||
|
||||
err = SetTokenType(token, IDTokenJWTType)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to set 'type' claim in token: %w", err)
|
||||
return nil, fmt.Errorf("failed to set 'type' claim in token: %w", err)
|
||||
}
|
||||
|
||||
for k, v := range userClaims {
|
||||
err = token.Set(k, v)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to set claim '%s': %w", k, err)
|
||||
return nil, fmt.Errorf("failed to set claim '%s': %w", k, err)
|
||||
}
|
||||
}
|
||||
|
||||
if nonce != "" {
|
||||
err = token.Set("nonce", nonce)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to set claim 'nonce': %w", err)
|
||||
return nil, fmt.Errorf("failed to set claim 'nonce': %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// GenerateIDToken creates and signs an ID token
|
||||
func (s *JwtService) GenerateIDToken(userClaims map[string]any, clientID string, nonce string) (string, error) {
|
||||
token, err := s.BuildIDToken(userClaims, clientID, nonce)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
alg, _ := s.privateKey.Algorithm()
|
||||
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey))
|
||||
if err != nil {
|
||||
@@ -290,7 +313,7 @@ func (s *JwtService) VerifyIdToken(tokenString string, acceptExpiredTokens bool)
|
||||
jwt.WithValidate(true),
|
||||
jwt.WithKey(alg, s.privateKey),
|
||||
jwt.WithAcceptableSkew(clockSkew),
|
||||
jwt.WithIssuer(common.EnvConfig.AppURL),
|
||||
jwt.WithIssuer(s.envConfig.AppURL),
|
||||
jwt.WithValidator(TokenTypeValidator(IDTokenJWTType)),
|
||||
)
|
||||
|
||||
@@ -313,24 +336,88 @@ func (s *JwtService) VerifyIdToken(tokenString string, acceptExpiredTokens bool)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) (string, error) {
|
||||
// BuildOAuthAccessToken creates an OAuth access token with all claims
|
||||
func (s *JwtService) BuildOAuthAccessToken(user model.User, clientID string) (jwt.Token, error) {
|
||||
now := time.Now()
|
||||
token, err := jwt.NewBuilder().
|
||||
Subject(user.ID).
|
||||
Expiration(now.Add(1 * time.Hour)).
|
||||
IssuedAt(now).
|
||||
Issuer(common.EnvConfig.AppURL).
|
||||
Issuer(s.envConfig.AppURL).
|
||||
Build()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build token: %w", err)
|
||||
}
|
||||
|
||||
err = SetAudienceString(token, clientID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to set 'aud' claim in token: %w", err)
|
||||
}
|
||||
|
||||
err = SetTokenType(token, OAuthAccessTokenJWTType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to set 'type' claim in token: %w", err)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// GenerateOAuthAccessToken creates and signs an OAuth access token
|
||||
func (s *JwtService) GenerateOAuthAccessToken(user model.User, clientID string) (string, error) {
|
||||
token, err := s.BuildOAuthAccessToken(user, clientID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
alg, _ := s.privateKey.Algorithm()
|
||||
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to sign token: %w", err)
|
||||
}
|
||||
|
||||
return string(signed), nil
|
||||
}
|
||||
|
||||
func (s *JwtService) VerifyOAuthAccessToken(tokenString string) (jwt.Token, error) {
|
||||
alg, _ := s.privateKey.Algorithm()
|
||||
token, err := jwt.ParseString(
|
||||
tokenString,
|
||||
jwt.WithValidate(true),
|
||||
jwt.WithKey(alg, s.privateKey),
|
||||
jwt.WithAcceptableSkew(clockSkew),
|
||||
jwt.WithIssuer(s.envConfig.AppURL),
|
||||
jwt.WithValidator(TokenTypeValidator(OAuthAccessTokenJWTType)),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (s *JwtService) GenerateOAuthRefreshToken(userID string, clientID string, refreshToken string) (string, error) {
|
||||
now := time.Now()
|
||||
token, err := jwt.NewBuilder().
|
||||
Subject(userID).
|
||||
Expiration(now.Add(RefreshTokenDuration)).
|
||||
IssuedAt(now).
|
||||
Issuer(s.envConfig.AppURL).
|
||||
Build()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to build token: %w", err)
|
||||
}
|
||||
|
||||
err = token.Set(RefreshTokenClaim, refreshToken)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to set 'rt' claim in token: %w", err)
|
||||
}
|
||||
|
||||
err = SetAudienceString(token, clientID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
|
||||
}
|
||||
|
||||
err = SetTokenType(token, OAuthAccessTokenJWTType)
|
||||
err = SetTokenType(token, OAuthRefreshTokenJWTType)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to set 'type' claim in token: %w", err)
|
||||
}
|
||||
@@ -344,21 +431,58 @@ func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string)
|
||||
return string(signed), nil
|
||||
}
|
||||
|
||||
func (s *JwtService) VerifyOauthAccessToken(tokenString string) (jwt.Token, error) {
|
||||
func (s *JwtService) VerifyOAuthRefreshToken(tokenString string) (userID, clientID, rt string, err error) {
|
||||
alg, _ := s.privateKey.Algorithm()
|
||||
token, err := jwt.ParseString(
|
||||
tokenString,
|
||||
jwt.WithValidate(true),
|
||||
jwt.WithKey(alg, s.privateKey),
|
||||
jwt.WithAcceptableSkew(clockSkew),
|
||||
jwt.WithIssuer(common.EnvConfig.AppURL),
|
||||
jwt.WithValidator(TokenTypeValidator(OAuthAccessTokenJWTType)),
|
||||
jwt.WithIssuer(s.envConfig.AppURL),
|
||||
jwt.WithValidator(TokenTypeValidator(OAuthRefreshTokenJWTType)),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||||
return "", "", "", fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
err = token.Get(RefreshTokenClaim, &rt)
|
||||
if err != nil {
|
||||
return "", "", "", fmt.Errorf("failed to get '%s' claim from token: %w", RefreshTokenClaim, err)
|
||||
}
|
||||
|
||||
audiences, ok := token.Audience()
|
||||
if !ok || len(audiences) != 1 || audiences[0] == "" {
|
||||
return "", "", "", errors.New("failed to get 'aud' claim from token")
|
||||
}
|
||||
clientID = audiences[0]
|
||||
|
||||
userID, ok = token.Subject()
|
||||
if !ok {
|
||||
return "", "", "", errors.New("failed to get 'sub' claim from token")
|
||||
}
|
||||
|
||||
return userID, clientID, rt, nil
|
||||
}
|
||||
|
||||
// GetTokenType returns the type of the JWT token issued by Pocket ID, but **does not validate it**.
|
||||
func (s *JwtService) GetTokenType(tokenString string) (string, jwt.Token, error) {
|
||||
// Disable validation and verification to parse the token without checking it
|
||||
token, err := jwt.ParseString(
|
||||
tokenString,
|
||||
jwt.WithValidate(false),
|
||||
jwt.WithVerify(false),
|
||||
)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
var tokenType string
|
||||
err = token.Get(TokenTypeClaim, &tokenType)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to get token type claim: %w", err)
|
||||
}
|
||||
|
||||
return tokenType, token, nil
|
||||
}
|
||||
|
||||
// GetPublicJWK returns the JSON Web Key (JWK) for the public key.
|
||||
@@ -372,7 +496,7 @@ func (s *JwtService) GetPublicJWK() (jwk.Key, error) {
|
||||
return nil, fmt.Errorf("failed to get public key: %w", err)
|
||||
}
|
||||
|
||||
EnsureAlgInKey(pubKey)
|
||||
jwkutils.EnsureAlgInKey(pubKey, "", "")
|
||||
|
||||
return pubKey, nil
|
||||
}
|
||||
@@ -401,107 +525,6 @@ func (s *JwtService) GetKeyAlg() (jwa.KeyAlgorithm, error) {
|
||||
return alg, nil
|
||||
}
|
||||
|
||||
func (s *JwtService) loadKeyJWK(path string) (jwk.Key, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read key data: %w", err)
|
||||
}
|
||||
|
||||
key, err := jwk.ParseKey(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse key: %w", err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// EnsureAlgInKey ensures that the key contains an "alg" parameter, set depending on the key type
|
||||
func EnsureAlgInKey(key jwk.Key) {
|
||||
_, ok := key.Algorithm()
|
||||
if ok {
|
||||
// Algorithm is already set
|
||||
return
|
||||
}
|
||||
|
||||
switch key.KeyType() {
|
||||
case jwa.RSA():
|
||||
// Default to RS256 for RSA keys
|
||||
_ = key.Set(jwk.AlgorithmKey, jwa.RS256())
|
||||
case jwa.EC():
|
||||
// Default to ES256 for ECDSA keys
|
||||
_ = key.Set(jwk.AlgorithmKey, jwa.ES256())
|
||||
case jwa.OKP():
|
||||
// Default to EdDSA for OKP keys
|
||||
_ = key.Set(jwk.AlgorithmKey, jwa.EdDSA())
|
||||
}
|
||||
}
|
||||
|
||||
func (s *JwtService) generateNewRSAKey() (jwk.Key, error) {
|
||||
// We generate RSA keys only
|
||||
rawKey, err := rsa.GenerateKey(rand.Reader, RsaKeySize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate RSA private key: %w", err)
|
||||
}
|
||||
|
||||
// Import the raw key
|
||||
return importRawKey(rawKey)
|
||||
}
|
||||
|
||||
func importRawKey(rawKey any) (jwk.Key, error) {
|
||||
key, err := jwk.Import(rawKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to import generated private key: %w", err)
|
||||
}
|
||||
|
||||
// Generate the key ID
|
||||
kid, err := generateRandomKeyID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate key ID: %w", err)
|
||||
}
|
||||
_ = key.Set(jwk.KeyIDKey, kid)
|
||||
|
||||
// Set other required fields
|
||||
_ = key.Set(jwk.KeyUsageKey, KeyUsageSigning)
|
||||
EnsureAlgInKey(key)
|
||||
|
||||
return key, err
|
||||
}
|
||||
|
||||
// SaveKeyJWK saves a JWK to a file
|
||||
func SaveKeyJWK(key jwk.Key, path string) error {
|
||||
dir := filepath.Dir(path)
|
||||
err := os.MkdirAll(dir, 0700)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create directory '%s' for key file: %w", dir, err)
|
||||
}
|
||||
|
||||
keyFile, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create key file: %w", err)
|
||||
}
|
||||
defer keyFile.Close()
|
||||
|
||||
// Write the JSON file to disk
|
||||
enc := json.NewEncoder(keyFile)
|
||||
enc.SetEscapeHTML(false)
|
||||
err = enc.Encode(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write key file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateRandomKeyID generates a random key ID.
|
||||
func generateRandomKeyID() (string, error) {
|
||||
buf := make([]byte, 8)
|
||||
_, err := io.ReadFull(rand.Reader, buf)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read random bytes: %w", err)
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(buf), nil
|
||||
}
|
||||
|
||||
// GetIsAdmin returns the value of the "isAdmin" claim in the token
|
||||
func GetIsAdmin(token jwt.Token) (bool, error) {
|
||||
if !token.Has(IsAdminClaim) {
|
||||
@@ -509,7 +532,10 @@ func GetIsAdmin(token jwt.Token) (bool, error) {
|
||||
}
|
||||
var isAdmin bool
|
||||
err := token.Get(IsAdminClaim, &isAdmin)
|
||||
return isAdmin, err
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to get 'isAdmin' claim from token: %w", err)
|
||||
}
|
||||
return isAdmin, nil
|
||||
}
|
||||
|
||||
// SetTokenType sets the "type" claim in the token
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk"
|
||||
)
|
||||
|
||||
func TestJwtService_Init(t *testing.T) {
|
||||
@@ -32,9 +33,16 @@ func TestJwtService_Init(t *testing.T) {
|
||||
// Create a temporary directory for the test
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Setup the environment variable required by the token verification
|
||||
mockEnvConfig := &common.EnvConfigSchema{
|
||||
AppURL: "https://test.example.com",
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
}
|
||||
|
||||
// Initialize the JWT service
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Verify the private key was set
|
||||
@@ -65,9 +73,16 @@ func TestJwtService_Init(t *testing.T) {
|
||||
// Create a temporary directory for the test
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Setup the environment variable required by the token verification
|
||||
mockEnvConfig := &common.EnvConfigSchema{
|
||||
AppURL: "https://test.example.com",
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
}
|
||||
|
||||
// First create a service to generate a key
|
||||
firstService := &JwtService{}
|
||||
err := firstService.init(mockConfig, tempDir)
|
||||
err := firstService.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get the key ID of the first service
|
||||
@@ -76,7 +91,7 @@ func TestJwtService_Init(t *testing.T) {
|
||||
|
||||
// Now create a new service that should load the existing key
|
||||
secondService := &JwtService{}
|
||||
err = secondService.init(mockConfig, tempDir)
|
||||
err = secondService.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the loaded key has the same ID as the original
|
||||
@@ -89,12 +104,19 @@ func TestJwtService_Init(t *testing.T) {
|
||||
// Create a temporary directory for the test
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Setup the environment variable required by the token verification
|
||||
mockEnvConfig := &common.EnvConfigSchema{
|
||||
AppURL: "https://test.example.com",
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
}
|
||||
|
||||
// Create a new JWK and save it to disk
|
||||
origKeyID := createECDSAKeyJWK(t, tempDir)
|
||||
|
||||
// Now create a new service that should load the existing key
|
||||
svc := &JwtService{}
|
||||
err := svc.init(mockConfig, tempDir)
|
||||
err := svc.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Ensure loaded key has the right algorithm
|
||||
@@ -112,12 +134,19 @@ func TestJwtService_Init(t *testing.T) {
|
||||
// Create a temporary directory for the test
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Setup the environment variable required by the token verification
|
||||
mockEnvConfig := &common.EnvConfigSchema{
|
||||
AppURL: "https://test.example.com",
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
}
|
||||
|
||||
// Create a new JWK and save it to disk
|
||||
origKeyID := createEdDSAKeyJWK(t, tempDir)
|
||||
|
||||
// Now create a new service that should load the existing key
|
||||
svc := &JwtService{}
|
||||
err := svc.init(mockConfig, tempDir)
|
||||
err := svc.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Ensure loaded key has the right algorithm and curve
|
||||
@@ -146,9 +175,16 @@ func TestJwtService_GetPublicJWK(t *testing.T) {
|
||||
// Create a temporary directory for the test
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Setup the environment variable required by the token verification
|
||||
mockEnvConfig := &common.EnvConfigSchema{
|
||||
AppURL: "https://test.example.com",
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
}
|
||||
|
||||
// Create a JWT service with initialized key
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Get the JWK (public key)
|
||||
@@ -177,12 +213,19 @@ func TestJwtService_GetPublicJWK(t *testing.T) {
|
||||
// Create a temporary directory for the test
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Setup the environment variable required by the token verification
|
||||
mockEnvConfig := &common.EnvConfigSchema{
|
||||
AppURL: "https://test.example.com",
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
}
|
||||
|
||||
// Create an ECDSA key and save it as JWK
|
||||
originalKeyID := createECDSAKeyJWK(t, tempDir)
|
||||
|
||||
// Create a JWT service that loads the ECDSA key
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Get the JWK (public key)
|
||||
@@ -215,12 +258,19 @@ func TestJwtService_GetPublicJWK(t *testing.T) {
|
||||
// Create a temporary directory for the test
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Setup the environment variable required by the token verification
|
||||
mockEnvConfig := &common.EnvConfigSchema{
|
||||
AppURL: "https://test.example.com",
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
}
|
||||
|
||||
// Create an EdDSA key and save it as JWK
|
||||
originalKeyID := createEdDSAKeyJWK(t, tempDir)
|
||||
|
||||
// Create a JWT service that loads the EdDSA key
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Get the JWK (public key)
|
||||
@@ -275,16 +325,16 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
||||
})
|
||||
|
||||
// Setup the environment variable required by the token verification
|
||||
originalAppURL := common.EnvConfig.AppURL
|
||||
common.EnvConfig.AppURL = "https://test.example.com"
|
||||
defer func() {
|
||||
common.EnvConfig.AppURL = originalAppURL
|
||||
}()
|
||||
mockEnvConfig := &common.EnvConfigSchema{
|
||||
AppURL: "https://test.example.com",
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
}
|
||||
|
||||
t.Run("generates token for regular user", func(t *testing.T) {
|
||||
// Create a JWT service
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Create a test user
|
||||
@@ -327,7 +377,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
||||
t.Run("generates token for admin user", func(t *testing.T) {
|
||||
// Create a JWT service
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Create a test admin user
|
||||
@@ -363,7 +413,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
||||
})
|
||||
|
||||
service := &JwtService{}
|
||||
err := service.init(customMockConfig, tempDir)
|
||||
err := service.init(nil, customMockConfig, mockEnvConfig)
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Create a test user
|
||||
@@ -398,7 +448,10 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
||||
|
||||
// Create a JWT service that loads the key
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
})
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Verify it loaded the right key
|
||||
@@ -452,7 +505,10 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
||||
|
||||
// Create a JWT service that loads the key
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
})
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Verify it loaded the right key
|
||||
@@ -506,7 +562,10 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
||||
|
||||
// Create a JWT service that loads the key
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
})
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Verify it loaded the right key
|
||||
@@ -562,16 +621,16 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
})
|
||||
|
||||
// Setup the environment variable required by the token verification
|
||||
originalAppURL := common.EnvConfig.AppURL
|
||||
common.EnvConfig.AppURL = "https://test.example.com"
|
||||
defer func() {
|
||||
common.EnvConfig.AppURL = originalAppURL
|
||||
}()
|
||||
mockEnvConfig := &common.EnvConfigSchema{
|
||||
AppURL: "https://test.example.com",
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
}
|
||||
|
||||
t.Run("generates and verifies ID token with standard claims", func(t *testing.T) {
|
||||
// Create a JWT service
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Create test claims
|
||||
@@ -600,7 +659,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
assert.Equal(t, []string{clientID}, audience, "Audience should contain the client ID")
|
||||
issuer, ok := claims.Issuer()
|
||||
_ = assert.True(t, ok, "Issuer not found in token") &&
|
||||
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
|
||||
assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL")
|
||||
|
||||
// Check token expiration time is approximately 1 hour from now
|
||||
expectedExp := time.Now().Add(1 * time.Hour)
|
||||
@@ -613,7 +672,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
t.Run("can accept expired tokens if told so", func(t *testing.T) {
|
||||
// Create a JWT service
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Create test claims
|
||||
@@ -627,7 +686,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
// Create a token that's already expired
|
||||
token, err := jwt.NewBuilder().
|
||||
Subject(userClaims["sub"].(string)).
|
||||
Issuer(common.EnvConfig.AppURL).
|
||||
Issuer(service.envConfig.AppURL).
|
||||
Audience([]string{clientID}).
|
||||
IssuedAt(time.Now().Add(-2 * time.Hour)).
|
||||
Expiration(time.Now().Add(-1 * time.Hour)). // Expired 1 hour ago
|
||||
@@ -665,13 +724,13 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
assert.Equal(t, userClaims["sub"], subject, "Token subject should match user ID")
|
||||
issuer, ok := claims.Issuer()
|
||||
_ = assert.True(t, ok, "Issuer not found in token") &&
|
||||
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
|
||||
assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL")
|
||||
})
|
||||
|
||||
t.Run("generates and verifies ID token with nonce", func(t *testing.T) {
|
||||
// Create a JWT service
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Create test claims with nonce
|
||||
@@ -702,7 +761,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
t.Run("fails verification with incorrect issuer", func(t *testing.T) {
|
||||
// Create a JWT service
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Generate a token with standard claims
|
||||
@@ -713,7 +772,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
require.NoError(t, err, "Failed to generate ID token")
|
||||
|
||||
// Temporarily change the app URL to simulate wrong issuer
|
||||
common.EnvConfig.AppURL = "https://wrong-issuer.com"
|
||||
service.envConfig.AppURL = "https://wrong-issuer.com"
|
||||
|
||||
// Verify should fail due to issuer mismatch
|
||||
_, err = service.VerifyIdToken(tokenString, false)
|
||||
@@ -730,7 +789,10 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
|
||||
// Create a JWT service that loads the key
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
})
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Verify it loaded the right key
|
||||
@@ -761,7 +823,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
assert.Equal(t, "eddsauser456", subject, "Token subject should match user ID")
|
||||
issuer, ok := claims.Issuer()
|
||||
_ = assert.True(t, ok, "Issuer not found in token") &&
|
||||
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
|
||||
assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL")
|
||||
|
||||
// Verify the key type is OKP
|
||||
publicKey, err := service.GetPublicJWK()
|
||||
@@ -783,7 +845,10 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
|
||||
// Create a JWT service that loads the key
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
})
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Verify it loaded the right key
|
||||
@@ -794,7 +859,6 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
// Create test claims
|
||||
userClaims := map[string]interface{}{
|
||||
"sub": "ecdsauser456",
|
||||
"name": "ECDSA User",
|
||||
"email": "ecdsauser@example.com",
|
||||
}
|
||||
const clientID = "ecdsa-client-123"
|
||||
@@ -814,7 +878,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
assert.Equal(t, "ecdsauser456", subject, "Token subject should match user ID")
|
||||
issuer, ok := claims.Issuer()
|
||||
_ = assert.True(t, ok, "Issuer not found in token") &&
|
||||
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
|
||||
assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL")
|
||||
|
||||
// Verify the key type is EC
|
||||
publicKey, err := service.GetPublicJWK()
|
||||
@@ -836,7 +900,10 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
|
||||
// Create a JWT service that loads the key
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
})
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Verify it loaded the right key
|
||||
@@ -867,21 +934,11 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
assert.Equal(t, "rsauser456", subject, "Token subject should match user ID")
|
||||
issuer, ok := claims.Issuer()
|
||||
_ = assert.True(t, ok, "Issuer not found in token") &&
|
||||
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
|
||||
|
||||
// Verify the key type is RSA
|
||||
publicKey, err := service.GetPublicJWK()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, jwa.RSA().String(), publicKey.KeyType().String(), "Key type should be RSA")
|
||||
|
||||
// Verify the algorithm is RS256
|
||||
alg, ok := publicKey.Algorithm()
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, jwa.RS256().String(), alg.String(), "Algorithm should be RS256")
|
||||
assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL")
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
||||
func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
|
||||
// Create a temporary directory for the test
|
||||
tempDir := t.TempDir()
|
||||
|
||||
@@ -891,16 +948,16 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
||||
})
|
||||
|
||||
// Setup the environment variable required by the token verification
|
||||
originalAppURL := common.EnvConfig.AppURL
|
||||
common.EnvConfig.AppURL = "https://test.example.com"
|
||||
defer func() {
|
||||
common.EnvConfig.AppURL = originalAppURL
|
||||
}()
|
||||
mockEnvConfig := &common.EnvConfigSchema{
|
||||
AppURL: "https://test.example.com",
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
}
|
||||
|
||||
t.Run("generates and verifies OAuth access token with standard claims", func(t *testing.T) {
|
||||
// Create a JWT service
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Create a test user
|
||||
@@ -913,12 +970,12 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
||||
const clientID = "test-client-123"
|
||||
|
||||
// Generate a token
|
||||
tokenString, err := service.GenerateOauthAccessToken(user, clientID)
|
||||
tokenString, err := service.GenerateOAuthAccessToken(user, clientID)
|
||||
require.NoError(t, err, "Failed to generate OAuth access token")
|
||||
assert.NotEmpty(t, tokenString, "Token should not be empty")
|
||||
|
||||
// Verify the token
|
||||
claims, err := service.VerifyOauthAccessToken(tokenString)
|
||||
claims, err := service.VerifyOAuthAccessToken(tokenString)
|
||||
require.NoError(t, err, "Failed to verify generated OAuth access token")
|
||||
|
||||
// Check the claims
|
||||
@@ -930,7 +987,7 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
||||
assert.Equal(t, []string{clientID}, audience, "Audience should contain the client ID")
|
||||
issuer, ok := claims.Issuer()
|
||||
_ = assert.True(t, ok, "Issuer not found in token") &&
|
||||
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
|
||||
assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL")
|
||||
|
||||
// Check token expiration time is approximately 1 hour from now
|
||||
expectedExp := time.Now().Add(1 * time.Hour)
|
||||
@@ -943,7 +1000,7 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
||||
t.Run("fails verification for expired token", func(t *testing.T) {
|
||||
// Create a JWT service with a mock function to generate an expired token
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Create a test user
|
||||
@@ -960,7 +1017,7 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
||||
Expiration(time.Now().Add(-1 * time.Hour)). // Expired 1 hour ago
|
||||
IssuedAt(time.Now().Add(-2 * time.Hour)).
|
||||
Audience([]string{clientID}).
|
||||
Issuer(common.EnvConfig.AppURL).
|
||||
Issuer(service.envConfig.AppURL).
|
||||
Build()
|
||||
require.NoError(t, err, "Failed to build token")
|
||||
|
||||
@@ -971,7 +1028,7 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
||||
require.NoError(t, err, "Failed to sign token")
|
||||
|
||||
// Verify should fail due to expiration
|
||||
_, err = service.VerifyOauthAccessToken(string(signed))
|
||||
_, err = service.VerifyOAuthAccessToken(string(signed))
|
||||
require.Error(t, err, "Verification should fail with expired token")
|
||||
assert.Contains(t, err.Error(), `"exp" not satisfied`, "Error message should indicate token verification failure")
|
||||
})
|
||||
@@ -979,11 +1036,17 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
||||
t.Run("fails verification with invalid signature", func(t *testing.T) {
|
||||
// Create two JWT services with different keys
|
||||
service1 := &JwtService{}
|
||||
err := service1.init(mockConfig, t.TempDir()) // Use a different temp dir
|
||||
err := service1.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||
KeysStorage: "file",
|
||||
KeysPath: t.TempDir(), // Use a different temp dir
|
||||
})
|
||||
require.NoError(t, err, "Failed to initialize first JWT service")
|
||||
|
||||
service2 := &JwtService{}
|
||||
err = service2.init(mockConfig, t.TempDir()) // Use a different temp dir
|
||||
err = service2.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||
KeysStorage: "file",
|
||||
KeysPath: t.TempDir(), // Use a different temp dir
|
||||
})
|
||||
require.NoError(t, err, "Failed to initialize second JWT service")
|
||||
|
||||
// Create a test user
|
||||
@@ -995,11 +1058,11 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
||||
const clientID = "test-client-789"
|
||||
|
||||
// Generate a token with the first service
|
||||
tokenString, err := service1.GenerateOauthAccessToken(user, clientID)
|
||||
tokenString, err := service1.GenerateOAuthAccessToken(user, clientID)
|
||||
require.NoError(t, err, "Failed to generate OAuth access token")
|
||||
|
||||
// Verify with the second service should fail due to different keys
|
||||
_, err = service2.VerifyOauthAccessToken(tokenString)
|
||||
_, err = service2.VerifyOAuthAccessToken(tokenString)
|
||||
require.Error(t, err, "Verification should fail with invalid signature")
|
||||
assert.Contains(t, err.Error(), "verification error", "Error message should indicate token verification failure")
|
||||
})
|
||||
@@ -1013,7 +1076,10 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
||||
|
||||
// Create a JWT service that loads the key
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
})
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Verify it loaded the right key
|
||||
@@ -1031,12 +1097,12 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
||||
const clientID = "eddsa-oauth-client"
|
||||
|
||||
// Generate a token
|
||||
tokenString, err := service.GenerateOauthAccessToken(user, clientID)
|
||||
tokenString, err := service.GenerateOAuthAccessToken(user, clientID)
|
||||
require.NoError(t, err, "Failed to generate OAuth access token with key")
|
||||
assert.NotEmpty(t, tokenString, "Token should not be empty")
|
||||
|
||||
// Verify the token
|
||||
claims, err := service.VerifyOauthAccessToken(tokenString)
|
||||
claims, err := service.VerifyOAuthAccessToken(tokenString)
|
||||
require.NoError(t, err, "Failed to verify generated OAuth access token with key")
|
||||
|
||||
// Check the claims
|
||||
@@ -1067,7 +1133,10 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
||||
|
||||
// Create a JWT service that loads the key
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
})
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Verify it loaded the right key
|
||||
@@ -1085,12 +1154,12 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
||||
const clientID = "ecdsa-oauth-client"
|
||||
|
||||
// Generate a token
|
||||
tokenString, err := service.GenerateOauthAccessToken(user, clientID)
|
||||
tokenString, err := service.GenerateOAuthAccessToken(user, clientID)
|
||||
require.NoError(t, err, "Failed to generate OAuth access token with key")
|
||||
assert.NotEmpty(t, tokenString, "Token should not be empty")
|
||||
|
||||
// Verify the token
|
||||
claims, err := service.VerifyOauthAccessToken(tokenString)
|
||||
claims, err := service.VerifyOAuthAccessToken(tokenString)
|
||||
require.NoError(t, err, "Failed to verify generated OAuth access token with key")
|
||||
|
||||
// Check the claims
|
||||
@@ -1121,7 +1190,10 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
||||
|
||||
// Create a JWT service that loads the key
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
})
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Verify it loaded the right key
|
||||
@@ -1139,12 +1211,12 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
||||
const clientID = "rsa-oauth-client"
|
||||
|
||||
// Generate a token
|
||||
tokenString, err := service.GenerateOauthAccessToken(user, clientID)
|
||||
tokenString, err := service.GenerateOAuthAccessToken(user, clientID)
|
||||
require.NoError(t, err, "Failed to generate OAuth access token with key")
|
||||
assert.NotEmpty(t, tokenString, "Token should not be empty")
|
||||
|
||||
// Verify the token
|
||||
claims, err := service.VerifyOauthAccessToken(tokenString)
|
||||
claims, err := service.VerifyOAuthAccessToken(tokenString)
|
||||
require.NoError(t, err, "Failed to verify generated OAuth access token with key")
|
||||
|
||||
// Check the claims
|
||||
@@ -1167,6 +1239,98 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateVerifyOAuthRefreshToken(t *testing.T) {
|
||||
// Create a temporary directory for the test
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Initialize the JWT service with a mock AppConfigService
|
||||
mockConfig := NewTestAppConfigService(&model.AppConfig{})
|
||||
|
||||
// Setup the environment variable required by the token verification
|
||||
mockEnvConfig := &common.EnvConfigSchema{
|
||||
AppURL: "https://test.example.com",
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
}
|
||||
|
||||
t.Run("generates and verifies refresh token", func(t *testing.T) {
|
||||
// Create a JWT service
|
||||
service := &JwtService{}
|
||||
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Create a test user
|
||||
const (
|
||||
userID = "user123"
|
||||
clientID = "client123"
|
||||
refreshToken = "rt-123"
|
||||
)
|
||||
|
||||
// Generate a token
|
||||
tokenString, err := service.GenerateOAuthRefreshToken(userID, clientID, refreshToken)
|
||||
require.NoError(t, err, "Failed to generate refresh token")
|
||||
assert.NotEmpty(t, tokenString, "Token should not be empty")
|
||||
|
||||
// Verify the token
|
||||
resUser, resClient, resRT, err := service.VerifyOAuthRefreshToken(tokenString)
|
||||
require.NoError(t, err, "Failed to verify generated token")
|
||||
assert.Equal(t, userID, resUser, "Should return correct user ID")
|
||||
assert.Equal(t, clientID, resClient, "Should return correct client ID")
|
||||
assert.Equal(t, refreshToken, resRT, "Should return correct refresh token")
|
||||
})
|
||||
|
||||
t.Run("fails verification for expired token", func(t *testing.T) {
|
||||
// Create a JWT service
|
||||
service := &JwtService{}
|
||||
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Generate a token using JWT directly to create an expired token
|
||||
token, err := jwt.NewBuilder().
|
||||
Subject("user789").
|
||||
Expiration(time.Now().Add(-1 * time.Hour)). // Expired 1 hour ago
|
||||
IssuedAt(time.Now().Add(-2 * time.Hour)).
|
||||
Audience([]string{"client123"}).
|
||||
Issuer(service.envConfig.AppURL).
|
||||
Build()
|
||||
require.NoError(t, err, "Failed to build token")
|
||||
|
||||
signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), service.privateKey))
|
||||
require.NoError(t, err, "Failed to sign token")
|
||||
|
||||
// Verify should fail due to expiration
|
||||
_, _, _, err = service.VerifyOAuthRefreshToken(string(signed))
|
||||
require.Error(t, err, "Verification should fail with expired token")
|
||||
assert.Contains(t, err.Error(), `"exp" not satisfied`, "Error message should indicate token verification failure")
|
||||
})
|
||||
|
||||
t.Run("fails verification with invalid signature", func(t *testing.T) {
|
||||
// Create two JWT services with different keys
|
||||
service1 := &JwtService{}
|
||||
err := service1.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||
KeysStorage: "file",
|
||||
KeysPath: t.TempDir(), // Use a different temp dir
|
||||
})
|
||||
require.NoError(t, err, "Failed to initialize first JWT service")
|
||||
|
||||
service2 := &JwtService{}
|
||||
err = service2.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||
KeysStorage: "file",
|
||||
KeysPath: t.TempDir(), // Use a different temp dir
|
||||
})
|
||||
require.NoError(t, err, "Failed to initialize second JWT service")
|
||||
|
||||
// Generate a token with the first service
|
||||
tokenString, err := service1.GenerateOAuthRefreshToken("user789", "client123", "my-rt-123")
|
||||
require.NoError(t, err, "Failed to generate refresh token")
|
||||
|
||||
// Verify with the second service should fail due to different keys
|
||||
_, _, _, err = service2.VerifyOAuthRefreshToken(tokenString)
|
||||
require.Error(t, err, "Verification should fail with invalid signature")
|
||||
assert.Contains(t, err.Error(), "verification error", "Error message should indicate token verification failure")
|
||||
})
|
||||
}
|
||||
|
||||
func TestTokenTypeValidator(t *testing.T) {
|
||||
// Create a context for the validator function
|
||||
ctx := context.Background()
|
||||
@@ -1212,16 +1376,125 @@ func TestTokenTypeValidator(t *testing.T) {
|
||||
require.Error(t, err, "Validator should reject token without type claim")
|
||||
assert.Contains(t, err.Error(), "failed to get token type claim")
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetTokenType(t *testing.T) {
|
||||
// Create a temporary directory for the test
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Initialize the JWT service
|
||||
mockConfig := NewTestAppConfigService(&model.AppConfig{})
|
||||
service := &JwtService{}
|
||||
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
})
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
buildTokenForType := func(t *testing.T, typ string, setClaimsFn func(b *jwt.Builder)) string {
|
||||
t.Helper()
|
||||
|
||||
b := jwt.NewBuilder()
|
||||
b.Subject("user123")
|
||||
if setClaimsFn != nil {
|
||||
setClaimsFn(b)
|
||||
}
|
||||
|
||||
token, err := b.Build()
|
||||
require.NoError(t, err, "Failed to build token")
|
||||
|
||||
err = SetTokenType(token, typ)
|
||||
require.NoError(t, err, "Failed to set token type")
|
||||
|
||||
alg, _ := service.privateKey.Algorithm()
|
||||
signed, err := jwt.Sign(token, jwt.WithKey(alg, service.privateKey))
|
||||
require.NoError(t, err, "Failed to sign token")
|
||||
|
||||
return string(signed)
|
||||
}
|
||||
|
||||
t.Run("correctly identifies access tokens", func(t *testing.T) {
|
||||
tokenString := buildTokenForType(t, AccessTokenJWTType, nil)
|
||||
|
||||
// Get the token type without validating
|
||||
tokenType, _, err := service.GetTokenType(tokenString)
|
||||
require.NoError(t, err, "GetTokenType should not return an error")
|
||||
assert.Equal(t, AccessTokenJWTType, tokenType, "Token type should be correctly identified as access token")
|
||||
})
|
||||
|
||||
t.Run("correctly identifies ID tokens", func(t *testing.T) {
|
||||
tokenString := buildTokenForType(t, IDTokenJWTType, nil)
|
||||
|
||||
// Get the token type without validating
|
||||
tokenType, _, err := service.GetTokenType(tokenString)
|
||||
require.NoError(t, err, "GetTokenType should not return an error")
|
||||
assert.Equal(t, IDTokenJWTType, tokenType, "Token type should be correctly identified as ID token")
|
||||
})
|
||||
|
||||
t.Run("correctly identifies OAuth access tokens", func(t *testing.T) {
|
||||
tokenString := buildTokenForType(t, OAuthAccessTokenJWTType, nil)
|
||||
|
||||
// Get the token type without validating
|
||||
tokenType, _, err := service.GetTokenType(tokenString)
|
||||
require.NoError(t, err, "GetTokenType should not return an error")
|
||||
assert.Equal(t, OAuthAccessTokenJWTType, tokenType, "Token type should be correctly identified as OAuth access token")
|
||||
})
|
||||
|
||||
t.Run("correctly identifies refresh tokens", func(t *testing.T) {
|
||||
tokenString := buildTokenForType(t, OAuthRefreshTokenJWTType, nil)
|
||||
|
||||
// Get the token type without validating
|
||||
tokenType, _, err := service.GetTokenType(tokenString)
|
||||
require.NoError(t, err, "GetTokenType should not return an error")
|
||||
assert.Equal(t, OAuthRefreshTokenJWTType, tokenType, "Token type should be correctly identified as refresh token")
|
||||
})
|
||||
|
||||
t.Run("works with expired tokens", func(t *testing.T) {
|
||||
tokenString := buildTokenForType(t, AccessTokenJWTType, func(b *jwt.Builder) {
|
||||
b.Expiration(time.Now().Add(-1 * time.Hour)) // Expired 1 hour ago
|
||||
})
|
||||
|
||||
// Get the token type without validating
|
||||
tokenType, _, err := service.GetTokenType(tokenString)
|
||||
require.NoError(t, err, "GetTokenType should not return an error for expired tokens")
|
||||
assert.Equal(t, AccessTokenJWTType, tokenType, "Token type should be correctly identified even for expired tokens")
|
||||
})
|
||||
|
||||
t.Run("returns error for malformed tokens", func(t *testing.T) {
|
||||
// Try to get the token type of a malformed token
|
||||
tokenType, _, err := service.GetTokenType("not.a.valid.jwt.token")
|
||||
require.Error(t, err, "GetTokenType should return an error for malformed tokens")
|
||||
assert.Empty(t, tokenType, "Token type should be empty for malformed tokens")
|
||||
})
|
||||
|
||||
t.Run("returns error for tokens without type claim", func(t *testing.T) {
|
||||
// Create a token without type claim
|
||||
tokenString := buildTokenForType(t, "", nil)
|
||||
|
||||
// Get the token type without validating
|
||||
tokenType, _, err := service.GetTokenType(tokenString)
|
||||
require.Error(t, err, "GetTokenType should return an error for tokens without type claim")
|
||||
assert.Empty(t, tokenType, "Token type should be empty when type claim is missing")
|
||||
assert.Contains(t, err.Error(), "failed to get token type claim", "Error message should indicate missing token type claim")
|
||||
})
|
||||
}
|
||||
|
||||
func importKey(t *testing.T, privateKeyRaw any, path string) string {
|
||||
t.Helper()
|
||||
|
||||
privateKey, err := importRawKey(privateKeyRaw)
|
||||
privateKey, err := jwkutils.ImportRawKey(privateKeyRaw, "", "")
|
||||
require.NoError(t, err, "Failed to import private key")
|
||||
|
||||
err = SaveKeyJWK(privateKey, filepath.Join(path, PrivateKeyFile))
|
||||
keyProvider := &jwkutils.KeyProviderFile{}
|
||||
err = keyProvider.Init(jwkutils.KeyProviderOpts{
|
||||
EnvConfig: &common.EnvConfigSchema{
|
||||
KeysStorage: "file",
|
||||
KeysPath: path,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err, "Failed to init file key provider")
|
||||
|
||||
err = keyProvider.SaveKey(privateKey)
|
||||
require.NoError(t, err, "Failed to save key")
|
||||
|
||||
kid, _ := privateKey.KeyID()
|
||||
|
||||
@@ -13,12 +13,16 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/go-ldap/ldap/v3"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/text/unicode/norm"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type LdapService struct {
|
||||
@@ -122,7 +126,7 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.
|
||||
ldapGroupIDs := make(map[string]struct{}, len(result.Entries))
|
||||
|
||||
for _, value := range result.Entries {
|
||||
ldapId := value.GetAttributeValue(dbConfig.LdapAttributeGroupUniqueIdentifier.Value)
|
||||
ldapId := convertLdapIdToString(value.GetAttributeValue(dbConfig.LdapAttributeGroupUniqueIdentifier.Value))
|
||||
|
||||
// Skip groups without a valid LDAP ID
|
||||
if ldapId == "" {
|
||||
@@ -178,7 +182,7 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.
|
||||
var databaseUser model.User
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Where("username = ? AND ldap_id IS NOT NULL", username).
|
||||
Where("username = ? AND ldap_id IS NOT NULL", norm.NFC.String(username)).
|
||||
First(&databaseUser).
|
||||
Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
@@ -194,8 +198,9 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.
|
||||
syncGroup := dto.UserGroupCreateDto{
|
||||
Name: value.GetAttributeValue(dbConfig.LdapAttributeGroupName.Value),
|
||||
FriendlyName: value.GetAttributeValue(dbConfig.LdapAttributeGroupName.Value),
|
||||
LdapID: value.GetAttributeValue(dbConfig.LdapAttributeGroupUniqueIdentifier.Value),
|
||||
LdapID: ldapId,
|
||||
}
|
||||
dto.Normalize(syncGroup)
|
||||
|
||||
if databaseGroup.ID == "" {
|
||||
newGroup, err := s.groupService.createInternal(ctx, syncGroup, tx)
|
||||
@@ -286,7 +291,7 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
|
||||
ldapUserIDs := make(map[string]struct{}, len(result.Entries))
|
||||
|
||||
for _, value := range result.Entries {
|
||||
ldapId := value.GetAttributeValue(dbConfig.LdapAttributeUserUniqueIdentifier.Value)
|
||||
ldapId := convertLdapIdToString(value.GetAttributeValue(dbConfig.LdapAttributeUserUniqueIdentifier.Value))
|
||||
|
||||
// Skip users without a valid LDAP ID
|
||||
if ldapId == "" {
|
||||
@@ -306,7 +311,6 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
|
||||
|
||||
// If a user is found (even if disabled), enable them since they're now back in LDAP
|
||||
if databaseUser.ID != "" && databaseUser.Disabled {
|
||||
// Use the transaction instead of the direct context
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Model(&model.User{}).
|
||||
@@ -315,7 +319,7 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
|
||||
Error
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Failed to enable user %s: %v", databaseUser.Username, err)
|
||||
return fmt.Errorf("failed to enable user %s: %w", databaseUser.Username, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -341,6 +345,7 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
|
||||
IsAdmin: isAdmin,
|
||||
LdapID: ldapId,
|
||||
}
|
||||
dto.Normalize(newUser)
|
||||
|
||||
if databaseUser.ID == "" {
|
||||
_, err = s.userService.createUserInternal(ctx, newUser, true, tx)
|
||||
@@ -468,3 +473,21 @@ func getDNProperty(property string, str string) string {
|
||||
// CN not found, return an empty string
|
||||
return ""
|
||||
}
|
||||
|
||||
// convertLdapIdToString converts LDAP IDs to valid UTF-8 strings.
|
||||
// LDAP servers may return binary UUIDs (16 bytes) or other non-UTF-8 data.
|
||||
func convertLdapIdToString(ldapId string) string {
|
||||
if utf8.ValidString(ldapId) {
|
||||
return norm.NFC.String(ldapId)
|
||||
}
|
||||
|
||||
// Try to parse as binary UUID (16 bytes)
|
||||
if len(ldapId) == 16 {
|
||||
if parsedUUID, err := uuid.FromBytes([]byte(ldapId)); err == nil {
|
||||
return parsedUUID.String()
|
||||
}
|
||||
}
|
||||
|
||||
// As a last resort, encode as base64 to make it UTF-8 safe
|
||||
return base64.StdEncoding.EncodeToString([]byte(ldapId))
|
||||
}
|
||||
|
||||
@@ -71,3 +71,36 @@ func TestGetDNProperty(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertLdapIdToString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "valid UTF-8 string",
|
||||
input: "simple-utf8-id",
|
||||
expected: "simple-utf8-id",
|
||||
},
|
||||
{
|
||||
name: "binary UUID (16 bytes)",
|
||||
input: string([]byte{0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf0, 0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf1}),
|
||||
expected: "12345678-9abc-def0-1234-56789abcdef1",
|
||||
},
|
||||
{
|
||||
name: "non-UTF8, non-UUID returns base64",
|
||||
input: string([]byte{0xff, 0xfe, 0xfd, 0xfc}),
|
||||
expected: "//79/A==",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := convertLdapIdToString(tt.input)
|
||||
if got != tt.expected {
|
||||
t.Errorf("Expected %q, got %q", tt.expected, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,18 +3,25 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/httprc/v3"
|
||||
"github.com/lestrrat-go/httprc/v3/errsink"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"github.com/lestrrat-go/jwx/v3/jws"
|
||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
@@ -31,6 +38,11 @@ const (
|
||||
GrantTypeAuthorizationCode = "authorization_code"
|
||||
GrantTypeRefreshToken = "refresh_token"
|
||||
GrantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code"
|
||||
|
||||
ClientAssertionTypeJWTBearer = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" //nolint:gosec
|
||||
|
||||
RefreshTokenDuration = 30 * 24 * time.Hour // 30 days
|
||||
DeviceCodeDuration = 15 * time.Minute
|
||||
)
|
||||
|
||||
type OidcService struct {
|
||||
@@ -39,16 +51,61 @@ type OidcService struct {
|
||||
appConfigService *AppConfigService
|
||||
auditLogService *AuditLogService
|
||||
customClaimService *CustomClaimService
|
||||
|
||||
httpClient *http.Client
|
||||
jwkCache *jwk.Cache
|
||||
}
|
||||
|
||||
func NewOidcService(db *gorm.DB, jwtService *JwtService, appConfigService *AppConfigService, auditLogService *AuditLogService, customClaimService *CustomClaimService) *OidcService {
|
||||
return &OidcService{
|
||||
func NewOidcService(
|
||||
ctx context.Context,
|
||||
db *gorm.DB,
|
||||
jwtService *JwtService,
|
||||
appConfigService *AppConfigService,
|
||||
auditLogService *AuditLogService,
|
||||
customClaimService *CustomClaimService,
|
||||
) (s *OidcService, err error) {
|
||||
s = &OidcService{
|
||||
db: db,
|
||||
jwtService: jwtService,
|
||||
appConfigService: appConfigService,
|
||||
auditLogService: auditLogService,
|
||||
customClaimService: customClaimService,
|
||||
}
|
||||
|
||||
// Note: we don't pass the HTTP Client with OTel instrumented to this because requests are always made in background and not tied to a specific trace
|
||||
s.jwkCache, err = s.getJWKCache(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) getJWKCache(ctx context.Context) (*jwk.Cache, error) {
|
||||
// We need to create a custom HTTP client to set a timeout.
|
||||
client := s.httpClient
|
||||
if client == nil {
|
||||
client = &http.Client{
|
||||
Timeout: 20 * time.Second,
|
||||
}
|
||||
|
||||
defaultTransport, ok := http.DefaultTransport.(*http.Transport)
|
||||
if !ok {
|
||||
// Indicates a development-time error
|
||||
panic("Default transport is not of type *http.Transport")
|
||||
}
|
||||
transport := defaultTransport.Clone()
|
||||
transport.TLSClientConfig.MinVersion = tls.VersionTLS12
|
||||
client.Transport = transport
|
||||
}
|
||||
|
||||
// Create the JWKS cache
|
||||
return jwk.NewCache(ctx,
|
||||
httprc.NewClient(
|
||||
httprc.WithErrorSink(errsink.NewSlog(slog.Default())),
|
||||
httprc.WithHTTPClient(client),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func (s *OidcService) Authorize(ctx context.Context, input dto.AuthorizeOidcClientRequestDto, userID, ipAddress, userAgent string) (string, string, error) {
|
||||
@@ -198,7 +255,7 @@ func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, input dto.O
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
_, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx)
|
||||
_, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input), true)
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
@@ -249,7 +306,7 @@ func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, input dto.O
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
accessToken, err := s.jwtService.GenerateOauthAccessToken(deviceAuth.User, input.ClientID)
|
||||
accessToken, err := s.jwtService.GenerateOAuthAccessToken(deviceAuth.User, input.ClientID)
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
@@ -279,7 +336,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, inpu
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx)
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input), true)
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
@@ -321,7 +378,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, inpu
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
accessToken, err := s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, input.ClientID)
|
||||
accessToken, err := s.jwtService.GenerateOAuthAccessToken(authorizationCodeMetaData.User, input.ClientID)
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
@@ -352,22 +409,39 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
|
||||
return CreatedTokens{}, &common.OidcMissingRefreshTokenError{}
|
||||
}
|
||||
|
||||
// Validate the signed refresh token and extract the actual token (which is a claim in the signed one)
|
||||
userID, clientID, rt, err := s.jwtService.VerifyOAuthRefreshToken(input.RefreshToken)
|
||||
if err != nil {
|
||||
return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{}
|
||||
}
|
||||
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
_, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx)
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input), true)
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
// The ID of the client that made the call must match the client ID in the token
|
||||
if client.ID != clientID {
|
||||
return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{}
|
||||
}
|
||||
|
||||
// Verify refresh token
|
||||
var storedRefreshToken model.OidcRefreshToken
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Preload("User").
|
||||
Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(input.RefreshToken), datatype.DateTime(time.Now())).
|
||||
Where(
|
||||
"token = ? AND expires_at > ? AND user_id = ? AND client_id = ?",
|
||||
utils.CreateSha256Hash(rt),
|
||||
datatype.DateTime(time.Now()),
|
||||
userID,
|
||||
input.ClientID,
|
||||
).
|
||||
First(&storedRefreshToken).
|
||||
Error
|
||||
if err != nil {
|
||||
@@ -383,7 +457,7 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
|
||||
}
|
||||
|
||||
// Generate a new access token
|
||||
accessToken, err := s.jwtService.GenerateOauthAccessToken(storedRefreshToken.User, input.ClientID)
|
||||
accessToken, err := s.jwtService.GenerateOAuthAccessToken(storedRefreshToken.User, input.ClientID)
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
@@ -415,70 +489,122 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecret, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
|
||||
if clientID == "" || clientSecret == "" {
|
||||
return introspectDto, &common.OidcMissingClientCredentialsError{}
|
||||
}
|
||||
|
||||
_, err = s.verifyClientCredentialsInternal(ctx, clientID, clientSecret, s.db)
|
||||
func (s *OidcService) IntrospectToken(ctx context.Context, creds ClientAuthCredentials, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, s.db, creds, false)
|
||||
if err != nil {
|
||||
return introspectDto, err
|
||||
}
|
||||
|
||||
token, err := s.jwtService.VerifyOauthAccessToken(tokenString)
|
||||
// Get the type of the token and the client ID
|
||||
tokenType, token, err := s.jwtService.GetTokenType(tokenString)
|
||||
if err != nil {
|
||||
if errors.Is(err, jwt.ParseError()) {
|
||||
// It's apparently not a valid JWT token, so we check if it's a valid refresh_token.
|
||||
return s.introspectRefreshToken(ctx, tokenString)
|
||||
}
|
||||
// We just treat the token as invalid
|
||||
introspectDto.Active = false
|
||||
return introspectDto, nil //nolint:nilerr
|
||||
}
|
||||
|
||||
// Every failure we get means the token is invalid. Nothing more to do with the error.
|
||||
// Get the audience from the token
|
||||
tokenAudiences, _ := token.Audience()
|
||||
if len(tokenAudiences) != 1 || tokenAudiences[0] == "" {
|
||||
introspectDto.Active = false
|
||||
return introspectDto, nil
|
||||
}
|
||||
|
||||
// Audience must match the client ID
|
||||
if client.ID != tokenAudiences[0] {
|
||||
return introspectDto, &common.OidcMissingClientCredentialsError{}
|
||||
}
|
||||
|
||||
// Introspect the token
|
||||
switch tokenType {
|
||||
case OAuthAccessTokenJWTType:
|
||||
return s.introspectAccessToken(client.ID, tokenString)
|
||||
case OAuthRefreshTokenJWTType:
|
||||
return s.introspectRefreshToken(ctx, client.ID, tokenString)
|
||||
default:
|
||||
// We just treat the token as invalid
|
||||
introspectDto.Active = false
|
||||
return introspectDto, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OidcService) introspectAccessToken(clientID string, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
|
||||
token, err := s.jwtService.VerifyOAuthAccessToken(tokenString)
|
||||
if err != nil {
|
||||
// Every failure we get means the token is invalid. Nothing more to do with the error.
|
||||
introspectDto.Active = false
|
||||
return introspectDto, nil //nolint:nilerr
|
||||
}
|
||||
|
||||
// The ID of the client that made the request must match the client ID in the token
|
||||
audience, ok := token.Audience()
|
||||
if !ok || len(audience) != 1 || audience[0] == "" {
|
||||
introspectDto.Active = false
|
||||
return introspectDto, nil
|
||||
}
|
||||
if audience[0] != clientID {
|
||||
return introspectDto, &common.OidcMissingClientCredentialsError{}
|
||||
}
|
||||
|
||||
introspectDto.Active = true
|
||||
introspectDto.TokenType = "access_token"
|
||||
introspectDto.Audience = audience
|
||||
if token.Has("scope") {
|
||||
var asString string
|
||||
var asStrings []string
|
||||
var (
|
||||
asString string
|
||||
asStrings []string
|
||||
)
|
||||
if err := token.Get("scope", &asString); err == nil {
|
||||
introspectDto.Scope = asString
|
||||
} else if err := token.Get("scope", &asStrings); err == nil {
|
||||
introspectDto.Scope = strings.Join(asStrings, " ")
|
||||
}
|
||||
}
|
||||
if expiration, hasExpiration := token.Expiration(); hasExpiration {
|
||||
if expiration, ok := token.Expiration(); ok {
|
||||
introspectDto.Expiration = expiration.Unix()
|
||||
}
|
||||
if issuedAt, hasIssuedAt := token.IssuedAt(); hasIssuedAt {
|
||||
if issuedAt, ok := token.IssuedAt(); ok {
|
||||
introspectDto.IssuedAt = issuedAt.Unix()
|
||||
}
|
||||
if notBefore, hasNotBefore := token.NotBefore(); hasNotBefore {
|
||||
if notBefore, ok := token.NotBefore(); ok {
|
||||
introspectDto.NotBefore = notBefore.Unix()
|
||||
}
|
||||
if subject, hasSubject := token.Subject(); hasSubject {
|
||||
if subject, ok := token.Subject(); ok {
|
||||
introspectDto.Subject = subject
|
||||
}
|
||||
if audience, hasAudience := token.Audience(); hasAudience {
|
||||
introspectDto.Audience = audience
|
||||
}
|
||||
if issuer, hasIssuer := token.Issuer(); hasIssuer {
|
||||
if issuer, ok := token.Issuer(); ok {
|
||||
introspectDto.Issuer = issuer
|
||||
}
|
||||
if identifier, hasIdentifier := token.JwtID(); hasIdentifier {
|
||||
if identifier, ok := token.JwtID(); ok {
|
||||
introspectDto.Identifier = identifier
|
||||
}
|
||||
|
||||
return introspectDto, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) introspectRefreshToken(ctx context.Context, refreshToken string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
|
||||
func (s *OidcService) introspectRefreshToken(ctx context.Context, clientID string, refreshToken string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
|
||||
// Validate the signed refresh token and extract the actual token (which is a claim in the signed one)
|
||||
tokenUserID, tokenClientID, tokenRT, err := s.jwtService.VerifyOAuthRefreshToken(refreshToken)
|
||||
if err != nil {
|
||||
return introspectDto, fmt.Errorf("invalid refresh token: %w", err)
|
||||
}
|
||||
|
||||
// The ID of the client that made the call must match the client ID in the token
|
||||
if tokenClientID != clientID {
|
||||
return introspectDto, errors.New("invalid refresh token: client ID does not match")
|
||||
}
|
||||
|
||||
var storedRefreshToken model.OidcRefreshToken
|
||||
err = s.db.
|
||||
WithContext(ctx).
|
||||
Preload("User").
|
||||
Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(refreshToken), datatype.DateTime(time.Now())).
|
||||
Where(
|
||||
"token = ? AND expires_at > ? AND user_id = ? AND client_id = ?",
|
||||
utils.CreateSha256Hash(tokenRT),
|
||||
datatype.DateTime(time.Now()),
|
||||
tokenUserID,
|
||||
tokenClientID,
|
||||
).
|
||||
First(&storedRefreshToken).
|
||||
Error
|
||||
if err != nil {
|
||||
@@ -542,13 +668,9 @@ func (s *OidcService) ListClients(ctx context.Context, name string, sortedPagina
|
||||
|
||||
func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) {
|
||||
client := model.OidcClient{
|
||||
Name: input.Name,
|
||||
CallbackURLs: input.CallbackURLs,
|
||||
LogoutCallbackURLs: input.LogoutCallbackURLs,
|
||||
CreatedByID: userID,
|
||||
IsPublic: input.IsPublic,
|
||||
PkceEnabled: input.PkceEnabled,
|
||||
CreatedByID: userID,
|
||||
}
|
||||
updateOIDCClientModelFromDto(&client, &input)
|
||||
|
||||
err := s.db.
|
||||
WithContext(ctx).
|
||||
@@ -577,11 +699,7 @@ func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input d
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
|
||||
client.Name = input.Name
|
||||
client.CallbackURLs = input.CallbackURLs
|
||||
client.LogoutCallbackURLs = input.LogoutCallbackURLs
|
||||
client.IsPublic = input.IsPublic
|
||||
client.PkceEnabled = input.IsPublic || input.PkceEnabled
|
||||
updateOIDCClientModelFromDto(&client, &input)
|
||||
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
@@ -599,6 +717,29 @@ func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input d
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func updateOIDCClientModelFromDto(client *model.OidcClient, input *dto.OidcClientCreateDto) {
|
||||
// Base fields
|
||||
client.Name = input.Name
|
||||
client.CallbackURLs = input.CallbackURLs
|
||||
client.LogoutCallbackURLs = input.LogoutCallbackURLs
|
||||
client.IsPublic = input.IsPublic
|
||||
// PKCE is required for public clients
|
||||
client.PkceEnabled = input.IsPublic || input.PkceEnabled
|
||||
|
||||
// Credentials
|
||||
if len(input.Credentials.FederatedIdentities) > 0 {
|
||||
client.Credentials.FederatedIdentities = make([]model.OidcClientFederatedIdentity, len(input.Credentials.FederatedIdentities))
|
||||
for i, fi := range input.Credentials.FederatedIdentities {
|
||||
client.Credentials.FederatedIdentities[i] = model.OidcClientFederatedIdentity{
|
||||
Issuer: fi.Issuer,
|
||||
Audience: fi.Audience,
|
||||
Subject: fi.Subject,
|
||||
JWKS: fi.JWKS,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OidcService) DeleteClient(ctx context.Context, clientID string) error {
|
||||
var client model.OidcClient
|
||||
err := s.db.
|
||||
@@ -676,7 +817,7 @@ func (s *OidcService) GetClientLogo(ctx context.Context, clientID string) (strin
|
||||
}
|
||||
|
||||
func (s *OidcService) UpdateClientLogo(ctx context.Context, clientID string, file *multipart.FileHeader) error {
|
||||
fileType := utils.GetFileExtension(file.Filename)
|
||||
fileType := strings.ToLower(utils.GetFileExtension(file.Filename))
|
||||
if mimeType := utils.GetImageMimeType(fileType); mimeType == "" {
|
||||
return &common.FileTypeNotSupportedError{}
|
||||
}
|
||||
@@ -744,6 +885,7 @@ func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) err
|
||||
return errors.New("image not found")
|
||||
}
|
||||
|
||||
oldImageType := *client.ImageType
|
||||
client.ImageType = nil
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
@@ -753,7 +895,7 @@ func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) err
|
||||
return err
|
||||
}
|
||||
|
||||
imagePath := common.EnvConfig.UploadPath + "/oidc-client-images/" + client.ID + "." + *client.ImageType
|
||||
imagePath := common.EnvConfig.UploadPath + "/oidc-client-images/" + client.ID + "." + oldImageType
|
||||
if err := os.Remove(imagePath); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -766,97 +908,6 @@ func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OidcService) GetUserClaimsForClient(ctx context.Context, userID string, clientID string) (map[string]interface{}, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
claims, err := s.getUserClaimsForClientInternal(ctx, userID, clientID, s.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) getUserClaimsForClientInternal(ctx context.Context, userID string, clientID string, tx *gorm.DB) (map[string]interface{}, error) {
|
||||
var authorizedOidcClient model.UserAuthorizedOidcClient
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Preload("User.UserGroups").
|
||||
First(&authorizedOidcClient, "user_id = ? AND client_id = ?", userID, clientID).
|
||||
Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := authorizedOidcClient.User
|
||||
scopes := strings.Split(authorizedOidcClient.Scope, " ")
|
||||
|
||||
claims := map[string]interface{}{
|
||||
"sub": user.ID,
|
||||
}
|
||||
|
||||
if slices.Contains(scopes, "email") {
|
||||
claims["email"] = user.Email
|
||||
claims["email_verified"] = s.appConfigService.GetDbConfig().EmailsVerified.IsTrue()
|
||||
}
|
||||
|
||||
if slices.Contains(scopes, "groups") {
|
||||
userGroups := make([]string, len(user.UserGroups))
|
||||
for i, group := range user.UserGroups {
|
||||
userGroups[i] = group.Name
|
||||
}
|
||||
claims["groups"] = userGroups
|
||||
}
|
||||
|
||||
profileClaims := map[string]interface{}{
|
||||
"given_name": user.FirstName,
|
||||
"family_name": user.LastName,
|
||||
"name": user.FullName(),
|
||||
"preferred_username": user.Username,
|
||||
"picture": common.EnvConfig.AppURL + "/api/users/" + user.ID + "/profile-picture.png",
|
||||
}
|
||||
|
||||
if slices.Contains(scopes, "profile") {
|
||||
// Add profile claims
|
||||
for k, v := range profileClaims {
|
||||
claims[k] = v
|
||||
}
|
||||
|
||||
// Add custom claims
|
||||
customClaims, err := s.customClaimService.GetCustomClaimsForUserWithUserGroups(ctx, userID, tx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, customClaim := range customClaims {
|
||||
// The value of the custom claim can be a JSON object or a string
|
||||
var jsonValue interface{}
|
||||
err := json.Unmarshal([]byte(customClaim.Value), &jsonValue)
|
||||
if err == nil {
|
||||
// It's JSON so we store it as an object
|
||||
claims[customClaim.Key] = jsonValue
|
||||
} else {
|
||||
// Marshalling failed, so we store it as a string
|
||||
claims[customClaim.Key] = customClaim.Value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if slices.Contains(scopes, "email") {
|
||||
claims["email"] = user.Email
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) UpdateAllowedUserGroups(ctx context.Context, id string, input dto.OidcUpdateAllowedUserGroupsDto) (client model.OidcClient, err error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
@@ -1078,7 +1129,12 @@ func (s *OidcService) addCallbackURLToClient(ctx context.Context, client *model.
|
||||
}
|
||||
|
||||
func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.OidcDeviceAuthorizationRequestDto) (*dto.OidcDeviceAuthorizationResponseDto, error) {
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, s.db)
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, s.db, ClientAuthCredentials{
|
||||
ClientID: input.ClientID,
|
||||
ClientSecret: input.ClientSecret,
|
||||
ClientAssertionType: input.ClientAssertionType,
|
||||
ClientAssertion: input.ClientAssertion,
|
||||
}, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1098,7 +1154,7 @@ func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.O
|
||||
DeviceCode: deviceCode,
|
||||
UserCode: userCode,
|
||||
Scope: input.Scope,
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(15 * time.Minute)),
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(DeviceCodeDuration)),
|
||||
IsAuthorized: false,
|
||||
ClientID: client.ID,
|
||||
}
|
||||
@@ -1112,7 +1168,7 @@ func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.O
|
||||
UserCode: userCode,
|
||||
VerificationURI: common.EnvConfig.AppURL + "/device",
|
||||
VerificationURIComplete: common.EnvConfig.AppURL + "/device?code=" + userCode,
|
||||
ExpiresIn: 900, // 15 minutes
|
||||
ExpiresIn: int(DeviceCodeDuration.Seconds()),
|
||||
Interval: 5,
|
||||
}, nil
|
||||
}
|
||||
@@ -1243,6 +1299,20 @@ func (s *OidcService) GetAllowedGroupsCountOfClient(ctx context.Context, id stri
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) ListAuthorizedClients(ctx context.Context, userID string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.UserAuthorizedOidcClient, utils.PaginationResponse, error) {
|
||||
|
||||
query := s.db.
|
||||
WithContext(ctx).
|
||||
Model(&model.UserAuthorizedOidcClient{}).
|
||||
Preload("Client").
|
||||
Where("user_id = ?", userID)
|
||||
|
||||
var authorizedClients []model.UserAuthorizedOidcClient
|
||||
response, err := utils.PaginateAndSort(sortedPaginationRequest, query, &authorizedClients)
|
||||
|
||||
return authorizedClients, response, err
|
||||
}
|
||||
|
||||
func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, userID string, scope string, tx *gorm.DB) (string, error) {
|
||||
refreshToken, err := utils.GenerateRandomAlphanumericString(40)
|
||||
if err != nil {
|
||||
@@ -1254,7 +1324,7 @@ func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, u
|
||||
refreshTokenHash := utils.CreateSha256Hash(refreshToken)
|
||||
|
||||
m := model.OidcRefreshToken{
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(30 * 24 * time.Hour)), // 30 days
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(RefreshTokenDuration)),
|
||||
Token: refreshTokenHash,
|
||||
ClientID: clientID,
|
||||
UserID: userID,
|
||||
@@ -1269,7 +1339,13 @@ func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, u
|
||||
return "", err
|
||||
}
|
||||
|
||||
return refreshToken, nil
|
||||
// Sign the refresh token
|
||||
signed, err := s.jwtService.GenerateOAuthRefreshToken(userID, clientID, refreshToken)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to sign refresh token: %w", err)
|
||||
}
|
||||
|
||||
return signed, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) createAuthorizedClientInternal(ctx context.Context, userID string, clientID string, scope string, tx *gorm.DB) error {
|
||||
@@ -1290,33 +1366,332 @@ func (s *OidcService) createAuthorizedClientInternal(ctx context.Context, userID
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, clientID, clientSecret string, tx *gorm.DB) (model.OidcClient, error) {
|
||||
// First, ensure we have a valid client ID
|
||||
if clientID == "" {
|
||||
return model.OidcClient{}, &common.OidcMissingClientCredentialsError{}
|
||||
type ClientAuthCredentials struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
ClientAssertion string
|
||||
ClientAssertionType string
|
||||
}
|
||||
|
||||
func clientAuthCredentialsFromCreateTokensDto(d *dto.OidcCreateTokensDto) ClientAuthCredentials {
|
||||
return ClientAuthCredentials{
|
||||
ClientID: d.ClientID,
|
||||
ClientSecret: d.ClientSecret,
|
||||
ClientAssertion: d.ClientAssertion,
|
||||
ClientAssertionType: d.ClientAssertionType,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *gorm.DB, input ClientAuthCredentials, allowPublicClientsWithoutAuth bool) (client *model.OidcClient, err error) {
|
||||
isClientAssertion := input.ClientAssertionType == ClientAssertionTypeJWTBearer && input.ClientAssertion != ""
|
||||
|
||||
// Determine the client ID based on the authentication method
|
||||
var clientID string
|
||||
switch {
|
||||
case isClientAssertion:
|
||||
// Extract client ID from the JWT assertion's 'sub' claim
|
||||
clientID, err = s.extractClientIDFromAssertion(input.ClientAssertion)
|
||||
if err != nil {
|
||||
slog.Error("Failed to extract client ID from assertion", "error", err)
|
||||
return nil, &common.OidcClientAssertionInvalidError{}
|
||||
}
|
||||
case input.ClientID != "":
|
||||
// Use the provided client ID for other authentication methods
|
||||
clientID = input.ClientID
|
||||
default:
|
||||
return nil, &common.OidcMissingClientCredentialsError{}
|
||||
}
|
||||
|
||||
// Load the OIDC client's configuration
|
||||
var client model.OidcClient
|
||||
err := tx.
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
First(&client, "id = ?", clientID).
|
||||
Error
|
||||
if err != nil {
|
||||
return model.OidcClient{}, err
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) && isClientAssertion {
|
||||
return nil, &common.OidcClientAssertionInvalidError{}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If we have a client secret, we validate it
|
||||
// Otherwise, we require the client to be public
|
||||
if clientSecret != "" {
|
||||
err = bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
|
||||
// Validate credentials based on the authentication method
|
||||
switch {
|
||||
// First, if we have a client secret, we validate it
|
||||
case input.ClientSecret != "":
|
||||
err = bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(input.ClientSecret))
|
||||
if err != nil {
|
||||
return model.OidcClient{}, &common.OidcClientSecretInvalidError{}
|
||||
return nil, &common.OidcClientSecretInvalidError{}
|
||||
}
|
||||
return client, nil
|
||||
} else if !client.IsPublic {
|
||||
return model.OidcClient{}, &common.OidcMissingClientCredentialsError{}
|
||||
|
||||
// Next, check if we want to use client assertions from federated identities
|
||||
case isClientAssertion:
|
||||
err = s.verifyClientAssertionFromFederatedIdentities(ctx, client, input)
|
||||
if err != nil {
|
||||
log.Printf("Invalid assertion for client '%s': %v", client.ID, err)
|
||||
return nil, &common.OidcClientAssertionInvalidError{}
|
||||
}
|
||||
return client, nil
|
||||
|
||||
// There's no credentials
|
||||
// This is allowed only if the client is public
|
||||
case client.IsPublic && allowPublicClientsWithoutAuth:
|
||||
return client, nil
|
||||
|
||||
// If we're here, we have no credentials AND the client is not public, so credentials are required
|
||||
default:
|
||||
return nil, &common.OidcMissingClientCredentialsError{}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OidcService) jwkSetForURL(ctx context.Context, url string) (set jwk.Set, err error) {
|
||||
// Check if we have already registered the URL
|
||||
if !s.jwkCache.IsRegistered(ctx, url) {
|
||||
// We set a timeout because otherwise Register will keep trying in case of errors
|
||||
registerCtx, registerCancel := context.WithTimeout(ctx, 15*time.Second)
|
||||
defer registerCancel()
|
||||
// We need to register the URL
|
||||
err = s.jwkCache.Register(
|
||||
registerCtx,
|
||||
url,
|
||||
jwk.WithMaxInterval(24*time.Hour),
|
||||
jwk.WithMinInterval(15*time.Minute),
|
||||
jwk.WithWaitReady(true),
|
||||
)
|
||||
// In case of race conditions (two goroutines calling jwkCache.Register at the same time), it's possible we can get a conflict anyways, so we ignore that error
|
||||
if err != nil && !errors.Is(err, httprc.ErrResourceAlreadyExists()) {
|
||||
return nil, fmt.Errorf("failed to register JWK set: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return client, nil
|
||||
jwks, err := s.jwkCache.CachedSet(url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get cached JWK set: %w", err)
|
||||
}
|
||||
|
||||
return jwks, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) verifyClientAssertionFromFederatedIdentities(ctx context.Context, client *model.OidcClient, input ClientAuthCredentials) error {
|
||||
// First, parse the assertion JWT, without validating it, to check the issuer
|
||||
assertion := []byte(input.ClientAssertion)
|
||||
insecureToken, err := jwt.ParseInsecure(assertion)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse client assertion JWT: %w", err)
|
||||
}
|
||||
|
||||
issuer, _ := insecureToken.Issuer()
|
||||
if issuer == "" {
|
||||
return errors.New("client assertion does not contain an issuer claim")
|
||||
}
|
||||
|
||||
// Ensure that this client is federated with the one that issued the token
|
||||
ocfi, ok := client.Credentials.FederatedIdentityForIssuer(issuer)
|
||||
if !ok {
|
||||
return fmt.Errorf("client assertion is not from an allowed issuer: %s", issuer)
|
||||
}
|
||||
|
||||
// Get the JWK set for the issuer
|
||||
jwksURL := ocfi.JWKS
|
||||
if jwksURL == "" {
|
||||
// Default URL is from the issuer
|
||||
if strings.HasSuffix(issuer, "/") {
|
||||
jwksURL = issuer + ".well-known/jwks.json"
|
||||
} else {
|
||||
jwksURL = issuer + "/.well-known/jwks.json"
|
||||
}
|
||||
}
|
||||
jwks, err := s.jwkSetForURL(ctx, jwksURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get JWK set for issuer '%s': %w", issuer, err)
|
||||
}
|
||||
|
||||
// Set default audience and subject if missing
|
||||
audience := ocfi.Audience
|
||||
if audience == "" {
|
||||
// Default to the Pocket ID's URL
|
||||
audience = common.EnvConfig.AppURL
|
||||
}
|
||||
subject := ocfi.Subject
|
||||
if subject == "" {
|
||||
// Default to the client ID, per RFC 7523
|
||||
subject = client.ID
|
||||
}
|
||||
|
||||
// Now re-parse the token with proper validation
|
||||
// (Note: we don't use jwt.WithIssuer() because that would be redundant)
|
||||
_, err = jwt.Parse(assertion,
|
||||
jwt.WithValidate(true),
|
||||
jwt.WithAcceptableSkew(clockSkew),
|
||||
jwt.WithKeySet(jwks, jws.WithInferAlgorithmFromKey(true), jws.WithUseDefault(true)),
|
||||
jwt.WithAudience(audience),
|
||||
jwt.WithSubject(subject),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("client assertion is not valid: %w", err)
|
||||
}
|
||||
|
||||
// If we're here, the assertion is valid
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractClientIDFromAssertion extracts the client_id from the JWT assertion's 'sub' claim
|
||||
func (s *OidcService) extractClientIDFromAssertion(assertion string) (string, error) {
|
||||
// Parse the JWT without verification first to get the claims
|
||||
insecureToken, err := jwt.ParseInsecure([]byte(assertion))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse JWT assertion: %w", err)
|
||||
}
|
||||
|
||||
// Extract the subject claim which must be the client_id according to RFC 7523
|
||||
sub, ok := insecureToken.Subject()
|
||||
if !ok || sub == "" {
|
||||
return "", fmt.Errorf("missing or invalid 'sub' claim in JWT assertion")
|
||||
}
|
||||
|
||||
return sub, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, userID string, scopes string) (*dto.OidcClientPreviewDto, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
client, err := s.getClientInternal(ctx, clientID, tx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var user model.User
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Preload("UserGroups").
|
||||
First(&user, "id = ?", userID).
|
||||
Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !s.IsUserGroupAllowedToAuthorize(user, client) {
|
||||
return nil, &common.OidcAccessDeniedError{}
|
||||
}
|
||||
|
||||
dummyAuthorizedClient := model.UserAuthorizedOidcClient{
|
||||
UserID: userID,
|
||||
ClientID: clientID,
|
||||
Scope: scopes,
|
||||
User: user,
|
||||
}
|
||||
|
||||
userClaims, err := s.getUserClaimsFromAuthorizedClient(ctx, &dummyAuthorizedClient, tx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Commit the transaction before signing tokens to avoid locking the database for longer
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
idToken, err := s.jwtService.BuildIDToken(userClaims, clientID, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accessToken, err := s.jwtService.BuildOAuthAccessToken(user, clientID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
idTokenPayload, err := utils.GetClaimsFromToken(idToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accessTokenPayload, err := utils.GetClaimsFromToken(accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &dto.OidcClientPreviewDto{
|
||||
IdToken: idTokenPayload,
|
||||
AccessToken: accessTokenPayload,
|
||||
UserInfo: userClaims,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) GetUserClaimsForClient(ctx context.Context, userID string, clientID string) (map[string]any, error) {
|
||||
return s.getUserClaimsForClientInternal(ctx, userID, clientID, s.db)
|
||||
}
|
||||
|
||||
func (s *OidcService) getUserClaimsForClientInternal(ctx context.Context, userID string, clientID string, tx *gorm.DB) (map[string]any, error) {
|
||||
var authorizedOidcClient model.UserAuthorizedOidcClient
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Preload("User.UserGroups").
|
||||
First(&authorizedOidcClient, "user_id = ? AND client_id = ?", userID, clientID).
|
||||
Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.getUserClaimsFromAuthorizedClient(ctx, &authorizedOidcClient, tx)
|
||||
|
||||
}
|
||||
|
||||
func (s *OidcService) getUserClaimsFromAuthorizedClient(ctx context.Context, authorizedClient *model.UserAuthorizedOidcClient, tx *gorm.DB) (map[string]any, error) {
|
||||
user := authorizedClient.User
|
||||
scopes := strings.Split(authorizedClient.Scope, " ")
|
||||
|
||||
claims := make(map[string]any, 10)
|
||||
|
||||
claims["sub"] = user.ID
|
||||
if slices.Contains(scopes, "email") {
|
||||
claims["email"] = user.Email
|
||||
claims["email_verified"] = s.appConfigService.GetDbConfig().EmailsVerified.IsTrue()
|
||||
}
|
||||
|
||||
if slices.Contains(scopes, "groups") {
|
||||
userGroups := make([]string, len(user.UserGroups))
|
||||
for i, group := range user.UserGroups {
|
||||
userGroups[i] = group.Name
|
||||
}
|
||||
claims["groups"] = userGroups
|
||||
}
|
||||
|
||||
if slices.Contains(scopes, "profile") {
|
||||
// Add profile claims
|
||||
claims["given_name"] = user.FirstName
|
||||
claims["family_name"] = user.LastName
|
||||
claims["name"] = user.FullName()
|
||||
claims["preferred_username"] = user.Username
|
||||
claims["picture"] = common.EnvConfig.AppURL + "/api/users/" + user.ID + "/profile-picture.png"
|
||||
|
||||
// Add custom claims
|
||||
customClaims, err := s.customClaimService.GetCustomClaimsForUserWithUserGroups(ctx, user.ID, tx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, customClaim := range customClaims {
|
||||
// The value of the custom claim can be a JSON object or a string
|
||||
var jsonValue any
|
||||
err := json.Unmarshal([]byte(customClaim.Value), &jsonValue)
|
||||
if err == nil {
|
||||
// It's JSON, so we store it as an object
|
||||
claims[customClaim.Key] = jsonValue
|
||||
} else {
|
||||
// Marshaling failed, so we store it as a string
|
||||
claims[customClaim.Key] = customClaim.Value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if slices.Contains(scopes, "email") {
|
||||
claims["email"] = user.Email
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
381
backend/internal/service/oidc_service_test.go
Normal file
381
backend/internal/service/oidc_service_test.go
Normal file
@@ -0,0 +1,381 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||
testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
|
||||
)
|
||||
|
||||
// generateTestECDSAKey creates an ECDSA key for testing
|
||||
func generateTestECDSAKey(t *testing.T) (jwk.Key, []byte) {
|
||||
t.Helper()
|
||||
|
||||
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
privateJwk, err := jwk.Import(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = privateJwk.Set(jwk.KeyIDKey, "test-key-1")
|
||||
require.NoError(t, err)
|
||||
err = privateJwk.Set(jwk.AlgorithmKey, "ES256")
|
||||
require.NoError(t, err)
|
||||
err = privateJwk.Set("use", "sig")
|
||||
require.NoError(t, err)
|
||||
|
||||
publicJwk, err := jwk.PublicKeyOf(privateJwk)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a JWK Set with the public key
|
||||
jwkSet := jwk.NewSet()
|
||||
err = jwkSet.AddKey(publicJwk)
|
||||
require.NoError(t, err)
|
||||
jwkSetJSON, err := json.Marshal(jwkSet)
|
||||
require.NoError(t, err)
|
||||
|
||||
return privateJwk, jwkSetJSON
|
||||
}
|
||||
|
||||
func TestOidcService_jwkSetForURL(t *testing.T) {
|
||||
// Generate a test key for JWKS
|
||||
_, jwkSetJSON1 := generateTestECDSAKey(t)
|
||||
_, jwkSetJSON2 := generateTestECDSAKey(t)
|
||||
|
||||
// Create a mock HTTP client with responses for different URLs
|
||||
const (
|
||||
url1 = "https://example.com/.well-known/jwks.json"
|
||||
url2 = "https://other-issuer.com/jwks"
|
||||
)
|
||||
mockResponses := map[string]*http.Response{
|
||||
//nolint:bodyclose
|
||||
url1: testutils.NewMockResponse(http.StatusOK, string(jwkSetJSON1)),
|
||||
//nolint:bodyclose
|
||||
url2: testutils.NewMockResponse(http.StatusOK, string(jwkSetJSON2)),
|
||||
}
|
||||
httpClient := &http.Client{
|
||||
Transport: &testutils.MockRoundTripper{
|
||||
Responses: mockResponses,
|
||||
},
|
||||
}
|
||||
|
||||
// Create the OidcService with our mock client
|
||||
s := &OidcService{
|
||||
httpClient: httpClient,
|
||||
}
|
||||
|
||||
var err error
|
||||
s.jwkCache, err = s.getJWKCache(t.Context())
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("Fetches and caches JWK set", func(t *testing.T) {
|
||||
jwks, err := s.jwkSetForURL(t.Context(), url1)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, jwks)
|
||||
|
||||
// Verify the JWK set contains our key
|
||||
require.Equal(t, 1, jwks.Len())
|
||||
})
|
||||
|
||||
t.Run("Fails with invalid URL", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
|
||||
defer cancel()
|
||||
_, err := s.jwkSetForURL(ctx, "https://bad-url.com")
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
})
|
||||
|
||||
t.Run("Safe for concurrent use", func(t *testing.T) {
|
||||
const concurrency = 20
|
||||
|
||||
// Channel to collect errors
|
||||
errChan := make(chan error, concurrency)
|
||||
|
||||
// Start concurrent requests
|
||||
for range concurrency {
|
||||
go func() {
|
||||
jwks, err := s.jwkSetForURL(t.Context(), url2)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the JWK set is valid
|
||||
if jwks == nil || jwks.Len() != 1 {
|
||||
errChan <- assert.AnError
|
||||
return
|
||||
}
|
||||
|
||||
errChan <- nil
|
||||
}()
|
||||
}
|
||||
|
||||
// Check for errors
|
||||
for range concurrency {
|
||||
assert.NoError(t, <-errChan, "Concurrent JWK set fetching should not produce errors")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
||||
const (
|
||||
federatedClientIssuer = "https://external-idp.com"
|
||||
federatedClientAudience = "https://pocket-id.com"
|
||||
federatedClientIssuerDefaults = "https://external-idp-defaults.com/"
|
||||
)
|
||||
|
||||
var err error
|
||||
// Create a test database
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
|
||||
// Create two JWKs for testing
|
||||
privateJWK, jwkSetJSON := generateTestECDSAKey(t)
|
||||
require.NoError(t, err)
|
||||
privateJWKDefaults, jwkSetJSONDefaults := generateTestECDSAKey(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a mock HTTP client with custom transport to return the JWKS
|
||||
httpClient := &http.Client{
|
||||
Transport: &testutils.MockRoundTripper{
|
||||
Responses: map[string]*http.Response{
|
||||
//nolint:bodyclose
|
||||
federatedClientIssuer + "/jwks.json": testutils.NewMockResponse(http.StatusOK, string(jwkSetJSON)),
|
||||
//nolint:bodyclose
|
||||
federatedClientIssuerDefaults + ".well-known/jwks.json": testutils.NewMockResponse(http.StatusOK, string(jwkSetJSONDefaults)),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Init the OidcService
|
||||
s := &OidcService{
|
||||
db: db,
|
||||
httpClient: httpClient,
|
||||
}
|
||||
s.jwkCache, err = s.getJWKCache(t.Context())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create the test clients
|
||||
// 1. Confidential client
|
||||
confidentialClient, err := s.CreateClient(t.Context(), dto.OidcClientCreateDto{
|
||||
Name: "Confidential Client",
|
||||
CallbackURLs: []string{"https://example.com/callback"},
|
||||
}, "test-user-id")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a client secret for the confidential client
|
||||
confidentialSecret, err := s.CreateClientSecret(t.Context(), confidentialClient.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 2. Public client
|
||||
publicClient, err := s.CreateClient(t.Context(), dto.OidcClientCreateDto{
|
||||
Name: "Public Client",
|
||||
CallbackURLs: []string{"https://example.com/callback"},
|
||||
IsPublic: true,
|
||||
}, "test-user-id")
|
||||
require.NoError(t, err)
|
||||
|
||||
// 3. Confidential client with federated identity
|
||||
federatedClient, err := s.CreateClient(t.Context(), dto.OidcClientCreateDto{
|
||||
Name: "Federated Client",
|
||||
CallbackURLs: []string{"https://example.com/callback"},
|
||||
}, "test-user-id")
|
||||
require.NoError(t, err)
|
||||
|
||||
federatedClient, err = s.UpdateClient(t.Context(), federatedClient.ID, dto.OidcClientCreateDto{
|
||||
Name: federatedClient.Name,
|
||||
CallbackURLs: federatedClient.CallbackURLs,
|
||||
Credentials: dto.OidcClientCredentialsDto{
|
||||
FederatedIdentities: []dto.OidcClientFederatedIdentityDto{
|
||||
{
|
||||
Issuer: federatedClientIssuer,
|
||||
Audience: federatedClientAudience,
|
||||
Subject: federatedClient.ID,
|
||||
JWKS: federatedClientIssuer + "/jwks.json",
|
||||
},
|
||||
{Issuer: federatedClientIssuerDefaults},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test cases for confidential client (using client secret)
|
||||
t.Run("Confidential client", func(t *testing.T) {
|
||||
t.Run("Succeeds with valid secret", func(t *testing.T) {
|
||||
// Test with valid client credentials
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
||||
ClientID: confidentialClient.ID,
|
||||
ClientSecret: confidentialSecret,
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, client)
|
||||
assert.Equal(t, confidentialClient.ID, client.ID)
|
||||
})
|
||||
|
||||
t.Run("Fails with invalid secret", func(t *testing.T) {
|
||||
// Test with invalid client secret
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
||||
ClientID: confidentialClient.ID,
|
||||
ClientSecret: "invalid-secret",
|
||||
}, true)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, &common.OidcClientSecretInvalidError{})
|
||||
assert.Nil(t, client)
|
||||
})
|
||||
|
||||
t.Run("Fails with missing secret", func(t *testing.T) {
|
||||
// Test with missing client secret
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
||||
ClientID: confidentialClient.ID,
|
||||
}, true)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, &common.OidcMissingClientCredentialsError{})
|
||||
assert.Nil(t, client)
|
||||
})
|
||||
})
|
||||
|
||||
// Test cases for public client
|
||||
t.Run("Public client", func(t *testing.T) {
|
||||
t.Run("Succeeds with no credentials", func(t *testing.T) {
|
||||
// Public clients don't require client secret
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
||||
ClientID: publicClient.ID,
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, client)
|
||||
assert.Equal(t, publicClient.ID, client.ID)
|
||||
})
|
||||
|
||||
t.Run("Fails with no credentials if allowPublicClientsWithoutAuth is false", func(t *testing.T) {
|
||||
// Public clients don't require client secret
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
||||
ClientID: publicClient.ID,
|
||||
}, false)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, &common.OidcMissingClientCredentialsError{})
|
||||
assert.Nil(t, client)
|
||||
})
|
||||
})
|
||||
|
||||
// Test cases for federated client using JWT assertion
|
||||
t.Run("Federated client", func(t *testing.T) {
|
||||
t.Run("Succeeds with valid JWT", func(t *testing.T) {
|
||||
// Create JWT for federated identity
|
||||
token, err := jwt.NewBuilder().
|
||||
Issuer(federatedClientIssuer).
|
||||
Audience([]string{federatedClientAudience}).
|
||||
Subject(federatedClient.ID).
|
||||
IssuedAt(time.Now()).
|
||||
Expiration(time.Now().Add(10 * time.Minute)).
|
||||
Build()
|
||||
require.NoError(t, err)
|
||||
signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.ES256(), privateJWK))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test with valid JWT assertion
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
||||
ClientID: federatedClient.ID,
|
||||
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
||||
ClientAssertion: string(signedToken),
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, client)
|
||||
assert.Equal(t, federatedClient.ID, client.ID)
|
||||
})
|
||||
|
||||
t.Run("Fails with malformed JWT", func(t *testing.T) {
|
||||
// Test with invalid JWT assertion (just a random string)
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
||||
ClientID: federatedClient.ID,
|
||||
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
||||
ClientAssertion: "invalid.jwt.token",
|
||||
}, true)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, &common.OidcClientAssertionInvalidError{})
|
||||
assert.Nil(t, client)
|
||||
})
|
||||
|
||||
testBadJWT := func(builderFn func(builder *jwt.Builder)) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
// Populate all claims with valid values
|
||||
builder := jwt.NewBuilder().
|
||||
Issuer(federatedClientIssuer).
|
||||
Audience([]string{federatedClientAudience}).
|
||||
Subject(federatedClient.ID).
|
||||
IssuedAt(time.Now()).
|
||||
Expiration(time.Now().Add(10 * time.Minute))
|
||||
|
||||
// Call builderFn to override the claims
|
||||
builderFn(builder)
|
||||
|
||||
token, err := builder.Build()
|
||||
require.NoError(t, err)
|
||||
signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.ES256(), privateJWK))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test with invalid JWT assertion
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
||||
ClientID: federatedClient.ID,
|
||||
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
||||
ClientAssertion: string(signedToken),
|
||||
}, true)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, &common.OidcClientAssertionInvalidError{})
|
||||
require.Nil(t, client)
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("Fails with expired JWT", testBadJWT(func(builder *jwt.Builder) {
|
||||
builder.Expiration(time.Now().Add(-30 * time.Minute))
|
||||
}))
|
||||
|
||||
t.Run("Fails with wrong issuer in JWT", testBadJWT(func(builder *jwt.Builder) {
|
||||
builder.Issuer("https://bad-issuer.com")
|
||||
}))
|
||||
|
||||
t.Run("Fails with wrong audience in JWT", testBadJWT(func(builder *jwt.Builder) {
|
||||
builder.Audience([]string{"bad-audience"})
|
||||
}))
|
||||
|
||||
t.Run("Fails with wrong subject in JWT", testBadJWT(func(builder *jwt.Builder) {
|
||||
builder.Subject("bad-subject")
|
||||
}))
|
||||
|
||||
t.Run("Uses default values for audience and subject", func(t *testing.T) {
|
||||
// Create JWT for federated identity
|
||||
token, err := jwt.NewBuilder().
|
||||
Issuer(federatedClientIssuerDefaults).
|
||||
Audience([]string{common.EnvConfig.AppURL}).
|
||||
Subject(federatedClient.ID).
|
||||
IssuedAt(time.Now()).
|
||||
Expiration(time.Now().Add(10 * time.Minute)).
|
||||
Build()
|
||||
require.NoError(t, err)
|
||||
signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.ES256(), privateJWKDefaults))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test with valid JWT assertion
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
||||
ClientID: federatedClient.ID,
|
||||
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
||||
ClientAssertion: string(signedToken),
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, client)
|
||||
assert.Equal(t, federatedClient.ID, client.ID)
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -296,15 +296,21 @@ func (s *UserService) updateUserInternal(ctx context.Context, userID string, upd
|
||||
isLdapUser := user.LdapID != nil && s.appConfigService.GetDbConfig().LdapEnabled.IsTrue()
|
||||
allowOwnAccountEdit := s.appConfigService.GetDbConfig().AllowOwnAccountEdit.IsTrue()
|
||||
|
||||
// For LDAP users or if own account editing is not allowed, only allow updating the locale unless it's an LDAP sync
|
||||
if !isLdapSync && (isLdapUser || (!allowOwnAccountEdit && !updateOwnUser)) {
|
||||
if !isLdapSync && (isLdapUser || (!allowOwnAccountEdit && updateOwnUser)) {
|
||||
// Restricted update: Only locale can be changed when:
|
||||
// - User is from LDAP, OR
|
||||
// - User is editing their own account but global setting disallows self-editing
|
||||
// (Exception: LDAP sync operations can update everything)
|
||||
user.Locale = updatedUser.Locale
|
||||
} else {
|
||||
// Full update: Allow updating all personal fields
|
||||
user.FirstName = updatedUser.FirstName
|
||||
user.LastName = updatedUser.LastName
|
||||
user.Email = updatedUser.Email
|
||||
user.Username = updatedUser.Username
|
||||
user.Locale = updatedUser.Locale
|
||||
|
||||
// Admin-only fields: Only allow updates when not updating own account
|
||||
if !updateOwnUser {
|
||||
user.IsAdmin = updatedUser.IsAdmin
|
||||
user.Disabled = updatedUser.Disabled
|
||||
@@ -463,9 +469,7 @@ func (s *UserService) ExchangeOneTimeAccessToken(ctx context.Context, token stri
|
||||
return model.User{}, "", err
|
||||
}
|
||||
|
||||
if ipAddress != "" && userAgent != "" {
|
||||
s.auditLogService.Create(ctx, model.AuditLogEventOneTimeAccessTokenSignIn, ipAddress, userAgent, oneTimeAccessToken.User.ID, model.AuditLogData{}, tx)
|
||||
}
|
||||
s.auditLogService.Create(ctx, model.AuditLogEventOneTimeAccessTokenSignIn, ipAddress, userAgent, oneTimeAccessToken.User.ID, model.AuditLogData{}, tx)
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
@@ -523,7 +527,7 @@ func (s *UserService) UpdateUserGroups(ctx context.Context, id string, userGroup
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *UserService) SetupInitialAdmin(ctx context.Context) (model.User, string, error) {
|
||||
func (s *UserService) SignUpInitialAdmin(ctx context.Context, signUpData dto.SignUpDto) (model.User, string, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
@@ -533,26 +537,23 @@ func (s *UserService) SetupInitialAdmin(ctx context.Context) (model.User, string
|
||||
if err := tx.WithContext(ctx).Model(&model.User{}).Count(&userCount).Error; err != nil {
|
||||
return model.User{}, "", err
|
||||
}
|
||||
if userCount > 1 {
|
||||
if userCount != 0 {
|
||||
return model.User{}, "", &common.SetupAlreadyCompletedError{}
|
||||
}
|
||||
|
||||
user := model.User{
|
||||
FirstName: "Admin",
|
||||
LastName: "Admin",
|
||||
Username: "admin",
|
||||
Email: "admin@admin.com",
|
||||
userToCreate := dto.UserCreateDto{
|
||||
FirstName: signUpData.FirstName,
|
||||
LastName: signUpData.LastName,
|
||||
Username: signUpData.Username,
|
||||
Email: signUpData.Email,
|
||||
IsAdmin: true,
|
||||
}
|
||||
|
||||
if err := tx.WithContext(ctx).Model(&model.User{}).Preload("Credentials").FirstOrCreate(&user).Error; err != nil {
|
||||
user, err := s.createUserInternal(ctx, userToCreate, false, tx)
|
||||
if err != nil {
|
||||
return model.User{}, "", err
|
||||
}
|
||||
|
||||
if len(user.Credentials) > 0 {
|
||||
return model.User{}, "", &common.SetupAlreadyCompletedError{}
|
||||
}
|
||||
|
||||
token, err := s.jwtService.GenerateAccessToken(user)
|
||||
if err != nil {
|
||||
return model.User{}, "", err
|
||||
@@ -630,6 +631,110 @@ func (s *UserService) disableUserInternal(ctx context.Context, userID string, tx
|
||||
Error
|
||||
}
|
||||
|
||||
func (s *UserService) CreateSignupToken(ctx context.Context, expiresAt time.Time, usageLimit int) (model.SignupToken, error) {
|
||||
return s.createSignupTokenInternal(ctx, expiresAt, usageLimit, s.db)
|
||||
}
|
||||
|
||||
func (s *UserService) createSignupTokenInternal(ctx context.Context, expiresAt time.Time, usageLimit int, tx *gorm.DB) (model.SignupToken, error) {
|
||||
signupToken, err := NewSignupToken(expiresAt, usageLimit)
|
||||
if err != nil {
|
||||
return model.SignupToken{}, err
|
||||
}
|
||||
|
||||
if err := tx.WithContext(ctx).Create(signupToken).Error; err != nil {
|
||||
return model.SignupToken{}, err
|
||||
}
|
||||
|
||||
return *signupToken, nil
|
||||
}
|
||||
|
||||
func (s *UserService) SignUp(ctx context.Context, signupData dto.SignUpDto, ipAddress, userAgent string) (model.User, string, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
tokenProvided := signupData.Token != ""
|
||||
|
||||
config := s.appConfigService.GetDbConfig()
|
||||
if config.AllowUserSignups.Value != "open" && !tokenProvided {
|
||||
return model.User{}, "", &common.OpenSignupDisabledError{}
|
||||
}
|
||||
|
||||
var signupToken model.SignupToken
|
||||
if tokenProvided {
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Where("token = ?", signupData.Token).
|
||||
First(&signupToken).
|
||||
Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return model.User{}, "", &common.TokenInvalidOrExpiredError{}
|
||||
}
|
||||
return model.User{}, "", err
|
||||
}
|
||||
|
||||
if !signupToken.IsValid() {
|
||||
return model.User{}, "", &common.TokenInvalidOrExpiredError{}
|
||||
}
|
||||
}
|
||||
|
||||
userToCreate := dto.UserCreateDto{
|
||||
Username: signupData.Username,
|
||||
Email: signupData.Email,
|
||||
FirstName: signupData.FirstName,
|
||||
LastName: signupData.LastName,
|
||||
}
|
||||
|
||||
user, err := s.createUserInternal(ctx, userToCreate, false, tx)
|
||||
if err != nil {
|
||||
return model.User{}, "", err
|
||||
}
|
||||
|
||||
accessToken, err := s.jwtService.GenerateAccessToken(user)
|
||||
if err != nil {
|
||||
return model.User{}, "", err
|
||||
}
|
||||
|
||||
if tokenProvided {
|
||||
s.auditLogService.Create(ctx, model.AuditLogEventAccountCreated, ipAddress, userAgent, user.ID, model.AuditLogData{
|
||||
"signupToken": signupToken.Token,
|
||||
}, tx)
|
||||
|
||||
signupToken.UsageCount++
|
||||
|
||||
err = tx.WithContext(ctx).Save(&signupToken).Error
|
||||
if err != nil {
|
||||
return model.User{}, "", err
|
||||
|
||||
}
|
||||
} else {
|
||||
s.auditLogService.Create(ctx, model.AuditLogEventAccountCreated, ipAddress, userAgent, user.ID, model.AuditLogData{
|
||||
"method": "open_signup",
|
||||
}, tx)
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return model.User{}, "", err
|
||||
}
|
||||
|
||||
return user, accessToken, nil
|
||||
}
|
||||
|
||||
func (s *UserService) ListSignupTokens(ctx context.Context, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.SignupToken, utils.PaginationResponse, error) {
|
||||
var tokens []model.SignupToken
|
||||
query := s.db.WithContext(ctx).Model(&model.SignupToken{})
|
||||
|
||||
pagination, err := utils.PaginateAndSort(sortedPaginationRequest, query, &tokens)
|
||||
return tokens, pagination, err
|
||||
}
|
||||
|
||||
func (s *UserService) DeleteSignupToken(ctx context.Context, tokenID string) error {
|
||||
return s.db.WithContext(ctx).Delete(&model.SignupToken{}, "id = ?", tokenID).Error
|
||||
}
|
||||
|
||||
func NewOneTimeAccessToken(userID string, expiresAt time.Time) (*model.OneTimeAccessToken, error) {
|
||||
// If expires at is less than 15 minutes, use a 6-character token instead of 16
|
||||
tokenLength := 16
|
||||
@@ -650,3 +755,20 @@ func NewOneTimeAccessToken(userID string, expiresAt time.Time) (*model.OneTimeAc
|
||||
|
||||
return o, nil
|
||||
}
|
||||
|
||||
func NewSignupToken(expiresAt time.Time, usageLimit int) (*model.SignupToken, error) {
|
||||
// Generate a random token
|
||||
randomString, err := utils.GenerateRandomAlphanumericString(16)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
token := &model.SignupToken{
|
||||
Token: randomString,
|
||||
ExpiresAt: datatype.DateTime(expiresAt),
|
||||
UsageLimit: usageLimit,
|
||||
UsageCount: 0,
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
24
backend/internal/utils/cmd_util.go
Normal file
24
backend/internal/utils/cmd_util.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// PromptForConfirmation prompts the user to answer "y" in the terminal
|
||||
func PromptForConfirmation(prompt string) (bool, error) {
|
||||
fmt.Print(prompt + " [y/N]: ")
|
||||
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
r, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
r = strings.TrimSpace(strings.ToLower(r))
|
||||
|
||||
ok := r == "yes" || r == "y"
|
||||
|
||||
return ok, nil
|
||||
}
|
||||
69
backend/internal/utils/crypto/crypto.go
Normal file
69
backend/internal/utils/crypto/crypto.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// ErrDecrypt is returned by Decrypt when the operation failed for any reason
|
||||
var ErrDecrypt = errors.New("failed to decrypt data")
|
||||
|
||||
// Encrypt a byte slice using AES-GCM and a random nonce
|
||||
// Important: do not encrypt more than ~4 billion messages with the same key!
|
||||
func Encrypt(key []byte, plaintext []byte, associatedData []byte) (ciphertext []byte, err error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create block cipher: %w", err)
|
||||
}
|
||||
aead, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create AEAD cipher: %w", err)
|
||||
}
|
||||
|
||||
// Generate a random nonce
|
||||
nonce := make([]byte, aead.NonceSize())
|
||||
_, err = io.ReadFull(rand.Reader, nonce)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate random nonce: %w", err)
|
||||
}
|
||||
|
||||
// Allocate the slice for the result, with additional space for the nonce and overhead
|
||||
ciphertext = make([]byte, 0, len(plaintext)+aead.NonceSize()+aead.Overhead())
|
||||
ciphertext = append(ciphertext, nonce...)
|
||||
|
||||
// Encrypt the plaintext
|
||||
// Tag is automatically added at the end
|
||||
ciphertext = aead.Seal(ciphertext, nonce, plaintext, associatedData)
|
||||
|
||||
return ciphertext, nil
|
||||
}
|
||||
|
||||
// Decrypt a byte slice using AES-GCM
|
||||
func Decrypt(key []byte, ciphertext []byte, associatedData []byte) (plaintext []byte, err error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create block cipher: %w", err)
|
||||
}
|
||||
aead, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create AEAD cipher: %w", err)
|
||||
}
|
||||
|
||||
// Extract the nonce
|
||||
if len(ciphertext) < (aead.NonceSize() + aead.Overhead()) {
|
||||
return nil, ErrDecrypt
|
||||
}
|
||||
|
||||
// Decrypt the data
|
||||
plaintext, err = aead.Open(nil, ciphertext[:aead.NonceSize()], ciphertext[aead.NonceSize():], associatedData)
|
||||
if err != nil {
|
||||
// Note: we do not return the exact error here, to avoid disclosing information
|
||||
return nil, ErrDecrypt
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
208
backend/internal/utils/crypto/crypto_test.go
Normal file
208
backend/internal/utils/crypto/crypto_test.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEncryptDecrypt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keySize int
|
||||
plaintext string
|
||||
associatedData []byte
|
||||
}{
|
||||
{
|
||||
name: "AES-128 with short plaintext",
|
||||
keySize: 16,
|
||||
plaintext: "Hello, World!",
|
||||
associatedData: []byte("test-aad"),
|
||||
},
|
||||
{
|
||||
name: "AES-192 with medium plaintext",
|
||||
keySize: 24,
|
||||
plaintext: "This is a longer message to test encryption and decryption",
|
||||
associatedData: []byte("associated-data-192"),
|
||||
},
|
||||
{
|
||||
name: "AES-256 with unicode",
|
||||
keySize: 32,
|
||||
plaintext: "Hello 世界! 🌍 Testing unicode characters", //nolint:gosmopolitan
|
||||
associatedData: []byte("unicode-test"),
|
||||
},
|
||||
{
|
||||
name: "No associated data",
|
||||
keySize: 32,
|
||||
plaintext: "Testing without associated data",
|
||||
associatedData: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Generate random key
|
||||
key := make([]byte, tt.keySize)
|
||||
_, err := rand.Read(key)
|
||||
require.NoError(t, err, "Failed to generate random key")
|
||||
|
||||
plaintext := []byte(tt.plaintext)
|
||||
|
||||
// Test encryption
|
||||
ciphertext, err := Encrypt(key, plaintext, tt.associatedData)
|
||||
require.NoError(t, err, "Encrypt should succeed")
|
||||
|
||||
// Verify ciphertext is different from plaintext (unless empty)
|
||||
if len(plaintext) > 0 {
|
||||
assert.NotEqual(t, plaintext, ciphertext)
|
||||
}
|
||||
|
||||
// Test decryption
|
||||
decrypted, err := Decrypt(key, ciphertext, tt.associatedData)
|
||||
require.NoError(t, err, "Decrypt should succeed")
|
||||
|
||||
// Verify decrypted text matches original
|
||||
assert.Equal(t, plaintext, decrypted, "Decrypted text should match original")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptWithInvalidKeySize(t *testing.T) {
|
||||
invalidKeySizes := []int{8, 12, 33, 47, 55, 128}
|
||||
|
||||
for _, keySize := range invalidKeySizes {
|
||||
t.Run(fmt.Sprintf("Key size %d", keySize), func(t *testing.T) {
|
||||
key := make([]byte, keySize)
|
||||
plaintext := []byte("test message")
|
||||
|
||||
_, err := Encrypt(key, plaintext, nil)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "invalid key size")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptWithInvalidKeySize(t *testing.T) {
|
||||
invalidKeySizes := []int{8, 12, 33, 47, 55, 128}
|
||||
|
||||
for _, keySize := range invalidKeySizes {
|
||||
t.Run(fmt.Sprintf("Key size %d", keySize), func(t *testing.T) {
|
||||
key := make([]byte, keySize)
|
||||
ciphertext := []byte("fake ciphertext")
|
||||
|
||||
_, err := Decrypt(key, ciphertext, nil)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "invalid key size")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptWithInvalidCiphertext(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
_, err := rand.Read(key)
|
||||
require.NoError(t, err, "Failed to generate random key")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ciphertext []byte
|
||||
}{
|
||||
{
|
||||
name: "empty ciphertext",
|
||||
ciphertext: []byte{},
|
||||
},
|
||||
{
|
||||
name: "too short ciphertext",
|
||||
ciphertext: []byte("short"),
|
||||
},
|
||||
{
|
||||
name: "random invalid data",
|
||||
ciphertext: []byte("this is not valid encrypted data"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := Decrypt(key, tt.ciphertext, nil)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, ErrDecrypt)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptWithWrongKey(t *testing.T) {
|
||||
// Generate two different keys
|
||||
key1 := make([]byte, 32)
|
||||
key2 := make([]byte, 32)
|
||||
_, err := rand.Read(key1)
|
||||
require.NoError(t, err)
|
||||
_, err = rand.Read(key2)
|
||||
require.NoError(t, err)
|
||||
|
||||
plaintext := []byte("secret message")
|
||||
|
||||
// Encrypt with key1
|
||||
ciphertext, err := Encrypt(key1, plaintext, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to decrypt with key2
|
||||
_, err = Decrypt(key2, ciphertext, nil)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, ErrDecrypt)
|
||||
}
|
||||
|
||||
func TestDecryptWithWrongAssociatedData(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
_, err := rand.Read(key)
|
||||
require.NoError(t, err, "Failed to generate random key")
|
||||
|
||||
plaintext := []byte("secret message")
|
||||
correctAAD := []byte("correct-aad")
|
||||
wrongAAD := []byte("wrong-aad")
|
||||
|
||||
// Encrypt with correct AAD
|
||||
ciphertext, err := Encrypt(key, plaintext, correctAAD)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to decrypt with wrong AAD
|
||||
_, err = Decrypt(key, ciphertext, wrongAAD)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, ErrDecrypt)
|
||||
|
||||
// Verify correct AAD works
|
||||
decrypted, err := Decrypt(key, ciphertext, correctAAD)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, plaintext, decrypted, "Decrypted text should match original when using correct AAD")
|
||||
}
|
||||
|
||||
func TestEncryptDecryptConsistency(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
_, err := rand.Read(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
plaintext := []byte("consistency test message")
|
||||
associatedData := []byte("test-aad")
|
||||
|
||||
// Encrypt multiple times and verify we get different ciphertexts (due to random IV)
|
||||
ciphertext1, err := Encrypt(key, plaintext, associatedData)
|
||||
require.NoError(t, err)
|
||||
|
||||
ciphertext2, err := Encrypt(key, plaintext, associatedData)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Ciphertexts should be different (due to random IV)
|
||||
assert.NotEqual(t, ciphertext1, ciphertext2, "Multiple encryptions of same plaintext should produce different ciphertexts")
|
||||
|
||||
// Both should decrypt to the same plaintext
|
||||
decrypted1, err := Decrypt(key, ciphertext1, associatedData)
|
||||
require.NoError(t, err)
|
||||
|
||||
decrypted2, err := Decrypt(key, ciphertext2, associatedData)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, plaintext, decrypted1, "First decrypted text should match original")
|
||||
assert.Equal(t, plaintext, decrypted2, "Second decrypted text should match original")
|
||||
assert.Equal(t, decrypted1, decrypted2, "Both decrypted texts should be identical")
|
||||
}
|
||||
33
backend/internal/utils/http_util.go
Normal file
33
backend/internal/utils/http_util.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// BearerAuth returns the value of the bearer token in the Authorization header if present
|
||||
func BearerAuth(r *http.Request) (string, bool) {
|
||||
const prefix = "bearer "
|
||||
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if len(authHeader) >= len(prefix) && strings.ToLower(authHeader[:len(prefix)]) == prefix {
|
||||
return authHeader[len(prefix):], true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
// SetCacheControlHeader sets the Cache-Control header for the response.
|
||||
func SetCacheControlHeader(ctx *gin.Context, maxAge, staleWhileRevalidate time.Duration) {
|
||||
_, ok := ctx.GetQuery("skipCache")
|
||||
if !ok {
|
||||
maxAgeSeconds := strconv.Itoa(int(maxAge.Seconds()))
|
||||
staleWhileRevalidateSeconds := strconv.Itoa(int(staleWhileRevalidate.Seconds()))
|
||||
ctx.Header("Cache-Control", "public, max-age="+maxAgeSeconds+", stale-while-revalidate="+staleWhileRevalidateSeconds)
|
||||
}
|
||||
|
||||
}
|
||||
65
backend/internal/utils/http_util_test.go
Normal file
65
backend/internal/utils/http_util_test.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBearerAuth(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
authHeader string
|
||||
expectedToken string
|
||||
expectedFound bool
|
||||
}{
|
||||
{
|
||||
name: "Valid bearer token",
|
||||
authHeader: "Bearer token123",
|
||||
expectedToken: "token123",
|
||||
expectedFound: true,
|
||||
},
|
||||
{
|
||||
name: "Valid bearer token with mixed case",
|
||||
authHeader: "beARer token456",
|
||||
expectedToken: "token456",
|
||||
expectedFound: true,
|
||||
},
|
||||
{
|
||||
name: "No bearer prefix",
|
||||
authHeader: "Basic dXNlcjpwYXNz",
|
||||
expectedToken: "",
|
||||
expectedFound: false,
|
||||
},
|
||||
{
|
||||
name: "Empty auth header",
|
||||
authHeader: "",
|
||||
expectedToken: "",
|
||||
expectedFound: false,
|
||||
},
|
||||
{
|
||||
name: "Bearer prefix only",
|
||||
authHeader: "Bearer ",
|
||||
expectedToken: "",
|
||||
expectedFound: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, "http://example.com", nil)
|
||||
require.NoError(t, err, "Failed to create request")
|
||||
|
||||
if tt.authHeader != "" {
|
||||
req.Header.Set("Authorization", tt.authHeader)
|
||||
}
|
||||
|
||||
token, found := BearerAuth(req)
|
||||
|
||||
assert.Equal(t, tt.expectedFound, found)
|
||||
assert.Equal(t, tt.expectedToken, token)
|
||||
})
|
||||
}
|
||||
}
|
||||
50
backend/internal/utils/jwk/key_provider.go
Normal file
50
backend/internal/utils/jwk/key_provider.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
)
|
||||
|
||||
type KeyProviderOpts struct {
|
||||
EnvConfig *common.EnvConfigSchema
|
||||
DB *gorm.DB
|
||||
Kek []byte
|
||||
}
|
||||
|
||||
type KeyProvider interface {
|
||||
Init(opts KeyProviderOpts) error
|
||||
LoadKey() (jwk.Key, error)
|
||||
SaveKey(key jwk.Key) error
|
||||
}
|
||||
|
||||
func GetKeyProvider(db *gorm.DB, envConfig *common.EnvConfigSchema, instanceID string) (keyProvider KeyProvider, err error) {
|
||||
// Load the encryption key (KEK) if present
|
||||
kek, err := LoadKeyEncryptionKey(envConfig, instanceID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load encryption key: %w", err)
|
||||
}
|
||||
|
||||
// Get the key provider
|
||||
switch envConfig.KeysStorage {
|
||||
case "file", "":
|
||||
keyProvider = &KeyProviderFile{}
|
||||
case "database":
|
||||
keyProvider = &KeyProviderDatabase{}
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid key storage '%s'", envConfig.KeysStorage)
|
||||
}
|
||||
err = keyProvider.Init(KeyProviderOpts{
|
||||
DB: db,
|
||||
EnvConfig: envConfig,
|
||||
Kek: kek,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to init key provider of type '%s': %w", envConfig.KeysStorage, err)
|
||||
}
|
||||
|
||||
return keyProvider, nil
|
||||
}
|
||||
109
backend/internal/utils/jwk/key_provider_database.go
Normal file
109
backend/internal/utils/jwk/key_provider_database.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto"
|
||||
)
|
||||
|
||||
const PrivateKeyDBKey = "jwt_private_key.json"
|
||||
|
||||
type KeyProviderDatabase struct {
|
||||
db *gorm.DB
|
||||
kek []byte
|
||||
}
|
||||
|
||||
func (f *KeyProviderDatabase) Init(opts KeyProviderOpts) error {
|
||||
if len(opts.Kek) == 0 {
|
||||
return errors.New("an encryption key is required when using the 'database' key provider")
|
||||
}
|
||||
|
||||
f.db = opts.DB
|
||||
f.kek = opts.Kek
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *KeyProviderDatabase) LoadKey() (key jwk.Key, err error) {
|
||||
row := model.KV{
|
||||
Key: PrivateKeyDBKey,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
err = f.db.WithContext(ctx).First(&row).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// Key not present in the database - return nil so a new one can be generated
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve private key from the database: %w", err)
|
||||
}
|
||||
|
||||
if row.Value == nil || *row.Value == "" {
|
||||
// Key not present in the database - return nil so a new one can be generated
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Decode from base64
|
||||
enc, err := base64.StdEncoding.DecodeString(*row.Value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read encrypted private key: not a valid base64-encoded value: %w", err)
|
||||
}
|
||||
|
||||
// Decrypt the data
|
||||
data, err := cryptoutils.Decrypt(f.kek, enc, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt private key: %w", err)
|
||||
}
|
||||
|
||||
// Parse the key
|
||||
key, err = jwk.ParseKey(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse encrypted private key: %w", err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (f *KeyProviderDatabase) SaveKey(key jwk.Key) error {
|
||||
// Encode the key to JSON
|
||||
data, err := EncodeJWKBytes(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encode key to JSON: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt the key then encode to Base64
|
||||
enc, err := cryptoutils.Encrypt(f.kek, data, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt key: %w", err)
|
||||
}
|
||||
encB64 := base64.StdEncoding.EncodeToString(enc)
|
||||
|
||||
// Save to database
|
||||
row := model.KV{
|
||||
Key: PrivateKeyDBKey,
|
||||
Value: &encB64,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
err = f.db.WithContext(ctx).Create(&row).Error
|
||||
if err != nil {
|
||||
// There's one scenario where if Pocket ID is started fresh with more than 1 replica, they both could be trying to create the private key in the database at the same time
|
||||
// In this case, only one of the replicas will succeed; the other one(s) will return an error here, which will cascade down and cause the replica(s) to crash and be restarted (at that point they'll load the then-existing key from the database)
|
||||
return fmt.Errorf("failed to store private key in database: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Compile-time interface check
|
||||
var _ KeyProvider = (*KeyProviderDatabase)(nil)
|
||||
275
backend/internal/utils/jwk/key_provider_database_test.go
Normal file
275
backend/internal/utils/jwk/key_provider_database_test.go
Normal file
@@ -0,0 +1,275 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto"
|
||||
testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestKeyProviderDatabase_Init(t *testing.T) {
|
||||
t.Run("Init fails when KEK is not provided", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
provider := &KeyProviderDatabase{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
DB: db,
|
||||
Kek: nil, // No KEK
|
||||
})
|
||||
require.Error(t, err, "Expected error when KEK is not provided")
|
||||
require.ErrorContains(t, err, "encryption key is required")
|
||||
})
|
||||
|
||||
t.Run("Init succeeds with KEK", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
provider := &KeyProviderDatabase{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
DB: db,
|
||||
Kek: generateTestKEK(t),
|
||||
})
|
||||
require.NoError(t, err, "Expected no error when KEK is provided")
|
||||
})
|
||||
}
|
||||
|
||||
func TestKeyProviderDatabase_LoadKey(t *testing.T) {
|
||||
// Generate a test key to use in our tests
|
||||
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
key, err := jwk.Import(pk)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("LoadKey with no existing key", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
kek := generateTestKEK(t)
|
||||
|
||||
provider := &KeyProviderDatabase{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
DB: db,
|
||||
Kek: kek,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Load key when none exists
|
||||
loadedKey, err := provider.LoadKey()
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, loadedKey, "Expected nil key when no key exists in database")
|
||||
})
|
||||
|
||||
t.Run("LoadKey with existing key", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
kek := generateTestKEK(t)
|
||||
|
||||
provider := &KeyProviderDatabase{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
DB: db,
|
||||
Kek: kek,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save a key
|
||||
err = provider.SaveKey(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Load the key
|
||||
loadedKey, err := provider.LoadKey()
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, loadedKey, "Expected non-nil key when key exists in database")
|
||||
|
||||
// Verify the loaded key is the same as the original
|
||||
keyBytes, err := EncodeJWKBytes(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
loadedKeyBytes, err := EncodeJWKBytes(loadedKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key")
|
||||
})
|
||||
|
||||
t.Run("LoadKey with invalid base64", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
kek := generateTestKEK(t)
|
||||
|
||||
provider := &KeyProviderDatabase{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
DB: db,
|
||||
Kek: kek,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert invalid base64 data
|
||||
invalidBase64 := "not-valid-base64"
|
||||
err = db.Create(&model.KV{
|
||||
Key: PrivateKeyDBKey,
|
||||
Value: &invalidBase64,
|
||||
}).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Attempt to load the key
|
||||
loadedKey, err := provider.LoadKey()
|
||||
require.Error(t, err, "Expected error when loading key with invalid base64")
|
||||
require.ErrorContains(t, err, "not a valid base64-encoded value")
|
||||
assert.Nil(t, loadedKey, "Expected nil key when loading fails")
|
||||
})
|
||||
|
||||
t.Run("LoadKey with invalid encrypted data", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
kek := generateTestKEK(t)
|
||||
|
||||
provider := &KeyProviderDatabase{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
DB: db,
|
||||
Kek: kek,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert valid base64 but invalid encrypted data
|
||||
invalidData := base64.StdEncoding.EncodeToString([]byte("not-valid-encrypted-data"))
|
||||
err = db.Create(&model.KV{
|
||||
Key: PrivateKeyDBKey,
|
||||
Value: &invalidData,
|
||||
}).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Attempt to load the key
|
||||
loadedKey, err := provider.LoadKey()
|
||||
require.Error(t, err, "Expected error when loading key with invalid encrypted data")
|
||||
require.ErrorContains(t, err, "failed to decrypt")
|
||||
assert.Nil(t, loadedKey, "Expected nil key when loading fails")
|
||||
})
|
||||
|
||||
t.Run("LoadKey with valid encrypted data but wrong KEK", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
originalKek := generateTestKEK(t)
|
||||
|
||||
// Save a key with the original KEK
|
||||
originalProvider := &KeyProviderDatabase{}
|
||||
err := originalProvider.Init(KeyProviderOpts{
|
||||
DB: db,
|
||||
Kek: originalKek,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = originalProvider.SaveKey(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now try to load with a different KEK
|
||||
differentKek := generateTestKEK(t)
|
||||
differentProvider := &KeyProviderDatabase{}
|
||||
err = differentProvider.Init(KeyProviderOpts{
|
||||
DB: db,
|
||||
Kek: differentKek,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Attempt to load the key with the wrong KEK
|
||||
loadedKey, err := differentProvider.LoadKey()
|
||||
require.Error(t, err, "Expected error when loading key with wrong KEK")
|
||||
require.ErrorContains(t, err, "failed to decrypt")
|
||||
assert.Nil(t, loadedKey, "Expected nil key when loading fails")
|
||||
})
|
||||
|
||||
t.Run("LoadKey with invalid key data", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
kek := generateTestKEK(t)
|
||||
|
||||
provider := &KeyProviderDatabase{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
DB: db,
|
||||
Kek: kek,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create invalid key data (valid JSON but not a valid JWK)
|
||||
invalidKeyData := []byte(`{"not": "a valid jwk"}`)
|
||||
|
||||
// Encrypt the invalid key data
|
||||
encryptedData, err := cryptoutils.Encrypt(kek, invalidKeyData, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Base64 encode the encrypted data
|
||||
encodedData := base64.StdEncoding.EncodeToString(encryptedData)
|
||||
|
||||
// Save to database
|
||||
err = db.Create(&model.KV{
|
||||
Key: PrivateKeyDBKey,
|
||||
Value: &encodedData,
|
||||
}).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Attempt to load the key
|
||||
loadedKey, err := provider.LoadKey()
|
||||
require.Error(t, err, "Expected error when loading invalid key data")
|
||||
require.ErrorContains(t, err, "failed to parse")
|
||||
assert.Nil(t, loadedKey, "Expected nil key when loading fails")
|
||||
})
|
||||
}
|
||||
|
||||
func TestKeyProviderDatabase_SaveKey(t *testing.T) {
|
||||
// Generate a test key to use in our tests
|
||||
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
key, err := jwk.Import(pk)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("SaveKey and verify database record", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
kek := generateTestKEK(t)
|
||||
|
||||
provider := &KeyProviderDatabase{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
DB: db,
|
||||
Kek: kek,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save the key
|
||||
err = provider.SaveKey(key)
|
||||
require.NoError(t, err, "Expected no error when saving key")
|
||||
|
||||
// Verify record exists in database
|
||||
var kv model.KV
|
||||
err = db.Where("key = ?", PrivateKeyDBKey).First(&kv).Error
|
||||
require.NoError(t, err, "Expected to find key in database")
|
||||
require.NotNil(t, kv.Value, "Expected non-nil value in database")
|
||||
assert.NotEmpty(t, *kv.Value, "Expected non-empty value in database")
|
||||
|
||||
// Decode and decrypt to verify content
|
||||
encBytes, err := base64.StdEncoding.DecodeString(*kv.Value)
|
||||
require.NoError(t, err, "Expected valid base64 encoding")
|
||||
|
||||
decBytes, err := cryptoutils.Decrypt(kek, encBytes, nil)
|
||||
require.NoError(t, err, "Expected valid encrypted data")
|
||||
|
||||
parsedKey, err := jwk.ParseKey(decBytes)
|
||||
require.NoError(t, err, "Expected valid JWK data")
|
||||
|
||||
// Compare keys
|
||||
keyBytes, err := EncodeJWKBytes(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
parsedKeyBytes, err := EncodeJWKBytes(parsedKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, keyBytes, parsedKeyBytes, "Expected saved key to match original key")
|
||||
})
|
||||
}
|
||||
|
||||
func generateTestKEK(t *testing.T) []byte {
|
||||
t.Helper()
|
||||
|
||||
// Generate a 32-byte kek
|
||||
kek := make([]byte, 32)
|
||||
_, err := rand.Read(kek)
|
||||
require.NoError(t, err)
|
||||
return kek
|
||||
}
|
||||
202
backend/internal/utils/jwk/key_provider_file.go
Normal file
202
backend/internal/utils/jwk/key_provider_file.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto"
|
||||
)
|
||||
|
||||
const (
|
||||
// PrivateKeyFile is the path in the data/keys folder where the key is stored
|
||||
// This is a JSON file containing a key encoded as JWK
|
||||
PrivateKeyFile = "jwt_private_key.json"
|
||||
|
||||
// PrivateKeyFileEncrypted is the path in the data/keys folder where the encrypted key is stored
|
||||
// This is a encrypted JSON file containing a key encoded as JWK
|
||||
PrivateKeyFileEncrypted = "jwt_private_key.json.enc"
|
||||
)
|
||||
|
||||
type KeyProviderFile struct {
|
||||
envConfig *common.EnvConfigSchema
|
||||
kek []byte
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) Init(opts KeyProviderOpts) error {
|
||||
f.envConfig = opts.EnvConfig
|
||||
f.kek = opts.Kek
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) LoadKey() (jwk.Key, error) {
|
||||
if len(f.kek) > 0 {
|
||||
return f.loadEncryptedKey()
|
||||
}
|
||||
return f.loadKey()
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) SaveKey(key jwk.Key) error {
|
||||
if len(f.kek) > 0 {
|
||||
return f.saveKeyEncrypted(key)
|
||||
}
|
||||
return f.saveKey(key)
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) loadKey() (jwk.Key, error) {
|
||||
var key jwk.Key
|
||||
|
||||
// First, check if we have a JWK file
|
||||
// If we do, then we just load that
|
||||
jwkPath := f.jwkPath()
|
||||
ok, err := utils.FileExists(jwkPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if private key file exists at path '%s': %w", jwkPath, err)
|
||||
}
|
||||
if !ok {
|
||||
// File doesn't exist, no key was loaded
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(jwkPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read private key file at path '%s': %w", jwkPath, err)
|
||||
}
|
||||
|
||||
key, err = jwk.ParseKey(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse private key file at path '%s': %w", jwkPath, err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) loadEncryptedKey() (key jwk.Key, err error) {
|
||||
// First, check if we have an encrypted JWK file
|
||||
// If we do, then we just load that
|
||||
encJwkPath := f.encJwkPath()
|
||||
ok, err := utils.FileExists(encJwkPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if encrypted private key file exists at path '%s': %w", encJwkPath, err)
|
||||
}
|
||||
if ok {
|
||||
encB64, err := os.ReadFile(encJwkPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read encrypted private key file at path '%s': %w", encJwkPath, err)
|
||||
}
|
||||
|
||||
// Decode from base64
|
||||
enc := make([]byte, base64.StdEncoding.DecodedLen(len(encB64)))
|
||||
n, err := base64.StdEncoding.Decode(enc, encB64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read encrypted private key file at path '%s': not a valid base64-encoded file: %w", encJwkPath, err)
|
||||
}
|
||||
|
||||
// Decrypt the data
|
||||
data, err := cryptoutils.Decrypt(f.kek, enc[:n], nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt private key file at path '%s': %w", encJwkPath, err)
|
||||
}
|
||||
|
||||
// Parse the key
|
||||
key, err = jwk.ParseKey(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse encrypted private key file at path '%s': %w", encJwkPath, err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// Check if we have an un-encrypted JWK file
|
||||
key, err = f.loadKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load un-encrypted key file: %w", err)
|
||||
}
|
||||
if key == nil {
|
||||
// No key exists, encrypted or un-encrypted
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// If we are here, we have loaded a key that was un-encrypted
|
||||
// We need to replace the plaintext key with the encrypted one before we return
|
||||
err = f.saveKeyEncrypted(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to save encrypted key file: %w", err)
|
||||
}
|
||||
jwkPath := f.jwkPath()
|
||||
err = os.Remove(jwkPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to remove un-encrypted key file at path '%s': %w", jwkPath, err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) saveKey(key jwk.Key) error {
|
||||
err := os.MkdirAll(f.envConfig.KeysPath, 0700)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create directory '%s' for key file: %w", f.envConfig.KeysPath, err)
|
||||
}
|
||||
|
||||
jwkPath := f.jwkPath()
|
||||
keyFile, err := os.OpenFile(jwkPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create key file at path '%s': %w", jwkPath, err)
|
||||
}
|
||||
defer keyFile.Close()
|
||||
|
||||
// Write the JSON file to disk
|
||||
err = EncodeJWK(keyFile, key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write key file at path '%s': %w", jwkPath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) saveKeyEncrypted(key jwk.Key) error {
|
||||
err := os.MkdirAll(f.envConfig.KeysPath, 0700)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create directory '%s' for encrypted key file: %w", f.envConfig.KeysPath, err)
|
||||
}
|
||||
|
||||
// Encode the key to JSON
|
||||
data, err := EncodeJWKBytes(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encode key to JSON: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt the key then encode to Base64
|
||||
enc, err := cryptoutils.Encrypt(f.kek, data, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt key: %w", err)
|
||||
}
|
||||
encB64 := make([]byte, base64.StdEncoding.EncodedLen(len(enc)))
|
||||
base64.StdEncoding.Encode(encB64, enc)
|
||||
|
||||
// Write to disk
|
||||
encJwkPath := f.encJwkPath()
|
||||
err = os.WriteFile(encJwkPath, encB64, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write encrypted key file at path '%s': %w", encJwkPath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) jwkPath() string {
|
||||
return filepath.Join(f.envConfig.KeysPath, PrivateKeyFile)
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) encJwkPath() string {
|
||||
return filepath.Join(f.envConfig.KeysPath, PrivateKeyFileEncrypted)
|
||||
}
|
||||
|
||||
// Compile-time interface check
|
||||
var _ KeyProvider = (*KeyProviderFile)(nil)
|
||||
320
backend/internal/utils/jwk/key_provider_file_test.go
Normal file
320
backend/internal/utils/jwk/key_provider_file_test.go
Normal file
@@ -0,0 +1,320 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto"
|
||||
)
|
||||
|
||||
func TestKeyProviderFile_LoadKey(t *testing.T) {
|
||||
// Generate a test key to use in our tests
|
||||
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
key, err := jwk.Import(pk)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("LoadKey with no existing key", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
provider := &KeyProviderFile{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
EnvConfig: &common.EnvConfigSchema{
|
||||
KeysPath: tempDir,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Load key when none exists
|
||||
loadedKey, err := provider.LoadKey()
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, loadedKey, "Expected nil key when no key exists")
|
||||
})
|
||||
|
||||
t.Run("LoadKey with no existing key (with kek)", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
provider := &KeyProviderFile{}
|
||||
err = provider.Init(KeyProviderOpts{
|
||||
EnvConfig: &common.EnvConfigSchema{
|
||||
KeysPath: tempDir,
|
||||
},
|
||||
Kek: makeKEK(t),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Load key when none exists
|
||||
loadedKey, err := provider.LoadKey()
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, loadedKey, "Expected nil key when no key exists")
|
||||
})
|
||||
|
||||
t.Run("LoadKey with unencrypted key", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
provider := &KeyProviderFile{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
EnvConfig: &common.EnvConfigSchema{
|
||||
KeysPath: tempDir,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save a key
|
||||
err = provider.SaveKey(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Make sure the key file exists
|
||||
keyPath := filepath.Join(tempDir, PrivateKeyFile)
|
||||
exists, err := utils.FileExists(keyPath)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Expected key file to exist")
|
||||
|
||||
// Load the key
|
||||
loadedKey, err := provider.LoadKey()
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, loadedKey, "Expected non-nil key when key exists")
|
||||
|
||||
// Verify the loaded key is the same as the original
|
||||
keyBytes, err := EncodeJWKBytes(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
loadedKeyBytes, err := EncodeJWKBytes(loadedKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key")
|
||||
})
|
||||
|
||||
t.Run("LoadKey with encrypted key", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
provider := &KeyProviderFile{}
|
||||
err = provider.Init(KeyProviderOpts{
|
||||
EnvConfig: &common.EnvConfigSchema{
|
||||
KeysPath: tempDir,
|
||||
},
|
||||
Kek: makeKEK(t),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save a key (will be encrypted)
|
||||
err = provider.SaveKey(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Make sure the encrypted key file exists
|
||||
encKeyPath := filepath.Join(tempDir, PrivateKeyFileEncrypted)
|
||||
exists, err := utils.FileExists(encKeyPath)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Expected encrypted key file to exist")
|
||||
|
||||
// Make sure the unencrypted key file does not exist
|
||||
keyPath := filepath.Join(tempDir, PrivateKeyFile)
|
||||
exists, err = utils.FileExists(keyPath)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Expected unencrypted key file to not exist")
|
||||
|
||||
// Load the key
|
||||
loadedKey, err := provider.LoadKey()
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, loadedKey, "Expected non-nil key when encrypted key exists")
|
||||
|
||||
// Verify the loaded key is the same as the original
|
||||
keyBytes, err := EncodeJWKBytes(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
loadedKeyBytes, err := EncodeJWKBytes(loadedKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key")
|
||||
})
|
||||
|
||||
t.Run("LoadKey replaces unencrypted key with encrypted key when kek is provided", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// First, create an unencrypted key
|
||||
providerNoKek := &KeyProviderFile{}
|
||||
err := providerNoKek.Init(KeyProviderOpts{
|
||||
EnvConfig: &common.EnvConfigSchema{
|
||||
KeysPath: tempDir,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save an unencrypted key
|
||||
err = providerNoKek.SaveKey(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify unencrypted key exists
|
||||
keyPath := filepath.Join(tempDir, PrivateKeyFile)
|
||||
exists, err := utils.FileExists(keyPath)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Expected unencrypted key file to exist")
|
||||
|
||||
// Now create a provider with a kek
|
||||
kek := make([]byte, 32)
|
||||
_, err = rand.Read(kek)
|
||||
require.NoError(t, err)
|
||||
|
||||
providerWithKek := &KeyProviderFile{}
|
||||
err = providerWithKek.Init(KeyProviderOpts{
|
||||
EnvConfig: &common.EnvConfigSchema{
|
||||
KeysPath: tempDir,
|
||||
},
|
||||
Kek: kek,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Load the key - this should convert the unencrypted key to encrypted
|
||||
loadedKey, err := providerWithKek.LoadKey()
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, loadedKey, "Expected non-nil key when loading and converting key")
|
||||
|
||||
// Verify the unencrypted key no longer exists
|
||||
exists, err = utils.FileExists(keyPath)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Expected unencrypted key file to be removed")
|
||||
|
||||
// Verify the encrypted key file exists
|
||||
encKeyPath := filepath.Join(tempDir, PrivateKeyFileEncrypted)
|
||||
exists, err = utils.FileExists(encKeyPath)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Expected encrypted key file to exist after conversion")
|
||||
|
||||
// Verify the key data
|
||||
keyBytes, err := EncodeJWKBytes(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
loadedKeyBytes, err := EncodeJWKBytes(loadedKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key after conversion")
|
||||
})
|
||||
}
|
||||
|
||||
func TestKeyProviderFile_SaveKey(t *testing.T) {
|
||||
// Generate a test key to use in our tests
|
||||
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
key, err := jwk.Import(pk)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("SaveKey unencrypted", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
provider := &KeyProviderFile{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
EnvConfig: &common.EnvConfigSchema{
|
||||
KeysPath: tempDir,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save the key
|
||||
err = provider.SaveKey(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the key file exists
|
||||
keyPath := filepath.Join(tempDir, PrivateKeyFile)
|
||||
exists, err := utils.FileExists(keyPath)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Expected key file to exist")
|
||||
|
||||
// Verify the content of the key file
|
||||
data, err := os.ReadFile(keyPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
parsedKey, err := jwk.ParseKey(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Compare the saved key with the original
|
||||
keyBytes, err := EncodeJWKBytes(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
parsedKeyBytes, err := EncodeJWKBytes(parsedKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, keyBytes, parsedKeyBytes, "Expected saved key to match original key")
|
||||
})
|
||||
|
||||
t.Run("SaveKey encrypted", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Generate a 64-byte kek
|
||||
kek := makeKEK(t)
|
||||
|
||||
provider := &KeyProviderFile{}
|
||||
err = provider.Init(KeyProviderOpts{
|
||||
EnvConfig: &common.EnvConfigSchema{
|
||||
KeysPath: tempDir,
|
||||
},
|
||||
Kek: kek,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save the key (will be encrypted)
|
||||
err = provider.SaveKey(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the encrypted key file exists
|
||||
encKeyPath := filepath.Join(tempDir, PrivateKeyFileEncrypted)
|
||||
exists, err := utils.FileExists(encKeyPath)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Expected encrypted key file to exist")
|
||||
|
||||
// Verify the unencrypted key file doesn't exist
|
||||
keyPath := filepath.Join(tempDir, PrivateKeyFile)
|
||||
exists, err = utils.FileExists(keyPath)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Expected unencrypted key file to not exist")
|
||||
|
||||
// Manually decrypt the encrypted key file to verify it contains the correct key
|
||||
encB64, err := os.ReadFile(encKeyPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Decode from base64
|
||||
enc := make([]byte, base64.StdEncoding.DecodedLen(len(encB64)))
|
||||
n, err := base64.StdEncoding.Decode(enc, encB64)
|
||||
require.NoError(t, err)
|
||||
enc = enc[:n] // Trim any padding
|
||||
|
||||
// Decrypt the data
|
||||
data, err := cryptoutils.Decrypt(kek, enc, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Parse the key
|
||||
parsedKey, err := jwk.ParseKey(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Compare the decrypted key with the original
|
||||
keyBytes, err := EncodeJWKBytes(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
parsedKeyBytes, err := EncodeJWKBytes(parsedKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, keyBytes, parsedKeyBytes, "Expected decrypted key to match original key")
|
||||
})
|
||||
}
|
||||
|
||||
func makeKEK(t *testing.T) []byte {
|
||||
t.Helper()
|
||||
|
||||
// Generate a 32-byte kek
|
||||
kek := make([]byte, 32)
|
||||
_, err := rand.Read(kek)
|
||||
require.NoError(t, err)
|
||||
return kek
|
||||
}
|
||||
180
backend/internal/utils/jwk/utils.go
Normal file
180
backend/internal/utils/jwk/utils.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/elliptic"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha3"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
)
|
||||
|
||||
const (
|
||||
// KeyUsageSigning is the usage for the private keys, for the "use" property
|
||||
KeyUsageSigning = "sig"
|
||||
)
|
||||
|
||||
// EncodeJWK encodes a jwk.Key to a writable stream.
|
||||
func EncodeJWK(w io.Writer, key jwk.Key) error {
|
||||
enc := json.NewEncoder(w)
|
||||
enc.SetEscapeHTML(false)
|
||||
return enc.Encode(key)
|
||||
}
|
||||
|
||||
// EncodeJWKBytes encodes a jwk.Key to a byte slice.
|
||||
func EncodeJWKBytes(key jwk.Key) ([]byte, error) {
|
||||
b := &bytes.Buffer{}
|
||||
err := EncodeJWK(b, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b.Bytes(), nil
|
||||
}
|
||||
|
||||
// LoadKeyEncryptionKey loads the key encryption key for JWKs
|
||||
func LoadKeyEncryptionKey(envConfig *common.EnvConfigSchema, instanceID string) (kek []byte, err error) {
|
||||
// Try getting the key from the env var as string
|
||||
kekInput := []byte(envConfig.EncryptionKey)
|
||||
|
||||
// If there's nothing in the env, try loading from file
|
||||
if len(kekInput) == 0 && envConfig.EncryptionKeyFile != "" {
|
||||
kekInput, err = os.ReadFile(envConfig.EncryptionKeyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read key file '%s': %w", envConfig.EncryptionKeyFile, err)
|
||||
}
|
||||
}
|
||||
|
||||
// If there's still no key, return
|
||||
if len(kekInput) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// We need a 256-bit key for encryption with AES-GCM-256
|
||||
// We use HMAC with SHA3-256 here to derive the key from the one passed as input
|
||||
// The key is tied to a specific instance of Pocket ID
|
||||
h := hmac.New(func() hash.Hash { return sha3.New256() }, kekInput)
|
||||
fmt.Fprint(h, "pocketid/"+instanceID+"/jwk-kek")
|
||||
kek = h.Sum(nil)
|
||||
|
||||
return kek, nil
|
||||
}
|
||||
|
||||
// ImportRawKey imports a crypto key in "raw" format (e.g. crypto.PrivateKey) into a jwk.Key.
|
||||
// It also populates additional fields such as the key ID, usage, and alg.
|
||||
func ImportRawKey(rawKey any, alg string, crv string) (jwk.Key, error) {
|
||||
key, err := jwk.Import(rawKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to import generated private key: %w", err)
|
||||
}
|
||||
|
||||
// Generate the key ID
|
||||
kid, err := generateRandomKeyID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate key ID: %w", err)
|
||||
}
|
||||
_ = key.Set(jwk.KeyIDKey, kid)
|
||||
|
||||
// Set other required fields
|
||||
_ = key.Set(jwk.KeyUsageKey, KeyUsageSigning)
|
||||
EnsureAlgInKey(key, alg, crv)
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// generateRandomKeyID generates a random key ID.
|
||||
func generateRandomKeyID() (string, error) {
|
||||
buf := make([]byte, 8)
|
||||
_, err := io.ReadFull(rand.Reader, buf)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read random bytes: %w", err)
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(buf), nil
|
||||
}
|
||||
|
||||
// EnsureAlgInKey ensures that the key contains an "alg" parameter (and "crv", if needed), set depending on the key type
|
||||
func EnsureAlgInKey(key jwk.Key, alg string, crv string) {
|
||||
_, ok := key.Algorithm()
|
||||
if ok {
|
||||
// Algorithm is already set
|
||||
return
|
||||
}
|
||||
|
||||
if alg != "" {
|
||||
_ = key.Set(jwk.AlgorithmKey, alg)
|
||||
if crv != "" {
|
||||
eca, ok := jwa.LookupEllipticCurveAlgorithm(crv)
|
||||
if ok {
|
||||
switch key.KeyType() {
|
||||
case jwa.EC():
|
||||
_ = key.Set(jwk.ECDSACrvKey, eca)
|
||||
case jwa.OKP():
|
||||
_ = key.Set(jwk.OKPCrvKey, eca)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// If we don't have an algorithm, set the default for the key type
|
||||
switch key.KeyType() {
|
||||
case jwa.RSA():
|
||||
// Default to RS256 for RSA keys
|
||||
_ = key.Set(jwk.AlgorithmKey, jwa.RS256())
|
||||
case jwa.EC():
|
||||
// Default to ES256 for ECDSA keys
|
||||
_ = key.Set(jwk.AlgorithmKey, jwa.ES256())
|
||||
_ = key.Set(jwk.ECDSACrvKey, jwa.P256())
|
||||
case jwa.OKP():
|
||||
// Default to EdDSA and Ed25519 for OKP keys
|
||||
_ = key.Set(jwk.AlgorithmKey, jwa.EdDSA())
|
||||
_ = key.Set(jwk.OKPCrvKey, jwa.Ed25519())
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateKey generates a new jwk.Key
|
||||
func GenerateKey(alg string, crv string) (key jwk.Key, err error) {
|
||||
var rawKey any
|
||||
switch alg {
|
||||
case jwa.RS256().String():
|
||||
rawKey, err = rsa.GenerateKey(rand.Reader, 2048)
|
||||
case jwa.RS384().String():
|
||||
rawKey, err = rsa.GenerateKey(rand.Reader, 3072)
|
||||
case jwa.RS512().String():
|
||||
rawKey, err = rsa.GenerateKey(rand.Reader, 4096)
|
||||
case jwa.ES256().String():
|
||||
rawKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
case jwa.ES384().String():
|
||||
rawKey, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
|
||||
case jwa.ES512().String():
|
||||
rawKey, err = ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
|
||||
case jwa.EdDSA().String():
|
||||
switch crv {
|
||||
case jwa.Ed25519().String():
|
||||
_, rawKey, err = ed25519.GenerateKey(rand.Reader)
|
||||
default:
|
||||
return nil, errors.New("unsupported curve for EdDSA algorithm")
|
||||
}
|
||||
default:
|
||||
return nil, errors.New("unsupported key algorithm")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate private key: %w", err)
|
||||
}
|
||||
|
||||
// Import the raw key
|
||||
return ImportRawKey(rawKey, alg, crv)
|
||||
}
|
||||
324
backend/internal/utils/jwk/utils_test.go
Normal file
324
backend/internal/utils/jwk/utils_test.go
Normal file
@@ -0,0 +1,324 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"testing"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGenerateKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
alg string
|
||||
crv string
|
||||
expectError bool
|
||||
expectedAlg jwa.SignatureAlgorithm
|
||||
}{
|
||||
{
|
||||
name: "RS256",
|
||||
alg: jwa.RS256().String(),
|
||||
crv: "",
|
||||
expectError: false,
|
||||
expectedAlg: jwa.RS256(),
|
||||
},
|
||||
{
|
||||
name: "RS384",
|
||||
alg: jwa.RS384().String(),
|
||||
crv: "",
|
||||
expectError: false,
|
||||
expectedAlg: jwa.RS384(),
|
||||
},
|
||||
// Skip the RS512 test as generating a RSA-4096 key can take some time
|
||||
/* {
|
||||
name: "RS512",
|
||||
alg: jwa.RS512().String(),
|
||||
crv: "",
|
||||
expectError: false,
|
||||
expectedAlg: jwa.RS512(),
|
||||
}, */
|
||||
{
|
||||
name: "ES256",
|
||||
alg: jwa.ES256().String(),
|
||||
crv: jwa.P256().String(),
|
||||
expectError: false,
|
||||
expectedAlg: jwa.ES256(),
|
||||
},
|
||||
{
|
||||
name: "ES384",
|
||||
alg: jwa.ES384().String(),
|
||||
crv: jwa.P384().String(),
|
||||
expectError: false,
|
||||
expectedAlg: jwa.ES384(),
|
||||
},
|
||||
{
|
||||
name: "ES512",
|
||||
alg: jwa.ES512().String(),
|
||||
crv: jwa.P521().String(),
|
||||
expectError: false,
|
||||
expectedAlg: jwa.ES512(),
|
||||
},
|
||||
{
|
||||
name: "EdDSA with Ed25519",
|
||||
alg: jwa.EdDSA().String(),
|
||||
crv: jwa.Ed25519().String(),
|
||||
expectError: false,
|
||||
expectedAlg: jwa.EdDSA(),
|
||||
},
|
||||
{
|
||||
name: "EdDSA with unsupported curve",
|
||||
alg: jwa.EdDSA().String(),
|
||||
crv: "unsupported",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Unsupported algorithm",
|
||||
alg: "UNSUPPORTED",
|
||||
crv: "",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
key, err := GenerateKey(tt.alg, tt.crv)
|
||||
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, key)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, key)
|
||||
|
||||
// Verify the algorithm is set correctly
|
||||
alg, ok := key.Algorithm()
|
||||
require.True(t, ok, "algorithm should be set in the key")
|
||||
assert.Equal(t, tt.expectedAlg.String(), alg.String())
|
||||
|
||||
// Verify other required fields are set
|
||||
kid, ok := key.KeyID()
|
||||
assert.True(t, ok, "key ID should be set")
|
||||
assert.NotEmpty(t, kid, "key ID should not be empty")
|
||||
|
||||
usage, ok := key.KeyUsage()
|
||||
assert.True(t, ok, "key usage should be set")
|
||||
assert.Equal(t, KeyUsageSigning, usage)
|
||||
|
||||
var crv any
|
||||
_ = key.Get("crv", &crv)
|
||||
|
||||
// Verify key type matches expected algorithm
|
||||
switch tt.expectedAlg {
|
||||
case jwa.RS256(), jwa.RS384(), jwa.RS512():
|
||||
assert.Equal(t, jwa.RSA(), key.KeyType())
|
||||
assert.Nil(t, crv)
|
||||
case jwa.ES256(), jwa.ES384(), jwa.ES512():
|
||||
assert.Equal(t, jwa.EC(), key.KeyType())
|
||||
eca, ok := crv.(jwa.EllipticCurveAlgorithm)
|
||||
_ = assert.NotNil(t, crv) &&
|
||||
assert.True(t, ok) &&
|
||||
assert.Equal(t, tt.crv, eca.String())
|
||||
case jwa.EdDSA():
|
||||
assert.Equal(t, jwa.OKP(), key.KeyType())
|
||||
eca, ok := crv.(jwa.EllipticCurveAlgorithm)
|
||||
_ = assert.NotNil(t, crv) &&
|
||||
assert.True(t, ok) &&
|
||||
assert.Equal(t, tt.crv, eca.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureAlgInKey(t *testing.T) {
|
||||
// Generate an RSA-2048 key
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("does not change alg already set", func(t *testing.T) {
|
||||
// Import the RSA key
|
||||
key, err := jwk.Import(rsaKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Pre-set the algorithm
|
||||
_ = key.Set(jwk.AlgorithmKey, jwa.RS256())
|
||||
|
||||
// Call EnsureAlgInKey with a different algorithm
|
||||
EnsureAlgInKey(key, jwa.RS384().String(), "")
|
||||
|
||||
// Verify the algorithm wasn't changed
|
||||
alg, ok := key.Algorithm()
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, jwa.RS256().String(), alg.String())
|
||||
})
|
||||
|
||||
t.Run("set algorithm to explicitly-provided value", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keyGen func() (any, error)
|
||||
alg string
|
||||
crv string
|
||||
expectedAlg jwa.SignatureAlgorithm
|
||||
expectedCrv string
|
||||
}{
|
||||
{
|
||||
name: "RSA key with RS384",
|
||||
keyGen: func() (any, error) {
|
||||
return rsaKey, nil
|
||||
},
|
||||
alg: jwa.RS384().String(),
|
||||
crv: "",
|
||||
expectedAlg: jwa.RS384(),
|
||||
expectedCrv: "",
|
||||
},
|
||||
{
|
||||
name: "ECDSA key with ES384",
|
||||
keyGen: func() (any, error) {
|
||||
return ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
},
|
||||
alg: jwa.ES384().String(),
|
||||
crv: jwa.P384().String(),
|
||||
expectedAlg: jwa.ES384(),
|
||||
expectedCrv: jwa.P384().String(),
|
||||
},
|
||||
{
|
||||
name: "Ed25519 key with EdDSA",
|
||||
keyGen: func() (any, error) {
|
||||
_, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
return priv, err
|
||||
},
|
||||
alg: jwa.EdDSA().String(),
|
||||
crv: jwa.Ed25519().String(),
|
||||
expectedAlg: jwa.EdDSA(),
|
||||
expectedCrv: jwa.Ed25519().String(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rawKey, err := tt.keyGen()
|
||||
require.NoError(t, err)
|
||||
|
||||
key, err := jwk.Import(rawKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Ensure no algorithm is set initially
|
||||
_, ok := key.Algorithm()
|
||||
assert.False(t, ok)
|
||||
|
||||
// Call EnsureAlgInKey
|
||||
EnsureAlgInKey(key, tt.alg, tt.crv)
|
||||
|
||||
// Verify the algorithm was set correctly
|
||||
alg, ok := key.Algorithm()
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, tt.expectedAlg.String(), alg.String())
|
||||
|
||||
// Verify curve if expected
|
||||
if tt.expectedCrv != "" {
|
||||
var crv any
|
||||
_ = key.Get("crv", &crv)
|
||||
require.NotNil(t, crv)
|
||||
eca, ok := crv.(jwa.EllipticCurveAlgorithm)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, tt.expectedCrv, eca.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("set default algorithms if not present", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keyGen func() (any, error)
|
||||
expectedAlg jwa.SignatureAlgorithm
|
||||
expectedCrv string
|
||||
}{
|
||||
{
|
||||
name: "RSA key defaults to RS256",
|
||||
keyGen: func() (any, error) {
|
||||
return rsaKey, nil
|
||||
},
|
||||
expectedAlg: jwa.RS256(),
|
||||
expectedCrv: "",
|
||||
},
|
||||
{
|
||||
name: "ECDSA key defaults to ES256 with P256",
|
||||
keyGen: func() (any, error) {
|
||||
return ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
},
|
||||
expectedAlg: jwa.ES256(),
|
||||
expectedCrv: jwa.P256().String(),
|
||||
},
|
||||
{
|
||||
name: "Ed25519 key defaults to EdDSA with Ed25519",
|
||||
keyGen: func() (any, error) {
|
||||
_, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
return priv, err
|
||||
},
|
||||
expectedAlg: jwa.EdDSA(),
|
||||
expectedCrv: jwa.Ed25519().String(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rawKey, err := tt.keyGen()
|
||||
require.NoError(t, err)
|
||||
|
||||
key, err := jwk.Import(rawKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Ensure no algorithm is set initially
|
||||
_, ok := key.Algorithm()
|
||||
assert.False(t, ok)
|
||||
|
||||
// Call EnsureAlgInKey with empty parameters
|
||||
EnsureAlgInKey(key, "", "")
|
||||
|
||||
// Verify the default algorithm was set
|
||||
alg, ok := key.Algorithm()
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, tt.expectedAlg.String(), alg.String())
|
||||
|
||||
// Verify curve if expected
|
||||
if tt.expectedCrv != "" {
|
||||
var crv any
|
||||
_ = key.Get("crv", &crv)
|
||||
require.NotNil(t, crv)
|
||||
eca, ok := crv.(jwa.EllipticCurveAlgorithm)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, tt.expectedCrv, eca.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid curve should not set curve parameter", func(t *testing.T) {
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
key, err := jwk.Import(rsaKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Call EnsureAlgInKey with invalid curve
|
||||
EnsureAlgInKey(key, jwa.RS256().String(), "invalid-curve")
|
||||
|
||||
// Verify algorithm was set but curve was not
|
||||
alg, ok := key.Algorithm()
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, jwa.RS256().String(), alg.String())
|
||||
|
||||
var crv any
|
||||
_ = key.Get("crv", &crv)
|
||||
assert.Nil(t, crv)
|
||||
})
|
||||
}
|
||||
20
backend/internal/utils/jwt_util.go
Normal file
20
backend/internal/utils/jwt_util.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||
)
|
||||
|
||||
func GetClaimsFromToken(token jwt.Token) (map[string]any, error) {
|
||||
keys := token.Keys()
|
||||
claims := make(map[string]any, len(keys))
|
||||
for _, key := range keys {
|
||||
var value any
|
||||
if err := token.Get(key, &value); err != nil {
|
||||
return nil, fmt.Errorf("failed to get claim %s: %w", key, err)
|
||||
}
|
||||
claims[key] = value
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
@@ -34,9 +34,12 @@ func PaginateAndSort(sortedPaginationRequest SortedPaginationRequest, query *gor
|
||||
|
||||
sortField, sortFieldFound := reflect.TypeOf(result).Elem().Elem().FieldByName(capitalizedSortColumn)
|
||||
isSortable, _ := strconv.ParseBool(sortField.Tag.Get("sortable"))
|
||||
isValidSortOrder := sort.Direction == "asc" || sort.Direction == "desc"
|
||||
|
||||
if sortFieldFound && isSortable && isValidSortOrder {
|
||||
if sort.Direction == "" || (sort.Direction != "asc" && sort.Direction != "desc") {
|
||||
sort.Direction = "asc"
|
||||
}
|
||||
|
||||
if sortFieldFound && isSortable {
|
||||
columnName := CamelCaseToSnakeCase(sort.Column)
|
||||
query = query.Clauses(clause.OrderBy{
|
||||
Columns: []clause.OrderByColumn{
|
||||
|
||||
51
backend/internal/utils/sqlite/sqlite_util.go
Normal file
51
backend/internal/utils/sqlite/sqlite_util.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
sqlitelib "github.com/glebarez/go-sqlite"
|
||||
"golang.org/x/text/unicode/norm"
|
||||
)
|
||||
|
||||
func RegisterSqliteFunctions() {
|
||||
// Register the `normalize(text, form)` function, which performs Unicode normalization on the text
|
||||
// This is currently only used in migration functions
|
||||
sqlitelib.MustRegisterDeterministicScalarFunction("normalize", 2, func(ctx *sqlitelib.FunctionContext, args []driver.Value) (driver.Value, error) {
|
||||
if len(args) != 2 {
|
||||
return nil, errors.New("normalize requires 2 arguments")
|
||||
}
|
||||
|
||||
arg0, ok := args[0].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("first argument for normalize is not a string: %T", args[0])
|
||||
}
|
||||
|
||||
arg1, ok := args[1].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("second argument for normalize is not a string: %T", args[1])
|
||||
}
|
||||
|
||||
var form norm.Form
|
||||
switch strings.ToLower(arg1) {
|
||||
case "nfc":
|
||||
form = norm.NFC
|
||||
case "nfd":
|
||||
form = norm.NFD
|
||||
case "nfkc":
|
||||
form = norm.NFKC
|
||||
case "nfkd":
|
||||
form = norm.NFKD
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported form: %s", arg1)
|
||||
}
|
||||
|
||||
if len(arg0) == 0 {
|
||||
return arg0, nil
|
||||
}
|
||||
|
||||
return form.String(arg0), nil
|
||||
})
|
||||
}
|
||||
77
backend/internal/utils/testing/database.go
Normal file
77
backend/internal/utils/testing/database.go
Normal file
@@ -0,0 +1,77 @@
|
||||
// This file is only imported by unit tests
|
||||
|
||||
package testing
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
sqliteMigrate "github.com/golang-migrate/migrate/v4/database/sqlite3"
|
||||
"github.com/golang-migrate/migrate/v4/source/iofs"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
sqliteutil "github.com/pocket-id/pocket-id/backend/internal/utils/sqlite"
|
||||
"github.com/pocket-id/pocket-id/backend/resources"
|
||||
)
|
||||
|
||||
func init() {
|
||||
sqliteutil.RegisterSqliteFunctions()
|
||||
}
|
||||
|
||||
// NewDatabaseForTest returns a new instance of GORM connected to an in-memory SQLite database.
|
||||
// Each database connection is unique for the test.
|
||||
// All migrations are automatically performed.
|
||||
func NewDatabaseForTest(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
// Get a name for this in-memory database that is specific to the test
|
||||
dbName := utils.CreateSha256Hash(t.Name())
|
||||
|
||||
// Connect to a new in-memory SQL database
|
||||
db, err := gorm.Open(
|
||||
sqlite.Open("file:"+dbName+"?mode=memory&cache=shared"),
|
||||
&gorm.Config{
|
||||
TranslateError: true,
|
||||
Logger: logger.New(
|
||||
testLoggerAdapter{t: t},
|
||||
logger.Config{
|
||||
SlowThreshold: 200 * time.Millisecond,
|
||||
LogLevel: logger.Info,
|
||||
IgnoreRecordNotFoundError: false,
|
||||
ParameterizedQueries: false,
|
||||
Colorful: false,
|
||||
},
|
||||
),
|
||||
})
|
||||
require.NoError(t, err, "Failed to connect to test database")
|
||||
|
||||
// Perform migrations with the embedded migrations
|
||||
sqlDB, err := db.DB()
|
||||
require.NoError(t, err, "Failed to get sql.DB")
|
||||
driver, err := sqliteMigrate.WithInstance(sqlDB, &sqliteMigrate.Config{})
|
||||
require.NoError(t, err, "Failed to create migration driver")
|
||||
source, err := iofs.New(resources.FS, "migrations/sqlite")
|
||||
require.NoError(t, err, "Failed to create embedded migration source")
|
||||
m, err := migrate.NewWithInstance("iofs", source, "pocket-id", driver)
|
||||
require.NoError(t, err, "Failed to create migration instance")
|
||||
err = m.Up()
|
||||
require.NoError(t, err, "Failed to perform migrations")
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
// Implements gorm's logger.Writer interface
|
||||
type testLoggerAdapter struct {
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func (l testLoggerAdapter) Printf(format string, args ...any) {
|
||||
l.t.Logf(format, args...)
|
||||
}
|
||||
38
backend/internal/utils/testing/round_tripper.go
Normal file
38
backend/internal/utils/testing/round_tripper.go
Normal file
@@ -0,0 +1,38 @@
|
||||
// This file is only imported by unit tests
|
||||
|
||||
package testing
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
)
|
||||
|
||||
// MockRoundTripper is a custom http.RoundTripper that returns responses based on the URL
|
||||
type MockRoundTripper struct {
|
||||
Err error
|
||||
Responses map[string]*http.Response
|
||||
}
|
||||
|
||||
// RoundTrip implements the http.RoundTripper interface
|
||||
func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Check if we have a specific response for this URL
|
||||
for url, resp := range m.Responses {
|
||||
if req.URL.String() == url {
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
return NewMockResponse(http.StatusNotFound, ""), nil
|
||||
}
|
||||
|
||||
// NewMockResponse creates an http.Response with the given status code and body
|
||||
func NewMockResponse(statusCode int, body string) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: statusCode,
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
Header: make(http.Header),
|
||||
}
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
@@ -0,0 +1 @@
|
||||
ALTER TABLE oidc_clients DROP COLUMN credentials;
|
||||
@@ -0,0 +1 @@
|
||||
ALTER TABLE oidc_clients ADD COLUMN credentials JSONB NULL;
|
||||
@@ -0,0 +1 @@
|
||||
DROP INDEX idx_audit_logs_country;
|
||||
@@ -0,0 +1 @@
|
||||
CREATE INDEX idx_audit_logs_country ON audit_logs(country);
|
||||
@@ -0,0 +1,3 @@
|
||||
DROP INDEX IF EXISTS idx_signup_tokens_expires_at;
|
||||
DROP INDEX IF EXISTS idx_signup_tokens_token;
|
||||
DROP TABLE IF EXISTS signup_tokens;
|
||||
@@ -0,0 +1,11 @@
|
||||
CREATE TABLE signup_tokens (
|
||||
id UUID NOT NULL PRIMARY KEY,
|
||||
created_at TIMESTAMPTZ NOT NULL,
|
||||
token VARCHAR(255) NOT NULL UNIQUE,
|
||||
expires_at TIMESTAMPTZ NOT NULL,
|
||||
usage_limit INTEGER NOT NULL DEFAULT 1,
|
||||
usage_count INTEGER NOT NULL DEFAULT 0
|
||||
);
|
||||
|
||||
CREATE INDEX idx_signup_tokens_token ON signup_tokens(token);
|
||||
CREATE INDEX idx_signup_tokens_expires_at ON signup_tokens(expires_at);
|
||||
@@ -0,0 +1,4 @@
|
||||
ALTER TABLE audit_logs ALTER COLUMN ip_address SET NOT NULL;
|
||||
|
||||
DROP INDEX IF EXISTS idx_audit_logs_created_at;
|
||||
DROP INDEX IF EXISTS idx_audit_logs_user_agent;
|
||||
@@ -0,0 +1,5 @@
|
||||
ALTER TABLE audit_logs ALTER COLUMN ip_address DROP NOT NULL;
|
||||
|
||||
-- Add missing indexes
|
||||
CREATE INDEX idx_audit_logs_created_at ON audit_logs(created_at);
|
||||
CREATE INDEX idx_audit_logs_user_agent ON audit_logs(user_agent);
|
||||
@@ -0,0 +1 @@
|
||||
DROP TABLE kv;
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user