diff --git a/.Pipelines/CI-AND-RELEASE-PIPELINES.md b/.Pipelines/CI-AND-RELEASE-PIPELINES.md new file mode 100644 index 00000000..d8d18d00 --- /dev/null +++ b/.Pipelines/CI-AND-RELEASE-PIPELINES.md @@ -0,0 +1,126 @@ +# CI/CD Pipelines + +This document describes the pipeline structure for the `msal` Python package, +including what each pipeline does, when it runs, and how to trigger a release. + +--- + +## Pipeline Files + +| File | Purpose | +|------|---------| +| [`azure-pipelines.yml`](../azure-pipelines.yml) | PR gate and post-merge CI — calls the shared template with `runPublish: false` | +| [`pipeline-publish.yml`](pipeline-publish.yml) | Release pipeline — manually queued, builds and publishes to PyPI | +| [`template-pipeline-stages.yml`](template-pipeline-stages.yml) | Shared stages template — PreBuildCheck, Validate, and CI stages reused by both pipelines | +| [`credscan-exclusion.json`](credscan-exclusion.json) | CredScan suppression file for known test fixtures | + +--- + +## PR / CI Pipeline (`azure-pipelines.yml`) + +### Triggers + +| Event | Branches | +|-------|----------| +| Pull request opened / updated | all branches | +| Push / merge | `dev`, `azure-pipelines` | +| Scheduled | Daily at 11:45 PM Pacific, `dev` branch (only when there are new changes) | + +### Stages + +``` +PreBuildCheck ─► CI +``` + +| Stage | What it does | +|-------|-------------| +| **PreBuildCheck** | Runs SDL security scans: PoliCheck (policy/offensive content), CredScan (leaked credentials), and PostAnalysis (breaks the build on findings) | +| **CI** | Runs the full test suite on Python 3.9, 3.10, 3.11, 3.12, 3.13, and 3.14 | + +The Validate stage is **skipped** on PR/CI runs (it only applies to release builds). + +> **SDL coverage:** The PreBuildCheck stage satisfies the OneBranch SDL requirement. +> It runs on every PR, every merge to `dev`, and on the daily schedule — ensuring +> continuous security scanning without a separate dedicated SDL pipeline. + +--- + +## Release Pipeline (`pipeline-publish.yml`) + +### Triggers + +**Manual only** — no automatic branch or tag triggers. Must be queued explicitly +with both parameters filled in. + +### Parameters + +| Parameter | Description | Example values | +|-----------|-------------|----------------| +| **Package version to publish** | Must exactly match `msal/sku.py __version__`. [PEP 440](https://peps.python.org/pep-0440/) format. | `1.36.0`, `1.36.0rc1`, `1.36.0b1` | +| **Publish target** | Destination for this release. | `test.pypi.org (Preview / RC)` or `pypi.org (ESRP Production)` | + +### Stage Flow + +``` +PreBuildCheck ─► Validate ─► CI ─► Build ─┬─► PublishMSALPython (publishTarget == 'test.pypi.org (Preview / RC)') + └─► PublishPyPI (publishTarget == 'pypi.org (ESRP Production)') +``` + +| Stage | What it does | Condition | +|-------|-------------|-----------| +| **PreBuildCheck** | PoliCheck + CredScan scans | Always | +| **Validate** | Asserts the `packageVersion` parameter matches `msal/sku.py __version__` | Always (release runs only) | +| **CI** | Full test matrix (Python 3.9–3.14) | After Validate passes | +| **Build** | Builds `sdist` and `wheel` via `python -m build`; publishes `python-dist` artifact | After CI passes | +| **PublishMSALPython** | Uploads to test.pypi.org | `publishTarget == test.pypi.org (Preview / RC)` | +| **PublishPyPI** | Uploads to PyPI via ESRP; requires manual approval | `publishTarget == pypi.org (ESRP Production)` | + +--- + +## How to Publish a Release + +### Step 1 — Update the version + +Edit `msal/sku.py` and set `__version__` to the target version: + +```python +__version__ = "1.36.0rc1" # RC / preview +__version__ = "1.36.0" # production release +``` + +Push the change to the branch you intend to release from. + +### Step 2 — Queue the pipeline + +1. Go to the **MSAL.Python-Publish** pipeline in ADO. +2. Click **Run pipeline**. +3. Select the branch to release from. +4. Enter the **Package version to publish** (must match `msal/sku.py` exactly). +5. Select the **Publish target**: + - `test.pypi.org (Preview / RC)` — for release candidates and previews + - `pypi.org (ESRP Production)` — for final releases (requires approval gate) +6. Click **Run**. + +### Step 3 — Approve (production releases only) + +The `pypi.org (ESRP Production)` path includes a required manual approval before +the package is uploaded. An approver must review and approve in the ADO +**Environments** panel before the `PublishPyPI` stage proceeds. + +### Step 4 — Verify + +- **test.pypi.org:** https://test.pypi.org/project/msal/ +- **PyPI:** https://pypi.org/project/msal/ + +--- + +## Version Format + +PyPI enforces [PEP 440](https://peps.python.org/pep-0440/). Versions with `-` (e.g. `1.36.0-Preview`) are rejected at upload time. Use standard suffixes: + +| Release type | Format | +|-------------|--------| +| Production | `1.36.0` | +| Release candidate | `1.36.0rc1` | +| Beta | `1.36.0b1` | +| Alpha | `1.36.0a1` | diff --git a/.Pipelines/credscan-exclusion.json b/.Pipelines/credscan-exclusion.json new file mode 100644 index 00000000..4defd2b8 --- /dev/null +++ b/.Pipelines/credscan-exclusion.json @@ -0,0 +1,13 @@ +{ + "tool": "Credential Scanner", + "suppressions": [ + { + "file": "tests/certificate-with-password.pfx", + "_justification": "Self-signed certificate used only in unit tests. Not a production credential." + }, + { + "file": "tests/test_mi.py", + "_justification": "WWW-Authenticate challenge header value used as a mock HTTP response fixture in unit tests. Not a real credential." + } + ] +} diff --git a/.Pipelines/pipeline-publish.yml b/.Pipelines/pipeline-publish.yml new file mode 100644 index 00000000..818dfddb --- /dev/null +++ b/.Pipelines/pipeline-publish.yml @@ -0,0 +1,179 @@ +# pipeline-publish.yml +# +# Release pipeline for the msal Python package — manually triggered only. +# Source: https://github.com/AzureAD/microsoft-authentication-library-for-python +# +# Publish targets: +# test.pypi.org (Preview / RC) — preview releases via MSAL-Test-Python-Upload SC +# (SC creation pending test.pypi.org API token) +# pypi.org (ESRP Production) — production releases via ESRP (EsrpRelease@9) using MSAL-ESRP-AME SC +# +# For pipeline documentation, see .Pipelines/CI-AND-RELEASE-PIPELINES.md. + +parameters: +- name: packageVersion + displayName: 'Package version to publish (must match msal/sku.py, e.g. 1.36.0 or 1.36.0rc1)' + type: string + +- name: publishTarget + displayName: 'Publish target' + type: string + values: + - 'test.pypi.org (Preview / RC)' + - 'pypi.org (ESRP Production)' + +trigger: none # manual runs only — no automatic branch or tag triggers +pr: none + +# Stage flow: +# +# PreBuildCheck ─► Validate ─► CI ─► Build ─► PublishMSALPython (publishTarget == Preview) +# └─► PublishPyPI (publishTarget == ESRP Production) + +stages: + +# PreBuildCheck, Validate, and CI stages are defined in the shared template. +- template: template-pipeline-stages.yml + parameters: + packageVersion: ${{ parameters.packageVersion }} + runPublish: true + +# ══════════════════════════════════════════════════════════════════════════════ +# Stage 3 · Build — build sdist + wheel +# ══════════════════════════════════════════════════════════════════════════════ +- stage: Build + displayName: 'Build package' + dependsOn: CI + condition: eq(dependencies.CI.result, 'Succeeded') + jobs: + - job: BuildDist + displayName: 'Build sdist + wheel (Python 3.12)' + pool: + vmImage: ubuntu-latest + steps: + - task: UsePythonVersion@0 + inputs: + versionSpec: '3.12' + displayName: 'Use Python 3.12' + + - script: | + python -m pip install --upgrade pip build twine + displayName: 'Install build toolchain' + + - script: | + python -m build + displayName: 'Build sdist and wheel' + + - script: | + python -m twine check dist/* + displayName: 'Verify distribution (twine check)' + + - task: PublishPipelineArtifact@1 + displayName: 'Publish dist/ as pipeline artifact' + inputs: + targetPath: dist/ + artifact: python-dist + +# ══════════════════════════════════════════════════════════════════════════════ +# Stage 4a · Publish to test.pypi.org (Preview / RC) +# Note: requires MSAL-Test-Python-Upload SC in ADO (pending test.pypi.org API token) +# ══════════════════════════════════════════════════════════════════════════════ +- stage: PublishMSALPython + displayName: 'Publish to test.pypi.org (Preview)' + dependsOn: Build + condition: > + and( + eq(dependencies.Build.result, 'Succeeded'), + eq('${{ parameters.publishTarget }}', 'test.pypi.org (Preview / RC)') + ) + jobs: + - deployment: DeployMSALPython + displayName: 'Upload to test.pypi.org' + pool: + vmImage: ubuntu-latest + environment: MSAL-Python + strategy: + runOnce: + deploy: + steps: + - task: DownloadPipelineArtifact@2 + displayName: 'Download python-dist artifact' + inputs: + artifactName: python-dist + targetPath: $(Pipeline.Workspace)/python-dist + + - task: UsePythonVersion@0 + inputs: + versionSpec: '3.12' + displayName: 'Use Python 3.12' + + - script: | + python -m pip install --upgrade pip twine + displayName: 'Install twine' + + # TODO: create MSAL-Test-Python-Upload SC with test.pypi.org API token, then uncomment: + # - task: TwineAuthenticate@1 + # displayName: 'Authenticate with MSAL-Test-Python-Upload' + # inputs: + # pythonUploadServiceConnection: MSAL-Test-Python-Upload + + # - script: | + # python -m twine upload \ + # -r "MSAL-Test-Python-Upload" \ + # --config-file $(PYPIRC_PATH) \ + # --skip-existing \ + # $(Pipeline.Workspace)/python-dist/* + # displayName: 'Upload to test.pypi.org' + + - script: echo "Publish to test.pypi.org skipped — MSAL-Test-Python-Upload SC not yet created." + displayName: 'Skip upload (SC pending)' + +# ══════════════════════════════════════════════════════════════════════════════ +# Stage 4b · Publish to PyPI (ESRP Production) +# Uses EsrpRelease@9 via the MSAL-ESRP-AME service connection. +# IMPORTANT: configure a required manual approval on this environment in +# ADO → Pipelines → Environments → MSAL-Python-Release → Approvals and checks. +# IMPORTANT: EsrpRelease@9 requires a Windows agent. +# ══════════════════════════════════════════════════════════════════════════════ +- stage: PublishPyPI + displayName: 'Publish to PyPI (ESRP Production)' + dependsOn: Build + condition: > + and( + eq(dependencies.Build.result, 'Succeeded'), + eq('${{ parameters.publishTarget }}', 'pypi.org (ESRP Production)') + ) + jobs: + - deployment: DeployPyPI + displayName: 'Upload to PyPI via ESRP' + pool: + vmImage: windows-latest + environment: MSAL-Python-Release + strategy: + runOnce: + deploy: + steps: + - task: DownloadPipelineArtifact@2 + displayName: 'Download python-dist artifact' + inputs: + artifactName: python-dist + targetPath: $(Pipeline.Workspace)/python-dist + + - task: EsrpRelease@9 + displayName: 'Publish to PyPI via ESRP' + inputs: + connectedservicename: 'MSAL-ESRP-AME' + usemanagedidentity: true + keyvaultname: 'MSALVault' + signcertname: 'MSAL-ESRP-Release-Signing' + clientid: '8650ce2b-38d4-466a-9144-bc5c19c88112' + intent: 'PackageDistribution' + contenttype: 'PyPi' + contentsource: 'Folder' + folderlocation: '$(Pipeline.Workspace)/python-dist' + waitforreleasecompletion: true + owners: 'ryauld@microsoft.com,avdunn@microsoft.com' + approvers: 'avdunn@microsoft.com,bogavril@microsoft.com' + serviceendpointurl: 'https://api.esrp.microsoft.com' + mainpublisher: 'ESRPRELPACMAN' + domaintenantid: '33e01921-4d64-4f8c-a055-5bdaffd5e33d' diff --git a/.Pipelines/template-pipeline-stages.yml b/.Pipelines/template-pipeline-stages.yml new file mode 100644 index 00000000..972121c3 --- /dev/null +++ b/.Pipelines/template-pipeline-stages.yml @@ -0,0 +1,204 @@ +# template-pipeline-stages.yml +# +# Shared stages template for the msal Python package. +# +# Called from: +# pipeline-publish.yml — release build (runPublish: true) +# azure-pipelines.yml — PR gate and post-merge CI (runPublish: false) +# +# Parameters: +# packageVersion - Version to validate against msal/sku.py +# Required when runPublish is true; unused otherwise. +# runPublish - When true: also runs the Validate stage before CI. +# When false (PR / merge builds): only PreBuildCheck + CI run. +# +# Stage flow: +# +# runPublish: true → PreBuildCheck ─► Validate ─► CI +# runPublish: false → PreBuildCheck ─► CI (Validate is skipped) +# +# Build and Publish stages are defined in pipeline-publish.yml (not here), +# so that the PR build never references PyPI service connections. + +parameters: +- name: packageVersion + type: string + default: '' +- name: runPublish + type: boolean + default: false + +stages: + +# ══════════════════════════════════════════════════════════════════════════════ +# Stage 0 · PreBuildCheck — SDL security scans (PoliCheck + CredScan) +# Always runs, mirrors MSAL.NET pre-build analysis. +# ══════════════════════════════════════════════════════════════════════════════ +- stage: PreBuildCheck + displayName: 'Pre-build security checks' + jobs: + - job: SecurityScan + displayName: 'PoliCheck + CredScan' + pool: + vmImage: windows-latest + variables: + Codeql.SkipTaskAutoInjection: true + steps: + - task: NodeTool@0 + displayName: 'Install Node.js (includes npm)' + inputs: + versionSpec: '20.x' + + - task: securedevelopmentteam.vss-secure-development-tools.build-task-policheck.PoliCheck@2 + displayName: 'Run PoliCheck' + inputs: + targetType: F + continueOnError: true + + - task: securedevelopmentteam.vss-secure-development-tools.build-task-credscan.CredScan@3 + displayName: 'Run CredScan' + inputs: + suppressionsFile: '$(Build.SourcesDirectory)/.Pipelines/credscan-exclusion.json' + toolMajorVersion: V2 + debugMode: false + + - task: securedevelopmentteam.vss-secure-development-tools.build-task-postanalysis.PostAnalysis@2 + displayName: 'Post Analysis' + inputs: + GdnBreakGdnToolCredScan: true + GdnBreakGdnToolPoliCheck: true + +# ══════════════════════════════════════════════════════════════════════════════ +# Stage 1 · Validate — verify packageVersion matches msal/sku.py __version__ +# Skipped when runPublish is false (PR / merge builds). +# ══════════════════════════════════════════════════════════════════════════════ +- stage: Validate + displayName: 'Validate version' + dependsOn: PreBuildCheck + condition: and(${{ parameters.runPublish }}, eq(dependencies.PreBuildCheck.result, 'Succeeded')) + jobs: + - job: ValidateVersion + displayName: 'Check version matches source' + pool: + vmImage: ubuntu-latest + steps: + - task: UsePythonVersion@0 + inputs: + versionSpec: '3.12' + displayName: 'Set up Python' + + - script: | + python - <<'EOF' + import sys, runpy + ns = runpy.run_path("msal/sku.py") + sku_ver = ns.get("__version__", "") + param_ver = "${{ parameters.packageVersion }}" + if not param_ver: + print("##vso[task.logissue type=error]packageVersion is required. Enter the version to publish (must match msal/sku.py __version__).") + sys.exit(1) + elif param_ver != sku_ver: + print(f"##vso[task.logissue type=error]Version mismatch: parameter '{param_ver}' != msal/sku.py '{sku_ver}'") + print("Update msal/sku.py __version__ to match the packageVersion parameter, or correct the parameter.") + sys.exit(1) + else: + print(f"Version validated: {param_ver}") + EOF + displayName: 'Verify version parameter matches msal/sku.py' + +# ══════════════════════════════════════════════════════════════════════════════ +# Stage 2 · CI — run the full test matrix across all supported Python versions. +# Always runs. Waits for Validate when runPublish is true; +# runs immediately when Validate is skipped (PR / merge builds). +# ══════════════════════════════════════════════════════════════════════════════ +- stage: CI + displayName: 'Run tests' + dependsOn: + - PreBuildCheck + - Validate + condition: | + and( + eq(dependencies.PreBuildCheck.result, 'Succeeded'), + in(dependencies.Validate.result, 'Succeeded', 'Skipped') + ) + jobs: + - job: Test + displayName: 'Run unit tests' + pool: + vmImage: ubuntu-latest + strategy: + matrix: + Python39: + python.version: '3.9' + Python310: + python.version: '3.10' + Python311: + python.version: '3.11' + Python312: + python.version: '3.12' + Python313: + python.version: '3.13' + Python314: + python.version: '3.14' + steps: + # Retrieve the MSID Lab certificate from Key Vault (via AuthSdkResourceManager SC). + # Matches the pattern used by MSAL.js (install-keyvault-secrets.yml) and MSAL Java. + # Skipped on forked PRs — service connections are not available to forks. + # E2E tests self-skip when LAB_APP_CLIENT_CERT_PFX_PATH is unset. + - task: AzureKeyVault@2 + displayName: 'Retrieve lab certificate from Key Vault' + condition: ne(variables['System.PullRequest.IsFork'], 'True') + inputs: + azureSubscription: 'AuthSdkResourceManager' + KeyVaultName: 'msidlabs' + SecretsFilter: 'LabAuth' + RunAsPreJob: false + + - bash: | + set -euo pipefail + if [ -z "${LAB_AUTH_B64:-}" ]; then + echo "##vso[task.logissue type=error]LabAuth secret is empty or was not injected — Key Vault retrieval may have failed." + exit 1 + fi + CERT_PATH="$(Agent.TempDirectory)/lab-auth.pfx" + printf '%s' "$LAB_AUTH_B64" | base64 -d > "$CERT_PATH" + echo "##vso[task.setvariable variable=LAB_APP_CLIENT_CERT_PFX_PATH]$CERT_PATH" + echo "Lab cert written to: $CERT_PATH ($(wc -c < "$CERT_PATH") bytes)" + displayName: 'Write lab certificate to disk' + condition: ne(variables['System.PullRequest.IsFork'], 'True') + env: + LAB_AUTH_B64: $(LabAuth) + + - task: UsePythonVersion@0 + inputs: + versionSpec: '$(python.version)' + displayName: 'Set up Python' + + - script: | + python -m pip install --upgrade pip + pip install -r requirements.txt + displayName: 'Install dependencies' + + # Use bash: explicitly; set -o pipefail so that pytest failures aren't hidden by the pipe to tee. + # Without pipefail, tee exits 0 and the step can succeed even when tests fail. + # (set -o pipefail also works in script: steps, but bash: makes the shell choice explicit.) + - bash: | + pip install pytest pytest-azurepipelines + mkdir -p test-results + set -o pipefail + pytest -vv --junitxml=test-results/junit.xml 2>&1 | tee test-results/pytest.log + displayName: 'Run tests' + env: + LAB_APP_CLIENT_CERT_PFX_PATH: $(LAB_APP_CLIENT_CERT_PFX_PATH) + + - task: PublishTestResults@2 + displayName: 'Publish test results' + condition: succeededOrFailed() + inputs: + testResultsFormat: 'JUnit' + testResultsFiles: 'test-results/junit.xml' + failTaskOnFailedTests: true + testRunTitle: 'Python $(python.version)' + + - bash: rm -f "$(Agent.TempDirectory)/lab-auth.pfx" + displayName: 'Clean up lab certificate' + condition: always() diff --git a/.gitignore b/.gitignore index 1af10eff..b2344a37 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ __pycache__/ # Result of running python setup.py install/pip install -e /build/ /msal.egg-info/ +/msal-key-attestation/msal_key_attestation.egg-info/ # Test results /TestResults/ diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 0426b1eb..b800165e 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -1,59 +1,26 @@ -# Derived from the default YAML generated by Azure DevOps for a Python package -# Create and test a Python package on multiple Python versions. -# Add steps that analyze code, save the dist with the build record, publish to a PyPI-compatible index, and more: -# https://docs.microsoft.com/azure/devops/pipelines/languages/python +# PR gate and branch CI for the msal Python package. +# Runs on pushes to dev/azure-pipelines, on all pull requests, and on a daily schedule. +# Delegates all stages to .Pipelines/template-pipeline-stages.yml with +# runPublish: false — PreBuildCheck (SDL scans) + CI (test matrix) only. trigger: - dev - azure-pipelines -pool: - vmImage: ubuntu-latest -strategy: - matrix: - Python39: - python.version: '3.9' - Python310: - python.version: '3.10' - Python311: - python.version: '3.11' - Python312: - python.version: '3.12' - Python313: - python.version: '3.13' - Python314: - python.version: '3.14' +pr: + branches: + include: + - '*' -steps: -- task: UsePythonVersion@0 - inputs: - versionSpec: '$(python.version)' - displayName: 'Use Python $(python.version)' +schedules: +- cron: '45 7 * * *' # 07:45 UTC daily (11:45 PM PST / 12:45 AM PDT) — matches legacy MSAL-Python-SDL-CI schedule + displayName: 'Daily SDL + CI (dev)' + branches: + include: + - dev + always: false # only run when there are new changes -- script: | - python -m pip install --upgrade pip - pip install -r requirements.txt - displayName: 'Install dependencies' - -- script: | - pip install pytest pytest-azurepipelines - mkdir -p test-results - set -o pipefail - pytest -vv --junitxml=test-results/junit.xml 2>&1 | tee test-results/pytest.log - displayName: 'pytest (verbose + junit + log)' - -- task: PublishTestResults@2 - displayName: 'Publish test results' - condition: succeededOrFailed() - inputs: - testResultsFormat: 'JUnit' - testResultsFiles: 'test-results/junit.xml' - failTaskOnFailedTests: true - testRunTitle: 'Python $(python.version) pytest' - -- task: PublishPipelineArtifact@1 - displayName: 'Publish pytest log artifact' - condition: succeededOrFailed() - inputs: - targetPath: 'test-results' - artifact: 'pytest-logs-$(python.version)' +stages: +- template: .Pipelines/template-pipeline-stages.yml + parameters: + runPublish: false diff --git a/msal-key-attestation/README.md b/msal-key-attestation/README.md new file mode 100644 index 00000000..97b0dfef --- /dev/null +++ b/msal-key-attestation/README.md @@ -0,0 +1,94 @@ +# msal-key-attestation + +KeyGuard attestation support for **MSAL Python** MSI v2 (mTLS Proof-of-Possession). + +This package provides the `AttestationClientLib.dll` bindings for Windows +Credential Guard / KeyGuard key attestation via Azure Attestation (MAA). + +## Installation + +```bash +pip install msal msal-key-attestation +``` + +## Prerequisites + +- **Windows** with Credential Guard / KeyGuard enabled (Azure VM with VBS) +- **AttestationClientLib.dll** — place it next to your application, or set + `ATTESTATION_CLIENTLIB_PATH` environment variable to its full path. + +## Usage + +```python +import msal, requests + +client = msal.ManagedIdentityClient( + msal.SystemAssignedManagedIdentity(), + http_client=requests.Session(), +) + +# with_attestation_support=True auto-discovers msal-key-attestation +result = client.acquire_token_for_client( + resource="https://graph.microsoft.com", + mtls_proof_of_possession=True, + with_attestation_support=True, +) + +if "access_token" in result: + print(f"Token type: {result['token_type']}") # mtls_pop + print(f"Cert thumbprint: {result.get('cert_thumbprint_sha256', 'N/A')}") +else: + print(f"Error: {result.get('error_description', result)}") +``` + +## How it works + +1. MSAL Python's MSI v2 flow creates a KeyGuard-protected RSA key (via NCrypt) +2. When `with_attestation_support=True`, MSAL auto-imports this package +3. This package calls `AttestationClientLib.dll` to attest the key with MAA +4. The attestation JWT is cached in-memory (~90% of its lifetime) +5. MSAL sends the JWT + CSR to IMDS `/issuecredential` +6. IMDS returns a short-lived certificate, which MSAL uses for mTLS token + acquisition + +## Architecture + +``` +┌─────────────────────────────────────────┐ +│ msal (pip install msal) │ +│ │ +│ ManagedIdentityClient │ +│ └─ acquire_token_for_client() │ +│ mtls_proof_of_possession=True │ +│ with_attestation_support=True │ +│ │ +│ msal.msi_v2 (core flow) │ +│ - NCrypt KeyGuard key (ctypes) │ +│ - PKCS#10 CSR builder │ +│ - IMDS getplatformmetadata │ +│ - IMDS issuecredential │ +│ - Crypt32 cert binding │ +│ - WinHTTP/SChannel mTLS │ +│ - Certificate cache (in-memory) │ +└────────────────┬────────────────────────┘ + │ auto-discovers via import +┌────────────────▼────────────────────────┐ +│ msal-key-attestation │ +│ (pip install msal-key-attestation) │ +│ │ +│ create_attestation_provider() │ +│ - AttestationClientLib.dll bindings │ +│ - MAA token cache (in-memory) │ +└─────────────────────────────────────────┘ +``` + +## Environment Variables + +| Variable | Description | +|---|---| +| `ATTESTATION_CLIENTLIB_PATH` | Full path to `AttestationClientLib.dll` | +| `MSAL_MSI_V2_ATTESTATION_CACHE` | `"0"` to disable MAA JWT caching | + +## License + +MIT diff --git a/msal-key-attestation/msal_key_attestation/__init__.py b/msal-key-attestation/msal_key_attestation/__init__.py new file mode 100644 index 00000000..2cfe90ec --- /dev/null +++ b/msal-key-attestation/msal_key_attestation/__init__.py @@ -0,0 +1,40 @@ +# Copyright (c) Microsoft Corporation. +# All rights reserved. +# +# This code is licensed under the MIT License. +""" +msal-key-attestation — KeyGuard attestation support for MSAL Python MSI v2. + +This package provides the ``create_attestation_provider()`` function that +returns a callable suitable for the ``attestation_token_provider`` parameter +in ``msal.msi_v2.obtain_token()``. + +It loads the Windows-only ``AttestationClientLib.dll`` (Azure Attestation +native library) via ctypes and exposes a high-level API that: + +- Initializes the native attestation library +- Calls ``AttestKeyGuardImportKey`` with the CNG key handle +- Returns the attestation JWT +- Caches the JWT in-memory until ~90 % of its lifetime + +Usage:: + + from msal_key_attestation import create_attestation_provider + + # Pass to MSAL's MSI v2 flow: + result = client.acquire_token_for_client( + resource="https://graph.microsoft.com", + mtls_proof_of_possession=True, + with_attestation_support=True, # auto-discovers this package + ) + + # Or use the provider directly: + provider = create_attestation_provider() + jwt = provider(attestation_endpoint, key_handle_int, client_id) +""" + +__version__ = "0.1.0" + +from .attestation import create_attestation_provider, get_attestation_jwt + +__all__ = ["create_attestation_provider", "get_attestation_jwt"] diff --git a/msal-key-attestation/msal_key_attestation/attestation.py b/msal-key-attestation/msal_key_attestation/attestation.py new file mode 100644 index 00000000..408318c2 --- /dev/null +++ b/msal-key-attestation/msal_key_attestation/attestation.py @@ -0,0 +1,387 @@ +# Copyright (c) Microsoft Corporation. +# All rights reserved. +# +# This code is licensed under the MIT License. +""" +Windows attestation for MSI v2 KeyGuard keys using AttestationClientLib.dll. + +This module calls into AttestationClientLib.dll to mint an attestation JWT for +a KeyGuard key handle. It also provides a small in-memory cache to reuse the +attestation JWT until ~90% of its lifetime. + +Caching notes: + - Cache is process-local (in-memory). Does not persist across process + restarts. + - Cache is keyed by (attestation_endpoint, client_id, cache_key). + - Provide a stable cache_key (e.g., the named per-boot key name) to + maximize hits. + - If the token cannot be parsed or has no ``exp`` claim, it is not cached. + +Env vars: + - ATTESTATION_CLIENTLIB_PATH: absolute path to AttestationClientLib.dll + - MSAL_MSI_V2_ATTESTATION_CACHE: "0" disables caching (default enabled) +""" + +from __future__ import annotations + +import base64 +import ctypes +import json +import logging +import os +import sys +import threading +import time +from ctypes import POINTER, Structure, c_char_p, c_int, c_void_p +from dataclasses import dataclass +from typing import Callable, Optional, Tuple + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Native callback type — prevent GC of the delegate +# --------------------------------------------------------------------------- + +_NATIVE_LOG_CB = None + +# void LogFunc(void* ctx, const char* tag, int lvl, const char* func, +# int line, const char* msg); +_LogFunc = ctypes.CFUNCTYPE( + None, c_void_p, c_char_p, c_int, c_char_p, c_int, c_char_p) + + +class AttestationLogInfo(Structure): + _fields_ = [("Log", c_void_p), ("Ctx", c_void_p)] + + +def _default_logger(ctx, tag, lvl, func, line, msg): + try: + tag_s = tag.decode("utf-8", errors="replace") if tag else "" + func_s = func.decode("utf-8", errors="replace") if func else "" + msg_s = msg.decode("utf-8", errors="replace") if msg else "" + logger.debug("[Native:%s:%s] %s:%s - %s", + tag_s, lvl, func_s, line, msg_s) + except Exception: + pass + + +# --------------------------------------------------------------------------- +# Env helpers +# --------------------------------------------------------------------------- + +def _truthy_env(name: str, default: str = "1") -> bool: + val = os.getenv(name, default) + return (val or "").strip().lower() in ("1", "true", "yes", "y", "on") + + +def _maybe_add_dll_dirs(): + """Make DLL resolution more reliable (especially for packaged apps).""" + if sys.platform != "win32": + return + add_dir = getattr(os, "add_dll_directory", None) + if not add_dir: + return + for p in (os.path.dirname(sys.executable), + os.getcwd(), os.path.dirname(__file__)): + try: + if p and os.path.isdir(p): + add_dir(p) + except Exception: + pass + + +# --------------------------------------------------------------------------- +# DLL loading +# --------------------------------------------------------------------------- + +def _load_lib(): + if sys.platform != "win32": + raise RuntimeError( + "[msal_key_attestation] AttestationClientLib is Windows-only.") + + _maybe_add_dll_dirs() + + explicit = os.getenv("ATTESTATION_CLIENTLIB_PATH") + try: + if explicit: + return ctypes.CDLL(explicit) + return ctypes.CDLL("AttestationClientLib.dll") + except OSError as exc: + raise RuntimeError( + "[msal_key_attestation] Unable to load AttestationClientLib.dll. " + "Place it next to the app/exe or set ATTESTATION_CLIENTLIB_PATH." + ) from exc + + +# --------------------------------------------------------------------------- +# JWT parsing (for cache lifetime) +# --------------------------------------------------------------------------- + +def _b64url_decode(s: str) -> bytes: + s = (s or "").strip() + s += "=" * ((4 - len(s) % 4) % 4) + return base64.urlsafe_b64decode(s.encode("ascii")) + + +def _try_extract_exp_iat(jwt: str) -> Tuple[Optional[int], Optional[int]]: + """Extract exp and iat (Unix seconds) from a JWT without validation.""" + try: + parts = jwt.split(".") + if len(parts) < 2: + return None, None + payload = json.loads( + _b64url_decode(parts[1]).decode("utf-8", errors="replace")) + if not isinstance(payload, dict): + return None, None + + def _to_int(v): + if isinstance(v, bool): + return None + if isinstance(v, int): + return v + if isinstance(v, float): + return int(v) + if isinstance(v, str) and v.strip().isdigit(): + return int(v.strip()) + return None + + return _to_int(payload.get("exp")), _to_int(payload.get("iat")) + except Exception: + return None, None + + +# --------------------------------------------------------------------------- +# MAA token cache (in-memory, process-local) +# --------------------------------------------------------------------------- + +@dataclass(frozen=True) +class _CacheKey: + attestation_endpoint: str + client_id: str + cache_key: str + auth_token: str + client_payload: str + + +@dataclass +class _CacheEntry: + jwt: str + exp: int + refresh_after: float # epoch seconds + + +_CACHE_LOCK = threading.Lock() +_CACHE: dict = {} + + +def _cache_lookup(key: _CacheKey) -> Optional[str]: + if not _truthy_env("MSAL_MSI_V2_ATTESTATION_CACHE", "1"): + return None + now = time.time() + with _CACHE_LOCK: + entry = _CACHE.get(key) + if not entry: + return None + if now >= entry.refresh_after or now >= entry.exp - 5: + return None + logger.debug("[msal_key_attestation] MAA cache HIT") + return entry.jwt + + +def _cache_store(key: _CacheKey, jwt: str) -> None: + if not _truthy_env("MSAL_MSI_V2_ATTESTATION_CACHE", "1"): + return + exp, iat = _try_extract_exp_iat(jwt) + if exp is None: + return + now = int(time.time()) + issued_at = iat if iat is not None else now + lifetime = exp - issued_at + if lifetime <= 0: + return + # Refresh at 90% of lifetime with small absolute guard + refresh_after = issued_at + (0.90 * lifetime) + refresh_after = min(refresh_after, exp - 10) + with _CACHE_LOCK: + _CACHE[key] = _CacheEntry( + jwt=jwt, exp=exp, refresh_after=float(refresh_after)) + logger.debug("[msal_key_attestation] MAA cache SET") + + +def _cache_clear() -> None: + """Clear cache (for testing).""" + with _CACHE_LOCK: + _CACHE.clear() + + +# --------------------------------------------------------------------------- +# Core attestation call +# --------------------------------------------------------------------------- + +def get_attestation_jwt( + *, + attestation_endpoint: str, + client_id: str, + key_handle: int, + auth_token: str = "", + client_payload: str = "{}", + cache_key: Optional[str] = None, +) -> str: + """ + Get attestation JWT from AttestationClientLib.dll for a KeyGuard key. + + Args: + attestation_endpoint: MAA endpoint URL. + client_id: Client ID for attestation. + key_handle: NCrypt key handle (integer). + auth_token: Optional auth token for attestation service. + client_payload: Optional JSON payload. + cache_key: Stable identifier for caching (recommended: key name). + + Returns: + Attestation JWT string. + + Raises: + RuntimeError: on DLL load or attestation failure. + """ + if not attestation_endpoint: + raise ValueError( + "[msal_key_attestation] attestation_endpoint must be non-empty") + if not client_id: + raise ValueError( + "[msal_key_attestation] client_id must be non-empty") + if not key_handle: + raise ValueError( + "[msal_key_attestation] key_handle must be non-zero") + + stable = cache_key if cache_key is not None else f"handle:{key_handle}" + ck = _CacheKey( + attestation_endpoint=str(attestation_endpoint), + client_id=str(client_id), + cache_key=str(stable), + auth_token=str(auth_token or ""), + client_payload=str(client_payload or "{}"), + ) + + cached = _cache_lookup(ck) + if cached: + return cached + + lib = _load_lib() + + lib.InitAttestationLib.argtypes = [POINTER(AttestationLogInfo)] + lib.InitAttestationLib.restype = c_int + + lib.AttestKeyGuardImportKey.argtypes = [ + c_char_p, # endpoint + c_char_p, # authToken + c_char_p, # clientPayload + c_void_p, # keyHandle (NCRYPT_KEY_HANDLE) + POINTER(c_void_p), # out token (char*) + c_char_p, # clientId + ] + lib.AttestKeyGuardImportKey.restype = c_int + + lib.FreeAttestationToken.argtypes = [c_void_p] + lib.FreeAttestationToken.restype = None + + lib.UninitAttestationLib.argtypes = [] + lib.UninitAttestationLib.restype = None + + global _NATIVE_LOG_CB # pylint: disable=global-statement + _NATIVE_LOG_CB = _LogFunc(_default_logger) + + info = AttestationLogInfo() + info.Log = ctypes.cast(_NATIVE_LOG_CB, c_void_p).value + info.Ctx = c_void_p(0) + + rc = lib.InitAttestationLib(ctypes.byref(info)) + if rc != 0: + raise RuntimeError( + f"[msal_key_attestation] InitAttestationLib failed: {rc}") + + token_ptr = c_void_p() + try: + rc = lib.AttestKeyGuardImportKey( + attestation_endpoint.encode("utf-8"), + (auth_token or "").encode("utf-8"), + (client_payload or "{}").encode("utf-8"), + c_void_p(int(key_handle)), + ctypes.byref(token_ptr), + client_id.encode("utf-8"), + ) + if rc != 0: + raise RuntimeError( + f"[msal_key_attestation] AttestKeyGuardImportKey failed: {rc}") + if not token_ptr.value: + raise RuntimeError( + "[msal_key_attestation] Attestation token pointer is NULL") + + token = ctypes.string_at(token_ptr.value).decode( + "utf-8", errors="replace") + if not token or "." not in token: + raise RuntimeError( + "[msal_key_attestation] Attestation token looks malformed") + + _cache_store(ck, token) + return token + finally: + try: + if token_ptr.value: + lib.FreeAttestationToken(token_ptr) + finally: + try: + lib.UninitAttestationLib() + except Exception: + pass + + +# --------------------------------------------------------------------------- +# Public factory — matches the callback signature MSAL expects: +# (endpoint: str, key_handle: int, client_id: str, cache_key: str) -> str +# --------------------------------------------------------------------------- + +def create_attestation_provider() -> Callable[[str, int, str, str], str]: + """ + Create an attestation token provider callable for MSAL MSI v2. + + The returned callable has signature:: + + provider(attestation_endpoint: str, key_handle: int, + client_id: str, cache_key: str) -> str + + ``cache_key`` should be the stable per-boot key name. Using the key + name (rather than the numeric handle) maximizes MAA-token cache hits + across key re-opens. + + It wraps :func:`get_attestation_jwt` with caching support. + + Usage:: + + from msal_key_attestation import create_attestation_provider + provider = create_attestation_provider() + + # MSAL auto-discovers this when with_attestation_support=True. + # Or pass explicitly: + from msal.msi_v2 import obtain_token + result = obtain_token( + http_client, managed_identity, resource, + attestation_token_provider=provider, + ) + + Returns: + Callable suitable for ``attestation_token_provider`` parameter. + """ + def _provider( + attestation_endpoint: str, + key_handle: int, + client_id: str, + cache_key: str = "", + ) -> str: + return get_attestation_jwt( + attestation_endpoint=attestation_endpoint, + client_id=client_id, + key_handle=key_handle, + cache_key=cache_key or None, + ) + return _provider diff --git a/msal-key-attestation/setup.cfg b/msal-key-attestation/setup.cfg new file mode 100644 index 00000000..274926c1 --- /dev/null +++ b/msal-key-attestation/setup.cfg @@ -0,0 +1,37 @@ +[bdist_wheel] +universal=0 + +[metadata] +name = msal-key-attestation +version = attr: msal_key_attestation.__version__ +description = KeyGuard attestation support for MSAL Python MSI v2 (mTLS PoP). Provides AttestationClientLib.dll bindings for Windows Credential Guard key attestation. +long_description = file: README.md +long_description_content_type = text/markdown +license = MIT +author = Microsoft Corporation +author_email = nugetaad@microsoft.com +url = https://github.com/AzureAD/microsoft-authentication-library-for-python +classifiers = + Development Status :: 3 - Alpha + Programming Language :: Python :: 3 :: Only + Programming Language :: Python :: 3 + Programming Language :: Python :: 3.8 + Programming Language :: Python :: 3.9 + Programming Language :: Python :: 3.10 + Programming Language :: Python :: 3.11 + Programming Language :: Python :: 3.12 + Programming Language :: Python :: 3.13 + Programming Language :: Python :: 3.14 + License :: OSI Approved :: MIT License + Operating System :: Microsoft :: Windows + +[options] +packages = find: +python_requires = >=3.8 +install_requires = + msal>=1.32.0 + +[options.packages.find] +exclude = + tests + tests.* diff --git a/msal-key-attestation/setup.py b/msal-key-attestation/setup.py new file mode 100644 index 00000000..8bf1ba93 --- /dev/null +++ b/msal-key-attestation/setup.py @@ -0,0 +1,2 @@ +from setuptools import setup +setup() diff --git a/msal-key-attestation/tests/test_attestation.py b/msal-key-attestation/tests/test_attestation.py new file mode 100644 index 00000000..9ee0ae8a --- /dev/null +++ b/msal-key-attestation/tests/test_attestation.py @@ -0,0 +1,162 @@ +# Copyright (c) Microsoft Corporation. +# All rights reserved. +# +# This code is licensed under the MIT License. +"""Tests for msal-key-attestation package.""" + +import json +import time +import unittest +from unittest.mock import patch, MagicMock + +from msal_key_attestation.attestation import ( + _cache_clear, + _cache_lookup, + _cache_store, + _CacheKey, + _try_extract_exp_iat, + create_attestation_provider, +) + + +class TestMaaTokenCache(unittest.TestCase): + def setUp(self): + _cache_clear() + + def tearDown(self): + _cache_clear() + + def _make_jwt(self, exp=None, iat=None) -> str: + """Create a minimal JWT with exp/iat claims.""" + import base64 + header = base64.urlsafe_b64encode( + b'{"alg":"none"}').rstrip(b"=").decode("ascii") + payload_obj = {} + if exp is not None: + payload_obj["exp"] = exp + if iat is not None: + payload_obj["iat"] = iat + payload = base64.urlsafe_b64encode( + json.dumps(payload_obj).encode("utf-8") + ).rstrip(b"=").decode("ascii") + return f"{header}.{payload}.sig" + + def test_extract_exp_iat(self): + now = int(time.time()) + jwt = self._make_jwt(exp=now + 3600, iat=now) + exp, iat = _try_extract_exp_iat(jwt) + self.assertEqual(exp, now + 3600) + self.assertEqual(iat, now) + + def test_extract_no_exp(self): + jwt = self._make_jwt() + exp, iat = _try_extract_exp_iat(jwt) + self.assertIsNone(exp) + self.assertIsNone(iat) + + def test_extract_malformed_jwt(self): + exp, iat = _try_extract_exp_iat("not.a.jwt.at.all") + self.assertIsNone(exp) + + def test_cache_stores_and_retrieves(self): + now = int(time.time()) + jwt = self._make_jwt(exp=now + 3600, iat=now) + key = _CacheKey("ep", "cid", "ck", "", "{}") + _cache_store(key, jwt) + self.assertEqual(_cache_lookup(key), jwt) + + def test_cache_miss(self): + key = _CacheKey("ep", "cid", "ck", "", "{}") + self.assertIsNone(_cache_lookup(key)) + + def test_cache_expired_not_returned(self): + now = int(time.time()) + jwt = self._make_jwt(exp=now - 100, iat=now - 200) + key = _CacheKey("ep", "cid", "ck", "", "{}") + _cache_store(key, jwt) + # Token already expired, should not be cached + self.assertIsNone(_cache_lookup(key)) + + def test_cache_no_exp_not_cached(self): + jwt = self._make_jwt() + key = _CacheKey("ep", "cid", "ck", "", "{}") + _cache_store(key, jwt) + self.assertIsNone(_cache_lookup(key)) + + @patch.dict("os.environ", {"MSAL_MSI_V2_ATTESTATION_CACHE": "0"}) + def test_cache_disabled_by_env(self): + now = int(time.time()) + jwt = self._make_jwt(exp=now + 3600, iat=now) + key = _CacheKey("ep", "cid", "ck", "", "{}") + _cache_store(key, jwt) + self.assertIsNone(_cache_lookup(key)) + + def test_cache_clear(self): + now = int(time.time()) + jwt = self._make_jwt(exp=now + 3600, iat=now) + key = _CacheKey("ep", "cid", "ck", "", "{}") + _cache_store(key, jwt) + _cache_clear() + self.assertIsNone(_cache_lookup(key)) + + +class TestCreateAttestationProvider(unittest.TestCase): + def test_returns_callable(self): + provider = create_attestation_provider() + self.assertTrue(callable(provider)) + + @patch("msal_key_attestation.attestation.get_attestation_jwt") + def test_provider_calls_get_attestation_jwt(self, mock_get): + mock_get.return_value = "fake.attestation.jwt" + provider = create_attestation_provider() + result = provider( + "https://attest.example.com", 12345, "client-id", "my-key-name") + self.assertEqual(result, "fake.attestation.jwt") + mock_get.assert_called_once_with( + attestation_endpoint="https://attest.example.com", + client_id="client-id", + key_handle=12345, + cache_key="my-key-name", + ) + + @patch("msal_key_attestation.attestation.get_attestation_jwt") + def test_provider_forwards_empty_cache_key_as_none(self, mock_get): + mock_get.return_value = "fake.jwt" + provider = create_attestation_provider() + provider("https://ep", 1, "cid", "") + mock_get.assert_called_once_with( + attestation_endpoint="https://ep", + client_id="cid", + key_handle=1, + cache_key=None, + ) + + +class TestGetAttestationJwtValidation(unittest.TestCase): + def test_empty_endpoint_raises(self): + from msal_key_attestation.attestation import get_attestation_jwt + with self.assertRaises(ValueError): + get_attestation_jwt( + attestation_endpoint="", + client_id="c", + key_handle=1) + + def test_empty_client_id_raises(self): + from msal_key_attestation.attestation import get_attestation_jwt + with self.assertRaises(ValueError): + get_attestation_jwt( + attestation_endpoint="https://ep", + client_id="", + key_handle=1) + + def test_zero_key_handle_raises(self): + from msal_key_attestation.attestation import get_attestation_jwt + with self.assertRaises(ValueError): + get_attestation_jwt( + attestation_endpoint="https://ep", + client_id="c", + key_handle=0) + + +if __name__ == "__main__": + unittest.main() diff --git a/msal/__init__.py b/msal/__init__.py index 295e9756..81763bb1 100644 --- a/msal/__init__.py +++ b/msal/__init__.py @@ -38,6 +38,7 @@ SystemAssignedManagedIdentity, UserAssignedManagedIdentity, ManagedIdentityClient, ManagedIdentityError, + MsiV2Error, ArcPlatformNotSupportedError, ) diff --git a/msal/managed_identity.py b/msal/managed_identity.py index 422b76e3..f35c898a 100644 --- a/msal/managed_identity.py +++ b/msal/managed_identity.py @@ -24,6 +24,11 @@ class ManagedIdentityError(ValueError): pass +class MsiV2Error(ManagedIdentityError): + """Raised when the MSI v2 (mTLS PoP) flow fails.""" + pass + + class ManagedIdentity(UserDict): """Feed an instance of this class to :class:`msal.ManagedIdentityClient` to acquire token for the specified managed identity. @@ -259,6 +264,8 @@ def acquire_token_for_client( *, resource: str, # If/when we support scope, resource will become optional claims_challenge: Optional[str] = None, + mtls_proof_of_possession: bool = False, + with_attestation_support: bool = False, ): """Acquire token for the managed identity. @@ -278,6 +285,25 @@ def acquire_token_for_client( even if the app developer did not opt in for the "CP1" client capability. Upon receiving a `claims_challenge`, MSAL will attempt to acquire a new token. + :param bool mtls_proof_of_possession: (optional) + When True **and** ``with_attestation_support`` is also True, + use the MSI v2 (mTLS Proof-of-Possession) flow to acquire an + ``mtls_pop`` token bound to a short-lived mTLS certificate issued + by the IMDS ``/issuecredential`` endpoint. + + Requires Windows with Credential Guard / KeyGuard active. + Without ``with_attestation_support``, this flag alone falls + through to the legacy IMDS v1 flow. Defaults to False. + + :param bool with_attestation_support: (optional) + When True (and ``mtls_proof_of_possession`` is also True), + perform KeyGuard / platform attestation before credential + issuance. This requires the **msal-key-attestation** package + (``pip install msal-key-attestation``). + + Setting this to True without ``mtls_proof_of_possession`` + raises :class:`ManagedIdentityError`. Defaults to False. + .. note:: Known issue: When an Azure VM has only one user-assigned managed identity, @@ -292,6 +318,46 @@ def acquire_token_for_client( client_id_in_cache = self._managed_identity.get( ManagedIdentity.ID, "SYSTEM_ASSIGNED_MANAGED_IDENTITY") now = time.time() + + # --- MSI v2 gate --- + # MSI v2 is opt-in: both mtls_proof_of_possession AND + # with_attestation_support must be True. + # No auto-fallback: if v2 fails, MsiV2Error is raised. + use_msi_v2 = bool(mtls_proof_of_possession and with_attestation_support) + + if with_attestation_support and not mtls_proof_of_possession: + raise ManagedIdentityError( + "attestation_requires_pop: with_attestation_support=True " + "requires mtls_proof_of_possession=True (mTLS PoP).") + + if use_msi_v2: + # Auto-discover attestation provider from msal-key-attestation + attestation_token_provider = None + try: + from msal_key_attestation import create_attestation_provider + attestation_token_provider = create_attestation_provider() + except ImportError: + raise MsiV2Error( + "[msi_v2] with_attestation_support=True requires the " + "msal-key-attestation package. " + "Install it with: pip install msal-key-attestation") + + from .msi_v2 import obtain_token as _obtain_token_v2 + try: + result = _obtain_token_v2( + self._http_client, self._managed_identity, resource, + attestation_enabled=True, + attestation_token_provider=attestation_token_provider, + ) + except MsiV2Error: + raise + except Exception as exc: + raise MsiV2Error( + f"[msi_v2] Unexpected failure: {exc}") from exc + if "access_token" in result and "error" not in result: + result[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP + return result + if True: # Attempt cache search even if receiving claims_challenge, # because we want to locate the existing token (if any) and refresh it matches = self._token_cache.search( diff --git a/msal/msi_v2.py b/msal/msi_v2.py new file mode 100644 index 00000000..468201dd --- /dev/null +++ b/msal/msi_v2.py @@ -0,0 +1,1437 @@ +# Copyright (c) Microsoft Corporation. +# All rights reserved. +# +# This code is licensed under the MIT License. +""" +MSI v2 (IMDSv2) Managed Identity flow — Windows KeyGuard + SChannel mTLS PoP. + +This module implements the MSI v2 token acquisition path using Windows native APIs +via ctypes: + - CNG/NCrypt: create/open a KeyGuard-protected per-boot RSA key (non-exportable) + - Minimal DER/PKCS#10: build a CSR signed with RSA-PSS/SHA256 + - IMDS: call getplatformmetadata + issuecredential + - Crypt32: bind the issued certificate to the CNG private key + - WinHTTP/SChannel: acquire access token over mTLS (token_type=mtls_pop) + +Key behavior: + - Uses a *named per-boot key*: opens the key if it already exists for this boot; + otherwise creates it. + - No MSI v1 fallback: any MSI v2 failure raises MsiV2Error. + - Production-ready handle management: all WinHTTP / Crypt32 / NCrypt handles are + released in finally blocks. + - Certificate cache: in-memory with lifetime-based eviction (like .NET + InMemoryCertificateCache). + - Returns certificate with token for mTLS with resource. + +Environment variables (optional): + - AZURE_POD_IDENTITY_AUTHORITY_HOST: override IMDS base URL + (default http://169.254.169.254) + - MSAL_MSI_V2_KEY_NAME: override the per-boot key name (otherwise derived from + metadata clientId) +""" + +from __future__ import annotations + +import base64 +import hashlib +import json +import logging +import os +import struct +import sys +import threading +import time +import uuid +from typing import Any, Callable, Dict, List, Optional, Tuple +from urllib.parse import urlparse, urlencode + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# IMDS constants +# --------------------------------------------------------------------------- + +_IMDS_DEFAULT_BASE = "http://169.254.169.254" +_IMDS_BASE_ENVVAR = "AZURE_POD_IDENTITY_AUTHORITY_HOST" + +_API_VERSION_QUERY_PARAM = "cred-api-version" +_IMDS_V2_API_VERSION = "2.0" + +_CSR_METADATA_PATH = "/metadata/identity/getplatformmetadata" +_ISSUE_CREDENTIAL_PATH = "/metadata/identity/issuecredential" +_ACQUIRE_ENTRA_TOKEN_PATH = "/oauth2/v2.0/token" + +_CU_ID_OID_STR = "1.3.6.1.4.1.311.90.2.10" + +# --------------------------------------------------------------------------- +# NCrypt/CNG flags +# --------------------------------------------------------------------------- + +_NCRYPT_USE_VIRTUAL_ISOLATION_FLAG = 0x00020000 +_NCRYPT_USE_PER_BOOT_KEY_FLAG = 0x00040000 + +_RSA_KEY_SIZE = 2048 + +_AT_SIGNATURE = 2 + +_NCRYPT_SILENT_FLAG = 0x40 + +_KEY_NAME_ENVVAR = "MSAL_MSI_V2_KEY_NAME" + +# NCrypt "not found" status codes +_NTE_BAD_KEYSET = 0x80090016 +_NTE_NO_KEY = 0x8009000D +_NTE_NOT_FOUND = 0x80090011 +_NTE_EXISTS = 0x8009000F + +# Lazy-loaded Win32 API cache +_WIN32: Optional[Dict[str, Any]] = None + + +# --------------------------------------------------------------------------- +# Certificate cache (in-memory, process-local) +# --------------------------------------------------------------------------- + +class _CertCacheEntry: + """Cached mTLS certificate + metadata.""" + __slots__ = ("cert_der", "cert_pem", "token_endpoint", "client_id", + "not_after", "created_at") + + # Minimum remaining cert lifetime to cache (24 hours) + MIN_REMAINING_LIFETIME_SEC = 24 * 3600 + + def __init__(self, cert_der: bytes, cert_pem: str, + token_endpoint: str, client_id: str, + not_after: float): + self.cert_der = cert_der + self.cert_pem = cert_pem + self.token_endpoint = token_endpoint + self.client_id = client_id + self.not_after = not_after + self.created_at = time.time() + + def is_expired(self, now: Optional[float] = None) -> bool: + now = now or time.time() + return now >= self.not_after - self.MIN_REMAINING_LIFETIME_SEC + + +_CERT_CACHE_LOCK = threading.Lock() +_CERT_CACHE: Dict[str, _CertCacheEntry] = {} + + +def _cert_cache_key(managed_identity: Optional[Dict[str, Any]], + attested: bool) -> str: + """Build a cache key from managed identity + identifier type + attestation flag.""" + mi_id_type = "SYSTEM_ASSIGNED" + mi_id = "SYSTEM_ASSIGNED" + if isinstance(managed_identity, dict): + mi_id_type = str( + managed_identity.get("ManagedIdentityIdType") or "SYSTEM_ASSIGNED") + mi_id = str(managed_identity.get("Id") or "SYSTEM_ASSIGNED") + tag = "#att=1" if attested else "#att=0" + return mi_id_type + ":" + mi_id + tag + + +def _cert_cache_get(key: str) -> Optional[_CertCacheEntry]: + """Return cached entry or None if missing/expired.""" + now = time.time() + with _CERT_CACHE_LOCK: + entry = _CERT_CACHE.get(key) + if entry is None: + return None + if entry.is_expired(now): + del _CERT_CACHE[key] + logger.debug("[msi_v2] Cert cache EVICT (expired) key=%s", key[:20]) + return None + logger.debug("[msi_v2] Cert cache HIT key=%s", key[:20]) + return entry + + +def _cert_cache_set(key: str, entry: _CertCacheEntry) -> None: + """Store entry if it has sufficient remaining lifetime.""" + now = time.time() + if entry.not_after <= now + _CertCacheEntry.MIN_REMAINING_LIFETIME_SEC: + logger.debug("[msi_v2] Cert cache SKIP (insufficient lifetime) key=%s", + key[:20]) + return + with _CERT_CACHE_LOCK: + _CERT_CACHE[key] = entry + logger.debug("[msi_v2] Cert cache SET key=%s", key[:20]) + + +def _cert_cache_remove(key: str) -> None: + """Remove entry (e.g., on SChannel failure).""" + with _CERT_CACHE_LOCK: + _CERT_CACHE.pop(key, None) + + +def _cert_cache_clear() -> None: + """Clear all entries (for testing).""" + with _CERT_CACHE_LOCK: + _CERT_CACHE.clear() + + +# --------------------------------------------------------------------------- +# Compatibility helpers (tests + cross-language parity) +# --------------------------------------------------------------------------- + +def get_cert_thumbprint_sha256(cert_pem: str) -> str: + """ + Return base64url(SHA256(der(cert))) without padding, for cnf.x5t#S256 + comparisons. Accepts a PEM certificate string. + """ + try: + from cryptography import x509 + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import serialization + + cert = x509.load_pem_x509_certificate( + cert_pem.encode("utf-8"), default_backend()) + der = cert.public_bytes(serialization.Encoding.DER) + digest = hashlib.sha256(der).digest() + return base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii") + except Exception: + return "" + + +def verify_cnf_binding(token: str, cert_pem: str) -> bool: + """ + Verify that JWT payload contains cnf.x5t#S256 matching the cert + thumbprint. + """ + try: + parts = token.split(".") + if len(parts) != 3: + return False + + payload_b64 = parts[1] + payload_b64 += "=" * ((4 - len(payload_b64) % 4) % 4) + claims = json.loads( + base64.urlsafe_b64decode(payload_b64.encode("ascii"))) + + cnf = claims.get("cnf", {}) if isinstance(claims, dict) else {} + if not isinstance(cnf, dict): + return False + token_x5t = cnf.get("x5t#S256") + if not token_x5t: + return False + + cert_x5t = get_cert_thumbprint_sha256(cert_pem) + if not cert_x5t: + return False + + return token_x5t == cert_x5t + except Exception: + return False + + +def _der_to_pem(der_bytes: bytes) -> str: + """Convert DER certificate bytes to PEM string format.""" + b64 = base64.b64encode(der_bytes).decode("ascii") + lines = [b64[i:i + 64] for i in range(0, len(b64), 64)] + return ("-----BEGIN CERTIFICATE-----\n" + + "\n".join(lines) + + "\n-----END CERTIFICATE-----") + + +def _try_parse_cert_not_after(der_bytes: bytes) -> float: + """ + Best-effort extraction of notAfter from a DER X.509 certificate. + Returns epoch seconds. Falls back to now + 8 hours on any failure. + """ + try: + from cryptography import x509 + from cryptography.hazmat.backends import default_backend + cert = x509.load_der_x509_certificate(der_bytes, default_backend()) + na = cert.not_valid_after_utc if hasattr( + cert, "not_valid_after_utc") else cert.not_valid_after + if na.tzinfo is None: + import calendar + return float(calendar.timegm(na.timetuple())) + return na.timestamp() + except Exception: + # Default: assume 8-hour cert lifetime (IMDS typical) + return time.time() + 8 * 3600 + + +# --------------------------------------------------------------------------- +# IMDS helpers +# --------------------------------------------------------------------------- + +def _imds_base() -> str: + return os.getenv(_IMDS_BASE_ENVVAR, _IMDS_DEFAULT_BASE).strip().rstrip("/") + + +def _new_correlation_id() -> str: + return str(uuid.uuid4()) + + +def _imds_headers(correlation_id: Optional[str] = None) -> Dict[str, str]: + return { + "Metadata": "true", + "x-ms-client-request-id": correlation_id or _new_correlation_id(), + } + + +def _resource_to_scope(resource_or_scope: str) -> str: + """Normalize resource to scope format (append /.default if needed).""" + s = (resource_or_scope or "").strip() + if not s: + raise ValueError("resource must be non-empty") + if s.endswith("/.default"): + return s + return s.rstrip("/") + "/.default" + + +def _der_utf8string(value: str) -> bytes: + """DER UTF8String encoder (tag 0x0C).""" + raw = value.encode("utf-8") + n = len(raw) + if n < 0x80: + len_bytes = bytes([n]) + else: + tmp = bytearray() + m = n + while m > 0: + tmp.insert(0, m & 0xFF) + m >>= 8 + len_bytes = bytes([0x80 | len(tmp)]) + bytes(tmp) + return bytes([0x0C]) + len_bytes + raw + + +def _json_loads(text: str, what: str) -> Dict[str, Any]: + """Parse JSON with error context.""" + from .managed_identity import MsiV2Error + try: + obj = json.loads(text) + if not isinstance(obj, dict): + raise TypeError("expected JSON object") + return obj + except Exception as exc: + raise MsiV2Error( + f"[msi_v2] Invalid JSON from {what}: {text!r}") from exc + + +def _get_first(obj: Dict[str, Any], *names: str) -> Optional[str]: + """Get first non-empty value from object by multiple name variants.""" + for n in names: + if n in obj and obj[n] is not None and str(obj[n]).strip() != "": + return str(obj[n]) + lower = {str(k).lower(): k for k in obj.keys()} + for n in names: + k = lower.get(n.lower()) + if k and obj[k] is not None and str(obj[k]).strip() != "": + return str(obj[k]) + return None + + +def _mi_query_params( + managed_identity: Optional[Dict[str, Any]], +) -> Dict[str, str]: + """Build IMDS query params: cred-api-version=2.0 + optional UAMI selector.""" + params: Dict[str, str] = {_API_VERSION_QUERY_PARAM: _IMDS_V2_API_VERSION} + if not isinstance(managed_identity, dict): + return params + id_type = managed_identity.get("ManagedIdentityIdType") + identifier = managed_identity.get("Id") + mapping = {"ClientId": "client_id", "ObjectId": "object_id", + "ResourceId": "msi_res_id"} + wire = mapping.get(id_type) + if wire and identifier: + params[wire] = str(identifier) + return params + + +def _imds_get_json( + http_client, url: str, params: Dict[str, str], + headers: Dict[str, str], +) -> Dict[str, Any]: + """GET request to IMDS with server header verification.""" + from .managed_identity import MsiV2Error + resp = http_client.get(url, params=params, headers=headers) + server = (resp.headers or {}).get("server", "") + if "imds" not in str(server).lower(): + raise MsiV2Error( + f"[msi_v2] IMDS server header check failed. " + f"server={server!r} url={url}") + if resp.status_code != 200: + raise MsiV2Error( + f"[msi_v2] IMDSv2 GET {url} failed: " + f"HTTP {resp.status_code}: {resp.text}") + return _json_loads(resp.text, f"GET {url}") + + +def _imds_post_json( + http_client, url: str, params: Dict[str, str], + headers: Dict[str, str], body: Dict[str, Any], +) -> Dict[str, Any]: + """POST request to IMDS with server header verification.""" + from .managed_identity import MsiV2Error + resp = http_client.post( + url, params=params, headers=headers, + data=json.dumps(body, separators=(",", ":"))) + server = (resp.headers or {}).get("server", "") + if "imds" not in str(server).lower(): + raise MsiV2Error( + f"[msi_v2] IMDS server header check failed. " + f"server={server!r} url={url}") + if resp.status_code != 200: + raise MsiV2Error( + f"[msi_v2] IMDSv2 POST {url} failed: " + f"HTTP {resp.status_code}: {resp.text}") + return _json_loads(resp.text, f"POST {url}") + + +def _token_endpoint_from_credential(cred: Dict[str, Any]) -> str: + """ + Extract token endpoint from issuecredential response. + Prefers explicit token_endpoint, falls back to + mtls_authentication_endpoint + tenant_id. + """ + token_endpoint = _get_first(cred, "token_endpoint", "tokenEndpoint") + if token_endpoint: + return token_endpoint + + mtls_auth = _get_first( + cred, "mtls_authentication_endpoint", + "mtlsAuthenticationEndpoint", "mtls_authenticationEndpoint") + tenant_id = _get_first(cred, "tenant_id", "tenantId") + if not mtls_auth or not tenant_id: + from .managed_identity import MsiV2Error + raise MsiV2Error( + f"[msi_v2] issuecredential missing " + f"mtls_authentication_endpoint/tenant_id: {cred}") + + base = mtls_auth.rstrip("/") + "/" + tenant_id.strip("/") + return base + _ACQUIRE_ENTRA_TOKEN_PATH + + +# --------------------------------------------------------------------------- +# Win32 primitives (ctypes) — lazy loaded +# --------------------------------------------------------------------------- + +def _load_win32() -> Dict[str, Any]: + """Lazy-load Win32 APIs via ctypes (safe to import on non-Windows).""" + global _WIN32 + from .managed_identity import MsiV2Error + + if _WIN32 is not None: + return _WIN32 + if sys.platform != "win32": + raise MsiV2Error("[msi_v2] KeyGuard + mTLS PoP is Windows-only.") + + import ctypes + from ctypes import wintypes + + ncrypt = ctypes.WinDLL("ncrypt.dll") + crypt32 = ctypes.WinDLL("crypt32.dll", use_last_error=True) + winhttp = ctypes.WinDLL("winhttp.dll", use_last_error=True) + + NCRYPT_PROV_HANDLE = ctypes.c_void_p + NCRYPT_KEY_HANDLE = ctypes.c_void_p + SECURITY_STATUS = ctypes.c_long + + class CERT_CONTEXT(ctypes.Structure): + _fields_ = [ + ("dwCertEncodingType", wintypes.DWORD), + ("pbCertEncoded", ctypes.POINTER(ctypes.c_ubyte)), + ("cbCertEncoded", wintypes.DWORD), + ("pCertInfo", ctypes.c_void_p), + ("hCertStore", ctypes.c_void_p), + ] + + PCCERT_CONTEXT = ctypes.POINTER(CERT_CONTEXT) + + class BCRYPT_PSS_PADDING_INFO(ctypes.Structure): + _fields_ = [ + ("pszAlgId", ctypes.c_wchar_p), + ("cbSalt", wintypes.ULONG), + ] + + # NCrypt prototypes + ncrypt.NCryptOpenStorageProvider.argtypes = [ + ctypes.POINTER(NCRYPT_PROV_HANDLE), ctypes.c_wchar_p, wintypes.DWORD] + ncrypt.NCryptOpenStorageProvider.restype = SECURITY_STATUS + + ncrypt.NCryptOpenKey.argtypes = [ + NCRYPT_PROV_HANDLE, ctypes.POINTER(NCRYPT_KEY_HANDLE), + ctypes.c_wchar_p, wintypes.DWORD, wintypes.DWORD] + ncrypt.NCryptOpenKey.restype = SECURITY_STATUS + + ncrypt.NCryptCreatePersistedKey.argtypes = [ + NCRYPT_PROV_HANDLE, ctypes.POINTER(NCRYPT_KEY_HANDLE), + ctypes.c_wchar_p, ctypes.c_wchar_p, wintypes.DWORD, wintypes.DWORD] + ncrypt.NCryptCreatePersistedKey.restype = SECURITY_STATUS + + ncrypt.NCryptSetProperty.argtypes = [ + ctypes.c_void_p, ctypes.c_wchar_p, ctypes.c_void_p, + wintypes.DWORD, wintypes.DWORD] + ncrypt.NCryptSetProperty.restype = SECURITY_STATUS + + ncrypt.NCryptFinalizeKey.argtypes = [NCRYPT_KEY_HANDLE, wintypes.DWORD] + ncrypt.NCryptFinalizeKey.restype = SECURITY_STATUS + + ncrypt.NCryptGetProperty.argtypes = [ + ctypes.c_void_p, ctypes.c_wchar_p, ctypes.c_void_p, wintypes.DWORD, + ctypes.POINTER(wintypes.DWORD), wintypes.DWORD] + ncrypt.NCryptGetProperty.restype = SECURITY_STATUS + + ncrypt.NCryptExportKey.argtypes = [ + NCRYPT_KEY_HANDLE, NCRYPT_KEY_HANDLE, ctypes.c_wchar_p, + ctypes.c_void_p, ctypes.c_void_p, wintypes.DWORD, + ctypes.POINTER(wintypes.DWORD), wintypes.DWORD] + ncrypt.NCryptExportKey.restype = SECURITY_STATUS + + ncrypt.NCryptSignHash.argtypes = [ + NCRYPT_KEY_HANDLE, ctypes.c_void_p, ctypes.c_void_p, wintypes.DWORD, + ctypes.c_void_p, wintypes.DWORD, ctypes.POINTER(wintypes.DWORD), + wintypes.DWORD] + ncrypt.NCryptSignHash.restype = SECURITY_STATUS + + ncrypt.NCryptFreeObject.argtypes = [ctypes.c_void_p] + ncrypt.NCryptFreeObject.restype = SECURITY_STATUS + + # Crypt32 prototypes + crypt32.CertCreateCertificateContext.argtypes = [ + wintypes.DWORD, ctypes.c_void_p, wintypes.DWORD] + crypt32.CertCreateCertificateContext.restype = PCCERT_CONTEXT + + crypt32.CertSetCertificateContextProperty.argtypes = [ + PCCERT_CONTEXT, wintypes.DWORD, wintypes.DWORD, ctypes.c_void_p] + crypt32.CertSetCertificateContextProperty.restype = wintypes.BOOL + + crypt32.CertFreeCertificateContext.argtypes = [PCCERT_CONTEXT] + crypt32.CertFreeCertificateContext.restype = wintypes.BOOL + + # WinHTTP prototypes + winhttp.WinHttpOpen.argtypes = [ + ctypes.c_wchar_p, wintypes.DWORD, ctypes.c_wchar_p, + ctypes.c_wchar_p, wintypes.DWORD] + winhttp.WinHttpOpen.restype = ctypes.c_void_p + + winhttp.WinHttpConnect.argtypes = [ + ctypes.c_void_p, ctypes.c_wchar_p, wintypes.WORD, wintypes.DWORD] + winhttp.WinHttpConnect.restype = ctypes.c_void_p + + winhttp.WinHttpOpenRequest.argtypes = [ + ctypes.c_void_p, ctypes.c_wchar_p, ctypes.c_wchar_p, + ctypes.c_wchar_p, ctypes.c_wchar_p, ctypes.c_void_p, wintypes.DWORD] + winhttp.WinHttpOpenRequest.restype = ctypes.c_void_p + + winhttp.WinHttpSetOption.argtypes = [ + ctypes.c_void_p, wintypes.DWORD, ctypes.c_void_p, wintypes.DWORD] + winhttp.WinHttpSetOption.restype = wintypes.BOOL + + winhttp.WinHttpSendRequest.argtypes = [ + ctypes.c_void_p, ctypes.c_wchar_p, wintypes.DWORD, ctypes.c_void_p, + wintypes.DWORD, wintypes.DWORD, ctypes.c_ulonglong] + winhttp.WinHttpSendRequest.restype = wintypes.BOOL + + winhttp.WinHttpReceiveResponse.argtypes = [ + ctypes.c_void_p, ctypes.c_void_p] + winhttp.WinHttpReceiveResponse.restype = wintypes.BOOL + + winhttp.WinHttpQueryHeaders.argtypes = [ + ctypes.c_void_p, wintypes.DWORD, ctypes.c_wchar_p, ctypes.c_void_p, + ctypes.POINTER(wintypes.DWORD), ctypes.POINTER(wintypes.DWORD)] + winhttp.WinHttpQueryHeaders.restype = wintypes.BOOL + + winhttp.WinHttpQueryDataAvailable.argtypes = [ + ctypes.c_void_p, ctypes.POINTER(wintypes.DWORD)] + winhttp.WinHttpQueryDataAvailable.restype = wintypes.BOOL + + winhttp.WinHttpReadData.argtypes = [ + ctypes.c_void_p, ctypes.c_void_p, wintypes.DWORD, + ctypes.POINTER(wintypes.DWORD)] + winhttp.WinHttpReadData.restype = wintypes.BOOL + + winhttp.WinHttpCloseHandle.argtypes = [ctypes.c_void_p] + winhttp.WinHttpCloseHandle.restype = wintypes.BOOL + + _WIN32 = { + "ctypes": ctypes, "wintypes": wintypes, + "ncrypt": ncrypt, "crypt32": crypt32, "winhttp": winhttp, + "NCRYPT_PROV_HANDLE": NCRYPT_PROV_HANDLE, + "NCRYPT_KEY_HANDLE": NCRYPT_KEY_HANDLE, + "SECURITY_STATUS": SECURITY_STATUS, + "CERT_CONTEXT": CERT_CONTEXT, + "PCCERT_CONTEXT": PCCERT_CONTEXT, + "BCRYPT_PSS_PADDING_INFO": BCRYPT_PSS_PADDING_INFO, + "ERROR_SUCCESS": 0, + "NCRYPT_OVERWRITE_KEY_FLAG": 0x00000080, + "NCRYPT_LENGTH_PROPERTY": "Length", + "NCRYPT_EXPORT_POLICY_PROPERTY": "Export Policy", + "NCRYPT_KEY_USAGE_PROPERTY": "Key Usage", + "NCRYPT_ALLOW_SIGNING_FLAG": 0x00000002, + "NCRYPT_ALLOW_DECRYPT_FLAG": 0x00000001, + "BCRYPT_PAD_PSS": 0x00000008, + "BCRYPT_SHA256_ALGORITHM": "SHA256", + "BCRYPT_RSA_ALGORITHM": "RSA", + "BCRYPT_RSAPUBLIC_BLOB": "RSAPUBLICBLOB", + "BCRYPT_RSAPUBLIC_MAGIC": 0x31415352, + "X509_ASN_ENCODING": 0x00000001, + "PKCS_7_ASN_ENCODING": 0x00010000, + "CERT_NCRYPT_KEY_HANDLE_PROP_ID": 78, + "CERT_SET_PROPERTY_INHIBIT_PERSIST_FLAG": 0x40000000, + "WINHTTP_ACCESS_TYPE_DEFAULT_PROXY": 0, + "WINHTTP_FLAG_SECURE": 0x00800000, + "WINHTTP_OPTION_CLIENT_CERT_CONTEXT": 47, + "WINHTTP_OPTION_ENABLE_HTTP2_PLUS_CLIENT_CERT": 161, + "WINHTTP_QUERY_STATUS_CODE": 19, + "WINHTTP_QUERY_FLAG_NUMBER": 0x20000000, + } + return _WIN32 + + +# --------------------------------------------------------------------------- +# Win32 error helpers +# --------------------------------------------------------------------------- + +def _raise_win32_last_error(msg: str) -> None: + from .managed_identity import MsiV2Error + win32 = _load_win32() + ctypes_mod = win32["ctypes"] + err = ctypes_mod.get_last_error() + detail = "" + try: + detail = ctypes_mod.FormatError(err).strip() + except Exception: + pass + raise MsiV2Error(f"{msg} (winerror={err} {detail})" if detail + else f"{msg} (winerror={err})") + + +def _check_security_status(status: int, what: str) -> None: + from .managed_identity import MsiV2Error + if int(status) != 0: + code_u32 = int(status) & 0xFFFFFFFF + raise MsiV2Error(f"[msi_v2] {what} failed: status=0x{code_u32:08X}") + + +def _status_u32(status: int) -> int: + return int(status) & 0xFFFFFFFF + + +def _is_key_not_found(status: int) -> bool: + return _status_u32(status) in (_NTE_BAD_KEYSET, _NTE_NO_KEY, _NTE_NOT_FOUND) + + +# --------------------------------------------------------------------------- +# DER helpers (minimal PKCS#10 CSR builder) +# --------------------------------------------------------------------------- + +def _der_len(n: int) -> bytes: + if n < 0: + raise ValueError("DER length cannot be negative") + if n < 0x80: + return bytes([n]) + out = bytearray() + m = n + while m > 0: + out.insert(0, m & 0xFF) + m >>= 8 + return bytes([0x80 | len(out)]) + bytes(out) + + +def _der(tag: int, content: bytes) -> bytes: + return bytes([tag]) + _der_len(len(content)) + content + + +def _der_null() -> bytes: + return b"\x05\x00" + + +def _der_integer(value: int) -> bytes: + if value < 0: + raise ValueError("Only non-negative INTEGER supported") + if value == 0: + raw = b"\x00" + else: + raw = value.to_bytes((value.bit_length() + 7) // 8, "big") + if raw[0] & 0x80: + raw = b"\x00" + raw + return _der(0x02, raw) + + +def _der_oid(oid: str) -> bytes: + parts = [int(x) for x in oid.split(".")] + if len(parts) < 2 or parts[0] > 2 or (parts[0] < 2 and parts[1] >= 40): + raise ValueError(f"Invalid OID: {oid}") + first = 40 * parts[0] + parts[1] + out = bytearray([first]) + for p in parts[2:]: + if p < 0: + raise ValueError(f"Invalid OID component: {oid}") + stack = bytearray() + if p == 0: + stack.append(0) + else: + m = p + while m > 0: + stack.insert(0, m & 0x7F) + m >>= 7 + for i in range(len(stack) - 1): + stack[i] |= 0x80 + out.extend(stack) + return _der(0x06, bytes(out)) + + +def _der_sequence(*items: bytes) -> bytes: + return _der(0x30, b"".join(items)) + + +def _der_set(*items: bytes) -> bytes: + enc = sorted(items) + return _der(0x31, b"".join(enc)) + + +def _der_bitstring(data: bytes) -> bytes: + return _der(0x03, b"\x00" + data) + + +def _der_ia5string(value: str) -> bytes: + return _der(0x16, value.encode("ascii")) + + +def _der_context_explicit(tagnum: int, inner: bytes) -> bytes: + return _der(0xA0 + tagnum, inner) + + +def _der_context_implicit_constructed(tagnum: int, inner_content: bytes) -> bytes: + return _der(0xA0 + tagnum, inner_content) + + +def _der_name_cn_dc(cn: str, dc: str) -> bytes: + cn_atv = _der_sequence(_der_oid("2.5.4.3"), _der_utf8string(cn)) + cn_rdn = _der_set(cn_atv) + try: + dc_value = _der_ia5string(dc) + except Exception: + dc_value = _der_utf8string(dc) + dc_atv = _der_sequence( + _der_oid("0.9.2342.19200300.100.1.25"), dc_value) + dc_rdn = _der_set(dc_atv) + return _der_sequence(cn_rdn, dc_rdn) + + +def _der_subject_public_key_info_rsa(modulus: int, exponent: int) -> bytes: + rsa_pub = _der_sequence(_der_integer(modulus), _der_integer(exponent)) + alg = _der_sequence( + _der_oid("1.2.840.113549.1.1.1"), _der_null()) # rsaEncryption + return _der_sequence(alg, _der_bitstring(rsa_pub)) + + +def _der_algid_rsapss_sha256() -> bytes: + """AlgorithmIdentifier for RSASSA-PSS with SHA-256, MGF1(SHA-256), + saltLength=32. trailerField omitted (DEFAULT=1, per .NET).""" + sha256 = _der_sequence( + _der_oid("2.16.840.1.101.3.4.2.1"), _der_null()) + mgf1 = _der_sequence(_der_oid("1.2.840.113549.1.1.8"), sha256) + salt_len = _der_integer(32) + params = _der_sequence( + _der_context_explicit(0, sha256), + _der_context_explicit(1, mgf1), + _der_context_explicit(2, salt_len), + # trailerField [3] omitted — DEFAULT trailerFieldBC(1) + ) + return _der_sequence(_der_oid("1.2.840.113549.1.1.10"), params) + + +# --------------------------------------------------------------------------- +# CNG/NCrypt wrappers +# --------------------------------------------------------------------------- + +def _ncrypt_get_property(win32: Dict[str, Any], h: Any, name: str) -> bytes: + ctypes_mod = win32["ctypes"] + wintypes = win32["wintypes"] + ncrypt = win32["ncrypt"] + cb = wintypes.DWORD(0) + status = ncrypt.NCryptGetProperty(h, name, None, 0, + ctypes_mod.byref(cb), 0) + if int(status) != 0 and cb.value == 0: + _check_security_status(status, f"NCryptGetProperty({name})") + if cb.value == 0: + return b"" + buf = (ctypes_mod.c_ubyte * cb.value)() + status = ncrypt.NCryptGetProperty(h, name, buf, cb.value, + ctypes_mod.byref(cb), 0) + _check_security_status(status, f"NCryptGetProperty({name})") + return bytes(buf[:cb.value]) + + +def _stable_key_name(client_id: str) -> str: + base = (client_id or "").strip() + safe = [] + for ch in base: + if ch.isalnum() or ch in ("-", "_"): + safe.append(ch) + else: + safe.append("_") + return "MsalMsiV2Key_" + "".join(safe)[:90] + + +def _open_or_create_keyguard_rsa_key( + win32: Dict[str, Any], *, key_name: str, +) -> Tuple[Any, Any, str, bool]: + """ + Open a named per-boot KeyGuard RSA key if it exists; otherwise create it. + Returns: (prov_handle, key_handle, key_name, opened_existing) + """ + ctypes_mod = win32["ctypes"] + wintypes = win32["wintypes"] + ncrypt = win32["ncrypt"] + + prov = win32["NCRYPT_PROV_HANDLE"]() + status = ncrypt.NCryptOpenStorageProvider( + ctypes_mod.byref(prov), + "Microsoft Software Key Storage Provider", 0) + _check_security_status(status, "NCryptOpenStorageProvider") + + key = win32["NCRYPT_KEY_HANDLE"]() + + # 1) Try open first + status = ncrypt.NCryptOpenKey(prov, ctypes_mod.byref(key), + str(key_name), _AT_SIGNATURE, 0) + if int(status) == 0: + vi = _ncrypt_get_property(win32, key, "Virtual Iso") + if not vi or len(vi) < 4: + from .managed_identity import MsiV2Error + raise MsiV2Error( + "[msi_v2] Virtual Iso property missing/invalid; " + "Credential Guard likely not active.") + return prov, key, str(key_name), True + + if not _is_key_not_found(status): + _check_security_status(status, f"NCryptOpenKey({key_name})") + + # 2) Create if missing + flags = (win32["NCRYPT_OVERWRITE_KEY_FLAG"] + | _NCRYPT_USE_VIRTUAL_ISOLATION_FLAG + | _NCRYPT_USE_PER_BOOT_KEY_FLAG) + + status = ncrypt.NCryptCreatePersistedKey( + prov, ctypes_mod.byref(key), win32["BCRYPT_RSA_ALGORITHM"], + str(key_name), _AT_SIGNATURE, flags) + + if _status_u32(status) == _NTE_EXISTS: + # Race: another thread/process created it + status2 = ncrypt.NCryptOpenKey(prov, ctypes_mod.byref(key), + str(key_name), _AT_SIGNATURE, 0) + _check_security_status(status2, + f"NCryptOpenKey({key_name}) after exists") + return prov, key, str(key_name), True + + _check_security_status(status, "NCryptCreatePersistedKey") + + # Set key properties + length = wintypes.DWORD(int(_RSA_KEY_SIZE)) + status = ncrypt.NCryptSetProperty( + key, win32["NCRYPT_LENGTH_PROPERTY"], + ctypes_mod.byref(length), ctypes_mod.sizeof(length), 0) + _check_security_status(status, "NCryptSetProperty(Length)") + + usage = wintypes.DWORD( + win32["NCRYPT_ALLOW_SIGNING_FLAG"] + | win32["NCRYPT_ALLOW_DECRYPT_FLAG"]) + status = ncrypt.NCryptSetProperty( + key, win32["NCRYPT_KEY_USAGE_PROPERTY"], + ctypes_mod.byref(usage), ctypes_mod.sizeof(usage), 0) + _check_security_status(status, "NCryptSetProperty(Key Usage)") + + export_policy = wintypes.DWORD(0) # non-exportable + status = ncrypt.NCryptSetProperty( + key, win32["NCRYPT_EXPORT_POLICY_PROPERTY"], + ctypes_mod.byref(export_policy), ctypes_mod.sizeof(export_policy), 0) + _check_security_status(status, "NCryptSetProperty(Export Policy)") + + status = ncrypt.NCryptFinalizeKey(key, 0) + _check_security_status(status, "NCryptFinalizeKey") + + vi = _ncrypt_get_property(win32, key, "Virtual Iso") + if not vi or len(vi) < 4: + from .managed_identity import MsiV2Error + raise MsiV2Error( + "[msi_v2] Virtual Iso property not available; " + "Credential Guard likely not active.") + + return prov, key, str(key_name), False + + +def _ncrypt_export_rsa_public( + win32: Dict[str, Any], key: Any, +) -> Tuple[int, int]: + """Export RSA public key (modulus, exponent) from an NCrypt key handle.""" + from .managed_identity import MsiV2Error + ctypes_mod = win32["ctypes"] + wintypes = win32["wintypes"] + ncrypt = win32["ncrypt"] + + cb = wintypes.DWORD(0) + status = ncrypt.NCryptExportKey( + key, None, win32["BCRYPT_RSAPUBLIC_BLOB"], None, None, 0, + ctypes_mod.byref(cb), 0) + if int(status) != 0 and cb.value == 0: + _check_security_status(status, "NCryptExportKey(size)") + if cb.value == 0: + raise MsiV2Error("[msi_v2] NCryptExportKey returned empty blob size") + + buf = (ctypes_mod.c_ubyte * cb.value)() + status = ncrypt.NCryptExportKey( + key, None, win32["BCRYPT_RSAPUBLIC_BLOB"], None, + buf, cb.value, ctypes_mod.byref(cb), 0) + _check_security_status(status, "NCryptExportKey(RSAPUBLICBLOB)") + blob = bytes(buf[:cb.value]) + + if len(blob) < 24: + raise MsiV2Error("[msi_v2] RSAPUBLICBLOB too small") + + magic, bitlen, cb_exp, cb_mod, cb_p1, cb_p2 = struct.unpack( + "<6I", blob[:24]) + if magic != win32["BCRYPT_RSAPUBLIC_MAGIC"]: + raise MsiV2Error( + f"[msi_v2] RSAPUBLICBLOB magic mismatch: 0x{magic:08X}") + + offset = 24 + if len(blob) < offset + cb_exp + cb_mod: + raise MsiV2Error("[msi_v2] RSAPUBLICBLOB truncated") + + exp_bytes = blob[offset:offset + cb_exp] + offset += cb_exp + mod_bytes = blob[offset:offset + cb_mod] + + exponent = int.from_bytes(exp_bytes, "big") + modulus = int.from_bytes(mod_bytes, "big") + return modulus, exponent + + +def _ncrypt_sign_pss_sha256( + win32: Dict[str, Any], key: Any, digest: bytes, +) -> bytes: + """Sign a SHA-256 digest using RSA-PSS via NCryptSignHash.""" + from .managed_identity import MsiV2Error + if len(digest) != 32: + raise MsiV2Error("[msi_v2] Expected SHA-256 digest (32 bytes)") + + ctypes_mod = win32["ctypes"] + wintypes = win32["wintypes"] + ncrypt = win32["ncrypt"] + + pad = win32["BCRYPT_PSS_PADDING_INFO"]( + win32["BCRYPT_SHA256_ALGORITHM"], 32) + hash_buf = (ctypes_mod.c_ubyte * len(digest)).from_buffer_copy(digest) + + cb_sig = wintypes.DWORD(0) + status = ncrypt.NCryptSignHash( + key, ctypes_mod.byref(pad), hash_buf, len(digest), + None, 0, ctypes_mod.byref(cb_sig), win32["BCRYPT_PAD_PSS"]) + if int(status) != 0 and cb_sig.value == 0: + _check_security_status(status, "NCryptSignHash(size)") + if cb_sig.value == 0: + raise MsiV2Error("[msi_v2] NCryptSignHash returned empty sig size") + + sig_buf = (ctypes_mod.c_ubyte * cb_sig.value)() + status = ncrypt.NCryptSignHash( + key, ctypes_mod.byref(pad), hash_buf, len(digest), + sig_buf, cb_sig.value, ctypes_mod.byref(cb_sig), + win32["BCRYPT_PAD_PSS"]) + _check_security_status(status, "NCryptSignHash") + return bytes(sig_buf[:cb_sig.value]) + + +# --------------------------------------------------------------------------- +# CSR builder +# --------------------------------------------------------------------------- + +def _build_csr_b64( + win32: Dict[str, Any], key: Any, + client_id: str, tenant_id: str, cu_id: Any, +) -> str: + """Build CSR signed by KeyGuard key (RSA-PSS SHA256), with cuId OID + attribute.""" + modulus, exponent = _ncrypt_export_rsa_public(win32, key) + subject = _der_name_cn_dc(client_id, tenant_id) + spki = _der_subject_public_key_info_rsa(modulus, exponent) + + cuid_json = json.dumps(cu_id, separators=(",", ":"), ensure_ascii=False) + cuid_val = _der_utf8string(cuid_json) + + attr = _der_sequence(_der_oid(_CU_ID_OID_STR), _der_set(cuid_val)) + attrs_content = b"".join(sorted([attr])) + attrs = _der_context_implicit_constructed(0, attrs_content) + + cri = _der_sequence(_der_integer(0), subject, spki, attrs) + digest = hashlib.sha256(cri).digest() + signature = _ncrypt_sign_pss_sha256(win32, key, digest) + + csr = _der_sequence(cri, _der_algid_rsapss_sha256(), + _der_bitstring(signature)) + return base64.b64encode(csr).decode("ascii") + + +# --------------------------------------------------------------------------- +# Certificate binding + WinHTTP mTLS +# --------------------------------------------------------------------------- + +def _create_cert_context_with_key( + win32: Dict[str, Any], cert_der: bytes, key: Any, key_name: str, + *, ksp_name: str = "Microsoft Software Key Storage Provider", +) -> Tuple[Any, Any, Tuple[Any, ...]]: + """Create a CERT_CONTEXT from DER bytes and associate it with a CNG + private key via multiple properties for SChannel compatibility.""" + ctypes_mod = win32["ctypes"] + wintypes = win32["wintypes"] + crypt32 = win32["crypt32"] + + enc = win32["X509_ASN_ENCODING"] | win32["PKCS_7_ASN_ENCODING"] + buf = ctypes_mod.create_string_buffer(cert_der) + ctx = crypt32.CertCreateCertificateContext(enc, buf, len(cert_der)) + if not ctx: + _raise_win32_last_error( + "[msi_v2] CertCreateCertificateContext failed") + + keepalive: List[Any] = [buf] + + try: + # (A) Direct NCrypt key handle + key_handle = ctypes_mod.c_void_p(int(key.value)) + keepalive.append(key_handle) + + ok = crypt32.CertSetCertificateContextProperty( + ctx, win32["CERT_NCRYPT_KEY_HANDLE_PROP_ID"], + win32["CERT_SET_PROPERTY_INHIBIT_PERSIST_FLAG"], + ctypes_mod.byref(key_handle)) + if not ok: + _raise_win32_last_error( + "[msi_v2] CertSetCertificateContextProperty" + "(CERT_NCRYPT_KEY_HANDLE_PROP_ID) failed") + + # (B) CERT_KEY_CONTEXT_PROP_ID (best-effort) + CERT_KEY_CONTEXT_PROP_ID = 5 + CERT_NCRYPT_KEY_SPEC = 0xFFFFFFFF + + class CERT_KEY_CONTEXT(ctypes_mod.Structure): + _fields_ = [ + ("cbSize", wintypes.DWORD), + ("hCryptProvOrNCryptKey", ctypes_mod.c_void_p), + ("dwKeySpec", wintypes.DWORD), + ] + + key_ctx = CERT_KEY_CONTEXT( + ctypes_mod.sizeof(CERT_KEY_CONTEXT), key_handle, + wintypes.DWORD(CERT_NCRYPT_KEY_SPEC)) + keepalive.append(key_ctx) + + ok = crypt32.CertSetCertificateContextProperty( + ctx, CERT_KEY_CONTEXT_PROP_ID, + win32["CERT_SET_PROPERTY_INHIBIT_PERSIST_FLAG"], + ctypes_mod.byref(key_ctx)) + if not ok: + logger.debug("[msi_v2] Failed to set CERT_KEY_CONTEXT_PROP_ID " + "(last_error=%s)", ctypes_mod.get_last_error()) + + # (C) CERT_KEY_PROV_INFO_PROP_ID (for SChannel reopen by name) + CERT_KEY_PROV_INFO_PROP_ID = 2 + + class CRYPT_KEY_PROV_INFO(ctypes_mod.Structure): + _fields_ = [ + ("pwszContainerName", wintypes.LPWSTR), + ("pwszProvName", wintypes.LPWSTR), + ("dwProvType", wintypes.DWORD), + ("dwFlags", wintypes.DWORD), + ("cProvParam", wintypes.DWORD), + ("rgProvParam", ctypes_mod.c_void_p), + ("dwKeySpec", wintypes.DWORD), + ] + + container_buf = ctypes_mod.create_unicode_buffer(str(key_name)) + provider_buf = ctypes_mod.create_unicode_buffer(str(ksp_name)) + keepalive.extend([container_buf, provider_buf]) + + prov_info = CRYPT_KEY_PROV_INFO( + ctypes_mod.cast(container_buf, wintypes.LPWSTR), + ctypes_mod.cast(provider_buf, wintypes.LPWSTR), + wintypes.DWORD(0), # CNG/KSP + wintypes.DWORD(_NCRYPT_SILENT_FLAG), + wintypes.DWORD(0), None, + wintypes.DWORD(_AT_SIGNATURE)) + keepalive.append(prov_info) + + ok = crypt32.CertSetCertificateContextProperty( + ctx, CERT_KEY_PROV_INFO_PROP_ID, + win32["CERT_SET_PROPERTY_INHIBIT_PERSIST_FLAG"], + ctypes_mod.byref(prov_info)) + if not ok: + logger.debug("[msi_v2] Failed to set CERT_KEY_PROV_INFO_PROP_ID " + "(last_error=%s)", ctypes_mod.get_last_error()) + + return ctx, buf, tuple(keepalive) + + except Exception: + try: + crypt32.CertFreeCertificateContext(ctx) + except Exception: + pass + raise + + +def _winhttp_close(win32: Dict[str, Any], h: Any) -> None: + try: + if h: + win32["winhttp"].WinHttpCloseHandle(h) + except Exception: + pass + + +def _winhttp_post( + win32: Dict[str, Any], url: str, cert_ctx: Any, + body: bytes, headers: Dict[str, str], +) -> Tuple[int, bytes]: + """POST to https URL using WinHTTP + SChannel with client cert.""" + from .managed_identity import MsiV2Error + ctypes_mod = win32["ctypes"] + wintypes = win32["wintypes"] + winhttp = win32["winhttp"] + + u = urlparse(url) + if u.scheme.lower() != "https": + raise MsiV2Error( + f"[msi_v2] Token endpoint must be https, got: {url!r}") + if not u.hostname: + raise MsiV2Error(f"[msi_v2] Invalid token endpoint: {url!r}") + + host = u.hostname + port = u.port or 443 + path = u.path or "/" + if u.query: + path += "?" + u.query + + h_session = winhttp.WinHttpOpen( + "msal-python-msi-v2", win32["WINHTTP_ACCESS_TYPE_DEFAULT_PROXY"], + None, None, 0) + if not h_session: + _raise_win32_last_error("[msi_v2] WinHttpOpen failed") + + h_connect = None + h_request = None + try: + # Best-effort: HTTP/2 + client cert + enable = wintypes.DWORD(1) + try: + winhttp.WinHttpSetOption( + h_session, + win32["WINHTTP_OPTION_ENABLE_HTTP2_PLUS_CLIENT_CERT"], + ctypes_mod.byref(enable), ctypes_mod.sizeof(enable)) + except Exception: + pass + + h_connect = winhttp.WinHttpConnect(h_session, host, int(port), 0) + if not h_connect: + _raise_win32_last_error("[msi_v2] WinHttpConnect failed") + + h_request = winhttp.WinHttpOpenRequest( + h_connect, "POST", path, None, None, None, + win32["WINHTTP_FLAG_SECURE"]) + if not h_request: + _raise_win32_last_error("[msi_v2] WinHttpOpenRequest failed") + + # Attach cert for mTLS + ok = winhttp.WinHttpSetOption( + h_request, win32["WINHTTP_OPTION_CLIENT_CERT_CONTEXT"], + cert_ctx, ctypes_mod.sizeof(win32["CERT_CONTEXT"])) + if not ok: + _raise_win32_last_error( + "[msi_v2] WinHttpSetOption(CLIENT_CERT) failed") + + header_lines = "".join(f"{k}: {v}\r\n" for k, v in headers.items()) + body_buf = ctypes_mod.create_string_buffer(body) + + ok = winhttp.WinHttpSendRequest( + h_request, header_lines, 0xFFFFFFFF, + body_buf, len(body), len(body), 0) + if not ok: + _raise_win32_last_error("[msi_v2] WinHttpSendRequest failed") + + ok = winhttp.WinHttpReceiveResponse(h_request, None) + if not ok: + _raise_win32_last_error("[msi_v2] WinHttpReceiveResponse failed") + + # Read status code + status = wintypes.DWORD(0) + status_size = wintypes.DWORD(ctypes_mod.sizeof(status)) + index = wintypes.DWORD(0) + + ok = winhttp.WinHttpQueryHeaders( + h_request, + win32["WINHTTP_QUERY_STATUS_CODE"] + | win32["WINHTTP_QUERY_FLAG_NUMBER"], + None, ctypes_mod.byref(status), + ctypes_mod.byref(status_size), ctypes_mod.byref(index)) + if not ok: + _raise_win32_last_error( + "[msi_v2] WinHttpQueryHeaders(STATUS_CODE) failed") + + # Read body + chunks: List[bytes] = [] + while True: + avail = wintypes.DWORD(0) + ok = winhttp.WinHttpQueryDataAvailable( + h_request, ctypes_mod.byref(avail)) + if not ok: + _raise_win32_last_error( + "[msi_v2] WinHttpQueryDataAvailable failed") + if avail.value == 0: + break + buf = (ctypes_mod.c_ubyte * avail.value)() + read = wintypes.DWORD(0) + ok = winhttp.WinHttpReadData( + h_request, buf, avail.value, ctypes_mod.byref(read)) + if not ok: + _raise_win32_last_error("[msi_v2] WinHttpReadData failed") + if read.value: + chunks.append(bytes(buf[:read.value])) + if read.value == 0: + break + + return int(status.value), b"".join(chunks) + finally: + _winhttp_close(win32, h_request) + _winhttp_close(win32, h_connect) + _winhttp_close(win32, h_session) + + +def _acquire_token_mtls_schannel( + win32: Dict[str, Any], token_endpoint: str, cert_ctx: Any, + client_id: str, scope: str, +) -> Dict[str, Any]: + """Acquire an mtls_pop token from ESTS using WinHTTP/SChannel.""" + from .managed_identity import MsiV2Error + + form = urlencode({ + "grant_type": "client_credentials", + "client_id": client_id, + "scope": scope, + "token_type": "mtls_pop", + }).encode("utf-8") + + status, resp_body = _winhttp_post( + win32, token_endpoint, cert_ctx, form, + headers={ + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json", + }) + + text = resp_body.decode("utf-8", errors="replace") + if status < 200 or status >= 300: + raise MsiV2Error( + f"[msi_v2] ESTS token request failed: " + f"HTTP {status} Body={text!r}") + return _json_loads(text, "ESTS token") + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +# Type alias for attestation provider callback. +# Signature: (endpoint, key_handle, client_id, cache_key) -> JWT string. +# cache_key is the stable per-boot key name for optimal caching. +AttestationTokenProvider = Callable[[str, int, str, str], str] + + +def obtain_token( + http_client, + managed_identity: Dict[str, Any], + resource: str, + *, + attestation_enabled: bool = True, + attestation_token_provider: Optional[AttestationTokenProvider] = None, +) -> Dict[str, Any]: + """ + Acquire mtls_pop token using Windows KeyGuard + optional MAA attestation. + + Flow: + 1. getplatformmetadata → client_id, tenant_id, cu_id, attestationEndpoint + 2. Open/create named per-boot KeyGuard RSA key (non-exportable) + 3. Build PKCS#10 CSR with cuId attribute, sign with RSA-PSS/SHA256 + 4. Get attestation JWT from MAA (if attestation_token_provider given) + 5. issuecredential → X.509 cert + 6. Create CERT_CONTEXT, bind to KeyGuard private key + 7. POST /oauth2/v2.0/token via WinHTTP/SChannel with mTLS + + Args: + http_client: HTTP client (e.g., requests.Session()) + managed_identity: MSAL managed identity dict + resource: Resource URI for token acquisition + attestation_enabled: Whether attestation is enabled + attestation_token_provider: Callback (endpoint, key_handle, + client_id, cache_key) -> JWT string. Provided by + msal-key-attestation package. cache_key is the stable + per-boot key name for optimal caching. None means + non-attested flow. + + Returns: + Token response dict with access_token, expires_in, token_type, + cert_pem, cert_der_b64, cert_thumbprint_sha256. + + Raises: + MsiV2Error: on any failure (no fallback to MSI v1) + """ + from .managed_identity import MsiV2Error + + win32 = _load_win32() + ncrypt = win32["ncrypt"] + crypt32 = win32["crypt32"] + + base = _imds_base() + params = _mi_query_params(managed_identity) + corr = _new_correlation_id() + + # Check certificate cache first + cache_key = _cert_cache_key( + managed_identity, attestation_token_provider is not None) + cached = _cert_cache_get(cache_key) + + prov = None + key = None + cert_ctx = None + cert_der = None + + try: + # 1) getplatformmetadata + meta_url = base + _CSR_METADATA_PATH + meta = _imds_get_json(http_client, meta_url, params, + _imds_headers(corr)) + + client_id = _get_first(meta, "clientId", "client_id") + tenant_id = _get_first(meta, "tenantId", "tenant_id") + cu_id = meta.get("cuId") if "cuId" in meta else meta.get("cu_id") + attestation_endpoint = _get_first( + meta, "attestationEndpoint", "attestation_endpoint") + + if not client_id or not tenant_id or cu_id is None: + raise MsiV2Error( + f"[msi_v2] getplatformmetadata missing required fields: " + f"{meta}") + + # 2) Open-or-create KeyGuard RSA key + key_name = (os.getenv(_KEY_NAME_ENVVAR) + or _stable_key_name(str(client_id))) + prov, key, key_name, opened = _open_or_create_keyguard_rsa_key( + win32, key_name=key_name) + logger.debug("[msi_v2] KeyGuard key=%s opened_existing=%s", + key_name, opened) + + # Use cached cert if available + if cached is not None: + cert_der = cached.cert_der + token_endpoint = cached.token_endpoint + canonical_client_id = cached.client_id + logger.debug("[msi_v2] Using cached certificate") + else: + # 3) Build CSR + csr_b64 = _build_csr_b64( + win32, key, str(client_id), str(tenant_id), cu_id) + + # 4) Attestation (if provider given) + att_jwt = "" + if attestation_enabled and attestation_token_provider is not None: + if not attestation_endpoint: + raise MsiV2Error( + "[msi_v2] attestationEndpoint missing from metadata.") + try: + att_jwt = attestation_token_provider( + str(attestation_endpoint), + int(key.value), + str(client_id), + str(key_name)) + except MsiV2Error: + raise + except Exception as exc: + raise MsiV2Error( + f"[msi_v2] Attestation provider failed: {exc}" + ) from exc + if not att_jwt or not str(att_jwt).strip(): + raise MsiV2Error( + "[msi_v2] Attestation provider returned empty JWT.") + + # 5) issuecredential + issue_url = base + _ISSUE_CREDENTIAL_PATH + issue_headers = _imds_headers(corr) + issue_headers["Content-Type"] = "application/json" + + cred = _imds_post_json( + http_client, issue_url, params, issue_headers, + {"csr": csr_b64, "attestation_token": att_jwt}) + + cert_b64 = _get_first(cred, "certificate", "Certificate") + if not cert_b64: + raise MsiV2Error( + f"[msi_v2] issuecredential missing certificate: {cred}") + + try: + cert_der = base64.b64decode(cert_b64) + except Exception as exc: + raise MsiV2Error( + "[msi_v2] issuecredential returned invalid base64 " + "certificate") from exc + + canonical_client_id = (_get_first(cred, "client_id", "clientId") + or str(client_id)) + token_endpoint = _token_endpoint_from_credential(cred) + + # Cache the cert + not_after = _try_parse_cert_not_after(cert_der) + _cert_cache_set(cache_key, _CertCacheEntry( + cert_der=cert_der, + cert_pem=_der_to_pem(cert_der), + token_endpoint=token_endpoint, + client_id=canonical_client_id, + not_after=not_after or (time.time() + 8 * 3600), + )) + + # 6) Create CERT_CONTEXT, bind to KeyGuard private key + cert_ctx, _, _ = _create_cert_context_with_key( + win32, cert_der, key, key_name) + scope = _resource_to_scope(resource) + + # 7) POST token via WinHTTP/SChannel mTLS + token_json = _acquire_token_mtls_schannel( + win32, token_endpoint, cert_ctx, canonical_client_id, scope) + + if token_json.get("access_token") and token_json.get("expires_in"): + cert_pem = _der_to_pem(cert_der) + cert_thumbprint = get_cert_thumbprint_sha256(cert_pem) + + return { + "access_token": token_json["access_token"], + "expires_in": int(token_json["expires_in"]), + "token_type": token_json.get("token_type") or "mtls_pop", + "resource": token_json.get("resource"), + "cert_pem": cert_pem, + "cert_der_b64": base64.b64encode( + cert_der).decode("ascii"), + "cert_thumbprint_sha256": cert_thumbprint, + } + return token_json + + except Exception: + # On failure, evict cached cert (may be stale/bad) + _cert_cache_remove(cache_key) + raise + + finally: + try: + if cert_ctx: + crypt32.CertFreeCertificateContext(cert_ctx) + except Exception: + pass + try: + if key: + ncrypt.NCryptFreeObject(key) + except Exception: + pass + try: + if prov: + ncrypt.NCryptFreeObject(prov) + except Exception: + pass diff --git a/msal/sku.py b/msal/sku.py index 01751048..19ff0138 100644 --- a/msal/sku.py +++ b/msal/sku.py @@ -1,6 +1,6 @@ -"""This module is from where we recieve the client sku name and version. +"""This module is from where we receive the client sku name and version. """ # The __init__.py will import this. Not the other way around. -__version__ = "1.35.0" +__version__ = "1.35.2rc1" SKU = "MSAL.Python" diff --git a/sample/MSI_V2_GUIDE.md b/sample/MSI_V2_GUIDE.md new file mode 100644 index 00000000..c726153e --- /dev/null +++ b/sample/MSI_V2_GUIDE.md @@ -0,0 +1,182 @@ +# MSI v2 (mTLS Proof-of-Possession) — Setup & Usage Guide + +## Overview + +MSI v2 enables Managed Identity token acquisition using mTLS Proof-of-Possession +on Windows Azure VMs with Credential Guard / KeyGuard. + +The implementation is split into two packages mirroring the MSAL .NET architecture: + +| Package | .NET Equivalent | What | +|---|---|---| +| `msal` | `Microsoft.Identity.Client` | Core mTLS PoP flow (KeyGuard key, CSR, IMDS, WinHTTP) | +| `msal-key-attestation` | `Microsoft.Identity.Client.KeyAttestation` | AttestationClientLib.dll native bindings | + +## Prerequisites + +1. **Windows Azure VM** with: + - Credential Guard / KeyGuard enabled (VBS) + - System-assigned or user-assigned managed identity + - Network access to IMDS (169.254.169.254) + +2. **AttestationClientLib.dll** — place in one of: + - Current working directory + - Same directory as `python.exe` + - Directory of the `msal_key_attestation` package + - Path specified by `ATTESTATION_CLIENTLIB_PATH` env var + +## Installation + +```bash +# Core MSAL package (includes MSI v2 flow) +pip install msal + +# Attestation support (loads AttestationClientLib.dll) +pip install msal-key-attestation +``` + +For development (from this repo): + +```bash +# Install msal in editable mode +pip install -e . + +# Install msal-key-attestation in editable mode +pip install -e msal-key-attestation/ +``` + +## Quick Start + +```python +import msal +import requests + +client = msal.ManagedIdentityClient( + msal.SystemAssignedManagedIdentity(), + http_client=requests.Session(), +) + +result = client.acquire_token_for_client( + resource="https://graph.microsoft.com", + mtls_proof_of_possession=True, + with_attestation_support=True, +) + +if "access_token" in result: + print(f"Token type: {result['token_type']}") # mtls_pop + print(f"Expires in: {result['expires_in']}s") + print(f"Thumbprint: {result['cert_thumbprint_sha256']}") +else: + print(f"Error: {result}") +``` + +## API Reference + +### `acquire_token_for_client()` — New Parameters + +| Parameter | Type | Default | Description | +|---|---|---|---| +| `mtls_proof_of_possession` | `bool` | `False` | Enable MSI v2 mTLS PoP flow | +| `with_attestation_support` | `bool` | `False` | Enable KeyGuard attestation (requires `msal-key-attestation`) | + +**Behavior matrix:** + +| `mtls_proof_of_possession` | `with_attestation_support` | Result | +|---|---|---| +| `False` | `False` | MSI v1 (default, unchanged) | +| `True` | `False` | MSI v1 fallthrough (PoP alone = no-op) | +| `False` | `True` | Raises `ManagedIdentityError` | +| `True` | `True` | **MSI v2** — KeyGuard + attestation + mTLS PoP | + +### Response + +When MSI v2 succeeds, the response dict includes extra fields: + +```python +{ + "access_token": "eyJ...", + "expires_in": 3600, + "token_type": "mtls_pop", + "resource": "https://graph.microsoft.com", + # Additional MSI v2 fields: + "cert_pem": "-----BEGIN CERTIFICATE-----\n...", + "cert_der_b64": "MIID...", + "cert_thumbprint_sha256": "abc123...", +} +``` + +### Errors + +| Error Class | When | +|---|---| +| `ManagedIdentityError` | `with_attestation_support=True` without `mtls_proof_of_possession` | +| `MsiV2Error` | Any MSI v2 flow failure (no fallback to v1) | +| `MsiV2Error` | `msal-key-attestation` package not installed | + +### Verification + +```python +from msal.msi_v2 import verify_cnf_binding + +bound = verify_cnf_binding(result["access_token"], result["cert_pem"]) +assert bound, "Token is not bound to the certificate" +``` + +## Flow Diagram + +``` +App MSAL Python IMDS ESTS (mTLS) + | | | | + |-- acquire_token ------>| | | + | (mtls_pop=True, | | | + | attestation=True) | | | + | | | | + | [1] NCrypt: KeyGuard key | | + | [2] GET /getplatformmetadata ->| | + | |<-- clientId, tenantId,| | + | | cuId, attestEP | | + | [3] Build CSR (RSA-PSS/SHA256) | | + | [4] AttestationClientLib.dll | | + | |--- MAA attest -------->| | + | |<-- attestation JWT | | + | [5] POST /issuecredential ---->| | + | |<-- certificate, endpoint | + | [6] Crypt32: bind cert to key | | + | [7] WinHTTP: POST /token ------|--------->| | + | | | | | + | |<-- mtls_pop token ----|---------| | + |<-- result -------------| | | +``` + +## Environment Variables + +| Variable | Description | +|---|---| +| `AZURE_POD_IDENTITY_AUTHORITY_HOST` | Override IMDS base URL | +| `MSAL_MSI_V2_KEY_NAME` | Override per-boot key name | +| `ATTESTATION_CLIENTLIB_PATH` | Full path to AttestationClientLib.dll | +| `MSAL_MSI_V2_ATTESTATION_CACHE` | `"0"` to disable MAA JWT caching | + +## Running the Sample + +```bash +cd sample/ +python msi_v2_sample.py + +# With verbose logging: +MSI_V2_VERBOSE=1 python msi_v2_sample.py + +# Custom resource: +RESOURCE=https://vault.azure.net python msi_v2_sample.py +``` + +## Running Tests + +```bash +# Core MSI v2 tests (no Windows/KeyGuard dependency) +pytest tests/test_msi_v2.py -v + +# Attestation package tests +cd msal-key-attestation/ +pytest tests/test_attestation.py -v +``` diff --git a/sample/msi_v2_sample.py b/sample/msi_v2_sample.py new file mode 100644 index 00000000..44be7ebf --- /dev/null +++ b/sample/msi_v2_sample.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +""" +MSI v2 (mTLS PoP) Sample — Managed Identity with KeyGuard Attestation. + +This sample demonstrates acquiring an mTLS Proof-of-Possession token using +MSAL Python's MSI v2 flow on a Windows Azure VM with Credential Guard. + +Prerequisites: + - Windows Azure VM with Credential Guard / KeyGuard enabled + - AttestationClientLib.dll accessible (next to script or via env var) + - pip install msal msal-key-attestation requests + +Usage: + python msi_v2_sample.py + +Environment variables (optional): + RESOURCE - Resource URI (default: https://graph.microsoft.com) + RESOURCE_URL - URL to call with the token (default: Graph /applications) + MSI_V2_VERBOSE - Set to "1" for verbose logging +""" + +import logging +import os +import sys + +import requests +import msal + +# Optional: enable verbose logging +if os.getenv("MSI_V2_VERBOSE", "").strip() in ("1", "true"): + logging.basicConfig(level=logging.DEBUG) +else: + logging.basicConfig(level=logging.INFO) + +logger = logging.getLogger(__name__) + + +def main(): + # --- Configuration --- + resource = os.getenv( + "RESOURCE", "https://graph.microsoft.com") + resource_url = os.getenv( + "RESOURCE_URL", + "https://mtlstb.graph.microsoft.com/v1.0/applications?$top=5") + + logger.info("=" * 60) + logger.info("MSI v2 (mTLS PoP) Sample") + logger.info("=" * 60) + logger.info("Resource: %s", resource) + logger.info("Resource URL: %s", resource_url) + + # --- Create client --- + http_session = requests.Session() + + # Optionally add retry + from requests.adapters import HTTPAdapter + try: + from urllib3.util.retry import Retry + retries = Retry(total=3, backoff_factor=0.5, + status_forcelist=[429, 500, 502, 503, 504]) + http_session.mount("https://", HTTPAdapter(max_retries=retries)) + http_session.mount("http://", HTTPAdapter(max_retries=retries)) + except ImportError: + pass + + client = msal.ManagedIdentityClient( + msal.SystemAssignedManagedIdentity(), + http_client=http_session, + ) + + # --- Acquire token --- + logger.info("Acquiring mTLS PoP token...") + result = client.acquire_token_for_client( + resource=resource, + mtls_proof_of_possession=True, + with_attestation_support=True, + ) + + if "access_token" not in result: + print("ERROR: Token acquisition failed. Check logs for details.") + sys.exit(1) + + token_type = result.get("token_type", "unknown") + expires_in = result.get("expires_in", 0) + + print("Token acquired successfully!") + print(f" token_type: {token_type}") + print(f" expires_in: {expires_in} seconds") + + if token_type != "mtls_pop": + print(f"WARNING: Expected token_type='mtls_pop' but got '{token_type}'.") + + # --- Verify binding --- + from msal.msi_v2 import verify_cnf_binding + cert_pem = result.get("cert_pem", "") + if cert_pem: + bound = verify_cnf_binding(result["access_token"], cert_pem) + logger.info(" cnf binding: %s", "VERIFIED" if bound else "FAILED") + if not bound: + logger.error("Token is NOT bound to the certificate!") + sys.exit(1) + + # --- Call resource over mTLS (optional) --- + if resource_url: + logger.info("Calling resource: %s", resource_url) + + # Note: mTLS resource calls require presenting the same cert. + # The cert + private key are bound via KeyGuard; a real mTLS call + # would use WinHTTP/SChannel. This demonstrates the auth header. + access_token = result["access_token"] + headers = { + "Authorization": f"{token_type} {access_token}", + "Accept": "application/json", + } + + try: + resp = http_session.get(resource_url, headers=headers) + logger.info(" Status: %d", resp.status_code) + if not resp.ok: + logger.warning(" Request failed with status %d", + resp.status_code) + except Exception as exc: + logger.warning(" Resource call failed: %s", type(exc).__name__) + logger.info( + "Note: mTLS resource calls may require WinHTTP/SChannel; " + "the requests library may not present the mTLS cert.") + + logger.info("=" * 60) + logger.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/setup.cfg b/setup.cfg index c770c9f1..49260d02 100644 --- a/setup.cfg +++ b/setup.cfg @@ -72,3 +72,6 @@ broker = [options.packages.find] exclude = tests + +[tool:pytest] +testpaths = tests diff --git a/tests/lab_config.py b/tests/lab_config.py index f99b1683..8b692e3f 100644 --- a/tests/lab_config.py +++ b/tests/lab_config.py @@ -20,7 +20,6 @@ app = get_app_config(AppSecrets.PCA_CLIENT) Environment Variables: - LAB_APP_CLIENT_ID: Client ID for Key Vault authentication (required) LAB_APP_CLIENT_CERT_PFX_PATH: Path to .pfx certificate file (required) """ @@ -37,6 +36,7 @@ __all__ = [ # Constants + "LAB_APP_CLIENT_ID", "UserSecrets", "AppSecrets", # Data classes @@ -48,6 +48,7 @@ "get_app_config", "get_user_password", "get_client_certificate", + "clean_env", ] # ============================================================================= @@ -57,6 +58,12 @@ _MSID_LAB_VAULT = "https://msidlabs.vault.azure.net" _MSAL_TEAM_VAULT = "https://id4skeyvault.vault.azure.net" +# Client ID for the RequestMSIDLAB app used to authenticate against the lab +# Key Vaults. Hardcoded here following the same pattern as MSAL.NET +# (see build/template-install-keyvault-secrets.yaml in that repo). +# See https://docs.msidlab.com/accounts/confidentialclient.html +LAB_APP_CLIENT_ID = "f62c5ae3-bf3a-4af5-afa8-a68b800396e9" + # ============================================================================= # Secret Name Constants # ============================================================================= @@ -164,6 +171,21 @@ class AppConfig: _msal_team_client: Optional[SecretClient] = None +def clean_env(name: str) -> Optional[str]: + """Return the env var value, or None if unset or it contains an unexpanded + ADO pipeline variable literal such as ``$(VAR_NAME)``. + + Azure DevOps injects the literal string ``$(VAR_NAME)`` when a ``$(...)`` + reference in a step ``env:`` block refers to a variable that has not been + defined at runtime. That literal is truthy, so a plain ``os.getenv()`` + check would incorrectly proceed as if the variable were set. + """ + value = os.getenv(name) + if value and value.startswith("$("): + return None + return value or None + + def _get_credential(): """ Create an Azure credential for Key Vault access. @@ -177,19 +199,14 @@ def _get_credential(): Raises: EnvironmentError: If required environment variables are not set. """ - client_id = os.getenv("LAB_APP_CLIENT_ID") - cert_path = os.getenv("LAB_APP_CLIENT_CERT_PFX_PATH") + cert_path = clean_env("LAB_APP_CLIENT_CERT_PFX_PATH") tenant_id = "72f988bf-86f1-41af-91ab-2d7cd011db47" # Microsoft tenant - - if not client_id: - raise EnvironmentError( - "LAB_APP_CLIENT_ID environment variable is required for Key Vault access") - + if cert_path: logger.debug("Using certificate credential for Key Vault access") return CertificateCredential( tenant_id=tenant_id, - client_id=client_id, + client_id=LAB_APP_CLIENT_ID, certificate_path=cert_path, send_certificate_chain=True, ) @@ -396,7 +413,7 @@ def get_client_certificate() -> Dict[str, object]: Raises: EnvironmentError: If LAB_APP_CLIENT_CERT_PFX_PATH is not set. """ - cert_path = os.getenv("LAB_APP_CLIENT_CERT_PFX_PATH") + cert_path = clean_env("LAB_APP_CLIENT_CERT_PFX_PATH") if not cert_path: raise EnvironmentError( "LAB_APP_CLIENT_CERT_PFX_PATH environment variable is required " @@ -407,4 +424,4 @@ def get_client_certificate() -> Dict[str, object]: return { "private_key_pfx_path": cert_path, "public_certificate": True, # Enable SNI (send certificate chain) - } \ No newline at end of file + } diff --git a/tests/test_e2e.py b/tests/test_e2e.py index 1202d443..37632ee7 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -1,5 +1,4 @@ """If the following ENV VAR were available, many end-to-end test cases would run. -LAB_APP_CLIENT_ID=... LAB_APP_CLIENT_CERT_PFX_PATH=... """ try: @@ -29,7 +28,7 @@ from tests.broker_util import is_pymsalruntime_installed from tests.lab_config import ( get_user_config, get_app_config, get_user_password, get_secret, - UserSecrets, AppSecrets, + UserSecrets, AppSecrets, LAB_APP_CLIENT_ID, clean_env, ) @@ -44,7 +43,17 @@ _PYMSALRUNTIME_INSTALLED = is_pymsalruntime_installed() _AZURE_CLI = "04b07795-8ddb-461a-bbee-02f9e1bf7b46" -_SKIP_UNATTENDED_E2E_TESTS = os.getenv("TRAVIS") or not os.getenv("CI") +# Skip interactive / browser-dependent tests when: +# - on Travis CI (TRAVIS), or +# - on Azure DevOps (TF_BUILD) where there is no display/browser on the agent, or +# - not running in any CI environment at all (not CI). +# Service-principal and ROPC tests are NOT gated on this flag; only tests that +# call acquire_token_interactive() or acquire_token_by_device_flow() are. +_SKIP_UNATTENDED_E2E_TESTS = ( + os.getenv("TRAVIS") or os.getenv("TF_BUILD") or not os.getenv("CI") +) + + def _get_app_and_auth_code( client_id, @@ -329,13 +338,16 @@ def test_access_token_should_be_obtained_for_a_supported_scope(self): self.assertIsNotNone(result.get("access_token")) -@unittest.skipIf(os.getenv("TF_BUILD"), "Skip PublicCloud scenarios on Azure DevOps") class PublicCloudScenariosTestCase(E2eTestCase): # Historically this class was driven by tests/config.json for semi-automated runs. - # It now uses lab config + env vars so it can run automatically without local files. + # It now uses lab config + env vars so it can run automatically on any CI + # (including Azure DevOps) as long as LAB_APP_CLIENT_CERT_PFX_PATH is set. @classmethod def setUpClass(cls): + if not clean_env("LAB_APP_CLIENT_CERT_PFX_PATH"): + raise unittest.SkipTest( + "LAB_APP_CLIENT_CERT_PFX_PATH not set; skipping PublicCloud e2e tests") pca_app = get_app_config(AppSecrets.PCA_CLIENT) user = get_user_config(UserSecrets.PUBLIC_CLOUD) cls.config = { @@ -416,13 +428,11 @@ def test_client_secret(self): def test_subject_name_issuer_authentication(self): from tests.lab_config import get_client_certificate - - client_id = os.getenv("LAB_APP_CLIENT_ID") - if not client_id: - self.skipTest("LAB_APP_CLIENT_ID environment variable is required") + if not clean_env("LAB_APP_CLIENT_CERT_PFX_PATH"): + self.skipTest("LAB_APP_CLIENT_CERT_PFX_PATH not set") self.app = msal.ConfidentialClientApplication( - client_id, + LAB_APP_CLIENT_ID, authority="https://login.microsoftonline.com/microsoft.onmicrosoft.com", client_credential=get_client_certificate(), http_client=MinimalHttpClient()) @@ -447,7 +457,6 @@ def manual_test_device_flow(self): def get_lab_app( - env_client_id="LAB_APP_CLIENT_ID", env_client_cert_path="LAB_APP_CLIENT_CERT_PFX_PATH", authority="https://login.microsoftonline.com/" "72f988bf-86f1-41af-91ab-2d7cd011db47", # Microsoft tenant ID @@ -455,19 +464,20 @@ def get_lab_app( **kwargs): """Returns the lab app as an MSAL confidential client. - Get it from environment variables if defined, otherwise fall back to use MSI. + Uses the hardcoded lab app client ID (RequestMSIDLAB) and a certificate + from the LAB_APP_CLIENT_CERT_PFX_PATH env var. """ logger.info( - "Reading ENV variables %s and %s for lab app defined at " + "Reading ENV variable %s for lab app defined at " "https://docs.msidlab.com/accounts/confidentialclient.html", - env_client_id, env_client_cert_path) - if os.getenv(env_client_id) and os.getenv(env_client_cert_path): + env_client_cert_path) + cert_path = clean_env(env_client_cert_path) + if cert_path: # id came from https://docs.msidlab.com/accounts/confidentialclient.html - client_id = os.getenv(env_client_id) client_credential = { "private_key_pfx_path": # Cert came from https://ms.portal.azure.com/#@microsoft.onmicrosoft.com/asset/Microsoft_Azure_KeyVault/Certificate/https://msidlabs.vault.azure.net/certificates/LabAuth - os.getenv(env_client_cert_path), + cert_path, "public_certificate": True, # Opt in for SNI } else: @@ -475,7 +485,7 @@ def get_lab_app( # See also https://microsoft.sharepoint-df.com/teams/MSIDLABSExtended/SitePages/Programmatically-accessing-LAB-API's.aspx raise unittest.SkipTest("MSI-based mechanism has not been implemented yet") return msal.ConfidentialClientApplication( - client_id, + LAB_APP_CLIENT_ID, client_credential=client_credential, authority=authority, http_client=MinimalHttpClient(timeout=timeout), @@ -831,7 +841,6 @@ def test_user_account(self): class WorldWideTestCase(LabBasedTestCase): - _ADFS_LABS_UNAVAILABLE = "ADFS labs were temporarily down since July 2025 until further notice" def test_aad_managed_user(self): # Pure cloud """Test username/password flow for a managed AAD user.""" @@ -846,7 +855,6 @@ def test_aad_managed_user(self): # Pure cloud scope=["https://graph.microsoft.com/.default"], ) - @unittest.skip(_ADFS_LABS_UNAVAILABLE) def test_adfs2022_fed_user(self): """Test username/password flow for a federated user via ADFS 2022.""" app = get_app_config(AppSecrets.PCA_CLIENT) @@ -1159,18 +1167,15 @@ def _test_acquire_token_for_client(self, configured_region, expected_region): Uses the lab app certificate for authentication. """ - import os from tests.lab_config import get_client_certificate - # Get client ID from environment and certificate from lab_config - client_id = os.getenv("LAB_APP_CLIENT_ID") - if not client_id: - self.skipTest("LAB_APP_CLIENT_ID environment variable is required") - + # Get client ID from lab_config constant and certificate from lab_config + if not clean_env("LAB_APP_CLIENT_CERT_PFX_PATH"): + self.skipTest("LAB_APP_CLIENT_CERT_PFX_PATH is required") client_credential = get_client_certificate() self.app = msal.ConfidentialClientApplication( - client_id, + LAB_APP_CLIENT_ID, client_credential=client_credential, authority="https://login.microsoftonline.com/microsoft.onmicrosoft.com", azure_region=configured_region, diff --git a/tests/test_msi_v2.py b/tests/test_msi_v2.py new file mode 100644 index 00000000..2b803ba8 --- /dev/null +++ b/tests/test_msi_v2.py @@ -0,0 +1,494 @@ +# Copyright (c) Microsoft Corporation. +# All rights reserved. +# +# This code is licensed under the MIT License. +"""Tests for MSI v2 (mTLS PoP) implementation. + +Goals: +- Provide strong unit coverage without depending on KeyGuard / real IMDS. +- Validate: + * x5t#S256 helper correctness (local) + * verify_cnf_binding behavior (msal.msi_v2) + * Certificate cache behavior + * ManagedIdentityClient strict gating behavior + * IMDS wire-contract helpers +""" + +import base64 +import datetime +import hashlib +import json +import os +import time +import unittest + +try: + from unittest.mock import patch, MagicMock +except ImportError: + from mock import patch, MagicMock + +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.x509.oid import NameOID + +import msal +from msal import MsiV2Error +from msal.msi_v2 import ( + verify_cnf_binding, + _cert_cache_clear, + _cert_cache_get, + _cert_cache_set, + _cert_cache_key, + _cert_cache_remove, + _CertCacheEntry, + _mi_query_params, + _resource_to_scope, + _token_endpoint_from_credential, + _der_to_pem, +) + +from tests.http_client import MinimalResponse + + +# --------------------------------------------------------------------------- +# Local helpers +# --------------------------------------------------------------------------- + +def _make_self_signed_cert(private_key, common_name="test"): + subject = issuer = x509.Name( + [x509.NameAttribute(NameOID.COMMON_NAME, common_name)]) + now = datetime.datetime.now(datetime.timezone.utc) + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now) + .not_valid_after(now + datetime.timedelta(days=1)) + .sign(private_key, hashes.SHA256(), default_backend()) + ) + return cert.public_bytes(serialization.Encoding.PEM).decode("utf-8") + + +def get_cert_thumbprint_sha256(cert_pem: str) -> str: + cert = x509.load_pem_x509_certificate( + cert_pem.encode("utf-8"), default_backend()) + cert_der = cert.public_bytes(serialization.Encoding.DER) + return (base64.urlsafe_b64encode(hashlib.sha256(cert_der).digest()) + .rstrip(b"=").decode("ascii")) + + +def _b64url(s: bytes) -> str: + return base64.urlsafe_b64encode(s).rstrip(b"=").decode("ascii") + + +def _make_jwt(payload_obj, header_obj=None) -> str: + header_obj = header_obj or {"alg": "RS256", "typ": "JWT"} + header = _b64url( + json.dumps(header_obj, separators=(",", ":")).encode("utf-8")) + payload = _b64url( + json.dumps(payload_obj, separators=(",", ":")).encode("utf-8")) + sig = _b64url(b"sig") + return f"{header}.{payload}.{sig}" + + +# --------------------------------------------------------------------------- +# Thumbprint helper +# --------------------------------------------------------------------------- + +class TestThumbprintHelper(unittest.TestCase): + def setUp(self): + self.key = rsa.generate_private_key( + public_exponent=65537, key_size=2048) + self.cert_pem = _make_self_signed_cert(self.key, "thumbprint-test") + + def test_returns_base64url_no_padding(self): + thumb = get_cert_thumbprint_sha256(self.cert_pem) + self.assertIsInstance(thumb, str) + self.assertNotIn("=", thumb) + decoded = base64.urlsafe_b64decode(thumb + "==") + self.assertEqual(len(decoded), 32) + + def test_same_cert_same_thumbprint(self): + t1 = get_cert_thumbprint_sha256(self.cert_pem) + t2 = get_cert_thumbprint_sha256(self.cert_pem) + self.assertEqual(t1, t2) + + def test_different_certs_different_thumbprints(self): + key2 = rsa.generate_private_key( + public_exponent=65537, key_size=2048) + cert2_pem = _make_self_signed_cert(key2, "thumbprint-test-2") + self.assertNotEqual( + get_cert_thumbprint_sha256(self.cert_pem), + get_cert_thumbprint_sha256(cert2_pem)) + + def test_matches_manual_sha256_der(self): + cert = x509.load_pem_x509_certificate( + self.cert_pem.encode("utf-8"), default_backend()) + cert_der = cert.public_bytes(serialization.Encoding.DER) + expected = (base64.urlsafe_b64encode( + hashlib.sha256(cert_der).digest()) + .rstrip(b"=").decode("ascii")) + self.assertEqual(get_cert_thumbprint_sha256(self.cert_pem), expected) + + +# --------------------------------------------------------------------------- +# verify_cnf_binding +# --------------------------------------------------------------------------- + +class TestVerifyCnfBinding(unittest.TestCase): + def setUp(self): + self.key = rsa.generate_private_key( + public_exponent=65537, key_size=2048) + self.cert_pem = _make_self_signed_cert(self.key, "cnf-test") + self.thumbprint = get_cert_thumbprint_sha256(self.cert_pem) + + def test_valid_binding_true(self): + token = _make_jwt({"cnf": {"x5t#S256": self.thumbprint}}) + self.assertTrue(verify_cnf_binding(token, self.cert_pem)) + + def test_wrong_thumbprint_false(self): + token = _make_jwt({"cnf": {"x5t#S256": "wrong"}}) + self.assertFalse(verify_cnf_binding(token, self.cert_pem)) + + def test_missing_cnf_false(self): + token = _make_jwt({"sub": "nobody"}) + self.assertFalse(verify_cnf_binding(token, self.cert_pem)) + + def test_missing_x5t_false(self): + token = _make_jwt({"cnf": {}}) + self.assertFalse(verify_cnf_binding(token, self.cert_pem)) + + def test_cnf_not_object_false(self): + token = _make_jwt({"cnf": "not-an-object"}) + self.assertFalse(verify_cnf_binding(token, self.cert_pem)) + + def test_not_a_jwt_false(self): + self.assertFalse(verify_cnf_binding("notajwt", self.cert_pem)) + + def test_two_part_jwt_false(self): + self.assertFalse(verify_cnf_binding("a.b", self.cert_pem)) + + def test_four_part_jwt_false(self): + self.assertFalse(verify_cnf_binding("a.b.c.d", self.cert_pem)) + + def test_malformed_payload_base64_false(self): + self.assertFalse(verify_cnf_binding("header.!!!.sig", self.cert_pem)) + + def test_payload_not_json_false(self): + header = _b64url(b'{"alg":"none"}') + payload = _b64url(b"not-json") + self.assertFalse( + verify_cnf_binding(f"{header}.{payload}.sig", self.cert_pem)) + + def test_payload_with_padding_works(self): + header = base64.urlsafe_b64encode( + b'{"alg":"RS256"}').decode("ascii") + payload = base64.urlsafe_b64encode(json.dumps( + {"cnf": {"x5t#S256": self.thumbprint}}).encode("utf-8") + ).decode("ascii") + self.assertTrue( + verify_cnf_binding(f"{header}.{payload}.sig", self.cert_pem)) + + def test_unicode_in_payload(self): + token = _make_jwt({ + "cnf": {"x5t#S256": self.thumbprint}, "msg": "こんにちは"}) + self.assertTrue(verify_cnf_binding(token, self.cert_pem)) + + +# --------------------------------------------------------------------------- +# Certificate cache +# --------------------------------------------------------------------------- + +class TestCertificateCache(unittest.TestCase): + def setUp(self): + _cert_cache_clear() + + def tearDown(self): + _cert_cache_clear() + + def _make_entry(self, *, not_after=None): + return _CertCacheEntry( + cert_der=b"fake-der", + cert_pem="-----BEGIN CERTIFICATE-----\nfake\n" + "-----END CERTIFICATE-----", + token_endpoint="https://login.microsoftonline.com/t/oauth2/v2.0/token", + client_id="test-client-id", + not_after=not_after or (time.time() + 48 * 3600), + ) + + def test_set_and_get(self): + entry = self._make_entry() + _cert_cache_set("k1", entry) + got = _cert_cache_get("k1") + self.assertIsNotNone(got) + self.assertEqual(got.cert_der, b"fake-der") + self.assertEqual(got.client_id, "test-client-id") + + def test_miss_returns_none(self): + self.assertIsNone(_cert_cache_get("no-such-key")) + + def test_expired_entry_evicted(self): + entry = self._make_entry(not_after=time.time() + 100) + _cert_cache_set("k2", entry) + # Force it to look expired + entry.not_after = time.time() - 1 + self.assertIsNone(_cert_cache_get("k2")) + + def test_insufficient_lifetime_not_cached(self): + # Not enough remaining lifetime (< 24h) + entry = self._make_entry(not_after=time.time() + 3600) + _cert_cache_set("k3", entry) + self.assertIsNone(_cert_cache_get("k3")) + + def test_remove(self): + entry = self._make_entry() + _cert_cache_set("k4", entry) + _cert_cache_remove("k4") + self.assertIsNone(_cert_cache_get("k4")) + + def test_clear(self): + _cert_cache_set("k5", self._make_entry()) + _cert_cache_set("k6", self._make_entry()) + _cert_cache_clear() + self.assertIsNone(_cert_cache_get("k5")) + self.assertIsNone(_cert_cache_get("k6")) + + def test_cache_key_generation(self): + mi_sys = {"ManagedIdentityIdType": "SystemAssigned", "Id": None} + mi_user = {"ManagedIdentityIdType": "ClientId", "Id": "abc"} + mi_obj = {"ManagedIdentityIdType": "ObjectId", "Id": "abc"} + k1 = _cert_cache_key(mi_sys, True) + k2 = _cert_cache_key(mi_sys, False) + k3 = _cert_cache_key(mi_user, True) + k4 = _cert_cache_key(mi_obj, True) + self.assertNotEqual(k1, k2) + self.assertNotEqual(k1, k3) + # Same Id but different IdType must produce different keys + self.assertNotEqual(k3, k4) + self.assertIn("#att=1", k1) + self.assertIn("#att=0", k2) + self.assertIn("ClientId:", k3) + self.assertIn("ObjectId:", k4) + + +# --------------------------------------------------------------------------- +# IMDS wire-contract helpers +# --------------------------------------------------------------------------- + +class TestImdsHelpers(unittest.TestCase): + def test_mi_query_params_system_assigned(self): + p = _mi_query_params( + {"ManagedIdentityIdType": "SystemAssigned", "Id": None}) + self.assertEqual(p["cred-api-version"], "2.0") + self.assertNotIn("client_id", p) + + def test_mi_query_params_client_id(self): + p = _mi_query_params( + {"ManagedIdentityIdType": "ClientId", "Id": "abc"}) + self.assertEqual(p["client_id"], "abc") + + def test_mi_query_params_object_id(self): + p = _mi_query_params( + {"ManagedIdentityIdType": "ObjectId", "Id": "oid"}) + self.assertEqual(p["object_id"], "oid") + + def test_mi_query_params_resource_id(self): + p = _mi_query_params( + {"ManagedIdentityIdType": "ResourceId", "Id": "/sub/..."}) + self.assertEqual(p["msi_res_id"], "/sub/...") + + def test_resource_to_scope_appends_default(self): + self.assertEqual( + _resource_to_scope("https://graph.microsoft.com"), + "https://graph.microsoft.com/.default") + + def test_resource_to_scope_preserves_existing(self): + self.assertEqual( + _resource_to_scope("https://graph.microsoft.com/.default"), + "https://graph.microsoft.com/.default") + + def test_resource_to_scope_strips_trailing_slash(self): + self.assertEqual( + _resource_to_scope("https://graph.microsoft.com/"), + "https://graph.microsoft.com/.default") + + def test_resource_to_scope_raises_on_empty(self): + with self.assertRaises(ValueError): + _resource_to_scope("") + + def test_token_endpoint_prefers_explicit(self): + cred = {"token_endpoint": "https://explicit.com/token", + "mtls_authentication_endpoint": "https://other"} + self.assertEqual( + _token_endpoint_from_credential(cred), + "https://explicit.com/token") + + def test_token_endpoint_falls_back_to_mtls_auth(self): + cred = { + "mtls_authentication_endpoint": "https://login.example.com", + "tenant_id": "tid", + } + self.assertEqual( + _token_endpoint_from_credential(cred), + "https://login.example.com/tid/oauth2/v2.0/token") + + def test_token_endpoint_raises_on_missing(self): + with self.assertRaises(MsiV2Error): + _token_endpoint_from_credential({}) + + +# --------------------------------------------------------------------------- +# ManagedIdentityClient gating +# --------------------------------------------------------------------------- + +class TestManagedIdentityClientStrictGating(unittest.TestCase): + def _make_client(self): + import requests + return msal.ManagedIdentityClient( + msal.SystemAssignedManagedIdentity(), + http_client=requests.Session(), + ) + + def test_error_is_exported(self): + self.assertIs(msal.MsiV2Error, MsiV2Error) + + def test_error_is_subclass(self): + self.assertTrue(issubclass(MsiV2Error, msal.ManagedIdentityError)) + + @patch("msal.managed_identity._obtain_token") + def test_default_path_calls_v1(self, mock_v1): + mock_v1.return_value = { + "access_token": "V1", "expires_in": 3600, "token_type": "Bearer"} + client = self._make_client() + res = client.acquire_token_for_client(resource="R") + self.assertEqual(res["access_token"], "V1") + mock_v1.assert_called_once() + + def test_attestation_requires_pop(self): + client = self._make_client() + with self.assertRaises(msal.ManagedIdentityError): + client.acquire_token_for_client( + resource="R", + mtls_proof_of_possession=False, + with_attestation_support=True) + + @patch("msal.msi_v2.obtain_token") + @patch("msal.managed_identity._obtain_token") + def test_pop_without_attestation_does_not_call_v2( + self, mock_v1, mock_v2): + mock_v1.return_value = { + "access_token": "V1", "expires_in": 3600, "token_type": "Bearer"} + client = self._make_client() + res = client.acquire_token_for_client( + resource="R", + mtls_proof_of_possession=True, + with_attestation_support=False) + self.assertEqual(res["token_type"], "Bearer") + mock_v2.assert_not_called() + mock_v1.assert_called_once() + + @patch("msal.managed_identity.create_attestation_provider", + create=True) + @patch("msal.msi_v2.obtain_token") + def test_v2_called_when_both_flags_true(self, mock_v2, _): + mock_v2.return_value = { + "access_token": "V2", "expires_in": 3600, + "token_type": "mtls_pop"} + client = self._make_client() + + with patch.dict("sys.modules", { + "msal_key_attestation": MagicMock( + create_attestation_provider=MagicMock( + return_value=lambda ep, kh, ci, ck="": "fake.jwt")) + }): + res = client.acquire_token_for_client( + resource="https://mtlstb.graph.microsoft.com", + mtls_proof_of_possession=True, + with_attestation_support=True) + + self.assertEqual(res["token_type"], "mtls_pop") + mock_v2.assert_called_once() + args, kwargs = mock_v2.call_args + self.assertTrue(len(args) >= 3) + self.assertEqual(args[2], "https://mtlstb.graph.microsoft.com") + self.assertTrue(kwargs["attestation_enabled"]) + + @patch("msal.msi_v2.obtain_token", side_effect=MsiV2Error("boom")) + @patch("msal.managed_identity._obtain_token") + def test_strict_v2_failure_raises_no_v1_fallback( + self, mock_v1, mock_v2): + client = self._make_client() + with patch.dict("sys.modules", { + "msal_key_attestation": MagicMock( + create_attestation_provider=MagicMock( + return_value=lambda ep, kh, ci, ck="": "fake.jwt")) + }): + with self.assertRaises(MsiV2Error): + client.acquire_token_for_client( + resource="https://mtlstb.graph.microsoft.com", + mtls_proof_of_possession=True, + with_attestation_support=True) + mock_v1.assert_not_called() + + @patch("msal.msi_v2.obtain_token", + side_effect=RuntimeError("DLL load failed")) + @patch("msal.managed_identity._obtain_token") + def test_runtime_error_wrapped_as_msi_v2_error( + self, mock_v1, mock_v2): + """RuntimeError from provider/DLL must surface as MsiV2Error.""" + client = self._make_client() + with patch.dict("sys.modules", { + "msal_key_attestation": MagicMock( + create_attestation_provider=MagicMock( + return_value=lambda ep, kh, ci, ck="": "fake.jwt")) + }): + with self.assertRaises(MsiV2Error) as ctx: + client.acquire_token_for_client( + resource="R", + mtls_proof_of_possession=True, + with_attestation_support=True) + self.assertIn("DLL load failed", str(ctx.exception)) + mock_v1.assert_not_called() + + def test_missing_attestation_package_raises_clear_error(self): + client = self._make_client() + with patch.dict("sys.modules", {"msal_key_attestation": None}): + with self.assertRaises(MsiV2Error) as ctx: + client.acquire_token_for_client( + resource="R", + mtls_proof_of_possession=True, + with_attestation_support=True) + self.assertIn("pip install msal-key-attestation", + str(ctx.exception)) + + +# --------------------------------------------------------------------------- +# DER helpers +# --------------------------------------------------------------------------- + +class TestDerHelpers(unittest.TestCase): + def test_der_to_pem_roundtrip(self): + key = rsa.generate_private_key( + public_exponent=65537, key_size=2048) + cert_pem = _make_self_signed_cert(key, "der-test") + cert = x509.load_pem_x509_certificate( + cert_pem.encode("utf-8"), default_backend()) + cert_der = cert.public_bytes(serialization.Encoding.DER) + + pem_out = _der_to_pem(cert_der) + self.assertIn("-----BEGIN CERTIFICATE-----", pem_out) + self.assertIn("-----END CERTIFICATE-----", pem_out) + + # Verify the PEM round-trips back to same DER + cert2 = x509.load_pem_x509_certificate( + pem_out.encode("utf-8"), default_backend()) + self.assertEqual( + cert2.public_bytes(serialization.Encoding.DER), cert_der) + + +if __name__ == "__main__": + unittest.main()