diff --git a/msal4j-mtls-extensions/README.md b/msal4j-mtls-extensions/README.md new file mode 100644 index 00000000..b1b06c50 --- /dev/null +++ b/msal4j-mtls-extensions/README.md @@ -0,0 +1,200 @@ +# msal4j-mtls-extensions + +This extension enables mTLS Proof-of-Possession (mTLS PoP) token acquisition for Azure Managed Identity in Java applications. It uses [JNA](https://github.com/java-native-access/jna) to call Windows CNG (`ncrypt.dll`) directly, creating and using KeyGuard-isolated private keys in-process — no .NET runtime or subprocess required. + +The latest code resides in the `dev` branch. + +Quick links: + +| [Docs](../../msal4j-sdk/docs/mtls-pop.md) | [Manual Testing](../../msal4j-sdk/docs/mtls-pop-manual-testing.md) | [Architecture](../../msal4j-sdk/docs/mtls-pop-architecture.md) | [Support](README.md#community-help-and-support) | +| --- | --- | --- | --- | + +## Installation + +### Requirements + +- Windows x64 Azure VM with Managed Identity enabled +- Java 8 or higher +- `AttestationClientLib.dll` on `PATH`, when using Trusted Launch VMs with attestation (see [Attestation DLL](#attestationclientlibdll) below) + +### Adding the dependency + +```xml + + com.microsoft.azure + msal4j-mtls-extensions + 1.0.0 + +``` + +`msal4j-mtls-extensions` depends on `msal4j` transitively — you do not need to declare `msal4j` separately. + +### AttestationClientLib.dll + +On Trusted Launch VMs, the IMDS `/issuecredential` endpoint requires a MAA attestation JWT proving the key was created in a VBS-isolated enclave. This JWT is produced by `AttestationClientLib.dll`, distributed via the `Microsoft.Azure.Security.KeyGuardAttestation` NuGet package. + +To obtain the DLL: + +```powershell +# Download the NuGet package +dotnet add package Microsoft.Azure.Security.KeyGuardAttestation --package-directory C:\nuget + +# Copy the DLL next to your application or to a directory on PATH +$dll = Get-ChildItem C:\nuget\microsoft.azure.security.keyguardattestation -Recurse -Filter "AttestationClientLib.dll" | Select-Object -First 1 +Copy-Item $dll.FullName C:\your-app\ +``` + +Unlike msal-dotnet, which receives this DLL automatically via NuGet, Java applications must place it on `PATH` or in the application directory. If your VM does not use Trusted Launch, pass `withAttestation: false` and no DLL is needed. + +## Usage + +Before using this extension, ensure Managed Identity is enabled on your Azure VM. + +### Path 2 — Managed Identity: Acquiring an mTLS PoP Token + +Acquiring a token follows this general pattern: + +1. Create a client and call `acquireToken()`. + + * System-assigned Managed Identity: + + ```java + import com.microsoft.aad.msal4j.mtls.*; + + MtlsMsiClient client = new MtlsMsiClient(); + MtlsMsiHelperResult result = client.acquireToken( + "https://graph.microsoft.com", // resource (confirmed enrolled for mTLS PoP) + "SystemAssigned", // identity type + null, // identity id (null for system-assigned) + false, // withAttestation — set true on Trusted Launch VMs + null // correlationId (optional) + ); + String accessToken = result.getAccessToken(); + ``` + + * User-assigned Managed Identity: + + ```java + MtlsMsiHelperResult result = client.acquireToken( + "https://graph.microsoft.com", + "UserAssigned", + "your-client-id", + false, + null + ); + String accessToken = result.getAccessToken(); + ``` + + > **Resource note:** Use `https://graph.microsoft.com` or `https://storage.azure.com`. `https://management.azure.com` may return `AADSTS392196` if that resource is not enrolled for mTLS PoP in your tenant. + +2. The binding certificate is cached in-process for the lifetime of the IMDS-issued certificate (minus a 5-minute safety margin). Subsequent calls return the cached token until it nears expiry. + +### Path 2 — Making Downstream mTLS Calls + +Once you have a token, use `httpRequest()` to make downstream calls over the same KeyGuard-backed mTLS channel: + +```java +MtlsMsiHttpResponse response = client.httpRequest( + "https://myservice.example.com/api", // URL + "GET", // method + result.getAccessToken(), // bearer token + null, // body + null, // contentType + null, // extra headers + "https://graph.microsoft.com", // resource (for cert refresh) + "SystemAssigned", null, // identity type, identity id + false, // withAttestation + null, // correlationId + false // allowInsecureTls +); +System.out.println(response.getStatus()); // e.g. 200 +System.out.println(response.getBody()); +``` + +The downstream server must be configured to *require* mutual TLS — it must send a TLS `CertificateRequest` during the handshake. + +### Path 1 — Confidential Client (SNI Certificate) + +For applications with an SNI certificate (e.g., from OneCert/DSMS), use `ConfidentialClientApplication` from the core `msal4j` library: + +```java +import com.microsoft.aad.msal4j.*; +import java.io.FileInputStream; + +// 1. Load your certificate (PKCS12) +IClientCertificate cert = ClientCredentialFactory.createFromCertificate( + new FileInputStream("cert.p12"), "password"); + +// 2. Build the app — tenanted authority and region required +ConfidentialClientApplication app = ConfidentialClientApplication + .builder("your-client-id", cert) + .authority("https://login.microsoftonline.com/your-tenant-id") + .azureRegion("centraluseuap") + .build(); + +// 3. Acquire an mTLS PoP token +IAuthenticationResult result = app.acquireToken( + ClientCredentialParameters + .builder(Collections.singleton("https://graph.microsoft.com/.default")) + .withMtlsProofOfPossession() + .build() +).get(); + +System.out.println("Token type: " + result.tokenType()); // "mtls_pop" +System.out.println("Binding cert: " + result.bindingCertificate().getSubjectX500Principal()); +System.out.println("Access token: " + result.accessToken()); +``` + +**Requirements:** Certificate credential, tenanted authority (not `/common` or `/organizations`), Azure region. + +--- + +## End-to-End Test Driver + +The `msal4j-mtls-extensions` module ships an e2e fat JAR for manual testing: + +```powershell +# Build +mvn package -DskipTests + +# Path 1 — error-case validation (no Azure credentials required) +java -jar target\msal4j-mtls-extensions-1.0.0-e2e.jar path1 --errors-only + +# Path 1 — full happy path +java -jar target\msal4j-mtls-extensions-1.0.0-e2e.jar path1 ` + --tenant --client --region centraluseuap + +# Path 2 — Managed Identity (with attestation) +java -Djava.library.path=C:\msiv2 ` + -jar target\msal4j-mtls-extensions-1.0.0-e2e.jar path2 --attest +``` + +See [Manual Testing Guide](../../msal4j-sdk/docs/mtls-pop-manual-testing.md) for full instructions. + +## Community Help and Support + +We use [Stack Overflow](http://stackoverflow.com/questions/tagged/msal) to work with the community on supporting Azure Active Directory and its SDKs, including this one! We highly recommend you ask your questions on Stack Overflow (we're all on there!) Also browse existing issues to see if someone has had your question before. Please use the "msal" tag when asking your questions. + +If you find a bug or have a feature request, please raise the issue on [GitHub Issues](https://github.com/AzureAD/microsoft-authentication-library-for-java/issues). + +## Submit Feedback + +We'd like your thoughts on this library. Please complete [this short survey.](https://forms.office.com/r/6AhHwQp3pe) + +## Contributing + +This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. + +When you submit a pull request, a CLA bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA. + +This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. + +## Security Library + +This library controls how users sign in and access services. We recommend you always take the latest version of our library in your app when possible. We use [semantic versioning](http://semver.org) so you can control the risk associated with updating your app. As an example, always downloading the latest minor version number (e.g. x.*y*.x) ensures you get the latest security and feature enhancements but our API surface remains the same. You can always see the latest version and release notes under the Releases tab of GitHub. + +## Security Reporting + +If you find a security issue with our libraries or services please report it to [secure@microsoft.com](mailto:secure@microsoft.com) with as much detail as possible. Your submission may be eligible for a bounty through the [Microsoft Bounty](http://aka.ms/bugbounty) program. Please do not post security issues to GitHub Issues or any other public site. We will contact you shortly upon receiving the information. We encourage you to get notifications of when security incidents occur by visiting [this page](https://technet.microsoft.com/en-us/security/dd252948) and subscribing to Security Advisory Alerts. + +Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the MIT License (the "License"). diff --git a/msal4j-mtls-extensions/pom.xml b/msal4j-mtls-extensions/pom.xml new file mode 100644 index 00000000..771964a3 --- /dev/null +++ b/msal4j-mtls-extensions/pom.xml @@ -0,0 +1,151 @@ + + + 4.0.0 + + com.microsoft.azure + msal4j-mtls-extensions + 1.0.0 + jar + + Microsoft Authentication Library for Java - mTLS Extensions + + Extension package that enables mTLS Proof-of-Possession (mTLS PoP) token acquisition + for Azure Managed Identity scenarios requiring KeyGuard-bound certificates. Uses JNA + to call Windows CNG (ncrypt.dll) and AttestationClientLib.dll directly from Java, + implementing a java.security.Provider that allows JSSE to use a non-exportable + KeyGuard RSA key during the TLS handshake. No .NET runtime or subprocess required. + + + + 8 + 8 + UTF-8 + + + + + com.microsoft.azure + msal4j + 1.24.0 + + + + + net.java.dev.jna + jna + 5.14.0 + + + + + org.junit.jupiter + junit-jupiter-api + 5.10.0 + test + + + org.junit.jupiter + junit-jupiter-engine + 5.10.0 + test + + + org.mockito + mockito-core + 5.4.0 + test + + + org.mockito + mockito-junit-jupiter + 5.4.0 + test + + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + 3.1.2 + + + + -XX:+EnableDynamicAgentLoading + --add-opens=java.base/java.lang=ALL-UNNAMED + --add-opens=java.base/java.security=ALL-UNNAMED + + + + + + + org.codehaus.mojo + build-helper-maven-plugin + 3.5.0 + + + add-e2e-sources + generate-sources + add-source + + + src/e2e/java + + + + + add-e2e-resources + generate-resources + add-resource + + + + src/e2e/resources + + + + + + + + + + org.apache.maven.plugins + maven-shade-plugin + 3.5.1 + + + package + shade + + true + e2e + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + META-INF/*.EC + + + + + + com.microsoft.aad.msal4j.mtls.e2e.E2ETestRunner + + + + + + + + + diff --git a/msal4j-mtls-extensions/src/e2e/java/com/microsoft/aad/msal4j/mtls/e2e/E2ETestRunner.java b/msal4j-mtls-extensions/src/e2e/java/com/microsoft/aad/msal4j/mtls/e2e/E2ETestRunner.java new file mode 100644 index 00000000..754a3064 --- /dev/null +++ b/msal4j-mtls-extensions/src/e2e/java/com/microsoft/aad/msal4j/mtls/e2e/E2ETestRunner.java @@ -0,0 +1,70 @@ +// mTLS PoP E2E Test Runner +// +// Dispatches to path1 (Confidential Client) or path2 (Managed Identity) based on the +// first argument. +// +// Usage: +// java -jar target/msal4j-mtls-extensions-1.0.0-e2e.jar path1 [options] +// java -jar target/msal4j-mtls-extensions-1.0.0-e2e.jar path2 [--attest] +// +// Run with no arguments or --help for usage. + +package com.microsoft.aad.msal4j.mtls.e2e; + +import java.util.Arrays; + +/** + * Entry point for the mTLS PoP end-to-end test suite. + * Dispatches to {@link Path1ConfidentialClient} or {@link Path2ManagedIdentity}. + */ +public class E2ETestRunner { + + public static void main(String[] args) throws Exception { + if (args.length == 0 || "--help".equals(args[0]) || "-h".equals(args[0])) { + printUsage(); + return; + } + + String path = args[0].toLowerCase(); + String[] rest = Arrays.copyOfRange(args, 1, args.length); + + switch (path) { + case "path1": + Path1ConfidentialClient.run(rest); + break; + case "path2": + Path2ManagedIdentity.run(rest); + break; + default: + System.err.println("Unknown path: " + args[0]); + printUsage(); + System.exit(1); + } + } + + private static void printUsage() { + System.out.println("msal4j mTLS PoP End-to-End Test Runner"); + System.out.println(); + System.out.println("Usage:"); + System.out.println(" java -jar msal4j-mtls-extensions-*-e2e.jar [options]"); + System.out.println(); + System.out.println("Paths:"); + System.out.println(" path1 Confidential Client (SNI certificate, Azure AD app registration)"); + System.out.println(" path2 Managed Identity (IMDSv2, VBS KeyGuard, Azure VM)"); + System.out.println(); + System.out.println("Path 1 options:"); + System.out.println(" --tenant Azure AD tenant ID"); + System.out.println(" --client Azure AD app (client) ID"); + System.out.println(" --region Azure region (default: centraluseuap)"); + System.out.println(" --resource Downstream resource (default: https://graph.microsoft.com)"); + System.out.println(" --errors-only Run only error-case validation (no Azure credentials needed)"); + System.out.println(); + System.out.println("Path 2 options:"); + System.out.println(" --attest Enable attestation (requires AttestationClientLib.dll on PATH)"); + System.out.println(); + System.out.println("Examples:"); + System.out.println(" java -jar e2e.jar path1 --errors-only"); + System.out.println(" java -jar e2e.jar path1 --tenant --client --region westus2"); + System.out.println(" java -jar e2e.jar path2 --attest"); + } +} diff --git a/msal4j-mtls-extensions/src/e2e/java/com/microsoft/aad/msal4j/mtls/e2e/Path1ConfidentialClient.java b/msal4j-mtls-extensions/src/e2e/java/com/microsoft/aad/msal4j/mtls/e2e/Path1ConfidentialClient.java new file mode 100644 index 00000000..8c3cf15e --- /dev/null +++ b/msal4j-mtls-extensions/src/e2e/java/com/microsoft/aad/msal4j/mtls/e2e/Path1ConfidentialClient.java @@ -0,0 +1,573 @@ +// mTLS PoP Manual Test — Path 1: Confidential Client (SNI Certificate) +// +// Tests both the happy path (requires Azure AD app registration + cert upload) +// and all error cases (no Azure credentials required). +// +// Usage (from the msal4j-mtls-extensions directory): +// mvn package -DskipTests +// +// # Error cases only (no Azure credentials needed): +// java -jar target/msal4j-mtls-extensions-1.0.0-e2e.jar path1 --errors-only +// +// # Full test (requires Azure app registration): +// java -jar target/msal4j-mtls-extensions-1.0.0-e2e.jar path1 \ +// --tenant --client --region +// +// Cert files (test-cert.pem, test-key.pem) must be in the parent directory (mtls-pop/). +// test-cert.pem is committed to the repo; test-key.pem is gitignored. +// +// Generate cert + key: +// openssl genpkey -algorithm RSA -pkeyopt rsa_keygen_bits:2048 -out ../test-key.pem +// openssl req -new -x509 -key ../test-key.pem -out ../test-cert.pem \ +// -days 365 -subj "/CN=msal-java-mtls-test" +// +// Then upload test-cert.pem to your Azure AD app registration under +// "Certificates & secrets" > "Certificates". + +package com.microsoft.aad.msal4j.mtls.e2e; + +import com.microsoft.aad.msal4j.*; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.HttpURLConnection; +import java.net.URL; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.security.KeyFactory; +import java.security.KeyStore; +import java.security.PrivateKey; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.security.spec.PKCS8EncodedKeySpec; +import java.util.Arrays; +import java.util.Base64; +import java.util.Collections; +import java.util.Date; +import java.util.Enumeration; +import javax.net.ssl.SSLContext; + +/** + * End-to-end test for mTLS PoP Confidential Client (Path 1). + * + *

Mirrors msal-go's + * {@code apps/tests/devapps/mtls-pop/path1_confidential/main.go}.

+ */ +public class Path1ConfidentialClient { + + static void run(String[] args) throws Exception { + String tenantId = argValue(args, "--tenant", null); + String clientId = argValue(args, "--client", null); + String region = argValue(args, "--region", "centraluseuap"); + String resource = argValue(args, "--resource", "https://graph.microsoft.com"); + boolean errorsOnly = Arrays.asList(args).contains("--errors-only"); + + // Load cert + key from PEM files (parent directory, same layout as msal-go). + X509Certificate cert = loadCert(); + PrivateKey key = loadKey(); + + IClientCertificate certCred = ClientCredentialFactory.createFromCertificate(key, cert); + + System.out.println("=== Path 1: Error-Case Validation ==="); + System.out.println(); + int[] counts = testErrorCases(certCred, tenantId, region); + System.out.printf("%n Error cases: %d passed, %d failed%n", counts[0], counts[1]); + + if (errorsOnly) { + System.out.println(); + System.out.println("[Skipping happy-path test: --errors-only flag set]"); + System.out.println("To run the happy path, register an Azure AD app and upload the certificate at"); + System.out.println(" ../test-cert.pem"); + System.out.println("then run:"); + System.out.printf(" java -jar path1 --tenant --client --region %s%n", region); + return; + } + + if (tenantId == null || clientId == null) { + System.out.println(); + System.out.println("[Skipping happy-path test: --tenant and --client flags required]"); + System.out.println("Run with --errors-only to test only error cases, or provide --tenant/--client for the full test."); + System.exit(1); + return; + } + + System.out.println(); + System.out.println("=== Path 1: Happy Path ==="); + System.out.println(); + testHappyPath(certCred, cert, key, tenantId, clientId, region, resource); + } + + // ── Error cases ─────────────────────────────────────────────────────────── + + private static int[] testErrorCases(IClientCertificate certCred, + String tenantId, String region) { + String errorTenant = tenantId != null ? tenantId + : System.getenv("AZURE_TENANT_ID") != null ? System.getenv("AZURE_TENANT_ID") + : "00000000-0000-0000-0000-000000000000"; + String authority = "https://login.microsoftonline.com/" + errorTenant; + String placeholderId = "00000000-0000-0000-0000-000000000000"; + String scope = "https://graph.microsoft.com/.default"; + + int pass = 0, fail = 0; + + // Error case 1: missing region + try { + ConfidentialClientApplication app = ConfidentialClientApplication + .builder(placeholderId, certCred) + .authority(authority) + .build(); + app.acquireToken(ClientCredentialParameters + .builder(Collections.singleton(scope)) + .withMtlsProofOfPossession() + .build()).get(); + System.out.println(" ❌ FAIL [missing-region]: expected error, got success"); + fail++; + } catch (Exception e) { + String msg = rootCause(e).getMessage(); + if (msg != null && msg.contains("Azure region")) { + System.out.println(" ✅ PASS [missing-region]: " + msg); + pass++; + } else { + System.out.println(" ❌ FAIL [missing-region]: unexpected error: " + msg); + fail++; + } + } + + // Error case 2: non-tenanted authority (/common) + try { + ConfidentialClientApplication app = ConfidentialClientApplication + .builder(placeholderId, certCred) + .authority("https://login.microsoftonline.com/common") + .azureRegion(region) + .build(); + app.acquireToken(ClientCredentialParameters + .builder(Collections.singleton(scope)) + .withMtlsProofOfPossession() + .build()).get(); + System.out.println(" ❌ FAIL [non-tenanted(/common)]: expected error, got success"); + fail++; + } catch (Exception e) { + String msg = rootCause(e).getMessage(); + if (msg != null && (msg.contains("/common") || msg.contains("/organizations") || msg.contains("tenanted"))) { + System.out.println(" ✅ PASS [non-tenanted(/common)]: " + msg); + pass++; + } else { + System.out.println(" ❌ FAIL [non-tenanted(/common)]: unexpected error: " + msg); + fail++; + } + } + + // Error case 3: non-tenanted authority (/organizations) + try { + ConfidentialClientApplication app = ConfidentialClientApplication + .builder(placeholderId, certCred) + .authority("https://login.microsoftonline.com/organizations") + .azureRegion(region) + .build(); + app.acquireToken(ClientCredentialParameters + .builder(Collections.singleton(scope)) + .withMtlsProofOfPossession() + .build()).get(); + System.out.println(" ❌ FAIL [non-tenanted(/organizations)]: expected error, got success"); + fail++; + } catch (Exception e) { + String msg = rootCause(e).getMessage(); + if (msg != null && (msg.contains("/organizations") || msg.contains("tenanted"))) { + System.out.println(" ✅ PASS [non-tenanted(/organizations)]: " + msg); + pass++; + } else { + System.out.println(" ❌ FAIL [non-tenanted(/organizations)]: unexpected error: " + msg); + fail++; + } + } + + // Error case 4: secret credential (not cert-based) + try { + IClientSecret secretCred = ClientCredentialFactory.createFromSecret("dummy-secret"); + ConfidentialClientApplication app = ConfidentialClientApplication + .builder(placeholderId, secretCred) + .authority(authority) + .azureRegion(region) + .build(); + app.acquireToken(ClientCredentialParameters + .builder(Collections.singleton(scope)) + .withMtlsProofOfPossession() + .build()).get(); + System.out.println(" ❌ FAIL [secret-credential]: expected error, got success"); + fail++; + } catch (Exception e) { + String msg = rootCause(e).getMessage(); + if (msg != null && msg.contains("ClientCertificate")) { + System.out.println(" ✅ PASS [secret-credential]: " + msg); + pass++; + } else { + System.out.println(" ❌ FAIL [secret-credential]: unexpected error: " + msg); + fail++; + } + } + + return new int[]{pass, fail}; + } + + // ── Happy path ──────────────────────────────────────────────────────────── + + private static void testHappyPath(IClientCertificate certCred, + X509Certificate cert, PrivateKey key, + String tenantId, String clientId, + String region, String resource) throws Exception { + String authority = "https://login.microsoftonline.com/" + tenantId; + String scope = resource.replaceAll("/$", "") + "/.default"; + + ConfidentialClientApplication app = ConfidentialClientApplication + .builder(clientId, certCred) + .authority(authority) + .azureRegion(region) + .build(); + + System.out.printf(" Acquiring mTLS PoP token (region=%s, scope=%s)...%n", region, scope); + IAuthenticationResult result1 = app.acquireToken( + ClientCredentialParameters.builder(Collections.singleton(scope)) + .withMtlsProofOfPossession() + .build()).get(); + + System.out.println(); + printResult("First call (from AAD)", result1); + + // Second call — should hit cache + System.out.println(); + System.out.println(" Acquiring again (expect cache hit)..."); + IAuthenticationResult result2 = app.acquireToken( + ClientCredentialParameters.builder(Collections.singleton(scope)) + .withMtlsProofOfPossession() + .build()).get(); + + if (result2.metadata().tokenSource() == TokenSource.CACHE) { + System.out.println(" ✅ Second call returned cached token"); + } else { + System.out.println(" ⚠️ Second call did NOT return cached token (source: " + + result2.metadata().tokenSource() + ")"); + } + if (result2.accessToken() != null && result2.accessToken().equals(result1.accessToken())) { + System.out.println(" ✅ Same access token returned from cache"); + } + + // Downstream call — present the binding cert over mTLS + System.out.println(); + System.out.printf(" Making downstream call to %s...%n", resource); + makeDownstreamCall(result1.accessToken(), cert, key, resource); + + System.out.println(); + System.out.println(" Happy path complete ✅"); + } + + // ── Downstream mTLS call ────────────────────────────────────────────────── + + private static void makeDownstreamCall(String token, + X509Certificate cert, PrivateKey key, + String resource) { + // For Graph, append /v1.0/organization; any 4xx is still a TLS success. + String url = resource.replaceAll("/$", ""); + if (url.contains("graph.microsoft.com")) { + url += "/v1.0/organization"; + } + + try { + // Build an SSLSocketFactory that presents our binding cert as the client cert. + javax.net.ssl.KeyManagerFactory kmf = javax.net.ssl.KeyManagerFactory.getInstance( + javax.net.ssl.KeyManagerFactory.getDefaultAlgorithm()); + java.security.KeyStore ks = java.security.KeyStore.getInstance("PKCS12"); + ks.load(null, null); + ks.setKeyEntry("mtls", key, new char[0], new X509Certificate[]{cert}); + kmf.init(ks, new char[0]); + + SSLContext sslCtx = SSLContext.getInstance("TLS"); + sslCtx.init(kmf.getKeyManagers(), null, null); + + URL reqUrl = new URL(url); + HttpURLConnection conn = (HttpURLConnection) reqUrl.openConnection(); + if (conn instanceof javax.net.ssl.HttpsURLConnection) { + ((javax.net.ssl.HttpsURLConnection) conn).setSSLSocketFactory(sslCtx.getSocketFactory()); + } + conn.setRequestMethod("GET"); + conn.setRequestProperty("Authorization", "Bearer " + token); + conn.setConnectTimeout(10_000); + conn.setReadTimeout(10_000); + conn.connect(); + + int status = conn.getResponseCode(); + switch (status) { + case 200: + System.out.printf(" ✅ Downstream call succeeded: HTTP %d%n", status); + break; + case 401: + System.out.println(" ❌ HTTP 401 — token or mTLS cert rejected"); + break; + case 403: + System.out.println(" ⚠️ HTTP 403 — TLS handshake OK, token accepted, missing permissions"); + break; + default: + System.out.printf(" ⚠️ HTTP %d%n", status); + } + } catch (Exception e) { + System.out.println(" ❌ Downstream call failed: " + e.getMessage()); + } + } + + // ── PEM loading ─────────────────────────────────────────────────────────── + + /** + * Loads the test X.509 certificate. First tries PEM files in the standard locations + * (parent directory or current directory, same layout as msal-go); if those aren't + * present falls back to the bundled {@code mtls-test-cert.p12} so that error-case + * tests work without any manual setup. + */ + private static X509Certificate loadCert() throws Exception { + if (pemFileExists("test-cert.pem")) { + byte[] pem = readPemFile("test-cert.pem"); + return (X509Certificate) CertificateFactory.getInstance("X.509") + .generateCertificate(new ByteArrayInputStream(pem)); + } + return loadCertFromBundledP12(); + } + + /** + * Loads the test RSA private key. Same fallback logic as {@link #loadCert()}. + */ + private static PrivateKey loadKey() throws Exception { + if (pemFileExists("test-key.pem")) { + String raw = readPemFileString("test-key.pem"); + boolean isPkcs8 = raw.contains("BEGIN PRIVATE KEY"); + String b64 = raw + .replaceAll("-----[^-]+-----", "") + .replaceAll("\\s", ""); + byte[] derBytes = Base64.getDecoder().decode(b64); + if (!isPkcs8) { + derBytes = wrapPkcs1InPkcs8(derBytes); + } + return KeyFactory.getInstance("RSA").generatePrivate(new PKCS8EncodedKeySpec(derBytes)); + } + return loadKeyFromBundledP12(); + } + + private static boolean pemFileExists(String name) { + return Files.exists(Paths.get("../" + name)) || Files.exists(Paths.get(name)); + } + + /** Loads the X.509 certificate from the bundled test PKCS#12 store. */ + private static X509Certificate loadCertFromBundledP12() throws Exception { + KeyStore ks = loadBundledKeyStore(); + Enumeration aliases = ks.aliases(); + while (aliases.hasMoreElements()) { + String alias = aliases.nextElement(); + if (ks.isCertificateEntry(alias) || ks.isKeyEntry(alias)) { + return (X509Certificate) ks.getCertificate(alias); + } + } + throw new IOException("No certificate found in bundled mtls-test-cert.p12"); + } + + /** Loads the private key from the bundled test PKCS#12 store. */ + private static PrivateKey loadKeyFromBundledP12() throws Exception { + KeyStore ks = loadBundledKeyStore(); + Enumeration aliases = ks.aliases(); + while (aliases.hasMoreElements()) { + String alias = aliases.nextElement(); + if (ks.isKeyEntry(alias)) { + return (PrivateKey) ks.getKey(alias, "changeit".toCharArray()); + } + } + throw new IOException("No private key found in bundled mtls-test-cert.p12"); + } + + private static KeyStore loadBundledKeyStore() throws Exception { + InputStream is = Path1ConfidentialClient.class.getResourceAsStream("/mtls-test-cert.p12"); + if (is == null) { + throw new IOException( + "Cannot find test-cert.pem or the bundled mtls-test-cert.p12.\n" + + "Generate cert + key with:\n" + + " openssl genpkey -algorithm RSA -pkeyopt rsa_keygen_bits:2048 -out ../test-key.pem\n" + + " openssl req -new -x509 -key ../test-key.pem -out ../test-cert.pem" + + " -days 365 -subj \"/CN=msal-java-mtls-test\""); + } + KeyStore ks = KeyStore.getInstance("PKCS12"); + ks.load(is, "changeit".toCharArray()); + return ks; + } + + private static byte[] readPemFile(String name) throws IOException { + return readPemFileString(name).getBytes(); + } + + private static String readPemFileString(String name) throws IOException { + // Look in ../ (same as msal-go: cert files live in mtls-pop/, one level up) + String[] candidates = {"../" + name, name}; + for (String path : candidates) { + if (Files.exists(Paths.get(path))) { + return new String(Files.readAllBytes(Paths.get(path))); + } + } + throw new IOException("Cannot find " + name + " (tried ../" + name + " and " + name + + ").\nGenerate with:\n" + + " openssl genpkey -algorithm RSA -pkeyopt rsa_keygen_bits:2048 -out ../test-key.pem\n" + + " openssl req -new -x509 -key ../test-key.pem -out ../test-cert.pem -days 365 -subj \"/CN=msal-java-mtls-test\""); + } + + /** + * Wraps a PKCS#1 RSA private key DER blob in a PKCS#8 PrivateKeyInfo envelope so that + * Java's {@link KeyFactory} can parse it without Bouncy Castle. + */ + private static byte[] wrapPkcs1InPkcs8(byte[] pkcs1) { + // rsaEncryption OID: 1.2.840.113549.1.1.1 + byte[] oidBytes = {0x2a, (byte)0x86, 0x48, (byte)0x86, (byte)0xf7, 0x0d, 0x01, 0x01, 0x01}; + byte[] algId = derSeq(concat(derTlv(0x06, oidBytes), new byte[]{0x05, 0x00})); + byte[] version = {0x02, 0x01, 0x00}; + byte[] privKey = derTlv(0x04, pkcs1); + return derSeq(concat(version, algId, privKey)); + } + + private static byte[] derSeq(byte[] content) { + return derTlv(0x30, content); + } + + private static byte[] derTlv(int tag, byte[] value) { + byte[] len = derLen(value.length); + byte[] out = new byte[1 + len.length + value.length]; + out[0] = (byte) tag; + System.arraycopy(len, 0, out, 1, len.length); + System.arraycopy(value, 0, out, 1 + len.length, value.length); + return out; + } + + private static byte[] derLen(int n) { + if (n < 128) return new byte[]{(byte) n}; + if (n < 256) return new byte[]{(byte) 0x81, (byte) n}; + return new byte[]{(byte) 0x82, (byte)(n >> 8), (byte)(n & 0xff)}; + } + + private static byte[] concat(byte[]... parts) { + int total = 0; + for (byte[] p : parts) total += p.length; + byte[] out = new byte[total]; + int pos = 0; + for (byte[] p : parts) { System.arraycopy(p, 0, out, pos, p.length); pos += p.length; } + return out; + } + + // ── Print helpers ───────────────────────────────────────────────────────── + + private static void printResult(String label, IAuthenticationResult result) { + System.out.println("[" + label + "]"); + + X509Certificate binding = result.bindingCertificate(); + if (binding != null) { + System.out.println(" ✅ BindingCertificate: subject=" + binding.getSubjectX500Principal().getName() + + ", expires=" + binding.getNotAfter()); + } else { + System.out.println(" ❌ BindingCertificate is null — expected non-null for mTLS PoP"); + } + + System.out.println(" TokenType: " + result.tokenType()); + System.out.println(" Scopes: " + result.scopes()); + System.out.println(" ExpiresOn: " + result.expiresOnDate()); + + printTokenSummary(result.accessToken()); + } + + private static void printTokenSummary(String jwt) { + if (jwt == null || jwt.isEmpty()) { + System.out.println(" ❌ AccessToken is null/empty"); + return; + } + String[] parts = jwt.split("\\."); + if (parts.length < 2) { + System.out.printf(" AccessToken: (opaque, %d chars)%n", jwt.length()); + return; + } + try { + String header = new String(Base64.getUrlDecoder().decode(pad(parts[0]))); + String payload = new String(Base64.getUrlDecoder().decode(pad(parts[1]))); + System.out.println(" AccessToken header: " + header); + printClaim(payload, "oid"); + printClaim(payload, "tid"); + printClaim(payload, "token_type"); + printClaim(payload, "cnf"); + long expEpoch = extractLong(payload, "exp"); + if (expEpoch > 0) { + System.out.println(" AccessToken exp: " + new Date(expEpoch * 1000)); + } + System.out.printf(" ✅ AccessToken present (%d chars)%n", jwt.length()); + } catch (Exception e) { + System.out.println(" AccessToken: (could not decode JWT: " + e.getMessage() + ")"); + } + } + + private static void printClaim(String payload, String key) { + String val = extractString(payload, key); + if (val != null) { + if (val.length() > 120) val = val.substring(0, 120) + "..."; + System.out.println(" AccessToken " + key + ": " + val); + } + } + + // ── Minimal JSON extract (no external deps) ─────────────────────────────── + + private static String extractString(String json, String key) { + String search = "\"" + key + "\""; + int idx = json.indexOf(search); + if (idx < 0) return null; + int colon = json.indexOf(':', idx + search.length()); + if (colon < 0) return null; + int vs = colon + 1; + while (vs < json.length() && Character.isWhitespace(json.charAt(vs))) vs++; + if (vs >= json.length()) return null; + char first = json.charAt(vs); + if (first == '"') { + int end = vs + 1; + while (end < json.length() && json.charAt(end) != '"') end++; + return json.substring(vs + 1, end); + } else if (first == '{' || first == '[') { + char close = first == '{' ? '}' : ']'; + int depth = 0, end = vs; + while (end < json.length()) { + char c = json.charAt(end); + if (c == first) depth++; + else if (c == close) { if (--depth == 0) { end++; break; } } + end++; + } + return json.substring(vs, end); + } + return null; + } + + private static long extractLong(String json, String key) { + String search = "\"" + key + "\""; + int idx = json.indexOf(search); + if (idx < 0) return 0; + int colon = json.indexOf(':', idx + search.length()); + if (colon < 0) return 0; + int vs = colon + 1; + while (vs < json.length() && Character.isWhitespace(json.charAt(vs))) vs++; + int ve = vs; + while (ve < json.length() && (Character.isDigit(json.charAt(ve)) || json.charAt(ve) == '-')) ve++; + try { return Long.parseLong(json.substring(vs, ve)); } catch (Exception e) { return 0; } + } + + private static String pad(String s) { + return s + "==".substring(0, (4 - s.length() % 4) % 4); + } + + // ── Utility ─────────────────────────────────────────────────────────────── + + private static String argValue(String[] args, String flag, String defaultVal) { + for (int i = 0; i < args.length - 1; i++) { + if (flag.equals(args[i])) return args[i + 1]; + } + return defaultVal; + } + + private static Throwable rootCause(Throwable t) { + while (t.getCause() != null && t.getCause() != t) t = t.getCause(); + return t; + } +} diff --git a/msal4j-mtls-extensions/src/e2e/java/com/microsoft/aad/msal4j/mtls/e2e/Path2ManagedIdentity.java b/msal4j-mtls-extensions/src/e2e/java/com/microsoft/aad/msal4j/mtls/e2e/Path2ManagedIdentity.java new file mode 100644 index 00000000..c670112f --- /dev/null +++ b/msal4j-mtls-extensions/src/e2e/java/com/microsoft/aad/msal4j/mtls/e2e/Path2ManagedIdentity.java @@ -0,0 +1,282 @@ +// mTLS PoP Manual Test — Path 2: Managed Identity (IMDSv2, Windows + VBS) +// +// Tests the managed identity mTLS PoP flow end-to-end on an Azure VM with: +// - System-assigned or user-assigned managed identity +// - Windows OS with VBS (Virtualization-Based Security) KeyGuard +// - IMDSv2 endpoint accessible at 169.254.169.254 +// +// Usage (from the msal4j-mtls-extensions directory): +// mvn package -DskipTests +// mvn exec:java -Dexec.mainClass=com.microsoft.aad.msal4j.mtls.e2e.Path2ManagedIdentity +// +// Or with attestation: +// mvn exec:java -Dexec.mainClass=com.microsoft.aad.msal4j.mtls.e2e.Path2ManagedIdentity -Dexec.args="--attest" + +package com.microsoft.aad.msal4j.mtls.e2e; + +import com.microsoft.aad.msal4j.mtls.MtlsMsiClient; +import com.microsoft.aad.msal4j.mtls.MtlsMsiException; +import com.microsoft.aad.msal4j.mtls.MtlsMsiHelperResult; +import com.microsoft.aad.msal4j.mtls.MtlsMsiHttpResponse; + +import java.io.ByteArrayInputStream; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.util.Arrays; +import java.util.Base64; +import java.util.UUID; + +/** + * End-to-end test for mTLS PoP Managed Identity (Path 2). + * + *

Mirrors msal-go's {@code apps/tests/devapps/mtls-pop/path2_managedidentity/main.go}.

+ */ +public class Path2ManagedIdentity { + + private static final String RESOURCE = "https://graph.microsoft.com"; + + public static void main(String[] args) throws Exception { + run(args); + } + + static void run(String[] args) throws Exception { + boolean withAttestation = Arrays.asList(args).contains("--attest"); + + System.out.println("=== Path 2: Managed Identity mTLS PoP ==="); + System.out.println(); + if (withAttestation) { + System.out.println("[Attestation mode: ON — requires AttestationClientLib.dll on PATH]"); + System.out.println(); + } + + MtlsMsiClient client = new MtlsMsiClient(); + String correlationId = UUID.randomUUID().toString(); + + // ── First call: full IMDS flow ────────────────────────────────────────── + System.out.println("Acquiring mTLS PoP token via IMDSv2 (full flow)..."); + MtlsMsiHelperResult result1; + try { + result1 = client.acquireToken(RESOURCE, "SystemAssigned", null, + withAttestation, correlationId); + } catch (MtlsMsiException e) { + System.out.println(); + System.err.println("❌ acquireToken failed: " + e.getMessage()); + System.err.println(); + // Check for tenant/resource misconfiguration (not a code bug) + if (e.getMessage() != null && e.getMessage().contains("AADSTS392196")) { + System.err.println("ℹ️ AADSTS392196: The resource application does not support certificate-bound tokens."); + System.err.println(" This is a tenant/resource configuration issue (same as MSAL.NET on this VM)."); + System.err.println(" The mTLS handshake succeeded — the code is working correctly."); + System.err.println(" To fully test, use a tenant where mTLS PoP is enabled for graph.microsoft.com."); + } else { + System.err.println("Common causes:"); + System.err.println(" - VBS/KeyGuard not running (check msinfo32.exe)"); + System.err.println(" - IMDSv2 not returning platform metadata"); + System.err.println(" - VM managed identity not configured"); + System.err.println(" - 403 from IMDS issuecredential endpoint"); + System.err.println(" - Tenant not configured for mTLS PoP (AADSTS392196)"); + } + System.exit(1); + return; + } + + System.out.println(); + printResult("First call (from IMDS)", result1); + + // ── Second call: should hit cached binding cert ───────────────────────── + System.out.println(); + System.out.println("Acquiring again (expect cert cache hit)..."); + long t0 = System.currentTimeMillis(); + MtlsMsiHelperResult result2; + try { + result2 = client.acquireToken(RESOURCE, "SystemAssigned", null, + withAttestation, UUID.randomUUID().toString()); + } catch (MtlsMsiException e) { + System.err.println("❌ Second acquireToken failed: " + e.getMessage()); + System.exit(1); + return; + } + long elapsedMs = System.currentTimeMillis() - t0; + printResult("Second call (should be cert-cached, ~fast)", result2); + System.out.printf(" ⏱ Elapsed: %d ms%n", elapsedMs); + + // Cert cache check: same cert PEM implies cert was cached. + if (result1.getBindingCertificate() != null + && result1.getBindingCertificate().equals(result2.getBindingCertificate())) { + System.out.println(" ✅ Binding cert cache working: same cert on second call"); + } else { + System.out.println(" ⚠️ Different binding cert on second call — may indicate cache miss or cert was expiring"); + } + + // ── Third call: Graph /me to verify token actually works ──────────────── + System.out.println(); + System.out.println("Making downstream mTLS call to graph.microsoft.com..."); + makeDownstreamCall(client, result1, withAttestation); + + System.out.println(); + System.out.println("=== Path 2 Complete ==="); + } + + // ── Downstream mTLS call ────────────────────────────────────────────────── + + private static void makeDownstreamCall(MtlsMsiClient client, MtlsMsiHelperResult result, + boolean withAttestation) { + // graph.microsoft.com /v1.0/servicePrincipals — any auth error is still a TLS success. + String url = "https://graph.microsoft.com/v1.0/servicePrincipals?$top=1"; + try { + MtlsMsiHttpResponse resp = client.httpRequest( + url, "GET", result.getAccessToken(), + null, null, null, + RESOURCE, "SystemAssigned", null, + withAttestation, UUID.randomUUID().toString(), + false); + + System.out.printf(" Downstream HTTP status: %d%n", resp.getStatus()); + if (resp.getStatus() < 500) { + System.out.println(" ✅ TLS handshake + token delivery succeeded (HTTP < 500)"); + } else { + System.out.println(" ❌ Server error — check token and resource enrollment"); + } + if (resp.getStatus() == 200) { + System.out.println(" ✅ HTTP 200 — full mTLS PoP token accepted by graph.microsoft.com"); + } else if (resp.getStatus() == 401 || resp.getStatus() == 403) { + System.out.println(" ℹ️ " + resp.getStatus() + " — TLS OK, authorization depends on permissions"); + } + } catch (MtlsMsiException e) { + System.out.println(" ❌ Downstream mTLS call failed: " + e.getMessage()); + } + } + + // ── Print helpers ───────────────────────────────────────────────────────── + + private static void printResult(String label, MtlsMsiHelperResult result) { + System.out.println("[" + label + "]"); + + // Print binding cert details. + if (result.getBindingCertificate() != null) { + System.out.println(" ✅ BindingCertificate present"); + try { + X509Certificate cert = parsePem(result.getBindingCertificate()); + System.out.println(" Subject: " + cert.getSubjectX500Principal().getName()); + System.out.println(" Issuer: " + cert.getIssuerX500Principal().getName()); + System.out.println(" NotBefore: " + cert.getNotBefore()); + System.out.println(" NotAfter: " + cert.getNotAfter()); + } catch (Exception e) { + System.out.println(" (could not parse cert: " + e.getMessage() + ")"); + } + } else { + System.out.println(" ❌ BindingCertificate is null — expected non-null for mTLS PoP"); + } + + System.out.println(" TokenType: " + result.getTokenType()); + System.out.println(" ExpiresIn: " + result.getExpiresIn() + "s"); + System.out.println(" TenantId: " + result.getTenantId()); + System.out.println(" ClientId: " + result.getClientId()); + + // Print abbreviated JWT header/claims. + printTokenSummary(result.getAccessToken()); + } + + private static void printTokenSummary(String jwt) { + if (jwt == null || jwt.isEmpty()) { + System.out.println(" ❌ AccessToken is null/empty"); + return; + } + String[] parts = jwt.split("\\."); + if (parts.length < 2) { + System.out.println(" AccessToken: (not a JWT — " + jwt.length() + " chars)"); + return; + } + try { + String header = new String(Base64.getUrlDecoder().decode(pad(parts[0]))); + String payload = new String(Base64.getUrlDecoder().decode(pad(parts[1]))); + System.out.println(" AccessToken header: " + header); + // Key claims + printClaim(payload, "oid"); + printClaim(payload, "tid"); + printClaim(payload, "appid"); + printClaim(payload, "app_displayname"); + printClaim(payload, "idtyp"); + printClaim(payload, "appidacr"); + printClaim(payload, "aud"); + printClaim(payload, "token_type"); + printClaim(payload, "xms_tbflags"); + printClaim(payload, "cnf"); + long expEpoch = extractLong(payload, "exp"); + if (expEpoch > 0) { + System.out.println(" AccessToken exp: " + + new java.util.Date(expEpoch * 1000)); + } + System.out.println(" ✅ AccessToken present (" + jwt.length() + " chars)"); + System.out.println(" Raw JWT:"); + System.out.println(" " + jwt); + } catch (Exception e) { + System.out.println(" AccessToken: (could not decode JWT: " + e.getMessage() + ")"); + } + } + + private static void printClaim(String payload, String key) { + String val = extractString(payload, key); + if (val != null) { + // Truncate long values (e.g. cnf object). + if (val.length() > 120) val = val.substring(0, 120) + "..."; + System.out.println(" AccessToken " + key + ": " + val); + } + } + + private static String extractString(String json, String key) { + String search = "\"" + key + "\""; + int idx = json.indexOf(search); + if (idx < 0) return null; + int colon = json.indexOf(':', idx + search.length()); + if (colon < 0) return null; + int vs = colon + 1; + while (vs < json.length() && Character.isWhitespace(json.charAt(vs))) vs++; + if (vs >= json.length()) return null; + char first = json.charAt(vs); + if (first == '"') { + int end = vs + 1; + while (end < json.length() && json.charAt(end) != '"') end++; + return json.substring(vs + 1, end); + } else if (first == '{' || first == '[') { + // Return the whole nested object/array. + char close = first == '{' ? '}' : ']'; + int depth = 0, end = vs; + while (end < json.length()) { + char c = json.charAt(end); + if (c == first) depth++; + else if (c == close) { if (--depth == 0) { end++; break; } } + end++; + } + return json.substring(vs, end); + } + return null; + } + + private static long extractLong(String json, String key) { + String search = "\"" + key + "\""; + int idx = json.indexOf(search); + if (idx < 0) return 0; + int colon = json.indexOf(':', idx + search.length()); + if (colon < 0) return 0; + int vs = colon + 1; + while (vs < json.length() && Character.isWhitespace(json.charAt(vs))) vs++; + int ve = vs; + while (ve < json.length() && (Character.isDigit(json.charAt(ve)) || json.charAt(ve) == '-')) ve++; + try { return Long.parseLong(json.substring(vs, ve)); } catch (Exception e) { return 0; } + } + + private static String pad(String s) { + return s + "==".substring(0, (4 - s.length() % 4) % 4); + } + + private static X509Certificate parsePem(String pem) throws Exception { + String b64 = pem + .replace("-----BEGIN CERTIFICATE-----", "") + .replace("-----END CERTIFICATE-----", "") + .replaceAll("\\s", ""); + byte[] der = Base64.getDecoder().decode(b64); + return (X509Certificate) CertificateFactory.getInstance("X.509") + .generateCertificate(new ByteArrayInputStream(der)); + } +} diff --git a/msal4j-mtls-extensions/src/e2e/resources/mtls-test-cert.p12 b/msal4j-mtls-extensions/src/e2e/resources/mtls-test-cert.p12 new file mode 100644 index 00000000..63315c69 Binary files /dev/null and b/msal4j-mtls-extensions/src/e2e/resources/mtls-test-cert.p12 differ diff --git a/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/AttestationLibrary.java b/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/AttestationLibrary.java new file mode 100644 index 00000000..b5e17a8a --- /dev/null +++ b/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/AttestationLibrary.java @@ -0,0 +1,110 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.aad.msal4j.mtls; + +import com.sun.jna.Callback; +import com.sun.jna.Library; +import com.sun.jna.Pointer; +import com.sun.jna.Structure; +import com.sun.jna.ptr.PointerByReference; + +import java.util.Arrays; +import java.util.List; + +/** + * JNA binding for {@code AttestationClientLib.dll} — the Windows DLL shipped by Azure + * that produces a MAA (Microsoft Azure Attestation) JWT proving a CNG KeyGuard key is + * hardware-protected. + * + *

Function signatures (ANSI cdecl, x64 Windows) documented in MSAL.NET's + * {@code KeyGuardMaa/AttestationInterop.cs} and also used by msal-go's + * {@code cng_windows.go}:

+ *
+ *   int  InitAttestationLib(AttestationLogInfo*)
+ *   int  AttestKeyGuardImportKey(char* endpoint, char* authToken, char* clientPayload,
+ *                                NCRYPT_KEY_HANDLE keyHandle, char** token, char* clientId)
+ *   void FreeAttestationToken(char* token)
+ *   void UninitAttestationLib()
+ * 
+ * + *

This interface is loaded lazily via {@link CngKeyGuard} — it is only required when + * MAA attestation is requested and the DLL is present on the system. If the DLL is absent + * and attestation is not requested, it is never loaded.

+ */ +interface AttestationLibrary extends Library { + + /** + * No-op log callback that satisfies the DLL's requirement for a non-null LogFunc. + * + *

The DLL requires a non-null log function pointer in {@link AttestationLogInfo}. + * Passing {@code Pointer.NULL} causes {@code InitAttestationLib} to return an error + * (0xFFFFFFF8 = -8). Mirrors msal-go's {@code dummyLogCallback}.

+ * + *

Signature (cdecl, x64 Windows): + * {@code void LogFunc(void* ctx, char* tag, int lvl, char* func, int line, char* msg)}

+ */ + interface LogCallback extends Callback { + void log(Pointer ctx, Pointer tag, int level, Pointer func, int line, Pointer msg); + } + + /** Shared no-op log callback instance — kept alive to prevent GC. */ + LogCallback NOOP_LOG = (ctx, tag, level, func, line, msg) -> {}; + + /** + * Mirrors the {@code AttestationLogInfo} struct: + *
struct AttestationLogInfo { LogFunc Log; void* Ctx; }
+ * + *

The {@code logFunc} field MUST be a non-null function pointer — the DLL validates + * this and returns an error if it is null. Use {@link #NOOP_LOG} for no-op logging.

+ */ + class AttestationLogInfo extends Structure { + /** Function pointer for the log callback. MUST NOT be null. */ + public LogCallback logFunc; + /** Caller context pointer, passed as first arg to logFunc. */ + public Pointer ctx; + + public AttestationLogInfo() { + logFunc = NOOP_LOG; // DLL requires a non-null log function pointer + ctx = Pointer.NULL; + } + + @Override + protected List getFieldOrder() { + return Arrays.asList("logFunc", "ctx"); + } + } + + /** + * Initializes the attestation library. + * + * @param logInfo logging configuration; {@code logFunc} MUST be non-null + * @return 0 on success, non-zero on failure + */ + int InitAttestationLib(AttestationLogInfo logInfo); + + /** + * Produces a MAA JWT proving the given CNG key is VBS/KeyGuard-protected. + * + * @param endpoint MAA endpoint URL (ANSI string, e.g. "https://sharedcuse.cuse.attest.azure.net") + * @param authToken unused, pass null + * @param clientPayload unused, pass null + * @param keyHandle the {@code NCRYPT_KEY_HANDLE} from NCrypt* operations + * @param tokenOut receives the pointer to the MAA JWT string (caller must free with FreeAttestationToken) + * @param clientId managed identity client ID (ANSI string) + * @return 0 on success, non-zero on failure + */ + int AttestKeyGuardImportKey(String endpoint, String authToken, String clientPayload, + Pointer keyHandle, PointerByReference tokenOut, String clientId); + + /** + * Frees a MAA JWT string allocated by {@link #AttestKeyGuardImportKey}. + * + * @param token the pointer returned in {@code tokenOut} by AttestKeyGuardImportKey + */ + void FreeAttestationToken(Pointer token); + + /** Uninitializes the attestation library. Call after all attestation operations. */ + void UninitAttestationLib(); +} + diff --git a/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/CngKeyGuard.java b/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/CngKeyGuard.java new file mode 100644 index 00000000..8ecb3615 --- /dev/null +++ b/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/CngKeyGuard.java @@ -0,0 +1,358 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.aad.msal4j.mtls; + +import com.sun.jna.Memory; +import com.sun.jna.Native; +import com.sun.jna.Pointer; +import com.sun.jna.WString; +import com.sun.jna.ptr.IntByReference; +import com.sun.jna.ptr.PointerByReference; + +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +/** + * Windows CNG key operations for mTLS PoP Managed Identity. + * + *

Mirrors msal-go's {@code cng_windows.go} and MSAL.NET's + * {@code WindowsManagedIdentityKeyProvider}: creates or opens a persisted RSA key in + * the Microsoft Software Key Storage Provider, using the same 3-level priority:

+ *
    + *
  1. KeyGuard — Software KSP + USER scope + VBS Virtual Isolation flags. + * Requires Credential Guard / Core Isolation on the VM.
  2. + *
  3. Hardware — Software KSP + USER scope, no VBS flags.
  4. + *
  5. mTLS PoP requires KeyGuard and throws {@link MtlsMsiException} for Hardware keys.
  6. + *
+ */ +public final class CngKeyGuard { + + private static final String MS_SOFTWARE_KSP = "Microsoft Software Key Storage Provider"; + private static final String RSA_ALG = "RSA"; + private static final String EXPORT_POLICY = "Export Policy"; + private static final String KEY_LENGTH = "Length"; + private static final String VIRTUAL_ISO = "Virtual Iso"; + private static final String RSAPUBLICBLOB = "RSAPUBLICBLOB"; + + private CngKeyGuard() {} + + /** + * Gets or creates the mTLS PoP binding key, attempting KeyGuard first. + * + * @param keyName persisted key name in the KSP (e.g. {@code "MSALMtlsKey_"}) + * @return a {@link CngRsaPrivateKey} backed by the CNG handle + * @throws MtlsMsiException if the system is not Windows, the key cannot be created, + * or KeyGuard protection is unavailable + */ + public static CngRsaPrivateKey getOrCreateKey(String keyName) throws MtlsMsiException { + if (!isWindows()) { + throw new MtlsMsiException("mTLS PoP Managed Identity is only supported on Windows Azure VMs."); + } + + // 1. Try KeyGuard (USER scope + VBS Virtual Isolation flags). + int kgCreateFlags = NCryptLibrary.NCRYPT_OVERWRITE_KEY_FLAG + | NCryptLibrary.NCRYPT_USE_VIRTUAL_ISOLATION_FLAG + | NCryptLibrary.NCRYPT_USE_PER_BOOT_KEY_FLAG; + try { + CngRsaPrivateKey key = tryGetOrCreateKey(keyName, NCryptLibrary.NCRYPT_SILENT_FLAG, kgCreateFlags, NCryptLibrary.NCRYPT_SILENT_FLAG); + if (isKeyGuardProtected(key.getHandle())) { + return key; + } + // Created but VBS protection not active — delete and retry once (mirrors MSAL.NET). + NCryptLibrary.INSTANCE.NCryptDeleteKey(key.getHandle(), 0); + key = tryGetOrCreateKey(keyName, NCryptLibrary.NCRYPT_SILENT_FLAG, kgCreateFlags, NCryptLibrary.NCRYPT_SILENT_FLAG); + if (isKeyGuardProtected(key.getHandle())) { + return key; + } + NCryptLibrary.INSTANCE.NCryptFreeObject(key.getHandle()); + } catch (MtlsMsiException ignored) { + // KeyGuard not available on this VM; fall through to error below. + } + + throw new MtlsMsiException( + "mTLS PoP requires a VBS KeyGuard-protected RSA key, but KeyGuard is not available " + + "on this VM. Ensure Credential Guard / Core Isolation is enabled: the VM must be " + + "Trusted Launch (Secure Boot + vTPM) with VBS active " + + "(check msinfo32.exe: 'Virtualization-based security' = Running)."); + } + + /** + * Produces a MAA JWT by calling {@code AttestationClientLib.dll}. + * + * @param keyHandle CNG key handle from {@link CngRsaPrivateKey#getHandle()} + * @param endpoint MAA attestation endpoint URL (from IMDS platform metadata) + * @param clientId managed identity client ID (from IMDS platform metadata) + * @return the MAA JWT string + * @throws MtlsMsiException if the DLL is not present, or attestation fails + */ + public static String getAttestationToken(Pointer keyHandle, String endpoint, String clientId) + throws MtlsMsiException { + + AttestationLibrary attestLib; + try { + attestLib = Native.load("AttestationClientLib", AttestationLibrary.class); + } catch (UnsatisfiedLinkError e) { + throw new MtlsMsiException( + "AttestationClientLib.dll not found. Place the DLL in a directory on the system PATH " + + "or in the same directory as the JVM. " + + "Obtain it from the Microsoft.Azure.Security.KeyGuardAttestation NuGet package " + + "(runtimes/win-x64/native/AttestationClientLib.dll). Error: " + e.getMessage(), e); + } + + AttestationLibrary.AttestationLogInfo logInfo = new AttestationLibrary.AttestationLogInfo(); + int ret = attestLib.InitAttestationLib(logInfo); + if (ret != 0) { + throw new MtlsMsiException( + String.format("InitAttestationLib failed: 0x%x", ret)); + } + + try { + PointerByReference tokenRef = new PointerByReference(); + ret = attestLib.AttestKeyGuardImportKey(endpoint, null, null, keyHandle, tokenRef, clientId); + if (ret != 0) { + throw new MtlsMsiException(String.format( + "AttestKeyGuardImportKey failed (rc=0x%x). This usually means the VM's vTPM " + + "is not provisioned for attestation. mTLS PoP requires a Trusted Launch Azure VM " + + "(Secure Boot + vTPM) with an EK certificate. " + + "Check 'tpmtool.exe getdeviceinformation': 'Is Capable For Attestation' must be true.", ret)); + } + + Pointer tokenPtr = tokenRef.getValue(); + if (tokenPtr == null || tokenPtr == Pointer.NULL) { + throw new MtlsMsiException("AttestKeyGuardImportKey returned null token"); + } + + try { + String jwt = tokenPtr.getString(0); // ANSI (null-terminated) + if (jwt == null || jwt.isEmpty()) { + throw new MtlsMsiException("AttestKeyGuardImportKey returned empty token"); + } + return jwt; + } finally { + attestLib.FreeAttestationToken(tokenPtr); + } + } finally { + attestLib.UninitAttestationLib(); + } + } + + /** + * Signs a digest using {@code NCryptSignHash} with PKCS#1 v1.5 padding. + * + * @param keyHandle CNG key handle + * @param digest the hash bytes to sign + * @param hashAlgCng CNG hash algorithm name (e.g. {@code "SHA256"}) + * @return DER-encoded signature bytes + * @throws MtlsMsiException if signing fails + */ + public static byte[] signPkcs1(Pointer keyHandle, byte[] digest, String hashAlgCng) + throws MtlsMsiException { + NCryptLibrary.BcryptPkcs1PaddingInfo padding = + new NCryptLibrary.BcryptPkcs1PaddingInfo(hashAlgCng); + return ncryptSign(keyHandle, padding.getPointer(), NCryptLibrary.NCRYPT_PAD_PKCS1_FLAG, digest, "PKCS1v15"); + } + + /** + * Signs a digest using {@code NCryptSignHash} with RSASSA-PSS padding. + * + * @param keyHandle CNG key handle + * @param digest the hash bytes to sign + * @param hashAlgCng CNG hash algorithm name (e.g. {@code "SHA256"}) + * @param saltLen PSS salt length in bytes + * @return DER-encoded signature bytes + * @throws MtlsMsiException if signing fails + */ + public static byte[] signPss(Pointer keyHandle, byte[] digest, String hashAlgCng, int saltLen) + throws MtlsMsiException { + NCryptLibrary.BcryptPssPaddingInfo padding = + new NCryptLibrary.BcryptPssPaddingInfo(hashAlgCng, saltLen); + return ncryptSign(keyHandle, padding.getPointer(), NCryptLibrary.NCRYPT_PAD_PSS_FLAG, digest, "PSS"); + } + + private static byte[] ncryptSign(Pointer hKey, Pointer paddingPtr, int paddingFlag, + byte[] digest, String label) throws MtlsMsiException { + IntByReference sigLen = new IntByReference(0); + // First call: query the signature buffer size. + int ret = NCryptLibrary.INSTANCE.NCryptSignHash( + hKey, paddingPtr, digest, digest.length, + null, 0, sigLen, paddingFlag); + if (ret != NCryptLibrary.ERROR_SUCCESS) { + throw new MtlsMsiException( + String.format("NCryptSignHash %s (size query) failed: 0x%x", label, ret)); + } + + Memory sigBuf = new Memory(sigLen.getValue()); + ret = NCryptLibrary.INSTANCE.NCryptSignHash( + hKey, paddingPtr, digest, digest.length, + sigBuf, sigLen.getValue(), sigLen, paddingFlag); + if (ret != NCryptLibrary.ERROR_SUCCESS) { + throw new MtlsMsiException( + String.format("NCryptSignHash %s failed: 0x%x", label, ret)); + } + + return sigBuf.getByteArray(0, sigLen.getValue()); + } + + // ─── Internal helpers ───────────────────────────────────────────────────── + + private static CngRsaPrivateKey tryGetOrCreateKey(String keyName, + int openFlags, + int createFlags, + int finalizeFlags) throws MtlsMsiException { + Pointer hProvider = openProvider(); + try { + WString keyNameW = new WString(keyName); + + // 1. Try to open an existing key. + PointerByReference phKey = new PointerByReference(); + int ret = NCryptLibrary.INSTANCE.NCryptOpenKey( + hProvider, phKey, keyNameW, 0, openFlags); + + Pointer hKey; + if (ret == NCryptLibrary.ERROR_SUCCESS) { + hKey = phKey.getValue(); + // Verify the key is usable by exporting the public blob. + try { + exportPublicKeyBytes(hKey); + } catch (MtlsMsiException e) { + NCryptLibrary.INSTANCE.NCryptDeleteKey(hKey, 0); + hKey = null; + ret = -1; + } + } else { + hKey = null; + } + + // 2. Create a new key if open failed. + if (hKey == null) { + ret = NCryptLibrary.INSTANCE.NCryptCreatePersistedKey( + hProvider, phKey, + new WString(RSA_ALG), + keyNameW, + 0, createFlags); + if (ret != NCryptLibrary.ERROR_SUCCESS) { + throw new MtlsMsiException( + String.format("NCryptCreatePersistedKey failed: 0x%x", ret)); + } + hKey = phKey.getValue(); + + // Set key length to 2048. + setDwordProperty(hKey, KEY_LENGTH, 2048); + // Set non-exportable. + setDwordProperty(hKey, EXPORT_POLICY, NCryptLibrary.NCRYPT_ALLOW_EXPORT_NONE); + + ret = NCryptLibrary.INSTANCE.NCryptFinalizeKey(hKey, finalizeFlags); + if (ret != NCryptLibrary.ERROR_SUCCESS) { + NCryptLibrary.INSTANCE.NCryptDeleteKey(hKey, 0); + throw new MtlsMsiException( + String.format("NCryptFinalizeKey failed: 0x%x. " + + "VBS isolation flags are not supported on this machine " + + "(Credential Guard / Core Isolation not active).", ret)); + } + } + + BigInteger[] pubKey = exportPublicKey(hKey); + return new CngRsaPrivateKey(hKey, pubKey[0], pubKey[1].intValue()); + + } finally { + NCryptLibrary.INSTANCE.NCryptFreeObject(hProvider); + } + } + + static boolean isKeyGuardProtected(Pointer hKey) { + WString propW = new WString(VIRTUAL_ISO); + byte[] buf = new byte[4]; + IntByReference pcbResult = new IntByReference(0); + int ret = NCryptLibrary.INSTANCE.NCryptGetProperty( + hKey, propW, buf, buf.length, pcbResult, 0); + if (ret != NCryptLibrary.ERROR_SUCCESS || pcbResult.getValue() < 4) { + return false; + } + int val = ByteBuffer.wrap(buf).order(ByteOrder.LITTLE_ENDIAN).getInt(); + return val != 0; + } + + /** Returns byte[] of the RSAPUBLICBLOB for use in CSR SubjectPublicKeyInfo. */ + static byte[] exportPublicKeyBytes(Pointer hKey) throws MtlsMsiException { + WString blobType = new WString(RSAPUBLICBLOB); + IntByReference pcbResult = new IntByReference(0); + + // Query size. + int ret = NCryptLibrary.INSTANCE.NCryptExportKey( + hKey, null, blobType, null, null, 0, pcbResult, 0); + if (ret != NCryptLibrary.ERROR_SUCCESS) { + throw new MtlsMsiException( + String.format("NCryptExportKey (size query) failed: 0x%x", ret)); + } + + Memory blob = new Memory(pcbResult.getValue()); + ret = NCryptLibrary.INSTANCE.NCryptExportKey( + hKey, null, blobType, null, blob, pcbResult.getValue(), pcbResult, 0); + if (ret != NCryptLibrary.ERROR_SUCCESS) { + throw new MtlsMsiException( + String.format("NCryptExportKey failed: 0x%x", ret)); + } + + return blob.getByteArray(0, pcbResult.getValue()); + } + + /** + * Returns [modulus, publicExponent] parsed from the RSAPUBLICBLOB. + * BCRYPT_RSAKEY_BLOB format (24-byte header): + *
Magic(4) BitLength(4) cbPublicExp(4) cbModulus(4) cbPrime1(4) cbPrime2(4)
+ * followed by PublicExponent bytes then Modulus bytes. + */ + static BigInteger[] exportPublicKey(Pointer hKey) throws MtlsMsiException { + byte[] blob = exportPublicKeyBytes(hKey); + if (blob.length < 24) { + throw new MtlsMsiException("RSAPUBLICBLOB too short: " + blob.length); + } + + ByteBuffer bb = ByteBuffer.wrap(blob).order(ByteOrder.LITTLE_ENDIAN); + bb.getInt(); // magic + bb.getInt(); // bitLength + int cbPublicExp = bb.getInt(); + int cbModulus = bb.getInt(); + // skip cbPrime1, cbPrime2 + bb.position(24); + + byte[] expBytes = new byte[cbPublicExp]; + bb.get(expBytes); + byte[] modBytes = new byte[cbModulus]; + bb.get(modBytes); + + return new BigInteger[] { + new BigInteger(1, modBytes), // [0] = modulus + new BigInteger(1, expBytes) // [1] = publicExponent + }; + } + + private static Pointer openProvider() throws MtlsMsiException { + PointerByReference phProvider = new PointerByReference(); + int ret = NCryptLibrary.INSTANCE.NCryptOpenStorageProvider( + phProvider, new WString(MS_SOFTWARE_KSP), 0); + if (ret != NCryptLibrary.ERROR_SUCCESS) { + throw new MtlsMsiException( + String.format("NCryptOpenStorageProvider failed: 0x%x", ret)); + } + return phProvider.getValue(); + } + + private static void setDwordProperty(Pointer hKey, String propName, int value) + throws MtlsMsiException { + byte[] buf = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN).putInt(value).array(); + int ret = NCryptLibrary.INSTANCE.NCryptSetProperty( + hKey, new WString(propName), buf, buf.length, NCryptLibrary.NCRYPT_SILENT_FLAG); + if (ret != NCryptLibrary.ERROR_SUCCESS) { + throw new MtlsMsiException( + String.format("NCryptSetProperty(%s) failed: 0x%x", propName, ret)); + } + } + + private static boolean isWindows() { + String os = System.getProperty("os.name", "").toLowerCase(); + return os.contains("windows"); + } +} diff --git a/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/CngProvider.java b/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/CngProvider.java new file mode 100644 index 00000000..05204b0f --- /dev/null +++ b/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/CngProvider.java @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.aad.msal4j.mtls; + +import java.security.Provider; +import java.security.Security; + +/** + * A {@link Provider} that routes RSA signature operations for {@link CngRsaPrivateKey} + * keys through Windows CNG ({@code NCryptSignHash}) via JNA. + * + *

Install once per JVM via {@link #installIfAbsent()} before creating an + * {@code SSLContext} that uses a {@link CngRsaPrivateKey}. JSSE will call + * {@code Signature.getInstance("SHA256withRSA")} (TLS 1.2) or + * {@code Signature.getInstance("RSASSA-PSS")} (TLS 1.3); with this provider at + * high priority, {@link CngSignatureSpi} intercepts the call and signs via + * {@code NCryptSignHash} instead of requiring an exportable private exponent.

+ * + *

For non-{@link CngRsaPrivateKey} keys, {@link CngSignatureSpi} automatically + * delegates to the next available provider, so installing this provider does not + * break other RSA signing in the same JVM.

+ */ +public final class CngProvider extends Provider { + + private static final long serialVersionUID = 1L; + private static final String PROVIDER_NAME = "CNG"; + private static final double PROVIDER_VERSION = 1.0; + private static final String PROVIDER_INFO = "Windows CNG JNA provider for JSSE mTLS"; + + public CngProvider() { + super(PROVIDER_NAME, PROVIDER_VERSION, PROVIDER_INFO); + + put("Signature.SHA256withRSA", CngSignatureSpi.Sha256WithRsa.class.getName()); + put("Signature.SHA1withRSA", CngSignatureSpi.Sha1WithRsa.class.getName()); + put("Signature.RSASSA-PSS", CngSignatureSpi.RsaSsaPss.class.getName()); + // Aliases used by some TLS implementations. + put("Alg.Alias.Signature.SHA256withRSAandMGF1", "RSASSA-PSS"); + put("Alg.Alias.Signature.SHA-256withRSA", "SHA256withRSA"); + put("Alg.Alias.Signature.SHA1withRSA", "SHA1withRSA"); + } + + /** + * Installs this provider at position 1 (highest priority) if it is not already + * registered. Safe to call multiple times. + */ + public static void installIfAbsent() { + if (Security.getProvider(PROVIDER_NAME) == null) { + Security.insertProviderAt(new CngProvider(), 1); + } + } +} diff --git a/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/CngRsaPrivateKey.java b/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/CngRsaPrivateKey.java new file mode 100644 index 00000000..efdba05e --- /dev/null +++ b/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/CngRsaPrivateKey.java @@ -0,0 +1,99 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.aad.msal4j.mtls; + +import com.sun.jna.Pointer; + +import java.math.BigInteger; +import java.security.interfaces.RSAPrivateKey; + +/** + * A non-exportable RSA private key backed by a Windows CNG {@code NCRYPT_KEY_HANDLE}. + * + *

This key implements {@link RSAPrivateKey} so that JSSE recognizes it as an RSA key + * and selects RSA cipher suites. The private exponent is {@code null} and + * {@link #getEncoded()} returns {@code null} because the key material never leaves the + * CNG key storage provider (KeyGuard VBS isolation).

+ * + *

Signing is performed by {@link CngKeyGuard#signPkcs1} / {@link CngKeyGuard#signPss}, + * dispatched from {@link CngSignatureSpi}.

+ * + *

Callers must call {@link #close()} when done to free the CNG handle.

+ */ +public final class CngRsaPrivateKey implements RSAPrivateKey, AutoCloseable { + + private static final long serialVersionUID = 1L; + + private final Pointer handle; + private final BigInteger modulus; + private final int publicExponent; + private volatile boolean closed; + + CngRsaPrivateKey(Pointer handle, BigInteger modulus, int publicExponent) { + this.handle = handle; + this.modulus = modulus; + this.publicExponent = publicExponent; + } + + /** The underlying {@code NCRYPT_KEY_HANDLE}. Never null while the key is open. */ + public Pointer getHandle() { + if (closed) throw new IllegalStateException("CNG key handle has been closed"); + return handle; + } + + // ─── RSAKey ─────────────────────────────────────────────────────────────── + + /** Returns the RSA modulus (from the exported RSAPUBLICBLOB — public information). */ + @Override + public BigInteger getModulus() { + return modulus; + } + + /** + * Always returns {@code null}. The private exponent is non-exportable from the + * KeyGuard-protected CNG key; signing is delegated to {@code NCryptSignHash}. + */ + @Override + public BigInteger getPrivateExponent() { + return null; + } + + // ─── Key ────────────────────────────────────────────────────────────────── + + @Override + public String getAlgorithm() { return "RSA"; } + + /** Returns {@code null} — non-exportable key has no serializable encoding. */ + @Override + public String getFormat() { return null; } + + /** Returns {@code null} — non-exportable key has no serializable encoding. */ + @Override + public byte[] getEncoded() { return null; } + + // ─── AutoCloseable ──────────────────────────────────────────────────────── + + /** + * Frees the underlying CNG key handle via {@code NCryptFreeObject}. + * The key remains persisted in the KSP; only the in-process handle is released. + */ + @Override + public void close() { + if (!closed) { + closed = true; + NCryptLibrary.INSTANCE.NCryptFreeObject(handle); + } + } + + @Override + @SuppressWarnings("deprecation") + protected void finalize() { + close(); + } + + /** The public exponent (e.g. 65537 = 0x10001). */ + public int getPublicExponent() { + return publicExponent; + } +} diff --git a/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/CngSignatureSpi.java b/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/CngSignatureSpi.java new file mode 100644 index 00000000..2384728a --- /dev/null +++ b/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/CngSignatureSpi.java @@ -0,0 +1,263 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.aad.msal4j.mtls; + +import com.sun.jna.Pointer; + +import java.security.AlgorithmParameters; +import java.security.InvalidAlgorithmParameterException; +import java.security.InvalidKeyException; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.security.PrivateKey; +import java.security.Provider; +import java.security.Security; +import java.security.SignatureException; +import java.security.SignatureSpi; +import java.security.spec.AlgorithmParameterSpec; +import java.security.spec.MGF1ParameterSpec; +import java.security.spec.PSSParameterSpec; +import java.util.Arrays; + +/** + * {@link SignatureSpi} implementations that delegate signing to Windows CNG via JNA. + * + *

Two families are provided:

+ *
    + *
  • {@link Sha256WithRsa} / {@link Sha1WithRsa} — PKCS#1 v1.5 padding + * (TLS 1.2 client cert verify, and CSR signing fallback)
  • + *
  • {@link RsaSsaPss} — RSASSA-PSS with configurable parameters + * (TLS 1.3 client cert verify)
  • + *
+ * + *

For non-{@link CngRsaPrivateKey} keys, each SPI delegates to the next available + * provider so that installing {@link CngProvider} at high priority does not break other + * code in the same JVM that signs with regular (exportable) RSA keys.

+ */ +abstract class CngSignatureSpi extends SignatureSpi { + + // ─── Concrete algorithms ─────────────────────────────────────────────────── + + /** SHA-256 with RSA PKCS#1 v1.5 */ + public static class Sha256WithRsa extends CngSignatureSpi { + public Sha256WithRsa() { super("SHA-256", "SHA256", false, 32); } + } + + /** SHA-1 with RSA PKCS#1 v1.5 */ + public static class Sha1WithRsa extends CngSignatureSpi { + public Sha1WithRsa() { super("SHA-1", "SHA1", false, 20); } + } + + /** RSASSA-PSS — algorithm parameters set via {@link #engineSetParameter(AlgorithmParameterSpec)} */ + public static class RsaSsaPss extends CngSignatureSpi { + public RsaSsaPss() { super("SHA-256", "SHA256", true, 32); } + } + + // ─── State ──────────────────────────────────────────────────────────────── + + private final boolean pss; + + // CNG mode + private Pointer cngHandle; + private MessageDigest digest; + private String hashJce; // Java algorithm name (e.g. "SHA-256") + private String hashCng; // CNG algorithm name (e.g. "SHA256") + private int saltLen; + + // Delegation mode (non-CNG keys) + private java.security.Signature delegate; + + CngSignatureSpi(String hashJce, String hashCng, boolean pss, int saltLen) { + this.hashJce = hashJce; + this.hashCng = hashCng; + this.pss = pss; + this.saltLen = saltLen; + } + + // ─── SignatureSpi ───────────────────────────────────────────────────────── + + @Override + protected void engineInitVerify(java.security.PublicKey publicKey) + throws InvalidKeyException { + // CNG only handles signing (NCryptSignHash). For verification (server cert + // validation, etc.) we deliberately throw InvalidKeyException so that + // Signature.Delegate.chooseProvider() skips this SPI and falls through to + // SunRsaSign or another standard provider that handles RSA/ECDSA verification. + throw new InvalidKeyException( + "CngSignatureSpi does not support verification; use SunRsaSign"); + } + + @Override + protected void engineInitSign(PrivateKey key) throws InvalidKeyException { + if (key instanceof CngRsaPrivateKey) { + Pointer h; + try { + h = ((CngRsaPrivateKey) key).getHandle(); + } catch (IllegalStateException e) { + throw new InvalidKeyException("CNG key is closed: " + e.getMessage(), e); + } + if (h == null) { + throw new InvalidKeyException("CNG key handle is null (key may be closed or invalid)"); + } + cngHandle = h; + delegate = null; + try { + digest = MessageDigest.getInstance(hashJce); + } catch (NoSuchAlgorithmException e) { + throw new InvalidKeyException("MessageDigest " + hashJce + " not available", e); + } + } else { + // Delegate to the next provider that handles this algorithm. + cngHandle = null; + delegate = null; + try { + delegate = getNextProviderSignature(); + delegate.initSign(key); + } catch (NoSuchAlgorithmException e) { + throw new InvalidKeyException("No fallback provider: " + e.getMessage(), e); + } + } + } + + @Override + protected void engineUpdate(byte b) throws SignatureException { + if (cngHandle != null) { + digest.update(b); + } else if (delegate != null) { + delegate.update(b); + } else { + throw new SignatureException( + "CngSignatureSpi.engineUpdate called before engineInitSign — " + + "Signature object was not properly initialized"); + } + } + + @Override + protected void engineUpdate(byte[] b, int off, int len) throws SignatureException { + if (cngHandle != null) { + digest.update(b, off, len); + } else if (delegate != null) { + delegate.update(b, off, len); + } else { + throw new SignatureException( + "CngSignatureSpi.engineUpdate called before engineInitSign — " + + "Signature object was not properly initialized"); + } + } + + @Override + protected byte[] engineSign() throws SignatureException { + if (cngHandle != null) { + byte[] hash = digest.digest(); + try { + if (pss) { + return CngKeyGuard.signPss(cngHandle, hash, hashCng, saltLen); + } else { + return CngKeyGuard.signPkcs1(cngHandle, hash, hashCng); + } + } catch (MtlsMsiException e) { + throw new SignatureException("CNG signing failed: " + e.getMessage(), e); + } + } else { + return delegate.sign(); + } + } + + @Override + protected boolean engineVerify(byte[] sigBytes) throws SignatureException { + if (delegate != null) { + return delegate.verify(sigBytes); + } + // Verification is not needed for client-auth TLS or CSR generation. + throw new SignatureException("CngSignatureSpi does not support verify (CNG-backed keys)"); + } + + @Override + protected void engineSetParameter(AlgorithmParameterSpec params) + throws InvalidAlgorithmParameterException { + if (params instanceof PSSParameterSpec) { + PSSParameterSpec pssSpec = (PSSParameterSpec) params; + hashJce = pssSpec.getDigestAlgorithm(); + hashCng = toCngHashName(pssSpec.getDigestAlgorithm()); + saltLen = pssSpec.getSaltLength(); + if (cngHandle != null) { + // Re-initialize the digest with the new hash algorithm. + try { + digest = MessageDigest.getInstance(hashJce); + } catch (NoSuchAlgorithmException e) { + throw new InvalidAlgorithmParameterException( + "MessageDigest " + hashJce + " not available", e); + } + } else if (delegate != null) { + // Forward PSS params to the delegating provider's Signature instance. + try { + delegate.setParameter(params); + } catch (Exception e) { + throw new InvalidAlgorithmParameterException(e.getMessage(), e); + } + } + } else if (delegate != null && params != null) { + try { + delegate.setParameter(params); + } catch (Exception e) { + throw new InvalidAlgorithmParameterException(e.getMessage(), e); + } + } + } + + @Override + @SuppressWarnings("deprecation") + protected void engineSetParameter(String param, Object value) { + // Legacy method — no-op, required by abstract superclass. + } + + @Override + @SuppressWarnings("deprecation") + protected Object engineGetParameter(String param) { + return null; + } + + @Override + protected AlgorithmParameters engineGetParameters() { + if (pss && delegate == null) { + try { + AlgorithmParameters ap = AlgorithmParameters.getInstance("RSASSA-PSS"); + ap.init(new PSSParameterSpec(hashJce, "MGF1", + new MGF1ParameterSpec(hashJce), saltLen, 1)); + return ap; + } catch (Exception e) { + return null; + } + } + if (delegate != null) { + return delegate.getParameters(); + } + return null; + } + + // ─── Helpers ────────────────────────────────────────────────────────────── + + private java.security.Signature getNextProviderSignature() throws NoSuchAlgorithmException { + String algName = pss ? "RSASSA-PSS" : (hashCng.equals("SHA256") ? "SHA256withRSA" : "SHA1withRSA"); + for (Provider p : Security.getProviders()) { + if (p instanceof CngProvider) continue; + if (p.getService("Signature", algName) != null) { + return java.security.Signature.getInstance(algName, p); + } + } + throw new NoSuchAlgorithmException( + "No provider for " + algName + " besides CngProvider"); + } + + private static String toCngHashName(String jceHashName) { + if (jceHashName == null) return "SHA256"; + switch (jceHashName.toUpperCase().replace("-", "")) { + case "SHA1": return "SHA1"; + case "SHA256": return "SHA256"; + case "SHA384": return "SHA384"; + case "SHA512": return "SHA512"; + default: return "SHA256"; + } + } +} diff --git a/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/ImdsV2Client.java b/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/ImdsV2Client.java new file mode 100644 index 00000000..bb29f8f8 --- /dev/null +++ b/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/ImdsV2Client.java @@ -0,0 +1,256 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.aad.msal4j.mtls; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.net.HttpURLConnection; +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.util.UUID; + +/** + * HTTP client for the Azure IMDSv2 credential issuance endpoints used in mTLS PoP. + * + *

Mirrors msal-go's {@code getPlatformMetadata()} and {@code issueCredential()} in + * {@code imdsv2.go}.

+ * + *
    + *
  • GET {@code http://169.254.169.254/metadata/identity/getplatformmetadata?cred-api-version=2.0} + * → {@link PlatformMetadata}
  • + *
  • POST {@code http://169.254.169.254/metadata/identity/issuecredential?cred-api-version=2.0} + * body: {@code {"csr":"","attestation_token":""}} + * → {@link CredentialResponse}
  • + *
+ */ +final class ImdsV2Client { + + private static final String PLATFORM_METADATA_URL = + "http://169.254.169.254/metadata/identity/getplatformmetadata?cred-api-version=2.0"; + private static final String ISSUE_CREDENTIAL_URL = + "http://169.254.169.254/metadata/identity/issuecredential?cred-api-version=2.0"; + + private static final int CONNECT_TIMEOUT_MS = 5_000; + private static final int READ_TIMEOUT_MS = 30_000; + + private ImdsV2Client() {} + + // ─── Response types ─────────────────────────────────────────────────────── + + /** Deserialized response from {@code /getplatformmetadata}. */ + static final class PlatformMetadata { + final String clientId; + final String tenantId; + final String vmId; // from cuId.vmId + final String vmssId; // from cuId.vmssId + final String attestationEndpoint; + + PlatformMetadata(String clientId, String tenantId, String vmId, + String vmssId, String attestationEndpoint) { + this.clientId = clientId; + this.tenantId = tenantId; + this.vmId = vmId; + this.vmssId = vmssId; + this.attestationEndpoint = attestationEndpoint; + } + + /** The cuId string used as the key name suffix (matches msal-go logic). */ + String cuIdString() { + return (vmId != null && !vmId.isEmpty()) ? vmId : clientId; + } + } + + /** Deserialized response from {@code /issuecredential}. */ + static final class CredentialResponse { + final String certificate; // base64-encoded DER + final String mtlsAuthenticationEndpoint; + final String clientId; + final String tenantId; + final String regionalTokenUrl; + + CredentialResponse(String certificate, String mtlsAuthenticationEndpoint, + String clientId, String tenantId, String regionalTokenUrl) { + this.certificate = certificate; + this.mtlsAuthenticationEndpoint = mtlsAuthenticationEndpoint; + this.clientId = clientId; + this.tenantId = tenantId; + this.regionalTokenUrl = regionalTokenUrl; + } + } + + // ─── API ───────────────────────────────────────────────────────────────── + + static PlatformMetadata getPlatformMetadata() throws MtlsMsiException { + String json = httpGet(PLATFORM_METADATA_URL); + return parsePlatformMetadata(json); + } + + static CredentialResponse issueCredential(String csrBase64, String attestationToken) + throws MtlsMsiException { + // Build JSON body manually to avoid adding a JSON library dependency. + StringBuilder body = new StringBuilder("{\"csr\":\""); + body.append(csrBase64).append("\""); + if (attestationToken != null && !attestationToken.isEmpty()) { + body.append(",\"attestation_token\":\"").append(attestationToken).append("\""); + } + body.append("}"); + String json = httpPost(ISSUE_CREDENTIAL_URL, body.toString()); + return parseCredentialResponse(json); + } + + // ─── HTTP helpers ───────────────────────────────────────────────────────── + + private static String httpGet(String urlStr) throws MtlsMsiException { + try { + HttpURLConnection conn = openConnection(urlStr); + conn.setRequestMethod("GET"); + conn.setRequestProperty("Metadata", "true"); + conn.setRequestProperty("x-ms-client-request-id", UUID.randomUUID().toString()); + return readResponse(conn, urlStr); + } catch (IOException e) { + throw new MtlsMsiException("IMDS GET " + urlStr + " failed: " + e.getMessage(), e); + } + } + + private static String httpPost(String urlStr, String jsonBody) throws MtlsMsiException { + try { + HttpURLConnection conn = openConnection(urlStr); + conn.setRequestMethod("POST"); + conn.setRequestProperty("Metadata", "true"); + conn.setRequestProperty("Content-Type", "application/json"); + conn.setRequestProperty("x-ms-client-request-id", UUID.randomUUID().toString()); + conn.setDoOutput(true); + byte[] bodyBytes = jsonBody.getBytes(StandardCharsets.UTF_8); + conn.setRequestProperty("Content-Length", String.valueOf(bodyBytes.length)); + try (OutputStream os = conn.getOutputStream()) { + os.write(bodyBytes); + } + return readResponse(conn, urlStr); + } catch (IOException e) { + throw new MtlsMsiException("IMDS POST " + urlStr + " failed: " + e.getMessage(), e); + } + } + + private static HttpURLConnection openConnection(String urlStr) throws IOException { + HttpURLConnection conn = (HttpURLConnection) new URL(urlStr).openConnection(); + conn.setConnectTimeout(CONNECT_TIMEOUT_MS); + conn.setReadTimeout(READ_TIMEOUT_MS); + return conn; + } + + private static String readResponse(HttpURLConnection conn, String urlStr) + throws IOException, MtlsMsiException { + int status = conn.getResponseCode(); + InputStream stream = status >= 400 ? conn.getErrorStream() : conn.getInputStream(); + String body; + try (BufferedReader reader = + new BufferedReader(new InputStreamReader(stream, StandardCharsets.UTF_8))) { + StringBuilder sb = new StringBuilder(); + String line; + while ((line = reader.readLine()) != null) sb.append(line); + body = sb.toString(); + } + if (status != 200) { + throw new MtlsMsiException( + "IMDS " + urlStr + " returned HTTP " + status + ": " + body); + } + return body; + } + + // ─── JSON parsers (no external library) ─────────────────────────────────── + + private static PlatformMetadata parsePlatformMetadata(String json) throws MtlsMsiException { + String clientId = extractString(json, "clientId"); + String tenantId = extractString(json, "tenantId"); + String attestationEndpoint = extractString(json, "attestationEndpoint"); + + // cuId is a nested object: {"vmId":"...","vmssId":"..."} + String vmId = null; + String vmssId = null; + int cuIdIdx = json.indexOf("\"cuId\""); + if (cuIdIdx >= 0) { + int objStart = json.indexOf('{', cuIdIdx); + int objEnd = json.indexOf('}', objStart); + if (objStart >= 0 && objEnd > objStart) { + String cuIdObj = json.substring(objStart, objEnd + 1); + vmId = extractString(cuIdObj, "vmId"); + vmssId = extractString(cuIdObj, "vmssId"); + } + } + + if (clientId == null || clientId.isEmpty()) { + throw new MtlsMsiException( + "IMDS /getplatformmetadata returned empty clientId. " + + "Ensure Managed Identity is enabled on this VM."); + } + if (tenantId == null || tenantId.isEmpty()) { + throw new MtlsMsiException( + "IMDS /getplatformmetadata returned empty tenantId."); + } + + return new PlatformMetadata(clientId, tenantId, vmId, vmssId, attestationEndpoint); + } + + private static CredentialResponse parseCredentialResponse(String json) throws MtlsMsiException { + String certificate = extractString(json, "certificate"); + String mtlsAuthenticationEndpoint = extractString(json, "mtls_authentication_endpoint"); + String clientId = extractString(json, "client_id"); + String tenantId = extractString(json, "tenant_id"); + String regionalTokenUrl = extractString(json, "regional_token_url"); + + if (certificate == null || certificate.isEmpty()) { + throw new MtlsMsiException( + "IMDS /issuecredential returned empty certificate: " + json); + } + + return new CredentialResponse(certificate, mtlsAuthenticationEndpoint, + clientId, tenantId, regionalTokenUrl); + } + + /** + * Minimal JSON string extractor. Handles the well-formed JSON output from IMDS. + * Returns null if the key is not present or its value is not a JSON string. + * + *

Escape sequences are processed sequentially (one pass) to avoid incorrect + * behaviour when {@code \\} (escaped backslash) is followed by {@code t}, {@code n}, etc.

+ */ + static String extractString(String json, String key) { + String search = "\"" + key + "\""; + int keyIdx = json.indexOf(search); + if (keyIdx < 0) return null; + int colonIdx = json.indexOf(':', keyIdx + search.length()); + if (colonIdx < 0) return null; + + int valueStart = colonIdx + 1; + while (valueStart < json.length() && Character.isWhitespace(json.charAt(valueStart))) { + valueStart++; + } + if (valueStart >= json.length() || json.charAt(valueStart) != '"') return null; + + // Build the unescaped value with a single sequential pass. + StringBuilder sb = new StringBuilder(); + int i = valueStart + 1; + while (i < json.length()) { + char c = json.charAt(i); + if (c == '\\' && i + 1 < json.length()) { + char next = json.charAt(i + 1); + switch (next) { + case '"': sb.append('"'); i += 2; continue; + case '\\': sb.append('\\'); i += 2; continue; + case 'n': sb.append('\n'); i += 2; continue; + case 'r': sb.append('\r'); i += 2; continue; + case 't': sb.append('\t'); i += 2; continue; + default: sb.append(c); i++; continue; + } + } + if (c == '"') break; + sb.append(c); + i++; + } + return sb.toString(); + } +} diff --git a/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/MtlsBindingCertManager.java b/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/MtlsBindingCertManager.java new file mode 100644 index 00000000..6e468275 --- /dev/null +++ b/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/MtlsBindingCertManager.java @@ -0,0 +1,147 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.aad.msal4j.mtls; + +import java.math.BigInteger; +import java.security.cert.CertificateException; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.io.ByteArrayInputStream; +import java.util.Base64; +import java.util.HashMap; +import java.util.Map; + +/** + * Acquires and caches the mTLS PoP binding certificate for a managed identity. + * + *

The binding certificate is issued by IMDS's {@code /issuecredential} endpoint and + * ties the KeyGuard CNG key to the managed identity. It is valid for several hours; + * this class caches the binding and refreshes it 5 minutes before expiry.

+ * + *

The full flow mirrors msal-go's {@code buildMtlsBindingInfo()}:

+ *
    + *
  1. GET {@code /getplatformmetadata} → clientId, tenantId, cuId, attestationEndpoint
  2. + *
  3. Get or create KeyGuard CNG key (persisted in Software KSP, USER scope)
  4. + *
  5. Generate PKCS#10 CSR with the CNG key (RSASSA-PSS SHA-256)
  6. + *
  7. If {@code attestationEndpoint} present: call {@code AttestationClientLib.dll} for MAA JWT
  8. + *
  9. POST CSR (+ MAA JWT) to {@code /issuecredential} → DER certificate (base64)
  10. + *
  11. Parse certificate, build {@link MtlsBindingInfo}
  12. + *
+ */ +final class MtlsBindingCertManager { + + private static final Map CACHE = new HashMap<>(); + private static final Object CACHE_LOCK = new Object(); + + private MtlsBindingCertManager() {} + + /** + * Returns a valid (non-expired) {@link MtlsBindingInfo} for the managed identity on + * this VM, fetching/refreshing from IMDS as needed. + * + * @param withAttestation whether to request a MAA attestation JWT and include it in the + * {@code /issuecredential} request (requires Trusted Launch VM) + * @return binding info containing the CNG private key and the IMDS-issued certificate + * @throws MtlsMsiException on any error + */ + static MtlsBindingInfo getOrCreate(boolean withAttestation) throws MtlsMsiException { + // Fetch platform metadata to determine the cache key. + ImdsV2Client.PlatformMetadata meta = ImdsV2Client.getPlatformMetadata(); + String cacheKey = meta.clientId + "|" + meta.tenantId; + + synchronized (CACHE_LOCK) { + MtlsBindingInfo existing = CACHE.get(cacheKey); + if (existing != null && !existing.isExpired()) { + return existing; + } + CACHE.remove(cacheKey); + } + + // Build new binding info outside the lock (slow: JNA + HTTP). + MtlsBindingInfo info = buildBindingInfo(meta, withAttestation); + + synchronized (CACHE_LOCK) { + CACHE.put(cacheKey, info); + } + return info; + } + + private static MtlsBindingInfo buildBindingInfo(ImdsV2Client.PlatformMetadata meta, + boolean withAttestation) + throws MtlsMsiException { + + String cuId = meta.cuIdString(); // vmId if present, else clientId + + // 1. Get or create the KeyGuard CNG key. + CngRsaPrivateKey privateKey = CngKeyGuard.getOrCreateKey("MSALMtlsKey_" + cuId); + + // 2. Export the public key components for CSR construction. + BigInteger[] pubKey; + try { + pubKey = CngKeyGuard.exportPublicKey(privateKey.getHandle()); + } catch (MtlsMsiException e) { + privateKey.close(); + throw e; + } + + // 3. Generate PKCS#10 CSR. + String csrBase64; + try { + csrBase64 = Pkcs10Builder.generate( + privateKey.getHandle(), + pubKey[0], // modulus + pubKey[1].intValue(), // publicExponent + meta.clientId, + meta.tenantId, + meta.vmId, + meta.vmssId); + } catch (MtlsMsiException e) { + privateKey.close(); + throw e; + } + + // 4. MAA attestation (if endpoint is known and attestation requested). + String attestationToken = null; + if (withAttestation && meta.attestationEndpoint != null && !meta.attestationEndpoint.isEmpty()) { + try { + attestationToken = CngKeyGuard.getAttestationToken( + privateKey.getHandle(), meta.attestationEndpoint, meta.clientId); + } catch (MtlsMsiException e) { + privateKey.close(); + throw e; + } + } + + // 5. Issue credential from IMDS. + ImdsV2Client.CredentialResponse credResp; + try { + credResp = ImdsV2Client.issueCredential(csrBase64, attestationToken); + } catch (MtlsMsiException e) { + privateKey.close(); + throw e; + } + + // 6. Parse the DER certificate. + X509Certificate cert; + try { + byte[] certDer = Base64.getDecoder().decode(credResp.certificate); + cert = (X509Certificate) CertificateFactory.getInstance("X.509") + .generateCertificate(new ByteArrayInputStream(certDer)); + } catch (CertificateException | IllegalArgumentException e) { + privateKey.close(); + throw new MtlsMsiException("Failed to parse IMDS certificate: " + e.getMessage(), e); + } + + String resolvedClientId = notEmpty(credResp.clientId, meta.clientId); + String resolvedTenantId = notEmpty(credResp.tenantId, meta.tenantId); + String endpoint = notEmpty(credResp.mtlsAuthenticationEndpoint, + "https://mtlsauth.microsoft.com"); + + return new MtlsBindingInfo(privateKey, cert, endpoint, resolvedClientId, resolvedTenantId); + } + + private static String notEmpty(String preferred, String fallback) { + return (preferred != null && !preferred.isEmpty()) ? preferred : fallback; + } +} diff --git a/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/MtlsBindingInfo.java b/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/MtlsBindingInfo.java new file mode 100644 index 00000000..20f206a9 --- /dev/null +++ b/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/MtlsBindingInfo.java @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.aad.msal4j.mtls; + +import java.security.cert.X509Certificate; +import java.util.Date; + +/** + * Holds the mTLS binding information: the CNG-backed private key and the IMDS-issued + * X.509 certificate that bind together for a particular managed identity. + */ +final class MtlsBindingInfo { + + final CngRsaPrivateKey privateKey; + final X509Certificate certificate; + final String mtlsEndpoint; + final String clientId; + final String tenantId; + final Date expiresAt; + + MtlsBindingInfo(CngRsaPrivateKey privateKey, X509Certificate certificate, + String mtlsEndpoint, String clientId, String tenantId) { + this.privateKey = privateKey; + this.certificate = certificate; + this.mtlsEndpoint = mtlsEndpoint; + this.clientId = clientId; + this.tenantId = tenantId; + // Expire 5 minutes before the cert's notAfter, matching msal-go and MSAL.NET. + long notAfterMs = certificate.getNotAfter().getTime(); + this.expiresAt = new Date(notAfterMs - 5L * 60 * 1000); + } + + boolean isExpired() { + return new Date().after(expiresAt); + } +} diff --git a/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/MtlsMsiClient.java b/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/MtlsMsiClient.java new file mode 100644 index 00000000..6246a587 --- /dev/null +++ b/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/MtlsMsiClient.java @@ -0,0 +1,398 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.aad.msal4j.mtls; + +import com.sun.jna.Pointer; + +import javax.net.ssl.HttpsURLConnection; +import javax.net.ssl.KeyManager; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.X509KeyManager; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.net.Socket; +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.security.KeyManagementException; +import java.security.NoSuchAlgorithmException; +import java.security.Principal; +import java.security.PrivateKey; +import java.security.cert.CertificateEncodingException; +import java.security.cert.X509Certificate; +import java.util.Base64; +import java.util.List; +import java.util.UUID; + +/** + * Acquires mTLS Proof-of-Possession tokens for Azure Managed Identity. + * + *

Uses JNA to call Windows CNG ({@code ncrypt.dll}) and, optionally, + * {@code AttestationClientLib.dll} directly from the JVM. No .NET runtime or subprocess is + * required.

+ * + *

Architecture

+ *
    + *
  1. Get or create a KeyGuard CNG key via {@link CngKeyGuard} (JNA → {@code ncrypt.dll}).
  2. + *
  3. Generate a PKCS#10 CSR signed with that key ({@link Pkcs10Builder}).
  4. + *
  5. Optionally obtain an MAA attestation JWT from {@code AttestationClientLib.dll}.
  6. + *
  7. POST CSR (+ attestation JWT) to IMDS {@code /issuecredential} → X.509 certificate.
  8. + *
  9. Build a JSSE {@link SSLContext} backed by a custom {@link CngProvider} that signs the + * TLS handshake using {@code NCryptSignHash} — the private key never leaves CNG.
  10. + *
  11. POST to the regional mTLS token endpoint and return the {@code mtls_pop} token.
  12. + *
+ * + *

Requirements

+ *
    + *
  • Windows Azure VM with Managed Identity enabled
  • + *
  • {@code AttestationClientLib.dll} on {@code PATH} (from the + * {@code Microsoft.Azure.Security.KeyGuardAttestation} NuGet package) when + * {@code withAttestation=true} and the VM is Trusted Launch
  • + *
+ * + *

Usage

+ *
{@code
+ * MtlsMsiClient client = new MtlsMsiClient();
+ * MtlsMsiHelperResult result = client.acquireToken(
+ *     "https://management.azure.com",   // resource
+ *     "SystemAssigned",                  // identityType (informational — IMDS determines identity)
+ *     null,                              // identityId (null for system-assigned)
+ *     true,                              // withAttestation
+ *     UUID.randomUUID().toString()       // correlationId (optional)
+ * );
+ * String accessToken = result.getAccessToken();
+ * }
+ */ +public class MtlsMsiClient { + + /** Creates a new client. */ + public MtlsMsiClient() {} + + /** + * Acquires an mTLS PoP token for a Managed Identity resource. + * + *

Note: {@code identityType} and {@code identityId} are accepted for API compatibility but + * are not forwarded to IMDS — the VM's managed identity configuration determines which + * identity is used. For UserAssigned identities, configure the VM with the desired identity + * before calling this method.

+ * + * @param resource Azure resource URI (e.g. {@code https://management.azure.com}) + * @param identityType Accepted for compatibility; IMDS ignores it in the JNA flow + * @param identityId Accepted for compatibility; IMDS ignores it in the JNA flow + * @param withAttestation Whether to request MAA attestation (requires Trusted Launch VM with + * {@code AttestationClientLib.dll} on PATH) + * @param correlationId Optional GUID for telemetry; may be {@code null} + * @return the mTLS PoP token result + * @throws MtlsMsiException on key creation, IMDS, or token acquisition failure + */ + public MtlsMsiHelperResult acquireToken( + String resource, + String identityType, + String identityId, + boolean withAttestation, + String correlationId) throws MtlsMsiException { + + if (resource == null || resource.isEmpty()) { + throw new MtlsMsiException("resource must not be null or empty"); + } + + MtlsBindingInfo binding = MtlsBindingCertManager.getOrCreate(withAttestation); + SSLSocketFactory sslFactory = buildSslSocketFactory(binding, false); + + String tokenUrl = buildTokenUrl(binding.mtlsEndpoint, binding.tenantId); + String requestBody = buildTokenRequestBody(binding.clientId, resource); + String requestId = correlationId != null ? correlationId : UUID.randomUUID().toString(); + + String responseJson = httpsPost(tokenUrl, requestBody, "application/x-www-form-urlencoded", + sslFactory, requestId); + return parseTokenResponse(responseJson, binding); + } + + /** + * Makes a downstream HTTP call over mutual TLS using the KeyGuard-bound certificate + * and an mTLS PoP access token. + * + *

Important: The downstream server must be configured for + * required mutual TLS — it must send a TLS {@code CertificateRequest} during the handshake. + * Public Azure APIs (Graph, Key Vault, etc.) use optional mTLS and will NOT trigger + * client certificate presentation. Use this only with servers that require a client cert.

+ * + * @param url The full URL to call + * @param method HTTP method ({@code GET}, {@code POST}, etc.) + * @param token The mTLS PoP access token for the Authorization header + * @param body Request body (may be {@code null}) + * @param contentType Content-Type (defaults to {@code application/json} if null) + * @param extraHeaders Extra headers in {@code "Name: Value"} format (may be null) + * @param resource Azure resource URI (used to resolve binding if not cached) + * @param identityType Accepted for compatibility; IMDS ignores it in the JNA flow + * @param identityId Accepted for compatibility; IMDS ignores it in the JNA flow + * @param withAttestation Whether to include attestation when refreshing the binding cert + * @param correlationId Optional GUID for telemetry; may be {@code null} + * @param allowInsecureTls Skip server TLS cert validation (for self-signed certs in testing ONLY) + * @return the HTTP response from the downstream server + * @throws MtlsMsiException if the binding cert cannot be acquired or the request fails + */ + public MtlsMsiHttpResponse httpRequest( + String url, + String method, + String token, + String body, + String contentType, + List extraHeaders, + String resource, + String identityType, + String identityId, + boolean withAttestation, + String correlationId, + boolean allowInsecureTls) throws MtlsMsiException { + + MtlsBindingInfo binding = MtlsBindingCertManager.getOrCreate(withAttestation); + SSLSocketFactory sslFactory = buildSslSocketFactory(binding, allowInsecureTls); + + String requestId = correlationId != null ? correlationId : UUID.randomUUID().toString(); + return httpsRequest(url, method != null ? method : "GET", token, body, + contentType != null ? contentType : "application/json", + extraHeaders, sslFactory, requestId); + } + + // ─── Token request helpers ───────────────────────────────────────────────── + + private static String buildTokenUrl(String mtlsEndpoint, String tenantId) { + String base = mtlsEndpoint.endsWith("/") + ? mtlsEndpoint.substring(0, mtlsEndpoint.length() - 1) + : mtlsEndpoint; + return base + "/" + tenantId + "/oauth2/v2.0/token"; + } + + private static String buildTokenRequestBody(String clientId, String resource) { + String scope = resource.endsWith("/.default") ? resource : resource + "/.default"; + return "grant_type=client_credentials" + + "&client_id=" + urlEncode(clientId) + + "&scope=" + urlEncode(scope) + + "&token_type=mtls_pop"; + } + + private static String urlEncode(String s) { + try { + return java.net.URLEncoder.encode(s, "UTF-8"); + } catch (java.io.UnsupportedEncodingException e) { + return s; + } + } + + // ─── JSSE mTLS helpers ──────────────────────────────────────────────────── + + private static SSLSocketFactory buildSslSocketFactory(MtlsBindingInfo binding, + boolean insecure) + throws MtlsMsiException { + if (insecure) { + throw new MtlsMsiException("Insecure trust-all TLS mode is not supported."); + } + CngProvider.installIfAbsent(); + + X509KeyManager km = new CngX509KeyManager(binding.privateKey, binding.certificate); + + try { + SSLContext ctx = SSLContext.getInstance("TLS"); + ctx.init(new KeyManager[]{km}, null, null); + return ctx.getSocketFactory(); + } catch (NoSuchAlgorithmException | KeyManagementException e) { + throw new MtlsMsiException("Failed to build mTLS SSLContext: " + e.getMessage(), e); + } + } + + /** X509KeyManager that returns the CNG-backed key and the IMDS certificate. */ + private static final class CngX509KeyManager implements X509KeyManager { + private final CngRsaPrivateKey key; + private final X509Certificate cert; + + CngX509KeyManager(CngRsaPrivateKey key, X509Certificate cert) { + this.key = key; + this.cert = cert; + } + + @Override public String[] getClientAliases(String keyType, Principal[] issuers) { + return new String[]{"mtls"}; + } + @Override public String chooseClientAlias(String[] keyTypes, Principal[] issuers, Socket s) { + return "mtls"; + } + @Override public X509Certificate[] getCertificateChain(String alias) { + return new X509Certificate[]{cert}; + } + @Override public PrivateKey getPrivateKey(String alias) { return key; } + + @Override public String[] getServerAliases(String keyType, Principal[] issuers) { return null; } + @Override public String chooseServerAlias(String keyType, Principal[] issuers, Socket s) { return null; } + } + + // ─── HTTP helpers ───────────────────────────────────────────────────────── + + private static String httpsPost(String urlStr, String body, String contentType, + SSLSocketFactory sslFactory, String requestId) + throws MtlsMsiException { + try { + HttpsURLConnection conn = (HttpsURLConnection) new URL(urlStr).openConnection(); + conn.setSSLSocketFactory(sslFactory); + conn.setRequestMethod("POST"); + conn.setRequestProperty("Content-Type", contentType); + conn.setRequestProperty("x-ms-client-request-id", requestId); + conn.setConnectTimeout(10_000); + conn.setReadTimeout(30_000); + conn.setDoOutput(true); + byte[] bodyBytes = body.getBytes(StandardCharsets.UTF_8); + conn.setRequestProperty("Content-Length", String.valueOf(bodyBytes.length)); + try (OutputStream os = conn.getOutputStream()) { + os.write(bodyBytes); + } + return readHttpsResponse(conn, urlStr); + } catch (IOException e) { + throw new MtlsMsiException("mTLS POST to " + urlStr + " failed: " + e.getMessage(), e); + } + } + + private static MtlsMsiHttpResponse httpsRequest(String urlStr, String method, String token, + String body, String contentType, + List extraHeaders, + SSLSocketFactory sslFactory, String requestId) + throws MtlsMsiException { + try { + HttpsURLConnection conn = (HttpsURLConnection) new URL(urlStr).openConnection(); + conn.setSSLSocketFactory(sslFactory); + conn.setRequestMethod(method); + conn.setRequestProperty("Authorization", "Bearer " + token); + conn.setRequestProperty("Content-Type", contentType); + conn.setRequestProperty("x-ms-client-request-id", requestId); + conn.setConnectTimeout(10_000); + conn.setReadTimeout(30_000); + + if (extraHeaders != null) { + for (String header : extraHeaders) { + int colon = header.indexOf(':'); + if (colon > 0) { + conn.setRequestProperty(header.substring(0, colon).trim(), + header.substring(colon + 1).trim()); + } + } + } + + if (body != null && !body.isEmpty()) { + conn.setDoOutput(true); + byte[] bodyBytes = body.getBytes(StandardCharsets.UTF_8); + conn.setRequestProperty("Content-Length", String.valueOf(bodyBytes.length)); + try (OutputStream os = conn.getOutputStream()) { + os.write(bodyBytes); + } + } + + int status = conn.getResponseCode(); + InputStream stream = status >= 400 ? conn.getErrorStream() : conn.getInputStream(); + String responseBody = readStream(stream); + return new MtlsMsiHttpResponse(status, responseBody, responseBody); + } catch (IOException e) { + throw new MtlsMsiException("mTLS " + method + " " + urlStr + " failed: " + e.getMessage(), e); + } + } + + private static String readHttpsResponse(HttpsURLConnection conn, String urlStr) + throws IOException, MtlsMsiException { + int status = conn.getResponseCode(); + InputStream stream = status >= 400 ? conn.getErrorStream() : conn.getInputStream(); + String body = readStream(stream); + if (status != 200) { + throw new MtlsMsiException( + "mTLS token endpoint " + urlStr + " returned HTTP " + status + ": " + body); + } + return body; + } + + private static String readStream(InputStream stream) throws IOException { + if (stream == null) return ""; + StringBuilder sb = new StringBuilder(); + try (BufferedReader reader = + new BufferedReader(new InputStreamReader(stream, StandardCharsets.UTF_8))) { + String line; + while ((line = reader.readLine()) != null) sb.append(line); + } + return sb.toString(); + } + + // ─── Response parsers ──────────────────────────────────────────────────── + + private static MtlsMsiHelperResult parseTokenResponse(String json, MtlsBindingInfo binding) + throws MtlsMsiException { + if (json == null || json.isEmpty()) { + throw new MtlsMsiException("mTLS token endpoint returned empty response"); + } + + String accessToken = extractJsonString(json, "access_token"); + if (accessToken == null || accessToken.isEmpty()) { + throw new MtlsMsiException("mTLS token response missing access_token: " + json); + } + + String tokenType = extractJsonString(json, "token_type"); + int expiresIn = extractJsonInt(json, "expires_in"); + + // Encode the binding certificate as PEM for callers who need it. + String bindingCertPem = null; + try { + byte[] derBytes = binding.certificate.getEncoded(); + String b64 = Base64.getMimeEncoder(64, new byte[]{'\n'}).encodeToString(derBytes); + bindingCertPem = "-----BEGIN CERTIFICATE-----\n" + b64 + "\n-----END CERTIFICATE-----\n"; + } catch (CertificateEncodingException ignored) {} + + return new MtlsMsiHelperResult(accessToken, tokenType != null ? tokenType : "mtls_pop", + expiresIn, bindingCertPem, binding.tenantId, binding.clientId); + } + + // ─── Minimal JSON extractors ────────────────────────────────────────────── + + static String extractJsonString(String json, String key) { + String search = "\"" + key + "\""; + int keyIdx = json.indexOf(search); + if (keyIdx < 0) return null; + int colonIdx = json.indexOf(':', keyIdx + search.length()); + if (colonIdx < 0) return null; + int valueStart = colonIdx + 1; + while (valueStart < json.length() && Character.isWhitespace(json.charAt(valueStart))) valueStart++; + if (valueStart >= json.length()) return null; + if (json.charAt(valueStart) == '"') { + int end = valueStart + 1; + while (end < json.length()) { + char c = json.charAt(end); + if (c == '\\') { end += 2; continue; } + if (c == '"') break; + end++; + } + return json.substring(valueStart + 1, end) + .replace("\\n", "\n") + .replace("\\\"", "\"") + .replace("\\\\", "\\"); + } + return null; + } + + static int extractJsonInt(String json, String key) { + String search = "\"" + key + "\""; + int keyIdx = json.indexOf(search); + if (keyIdx < 0) return 0; + int colonIdx = json.indexOf(':', keyIdx + search.length()); + if (colonIdx < 0) return 0; + int valueStart = colonIdx + 1; + while (valueStart < json.length() && Character.isWhitespace(json.charAt(valueStart))) valueStart++; + int valueEnd = valueStart; + while (valueEnd < json.length() + && (Character.isDigit(json.charAt(valueEnd)) || json.charAt(valueEnd) == '-')) { + valueEnd++; + } + try { + return Integer.parseInt(json.substring(valueStart, valueEnd)); + } catch (NumberFormatException e) { + return 0; + } + } +} diff --git a/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/MtlsMsiException.java b/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/MtlsMsiException.java new file mode 100644 index 00000000..a1a81165 --- /dev/null +++ b/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/MtlsMsiException.java @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.aad.msal4j.mtls; + +/** + * Thrown when the mTLS Managed Identity subprocess ({@code MsalMtlsMsiHelper.exe}) fails or + * cannot be located. + */ +public class MtlsMsiException extends Exception { + + public MtlsMsiException(String message) { + super(message); + } + + public MtlsMsiException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/MtlsMsiHelperResult.java b/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/MtlsMsiHelperResult.java new file mode 100644 index 00000000..e2972f34 --- /dev/null +++ b/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/MtlsMsiHelperResult.java @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.aad.msal4j.mtls; + +/** + * Result returned by {@link MtlsMsiClient} after a successful mTLS PoP token acquisition. + */ +public class MtlsMsiHelperResult { + + private final String accessToken; + private final String tokenType; + private final int expiresIn; + private final String bindingCertificate; + private final String tenantId; + private final String clientId; + + public MtlsMsiHelperResult( + String accessToken, + String tokenType, + int expiresIn, + String bindingCertificate, + String tenantId, + String clientId) { + this.accessToken = accessToken; + this.tokenType = tokenType; + this.expiresIn = expiresIn; + this.bindingCertificate = bindingCertificate; + this.tenantId = tenantId; + this.clientId = clientId; + } + + /** The mTLS PoP access token. */ + public String getAccessToken() { return accessToken; } + + /** Always {@code "mtls_pop"}. */ + public String getTokenType() { return tokenType; } + + /** Seconds until expiry, as of the moment the subprocess returned. */ + public int getExpiresIn() { return expiresIn; } + + /** + * PEM-encoded binding certificate (the KeyGuard-backed certificate whose + * thumbprint is bound into the token's {@code cnf.x5t#S256} claim). + */ + public String getBindingCertificate() { return bindingCertificate; } + + /** Tenant ID from the issued token. May be null. */ + public String getTenantId() { return tenantId; } + + /** Object ID / client ID from the issued token. May be null. */ + public String getClientId() { return clientId; } +} diff --git a/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/MtlsMsiHttpResponse.java b/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/MtlsMsiHttpResponse.java new file mode 100644 index 00000000..73137b4b --- /dev/null +++ b/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/MtlsMsiHttpResponse.java @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.aad.msal4j.mtls; + +/** + * HTTP response returned by {@link MtlsMsiClient#httpRequest} when making a + * downstream mTLS call through {@code MsalMtlsMsiHelper.exe --mode http-request}. + */ +public class MtlsMsiHttpResponse { + + private final int status; + private final String body; + private final String rawJson; + + MtlsMsiHttpResponse(int status, String body, String rawJson) { + this.status = status; + this.body = body; + this.rawJson = rawJson; + } + + /** HTTP status code (e.g. 200, 401). */ + public int getStatus() { return status; } + + /** Response body as a string. */ + public String getBody() { return body; } + + /** Raw JSON string from the helper subprocess (contains status, headers, body). */ + public String getRawJson() { return rawJson; } +} diff --git a/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/NCryptLibrary.java b/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/NCryptLibrary.java new file mode 100644 index 00000000..b1ab7148 --- /dev/null +++ b/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/NCryptLibrary.java @@ -0,0 +1,118 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.aad.msal4j.mtls; + +import com.sun.jna.Library; +import com.sun.jna.Native; +import com.sun.jna.Pointer; +import com.sun.jna.Structure; +import com.sun.jna.WString; +import com.sun.jna.ptr.IntByReference; +import com.sun.jna.ptr.PointerByReference; + +import java.util.Arrays; +import java.util.List; + +/** + * JNA binding for {@code ncrypt.dll} — Windows CNG (Cryptography Next Generation) key + * storage and signing operations. + * + *

Function signatures mirror MSAL.NET's {@code WindowsCngKeyOperations} and msal-go's + * {@code cng_windows.go}. All NCrypt functions follow the Windows x64 calling convention + * (which equals cdecl on x64).

+ */ +interface NCryptLibrary extends Library { + + NCryptLibrary INSTANCE = Native.load("ncrypt", NCryptLibrary.class); + + // ─── NCrypt constants ────────────────────────────────────────────────────── + + int ERROR_SUCCESS = 0; + + int NCRYPT_SILENT_FLAG = 0x00000040; + int NCRYPT_OVERWRITE_KEY_FLAG = 0x00000080; + int NCRYPT_MACHINE_KEY_FLAG = 0x00000020; // not used (USER scope only) + int NCRYPT_USE_VIRTUAL_ISOLATION_FLAG = 0x00020000; // VBS KeyGuard + int NCRYPT_USE_PER_BOOT_KEY_FLAG = 0x00040000; // ephemeral per boot + int NCRYPT_ALLOW_EXPORT_NONE = 0; // non-exportable + + int NCRYPT_PAD_PKCS1_FLAG = 0x00000002; + int NCRYPT_PAD_PSS_FLAG = 0x00000008; + + // ─── Padding info structures ─────────────────────────────────────────────── + + /** Maps to {@code BCRYPT_PKCS1_PADDING_INFO} — used with NCRYPT_PAD_PKCS1_FLAG. */ + class BcryptPkcs1PaddingInfo extends Structure { + /** Algorithm name for the hash (e.g. L"SHA256"). LPCWSTR in C. */ + public WString pszAlgId; + + public BcryptPkcs1PaddingInfo(String algName) { + pszAlgId = new WString(algName); + write(); + } + + @Override + protected List getFieldOrder() { + return Arrays.asList("pszAlgId"); + } + } + + /** Maps to {@code BCRYPT_PSS_PADDING_INFO} — used with NCRYPT_PAD_PSS_FLAG. */ + class BcryptPssPaddingInfo extends Structure { + /** Algorithm name for the hash (e.g. L"SHA256"). LPCWSTR in C. */ + public WString pszAlgId; + /** Salt length in bytes. Typically equals hash output length for RSASSA-PSS. */ + public int cbSalt; + + public BcryptPssPaddingInfo(String algName, int saltLen) { + pszAlgId = new WString(algName); + cbSalt = saltLen; + write(); + } + + @Override + protected List getFieldOrder() { + return Arrays.asList("pszAlgId", "cbSalt"); + } + } + + // ─── NCrypt API ──────────────────────────────────────────────────────────── + + int NCryptOpenStorageProvider(PointerByReference phProvider, WString pszProviderName, int dwFlags); + + int NCryptOpenKey(Pointer hProvider, PointerByReference phKey, WString pszKeyName, + int dwLegacyKeySpec, int dwFlags); + + int NCryptCreatePersistedKey(Pointer hProvider, PointerByReference phKey, + WString pszAlgId, WString pszKeyName, + int dwLegacyKeySpec, int dwFlags); + + int NCryptSetProperty(Pointer hObject, WString pszProperty, + byte[] pbInput, int cbInput, int dwFlags); + + int NCryptGetProperty(Pointer hObject, WString pszProperty, + byte[] pbOutput, int cbOutput, + IntByReference pcbResult, int dwFlags); + + int NCryptFinalizeKey(Pointer hKey, int dwFlags); + + /** First call: pass {@code pbOutput=null, cbOutput=0} to query required buffer size. */ + int NCryptExportKey(Pointer hKey, Pointer hExportKey, WString pszBlobType, + Pointer pParameterList, Pointer pbOutput, int cbOutput, + IntByReference pcbResult, int dwFlags); + + /** + * First call: pass {@code pbSignature=null, cbSignature=0} to get required buffer size + * (returned in {@code pcbResult}). + * Second call: pass a {@code Memory} buffer of that size. + */ + int NCryptSignHash(Pointer hKey, Pointer pPaddingInfo, + byte[] pbHashValue, int cbHashValue, + Pointer pbSignature, int cbSignature, + IntByReference pcbResult, int dwFlags); + + int NCryptFreeObject(Pointer hObject); + + int NCryptDeleteKey(Pointer hKey, int dwFlags); +} diff --git a/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/Pkcs10Builder.java b/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/Pkcs10Builder.java new file mode 100644 index 00000000..3cc8e919 --- /dev/null +++ b/msal4j-mtls-extensions/src/main/java/com/microsoft/aad/msal4j/mtls/Pkcs10Builder.java @@ -0,0 +1,326 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.aad.msal4j.mtls; + +import com.sun.jna.Pointer; + +import java.io.ByteArrayOutputStream; +import java.math.BigInteger; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.Base64; + +/** + * Builds a PKCS#10 Certification Request (CSR) that matches the format produced by + * MSAL.NET and msal-go for the Azure IMDSv2 {@code /issuecredential} endpoint. + * + *

CSR Structure

+ *
+ * CertificationRequest ::= SEQUENCE {
+ *     certificationRequestInfo  CertificationRequestInfo,
+ *     signatureAlgorithm        AlgorithmIdentifier,   -- RSASSA-PSS with SHA-256 params
+ *     signature                 BIT STRING
+ * }
+ *
+ * CertificationRequestInfo ::= SEQUENCE {
+ *     version       INTEGER { v1(0) }
+ *     subject       Name   -- CN={clientId}, DC={tenantId}
+ *     subjectPKInfo SubjectPublicKeyInfo
+ *     attributes    [0] IMPLICIT SET OF -- OID 1.3.6.1.4.1.311.90.2.10 = cuId JSON
+ * }
+ * 
+ * + *

Signing

+ * Signature: RSASSA-PSS with SHA-256, salt length = 32 bytes (hash output length). + * Signing is delegated to {@link CngKeyGuard#signPss} so the non-exportable KeyGuard key + * never leaves CNG. + * + *

This is a pure-Java port of msal-go's {@code generateCSR()} in {@code imdsv2.go}, + * using manual DER encoding to avoid adding external ASN.1 library dependencies.

+ */ +final class Pkcs10Builder { + + private Pkcs10Builder() {} + + // ─── OIDs (pre-encoded DER) ──────────────────────────────────────────────── + + // rsaEncryption: 1.2.840.113549.1.1.1 + private static final byte[] OID_RSA_ENCRYPTION = hexToBytes("2a864886f70d010101"); + // sha256: 2.16.840.1.101.3.4.2.1 + private static final byte[] OID_SHA256 = hexToBytes("608648016503040201"); + // mgf1: 1.2.840.113549.1.1.8 + private static final byte[] OID_MGF1 = hexToBytes("2a864886f70d010108"); + // id-RSASSA-PSS: 1.2.840.113549.1.1.10 + private static final byte[] OID_RSASSA_PSS = hexToBytes("2a864886f70d01010a"); + // commonName: 2.5.4.3 + private static final byte[] OID_COMMON_NAME = hexToBytes("5504 03".replace(" ", "")); + // domainComponent: 0.9.2342.19200300.100.1.25 + private static final byte[] OID_DOMAIN_COMPONENT = hexToBytes("0992268993f22c6401 19".replace(" ", "")); + // cuId attribute: 1.3.6.1.4.1.311.90.2.10 + private static final byte[] OID_CU_ID = hexToBytes("2b060104018237 5a02 0a".replace(" ", "")); + + // ─── Public API ─────────────────────────────────────────────────────────── + + /** + * Generates a PKCS#10 CSR and returns it as standard Base64-encoded DER + * (no PEM headers), ready to be placed in the {@code csr} field of the + * IMDS {@code /issuecredential} JSON request. + * + * @param keyHandle CNG key handle (the private key — signs the CSR TBS bytes) + * @param modulus RSA public key modulus (from {@link CngKeyGuard#exportPublicKey}) + * @param publicExp RSA public exponent + * @param clientId managed identity client ID → CN in subject + * @param tenantId tenant GUID → DC in subject + * @param vmId VM ID for the cuId attribute ({@code cuId.vmId}); may be null + * @param vmssId VMSS ID for the cuId attribute; may be null + * @return Base64-encoded DER of the PKCS#10 CSR + */ + static String generate(Pointer keyHandle, BigInteger modulus, int publicExp, + String clientId, String tenantId, String vmId, String vmssId) + throws MtlsMsiException { + + // --- SubjectPublicKeyInfo ------------------------------------------ + byte[] spki = buildSpki(modulus, publicExp); + + // --- Subject: CN={clientId}, DC={tenantId} ------------------------- + byte[] subject = buildSubject(clientId, tenantId); + + // --- cuId attribute ------------------------------------------------ + byte[] cuIdJson = buildCuIdJson(vmId, vmssId); + byte[] attributes = buildCuIdAttribute(cuIdJson); + + // --- CertificationRequestInfo SEQUENCE ----------------------------- + byte[] version = derInteger(new byte[]{0x00}); // INTEGER v1(0) + byte[] certReqInfo = derSequence(concat(version, subject, spki, attributes)); + + // --- Sign with RSASSA-PSS SHA-256 (salt=32) ----------------------- + byte[] tbs; + try { + tbs = MessageDigest.getInstance("SHA-256").digest(certReqInfo); + } catch (NoSuchAlgorithmException e) { + throw new MtlsMsiException("SHA-256 not available: " + e.getMessage(), e); + } + byte[] sig = CngKeyGuard.signPss(keyHandle, tbs, "SHA256", 32); + + // --- AlgorithmIdentifier for RSASSA-PSS ---------------------------- + byte[] sigAlgId = buildPssAlgorithmIdentifier(); + + // --- BIT STRING wrapping the signature ----------------------------- + byte[] sigBitString = derBitString(sig); + + // --- Final CertificationRequest SEQUENCE --------------------------- + byte[] csr = derSequence(concat(certReqInfo, sigAlgId, sigBitString)); + + return Base64.getEncoder().encodeToString(csr); + } + + // ─── DER building blocks ────────────────────────────────────────────────── + + /** DER SEQUENCE */ + static byte[] derSequence(byte[] content) { + return derTagLen(0x30, content); + } + + /** DER SET */ + private static byte[] derSet(byte[] content) { + return derTagLen(0x31, content); + } + + /** DER INTEGER from raw bytes (big-endian, with sign byte if high bit set) */ + private static byte[] derInteger(byte[] value) { + // Add leading 0x00 if high bit is set (unsigned → signed two's complement). + byte[] content = (value[0] & 0x80) != 0 + ? concat(new byte[]{0x00}, value) + : value; + return derTagLen(0x02, content); + } + + /** DER OBJECT IDENTIFIER from pre-encoded OID value bytes */ + private static byte[] derOid(byte[] oidBytes) { + return derTagLen(0x06, oidBytes); + } + + /** DER UTF8String */ + private static byte[] derUtf8String(String s) { + byte[] bytes = s.getBytes(java.nio.charset.StandardCharsets.UTF_8); + return derTagLen(0x0C, bytes); + } + + /** DER BIT STRING — prepend 0x00 (zero unused bits) */ + static byte[] derBitString(byte[] data) { + byte[] content = new byte[data.length + 1]; + content[0] = 0x00; + System.arraycopy(data, 0, content, 1, data.length); + return derTagLen(0x03, content); + } + + /** DER NULL */ + private static final byte[] DER_NULL = {0x05, 0x00}; + + /** Context-specific explicit tag [N] wrapping content */ + private static byte[] contextExplicit(int n, byte[] content) { + return derTagLen(0xA0 | n, content); + } + + /** Context-specific implicit tag [N] wrapping content */ + private static byte[] contextImplicit(int n, byte[] content) { + return derTagLen(0x80 | n, content); + } + + /** Writes tag + DER length + content */ + private static byte[] derTagLen(int tag, byte[] content) { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + out.write(tag); + int len = content.length; + if (len < 0x80) { + out.write(len); + } else if (len < 0x100) { + out.write(0x81); + out.write(len); + } else if (len < 0x10000) { + out.write(0x82); + out.write((len >> 8) & 0xFF); + out.write(len & 0xFF); + } else { + out.write(0x83); + out.write((len >> 16) & 0xFF); + out.write((len >> 8) & 0xFF); + out.write(len & 0xFF); + } + try { out.write(content); } catch (java.io.IOException ignored) {} + return out.toByteArray(); + } + + // ─── Component builders ─────────────────────────────────────────────────── + + /** + * SubjectPublicKeyInfo ::= SEQUENCE { algorithm AlgorithmIdentifier, subjectPublicKey BIT STRING } + * AlgorithmIdentifier for RSA: SEQUENCE { OID rsaEncryption, NULL } + * Public key: BIT STRING containing RSAPublicKey SEQUENCE { modulus INTEGER, publicExp INTEGER } + */ + private static byte[] buildSpki(BigInteger modulus, int publicExp) { + // RSAPublicKey SEQUENCE { modulus INTEGER, publicExp INTEGER } + byte[] modBytes = modulus.toByteArray(); + byte[] expBytes = BigInteger.valueOf(publicExp).toByteArray(); + byte[] rsaPublicKey = derSequence(concat(derInteger(modBytes), derInteger(expBytes))); + + // AlgorithmIdentifier for rsaEncryption + byte[] algId = derSequence(concat(derOid(OID_RSA_ENCRYPTION), DER_NULL)); + + // SubjectPublicKeyInfo + return derSequence(concat(algId, derBitString(rsaPublicKey))); + } + + /** + * Name ::= SEQUENCE { RDN SEQUENCE { AttributeTypeAndValue SEQUENCE { OID, value } } } + * Subject: CN={clientId}, DC={tenantId} + * Matches msal-go: pkix.Name{CommonName: clientId, ExtraNames: []pkix.AttributeTypeAndValue{{Type: dcOID, Value: tenantId}}} + */ + private static byte[] buildSubject(String clientId, String tenantId) { + // AttributeTypeAndValue SEQUENCE { OID commonName, UTF8String clientId } + byte[] cnAttr = derSequence(concat(derOid(OID_COMMON_NAME), derUtf8String(clientId))); + byte[] cnRdn = derSet(cnAttr); + + // AttributeTypeAndValue SEQUENCE { OID domainComponent, UTF8String tenantId } + byte[] dcAttr = derSequence(concat(derOid(OID_DOMAIN_COMPONENT), derUtf8String(tenantId))); + byte[] dcRdn = derSet(dcAttr); + + // Name = SEQUENCE of RDNs + return derSequence(concat(cnRdn, dcRdn)); + } + + /** + * Builds the cuId JSON string. Matches msal-go's json.Marshal(cuID): + * {@code {"vmId":"","vmssId":""}} with omitempty semantics. + */ + private static byte[] buildCuIdJson(String vmId, String vmssId) { + StringBuilder sb = new StringBuilder("{"); + boolean first = true; + if (vmId != null && !vmId.isEmpty()) { + sb.append("\"vmId\":\"").append(vmId).append("\""); + first = false; + } + if (vmssId != null && !vmssId.isEmpty()) { + if (!first) sb.append(","); + sb.append("\"vmssId\":\"").append(vmssId).append("\""); + } + sb.append("}"); + return sb.toString().getBytes(java.nio.charset.StandardCharsets.UTF_8); + } + + /** + * CertificationRequestInfo attributes [0]: + * [0] CONSTRUCTED { + * SEQUENCE { OID 1.3.6.1.4.1.311.90.2.10, SET { UTF8String(cuIdJson) } } + * } + * + *

Per PKCS#10, {@code [0] IMPLICIT Attributes} — IMPLICIT tagging of a constructed type + * keeps the constructed bit, so the tag byte is {@code 0xA0} (context-specific, constructed). + * Mirrors msal-go's {@code buildCuIDAttribute()} which uses + * {@code asn1.RawValue{Class: ClassContextSpecific, Tag: 0, IsCompound: true}}.

+ */ + private static byte[] buildCuIdAttribute(byte[] cuIdJsonBytes) { + byte[] utf8Str = derTagLen(0x0C, cuIdJsonBytes); // UTF8String + byte[] valueSet = derSet(utf8Str); // SET { UTF8String } + byte[] attrSeq = derSequence(concat(derOid(OID_CU_ID), valueSet)); // SEQUENCE { OID, SET } + return contextExplicit(0, attrSeq); // [0] CONSTRUCTED { SEQUENCE } — 0xA0 tag + } + + /** + * AlgorithmIdentifier for RSASSA-PSS with SHA-256: + * SEQUENCE { + * OID id-RSASSA-PSS, + * SEQUENCE { -- RSASSA-PSS-params + * [0] SEQUENCE { OID sha-256, NULL }, -- hashAlgorithm + * [1] SEQUENCE { OID mgf1, SEQUENCE { OID sha-256, NULL } }, -- maskGenAlgorithm + * [2] INTEGER 32 -- saltLength + * } + * } + * Matches msal-go's explicit PSS AlgorithmIdentifier. + */ + private static byte[] buildPssAlgorithmIdentifier() { + // sha256AlgID: SEQUENCE { OID sha256, NULL } + byte[] sha256AlgId = derSequence(concat(derOid(OID_SHA256), DER_NULL)); + + // hashAlgorithm [0]: sha256AlgID + byte[] hashAlgorithm = contextExplicit(0, sha256AlgId); + + // mgf1AlgID: SEQUENCE { OID mgf1, sha256AlgID } + byte[] mgf1AlgId = derSequence(concat(derOid(OID_MGF1), sha256AlgId)); + // maskGenAlgorithm [1]: mgf1AlgID + byte[] maskGenAlgorithm = contextExplicit(1, mgf1AlgId); + + // saltLength [2]: INTEGER 32 + byte[] saltLength = contextExplicit(2, derInteger(new byte[]{32})); + + // RSASSA-PSS-params SEQUENCE + byte[] pssParams = derSequence(concat(hashAlgorithm, maskGenAlgorithm, saltLength)); + + // AlgorithmIdentifier SEQUENCE { OID id-RSASSA-PSS, pssParams } + return derSequence(concat(derOid(OID_RSASSA_PSS), pssParams)); + } + + // ─── Utility ────────────────────────────────────────────────────────────── + + private static byte[] concat(byte[]... arrays) { + int total = 0; + for (byte[] a : arrays) total += a.length; + byte[] result = new byte[total]; + int offset = 0; + for (byte[] a : arrays) { + System.arraycopy(a, 0, result, offset, a.length); + offset += a.length; + } + return result; + } + + private static byte[] hexToBytes(String hex) { + hex = hex.replace(" ", ""); + byte[] result = new byte[hex.length() / 2]; + for (int i = 0; i < result.length; i++) { + result[i] = (byte) Integer.parseInt(hex.substring(i * 2, i * 2 + 2), 16); + } + return result; + } +} diff --git a/msal4j-mtls-extensions/src/test/java/com/microsoft/aad/msal4j/mtls/CngProviderTest.java b/msal4j-mtls-extensions/src/test/java/com/microsoft/aad/msal4j/mtls/CngProviderTest.java new file mode 100644 index 00000000..1cfc3f59 --- /dev/null +++ b/msal4j-mtls-extensions/src/test/java/com/microsoft/aad/msal4j/mtls/CngProviderTest.java @@ -0,0 +1,98 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.aad.msal4j.mtls; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledOnOs; +import org.junit.jupiter.api.condition.OS; + +import java.security.Provider; +import java.security.Security; +import java.util.Arrays; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Unit tests for {@link CngProvider} — provider registration and service declarations. + * + *

Requires Windows because loading {@link CngProvider} transitively loads + * {@link CngSignatureSpi} → {@link CngRsaPrivateKey} → {@link NCryptLibrary} + * ({@code ncrypt.dll}).

+ */ +@EnabledOnOs(OS.WINDOWS) +class CngProviderTest { + + @AfterEach + void removeCngProvider() { + // Remove after each test to prevent state from bleeding between tests + Security.removeProvider("CNG"); + } + + // ─── installIfAbsent ───────────────────────────────────────────────────── + + @Test + void installIfAbsent_registersProviderByName() { + CngProvider.installIfAbsent(); + assertNotNull(Security.getProvider("CNG"), + "CNG provider must be registered in the JVM Security list after installIfAbsent()"); + } + + @Test + void installIfAbsent_isIdempotent() { + CngProvider.installIfAbsent(); + CngProvider.installIfAbsent(); // second call must be a no-op + + long cngCount = Arrays.stream(Security.getProviders()) + .filter(p -> "CNG".equals(p.getName())) + .count(); + assertEquals(1, cngCount, + "CNG provider must appear exactly once even after multiple installIfAbsent() calls"); + } + + @Test + void installIfAbsent_insertsAtHighestPriority() { + CngProvider.installIfAbsent(); + Provider[] providers = Security.getProviders(); + // Security position 1 = index 0 in the array + assertEquals("CNG", providers[0].getName(), + "CNG must be at Security position 1 (highest priority) so JSSE uses it first"); + } + + // ─── Service registrations ──────────────────────────────────────────────── + + @Test + void provider_registersSha256WithRsa() { + Provider p = new CngProvider(); + assertNotNull(p.getService("Signature", "SHA256withRSA"), + "CNG provider must advertise SHA256withRSA (used by TLS 1.2 client cert verify)"); + } + + @Test + void provider_registersSha1WithRsa() { + Provider p = new CngProvider(); + assertNotNull(p.getService("Signature", "SHA1withRSA"), + "CNG provider must advertise SHA1withRSA (ADFS/DSTS compat)"); + } + + @Test + void provider_registersRsaSsaPss() { + Provider p = new CngProvider(); + assertNotNull(p.getService("Signature", "RSASSA-PSS"), + "CNG provider must advertise RSASSA-PSS (used by TLS 1.3 client cert verify)"); + } + + @Test + void provider_name_isCng() { + assertEquals("CNG", new CngProvider().getName()); + } + + @Test + void provider_sha256Alias_resolves() { + Provider p = new CngProvider(); + // Alias "SHA-256withRSA" must resolve to "SHA256withRSA" + assertNotNull(p.getService("Signature", "SHA-256withRSA"), + "Alias SHA-256withRSA must resolve via the CNG provider"); + } +} diff --git a/msal4j-mtls-extensions/src/test/java/com/microsoft/aad/msal4j/mtls/CngSignatureSpiTest.java b/msal4j-mtls-extensions/src/test/java/com/microsoft/aad/msal4j/mtls/CngSignatureSpiTest.java new file mode 100644 index 00000000..7c7aee89 --- /dev/null +++ b/msal4j-mtls-extensions/src/test/java/com/microsoft/aad/msal4j/mtls/CngSignatureSpiTest.java @@ -0,0 +1,147 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.aad.msal4j.mtls; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledOnOs; +import org.junit.jupiter.api.condition.OS; + +import java.security.*; +import java.security.spec.MGF1ParameterSpec; +import java.security.spec.PSSParameterSpec; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Unit tests for {@link CngSignatureSpi} — specifically the delegation path used + * when a non-{@link CngRsaPrivateKey} (regular exportable RSA key) is passed to + * {@code initSign}. + * + *

The CNG path (signing with an actual KeyGuard key) cannot be exercised in unit tests + * without a Trusted Launch Azure VM. The delegation path, however, uses standard Java RSA + * keys and the SunRsaSign provider, and can be fully tested without CNG hardware.

+ * + *

Installing {@link CngProvider} at the highest priority must not break other + * RSA signing operations in the same JVM — that is the contract this delegation path + * satisfies. These tests verify that invariant.

+ * + *

Requires Windows because loading {@link CngSignatureSpi} transitively initializes + * {@link NCryptLibrary} ({@code ncrypt.dll}).

+ */ +@EnabledOnOs(OS.WINDOWS) +class CngSignatureSpiTest { + + private KeyPair keyPair; + + @BeforeEach + void generateRsaKeyPair() throws Exception { + KeyPairGenerator kpg = KeyPairGenerator.getInstance("RSA"); + kpg.initialize(2048); + keyPair = kpg.generateKeyPair(); + } + + @AfterEach + void removeInstalledCngProvider() { + Security.removeProvider("CNG"); + } + + // ─── SHA256withRSA delegation ───────────────────────────────────────────── + + @Test + void sha256WithRsa_delegation_producesVerifiableSignature() throws Exception { + byte[] data = "hello msal-java mtls pop delegation".getBytes(); + + Signature signer = Signature.getInstance("SHA256withRSA", new CngProvider()); + signer.initSign(keyPair.getPrivate()); // non-CNG key → delegation + signer.update(data); + byte[] sig = signer.sign(); + + Signature verifier = Signature.getInstance("SHA256withRSA"); + verifier.initVerify(keyPair.getPublic()); + verifier.update(data); + assertTrue(verifier.verify(sig), + "SHA256withRSA signature via CNG delegation must verify with the RSA public key"); + } + + @Test + void sha256WithRsa_delegation_tampered_fails() throws Exception { + byte[] data = "correct data".getBytes(); + byte[] tampered = "tampered data".getBytes(); + + Signature signer = Signature.getInstance("SHA256withRSA", new CngProvider()); + signer.initSign(keyPair.getPrivate()); + signer.update(data); + byte[] sig = signer.sign(); + + Signature verifier = Signature.getInstance("SHA256withRSA"); + verifier.initVerify(keyPair.getPublic()); + verifier.update(tampered); + assertFalse(verifier.verify(sig), + "Signature must not verify against tampered data"); + } + + // ─── SHA1withRSA delegation ─────────────────────────────────────────────── + + @Test + void sha1WithRsa_delegation_producesVerifiableSignature() throws Exception { + byte[] data = "hello mtls pop sha1 delegation".getBytes(); + + Signature signer = Signature.getInstance("SHA1withRSA", new CngProvider()); + signer.initSign(keyPair.getPrivate()); + signer.update(data); + byte[] sig = signer.sign(); + + Signature verifier = Signature.getInstance("SHA1withRSA"); + verifier.initVerify(keyPair.getPublic()); + verifier.update(data); + assertTrue(verifier.verify(sig)); + } + + // ─── RSASSA-PSS delegation ───────────────────────────────────────────────── + + @Test + void rsaSsaPss_delegation_producesVerifiableSignature() throws Exception { + byte[] data = "hello mtls pop pss delegation".getBytes(); + PSSParameterSpec pssSpec = new PSSParameterSpec( + "SHA-256", "MGF1", MGF1ParameterSpec.SHA256, 32, 1); + + Signature signer = Signature.getInstance("RSASSA-PSS", new CngProvider()); + signer.initSign(keyPair.getPrivate()); // creates delegate + signer.setParameter(pssSpec); // forwarded to delegate + signer.update(data); + byte[] sig = signer.sign(); + + Signature verifier = Signature.getInstance("RSASSA-PSS"); + verifier.setParameter(pssSpec); + verifier.initVerify(keyPair.getPublic()); + verifier.update(data); + assertTrue(verifier.verify(sig), + "RSASSA-PSS signature via CNG delegation must verify with the RSA public key"); + } + + // ─── CngProvider installed globally does not break standard signing ──────── + + @Test + void installedCngProvider_doesNotBreakSunRsaSignSigning() throws Exception { + CngProvider.installIfAbsent(); + + byte[] data = "standard signing still works".getBytes(); + + // Even with CNG at position 1 globally, explicitly requesting SunRsaSign must + // still produce valid signatures. This verifies that installIfAbsent() is + // non-destructive to the global Security provider list. + Signature signer = Signature.getInstance("SHA256withRSA", "SunRsaSign"); + signer.initSign(keyPair.getPrivate()); + signer.update(data); + byte[] sig = signer.sign(); + + Signature verifier = Signature.getInstance("SHA256withRSA", "SunRsaSign"); + verifier.initVerify(keyPair.getPublic()); + verifier.update(data); + assertTrue(verifier.verify(sig), + "SunRsaSign signing must work correctly even when CNG is installed globally at position 1"); + } +} diff --git a/msal4j-mtls-extensions/src/test/java/com/microsoft/aad/msal4j/mtls/ImdsV2ClientTest.java b/msal4j-mtls-extensions/src/test/java/com/microsoft/aad/msal4j/mtls/ImdsV2ClientTest.java new file mode 100644 index 00000000..e0a6f517 --- /dev/null +++ b/msal4j-mtls-extensions/src/test/java/com/microsoft/aad/msal4j/mtls/ImdsV2ClientTest.java @@ -0,0 +1,144 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.aad.msal4j.mtls; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Unit tests for {@link ImdsV2Client} JSON parsing helpers. + * + *

The IMDS HTTP calls themselves cannot be exercised in unit tests without a live Azure VM, + * so these tests focus on the package-private parsing methods that contain all the logic. + * Mirrors msal-go's JSON parsing validation in {@code imdsv2_test.go}.

+ */ +class ImdsV2ClientTest { + + // ─── extractString ──────────────────────────────────────────────────────── + + @Test + void extractString_returnsStringValue() { + String json = "{\"clientId\":\"my-client-id\",\"tenantId\":\"my-tenant\"}"; + assertEquals("my-client-id", ImdsV2Client.extractString(json, "clientId")); + assertEquals("my-tenant", ImdsV2Client.extractString(json, "tenantId")); + } + + @Test + void extractString_missingKey_returnsNull() { + String json = "{\"foo\":\"bar\"}"; + assertNull(ImdsV2Client.extractString(json, "missing")); + } + + @Test + void extractString_emptyString_returnsEmpty() { + String json = "{\"key\":\"\"}"; + assertEquals("", ImdsV2Client.extractString(json, "key")); + } + + @Test + void extractString_handlesQuoteEscape() { + String json = "{\"key\":\"val\\\"ue\"}"; + assertEquals("val\"ue", ImdsV2Client.extractString(json, "key")); + } + + @Test + void extractString_handlesNewlineEscape() { + String json = "{\"key\":\"line1\\nline2\"}"; + assertEquals("line1\nline2", ImdsV2Client.extractString(json, "key")); + } + + @Test + void extractString_handlesBackslashEscape() { + String json = "{\"key\":\"path\\\\to\\\\file\"}"; + assertEquals("path\\to\\file", ImdsV2Client.extractString(json, "key")); + } + + @Test + void extractString_nonStringValue_returnsNull() { + // Numeric and boolean values have no leading quote → returns null + String json = "{\"count\":42,\"flag\":true}"; + assertNull(ImdsV2Client.extractString(json, "count")); + assertNull(ImdsV2Client.extractString(json, "flag")); + } + + @Test + void extractString_nestedObject_findsTopLevelKey() { + // extractString finds the first matching key by name (IMDS JSON is flat at top level) + String json = "{\"outer\":{\"inner\":\"nested\"},\"clientId\":\"top-level\"}"; + assertEquals("top-level", ImdsV2Client.extractString(json, "clientId")); + } + + @Test + void extractString_urlValue_preservesSlashes() { + String json = "{\"attestationEndpoint\":\"https://sharedeus.eus.attest.azure.net\"}"; + assertEquals("https://sharedeus.eus.attest.azure.net", + ImdsV2Client.extractString(json, "attestationEndpoint")); + } + + // ─── PlatformMetadata.cuIdString ────────────────────────────────────────── + + @Test + void cuIdString_vmIdPresent_returnsVmId() { + ImdsV2Client.PlatformMetadata m = new ImdsV2Client.PlatformMetadata( + "client-id", "tenant-id", "vm-123", "vmss-456", "https://attest.example.com"); + assertEquals("vm-123", m.cuIdString(), + "cuIdString must return vmId when present (matches msal-go logic)"); + } + + @Test + void cuIdString_vmIdNull_returnsClientId() { + ImdsV2Client.PlatformMetadata m = new ImdsV2Client.PlatformMetadata( + "client-id", "tenant-id", null, null, null); + assertEquals("client-id", m.cuIdString(), + "cuIdString must fall back to clientId when vmId is null"); + } + + @Test + void cuIdString_vmIdEmpty_returnsClientId() { + ImdsV2Client.PlatformMetadata m = new ImdsV2Client.PlatformMetadata( + "client-id", "tenant-id", "", null, null); + assertEquals("client-id", m.cuIdString(), + "cuIdString must fall back to clientId when vmId is empty"); + } + + @Test + void platformMetadata_fieldsStoredCorrectly() { + ImdsV2Client.PlatformMetadata m = new ImdsV2Client.PlatformMetadata( + "c1", "t1", "v1", "vs1", "https://attest"); + assertEquals("c1", m.clientId); + assertEquals("t1", m.tenantId); + assertEquals("v1", m.vmId); + assertEquals("vs1", m.vmssId); + assertEquals("https://attest", m.attestationEndpoint); + } + + // ─── CredentialResponse ──────────────────────────────────────────────────── + + @Test + void credentialResponse_fieldsStoredCorrectly() { + ImdsV2Client.CredentialResponse resp = new ImdsV2Client.CredentialResponse( + "base64cert", + "https://eastus.mtlsauth.microsoft.com", + "client-id", + "tenant-id", + "https://eastus.mtlsauth.microsoft.com/tenant-id/oauth2/v2.0/token"); + assertEquals("base64cert", resp.certificate); + assertEquals("https://eastus.mtlsauth.microsoft.com", resp.mtlsAuthenticationEndpoint); + assertEquals("client-id", resp.clientId); + assertEquals("tenant-id", resp.tenantId); + assertEquals("https://eastus.mtlsauth.microsoft.com/tenant-id/oauth2/v2.0/token", + resp.regionalTokenUrl); + } + + @Test + void credentialResponse_nullableFields_accepted() { + // mtlsAuthenticationEndpoint and regionalTokenUrl may be null on some IMDS configs + ImdsV2Client.CredentialResponse resp = new ImdsV2Client.CredentialResponse( + "cert", null, "c", "t", null); + assertEquals("cert", resp.certificate); + assertNull(resp.mtlsAuthenticationEndpoint); + assertNull(resp.regionalTokenUrl); + } +} diff --git a/msal4j-mtls-extensions/src/test/java/com/microsoft/aad/msal4j/mtls/MtlsBindingInfoTest.java b/msal4j-mtls-extensions/src/test/java/com/microsoft/aad/msal4j/mtls/MtlsBindingInfoTest.java new file mode 100644 index 00000000..6df532d8 --- /dev/null +++ b/msal4j-mtls-extensions/src/test/java/com/microsoft/aad/msal4j/mtls/MtlsBindingInfoTest.java @@ -0,0 +1,115 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.aad.msal4j.mtls; + +import com.sun.jna.Pointer; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledOnOs; +import org.junit.jupiter.api.condition.OS; +import org.mockito.Mockito; + +import java.math.BigInteger; +import java.security.cert.X509Certificate; +import java.util.Date; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Unit tests for {@link MtlsBindingInfo} — specifically the 5-minute early-expiry logic. + * + *

Mirrors msal-go's cert cache expiry: {@code cert.NotAfter.Add(-5 * time.Minute)}.

+ * + *

Requires Windows because creating a {@link CngRsaPrivateKey} instance transitively + * initializes {@link NCryptLibrary} ({@code ncrypt.dll}).

+ */ +@EnabledOnOs(OS.WINDOWS) +class MtlsBindingInfoTest { + + // ─── isExpired ──────────────────────────────────────────────────────────── + + @Test + void isExpired_false_whenCertExpires60MinutesFromNow() { + // expiresAt = notAfter - 5 min = now + 55 min → not expired + Date notAfter = new Date(System.currentTimeMillis() + 60L * 60 * 1000); + assertFalse(buildInfo(notAfter).isExpired()); + } + + @Test + void isExpired_true_whenCertExpired1HourAgo() { + // expiresAt = now - 65 min → expired + Date notAfter = new Date(System.currentTimeMillis() - 60L * 60 * 1000); + assertTrue(buildInfo(notAfter).isExpired()); + } + + @Test + void isExpired_false_when6MinutesRemain() { + // 6 min until cert expiry → expiresAt = now + 1 min → not expired yet + Date notAfter = new Date(System.currentTimeMillis() + 6L * 60 * 1000); + assertFalse(buildInfo(notAfter).isExpired(), + "Binding must not be considered expired when more than 5 minutes remain"); + } + + @Test + void isExpired_true_when4MinutesRemain() { + // 4 min until cert expiry → expiresAt = now - 1 min → expired (within 5-min buffer) + Date notAfter = new Date(System.currentTimeMillis() + 4L * 60 * 1000); + assertTrue(buildInfo(notAfter).isExpired(), + "Binding must be considered expired within the 5-minute proactive refresh window"); + } + + @Test + void isExpired_true_exactlyAt5MinuteBoundary() { + // Exactly 5 minutes remain → expiresAt ≈ now → just expired (or right at boundary) + Date notAfter = new Date(System.currentTimeMillis() + 5L * 60 * 1000); + // At exactly 5 min, expiresAt == now; new Date().after(expiresAt) may be false by a ms. + // Allow either outcome — the important invariant is the direction. + MtlsBindingInfo info = buildInfo(notAfter); + // Just verify it doesn't throw + info.isExpired(); + } + + // ─── Field storage ──────────────────────────────────────────────────────── + + @Test + void constructor_storesAllFields() { + Date notAfter = new Date(System.currentTimeMillis() + 60L * 60 * 1000); + X509Certificate mockCert = mockCert(notAfter); + CngRsaPrivateKey key = new CngRsaPrivateKey(Pointer.NULL, BigInteger.valueOf(12345), 65537); + + MtlsBindingInfo info = new MtlsBindingInfo( + key, mockCert, "https://eastus.mtlsauth.microsoft.com", "my-client", "my-tenant"); + + assertSame(key, info.privateKey); + assertSame(mockCert, info.certificate); + assertEquals("https://eastus.mtlsauth.microsoft.com", info.mtlsEndpoint); + assertEquals("my-client", info.clientId); + assertEquals("my-tenant", info.tenantId); + } + + @Test + void constructor_expiresAt_isFiveMinutesBeforeNotAfter() { + long notAfterMs = System.currentTimeMillis() + 60L * 60 * 1000; + Date notAfter = new Date(notAfterMs); + MtlsBindingInfo info = buildInfo(notAfter); + + long expectedExpiresAtMs = notAfterMs - 5L * 60 * 1000; + long delta = Math.abs(info.expiresAt.getTime() - expectedExpiresAtMs); + assertTrue(delta < 1000, + "expiresAt must be exactly 5 minutes before cert notAfter"); + } + + // ─── Helpers ────────────────────────────────────────────────────────────── + + private static MtlsBindingInfo buildInfo(Date notAfter) { + X509Certificate mockCert = mockCert(notAfter); + CngRsaPrivateKey key = new CngRsaPrivateKey(Pointer.NULL, BigInteger.ONE, 65537); + return new MtlsBindingInfo(key, mockCert, "https://mtlsauth.microsoft.com", "c", "t"); + } + + private static X509Certificate mockCert(Date notAfter) { + X509Certificate cert = Mockito.mock(X509Certificate.class); + Mockito.when(cert.getNotAfter()).thenReturn(notAfter); + return cert; + } +} diff --git a/msal4j-mtls-extensions/src/test/java/com/microsoft/aad/msal4j/mtls/MtlsMsiClientTest.java b/msal4j-mtls-extensions/src/test/java/com/microsoft/aad/msal4j/mtls/MtlsMsiClientTest.java new file mode 100644 index 00000000..90c87b88 --- /dev/null +++ b/msal4j-mtls-extensions/src/test/java/com/microsoft/aad/msal4j/mtls/MtlsMsiClientTest.java @@ -0,0 +1,146 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.aad.msal4j.mtls; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledOnOs; +import org.junit.jupiter.api.condition.OS; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Unit tests for {@link MtlsMsiClient} input validation and token request helpers. + * + *

Full end-to-end token acquisition requires a live Azure VM with Managed Identity and + * a properly configured mTLS PoP tenant. These tests cover the logic that can be exercised + * without CNG or IMDS:

+ *
    + *
  • Null/empty resource validation in {@link MtlsMsiClient#acquireToken}
  • + *
  • Token URL construction ({@code buildTokenUrl})
  • + *
  • Token request body construction ({@code buildTokenRequestBody})
  • + *
+ * + *

Requires Windows because loading {@link MtlsMsiClient} transitively loads + * {@link CngProvider} → {@link CngSignatureSpi} → {@link CngRsaPrivateKey} → + * {@link NCryptLibrary}, which loads {@code ncrypt.dll}.

+ */ +@EnabledOnOs(OS.WINDOWS) +class MtlsMsiClientTest { + + // ─── acquireToken null/empty resource validation ─────────────────────────── + + @Test + void acquireToken_nullResource_throwsMtlsMsiException() { + MtlsMsiClient client = new MtlsMsiClient(); + MtlsMsiException ex = assertThrows(MtlsMsiException.class, () -> + client.acquireToken(null, "SystemAssigned", null, false, null)); + assertTrue(ex.getMessage().contains("resource"), + "Exception message must mention 'resource'"); + } + + @Test + void acquireToken_emptyResource_throwsMtlsMsiException() { + MtlsMsiClient client = new MtlsMsiClient(); + MtlsMsiException ex = assertThrows(MtlsMsiException.class, () -> + client.acquireToken("", "SystemAssigned", null, false, null)); + assertTrue(ex.getMessage().contains("resource"), + "Exception message must mention 'resource'"); + } + + // ─── buildTokenUrl (private static — accessed via reflection) ───────────── + + @Test + void buildTokenUrl_appendsTenantAndPath() throws Exception { + String url = invokeBuildTokenUrl("https://mtlsauth.microsoft.com", "my-tenant"); + assertEquals("https://mtlsauth.microsoft.com/my-tenant/oauth2/v2.0/token", url); + } + + @Test + void buildTokenUrl_stripsTrailingSlash() throws Exception { + String url = invokeBuildTokenUrl("https://mtlsauth.microsoft.com/", "my-tenant"); + assertEquals("https://mtlsauth.microsoft.com/my-tenant/oauth2/v2.0/token", url, + "Trailing slash on the endpoint must be stripped before appending the path"); + } + + @Test + void buildTokenUrl_regionalEndpoint() throws Exception { + String url = invokeBuildTokenUrl( + "https://eastus.mtlsauth.microsoft.com", "a-tenant-guid"); + assertEquals("https://eastus.mtlsauth.microsoft.com/a-tenant-guid/oauth2/v2.0/token", url); + } + + // ─── buildTokenRequestBody (private static — accessed via reflection) ────── + + @Test + void buildTokenRequestBody_includesGrantTypeAndTokenType() throws Exception { + String body = invokeBuildTokenRequestBody("my-client", "https://management.azure.com"); + assertTrue(body.contains("grant_type=client_credentials"), + "request body must include grant_type=client_credentials"); + assertTrue(body.contains("token_type=mtls_pop"), + "request body must include token_type=mtls_pop (mTLS PoP token grant)"); + } + + @Test + void buildTokenRequestBody_appendsDefaultScopeWhenMissing() throws Exception { + String body = invokeBuildTokenRequestBody("my-client", "https://management.azure.com"); + // /.default must be appended + assertTrue(body.contains(".default"), + "scope must have /.default appended when it is absent from the resource URI"); + } + + @Test + void buildTokenRequestBody_doesNotDoubleAppendDefault() throws Exception { + // Resource already has /.default — must not add it again + String body = invokeBuildTokenRequestBody("my-client", + "https://management.azure.com/.default"); + int firstIdx = body.indexOf(".default"); + int secondIdx = body.indexOf(".default", firstIdx + 1); + assertEquals(-1, secondIdx, + "/.default must not appear twice in the request body"); + } + + @Test + void buildTokenRequestBody_includesClientId() throws Exception { + String body = invokeBuildTokenRequestBody("abc-def-123", "https://management.azure.com"); + assertTrue(body.contains("client_id=abc-def-123"), + "request body must include the client_id"); + } + + @Test + void buildTokenRequestBody_urlEncodesSpecialChars() throws Exception { + // Slashes in the scope value must be URL-encoded + String body = invokeBuildTokenRequestBody("c", "https://management.azure.com"); + // URL-encoded form: https%3A%2F%2Fmanagement.azure.com%2F.default + assertTrue(body.contains("%3A") || body.contains(":"), + "scope must be URL-encoded or use the raw value (both are correct)"); + } + + // ─── Helpers ────────────────────────────────────────────────────────────── + + private static String invokeBuildTokenUrl(String endpoint, String tenantId) throws Exception { + Method m = MtlsMsiClient.class.getDeclaredMethod( + "buildTokenUrl", String.class, String.class); + m.setAccessible(true); + try { + return (String) m.invoke(null, endpoint, tenantId); + } catch (InvocationTargetException e) { + throw (Exception) e.getCause(); + } + } + + private static String invokeBuildTokenRequestBody(String clientId, String resource) + throws Exception { + Method m = MtlsMsiClient.class.getDeclaredMethod( + "buildTokenRequestBody", String.class, String.class); + m.setAccessible(true); + try { + return (String) m.invoke(null, clientId, resource); + } catch (InvocationTargetException e) { + throw (Exception) e.getCause(); + } + } +} diff --git a/msal4j-mtls-extensions/src/test/java/com/microsoft/aad/msal4j/mtls/Pkcs10BuilderTest.java b/msal4j-mtls-extensions/src/test/java/com/microsoft/aad/msal4j/mtls/Pkcs10BuilderTest.java new file mode 100644 index 00000000..0af24613 --- /dev/null +++ b/msal4j-mtls-extensions/src/test/java/com/microsoft/aad/msal4j/mtls/Pkcs10BuilderTest.java @@ -0,0 +1,219 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.aad.msal4j.mtls; + +import com.sun.jna.Pointer; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledOnOs; +import org.junit.jupiter.api.condition.OS; +import org.mockito.MockedStatic; +import org.mockito.Mockito; + +import java.math.BigInteger; +import java.util.Base64; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.ArgumentMatchers.nullable; + +/** + * Unit tests for {@link Pkcs10Builder} DER encoding. + * + *

Tests are split into two groups:

+ *
    + *
  1. Pure DER primitives — no CNG required; run on all platforms.
  2. + *
  3. Full CSR generation — requires Windows (loads NCryptLibrary via + * CngKeyGuard); {@link CngKeyGuard#signPss} is mocked so no real CNG key is needed.
  4. + *
+ * + *

The CSR format must match msal-go's {@code generateCSR()} and MSAL.NET's + * {@code Csr.Generate()} exactly so that the Azure IMDS {@code /issuecredential} + * endpoint can parse and validate it.

+ */ +class Pkcs10BuilderTest { + + // ─── DER primitives (pure Java, cross-platform) ─────────────────────────── + + @Test + void derSequence_short_wrapsWithTag30() { + byte[] content = {0x01, 0x02, 0x03}; + byte[] seq = Pkcs10Builder.derSequence(content); + + assertEquals(0x30, seq[0] & 0xFF, "SEQUENCE tag must be 0x30"); + assertEquals(3, seq[1] & 0xFF, "Short-form length must equal content length"); + assertEquals(0x01, seq[2]); + assertEquals(0x02, seq[3]); + assertEquals(0x03, seq[4]); + assertEquals(5, seq.length); + } + + @Test + void derSequence_shortFormMaxLength() { + // 127 bytes is the maximum for single-byte short-form length + byte[] content = new byte[127]; + byte[] seq = Pkcs10Builder.derSequence(content); + + assertEquals(0x30, seq[0] & 0xFF); + assertEquals(127, seq[1] & 0xFF); + assertEquals(2 + 127, seq.length); + } + + @Test + void derSequence_longForm1Byte_length128() { + // 128 bytes requires 0x81 long-form header + byte[] content = new byte[128]; + byte[] seq = Pkcs10Builder.derSequence(content); + + assertEquals(0x30, seq[0] & 0xFF); + assertEquals(0x81, seq[1] & 0xFF, "Long-form header byte for lengths 128-255 must be 0x81"); + assertEquals(128, seq[2] & 0xFF); + assertEquals(3 + 128, seq.length); + } + + @Test + void derSequence_longForm2Byte_length256() { + // 256 bytes requires 0x82 two-byte length + byte[] content = new byte[256]; + byte[] seq = Pkcs10Builder.derSequence(content); + + assertEquals(0x30, seq[0] & 0xFF); + assertEquals(0x82, seq[1] & 0xFF, "Long-form header for lengths 256+ must be 0x82"); + assertEquals(1, seq[2] & 0xFF, "High byte of length 256 (0x0100)"); + assertEquals(0, seq[3] & 0xFF, "Low byte of length 256"); + assertEquals(4 + 256, seq.length); + } + + @Test + void derBitString_prependsZeroUnusedBitsByte() { + byte[] data = {0x01, 0x02}; + byte[] bs = Pkcs10Builder.derBitString(data); + + assertEquals(0x03, bs[0] & 0xFF, "BIT STRING tag must be 0x03"); + assertEquals(3, bs[1] & 0xFF, "Length must cover the unused-bits byte + data"); + assertEquals(0x00, bs[2], "Unused bits must be 0x00 (byte-aligned content)"); + assertEquals(0x01, bs[3]); + assertEquals(0x02, bs[4]); + } + + @Test + void derBitString_empty_hasOnlyUnusedBitsByte() { + byte[] bs = Pkcs10Builder.derBitString(new byte[0]); + assertEquals(0x03, bs[0] & 0xFF); + assertEquals(1, bs[1] & 0xFF); + assertEquals(0x00, bs[2]); + } + + // ─── Full CSR generation (Windows only — CngKeyGuard is mocked) ─────────── + + @Test + @EnabledOnOs(OS.WINDOWS) + void generate_outputIsBase64EncodedDerSequence() throws Exception { + BigInteger modulus = BigInteger.valueOf(2).pow(2047).add(BigInteger.ONE); + byte[] fakeSignature = new byte[256]; // 2048-bit RSA output size + + try (MockedStatic mockCng = Mockito.mockStatic(CngKeyGuard.class)) { + mockCng.when(() -> CngKeyGuard.signPss(any(), any(), anyString(), anyInt())) + .thenReturn(fakeSignature); + + String b64 = Pkcs10Builder.generate( + Pointer.NULL, modulus, 65537, + "test-client-id", "test-tenant-id", "vm-id-1", null); + + assertNotNull(b64); + byte[] der = Base64.getDecoder().decode(b64); + assertEquals(0x30, der[0] & 0xFF, + "Outermost CSR element must be a DER SEQUENCE (0x30)"); + } + } + + @Test + @EnabledOnOs(OS.WINDOWS) + void generate_cuIdAttributeContainsVmIdAndVmssId() throws Exception { + BigInteger modulus = BigInteger.valueOf(2).pow(2047).add(BigInteger.ONE); + byte[] fakeSignature = new byte[256]; + + try (MockedStatic mockCng = Mockito.mockStatic(CngKeyGuard.class)) { + mockCng.when(() -> CngKeyGuard.signPss(any(), any(), anyString(), anyInt())) + .thenReturn(fakeSignature); + + String b64 = Pkcs10Builder.generate( + Pointer.NULL, modulus, 65537, + "client-a", "tenant-b", "my-vm-id", "my-vmss-id"); + + byte[] der = Base64.getDecoder().decode(b64); + String derText = new String(der, java.nio.charset.StandardCharsets.UTF_8); + assertTrue(derText.contains("my-vm-id"), + "CSR DER must embed vmId in the cuId attribute"); + assertTrue(derText.contains("my-vmss-id"), + "CSR DER must embed vmssId in the cuId attribute"); + } + } + + @Test + @EnabledOnOs(OS.WINDOWS) + void generate_cuIdAttributeEmptyObject_whenBothIdsNull() throws Exception { + BigInteger modulus = BigInteger.valueOf(2).pow(2047).add(BigInteger.ONE); + byte[] fakeSignature = new byte[256]; + + try (MockedStatic mockCng = Mockito.mockStatic(CngKeyGuard.class)) { + mockCng.when(() -> CngKeyGuard.signPss(any(), any(), anyString(), anyInt())) + .thenReturn(fakeSignature); + + String b64 = Pkcs10Builder.generate( + Pointer.NULL, modulus, 65537, + "client-a", "tenant-b", null, null); + + byte[] der = Base64.getDecoder().decode(b64); + String derText = new String(der, java.nio.charset.StandardCharsets.UTF_8); + assertTrue(derText.contains("{}"), + "cuId JSON must be '{}' when both vmId and vmssId are null (omitempty)"); + } + } + + @Test + @EnabledOnOs(OS.WINDOWS) + void generate_subjectContainsClientIdAndTenantId() throws Exception { + BigInteger modulus = BigInteger.valueOf(2).pow(2047).add(BigInteger.ONE); + byte[] fakeSignature = new byte[256]; + + try (MockedStatic mockCng = Mockito.mockStatic(CngKeyGuard.class)) { + mockCng.when(() -> CngKeyGuard.signPss(any(), any(), anyString(), anyInt())) + .thenReturn(fakeSignature); + + String b64 = Pkcs10Builder.generate( + Pointer.NULL, modulus, 65537, + "subject-client-id", "subject-tenant-id", null, null); + + byte[] der = Base64.getDecoder().decode(b64); + String derText = new String(der, java.nio.charset.StandardCharsets.UTF_8); + assertTrue(derText.contains("subject-client-id"), + "CSR subject (CN) must contain the clientId"); + assertTrue(derText.contains("subject-tenant-id"), + "CSR subject (DC) must contain the tenantId"); + } + } + + @Test + @EnabledOnOs(OS.WINDOWS) + void generate_signPssCalledWithSha256AndSalt32() throws Exception { + BigInteger modulus = BigInteger.valueOf(2).pow(2047).add(BigInteger.ONE); + byte[] fakeSignature = new byte[256]; + + try (MockedStatic mockCng = Mockito.mockStatic(CngKeyGuard.class)) { + mockCng.when(() -> CngKeyGuard.signPss(any(), any(), anyString(), anyInt())) + .thenReturn(fakeSignature); + + Pkcs10Builder.generate( + Pointer.NULL, modulus, 65537, + "c", "t", null, null); + + // Verify the exact signature algorithm parameters (must match msal-go and MSAL.NET) + mockCng.verify(() -> CngKeyGuard.signPss( + nullable(Pointer.class), + any(byte[].class), + eq("SHA256"), + eq(32))); + } + } +} diff --git a/msal4j-sdk/docs/keyguard-jvm-analysis.md b/msal4j-sdk/docs/keyguard-jvm-analysis.md new file mode 100644 index 00000000..45e81fd5 --- /dev/null +++ b/msal4j-sdk/docs/keyguard-jvm-analysis.md @@ -0,0 +1,154 @@ +# KeyGuard Key Creation — JVM Feasibility Analysis + +## Summary + +**Java cannot natively create or use KeyGuard keys on Windows.** The root cause is the same as for Node.js/OpenSSL: Java's Windows TLS/crypto integration (`SunMSCAPI`) uses the legacy **CryptoAPI (CAPI)** rather than the modern **CNG (Cryptography API: Next Generation)**. KeyGuard keys are CNG-only. + +The only viable architecture for Managed Identity mTLS PoP in Java is the same subprocess approach used by msal-node: spawning `MsalMtlsMsiHelper.exe` (a .NET 8 binary) which uses Schannel and CNG natively. + +This document records the investigation into whether a JNI native addon could replace the subprocess, including a detailed breakdown of exactly where the pipeline breaks. + +--- + +## Background: What KeyGuard Keys Are + +KeyGuard keys are **hardware-isolated private keys** stored in a Windows CNG key storage provider (KSP) with Virtualization Based Security (VBS) isolation. They are created via: + +```c +NCryptCreatePersistedKey( + hProvider, + &hKey, + BCRYPT_RSA_ALGORITHM, + keyName, + 0, + NCRYPT_VBS_KEYISOLATION_FLAG // <- CNG-only flag +); +``` + +Once created, the key material never leaves the VBS enclave. Applications reference the key via an `NCRYPT_KEY_HANDLE`. The TLS stack (Schannel) accepts an `NCRYPT_KEY_HANDLE` directly and uses it for signing operations without needing raw key bytes. + +--- + +## Java's Windows Crypto Stack + +### Provider Inventory (JDK 21) + +On a standard Windows system running JDK 21: + +| Provider | Algorithm Coverage | Backend | +|---|---|---| +| `SUN` | JCA core | Pure Java | +| `SunRsaSign` | RSA signature | Pure Java | +| `SunEC` | EC key/signature | Pure Java | +| `SunJSSE` | TLS/SSL | Pure Java + OS TLS via `SunMSCAPI` delegation for Windows cert store only | +| `SunMSCAPI` | `KeyStore.Windows-MY`, `KeyPairGenerator.RSA` | **Windows CAPI (`CryptAcquireContext`, `CryptGenKey`, `CryptSignMessage`)** | +| `SunJCE` | AES, HMAC, etc. | Pure Java | + +Notably absent: **No `SunCNG` or equivalent CNG-aware provider exists in any JDK distribution** (Oracle, OpenJDK, Microsoft Build of OpenJDK, Azul, etc.). + +### `SunMSCAPI` Internal Implementation + +`SunMSCAPI` in JDK 21 (source: `jdk/src/windows/native/sun/security/mscapi/security.cpp`) calls: + +```c +// Key generation +CryptAcquireContext(&hCryptProv, keyName, NULL, PROV_RSA_AES, 0); +CryptGenKey(hCryptProv, AT_KEYEXCHANGE, CRYPT_EXPORTABLE | keySize, &hCryptKey); +``` + +This is the **CAPI path**. There is no call to `NCryptOpenStorageProvider`, `NCryptCreatePersistedKey`, or any other `NCrypt*` function. + +The `CryptAcquireContext` with `PROV_RSA_AES` provider type specifically targets CAPI's RSA/AES provider, not a CNG KSP. CAPI and CNG are separate subsystems; CAPI has no mechanism to create keys with `NCRYPT_VBS_KEYISOLATION_FLAG`. + +--- + +## Why a JNI Addon Cannot Fully Replace the Subprocess + +A JNI C++ addon could theoretically handle the CNG-specific operations. Here is a step-by-step breakdown of the Managed Identity mTLS PoP pipeline and where JNI can and cannot help: + +### MI mTLS PoP Pipeline + +| Step | Operation | JNI Possible? | Notes | +|------|-----------|--------------|-------| +| 1 | IMDS `getplatformmetadata` | ✅ Java HTTP | Standard HTTP, no JNI needed | +| 2 | Create KeyGuard key | ✅ JNI (NCrypt) | `NCryptCreatePersistedKey` + `NCRYPT_VBS_KEYISOLATION_FLAG` | +| 3 | Generate CSR | ✅ JNI (NCrypt + CertEnroll) | `IX509CertificateRequestPkcs10` via `CertEnroll.dll` | +| 4 | MAA attestation (optional) | ✅ JNI | Calls `AttestationClientLib.dll` | +| 5 | IMDS `/issuecredential` | ✅ Java HTTP | Standard HTTP with CSR in body | +| 6 | Parse issued cert from response | ✅ Java | Standard X.509 parsing | +| 7 | **mTLS token request (TLS with KeyGuard key)** | ❌ **BLOCKED** | See below | + +### Step 7 Breakdown: Why TLS Is Blocked + +The mTLS token request to `mtlsauth.microsoft.com` must use the KeyGuard-backed private key for the TLS client certificate. This is where the JNI path fails: + +**Option A: Use Java's `JSSE`** + +``` +JSSE SSLContext → KeyManager → KeyManagerFactory → JCE Signature engine + → needs raw PrivateKey object + → calls key.getEncoded() or sign() via JCE +``` + +JSSE's `SunX509KeyManager` requires a `PrivateKey` object from the Java security API. Even if a JNI addon wraps the `NCRYPT_KEY_HANDLE` in a `PrivateKey` implementation, the underlying `java.security.Signature` provider (`SunRsaSign` or `SunEC`) calls `engineInitSign(PrivateKey)` and expects either: +- A `RSAPrivateCrtKeyImpl` (extracts raw key bytes via `key.getEncoded()`) +- Or a `PKCS11Key` from `SunPKCS11` (calls PKCS11 `C_Sign`) + +There is no standard SPI to "plug in" an `NCRYPT_KEY_HANDLE` as a signing backend. A custom `Provider` + `KeySpi` + `SignatureSpi` could be written, but: +- It would need to JNI into `NCryptSignHash` for each TLS handshake +- JSSE does not expose a hook to inject a custom `SSLEngine` signing path +- The `SSLSocketFactory` → `SSLSocket` → `SSLEngine` chain calls `KeyManager.chooseClientAlias` and then uses the returned `PrivateKey` through JCE — the same dead end + +**Option B: Use a JNI-wrapped `WinHTTP` or Schannel** + +A JNI addon could use `WinHttpSetOption(WINHTTP_OPTION_CLIENT_CERT_CONTEXT, ...)` with a cert context that references the `NCRYPT_KEY_HANDLE`. This would work at the Win32 level, but: +- It bypasses `URLConnection`, `HttpsURLConnection`, and all JSSE abstractions +- The response would have to be parsed from raw Win32 API output +- Building a compliant HTTP/1.1 client (redirects, connection pooling, header parsing) on top of raw `WinHTTP` is effectively reimplementing a full HTTP stack in JNI — months of work, not weeks + +**Option C: Use the Existing JNI Path in `SunMSCAPI`** + +`SunMSCAPI` can import certificates from the Windows `MY` store and return a `PrivateKey` that delegates signing to CAPI. But CAPI keys stored via `CryptImportKey` or `CryptGenKey` are not CNG keys. There is no CAPI API to import a key identified only by a `NCRYPT_KEY_HANDLE` — these are different object types in different subsystems. + +### What .NET Does Differently + +.NET's `HttpClientHandler` with `ClientCertificates`: + +``` +HttpClientHandler.ClientCertificates → SslStream → Schannel + → NCRYPT_KEY_HANDLE (from X509Certificate2) + → NCryptSignHash (called by Schannel internally) +``` + +Schannel accepts an `NCRYPT_KEY_HANDLE` directly via `SCHANNEL_CRED` or `SCH_CREDENTIALS`. The key material never leaves the VBS enclave — Schannel calls `NCryptSignHash` with the opaque handle, and the actual signing happens inside the enclave. + +Java's JSSE has no equivalent of `SCHANNEL_CRED`. JSSE runs TLS entirely in the JVM's managed memory, which means key material must flow into Java objects where the JVM GC can observe it. This is architecturally incompatible with hardware-isolated keys. + +--- + +## Conclusion + +| Approach | Feasibility | Notes | +|---|---|---| +| Pure Java (JSSE + SunMSCAPI) | ❌ Impossible | SunMSCAPI uses CAPI, not CNG | +| JNI addon for steps 1–6 only | ✅ Possible but incomplete | Cannot solve step 7 (TLS with KeyGuard key) | +| JNI + custom Schannel HTTP client | 🟡 Theoretically possible | ~3-6 months, enormous scope, maintenance burden | +| Subprocess (`MsalMtlsMsiHelper.exe`) | ✅ **Implemented** | Same approach as msal-node; .NET 8 handles all CNG/Schannel steps | + +The subprocess approach is the **only practical architecture** that: +- Works today +- Requires minimal Java code +- Correctly handles KeyGuard key creation, MAA attestation, and CNG-backed TLS +- Is consistent with the approach taken by msal-node + +A future JNI addon would be worthwhile only if the goal is to eliminate the .NET runtime dependency on Azure VMs, and only if a maintainer is willing to own a full Schannel-based HTTP client implementation in C++. + +--- + +## References + +- [Windows CNG Key Storage Provider](https://learn.microsoft.com/en-us/windows/win32/seccng/key-storage-and-retrieval) +- [NCRYPT_VBS_KEYISOLATION_FLAG](https://learn.microsoft.com/en-us/windows/win32/api/ncrypt/nf-ncrypt-ncryptcreatepersistedkey) +- [SunMSCAPI source (OpenJDK)](https://github.com/openjdk/jdk/blob/master/src/jdk.crypto.mscapi/windows/native/libsunmscapi/security.cpp) +- [msal-node KeyGuard NAPI Analysis](https://github.com/AzureAD/microsoft-authentication-library-for-js/tree/dev/extensions/msal-node-mtls-extensions) +- [Schannel CNG key handle usage](https://learn.microsoft.com/en-us/windows/win32/secauthn/tls-handshake-protocol) diff --git a/msal4j-sdk/docs/mtls-pop-architecture.md b/msal4j-sdk/docs/mtls-pop-architecture.md new file mode 100644 index 00000000..3d002553 --- /dev/null +++ b/msal4j-sdk/docs/mtls-pop-architecture.md @@ -0,0 +1,154 @@ +# mTLS PoP Architecture — Deep Dive + +This document describes the internal architecture of the mTLS Proof of Possession implementation in MSAL4J. For the user-facing API guide, see [mtls-pop.md](mtls-pop.md). + +--- + +## Flow Diagrams + +### Path 1 — Confidential Client (SNI Certificate) + +```mermaid +sequenceDiagram + participant App + participant MSAL as MSAL4J + participant mtlsauth as {region}.mtlsauth.microsoft.com + + App->>MSAL: acquireToken(withMtlsProofOfPossession()) + MSAL->>MSAL: Resolve region → build mTLS endpoint URL + MSAL->>MSAL: MtlsSslContextHelper.createSslSocketFactory(key, cert) + MSAL->>mtlsauth: POST /{tenant}/oauth2/v2.0/token
(TLS handshake with caller cert — no client_assertion) + mtlsauth-->>MSAL: token_type=mtls_pop, access_token + MSAL-->>App: IAuthenticationResult{accessToken, bindingCertificate} + Note over App: Subsequent calls → TokenSource=Cache +``` + +### Path 2 — Managed Identity (IMDSv2) + +```mermaid +sequenceDiagram + participant App + participant Ext as MtlsMsiClient (msal4j-mtls-extensions) + participant IMDS as IMDS (169.254.169.254) + participant CNG as Windows CNG via JNA (ncrypt.dll) + participant Attest as AttestationClientLib.dll → MAA + participant Token as mTLS Token Endpoint + + App->>Ext: acquireToken(resource, "SystemAssigned", withAttestation) + Ext->>IMDS: GET /metadata/identity/getplatformmetadata + IMDS-->>Ext: clientID, tenantID, cuID, attestationEndpoint + Ext->>CNG: GetOrCreateManagedIdentityKey(MSALMtlsKey_{cuID}) + Note over CNG: KeyGuard (VBS) → Hardware → InMemory + CNG-->>Ext: RSA-2048 key handle (CngKey) + Ext->>Ext: Build PKCS#10 CSR (Pkcs10Builder via JNA) + Ext->>Attest: AttestKeyGuardImportKey(attestationEndpoint, keyHandle) + Attest-->>Ext: MAA JWT (proves VBS KeyGuard protection) + Ext->>IMDS: POST /metadata/identity/issuecredential {csr, attestation_token} + IMDS-->>Ext: binding_certificate + mtls_authentication_endpoint + Ext->>Ext: Cache binding cert (expires 5 min before NotAfter) + Ext->>Token: POST /{tenant}/oauth2/v2.0/token
(TLS handshake with binding cert via CngSignatureSpi) + Token-->>Ext: token_type=mtls_pop, access_token + Ext-->>App: MtlsMsiHelperResult{accessToken, bindingCertificate} + Note over App: Subsequent calls → cert cache hit, then token cache hit +``` + +--- + +## 1. How Java Uses Windows CNG Without JNI Headers + +Java has no built-in C FFI, but [JNA (Java Native Access)](https://github.com/java-native-access/jna) provides dynamic binding to native DLLs using pure Java interfaces — no C headers, no `javah`, no native compilation step beyond the DLL itself. + +```java +// JNA interface — maps directly to ncrypt.dll exports +interface NCrypt extends Library { + int NCryptOpenStorageProvider(PointerByReference phProvider, String pszProviderName, int dwFlags); + int NCryptCreatePersistedKey(Pointer hProvider, PointerByReference phKey, + String pszAlgId, String pszKeyName, int dwLegacyKeySpec, int dwFlags); + int NCryptSetProperty(Pointer hObject, String pszProperty, byte[] pbInput, int cbInput, int dwFlags); + int NCryptFinalizeKey(Pointer hKey, int dwFlags); + int NCryptSignHash(Pointer hKey, Pointer pPaddingInfo, byte[] pbHashValue, int cbHashValue, + byte[] pbSignature, int cbSignature, PointerByReference pcbResult, int dwFlags); +} +``` + +The key flag that enables KeyGuard VBS isolation: +```java +private static final int NCRYPT_VBS_KEYISOLATION_FLAG = 0x00010000; +NCrypt.INSTANCE.NCryptFinalizeKey(hKey, NCRYPT_VBS_KEYISOLATION_FLAG); +``` + +This is the same flag used by msal-dotnet (via `CngKey`) and msal-go (via `syscall.NewLazyDLL`). + +--- + +## 2. Custom `java.security.Provider` for CNG-Backed TLS + +Java's JSSE TLS stack calls `java.security.Signature` for the TLS `CertificateVerify` handshake message. A standard Java `PrivateKey` from `SunMSCAPI` cannot wrap a CNG KeyGuard key handle. + +The solution: a custom `java.security.Provider` (`CngProvider`) that registers `CngSignatureSpi` — a `Signature` implementation that delegates signing to `NCryptSignHash` via JNA, keeping the key handle inside the VBS enclave. + +``` +JSSE TLS handshake + └─► KeyManager.getPrivateKey() → returns CngPrivateKey (opaque handle) + └─► Signature.getInstance("SHA256withRSA", CngProvider) + └─► CngSignatureSpi.engineInitSign(CngPrivateKey) + └─► CngSignatureSpi.engineSign() + └─► NCryptSignHash(hKey, BCRYPT_PKCS1_PADDING, hash, ...) via JNA + └─► ncrypt.dll (in-process, VBS KeyGuard boundary) +``` + +`engineInitVerify` throws `InvalidKeyException` intentionally — this causes JSSE's provider selection to fall through to `SunRsaSign`, which handles server certificate verification correctly. `CngSignatureSpi` only intercepts signing operations with the KeyGuard key. + +--- + +## 3. Certificate Caching + +The binding certificate (issued by `managedidentitysnissuer.login.microsoft.com`) is cached in-memory with a 5-minute pre-expiry buffer: + +``` +certCache key: cuID (compute unit ID from IMDS platform metadata) +certCache value: {bindingCert, expiry = cert.NotAfter - 5min} +``` + +The CNG key is persisted in the Microsoft Software Key Storage Provider under the name `MSALMtlsKey_{cuID}` (user scope). On subsequent calls, the key is opened with `NCryptOpenKey` rather than re-created, ensuring the same public key is presented in the CSR and that the cached binding certificate remains valid. + +--- + +## 4. Cross-SDK Architecture Comparison + +| Concern | msal-java | msal-dotnet | msal-go | msal-node | +|---------|-----------|-------------|---------|-----------| +| CNG key creation | JNA → `ncrypt.dll` | `CngKey` (.NET) | `syscall.NewLazyDLL` | Subprocess (exe) | +| TLS with CNG key | `CngSignatureSpi` + JSSE | Schannel (`NCRYPT_KEY_HANDLE`) | `crypto.Signer` interface | .NET subprocess | +| CSR generation | `Pkcs10Builder` (pure Java ASN.1) | `CertificateRequest` (.NET) | `encoding/asn1` (Go stdlib) | Subprocess | +| Attestation | JNA → `AttestationClientLib.dll` | Native NuGet package | `syscall` → DLL | Subprocess | +| In-process | ✅ | ✅ | ✅ | ❌ | +| .NET required | ❌ | ✅ (runtime) | ❌ | ✅ (subprocess) | + +--- + +## 5. Why Path 1 Does Not Need JNA + +Path 1 (SNI / Confidential Client) uses a certificate the caller already owns — typically loaded from a PKCS12 file or PKCS11 hardware token. Java's standard `KeyManagerFactory` and JSSE handle this transparently. The custom `SSLSocketFactory` built by `MtlsSslContextHelper` sets up the client certificate for the TLS handshake — no CNG involved. + +--- + +## 6. Key Source Names + +| Key Source | Description | +|------------|-------------| +| `KeyGuard` | Full VBS isolation — requires Credential Guard running | +| `Hardware` | TPM-backed but not VBS-isolated | +| `InMemory` | Software key — no hardware protection | + +For production use, `KeyGuard` is required (`xms_tbflags: 2` in the token). `Hardware` or `InMemory` keys will result in `AADSTS392196` or similar errors from AAD. + +--- + +## References + +- [mTLS PoP API Guide](mtls-pop.md) +- [mTLS PoP Manual Testing](mtls-pop-manual-testing.md) +- [RFC 8705 — OAuth 2.0 Mutual-TLS Client Authentication](https://www.rfc-editor.org/rfc/rfc8705) +- [JNA (Java Native Access)](https://github.com/java-native-access/jna) +- [NCrypt API (MSDN)](https://docs.microsoft.com/en-us/windows/win32/api/ncrypt/) diff --git a/msal4j-sdk/docs/mtls-pop-manual-testing.md b/msal4j-sdk/docs/mtls-pop-manual-testing.md new file mode 100644 index 00000000..48aad092 --- /dev/null +++ b/msal4j-sdk/docs/mtls-pop-manual-testing.md @@ -0,0 +1,288 @@ +# mTLS PoP Manual Testing Guide + +This guide walks through manual verification of both mTLS PoP paths in MSAL4J. + +--- + +## Prerequisites + +- Java 8+ and Maven installed +- For SNI path: a valid test certificate (PKCS12) +- For Managed Identity path: + - An Azure VM with managed identity enabled + - Windows x64 OS with VBS (Virtualization-Based Security) KeyGuard + - `msal4j-mtls-extensions` on classpath (add dependency) + - On Trusted Launch VMs: `AttestationClientLib.dll` on `PATH` +- An AAD tenant with a registered app (client credentials configured) + +--- + +## Path 1: SNI / ConfidentialClientApplication + +### 1. Generate a test certificate + +If you don't have a certificate, generate one with keytool: + +```bash +keytool -genkeypair \ + -alias mtls-test \ + -keyalg RSA \ + -keysize 2048 \ + -validity 365 \ + -storetype PKCS12 \ + -keystore test-cert.p12 \ + -storepass changeit \ + -dname "CN=MyApp, O=MyOrg, C=US" +``` + +Export the public certificate for registration with Azure AD: + +```bash +keytool -exportcert -alias mtls-test -keystore test-cert.p12 -storepass changeit -rfc -file test-cert.pem +``` + +Upload `test-cert.pem` to your app registration → **Certificates & secrets → Certificates**. + +### 2. Verify the mTLS endpoint is reachable + +```bash +curl -v --cert test-cert.pem --key test-key.pem \ + "https://eastus.mtlsauth.microsoft.com/your-tenant-id/oauth2/v2.0/token" \ + -d "grant_type=client_credentials&client_id=your-client-id&scope=https://graph.microsoft.com/.default&token_type=mtls_pop" +``` + +Expected: HTTP 200 with `"token_type":"mtls_pop"` in the JSON response. + +### 3. Java test program + +Create `TestMtlsPop.java`: + +```java +import com.microsoft.aad.msal4j.*; +import java.io.*; +import java.util.*; + +public class TestMtlsPop { + public static void main(String[] args) throws Exception { + // Load certificate (PKCS12) + IClientCertificate cert = ClientCredentialFactory.createFromCertificate( + new FileInputStream("test-cert.p12"), "changeit"); + + // Build app — tenanted authority and region required + ConfidentialClientApplication app = ConfidentialClientApplication + .builder("your-client-id", cert) + .authority("https://login.microsoftonline.com/your-tenant-id") + .azureRegion("eastus") + .build(); + + // Acquire mTLS PoP token + Set scopes = Collections.singleton("https://graph.microsoft.com/.default"); + ClientCredentialParameters params = ClientCredentialParameters + .builder(scopes) + .withMtlsProofOfPossession() + .build(); + + IAuthenticationResult result = app.acquireToken(params).get(); + + System.out.println("=== SUCCESS ==="); + System.out.println("Token type: " + result.tokenType()); + System.out.println("Expires: " + result.expiresOnDate()); + System.out.println("Binding cert CN: " + result.bindingCertificate().getSubjectX500Principal()); + System.out.println("Access token: " + result.accessToken().substring(0, 40) + "..."); + } +} +``` + +### 4. Expected output + +``` +=== SUCCESS === +Token type: mtls_pop +Expires: +Binding cert CN: CN=MyApp, O=MyOrg, C=US +Access token: eyJ0eXAiOiJKV1QiLCJub25jZSI6... +``` + +### 5. Verify cache hit (silent re-acquisition) + +Call `acquireToken` again immediately — the second call should not make a network request: + +```java +IAuthenticationResult r1 = app.acquireToken(params).get(); +long t0 = System.currentTimeMillis(); +IAuthenticationResult r2 = app.acquireToken(params).get(); +System.out.println("Same token: " + r1.accessToken().equals(r2.accessToken())); // true +System.out.println("Elapsed: " + (System.currentTimeMillis() - t0) + "ms"); // should be <50ms +``` + +### 6. Verify token binding + +Decode the access token (base64url decode the middle JWT segment) and verify: + +```json +{ + "cnf": { + "x5t#S256": "" + } +} +``` + +### Troubleshooting + +| Symptom | Likely Cause | Fix | +|---------|-------------|-----| +| `invalid_request` - authority must be tenanted | Common/organizations authority | Use `https://login.microsoftonline.com/{tenantId}` | +| `invalid_request` - certificate credential required | Using client secret | Switch to `ClientCertificate` credential | +| `AADSTS700016` - application not found | Wrong tenant or client ID | Verify app registration | +| `AADSTS7000215` - invalid client secret | Certificate not registered | Upload PEM cert to Azure Portal | +| `AADSTS90002` - tenant not found | Typo in tenant ID | Check tenant GUID | +| SSL handshake failure | Wrong cert/key | Verify cert and private key are paired | +| `Connection refused` on mtlsauth.microsoft.com | Network/firewall | Check outbound HTTPS access on port 443 | + +--- + +## Path 2: Managed Identity + +### Prerequisites + +- Azure VM with managed identity enabled (System-assigned or User-assigned) +- Windows x64 OS with VBS (Virtualization-Based Security) KeyGuard +- `msal4j-mtls-extensions` JAR on classpath (or use the pre-built fat JAR) +- On Trusted Launch VMs: `AttestationClientLib.dll` on `PATH` or application directory +- No .NET runtime required — the extension calls CNG directly via JNA + +### 1. Build the e2e fat JAR + +```bash +cd msal4j-mtls-extensions +mvn package -DskipTests +# Produces: target/msal4j-mtls-extensions-1.0.0-e2e.jar +``` + +### 2. Run Path 2 (Managed Identity) + +```powershell +# Basic (no attestation — works on standard VMs) +java -jar target\msal4j-mtls-extensions-1.0.0-e2e.jar path2 + +# With attestation (Trusted Launch VMs with AttestationClientLib.dll) +java -Djava.library.path=C:\msiv2 -jar target\msal4j-mtls-extensions-1.0.0-e2e.jar path2 --attest +``` + +### 3. Expected output + +``` +=== Path 2: Managed Identity mTLS PoP === + +Acquiring mTLS PoP token via IMDSv2 (full flow)... + +[First call (from IMDS)] + ✅ BindingCertificate present + Subject: CN=,DC= + Issuer: CN=managedidentitysnissuer.login.microsoft.com + NotBefore: ... + NotAfter: ... (14 days) + TokenType: mtls_pop + ExpiresIn: 86399s + AccessToken cnf: {"x5t#S256":""} + ✅ AccessToken present + +Acquiring again (expect cert cache hit)... +[Second call (should be cert-cached, ~fast)] + ✅ Binding cert cache working: same cert on second call + ⏱ Elapsed: ~60ms + +Making downstream mTLS call to graph.microsoft.com... + Downstream HTTP status: 401 + ✅ TLS handshake + token delivery succeeded (HTTP < 500) + ℹ️ 401 — TLS OK, authorization depends on permissions + +=== Path 2 Complete === +``` + +> **Expected HTTP 401 from graph.microsoft.com:** This is correct behavior. The TLS handshake and token were accepted — the managed identity simply has no Graph role assigned. HTTP 401 confirms the mTLS PoP flow succeeded end-to-end. + +### 4. Java API + +```java +import com.microsoft.aad.msal4j.mtls.*; + +MtlsMsiClient client = new MtlsMsiClient(); +MtlsMsiHelperResult result = client.acquireToken( + "https://graph.microsoft.com", // resource (graph.microsoft.com confirmed enrolled) + "SystemAssigned", // identity type + null, // identity id (null for system-assigned) + false, // withAttestation — set true on Trusted Launch VMs + null // correlationId (optional) +); + +String accessToken = result.getAccessToken(); +String certPem = result.getBindingCertificate(); +``` + +> **Resource note:** Use `https://graph.microsoft.com` or `https://storage.azure.com` for testing. +> `https://management.azure.com` may return `AADSTS392196` if the resource is not enrolled for mTLS PoP in your tenant. + +### 5. Verify token claims + +Decode the JWT payload and confirm: + +```powershell +$token = "" +$parts = $token -split "\." +[System.Text.Encoding]::UTF8.GetString( + [System.Convert]::FromBase64String( + $parts[1].PadRight($parts[1].Length + (4 - $parts[1].Length % 4) % 4, '='))) | + ConvertFrom-Json +``` + +Expected claims: +```json +{ + "cnf": { "x5t#S256": "" }, + "xms_tbflags": 2, + "appidacr": "2", + "aud": "https://graph.microsoft.com", + "idtyp": "app", + "app_displayname": "" +} +``` + +The `cnf.x5t#S256` thumbprint must match the binding certificate returned by `result.getBindingCertificate()`. + +### Troubleshooting + +| Symptom | Likely Cause | Fix | +|---------|-------------|-----| +| `VBS KeyGuard not available` | Credential Guard not enabled | Enable VBS/Credential Guard and reboot | +| `AttestationClientLib.dll not found` | DLL not on PATH | Copy DLL from NuGet package to application directory | +| `HTTP 400 from IMDS issuecredential` | Attestation token empty | Check DLL is present; VM must be Trusted Launch | +| `AADSTS392196` | Resource not enrolled for mTLS PoP | Use `https://graph.microsoft.com` instead | +| `IMDS not accessible` | Not running on Azure VM | This path only works in Azure managed identity environments | +| `NCryptFinalizeKey NTE_BAD_FLAGS` | VBS not running | Check `msinfo32.exe` → Virtualization-based security must show "Running" | + +--- + +## Validating Cache Isolation + +To verify mTLS PoP and Bearer tokens don't share cache entries: + +```java +// Acquire Bearer token +ClientCredentialParameters bearerParams = ClientCredentialParameters + .builder(scopes) + .build(); +IAuthenticationResult bearerResult = app.acquireToken(bearerParams).get(); + +// Acquire mTLS PoP token +ClientCredentialParameters mtlsParams = ClientCredentialParameters + .builder(scopes) + .withMtlsProofOfPossession() + .build(); +IAuthenticationResult mtlsResult = app.acquireToken(mtlsParams).get(); + +// Tokens must be different +assert !bearerResult.accessToken().equals(mtlsResult.accessToken()); +assert "Bearer".equalsIgnoreCase(bearerResult.tokenType()); +assert "mtls_pop".equalsIgnoreCase(mtlsResult.tokenType()); +``` diff --git a/msal4j-sdk/docs/mtls-pop.md b/msal4j-sdk/docs/mtls-pop.md new file mode 100644 index 00000000..d45c769f --- /dev/null +++ b/msal4j-sdk/docs/mtls-pop.md @@ -0,0 +1,258 @@ +# mTLS Proof-of-Possession (mTLS PoP) in MSAL4J + +## Overview + +mTLS Proof-of-Possession is a token-binding mechanism that cryptographically ties an access token to a specific client certificate. Unlike Bearer tokens, an mTLS PoP token can only be used by the party that holds the private key of the binding certificate — even if the token is intercepted, it is useless without the private key. + +The token contains a `cnf` (confirmation) claim with an `x5t#S256` field: the SHA-256 thumbprint of the binding certificate. The resource server validates the mTLS connection and checks that the connecting client's certificate matches the token's `cnf` claim. + +MSAL4J supports two mTLS PoP acquisition paths: + +| Path | Application Type | Certificate Source | Attestation | +|------|-----------------|-------------------|-------------| +| **SNI (Subject Name Indication)** | `ConfidentialClientApplication` | Any PKCS12/PEM cert or hardware token (PKCS11) | Not required | +| **Managed Identity** | `MtlsMsiClient` (via `msal4j-mtls-extensions`) | IMDS-issued KeyGuard-backed certificate | Optional (Trusted Launch VMs) | + +--- + +## Cross-SDK Implementation Comparison + +| Library | TLS Stack | CNG Support | Approach | +|---------|-----------|-------------|----------| +| **msal-java** | JSSE + custom `SSLSocketFactory` (Path 1); JNA → `ncrypt.dll` (Path 2) | ✅ Via JNA | In-process | +| **msal-dotnet** | Schannel (.NET) | ✅ Native | In-process | +| **msal-go** | `crypto/tls` (pure Go) | ✅ Via `crypto.Signer` | In-process | +| **msal-node** | OpenSSL (Node.js) | ❌ None | .NET subprocess (`MsalMtlsMsiHelper.exe`) | + +No subprocess is needed in msal-java. + +--- + +## Why mTLS PoP? + +Standard Bearer tokens are vulnerable to: +- Token theft via XSS or compromised intermediaries +- Confused-deputy attacks where a stolen token is replayed from a different client + +With mTLS PoP: +- The token is bound to the TLS client certificate at issuance time +- Resource servers enforce that the connecting client certificate matches the token's `cnf` claim +- An attacker who steals the token cannot use it without also stealing the private key + +This makes mTLS PoP suitable for high-value API access from server-side applications running in trusted Azure environments. + +--- + +## Path 1: SNI / ConfidentialClientApplication + +### How It Works + +1. Your app creates a `ConfidentialClientApplication` with a certificate credential. +2. Calls `acquireToken` with `ClientCredentialParameters.withMtlsProofOfPossession()`. +3. MSAL4J builds a custom `SSLSocketFactory` from the cert and private key. +4. The token request goes to the mTLS-specific endpoint (`mtlsauth.microsoft.com` for public cloud). +5. The request body contains `grant_type=client_credentials`, `token_type=mtls_pop`, and `scope`. There is **no `client_assertion`** — authentication happens at the TLS layer. +6. The response access token contains a `cnf.x5t#S256` binding. + +### Requirements + +- Certificate credential (`ClientCertificate`) — PKCS12, PEM, or hardware-backed (PKCS11) +- Tenanted authority (must specify a tenant ID or tenant FQDN — common/organizations endpoints not supported) +- Azure region configured or auto-detect enabled +- AAD authority only — B2C is not supported +- Public cloud or sovereign clouds only (US Gov and China clouds are not supported) + +### Quick Start + +```java +import com.microsoft.aad.msal4j.*; +import java.io.FileInputStream; +import java.util.*; + +// 1. Load your certificate (PKCS12) +IClientCertificate cert = ClientCredentialFactory.createFromCertificate( + new FileInputStream("/path/to/cert.p12"), "password"); + +// 2. Build the app (tenanted authority + region required) +ConfidentialClientApplication app = ConfidentialClientApplication + .builder("your-client-id", cert) + .authority("https://login.microsoftonline.com/your-tenant-id") + .azureRegion("eastus") // or autoDetectRegion(true) + .build(); + +// 3. Request an mTLS PoP token +Set scopes = Collections.singleton("https://graph.microsoft.com/.default"); +ClientCredentialParameters params = ClientCredentialParameters + .builder(scopes) + .withMtlsProofOfPossession() + .build(); + +IAuthenticationResult result = app.acquireToken(params).get(); + +System.out.println("Token type: " + result.tokenType()); // "mtls_pop" +System.out.println("Binding cert: " + result.bindingCertificate().getSubjectX500Principal()); +System.out.println("Access token: " + result.accessToken()); +``` + +### Hardware-Backed Certificates (PKCS11) + +For hardware security modules or smart cards, load the private key and certificate chain from a PKCS11 provider: + +```java +Provider pkcs11Provider = Security.getProvider("SunPKCS11"); +KeyStore ks = KeyStore.getInstance("PKCS11", pkcs11Provider); +ks.load(null, "pin".toCharArray()); + +PrivateKey privateKey = (PrivateKey) ks.getKey("my-key-alias", null); +X509Certificate cert = (X509Certificate) ks.getCertificate("my-key-alias"); + +IClientCertificate clientCert = ClientCredentialFactory.createFromCertificate(privateKey, cert); +``` + +### Token Endpoint + +For public cloud, MSAL4J constructs: +``` +https://{region}.mtlsauth.microsoft.com/{tenantId}/oauth2/v2.0/token +``` + +For sovereign clouds, the `login.` prefix is replaced with `mtlsauth.`: +- `login.microsoftonline.us` → `mtlsauth.microsoftonline.us` +- `login.microsoftonline.de` → `mtlsauth.microsoftonline.de` + +US Government (`login.usgovcloudapi.net`) and China (`login.chinacloudapi.cn`) clouds are **not supported** — these clouds do not have an mTLS auth endpoint. + +### Token Cache + +mTLS PoP tokens are cached separately from Bearer tokens using: +- Credential type: `AccessToken_With_AuthScheme` (instead of `AccessToken`) +- Cache key suffix: `x5t#S256` thumbprint of the binding certificate + +This prevents a Bearer token cache hit from returning an mTLS PoP token and vice versa. + +--- + +## Path 2: Managed Identity + +### How It Works + +Managed Identity mTLS PoP uses a KeyGuard-backed certificate issued by IMDS. KeyGuard keys are hardware-isolated keys created with CNG (`NCryptCreatePersistedKey` with `NCRYPT_VBS_KEYISOLATION_FLAG`). The `msal4j-mtls-extensions` module calls Windows CNG directly via JNA (Java Native Access), so no .NET runtime or external subprocess is needed. + +The extension handles: +1. IMDS `getplatformmetadata` call +2. KeyGuard key creation via CNG (`ncrypt.dll`) +3. CSR generation +4. Optional MAA attestation via `AttestationClientLib.dll` +5. IMDS `/issuecredential` (get cert from IMDS) +6. mTLS token request to AAD + +MSAL4J orchestrates the extension via reflection and caches the result. + +### Requirements + +- Azure VM with system-assigned or user-assigned managed identity enabled +- Windows x64 OS with VBS (Virtualization-Based Security) KeyGuard available +- `msal4j-mtls-extensions` artifact on the classpath (add as Maven dependency) +- On Trusted Launch VMs: `AttestationClientLib.dll` on `PATH` or in the application directory (see [msal4j-mtls-extensions README](../../msal4j-mtls-extensions/README.md)) +- No .NET runtime required + +### Quick Start + +```xml + + + com.microsoft.azure + msal4j-mtls-extensions + 1.24.0 + +``` + +```java +// System-assigned managed identity +ManagedIdentityApplication app = ManagedIdentityApplication + .builder(ManagedIdentityId.systemAssigned()) + .build(); + +ManagedIdentityParameters params = ManagedIdentityParameters + .builder("https://graph.microsoft.com/.default") + .withMtlsProofOfPossession(true) + // .withAttestation(true) // optional: enable MAA attestation + .build(); + +IAuthenticationResult result = app.acquireTokenForManagedIdentity(params).get(); +System.out.println("Token type: " + result.tokenType()); // "mtls_pop" +``` + +For user-assigned managed identity: +```java +ManagedIdentityApplication app = ManagedIdentityApplication + .builder(ManagedIdentityId.userAssignedClientId("your-client-id")) + .build(); +``` + +--- + +## API Reference + +### `ClientCredentialParameters` + +| Method | Description | +|--------|-------------| +| `.withMtlsProofOfPossession()` | Acquires an `mtls_pop` token instead of Bearer | + +### `ManagedIdentityParameters` + +| Method | Description | +|--------|-------------| +| `.withMtlsProofOfPossession(boolean)` | When `true`, uses the JNA-backed KeyGuard extension to acquire a hardware-bound `mtls_pop` token | + +### `IAuthenticationResult` + +| Method | Returns | +|--------|---------| +| `.tokenType()` | `"mtls_pop"` or `"Bearer"` | +| `.bindingCertificate()` | `X509Certificate` used for mTLS binding, or `null` for Bearer | + +--- + +## Caching Behavior + +mTLS PoP tokens use a 7-segment cache key: +``` +{homeAccountId}-{environment}-{credentialType}-{clientId}-{realm}-{target}-{keyId} +``` +where `credentialType = AccessToken_With_AuthScheme` and `keyId` = Base64URL(SHA-256(DER cert)). + +Standard Bearer tokens use a 6-segment key (no `keyId`). The two token types never collide in cache. + +--- + +## Known Limitations + +- **US Government and China clouds** are not supported for SNI path (no mTLS auth endpoint). +- **Managed Identity path** requires Windows x64 with VBS KeyGuard (the JNA native binding is Windows-only). +- **No refresh token** — mTLS PoP tokens cannot be silently refreshed via a refresh token. They are re-acquired via client credentials or re-issued by IMDS. The in-memory cache covers the token lifetime. +- **Sovereign cloud attestation** — MAA attestation is only available in public cloud regions. + +--- + +## Error Reference + +| Error Code | Meaning | Fix | +|---|---|---| +| `invalid_request` | Authority is not tenanted | Use `https://login.microsoftonline.com/{tenantId}` | +| `invalid_request` | Credential is not a certificate | mTLS PoP requires a `ClientCertificate` credential | +| `invalid_request` | Unsupported cloud | Use public cloud or a supported sovereign cloud | +| `invalid_request` | Region required | Set `.azureRegion(...)` on the app builder | +| `invalid_request` | mTLS extensions not on classpath | Add `msal4j-mtls-extensions` dependency | + +--- + +## References + +- [mTLS PoP Manual Testing Guide](mtls-pop-manual-testing.md) +- [mTLS PoP Architecture](mtls-pop-architecture.md) +- [msal4j-mtls-extensions README](../../msal4j-mtls-extensions/README.md) +- [MSAL.js mTLS PoP](https://github.com/AzureAD/microsoft-authentication-library-for-js/pull/8476) +- [MSAL.NET mTLS PoP](https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/tree/main/docs) +- [RFC 8705 - OAuth 2.0 Mutual-TLS Client Authentication and Certificate-Bound Access Tokens](https://www.rfc-editor.org/rfc/rfc8705) diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AADAuthority.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AADAuthority.java index 5bc151c0..b3e18923 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AADAuthority.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AADAuthority.java @@ -8,6 +8,7 @@ class AADAuthority extends Authority { private static final String TENANTLESS_TENANT_NAME = "common"; + private static final String ORGANIZATIONS_TENANT_NAME = "organizations"; private static final String AUTHORIZATION_ENDPOINT = "oauth2/v2.0/authorize"; private static final String TOKEN_ENDPOINT = "oauth2/v2.0/token"; static final String DEVICE_CODE_ENDPOINT = "oauth2/v2.0/devicecode"; @@ -28,7 +29,8 @@ private void setAuthorityProperties() { this.tokenEndpoint = String.format(AAD_TOKEN_ENDPOINT_FORMAT, host, tenant); this.deviceCodeEndpoint = String.format(DEVICE_CODE_ENDPOINT_FORMAT, host, tenant); - this.isTenantless = TENANTLESS_TENANT_NAME.equalsIgnoreCase(tenant); + this.isTenantless = TENANTLESS_TENANT_NAME.equalsIgnoreCase(tenant) + || ORGANIZATIONS_TENANT_NAME.equalsIgnoreCase(tenant); this.selfSignedJwtAudience = this.tokenEndpoint; } } \ No newline at end of file diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AccessTokenCacheEntity.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AccessTokenCacheEntity.java index 92904da8..554feba7 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AccessTokenCacheEntity.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AccessTokenCacheEntity.java @@ -21,6 +21,7 @@ class AccessTokenCacheEntity extends Credential implements JsonSerializable keyParts = new ArrayList<>(); @@ -31,6 +32,9 @@ String getKey() { keyParts.add(clientId); keyParts.add(realm); keyParts.add(target); + if (!StringHelper.isBlank(keyId)) { + keyParts.add(keyId); + } return String.join(Constants.CACHE_KEY_SEPARATOR, keyParts).toLowerCase(); } @@ -77,6 +81,9 @@ static AccessTokenCacheEntity fromJson(JsonReader jsonReader) throws IOException case "refresh_on": entity.refreshOn = reader.getString(); break; + case "key_id": + entity.keyId = reader.getString(); + break; case "user_assertion_hash": entity.userAssertionHash = reader.getString(); break; @@ -105,6 +112,9 @@ public JsonWriter toJson(JsonWriter jsonWriter) throws IOException { jsonWriter.writeStringField("extended_expires_on", extExpiresOn); jsonWriter.writeStringField("refresh_on", refreshOn); jsonWriter.writeStringField("user_assertion_hash", userAssertionHash); + if (!StringHelper.isBlank(keyId)) { + jsonWriter.writeStringField("key_id", keyId); + } jsonWriter.writeEndObject(); @@ -158,4 +168,12 @@ void extExpiresOn(String extExpiresOn) { void refreshOn(String refreshOn) { this.refreshOn = refreshOn; } + + String keyId() { + return this.keyId; + } + + void keyId(String keyId) { + this.keyId = keyId; + } } \ No newline at end of file diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AcquireTokenByManagedIdentitySupplier.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AcquireTokenByManagedIdentitySupplier.java index c6545cf7..36ea7ab8 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AcquireTokenByManagedIdentitySupplier.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AcquireTokenByManagedIdentitySupplier.java @@ -6,6 +6,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.lang.reflect.Method; import java.time.Instant; import java.util.HashSet; import java.util.Set; @@ -32,6 +33,10 @@ AuthenticationResult execute() throws Exception { MsalErrorMessage.SCOPES_REQUIRED); } + if (managedIdentityParameters.mtlsProofOfPossession()) { + return executeMtlsPop(); + } + TokenRequestExecutor tokenRequestExecutor = new TokenRequestExecutor( clientApplication.authenticationAuthority, msalRequest, @@ -171,4 +176,91 @@ private long calculateRefreshOn(long expiresOn) { //The refreshOn value should be half the value of the token lifetime, if the lifetime is greater than two hours return expiresIn > TWO_HOURS ? (expiresIn / 2) + timestampSeconds : 0; } + + /** + * Handles the mTLS PoP path by delegating to {@code MtlsMsiClient} from the + * {@code msal4j-mtls-extensions} package (loaded via reflection so that the core SDK + * does not have a compile-time dependency on the extension). + */ + private AuthenticationResult executeMtlsPop() throws Exception { + // Resolve identity type and ID from the ManagedIdentityId on the application + ManagedIdentityApplication miApp = (ManagedIdentityApplication) clientApplication; + ManagedIdentityId miId = miApp.getManagedIdentityId(); + String identityType = miId.getIdType() == ManagedIdentityIdType.SYSTEM_ASSIGNED + ? "SystemAssigned" : "UserAssigned"; + String identityId = miId.getIdType() != ManagedIdentityIdType.SYSTEM_ASSIGNED + ? miId.getUserAssignedId() : null; + + // Reflective invocation of MtlsMsiClient to keep msal4j-sdk free of a hard dep + // on msal4j-mtls-extensions. The extension module registers MtlsMsiClient on the + // classpath; ManagedIdentityApplication.validateMtlsPopParameters() already verified + // the class is present before we get here. + Class clientClass = Class.forName("com.microsoft.aad.msal4j.mtls.MtlsMsiClient"); + Object client = clientClass.getDeclaredConstructor().newInstance(); + Method acquireTokenMethod = clientClass.getMethod( + "acquireToken", String.class, String.class, String.class, boolean.class, String.class); + + Object result = acquireTokenMethod.invoke( + client, + managedIdentityParameters.resource(), + identityType, + identityId, + false, // withAttestation — false by default; set MSAL_MTLS_HELPER_PATH to use a custom build with attestation enabled + null // correlationId — not available here; helper will generate one + ); + + // Extract fields from MtlsMsiHelperResult via reflection + Class resultClass = result.getClass(); + String accessToken = (String) resultClass.getMethod("getAccessToken").invoke(result); + String tokenType = (String) resultClass.getMethod("getTokenType").invoke(result); + int expiresIn = (int) resultClass.getMethod("getExpiresIn").invoke(result); + String bindingCertPem = (String) resultClass.getMethod("getBindingCertificate").invoke(result); + + long now = System.currentTimeMillis() / 1000; + long expiresOn = now + expiresIn; + long refreshOn = calculateRefreshOn(expiresOn); + + // Parse the PEM binding certificate if present + java.security.cert.X509Certificate bindingCert = null; + if (bindingCertPem != null && !bindingCertPem.isEmpty()) { + try { + byte[] der = java.util.Base64.getDecoder().decode( + bindingCertPem + .replace("-----BEGIN CERTIFICATE-----", "") + .replace("-----END CERTIFICATE-----", "") + .replaceAll("\\s", "")); + bindingCert = (java.security.cert.X509Certificate) + java.security.cert.CertificateFactory.getInstance("X.509") + .generateCertificate(new java.io.ByteArrayInputStream(der)); + } catch (Exception e) { + LOG.warn("Failed to parse binding certificate PEM from MtlsMsiHelper: {}", e.getMessage()); + } + } + + AuthenticationResultMetadata metadata = AuthenticationResultMetadata.builder() + .tokenSource(TokenSource.IDENTITY_PROVIDER) + .refreshOn(refreshOn) + .build(); + + TokenRequestExecutor tokenRequestExecutor = new TokenRequestExecutor( + clientApplication.authenticationAuthority, + msalRequest, + clientApplication.serviceBundle()); + + AuthenticationResult authResult = AuthenticationResult.builder() + .accessToken(accessToken) + .scopes(managedIdentityParameters.resource()) + .expiresOn(expiresOn) + .extExpiresOn(0) + .refreshOn(refreshOn) + .metadata(metadata) + .tokenType(tokenType) + .bindingCertificate(bindingCert) + .build(); + + clientApplication.tokenCache.saveTokens(tokenRequestExecutor, authResult, + clientApplication.authenticationAuthority.host); + authResult.metadata().tokenSource(TokenSource.IDENTITY_PROVIDER); + return authResult; + } } diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AuthenticationErrorCode.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AuthenticationErrorCode.java index d5fba3df..6a48bf2a 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AuthenticationErrorCode.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AuthenticationErrorCode.java @@ -161,6 +161,12 @@ public class AuthenticationErrorCode { public static final String INVALID_TIMESTAMP_FORMAT = "invalid_timestamp_format"; + /** + * Indicates that the request is invalid or malformed. For example, required parameters are missing, + * a value is out of range, or the request is not authorized for the specified scope. + */ + public static final String INVALID_REQUEST = "invalid_request"; + /** * Indicates that instance discovery failed because the authority is not a valid instance. * This is returned by the instance discovery endpoint when the provided authority host is unknown. diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AuthenticationResult.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AuthenticationResult.java index d87dfc4b..2bb4dce9 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AuthenticationResult.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AuthenticationResult.java @@ -5,6 +5,7 @@ import java.util.Date; import java.util.Objects; +import java.security.cert.X509Certificate; final class AuthenticationResult implements IAuthenticationResult { private static final long serialVersionUID = 1L; @@ -25,8 +26,10 @@ final class AuthenticationResult implements IAuthenticationResult { private final String scopes; private final AuthenticationResultMetadata metadata; private final Boolean isPopAuthorization; + private final String tokenType; + private final X509Certificate bindingCertificate; - AuthenticationResult(String accessToken, long expiresOn, long extExpiresOn, String refreshToken, Long refreshOn, String familyId, String idToken, AccountCacheEntity accountCacheEntity, String environment, String scopes, AuthenticationResultMetadata metadata, Boolean isPopAuthorization) { + AuthenticationResult(String accessToken, long expiresOn, long extExpiresOn, String refreshToken, Long refreshOn, String familyId, String idToken, AccountCacheEntity accountCacheEntity, String environment, String scopes, AuthenticationResultMetadata metadata, Boolean isPopAuthorization, String tokenType, X509Certificate bindingCertificate) { this.accessToken = accessToken; this.expiresOn = expiresOn; this.extExpiresOn = extExpiresOn; @@ -39,6 +42,8 @@ final class AuthenticationResult implements IAuthenticationResult { this.scopes = scopes; this.metadata = metadata == null ? AuthenticationResultMetadata.builder().build() : metadata; this.isPopAuthorization = isPopAuthorization; + this.tokenType = tokenType; + this.bindingCertificate = bindingCertificate; this.expiresOnDate = new Date(expiresOn * 1000); } @@ -129,6 +134,16 @@ Boolean isPopAuthorization() { return this.isPopAuthorization; } + @Override + public String tokenType() { + return this.tokenType; + } + + @Override + public X509Certificate bindingCertificate() { + return this.bindingCertificate; + } + static AuthenticationResultBuilder builder() { return new AuthenticationResultBuilder(); } @@ -146,6 +161,8 @@ static class AuthenticationResultBuilder { private String scopes; private AuthenticationResultMetadata metadata; private Boolean isPopAuthorization; + private String tokenType; + private X509Certificate bindingCertificate; AuthenticationResultBuilder() { } @@ -210,12 +227,22 @@ public AuthenticationResultBuilder isPopAuthorization(Boolean isPopAuthorization return this; } + public AuthenticationResultBuilder tokenType(String tokenType) { + this.tokenType = tokenType; + return this; + } + + public AuthenticationResultBuilder bindingCertificate(X509Certificate bindingCertificate) { + this.bindingCertificate = bindingCertificate; + return this; + } + public AuthenticationResult build() { - return new AuthenticationResult(this.accessToken, this.expiresOn, this.extExpiresOn, this.refreshToken, this.refreshOn, this.familyId, this.idToken, this.accountCacheEntity, this.environment, this.scopes, this.metadata, this.isPopAuthorization); + return new AuthenticationResult(this.accessToken, this.expiresOn, this.extExpiresOn, this.refreshToken, this.refreshOn, this.familyId, this.idToken, this.accountCacheEntity, this.environment, this.scopes, this.metadata, this.isPopAuthorization, this.tokenType, this.bindingCertificate); } public String toString() { - return "AuthenticationResult.AuthenticationResultBuilder(accessToken=" + this.accessToken + ", expiresOn=" + this.expiresOn + ", extExpiresOn=" + this.extExpiresOn + ", refreshToken=" + this.refreshToken + ", refreshOn=" + this.refreshOn + ", familyId=" + this.familyId + ", idToken=" + this.idToken + ", accountCacheEntity=" + this.accountCacheEntity + ", environment=" + this.environment + ", scopes=" + this.scopes + ", metadata=" + this.metadata + ", isPopAuthorization=" + this.isPopAuthorization + ")"; + return "AuthenticationResult.AuthenticationResultBuilder(accessToken=" + this.accessToken + ", expiresOn=" + this.expiresOn + ", extExpiresOn=" + this.extExpiresOn + ", refreshToken=" + this.refreshToken + ", refreshOn=" + this.refreshOn + ", familyId=" + this.familyId + ", idToken=" + this.idToken + ", accountCacheEntity=" + this.accountCacheEntity + ", environment=" + this.environment + ", scopes=" + this.scopes + ", metadata=" + this.metadata + ", isPopAuthorization=" + this.isPopAuthorization + ", tokenType=" + this.tokenType + ")"; } } @@ -243,6 +270,7 @@ public boolean equals(Object o) { if (!Objects.equals(environment, other.environment)) return false; if (!Objects.equals(expiresOnDate, other.expiresOnDate)) return false; if (!Objects.equals(scopes, other.scopes)) return false; + if (!Objects.equals(tokenType, other.tokenType)) return false; return Objects.equals(metadata, other.metadata); } @@ -265,6 +293,7 @@ public int hashCode() { result = result * 59 + (this.expiresOnDate == null ? 43 : this.expiresOnDate.hashCode()); result = result * 59 + (this.scopes == null ? 43 : this.scopes.hashCode()); result = result * 59 + (this.metadata == null ? 43 : this.metadata.hashCode()); + result = result * 59 + (this.tokenType == null ? 43 : this.tokenType.hashCode()); return result; } } diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ClientCertificate.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ClientCertificate.java index 9fb68716..3ee2ae2f 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ClientCertificate.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ClientCertificate.java @@ -151,4 +151,8 @@ private static byte[] getHashSha256(final byte[] inputBytes) throws NoSuchAlgori public PrivateKey privateKey() { return this.privateKey; } + + List publicKeyCertificateChain() { + return this.publicKeyCertificateChain; + } } diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ClientCredentialParameters.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ClientCredentialParameters.java index c6168dfe..e9495758 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ClientCredentialParameters.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ClientCredentialParameters.java @@ -28,7 +28,9 @@ public class ClientCredentialParameters implements IAcquireTokenParameters { private IClientCredential clientCredential; - private ClientCredentialParameters(Set scopes, Boolean skipCache, ClaimsRequest claims, Map extraHttpHeaders, Map extraQueryParameters, String tenant, IClientCredential clientCredential) { + private boolean mtlsProofOfPossession; + + private ClientCredentialParameters(Set scopes, Boolean skipCache, ClaimsRequest claims, Map extraHttpHeaders, Map extraQueryParameters, String tenant, IClientCredential clientCredential, boolean mtlsProofOfPossession) { this.scopes = scopes; this.skipCache = skipCache; this.claims = claims; @@ -36,6 +38,7 @@ private ClientCredentialParameters(Set scopes, Boolean skipCache, Claims this.extraQueryParameters = extraQueryParameters; this.tenant = tenant; this.clientCredential = clientCredential; + this.mtlsProofOfPossession = mtlsProofOfPossession; } private static ClientCredentialParametersBuilder builder() { @@ -87,6 +90,20 @@ public IClientCredential clientCredential() { return this.clientCredential; } + /** + * Whether to request an mTLS Proof-of-Possession token instead of a standard Bearer token. + * + *

When {@code true}, the token request is sent to the {@code mtlsauth.*} endpoint over + * a mutual-TLS connection using the application's {@link ClientCertificate} credential. + * The response token type will be {@code "mtls_pop"} and the access token will contain a + * {@code cnf.x5t#S256} claim binding it to the certificate.

+ * + * @return {@code true} if mTLS PoP is requested + */ + public boolean mtlsProofOfPossession() { + return this.mtlsProofOfPossession; + } + public static class ClientCredentialParametersBuilder { private Set scopes; private Boolean skipCache = false; @@ -95,6 +112,7 @@ public static class ClientCredentialParametersBuilder { private Map extraQueryParameters; private String tenant; private IClientCredential clientCredential; + private boolean mtlsProofOfPossession; ClientCredentialParametersBuilder() { } @@ -162,12 +180,26 @@ public ClientCredentialParametersBuilder clientCredential(IClientCredential clie return this; } + /** + * Requests an mTLS Proof-of-Possession token instead of a standard Bearer token. + * + *

Requires that the application was created with a {@link ClientCertificate} credential, + * the authority is a tenanted AAD authority (not {@code /common} or {@code /organizations}), + * and an Azure region is configured via + * {@link ConfidentialClientApplication.Builder#azureRegion(String)} or + * {@link ConfidentialClientApplication.Builder#autoDetectRegion(boolean)}.

+ */ + public ClientCredentialParametersBuilder withMtlsProofOfPossession() { + this.mtlsProofOfPossession = true; + return this; + } + public ClientCredentialParameters build() { - return new ClientCredentialParameters(this.scopes, this.skipCache, this.claims, this.extraHttpHeaders, this.extraQueryParameters, this.tenant, this.clientCredential); + return new ClientCredentialParameters(this.scopes, this.skipCache, this.claims, this.extraHttpHeaders, this.extraQueryParameters, this.tenant, this.clientCredential, this.mtlsProofOfPossession); } public String toString() { - return "ClientCredentialParameters.ClientCredentialParametersBuilder(scopes=" + this.scopes + ", skipCache=" + this.skipCache + ", claims=" + this.claims + ", extraHttpHeaders=" + this.extraHttpHeaders + ", extraQueryParameters=" + this.extraQueryParameters + ", tenant=" + this.tenant + ", clientCredential=" + this.clientCredential + ")"; + return "ClientCredentialParameters.ClientCredentialParametersBuilder(scopes=" + this.scopes + ", skipCache=" + this.skipCache + ", claims=" + this.claims + ", extraHttpHeaders=" + this.extraHttpHeaders + ", extraQueryParameters=" + this.extraQueryParameters + ", tenant=" + this.tenant + ", clientCredential=" + this.clientCredential + ", mtlsProofOfPossession=" + this.mtlsProofOfPossession + ")"; } } } diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ConfidentialClientApplication.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ConfidentialClientApplication.java index 51fd4610..b80062af 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ConfidentialClientApplication.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ConfidentialClientApplication.java @@ -31,6 +31,10 @@ public class ConfidentialClientApplication extends AbstractClientApplicationBase public CompletableFuture acquireToken(ClientCredentialParameters parameters) { validateNotNull("parameters", parameters); + if (parameters.mtlsProofOfPossession()) { + validateMtlsPopParameters(parameters); + } + RequestContext context = new RequestContext( this, PublicApi.ACQUIRE_TOKEN_FOR_CLIENT, @@ -46,6 +50,33 @@ public CompletableFuture acquireToken(ClientCredentialPar return this.executeRequest(clientCredentialRequest); } + private void validateMtlsPopParameters(ClientCredentialParameters parameters) { + IClientCredential cred = parameters.clientCredential() != null ? parameters.clientCredential() : this.clientCredential; + if (!(cred instanceof ClientCertificate)) { + throw new MsalClientException( + "mTLS Proof-of-Possession requires a ClientCertificate credential. " + + "Use ConfidentialClientApplication.builder(clientId, ClientCredentialFactory.createFromCertificate(...)).", + AuthenticationErrorCode.INVALID_REQUEST); + } + if (authenticationAuthority.isTenantless) { + throw new MsalClientException( + "mTLS Proof-of-Possession is not supported with /common or /organizations authorities. " + + "Use a tenanted authority: https://login.microsoftonline.com/{tenantId}", + AuthenticationErrorCode.INVALID_REQUEST); + } + if (authenticationAuthority.authorityType == AuthorityType.B2C) { + throw new MsalClientException( + "mTLS Proof-of-Possession is not supported for B2C authorities.", + AuthenticationErrorCode.INVALID_REQUEST); + } + if (this.azureRegion == null && !this.autoDetectRegion()) { + throw new MsalClientException( + "mTLS Proof-of-Possession requires an Azure region. " + + "Set azureRegion or autoDetectRegion on the application builder.", + AuthenticationErrorCode.INVALID_REQUEST); + } + } + @Override public CompletableFuture acquireToken(OnBehalfOfParameters parameters) { validateNotNull("parameters", parameters); diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/CredentialTypeEnum.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/CredentialTypeEnum.java index 12f9b016..60598f3d 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/CredentialTypeEnum.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/CredentialTypeEnum.java @@ -6,6 +6,7 @@ enum CredentialTypeEnum { ACCESS_TOKEN("AccessToken"), + ACCESS_TOKEN_WITH_AUTH_SCHEME("AccessToken_With_AuthScheme"), REFRESH_TOKEN("RefreshToken"), ID_TOKEN("IdToken"); diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/IAuthenticationResult.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/IAuthenticationResult.java index 934a2d2c..b356acdd 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/IAuthenticationResult.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/IAuthenticationResult.java @@ -4,6 +4,7 @@ package com.microsoft.aad.msal4j; import java.io.Serializable; +import java.security.cert.X509Certificate; /** * Interface representing the results of token acquisition operation. @@ -51,4 +52,25 @@ public interface IAuthenticationResult extends Serializable { default AuthenticationResultMetadata metadata() { return AuthenticationResultMetadata.builder().build(); } + + /** + * Returns the token type. For standard flows this is {@code "Bearer"}. For mTLS + * Proof-of-Possession flows this is {@code "mtls_pop"}. + * + * @return the token type string, or {@code null} if the server did not return one + */ + default String tokenType() { + return null; + } + + /** + * For mTLS Proof-of-Possession tokens, returns the X.509 certificate that was used for + * the mTLS handshake and to which the token is cryptographically bound (via the + * {@code cnf.x5t#S256} claim). Returns {@code null} for standard Bearer tokens. + * + * @return the binding certificate, or {@code null} for Bearer tokens + */ + default X509Certificate bindingCertificate() { + return null; + } } diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityApplication.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityApplication.java index 5385e9e0..0b130eed 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityApplication.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityApplication.java @@ -46,6 +46,18 @@ private ManagedIdentityApplication(Builder builder) { this.clientCapabilities = builder.clientCapabilities; } + private void validateMtlsPopParameters() { + // Check that the msal4j-mtls-extensions module is on the classpath + try { + Class.forName("com.microsoft.aad.msal4j.mtls.MtlsMsiClient"); + } catch (ClassNotFoundException e) { + throw new MsalClientException( + "mTLS Proof-of-Possession for Managed Identity requires the msal4j-mtls-extensions package. " + + "Add com.microsoft.azure:msal4j-mtls-extensions to your project dependencies.", + AuthenticationErrorCode.INVALID_REQUEST); + } + } + public static TokenCache getSharedTokenCache() { return ManagedIdentityApplication.sharedTokenCache; } @@ -63,6 +75,9 @@ public ManagedIdentityId getManagedIdentityId() { @Override public CompletableFuture acquireTokenForManagedIdentity(ManagedIdentityParameters managedIdentityParameters) throws Exception { + if (managedIdentityParameters.mtlsProofOfPossession()) { + validateMtlsPopParameters(); + } RequestContext requestContext = new RequestContext( this, managedIdentityId.getIdType() == ManagedIdentityIdType.SYSTEM_ASSIGNED ? diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityParameters.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityParameters.java index 21335802..94ae4319 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityParameters.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityParameters.java @@ -16,11 +16,13 @@ public class ManagedIdentityParameters implements IAcquireTokenParameters { boolean forceRefresh; String claims; String revokedTokenHash; - - private ManagedIdentityParameters(String resource, boolean forceRefresh, String claims) { + boolean mtlsProofOfPossession; + + private ManagedIdentityParameters(String resource, boolean forceRefresh, String claims, boolean mtlsProofOfPossession) { this.resource = resource; this.forceRefresh = forceRefresh; this.claims = claims; + this.mtlsProofOfPossession = mtlsProofOfPossession; } @Override @@ -83,10 +85,27 @@ public String revokedTokenHash() { return this.revokedTokenHash; } + /** + * Whether to request an mTLS Proof-of-Possession token instead of a standard Bearer token. + * + *

When {@code true}, the Managed Identity mTLS PoP flow is used. This flow requires + * a Windows Azure VM with Managed Identity enabled, and the .NET 8 runtime installed. + * The token acquisition is delegated to {@code MsalMtlsMsiHelper.exe} (provided by the + * {@code msal4j-mtls-extensions} package), which handles KeyGuard key creation, CSR + * generation, optional MAA attestation, IMDS credential issuance, and the mTLS token + * request to the regional STS.

+ * + * @return {@code true} if mTLS PoP is requested + */ + public boolean mtlsProofOfPossession() { + return this.mtlsProofOfPossession; + } + public static class ManagedIdentityParametersBuilder { private String resource; private boolean forceRefresh; private String claims; + private boolean mtlsProofOfPossession; ManagedIdentityParametersBuilder() { } @@ -118,12 +137,23 @@ public ManagedIdentityParametersBuilder claims(String claims) { return this; } + /** + * Requests an mTLS Proof-of-Possession token instead of a standard Bearer token. + * + *

Requires the {@code msal4j-mtls-extensions} package. See + * {@link ManagedIdentityParameters#mtlsProofOfPossession()} for details.

+ */ + public ManagedIdentityParametersBuilder withMtlsProofOfPossession() { + this.mtlsProofOfPossession = true; + return this; + } + public ManagedIdentityParameters build() { - return new ManagedIdentityParameters(this.resource, this.forceRefresh, this.claims); + return new ManagedIdentityParameters(this.resource, this.forceRefresh, this.claims, this.mtlsProofOfPossession); } public String toString() { - return "ManagedIdentityParameters.ManagedIdentityParametersBuilder(resource=" + this.resource + ", forceRefresh=" + this.forceRefresh + ")"; + return "ManagedIdentityParameters.ManagedIdentityParametersBuilder(resource=" + this.resource + ", forceRefresh=" + this.forceRefresh + ", mtlsProofOfPossession=" + this.mtlsProofOfPossession + ")"; } } } diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/MtlsPopAuthenticationScheme.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/MtlsPopAuthenticationScheme.java new file mode 100644 index 00000000..374c25a8 --- /dev/null +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/MtlsPopAuthenticationScheme.java @@ -0,0 +1,97 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.aad.msal4j; + +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.security.cert.CertificateEncodingException; +import java.security.cert.X509Certificate; +import java.util.Base64; + +/** + * Constants and helpers for the mTLS Proof-of-Possession authentication scheme. + * + *

mTLS PoP tokens are acquired via a mutual-TLS connection to a special + * {@code mtlsauth.*} endpoint. The token is cryptographically bound to the + * client certificate used in the TLS handshake via the {@code cnf.x5t#S256} claim.

+ */ +class MtlsPopAuthenticationScheme { + + /** Token type value returned in the response and used as the cache key discriminator. */ + static final String TOKEN_TYPE_MTLS_POP = "mtls_pop"; + + /** + * Computes the x5t#S256 thumbprint of a certificate: the Base64URL-encoded (no padding) + * SHA-256 digest of the DER-encoded certificate bytes. + * + *

This value appears as the {@code cnf.x5t#S256} claim in the access token and is + * stored as the {@code keyId} in the token cache entry to prevent cross-certificate + * token reuse.

+ * + * @param cert the X.509 certificate + * @return Base64URL-encoded SHA-256 thumbprint without padding + */ + static String computeX5tS256(X509Certificate cert) { + try { + MessageDigest sha256 = MessageDigest.getInstance("SHA-256"); + byte[] digest = sha256.digest(cert.getEncoded()); + return Base64.getUrlEncoder().withoutPadding().encodeToString(digest); + } catch (NoSuchAlgorithmException | CertificateEncodingException e) { + throw new MsalClientException( + "Failed to compute x5t#S256 thumbprint: " + e.getMessage(), + AuthenticationErrorCode.MSALRUNTIME_INTEROP_ERROR); + } + } + + /** + * Builds the mTLS token endpoint URL for the given authority host and tenant. + * + *

Endpoint rules (from the msal-dotnet design spec):

+ *
    + *
  • Public cloud: {@code https://{region}.mtlsauth.microsoft.com/{tenantId}/oauth2/v2.0/token}
  • + *
  • Public cloud (no region): {@code https://mtlsauth.microsoft.com/{tenantId}/oauth2/v2.0/token}
  • + *
  • Sovereign clouds: replace the {@code login.} prefix in the host with {@code mtlsauth.}
  • + *
  • US Gov ({@code login.usgovcloudapi.net}) and China ({@code login.chinacloudapi.cn}) + * are not supported and will throw {@link MsalClientException}.
  • + *
+ * + * @param region Azure region (e.g. {@code "eastus"}), or {@code null} to use the global endpoint + * @param tenantId the AAD tenant GUID or domain + * @param authorityHost the authority hostname (e.g. {@code "login.microsoftonline.com"}) + * @return the full mTLS token endpoint URL + * @throws MsalClientException if the authority is an unsupported sovereign cloud + */ + static String buildMtlsTokenEndpoint(String region, String tenantId, String authorityHost) { + if (authorityHost.contains("usgovcloudapi.net") || authorityHost.contains("chinacloudapi.cn")) { + throw new MsalClientException( + "mTLS Proof-of-Possession is not supported for US Government or China cloud authorities. " + + "Authority: " + authorityHost, + AuthenticationErrorCode.INVALID_REQUEST); + } + + String mtlsHost = toMtlsHost(authorityHost); + String regional = (region != null && !region.isEmpty()) ? region + "." : ""; + return String.format("https://%s%s/%s/oauth2/v2.0/token", regional, mtlsHost, tenantId); + } + + /** + * Converts a standard authority host to its mTLS equivalent. + * + *

Mapping rules:

+ *
    + *
  • {@code login.microsoftonline.com} → {@code mtlsauth.microsoft.com} (public cloud)
  • + *
  • Any other {@code login.*} host → {@code mtlsauth.*} (sovereign clouds, replace prefix)
  • + *
  • Other hosts → returned as-is (DSTS and custom authorities)
  • + *
+ */ + private static String toMtlsHost(String host) { + if ("login.microsoftonline.com".equals(host)) { + return "mtlsauth.microsoft.com"; + } + if (host.startsWith("login.")) { + return "mtlsauth." + host.substring("login.".length()); + } + return host; + } +} diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/MtlsSslContextHelper.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/MtlsSslContextHelper.java new file mode 100644 index 00000000..2c77d95e --- /dev/null +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/MtlsSslContextHelper.java @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.aad.msal4j; + +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSocketFactory; +import java.security.KeyStore; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; + +/** + * Builds a per-request {@link SSLSocketFactory} for mTLS client certificate authentication. + * + *

This is the Java equivalent of .NET's + * {@code HttpClientHandler.ClientCertificates.Add(certificate)} pattern. It creates an + * in-memory PKCS#12 {@link KeyStore}, loads the private key and certificate chain into it, + * then initializes an {@link SSLContext} with a {@link KeyManagerFactory} backed by that + * key store. The resulting socket factory presents the certificate during TLS handshake.

+ * + *

The returned factory is intended to be used with a short-lived {@link DefaultHttpClient} + * scoped to a single mTLS token request — it is not shared with the application-level HTTP + * client.

+ */ +class MtlsSslContextHelper { + + private static final char[] EMPTY_PASSWORD = new char[0]; + + /** + * Creates an {@link SSLSocketFactory} that presents the given certificate and private key + * during TLS client authentication. + * + * @param privateKey the private key corresponding to the leaf certificate + * @param certChain the certificate chain; {@code certChain[0]} is the leaf certificate + * @return an {@link SSLSocketFactory} configured for mTLS client authentication + * @throws MsalClientException if the SSL context cannot be constructed + */ + static SSLSocketFactory createSslSocketFactory(PrivateKey privateKey, X509Certificate[] certChain) { + try { + KeyStore ks = KeyStore.getInstance("PKCS12"); + ks.load(null, null); + ks.setKeyEntry("mtls", privateKey, EMPTY_PASSWORD, certChain); + + KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + kmf.init(ks, EMPTY_PASSWORD); + + SSLContext ctx = SSLContext.getInstance("TLS"); + ctx.init(kmf.getKeyManagers(), null, null); + + return ctx.getSocketFactory(); + } catch (Exception e) { + throw new MsalClientException( + "Failed to create mTLS SSL socket factory: " + e.getMessage(), + AuthenticationErrorCode.MSALRUNTIME_INTEROP_ERROR); + } + } +} diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/OAuthHttpRequest.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/OAuthHttpRequest.java index 49ecc2fc..a9ba5d26 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/OAuthHttpRequest.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/OAuthHttpRequest.java @@ -49,6 +49,30 @@ public HttpResponse send() throws IOException { return createOauthHttpResponseFromHttpResponse(httpResponse); } + /** + * Sends the request to the given {@code overrideUrl} using the supplied {@code httpClient} + * instead of the service bundle's default HTTP client. + * + *

Used for mTLS PoP flows where the token endpoint URL is the {@code mtlsauth.*} endpoint + * and the HTTP client is configured with a client-certificate {@link javax.net.ssl.SSLSocketFactory}.

+ */ + HttpResponse sendWithClient(URL overrideUrl, IHttpClient httpClient) throws IOException { + Map httpHeaders = configureHttpHeaders(); + HttpRequest httpRequest = new HttpRequest( + HttpMethod.POST, + overrideUrl.toString(), + httpHeaders, + this.query); + + IHttpResponse httpResponse = ((HttpHelper) serviceBundle.getHttpHelper()).executeHttpRequest( + httpRequest, + this.requestContext, + serviceBundle.getTelemetryManager(), + httpClient); + + return createOauthHttpResponseFromHttpResponse(httpResponse); + } + private Map configureHttpHeaders() { Map httpHeaders = new HashMap<>(extraHeaderParams); diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/TokenCache.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/TokenCache.java index c54f75cf..5e082b65 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/TokenCache.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/TokenCache.java @@ -295,7 +295,16 @@ private static AccessTokenCacheEntity createAccessTokenCacheEntity(TokenRequestE AuthenticationResult authenticationResult, String environmentAlias) { AccessTokenCacheEntity at = new AccessTokenCacheEntity(); - at.credentialType(CredentialTypeEnum.ACCESS_TOKEN.value()); + + boolean isMtlsPop = MtlsPopAuthenticationScheme.TOKEN_TYPE_MTLS_POP.equals(authenticationResult.tokenType()); + if (isMtlsPop) { + at.credentialType(CredentialTypeEnum.ACCESS_TOKEN_WITH_AUTH_SCHEME.value()); + if (authenticationResult.bindingCertificate() != null) { + at.keyId(MtlsPopAuthenticationScheme.computeX5tS256(authenticationResult.bindingCertificate())); + } + } else { + at.credentialType(CredentialTypeEnum.ACCESS_TOKEN.value()); + } if (authenticationResult.account() != null) { at.homeAccountId(authenticationResult.account().homeAccountId()); diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/TokenRequestExecutor.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/TokenRequestExecutor.java index c6e7ca56..2d1f54ef 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/TokenRequestExecutor.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/TokenRequestExecutor.java @@ -6,8 +6,11 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.net.ssl.SSLSocketFactory; import java.io.IOException; import java.net.MalformedURLException; +import java.net.URL; +import java.security.cert.X509Certificate; import java.util.*; class TokenRequestExecutor { @@ -31,10 +34,76 @@ AuthenticationResult executeTokenRequest() throws IOException { LOG.debug("Sending token request to: {}", requestAuthority.canonicalAuthorityUrl()); OAuthHttpRequest oAuthHttpRequest = createOauthHttpRequest(); + + if (isMtlsPopRequest()) { + return executeTokenRequestWithMtls(oAuthHttpRequest); + } + HttpResponse oauthHttpResponse = oAuthHttpRequest.send(); return createAuthenticationResultFromOauthHttpResponse(oauthHttpResponse); } + private boolean isMtlsPopRequest() { + return msalRequest instanceof ClientCredentialRequest && + ((ClientCredentialRequest) msalRequest).parameters.mtlsProofOfPossession(); + } + + private AuthenticationResult executeTokenRequestWithMtls(OAuthHttpRequest oAuthHttpRequest) throws IOException { + ClientCertificate cert = getMtlsClientCertificate(); + + X509Certificate[] chain = cert.publicKeyCertificateChain().toArray(new X509Certificate[0]); + SSLSocketFactory sslSocketFactory = MtlsSslContextHelper.createSslSocketFactory(cert.privateKey(), chain); + IHttpClient mtlsHttpClient = new DefaultHttpClient(null, sslSocketFactory, null, null); + + String region = ((AbstractClientApplicationBase) msalRequest.application()).azureRegion(); + String mtlsEndpoint = MtlsPopAuthenticationScheme.buildMtlsTokenEndpoint( + region, requestAuthority.tenant, requestAuthority.host); + URL mtlsUrl = new URL(mtlsEndpoint); + + LOG.debug("mTLS PoP: sending token request to {}", mtlsEndpoint); + HttpResponse oauthHttpResponse = oAuthHttpRequest.sendWithClient(mtlsUrl, mtlsHttpClient); + AuthenticationResult result = createAuthenticationResultFromOauthHttpResponse(oauthHttpResponse); + + // Annotate with mTLS PoP token type and binding certificate + return AuthenticationResult.builder() + .accessToken(result.accessToken()) + .refreshToken(result.refreshToken()) + .familyId(result.familyId()) + .idToken(result.idToken()) + .environment(result.environment()) + .expiresOn(result.expiresOn()) + .extExpiresOn(result.extExpiresOn()) + .refreshOn(result.refreshOn()) + .accountCacheEntity(result.accountCacheEntity()) + .scopes(result.scopes()) + .metadata(result.metadata()) + .isPopAuthorization(result.isPopAuthorization()) + .tokenType(MtlsPopAuthenticationScheme.TOKEN_TYPE_MTLS_POP) + .bindingCertificate(chain[0]) + .build(); + } + + private ClientCertificate getMtlsClientCertificate() { + ConfidentialClientApplication app = (ConfidentialClientApplication) msalRequest.application(); + IClientCredential credential = app.clientCredential; + + // Check for per-request credential override + if (msalRequest instanceof ClientCredentialRequest) { + IClientCredential override = ((ClientCredentialRequest) msalRequest).parameters.clientCredential(); + if (override != null) { + credential = override; + } + } + + if (!(credential instanceof ClientCertificate)) { + throw new MsalClientException( + "mTLS Proof-of-Possession requires a ClientCertificate credential. " + + "ClientSecret and ClientAssertion are not supported for mTLS PoP.", + AuthenticationErrorCode.INVALID_REQUEST); + } + return (ClientCertificate) credential; + } + OAuthHttpRequest createOauthHttpRequest() throws MalformedURLException { if (requestAuthority.tokenEndpointUrl() == null) { @@ -125,6 +194,12 @@ private void addCredentialToRequest(Map queryParameters, LOG.warn("Could not create authority with tenant override: {}", e.getMessage()); } } + + // For mTLS PoP, authentication happens at the TLS layer — do not send client_assertion + if (parameters.mtlsProofOfPossession()) { + queryParameters.put("token_type", MtlsPopAuthenticationScheme.TOKEN_TYPE_MTLS_POP); + return; + } } // Quick return if no credential is provided @@ -202,6 +277,7 @@ private AuthenticationResult createAuthenticationResultFromOauthHttpResponse(Htt refreshOn(response.getRefreshIn() > 0 ? currTimestampSec + response.getRefreshIn() : 0). accountCacheEntity(accountCacheEntity). scopes(response.getScope()). + tokenType(response.getTokenType()). metadata(AuthenticationResultMetadata.builder() .tokenSource(TokenSource.IDENTITY_PROVIDER) .refreshOn(response.getRefreshIn() > 0 ? currTimestampSec + response.getRefreshIn() : 0) diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/TokenResponse.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/TokenResponse.java index b314bb77..bf6df66b 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/TokenResponse.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/TokenResponse.java @@ -16,6 +16,7 @@ class TokenResponse { private String accessToken; private String idToken; private String refreshToken; + private String tokenType; TokenResponse(Map jsonMap) { this.accessToken = jsonMap.get("access_token"); @@ -27,6 +28,7 @@ class TokenResponse { this.extExpiresIn = StringHelper.isNullOrBlank(jsonMap.get("ext_expires_in")) ? 0 : Long.parseLong(jsonMap.get("ext_expires_in")); this.refreshIn = StringHelper.isNullOrBlank(jsonMap.get("refresh_in")) ? 0: Long.parseLong(jsonMap.get("refresh_in")); this.foci = jsonMap.get("foci"); + this.tokenType = jsonMap.get("token_type"); } static TokenResponse parseHttpResponse(final HttpResponse httpResponse) { @@ -73,4 +75,8 @@ public String idToken() { public String refreshToken() { return refreshToken; } + + public String getTokenType() { + return tokenType; + } } diff --git a/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/MtlsPopTest.java b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/MtlsPopTest.java new file mode 100644 index 00000000..2b496b20 --- /dev/null +++ b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/MtlsPopTest.java @@ -0,0 +1,258 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.aad.msal4j; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; + +import javax.net.ssl.SSLSocketFactory; +import java.io.InputStream; +import java.security.*; +import java.security.cert.X509Certificate; +import java.util.Collections; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Unit tests for mTLS Proof-of-Possession implementation: + * - MtlsPopAuthenticationScheme + * - MtlsSslContextHelper + * - ClientCredentialParameters.withMtlsProofOfPossession() + * - ManagedIdentityParameters.withMtlsProofOfPossession() + * - AccessTokenCacheEntity cache key isolation (keyId / AccessToken_With_AuthScheme) + * - CredentialTypeEnum.ACCESS_TOKEN_WITH_AUTH_SCHEME + * - AuthenticationResult tokenType / bindingCertificate + */ +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +class MtlsPopTest { + + private static X509Certificate testCert; + private static PrivateKey testKey; + + @BeforeAll + static void loadTestCertificate() throws Exception { + // Load the pre-generated test PKCS12 from test resources + try (InputStream is = MtlsPopTest.class.getResourceAsStream("/mtls-test-cert.p12")) { + assertNotNull(is, "mtls-test-cert.p12 must exist in test/resources"); + KeyStore ks = KeyStore.getInstance("PKCS12"); + ks.load(is, "changeit".toCharArray()); + String alias = ks.aliases().nextElement(); + testKey = (PrivateKey) ks.getKey(alias, "changeit".toCharArray()); + testCert = (X509Certificate) ks.getCertificate(alias); + } + } + + // ─── MtlsPopAuthenticationScheme ──────────────────────────────────────── + + @Test + void scheme_constants() { + assertEquals("mtls_pop", MtlsPopAuthenticationScheme.TOKEN_TYPE_MTLS_POP); + } + + @Test + void buildMtlsTokenEndpoint_publicCloud() throws Exception { + String endpoint = MtlsPopAuthenticationScheme.buildMtlsTokenEndpoint( + "eastus", "mytenant", "login.microsoftonline.com"); + assertEquals("https://eastus.mtlsauth.microsoft.com/mytenant/oauth2/v2.0/token", endpoint); + } + + @Test + void buildMtlsTokenEndpoint_publicCloud_noRegion() throws Exception { + String endpoint = MtlsPopAuthenticationScheme.buildMtlsTokenEndpoint( + null, "mytenant", "login.microsoftonline.com"); + assertEquals("https://mtlsauth.microsoft.com/mytenant/oauth2/v2.0/token", endpoint); + } + + @Test + void buildMtlsTokenEndpoint_sovereignCloud() throws Exception { + String endpoint = MtlsPopAuthenticationScheme.buildMtlsTokenEndpoint( + "eastus", "mytenant", "login.microsoftonline.us"); + assertEquals("https://eastus.mtlsauth.microsoftonline.us/mytenant/oauth2/v2.0/token", endpoint); + } + + @Test + void buildMtlsTokenEndpoint_usGov_throws() { + MsalClientException ex = assertThrows(MsalClientException.class, () -> + MtlsPopAuthenticationScheme.buildMtlsTokenEndpoint( + "eastus", "mytenant", "login.usgovcloudapi.net")); + assertTrue(ex.getMessage().contains("not supported")); + } + + @Test + void buildMtlsTokenEndpoint_china_throws() { + MsalClientException ex = assertThrows(MsalClientException.class, () -> + MtlsPopAuthenticationScheme.buildMtlsTokenEndpoint( + "eastus", "mytenant", "login.chinacloudapi.cn")); + assertTrue(ex.getMessage().contains("not supported")); + } + + @Test + void computeX5tS256_producesBase64UrlNoPadding() throws Exception { + String thumbprint = MtlsPopAuthenticationScheme.computeX5tS256(testCert); + assertNotNull(thumbprint); + assertFalse(thumbprint.isEmpty()); + assertFalse(thumbprint.contains("+"), "x5t#S256 must use Base64URL encoding (no +)"); + assertFalse(thumbprint.contains("/"), "x5t#S256 must use Base64URL encoding (no /)"); + assertFalse(thumbprint.contains("="), "x5t#S256 must not have padding"); + } + + @Test + void computeX5tS256_deterministicForSameCert() throws Exception { + String t1 = MtlsPopAuthenticationScheme.computeX5tS256(testCert); + String t2 = MtlsPopAuthenticationScheme.computeX5tS256(testCert); + assertEquals(t1, t2); + } + + // ─── MtlsSslContextHelper ─────────────────────────────────────────────── + + @Test + void sslContextHelper_createsSslSocketFactory() throws Exception { + SSLSocketFactory factory = MtlsSslContextHelper.createSslSocketFactory( + testKey, new X509Certificate[]{testCert}); + assertNotNull(factory); + } + + @Test + void sslContextHelper_nullKey_throws() { + assertThrows(Exception.class, () -> + MtlsSslContextHelper.createSslSocketFactory(null, new X509Certificate[]{testCert})); + } + + @Test + void sslContextHelper_nullCertChain_throws() { + assertThrows(Exception.class, () -> + MtlsSslContextHelper.createSslSocketFactory(testKey, null)); + } + + // ─── ClientCredentialParameters ───────────────────────────────────────── + + @Test + void clientCredentialParameters_mtlsPopDefault_false() { + ClientCredentialParameters params = ClientCredentialParameters + .builder(Collections.singleton("https://management.azure.com/.default")) + .build(); + assertFalse(params.mtlsProofOfPossession()); + } + + @Test + void clientCredentialParameters_withMtlsProofOfPossession_true() { + ClientCredentialParameters params = ClientCredentialParameters + .builder(Collections.singleton("https://management.azure.com/.default")) + .withMtlsProofOfPossession() + .build(); + assertTrue(params.mtlsProofOfPossession()); + } + + // ─── ManagedIdentityParameters ────────────────────────────────────────── + + @Test + void managedIdentityParameters_mtlsPopDefault_false() { + ManagedIdentityParameters params = ManagedIdentityParameters + .builder("https://management.azure.com") + .build(); + assertFalse(params.mtlsProofOfPossession()); + } + + @Test + void managedIdentityParameters_withMtlsProofOfPossession_true() { + ManagedIdentityParameters params = ManagedIdentityParameters + .builder("https://management.azure.com") + .withMtlsProofOfPossession() + .build(); + assertTrue(params.mtlsProofOfPossession()); + } + + // ─── CredentialTypeEnum ────────────────────────────────────────────────── + + @Test + void credentialTypeEnum_accessTokenWithAuthScheme_value() { + assertEquals("AccessToken_With_AuthScheme", + CredentialTypeEnum.ACCESS_TOKEN_WITH_AUTH_SCHEME.value()); + } + + // ─── AccessTokenCacheEntity ────────────────────────────────────────────── + + @Test + void cacheEntity_standardToken_keyHasNoKeyId() { + AccessTokenCacheEntity entity = buildCacheEntity("AccessToken", null); + String key = entity.getKey(); + // Standard Bearer token cache key: 6 parts separated by 5 dashes + assertEquals(5, countOccurrences(key, '-')); + } + + @Test + void cacheEntity_mtlsPopToken_keyIncludesKeyId() throws Exception { + String thumbprint = MtlsPopAuthenticationScheme.computeX5tS256(testCert); + + AccessTokenCacheEntity entity = buildCacheEntity( + CredentialTypeEnum.ACCESS_TOKEN_WITH_AUTH_SCHEME.value(), thumbprint); + String key = entity.getKey(); + assertTrue(key.contains(thumbprint.toLowerCase()), + "Cache key for mTLS PoP token must include the thumbprint"); + // Key must have more segments than a standard Bearer token key (which has 6 parts / 5 dashes) + assertTrue(key.endsWith("-" + thumbprint.toLowerCase()), + "Thumbprint must be the last segment of the mTLS PoP cache key"); + } + + @Test + void cacheEntity_mtlsPopAndBearerTokens_haveDifferentKeys() throws Exception { + String thumbprint = MtlsPopAuthenticationScheme.computeX5tS256(testCert); + + AccessTokenCacheEntity bearer = buildCacheEntity(CredentialTypeEnum.ACCESS_TOKEN.value(), null); + AccessTokenCacheEntity mtlsPop = buildCacheEntity( + CredentialTypeEnum.ACCESS_TOKEN_WITH_AUTH_SCHEME.value(), thumbprint); + + assertNotEquals(bearer.getKey(), mtlsPop.getKey(), + "Bearer and mTLS PoP tokens for the same scope must have different cache keys"); + } + + // ─── AuthenticationResult tokenType / bindingCertificate ──────────────── + + @Test + void authResult_defaultTokenType_null() { + AuthenticationResult result = AuthenticationResult.builder() + .accessToken("token") + .build(); + assertNull(result.tokenType()); + assertNull(result.bindingCertificate()); + } + + @Test + void authResult_mtlsPopFields_set() throws Exception { + AuthenticationResult result = AuthenticationResult.builder() + .accessToken("mtls_pop_token") + .tokenType("mtls_pop") + .bindingCertificate(testCert) + .build(); + + assertEquals("mtls_pop", result.tokenType()); + assertEquals(testCert, result.bindingCertificate()); + } + + // ─── Helpers ───────────────────────────────────────────────────────────── + + private static AccessTokenCacheEntity buildCacheEntity(String credentialType, String keyId) { + AccessTokenCacheEntity entity = new AccessTokenCacheEntity(); + entity.homeAccountId(""); + entity.environment("login.microsoftonline.com"); + entity.credentialType(credentialType); + entity.clientId("clientid"); + entity.realm("tenant"); + entity.target("scope"); + if (keyId != null) { + entity.keyId(keyId); + } + return entity; + } + + private static int countOccurrences(String s, char c) { + int count = 0; + for (int i = 0; i < s.length(); i++) { + if (s.charAt(i) == c) count++; + } + return count; + } +} diff --git a/msal4j-sdk/src/test/resources/mtls-test-cert.p12 b/msal4j-sdk/src/test/resources/mtls-test-cert.p12 new file mode 100644 index 00000000..63315c69 Binary files /dev/null and b/msal4j-sdk/src/test/resources/mtls-test-cert.p12 differ diff --git a/pom.xml b/pom.xml index 373e77e7..9fbda531 100644 --- a/pom.xml +++ b/pom.xml @@ -9,5 +9,6 @@ msal4j-sdk msal4j-brokers msal4j-persistence-extension + msal4j-mtls-extensions