Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion packages/appdev-common/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@dvsa/appdev-api-common",
"version": "2.0.1",
"version": "2.0.2",
"keywords": [
"dvsa",
"nodejs",
Expand Down
64 changes: 64 additions & 0 deletions packages/appdev-common/src/auth/__tests__/verify-jwt.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
const createRemoteJWKSetMock = jest.fn((url: URL) => ({ url }));
const jwtVerifyMock = jest.fn();

jest.mock("jose", () => ({
createRemoteJWKSet: (url: URL) => createRemoteJWKSetMock(url),
jwtVerify: (...args: unknown[]) => jwtVerifyMock(...args),
}));

describe("JwtAuthoriser", () => {
beforeEach(() => {
jest.resetModules();
createRemoteJWKSetMock.mockClear();
jwtVerifyMock.mockReset();
process.env.environment = "PRODUCTION";
});

it("should use a tenant-specific JWKS endpoint when tenantId is provided", async () => {
jwtVerifyMock.mockResolvedValue({ payload: { sub: "user-1" } });

const { JwtAuthoriser } = await import("../verify-jwt");

await new JwtAuthoriser("client-id", "tenant-123").verify("token");

expect(createRemoteJWKSetMock).toHaveBeenCalledWith(
new URL("https://login.microsoftonline.com/tenant-123/discovery/keys"),
);
expect(jwtVerifyMock).toHaveBeenCalledWith(
"token",
{
url: new URL(
"https://login.microsoftonline.com/tenant-123/discovery/keys",
),
},
expect.objectContaining({
audience: ["client-id"],
issuer: [
"https://sts.windows.net/tenant-123/",
"https://login.microsoftonline.com/tenant-123/v2.0",
],
}),
);
});

it("should fall back to the common JWKS endpoint when tenantId is not provided", async () => {
jwtVerifyMock.mockResolvedValue({ payload: { sub: "user-1" } });

const { JwtAuthoriser } = await import("../verify-jwt");

await new JwtAuthoriser("client-id").verify("token");

expect(createRemoteJWKSetMock).toHaveBeenCalledWith(
new URL("https://login.microsoftonline.com/common/discovery/keys"),
);
expect(jwtVerifyMock).toHaveBeenCalledWith(
"token",
{
url: new URL("https://login.microsoftonline.com/common/discovery/keys"),
},
expect.objectContaining({
audience: ["client-id"],
}),
);
});
});
Comment thread
dhcrees marked this conversation as resolved.
42 changes: 37 additions & 5 deletions packages/appdev-common/src/auth/verify-jwt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@ export class JwtAuthoriser {
"DEVELOPMENT",
"NON-PROD",
];
private static readonly JWKS_URI = new URL(
"https://login.microsoftonline.com/common/discovery/keys",
);
private static JWKS = createRemoteJWKSet(JwtAuthoriser.JWKS_URI);
private static readonly DEFAULT_TENANT = "common";
private static readonly MICROSOFT_LOGIN_BASE_URL =
"https://login.microsoftonline.com";
private static readonly jwksByTenant = new Map<
string,
ReturnType<typeof createRemoteJWKSet>
>();

/**
* Create a new instance of the JwtAuthoriser class
Expand All @@ -33,6 +36,31 @@ export class JwtAuthoriser {
this.tenantId = tenantId;
}

private static getTenantSegment(tenantId: string | null): string {
return tenantId?.trim() || JwtAuthoriser.DEFAULT_TENANT;
}

private static getJwks(
tenantId: string | null,
): ReturnType<typeof createRemoteJWKSet> {
const tenantSegment = JwtAuthoriser.getTenantSegment(tenantId);
const cachedJwks = JwtAuthoriser.jwksByTenant.get(tenantSegment);

if (cachedJwks) {
return cachedJwks;
}

const jwks = createRemoteJWKSet(
new URL(
`${JwtAuthoriser.MICROSOFT_LOGIN_BASE_URL}/${tenantSegment}/discovery/keys`,
),
);

JwtAuthoriser.jwksByTenant.set(tenantSegment, jwks);

return jwks;
}

/**
* Validate a JWT and return the decoded payload
* @param {string} token - the JWT token to validate
Expand Down Expand Up @@ -67,7 +95,11 @@ export class JwtAuthoriser {
opts.maxTokenAge = Number.POSITIVE_INFINITY;
}

const { payload } = await jwtVerify(token, JwtAuthoriser.JWKS, opts);
const { payload } = await jwtVerify(
token,
JwtAuthoriser.getJwks(this.tenantId),
opts,
);

return payload;
} catch (err) {
Expand Down