diff --git a/package-lock.json b/package-lock.json index efae2e8..08b7af7 100644 --- a/package-lock.json +++ b/package-lock.json @@ -17995,7 +17995,7 @@ }, "packages/appdev-common": { "name": "@dvsa/appdev-api-common", - "version": "2.0.1", + "version": "2.0.2", "dependencies": { "@aws-sdk/client-sts": "^3.989.0", "ajv": "^8.18.0", diff --git a/packages/appdev-common/package.json b/packages/appdev-common/package.json index 0ac63a7..308e08f 100644 --- a/packages/appdev-common/package.json +++ b/packages/appdev-common/package.json @@ -1,6 +1,6 @@ { "name": "@dvsa/appdev-api-common", - "version": "2.0.1", + "version": "2.0.2", "keywords": [ "dvsa", "nodejs", diff --git a/packages/appdev-common/src/auth/__tests__/verify-jwt.spec.ts b/packages/appdev-common/src/auth/__tests__/verify-jwt.spec.ts new file mode 100644 index 0000000..dbe95aa --- /dev/null +++ b/packages/appdev-common/src/auth/__tests__/verify-jwt.spec.ts @@ -0,0 +1,64 @@ +const createRemoteJWKSetMock = jest.fn((url: URL) => ({ url })); +const jwtVerifyMock = jest.fn(); + +jest.mock("jose", () => ({ + createRemoteJWKSet: (url: URL) => createRemoteJWKSetMock(url), + jwtVerify: (...args: unknown[]) => jwtVerifyMock(...args), +})); + +describe("JwtAuthoriser", () => { + beforeEach(() => { + jest.resetModules(); + createRemoteJWKSetMock.mockClear(); + jwtVerifyMock.mockReset(); + process.env.environment = "PRODUCTION"; + }); + + it("should use a tenant-specific JWKS endpoint when tenantId is provided", async () => { + jwtVerifyMock.mockResolvedValue({ payload: { sub: "user-1" } }); + + const { JwtAuthoriser } = await import("../verify-jwt"); + + await new JwtAuthoriser("client-id", "tenant-123").verify("token"); + + expect(createRemoteJWKSetMock).toHaveBeenCalledWith( + new URL("https://login.microsoftonline.com/tenant-123/discovery/keys"), + ); + expect(jwtVerifyMock).toHaveBeenCalledWith( + "token", + { + url: new URL( + "https://login.microsoftonline.com/tenant-123/discovery/keys", + ), + }, + expect.objectContaining({ + audience: ["client-id"], + issuer: [ + "https://sts.windows.net/tenant-123/", + "https://login.microsoftonline.com/tenant-123/v2.0", + ], + }), + ); + }); + + it("should fall back to the common JWKS endpoint when tenantId is not provided", async () => { + jwtVerifyMock.mockResolvedValue({ payload: { sub: "user-1" } }); + + const { JwtAuthoriser } = await import("../verify-jwt"); + + await new JwtAuthoriser("client-id").verify("token"); + + expect(createRemoteJWKSetMock).toHaveBeenCalledWith( + new URL("https://login.microsoftonline.com/common/discovery/keys"), + ); + expect(jwtVerifyMock).toHaveBeenCalledWith( + "token", + { + url: new URL("https://login.microsoftonline.com/common/discovery/keys"), + }, + expect.objectContaining({ + audience: ["client-id"], + }), + ); + }); +}); diff --git a/packages/appdev-common/src/auth/verify-jwt.ts b/packages/appdev-common/src/auth/verify-jwt.ts index fa970d3..6990892 100644 --- a/packages/appdev-common/src/auth/verify-jwt.ts +++ b/packages/appdev-common/src/auth/verify-jwt.ts @@ -15,10 +15,13 @@ export class JwtAuthoriser { "DEVELOPMENT", "NON-PROD", ]; - private static readonly JWKS_URI = new URL( - "https://login.microsoftonline.com/common/discovery/keys", - ); - private static JWKS = createRemoteJWKSet(JwtAuthoriser.JWKS_URI); + private static readonly DEFAULT_TENANT = "common"; + private static readonly MICROSOFT_LOGIN_BASE_URL = + "https://login.microsoftonline.com"; + private static readonly jwksByTenant = new Map< + string, + ReturnType + >(); /** * Create a new instance of the JwtAuthoriser class @@ -33,6 +36,31 @@ export class JwtAuthoriser { this.tenantId = tenantId; } + private static getTenantSegment(tenantId: string | null): string { + return tenantId?.trim() || JwtAuthoriser.DEFAULT_TENANT; + } + + private static getJwks( + tenantId: string | null, + ): ReturnType { + const tenantSegment = JwtAuthoriser.getTenantSegment(tenantId); + const cachedJwks = JwtAuthoriser.jwksByTenant.get(tenantSegment); + + if (cachedJwks) { + return cachedJwks; + } + + const jwks = createRemoteJWKSet( + new URL( + `${JwtAuthoriser.MICROSOFT_LOGIN_BASE_URL}/${tenantSegment}/discovery/keys`, + ), + ); + + JwtAuthoriser.jwksByTenant.set(tenantSegment, jwks); + + return jwks; + } + /** * Validate a JWT and return the decoded payload * @param {string} token - the JWT token to validate @@ -67,7 +95,11 @@ export class JwtAuthoriser { opts.maxTokenAge = Number.POSITIVE_INFINITY; } - const { payload } = await jwtVerify(token, JwtAuthoriser.JWKS, opts); + const { payload } = await jwtVerify( + token, + JwtAuthoriser.getJwks(this.tenantId), + opts, + ); return payload; } catch (err) {