diff --git a/cmdline/pom.xml b/cmdline/pom.xml index 3f82579a..82b9a924 100644 --- a/cmdline/pom.xml +++ b/cmdline/pom.xml @@ -77,5 +77,24 @@ sdk ${project.version} + + + org.bouncycastle + bcpkix-jdk18on + + + + org.junit.jupiter + junit-jupiter + test + + + org.assertj + assertj-core + 3.25.3 + test + diff --git a/cmdline/src/main/java/io/opentdf/platform/CliDpopOptions.java b/cmdline/src/main/java/io/opentdf/platform/CliDpopOptions.java new file mode 100644 index 00000000..30d155ac --- /dev/null +++ b/cmdline/src/main/java/io/opentdf/platform/CliDpopOptions.java @@ -0,0 +1,127 @@ +package io.opentdf.platform; + +import com.nimbusds.jose.JOSEException; +import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.jwk.Curve; +import com.nimbusds.jose.jwk.ECKey; +import com.nimbusds.jose.jwk.JWK; +import com.nimbusds.jose.jwk.KeyUse; +import com.nimbusds.jose.jwk.gen.ECKeyGenerator; +import com.nimbusds.jose.jwk.gen.RSAKeyGenerator; +import io.opentdf.platform.sdk.DpopKeyValidation; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Optional; +import java.util.UUID; + +final class CliDpopOptions { + private CliDpopOptions() { + } + + static final class DpopMaterial { + final JWK jwk; + final JWSAlgorithm alg; + + DpopMaterial(JWK jwk, JWSAlgorithm alg) { + this.jwk = jwk; + this.alg = alg; + } + } + + static Optional parse(String dpopAlg, Path dpopKeyPath) { + if (dpopKeyPath != null) { + JWK jwk = loadPrivateKey(dpopKeyPath); + JWSAlgorithm alg; + if (dpopAlg != null && !dpopAlg.isEmpty()) { + alg = parseAlgorithm(dpopAlg); + } else if (jwk instanceof ECKey) { + Curve curve = ((ECKey) jwk).getCurve(); + try { + alg = DpopKeyValidation.inferEcAlgorithm(curve); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException( + "DPoP key file " + dpopKeyPath + " uses unsupported EC curve " + curve, e); + } + } else { + alg = JWSAlgorithm.RS256; + } + try { + DpopKeyValidation.validate(jwk, alg); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException( + "DPoP key file " + dpopKeyPath + " is incompatible with --dpop=" + alg + ": " + e.getMessage(), + e); + } + return Optional.of(new DpopMaterial(jwk, alg)); + } + if (dpopAlg != null) { + JWSAlgorithm alg = dpopAlg.isEmpty() ? JWSAlgorithm.RS256 : parseAlgorithm(dpopAlg); + return Optional.of(new DpopMaterial(generateKeyForAlgorithm(alg), alg)); + } + return Optional.empty(); + } + + static JWSAlgorithm parseAlgorithm(String alg) { + switch (alg.toUpperCase()) { + case "RS256": return JWSAlgorithm.RS256; + case "RS384": return JWSAlgorithm.RS384; + case "RS512": return JWSAlgorithm.RS512; + case "ES256": return JWSAlgorithm.ES256; + case "ES384": return JWSAlgorithm.ES384; + case "ES512": return JWSAlgorithm.ES512; + default: + throw new IllegalArgumentException("Unsupported DPoP algorithm: " + alg + + ". Supported: RS256, RS384, RS512, ES256, ES384, ES512"); + } + } + + private static JWK loadPrivateKey(Path path) { + String pem; + try { + pem = Files.readString(path); + } catch (IOException e) { + throw new IllegalArgumentException("Cannot read DPoP key file " + path + ": " + e.getMessage(), e); + } + JWK jwk; + try { + jwk = JWK.parseFromPEMEncodedObjects(pem); + } catch (JOSEException e) { + throw new IllegalArgumentException( + "DPoP key file " + path + " is not a valid PEM-encoded key: " + e.getMessage(), e); + } + if (!jwk.isPrivate()) { + throw new IllegalArgumentException( + "DPoP key file " + path + " contains a public key only; a private key is required"); + } + return jwk; + } + + private static JWK generateKeyForAlgorithm(JWSAlgorithm alg) { + try { + if (JWSAlgorithm.RS256.equals(alg) || JWSAlgorithm.RS384.equals(alg) || JWSAlgorithm.RS512.equals(alg)) { + return new RSAKeyGenerator(2048) + .keyUse(KeyUse.SIGNATURE) + .keyID(UUID.randomUUID().toString()) + .generate(); + } + Curve curve; + if (JWSAlgorithm.ES256.equals(alg)) { + curve = Curve.P_256; + } else if (JWSAlgorithm.ES384.equals(alg)) { + curve = Curve.P_384; + } else if (JWSAlgorithm.ES512.equals(alg)) { + curve = Curve.P_521; + } else { + throw new IllegalArgumentException("Cannot generate key for algorithm: " + alg); + } + return new ECKeyGenerator(curve) + .keyUse(KeyUse.SIGNATURE) + .keyID(UUID.randomUUID().toString()) + .generate(); + } catch (JOSEException e) { + throw new IllegalArgumentException("Failed to generate DPoP key for algorithm " + alg + ": " + e.getMessage(), e); + } + } +} diff --git a/cmdline/src/main/java/io/opentdf/platform/Command.java b/cmdline/src/main/java/io/opentdf/platform/Command.java index 685f8782..bc9ba5cc 100644 --- a/cmdline/src/main/java/io/opentdf/platform/Command.java +++ b/cmdline/src/main/java/io/opentdf/platform/Command.java @@ -1,29 +1,14 @@ package io.opentdf.platform; import com.google.gson.Gson; +import com.google.gson.GsonBuilder; import com.google.gson.JsonDeserializationContext; import com.google.gson.JsonDeserializer; import com.google.gson.JsonElement; import com.google.gson.JsonObject; import com.google.gson.JsonParseException; -import com.nimbusds.jose.jwk.JWK; -import com.google.gson.GsonBuilder; -import com.google.gson.reflect.TypeToken; - -import java.security.cert.X509Certificate; -import java.text.ParseException; import com.google.gson.JsonSyntaxException; -import io.opentdf.platform.sdk.AssertionConfig; -import io.opentdf.platform.sdk.AutoConfigureException; -import io.opentdf.platform.sdk.Config; -import io.opentdf.platform.sdk.KeyType; -import io.opentdf.platform.sdk.SDK; -import io.opentdf.platform.sdk.SDKBuilder; -import picocli.CommandLine; -import picocli.CommandLine.HelpCommand; -import picocli.CommandLine.Option; - -import javax.net.ssl.X509TrustManager; +import com.google.gson.reflect.TypeToken; import java.io.BufferedInputStream; import java.io.BufferedOutputStream; import java.io.File; @@ -41,12 +26,25 @@ import java.security.spec.InvalidKeySpecException; import java.security.spec.PKCS8EncodedKeySpec; import java.security.spec.X509EncodedKeySpec; +import java.text.ParseException; import java.util.ArrayList; import java.util.Base64; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.concurrent.Callable; import java.util.function.Consumer; +import io.opentdf.platform.sdk.AssertionConfig; +import io.opentdf.platform.sdk.AutoConfigureException; +import io.opentdf.platform.sdk.Config; +import io.opentdf.platform.sdk.KeyType; +import io.opentdf.platform.sdk.SDK; +import io.opentdf.platform.sdk.SDKBuilder; +import org.apache.logging.log4j.Level; +import org.apache.logging.log4j.core.config.Configurator; +import picocli.CommandLine; +import picocli.CommandLine.HelpCommand; +import picocli.CommandLine.Option; /** * Constants for the TDF command line tool. @@ -60,18 +58,38 @@ class Versions { public static final String TDF_SPEC = "4.3.0"; } -@CommandLine.Command(name = "tdf", subcommands = { HelpCommand.class }, version = "{\"version\":\"" + Versions.SDK - + "\",\"tdfSpecVersion\":\"" + Versions.TDF_SPEC + "\"}") +@CommandLine.Command(name = "tdf", subcommands = { HelpCommand.class, + Command.Supports.class }, version = "{\"version\":\"" + Versions.SDK + + "\",\"tdfSpecVersion\":\"" + Versions.TDF_SPEC + "\"}") class Command { @Option(names = { "-V", "--version" }, versionHelp = true, description = "display version info") boolean versionInfoRequested; + // Picocli injects the parsed command spec here so buildSDK() can raise + // ParameterException with the right help context when required options + // are missing for encrypt/decrypt/metadata (which all call buildSDK()). + @CommandLine.Spec + CommandLine.Model.CommandSpec spec; + + @CommandLine.Command(name = "supports", description = "Check if a feature is supported") + static class Supports implements Callable { + @CommandLine.Parameters(index = "0", description = "Feature to check (e.g., dpop)") + private String feature; + + @Override + public Integer call() { + return ("dpop".equalsIgnoreCase(feature) || "dpop_nonce_challenge".equalsIgnoreCase(feature)) ? 0 : 1; + } + } + private static class AssertionKeyDeserializer implements JsonDeserializer { @Override - public AssertionConfig.AssertionKey deserialize(JsonElement json, java.lang.reflect.Type typeOfT, JsonDeserializationContext context) throws JsonParseException { + public AssertionConfig.AssertionKey deserialize(JsonElement json, java.lang.reflect.Type typeOfT, + JsonDeserializationContext context) throws JsonParseException { JsonObject jsonObject = json.getAsJsonObject(); - AssertionConfig.AssertionKey assertionKey = new AssertionConfig.AssertionKey(AssertionConfig.AssertionKeyAlg.NotDefined, null); + AssertionConfig.AssertionKey assertionKey = new AssertionConfig.AssertionKey( + AssertionConfig.AssertionKeyAlg.NotDefined, null); if (jsonObject.has("alg")) { assertionKey.alg = context.deserialize(jsonObject.get("alg"), AssertionConfig.AssertionKeyAlg.class); @@ -81,13 +99,15 @@ public AssertionConfig.AssertionKey deserialize(JsonElement json, java.lang.refl } if (jsonObject.has("jwk")) { try { - assertionKey.jwk = JWK.parse(jsonObject.get("jwk").toString()); + assertionKey.jwk = com.nimbusds.jose.jwk.JWK.parse(jsonObject.get("jwk").toString()); } catch (ParseException e) { throw new JsonParseException("Failed to parse jwk", e); } } if (jsonObject.has("x5c")) { - assertionKey.x5c = context.deserialize(jsonObject.get("x5c"), new TypeToken>() {}.getType()); + assertionKey.x5c = context.deserialize(jsonObject.get("x5c"), + new TypeToken>() { + }.getType()); } return assertionKey; @@ -105,7 +125,19 @@ private Gson buildGson() { private static final String PEM_HEADER = "-----BEGIN (.*)-----"; private static final String PEM_FOOTER = "-----END (.*)-----"; - @Option(names = { "--client-secret" }, required = true) + @Option(names = { "-v", "--verbose" }, scope = CommandLine.ScopeType.INHERIT, defaultValue = "false", description = "Enable verbose output including stack traces on error") + void setVerbose(boolean verbose) { + this.verbose = verbose; + if (verbose) { + var root = org.apache.logging.log4j.LogManager.getRootLogger(); + if (!root.getLevel().isLessSpecificThan(Level.DEBUG)) { + Configurator.setRootLevel(Level.DEBUG); + } + } + } + boolean verbose; + + @Option(names = { "--client-secret" }) private String clientSecret; @Option(names = { "-h", "--plaintext" }, defaultValue = "false") @@ -114,12 +146,20 @@ private Gson buildGson() { @Option(names = { "-i", "--insecure" }, defaultValue = "false") private boolean insecure; - @Option(names = { "--client-id" }, required = true) + @Option(names = { "--client-id" }) private String clientId; - @Option(names = { "-p", "--platform-endpoint" }, required = true) + @Option(names = { "-p", "--platform-endpoint" }) private String platformEndpoint; + @Option(names = { + "--dpop" }, arity = "0..1", fallbackValue = "", scope = CommandLine.ScopeType.INHERIT, description = "Enable DPoP (RFC 9449). Optional: specify algorithm (RS256, RS384, RS512, ES256, ES384, ES512). Default: RS256.") + private String dpopAlg; + + @Option(names = { + "--dpop-key" }, scope = CommandLine.ScopeType.INHERIT, description = "Enable DPoP using a PEM-encoded private key at . Algorithm inferred from key type. Combinable with --dpop=.") + private Path dpopKeyPath; + private Object correctKeyType(AssertionConfig.AssertionKeyAlg alg, Object key, boolean publicKey) throws RuntimeException { if (alg == AssertionConfig.AssertionKeyAlg.HS256) { @@ -261,16 +301,47 @@ void encrypt( } private SDK buildSDK() { + // The picocli @Option annotations on platformEndpoint/clientId/clientSecret are + // intentionally NOT marked required = true so that `tdf supports ` can + // run without credentials. Subcommands that actually build an SDK enforce them + // here so the failure surfaces as a normal picocli ParameterException (exit 2) + // rather than a deep SDK error. + if (platformEndpoint == null || platformEndpoint.isEmpty()) { + throw new CommandLine.ParameterException(spec.commandLine(), + "Missing required option: '--platform-endpoint='"); + } + if (clientId == null || clientId.isEmpty()) { + throw new CommandLine.ParameterException(spec.commandLine(), + "Missing required option: '--client-id='"); + } + if (clientSecret == null || clientSecret.isEmpty()) { + throw new CommandLine.ParameterException(spec.commandLine(), + "Missing required option: '--client-secret='"); + } + SDKBuilder builder = new SDKBuilder(); if (insecure) { builder.insecureSslFactory(); } + applyDPoPOptions(builder); + return builder.platformEndpoint(platformEndpoint) .clientSecret(clientId, clientSecret).useInsecurePlaintextConnection(plaintext) .build(); } + private void applyDPoPOptions(SDKBuilder builder) { + try { + CliDpopOptions.parse(dpopAlg, dpopKeyPath).ifPresent(m -> { + builder.dpopKey(m.jwk); + builder.dpopAlgorithm(m.alg); + }); + } catch (IllegalArgumentException e) { + throw new CommandLine.ParameterException(spec.commandLine(), e.getMessage()); + } + } + @CommandLine.Command(name = "decrypt") void decrypt( @Option(names = { "-f", "--file" }, required = true) Path tdfPath, @@ -300,7 +371,8 @@ void decrypt( // try it as a file path try { String fileJson = new String(Files.readAllBytes(Paths.get(assertionVerificationInput))); - assertionVerificationKeys = gson.fromJson(fileJson, Config.AssertionVerificationKeys.class); + assertionVerificationKeys = gson.fromJson(fileJson, + Config.AssertionVerificationKeys.class); } catch (JsonSyntaxException e2) { throw new RuntimeException("Failed to parse assertion verification keys from file", e2); } catch (Exception e3) { diff --git a/cmdline/src/main/java/io/opentdf/platform/TDF.java b/cmdline/src/main/java/io/opentdf/platform/TDF.java index 5d9a27c1..8ff9e88b 100644 --- a/cmdline/src/main/java/io/opentdf/platform/TDF.java +++ b/cmdline/src/main/java/io/opentdf/platform/TDF.java @@ -4,7 +4,16 @@ public class TDF { public static void main(String[] args) { - var result = new CommandLine(new Command()).execute(args); - System.exit(result); + var command = new Command(); + var cmd = new CommandLine(command); + cmd.setExecutionExceptionHandler((ex, commandLine, parseResult) -> { + if (command.verbose) { + ex.printStackTrace(System.err); + } else { + System.err.println(ex.getMessage() != null ? ex.getMessage() : ex.toString()); + } + return 1; + }); + System.exit(cmd.execute(args)); } } \ No newline at end of file diff --git a/cmdline/src/main/resources/log4j2.xml b/cmdline/src/main/resources/log4j2.xml index da185a60..a025e82a 100644 --- a/cmdline/src/main/resources/log4j2.xml +++ b/cmdline/src/main/resources/log4j2.xml @@ -6,7 +6,7 @@ - + diff --git a/cmdline/src/test/java/io/opentdf/platform/CliDpopOptionsTest.java b/cmdline/src/test/java/io/opentdf/platform/CliDpopOptionsTest.java new file mode 100644 index 00000000..0e5a22cc --- /dev/null +++ b/cmdline/src/test/java/io/opentdf/platform/CliDpopOptionsTest.java @@ -0,0 +1,165 @@ +package io.opentdf.platform; + +import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.jwk.Curve; +import com.nimbusds.jose.jwk.ECKey; +import com.nimbusds.jose.jwk.KeyUse; +import com.nimbusds.jose.jwk.RSAKey; +import com.nimbusds.jose.jwk.gen.ECKeyGenerator; +import com.nimbusds.jose.jwk.gen.RSAKeyGenerator; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Optional; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class CliDpopOptionsTest { + + @Test + void parse_returnsEmpty_whenNeitherFlagSet() { + assertThat(CliDpopOptions.parse(null, null)).isEmpty(); + } + + @ParameterizedTest + @ValueSource(strings = {"RS256", "RS384", "RS512", "ES256", "ES384", "ES512"}) + void parse_generatesKeyForExplicitAlgorithm(String alg) { + Optional result = CliDpopOptions.parse(alg, null); + assertThat(result).isPresent(); + assertThat(result.get().alg).isEqualTo(JWSAlgorithm.parse(alg)); + assertThat(result.get().jwk.isPrivate()).isTrue(); + } + + @Test + void parse_defaultsToRs256_whenDpopFlagWithoutValue() { + Optional result = CliDpopOptions.parse("", null); + assertThat(result).isPresent(); + assertThat(result.get().alg).isEqualTo(JWSAlgorithm.RS256); + assertThat(result.get().jwk).isInstanceOf(RSAKey.class); + } + + @Test + void parse_throwsForUnsupportedAlgorithm() { + assertThatThrownBy(() -> CliDpopOptions.parse("HS256", null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Unsupported DPoP algorithm") + .hasMessageContaining("HS256"); + } + + @Test + void parse_throwsForMissingKeyFile() { + Path nonexistent = Path.of("/tmp/definitely-does-not-exist-" + UUID.randomUUID() + ".pem"); + assertThatThrownBy(() -> CliDpopOptions.parse(null, nonexistent)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Cannot read DPoP key file") + .hasMessageContaining(nonexistent.toString()); + } + + @Test + void parse_throwsForMalformedPem(@TempDir Path tmp) throws Exception { + Path badPem = tmp.resolve("bad.pem"); + Files.writeString(badPem, "this is not a PEM file"); + assertThatThrownBy(() -> CliDpopOptions.parse(null, badPem)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("not a valid PEM-encoded key"); + } + + @Test + void parse_throwsForPublicKeyOnlyPem(@TempDir Path tmp) throws Exception { + RSAKey rsa = new RSAKeyGenerator(2048) + .keyUse(KeyUse.SIGNATURE) + .keyID(UUID.randomUUID().toString()) + .generate(); + Path publicOnly = tmp.resolve("public.pem"); + Files.writeString(publicOnly, encodePublicKey(rsa)); + assertThatThrownBy(() -> CliDpopOptions.parse(null, publicOnly)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("public key only") + .hasMessageContaining("private key is required"); + } + + @Test + void parse_acceptsRsaPrivateKeyPemAndDefaultsToRs256(@TempDir Path tmp) throws Exception { + RSAKey rsa = new RSAKeyGenerator(2048) + .keyUse(KeyUse.SIGNATURE) + .keyID(UUID.randomUUID().toString()) + .generate(); + Path keyFile = tmp.resolve("rsa.pem"); + Files.writeString(keyFile, encodePrivateKey(rsa.toPrivateKey().getEncoded())); + + Optional result = CliDpopOptions.parse(null, keyFile); + assertThat(result).isPresent(); + assertThat(result.get().alg).isEqualTo(JWSAlgorithm.RS256); + assertThat(result.get().jwk).isInstanceOf(RSAKey.class); + assertThat(result.get().jwk.isPrivate()).isTrue(); + } + + @Test + void parse_acceptsEcPrivateKeyAndInfersAlgorithm(@TempDir Path tmp) throws Exception { + ECKey ec = new ECKeyGenerator(Curve.P_256) + .keyUse(KeyUse.SIGNATURE) + .keyID(UUID.randomUUID().toString()) + .generate(); + Path keyFile = tmp.resolve("ec.pem"); + Files.writeString(keyFile, encodeEcKeyPair(ec)); + + Optional result = CliDpopOptions.parse(null, keyFile); + assertThat(result).isPresent(); + assertThat(result.get().alg).isEqualTo(JWSAlgorithm.ES256); + assertThat(result.get().jwk).isInstanceOf(ECKey.class); + } + + @Test + void parse_rejectsRsaKeyWithEcAlgorithm(@TempDir Path tmp) throws Exception { + RSAKey rsa = new RSAKeyGenerator(2048) + .keyUse(KeyUse.SIGNATURE) + .keyID(UUID.randomUUID().toString()) + .generate(); + Path keyFile = tmp.resolve("rsa.pem"); + Files.writeString(keyFile, encodePrivateKey(rsa.toPrivateKey().getEncoded())); + + assertThatThrownBy(() -> CliDpopOptions.parse("ES256", keyFile)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("incompatible with --dpop=ES256"); + } + + @Test + void parse_explicitAlgorithmOverridesEcInferenceWhenCompatible(@TempDir Path tmp) throws Exception { + ECKey ec = new ECKeyGenerator(Curve.P_256) + .keyUse(KeyUse.SIGNATURE) + .keyID(UUID.randomUUID().toString()) + .generate(); + Path keyFile = tmp.resolve("ec.pem"); + Files.writeString(keyFile, encodeEcKeyPair(ec)); + + Optional result = CliDpopOptions.parse("ES256", keyFile); + assertThat(result).isPresent(); + assertThat(result.get().alg).isEqualTo(JWSAlgorithm.ES256); + } + + private static String encodePrivateKey(byte[] pkcs8) { + String base64 = java.util.Base64.getMimeEncoder(64, "\n".getBytes()).encodeToString(pkcs8); + return "-----BEGIN PRIVATE KEY-----\n" + base64 + "\n-----END PRIVATE KEY-----\n"; + } + + private static String encodePublicKey(RSAKey key) throws Exception { + byte[] x509 = key.toPublicKey().getEncoded(); + String base64 = java.util.Base64.getMimeEncoder(64, "\n".getBytes()).encodeToString(x509); + return "-----BEGIN PUBLIC KEY-----\n" + base64 + "\n-----END PUBLIC KEY-----\n"; + } + + private static String encodeEcKeyPair(ECKey ec) throws Exception { + byte[] pubX509 = ec.toPublicKey().getEncoded(); + byte[] privPkcs8 = ec.toPrivateKey().getEncoded(); + String pubB64 = java.util.Base64.getMimeEncoder(64, "\n".getBytes()).encodeToString(pubX509); + String privB64 = java.util.Base64.getMimeEncoder(64, "\n".getBytes()).encodeToString(privPkcs8); + return "-----BEGIN PUBLIC KEY-----\n" + pubB64 + "\n-----END PUBLIC KEY-----\n" + + "-----BEGIN PRIVATE KEY-----\n" + privB64 + "\n-----END PRIVATE KEY-----\n"; + } +} diff --git a/cmdline/src/test/java/io/opentdf/platform/CommandTest.java b/cmdline/src/test/java/io/opentdf/platform/CommandTest.java new file mode 100644 index 00000000..54ae3c44 --- /dev/null +++ b/cmdline/src/test/java/io/opentdf/platform/CommandTest.java @@ -0,0 +1,108 @@ +package io.opentdf.platform; + +import org.junit.jupiter.api.Test; +import picocli.CommandLine; + +import java.io.PrintWriter; +import java.io.StringWriter; + +import static org.assertj.core.api.Assertions.assertThat; + +class CommandTest { + + @Test + void supports_dpop_exits_0() { + int code = new CommandLine(new Command()).execute("supports", "dpop"); + assertThat(code).isEqualTo(0); + } + + @Test + void supports_dpop_nonce_challenge_exits_0() { + int code = new CommandLine(new Command()).execute("supports", "dpop_nonce_challenge"); + assertThat(code).isEqualTo(0); + } + + @Test + void supports_unknown_feature_exits_1() { + int code = new CommandLine(new Command()).execute("supports", "unknown_feature"); + assertThat(code).isEqualTo(1); + } + + @Test + void encrypt_withoutCredentials_failsWithMissingPlatformEndpoint() { + StringWriter err = new StringWriter(); + CommandLine cli = new CommandLine(new Command()); + cli.setErr(new PrintWriter(err)); + + int code = cli.execute("encrypt", "-k", "https://kas.example.com", "-f", "/dev/null"); + + // Picocli exit code for ParameterException is USAGE (2). + assertThat(code).isEqualTo(CommandLine.ExitCode.USAGE); + assertThat(err.toString()).contains("Missing required option: '--platform-endpoint='"); + } + + @Test + void supports_withoutCredentials_stillExits0() { + // Regression sentinel: tdf supports must not require --client-id/--client-secret/--platform-endpoint. + int code = new CommandLine(new Command()).execute("supports", "dpop"); + assertThat(code).isEqualTo(0); + } + + @Test + void verbose_flag_accepted_by_supports() { + int code = new CommandLine(new Command()).execute("--verbose", "supports", "dpop"); + assertThat(code).isEqualTo(0); + } + + @Test + void verbose_short_flag_accepted_by_supports() { + int code = new CommandLine(new Command()).execute("-v", "supports", "dpop"); + assertThat(code).isEqualTo(0); + } + + @Test + void verbose_flag_sets_verbose_field() { + var command = new Command(); + new CommandLine(command).parseArgs("--verbose", "supports", "dpop"); + assertThat(command.verbose).isTrue(); + } + + @Test + void encrypt_withUnsupportedDpopAlgorithm_failsWithUsage() { + StringWriter err = new StringWriter(); + CommandLine cli = new CommandLine(new Command()); + cli.setErr(new PrintWriter(err)); + + int code = cli.execute( + "--platform-endpoint", "https://example.invalid", + "--client-id", "x", + "--client-secret", "x", + "encrypt", + "--dpop=HS256", + "-k", "https://kas.example.invalid", + "-f", "/dev/null"); + + assertThat(code).isEqualTo(CommandLine.ExitCode.USAGE); + assertThat(err.toString()).contains("Unsupported DPoP algorithm").contains("HS256"); + } + + @Test + void encrypt_withMissingDpopKeyFile_failsWithUsage() { + StringWriter err = new StringWriter(); + CommandLine cli = new CommandLine(new Command()); + cli.setErr(new PrintWriter(err)); + + int code = cli.execute( + "--platform-endpoint", "https://example.invalid", + "--client-id", "x", + "--client-secret", "x", + "encrypt", + "--dpop-key", "/tmp/does-not-exist-dpop-key.pem", + "-k", "https://kas.example.invalid", + "-f", "/dev/null"); + + assertThat(code).isEqualTo(CommandLine.ExitCode.USAGE); + assertThat(err.toString()).contains("Cannot read DPoP key file") + .contains("/tmp/does-not-exist-dpop-key.pem"); + } +} diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/DpopKeyValidation.java b/sdk/src/main/java/io/opentdf/platform/sdk/DpopKeyValidation.java new file mode 100644 index 00000000..d236b7d5 --- /dev/null +++ b/sdk/src/main/java/io/opentdf/platform/sdk/DpopKeyValidation.java @@ -0,0 +1,55 @@ +package io.opentdf.platform.sdk; + +import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.jwk.Curve; +import com.nimbusds.jose.jwk.ECKey; +import com.nimbusds.jose.jwk.JWK; +import com.nimbusds.jose.jwk.RSAKey; + +public final class DpopKeyValidation { + private DpopKeyValidation() { + } + + public static void validate(JWK jwk, JWSAlgorithm alg) { + if (jwk == null) { + throw new IllegalArgumentException("DPoP JWK cannot be null"); + } + if (alg == null) { + throw new IllegalArgumentException("DPoP algorithm cannot be null"); + } + if (jwk instanceof RSAKey) { + if (!isRsaAlgorithm(alg)) { + throw new IllegalArgumentException("DPoP algorithm " + alg + + " is not compatible with an RSA key; expected one of RS256/RS384/RS512 or PS256/PS384/PS512"); + } + } else if (jwk instanceof ECKey) { + JWSAlgorithm expected = inferEcAlgorithm(((ECKey) jwk).getCurve()); + if (!alg.equals(expected)) { + throw new IllegalArgumentException("DPoP algorithm " + alg + + " is not compatible with EC key on curve " + ((ECKey) jwk).getCurve() + + "; expected " + expected); + } + } else { + throw new IllegalArgumentException("Unsupported JWK type for DPoP: " + jwk.getKeyType() + + "; expected RSA or EC"); + } + } + + public static JWSAlgorithm inferEcAlgorithm(Curve curve) { + if (Curve.P_256.equals(curve)) { + return JWSAlgorithm.ES256; + } + if (Curve.P_384.equals(curve)) { + return JWSAlgorithm.ES384; + } + if (Curve.P_521.equals(curve)) { + return JWSAlgorithm.ES512; + } + throw new IllegalArgumentException("Unsupported EC curve for DPoP: " + curve); + } + + private static boolean isRsaAlgorithm(JWSAlgorithm alg) { + return JWSAlgorithm.RS256.equals(alg) || JWSAlgorithm.RS384.equals(alg) || JWSAlgorithm.RS512.equals(alg) + || JWSAlgorithm.PS256.equals(alg) || JWSAlgorithm.PS384.equals(alg) || JWSAlgorithm.PS512.equals(alg); + } +} diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/SDKBuilder.java b/sdk/src/main/java/io/opentdf/platform/sdk/SDKBuilder.java index 5b7bbcc7..e612404f 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/SDKBuilder.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/SDKBuilder.java @@ -1,15 +1,19 @@ package io.opentdf.platform.sdk; import com.connectrpc.ConnectException; -import com.connectrpc.Interceptor; import com.connectrpc.ProtocolClientConfig; import com.connectrpc.extensions.GoogleJavaProtobufStrategy; import com.connectrpc.impl.ProtocolClient; import com.connectrpc.okhttp.ConnectOkHttpClient; import com.connectrpc.protocols.GETConfiguration; import com.nimbusds.jose.JOSEException; +import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.jwk.Curve; +import com.nimbusds.jose.jwk.ECKey; +import com.nimbusds.jose.jwk.JWK; import com.nimbusds.jose.jwk.KeyUse; import com.nimbusds.jose.jwk.RSAKey; +import com.nimbusds.jose.jwk.gen.ECKeyGenerator; import com.nimbusds.jose.jwk.gen.RSAKeyGenerator; import com.nimbusds.oauth2.sdk.AuthorizationGrant; import com.nimbusds.oauth2.sdk.ClientCredentialsGrant; @@ -64,6 +68,8 @@ public class SDKBuilder { private AuthorizationGrant authzGrant; private ProtocolType protocolType = ProtocolType.CONNECT; private SrtSigner srtSigner; + private JWK dpopKey; + private JWSAlgorithm dpopAlg; private static final Logger logger = LoggerFactory.getLogger(SDKBuilder.class); @@ -194,7 +200,7 @@ public SDKBuilder useInsecurePlaintextConnection(Boolean usePlainText) { /** * Set the network protocol to use for communication with platform services. - * + * * @param protocolType the protocol type to use (CONNECT, GRPC, or GRPC_WEB) * @return this builder instance for method chaining * @throws IllegalArgumentException if protocolType is null @@ -213,7 +219,32 @@ public SDKBuilder srtSigner(SrtSigner signer) { return this; } - private Interceptor getAuthInterceptor(RSAKey rsaKey) { + /** + * Configure a custom JWK (RSA or EC) for DPoP (RFC 9449) proof generation. + * If not provided, the SDK will auto-generate an ephemeral RSA-2048 key. + * RSA keys also serve as the SRT signing key; EC keys use a separate auto-generated RSA key for SRT. + * + * @param dpopKey JWK (RSA or EC) to use for DPoP proofs + * @return this builder instance for method chaining + */ + public SDKBuilder dpopKey(JWK dpopKey) { + this.dpopKey = dpopKey; + return this; + } + + /** + * Set the JWS algorithm to use for DPoP proofs. If omitted, defaults to RS256 for RSA keys + * or the curve-appropriate algorithm for EC keys. + * + * @param dpopAlg JWS algorithm (e.g. RS256, ES256) + * @return this builder instance for method chaining + */ + public SDKBuilder dpopAlgorithm(JWSAlgorithm dpopAlg) { + this.dpopAlg = dpopAlg; + return this; + } + + private AuthInterceptor getAuthInterceptor(JWK dpopJwk, JWSAlgorithm dpopAlgorithm) { if (platformEndpoint == null) { throw new SDKException("cannot build an SDK without specifying the platform endpoint"); } @@ -243,6 +274,12 @@ private Interceptor getAuthInterceptor(RSAKey rsaKey) { .getFieldsOrThrow(PLATFORM_ISSUER) .getStringValue(); } catch (IllegalArgumentException e) { + if (this.dpopKey != null || this.dpopAlg != null) { + throw new SDKException( + "DPoP was requested but the platform_issuer is missing from the well-known " + + "configuration at " + platformEndpoint + + "; the SDK cannot configure DPoP without a token endpoint", e); + } logger.warn( "no `platform_issuer` found in well known configuration. requests from the SDK will be unauthenticated", e); @@ -264,7 +301,7 @@ private Interceptor getAuthInterceptor(RSAKey rsaKey) { if (this.authzGrant == null) { this.authzGrant = new ClientCredentialsGrant(); } - var ts = new TokenSource(clientAuth, rsaKey, providerMetadata.getTokenEndpointURI(), this.authzGrant, sslSocketFactory); + var ts = new TokenSource(clientAuth, dpopJwk, dpopAlgorithm, providerMetadata.getTokenEndpointURI(), this.authzGrant, sslSocketFactory); return new AuthInterceptor(ts); } @@ -282,14 +319,14 @@ public SDKBuilder insecureSslFactory() { } static class ServicesAndInternals { - final Interceptor interceptor; + final AuthInterceptor interceptor; final TrustManager trustManager; final ProtocolClient protocolClient; final SrtSigner srtSigner; final SDK.Services services; - ServicesAndInternals(Interceptor interceptor, TrustManager trustManager, SDK.Services services, ProtocolClient protocolClient, SrtSigner srtSigner) { + ServicesAndInternals(AuthInterceptor interceptor, TrustManager trustManager, SDK.Services services, ProtocolClient protocolClient, SrtSigner srtSigner) { this.interceptor = interceptor; this.trustManager = trustManager; this.services = services; @@ -305,22 +342,55 @@ ServicesAndInternals buildServices() { "gRPC-Web is designed for web browsers and typically operates over HTTP/1.1, " + "while plaintext connections force HTTP/2 prior knowledge."); } - - RSAKey dpopKey; - try { - dpopKey = new RSAKeyGenerator(2048) - .keyUse(KeyUse.SIGNATURE) - .keyID(UUID.randomUUID().toString()) - .generate(); - } catch (JOSEException e) { - throw new SDKException("Error generating DPoP key", e); + + // Resolve the DPoP JWK and algorithm + JWK effectiveDpopJwk; + JWSAlgorithm effectiveDpopAlg; + RSAKey srtKey; // SRT signing always uses RSA + + if (this.dpopKey != null) { + effectiveDpopJwk = this.dpopKey; + if (this.dpopAlg != null) { + effectiveDpopAlg = this.dpopAlg; + } else if (effectiveDpopJwk instanceof ECKey) { + effectiveDpopAlg = inferEcAlgorithm((ECKey) effectiveDpopJwk); + } else { + effectiveDpopAlg = JWSAlgorithm.RS256; + } + if (effectiveDpopJwk instanceof RSAKey) { + srtKey = (RSAKey) effectiveDpopJwk; + } else { + // EC DPoP key: generate a separate RSA key for SRT signing + try { + srtKey = new RSAKeyGenerator(2048) + .keyUse(KeyUse.SIGNATURE) + .keyID(UUID.randomUUID().toString()) + .generate(); + } catch (JOSEException e) { + throw new SDKException("Error generating SRT RSA key", e); + } + } + } else { + // Auto-generate RSA-2048 for both DPoP and SRT + try { + srtKey = new RSAKeyGenerator(2048) + .keyUse(KeyUse.SIGNATURE) + .keyID(UUID.randomUUID().toString()) + .generate(); + } catch (JOSEException e) { + throw new SDKException("Error generating DPoP key", e); + } + effectiveDpopJwk = srtKey; + effectiveDpopAlg = this.dpopAlg != null ? this.dpopAlg : JWSAlgorithm.RS256; } this.platformEndpoint = AddressNormalizer.normalizeAddress(this.platformEndpoint, this.usePlainText); - var authInterceptor = getAuthInterceptor(dpopKey); - var srtSignerToUse = this.srtSigner == null ? new DefaultSrtSigner(dpopKey) : this.srtSigner; - var kasClient = getKASClient(srtSignerToUse, authInterceptor); - var httpClient = getHttpClient(); + var authInterceptor = getAuthInterceptor(effectiveDpopJwk, effectiveDpopAlg); + var srtSignerToUse = this.srtSigner == null ? new DefaultSrtSigner(srtKey) : this.srtSigner; + + okhttp3.Interceptor dpopRetry = authInterceptor != null ? authInterceptor.dpopRetryInterceptor() : null; + var kasClient = getKASClient(srtSignerToUse, authInterceptor, dpopRetry); + var httpClient = getHttpClient(dpopRetry); var client = getProtocolClient(platformEndpoint, httpClient, authInterceptor); var attributeService = new AttributesServiceClient(client); var namespaceService = new NamespaceServiceClient(client); @@ -394,9 +464,9 @@ public SDK.KAS kas() { } @Nonnull - private KASClient getKASClient(SrtSigner srtSigner, Interceptor interceptor) { + private KASClient getKASClient(SrtSigner srtSigner, AuthInterceptor interceptor, okhttp3.Interceptor dpopRetry) { BiFunction protocolClientFactory = (OkHttpClient client, String address) -> getProtocolClient(address, client, interceptor); - return new KASClient(getHttpClient(), protocolClientFactory, srtSigner, usePlainText); + return new KASClient(getHttpClient(dpopRetry), protocolClientFactory, srtSigner, usePlainText); } public SDK build() { @@ -408,13 +478,19 @@ private ProtocolClient getUnauthenticatedProtocolClient(String endpoint, OkHttpC return getProtocolClient(endpoint, httpClient, null); } - private ProtocolClient getProtocolClient(String endpoint, OkHttpClient httpClient, Interceptor authInterceptor) { + private ProtocolClient getProtocolClient(String endpoint, OkHttpClient httpClient, AuthInterceptor authInterceptor) { + // Connect-GET would rewrite idempotent POST RPCs to GET on the wire, which invalidates + // the DPoP proof's htm claim (stamped before the rewrite). Keep it enabled only on the + // unauthenticated bootstrap path where no DPoP proof is attached. + GETConfiguration getConfig = authInterceptor != null + ? GETConfiguration.Disabled.INSTANCE + : GETConfiguration.Enabled.INSTANCE; var protocolClientConfig = new ProtocolClientConfig( endpoint, new GoogleJavaProtobufStrategy(), protocolType.getNetworkProtocol(), null, - GETConfiguration.Enabled.INSTANCE, + getConfig, authInterceptor == null ? Collections.emptyList() : List.of(ignoredConfig -> authInterceptor) ); @@ -423,9 +499,14 @@ private ProtocolClient getProtocolClient(String endpoint, OkHttpClient httpClien @SuppressWarnings("deprecation") private OkHttpClient getHttpClient() { - // using a single http client is apparently the best practice, subject to everyone wanting to - // have the same protocols + return getHttpClient((okhttp3.Interceptor) null); + } + + private OkHttpClient getHttpClient(okhttp3.Interceptor additionalInterceptor) { var httpClient = new OkHttpClient.Builder(); + if (additionalInterceptor != null) { + httpClient.addInterceptor(additionalInterceptor); + } if (usePlainText) { // For plaintext connections, we need HTTP/2 prior knowledge because gRPC servers // expect HTTP/2, and Connect protocol can communicate with gRPC servers over HTTP/2 @@ -451,4 +532,12 @@ SSLSocketFactory getSslFactory() { X509TrustManager getTrustManager() { return this.trustManager; } + + private static JWSAlgorithm inferEcAlgorithm(ECKey ecKey) { + try { + return DpopKeyValidation.inferEcAlgorithm(ecKey.getCurve()); + } catch (IllegalArgumentException e) { + throw new SDKException(e.getMessage(), e); + } + } } diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/TokenSource.java b/sdk/src/main/java/io/opentdf/platform/sdk/TokenSource.java index 01089452..10ff0209 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/TokenSource.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/TokenSource.java @@ -2,6 +2,7 @@ import com.nimbusds.jose.JOSEException; import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.jwk.JWK; import com.nimbusds.jose.jwk.RSAKey; import com.nimbusds.jwt.SignedJWT; import com.nimbusds.oauth2.sdk.AuthorizationGrant; @@ -14,40 +15,56 @@ import com.nimbusds.oauth2.sdk.http.HTTPRequest; import com.nimbusds.oauth2.sdk.http.HTTPResponse; import com.nimbusds.oauth2.sdk.token.AccessToken; +import com.nimbusds.openid.connect.sdk.Nonce; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.annotation.Nullable; import javax.net.ssl.SSLSocketFactory; +import java.io.IOException; +import java.net.MalformedURLException; import java.net.URI; import java.net.URISyntaxException; import java.net.URL; import java.time.Instant; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; /** * The TokenSource class is responsible for providing authorization tokens. It handles * timeouts and creating OIDC calls. It is thread-safe. */ class TokenSource { + static final String SCHEME_DPOP = "DPoP"; + static final String SCHEME_BEARER = "Bearer"; + private Instant tokenExpiryTime; private AccessToken token; + private String tokenScheme; private final ClientAuthentication clientAuth; - private final RSAKey rsaKey; + private final JWK dpopJwk; + private final JWSAlgorithm dpopAlg; private final URI tokenEndpointURI; private final AuthorizationGrant authzGrant; private final SSLSocketFactory sslSocketFactory; + // Cache for server-issued nonces, keyed by origin (scheme://host:port) + private final Map nonceCache = new ConcurrentHashMap<>(); private static final Logger logger = LoggerFactory.getLogger(TokenSource.class); /** - * Constructs a new TokenSource with the specified client authentication and RSA key. + * Constructs a new TokenSource with the specified client authentication and DPoP key. * * @param clientAuth the client authentication to be used by the interceptor - * @param rsaKey the RSA key to be used by the interceptor + * @param dpopJwk the JWK (RSA or EC) to use for DPoP proof generation + * @param dpopAlg the JWS algorithm matching the key type * @param sslSocketFactory Optional SSLSocketFactory for token endpoint requests */ - public TokenSource(ClientAuthentication clientAuth, RSAKey rsaKey, URI tokenEndpointURI, AuthorizationGrant authzGrant, SSLSocketFactory sslSocketFactory) { + public TokenSource(ClientAuthentication clientAuth, JWK dpopJwk, JWSAlgorithm dpopAlg, URI tokenEndpointURI, AuthorizationGrant authzGrant, SSLSocketFactory sslSocketFactory) { + DpopKeyValidation.validate(dpopJwk, dpopAlg); this.clientAuth = clientAuth; - this.rsaKey = rsaKey; + this.dpopJwk = dpopJwk; + this.dpopAlg = dpopAlg; this.tokenEndpointURI = tokenEndpointURI; this.sslSocketFactory = sslSocketFactory; this.authzGrant = authzGrant; @@ -56,9 +73,10 @@ public TokenSource(ClientAuthentication clientAuth, RSAKey rsaKey, URI tokenEndp class AuthHeaders { private final String authHeader; + @Nullable private final String dpopHeader; - public AuthHeaders(String authHeader, String dpopHeader) { + public AuthHeaders(String authHeader, @Nullable String dpopHeader) { this.authHeader = authHeader; this.dpopHeader = dpopHeader; } @@ -67,20 +85,53 @@ public String getAuthHeader() { return authHeader; } + @Nullable public String getDpopHeader() { return dpopHeader; } } public AuthHeaders getAuthHeaders(URL url, String method) { + return getAuthHeaders(url, method, null); + } + + /** + * Get authorization headers for a request, including DPoP proof. + * + * @param url The URL being accessed + * @param method The HTTP method + * @param nonce Optional server-issued nonce to include in the proof + * @return AuthHeaders containing Authorization and DPoP headers + */ + public AuthHeaders getAuthHeaders(URL url, String method, String nonce) { // Get the access token AccessToken t = getToken(); + // If the AS returned a plain bearer token, send it as a bearer credential + // without a DPoP proof. Sending "Authorization: DPoP " is a misuse + // of the scheme and resource servers that enforce DPoP will reject it. + if (SCHEME_BEARER.equals(tokenScheme)) { + return new AuthHeaders("Bearer " + t.getValue(), null); + } + // Build the DPoP proof for each request String dpopProof; try { - DPoPProofFactory dpopFactory = new DefaultDPoPProofFactory(rsaKey, JWSAlgorithm.RS256); - SignedJWT proof = dpopFactory.createDPoPJWT(method, url.toURI(), t); + DPoPProofFactory dpopFactory = new DefaultDPoPProofFactory(dpopJwk, dpopAlg); + + // Get cached nonce if not explicitly provided + if (nonce == null) { + String origin = getOrigin(url); + nonce = nonceCache.get(origin); + } + + SignedJWT proof; + URI htu = htuOf(url.toURI()); + if (nonce != null) { + proof = dpopFactory.createDPoPJWT(method, htu, t, new Nonce(nonce)); + } else { + proof = dpopFactory.createDPoPJWT(method, htu, t); + } dpopProof = proof.serialize(); } catch (URISyntaxException e) { throw new SDKException("Invalid URI syntax for DPoP proof creation", e); @@ -93,6 +144,47 @@ public AuthHeaders getAuthHeaders(URL url, String method) { dpopProof); } + /** + * Cache a server-issued nonce for the given URL's origin. + * + * @param url The URL from which the nonce was received + * @param nonce The nonce value to cache + */ + public void cacheNonce(URL url, String nonce) { + if (nonce != null && !nonce.isEmpty()) { + String origin = getOrigin(url); + nonceCache.put(origin, nonce); + logger.trace("Cached DPoP nonce for origin: {}", origin); + } + } + + // RFC 9449 §4.2: the htu claim is the request URI with query and fragment removed. + // Nimbus rejects any URI carrying a query, so strip both before handing it off. + private static URI htuOf(URI uri) { + if (uri.getRawQuery() == null && uri.getRawFragment() == null) { + return uri; + } + try { + return new URI(uri.getScheme(), uri.getAuthority(), uri.getPath(), null, null); + } catch (URISyntaxException e) { + throw new SDKException("failed to normalize URI for DPoP htu claim: " + uri, e); + } + } + + /** + * Get the origin (scheme://host:port) from a URL for nonce caching. + * + * @param url The URL to extract origin from + * @return The origin string + */ + private String getOrigin(URL url) { + int port = url.getPort(); + if (port == -1) { + port = url.getDefaultPort(); + } + return url.getProtocol() + "://" + url.getHost() + ":" + port; + } + /** * Either fetches a new access token or returns the cached access token if it is still valid. * @@ -105,42 +197,88 @@ private synchronized AccessToken getToken() { logger.trace("The current access token is expired or empty, getting a new one"); - // Make the token request - TokenRequest tokenRequest = new TokenRequest(this.tokenEndpointURI, - clientAuth, authzGrant, null); + DPoPProofFactory dpopFactory = new DefaultDPoPProofFactory(dpopJwk, dpopAlg); + + // Proactively use any cached nonce for the token endpoint origin (RFC 9449 §8.2) + URL tokenEndpointUrl = tokenEndpointURI.toURL(); + String cachedNonce = nonceCache.get(getOrigin(tokenEndpointUrl)); + + TokenRequest tokenRequest = new TokenRequest(this.tokenEndpointURI, clientAuth, authzGrant, null); HTTPRequest httpRequest = tokenRequest.toHTTPRequest(); if (sslSocketFactory != null) { httpRequest.setSSLSocketFactory(sslSocketFactory); } - - DPoPProofFactory dpopFactory = new DefaultDPoPProofFactory(rsaKey, JWSAlgorithm.RS256); - - SignedJWT proof = dpopFactory.createDPoPJWT(httpRequest.getMethod().name(), httpRequest.getURI()); - + URI tokenHtu = htuOf(httpRequest.getURI()); + SignedJWT proof = (cachedNonce != null) + ? dpopFactory.createDPoPJWT(httpRequest.getMethod().name(), tokenHtu, new Nonce(cachedNonce)) + : dpopFactory.createDPoPJWT(httpRequest.getMethod().name(), tokenHtu); httpRequest.setDPoP(proof); - TokenResponse tokenResponse; HTTPResponse httpResponse = httpRequest.send(); - tokenResponse = TokenResponse.parse(httpResponse); + TokenResponse tokenResponse = TokenResponse.parse(httpResponse); + + // RFC 9449 §8.2: if AS requires a nonce, cache it and retry once if (!tokenResponse.indicatesSuccess()) { ErrorObject error = tokenResponse.toErrorResponse().getErrorObject(); - throw new SDKException("failure to get token. description = [" + error.getDescription() + "] error code = [" + error.getCode() + "] error uri = [" + error.getURI() + "]"); + if ("use_dpop_nonce".equals(error.getCode())) { + String dpopNonce = httpResponse.getHeaderValue("DPoP-Nonce"); + if (dpopNonce != null) { + cacheNonce(tokenEndpointUrl, dpopNonce); + TokenRequest retryRequest = new TokenRequest(tokenEndpointURI, clientAuth, authzGrant, null); + HTTPRequest retryHttpRequest = retryRequest.toHTTPRequest(); + if (sslSocketFactory != null) { + retryHttpRequest.setSSLSocketFactory(sslSocketFactory); + } + SignedJWT retryProof = dpopFactory.createDPoPJWT( + retryHttpRequest.getMethod().name(), + htuOf(retryHttpRequest.getURI()), + new Nonce(dpopNonce)); + retryHttpRequest.setDPoP(retryProof); + httpResponse = retryHttpRequest.send(); + tokenResponse = TokenResponse.parse(httpResponse); + // Cache any nonce rotation from the AS (RFC 9449 §8.2) + String rotatedNonce = httpResponse.getHeaderValue("DPoP-Nonce"); + if (rotatedNonce != null) { + cacheNonce(tokenEndpointUrl, rotatedNonce); + } + } else { + logger.warn("token endpoint {} returned use_dpop_nonce but did not supply a DPoP-Nonce response header", + tokenEndpointURI); + } + } + if (!tokenResponse.indicatesSuccess()) { + ErrorObject finalError = tokenResponse.toErrorResponse().getErrorObject(); + throw new SDKException("failure to get token. description = [" + finalError.getDescription() + + "] error code = [" + finalError.getCode() + + "] error uri = [" + finalError.getURI() + "]"); + } } var tokens = tokenResponse.toSuccessResponse().getTokens(); - if (tokens.getDPoPAccessToken() != null) { + boolean asAssertsDpop = tokens.getDPoPAccessToken() != null; + if (asAssertsDpop) { logger.trace("retrieved a new DPoP access token"); } else if (tokens.getAccessToken() != null) { logger.trace("retrieved a new access token"); } else { - logger.trace("got an access token of unknown type"); + logger.warn("token endpoint {} returned a success response with an unknown access token type", + tokenEndpointURI); } this.token = tokens.getAccessToken(); + if (this.token == null) { + throw new SDKException("token endpoint " + tokenEndpointURI + + " returned a success response with no access token"); + } + this.tokenScheme = asAssertsDpop ? SCHEME_DPOP : SCHEME_BEARER; + if (!asAssertsDpop) { + logger.warn("token endpoint {} returned a non-DPoP-bound access token (token_type=Bearer) despite" + + " DPoP proof — falling back to Bearer scheme. Check the IdP DPoP configuration.", + tokenEndpointURI); + } if (token.getLifetime() != 0) { - // Need some type of leeway but not sure whats best this.tokenExpiryTime = Instant.now().plusSeconds(token.getLifetime() / 3); } @@ -149,8 +287,19 @@ private synchronized AccessToken getToken() { return this.token; } - } catch (Exception e) { - throw new SDKException("failed to get token", e); + } catch (SDKException e) { + // Already shaped for the caller — don't double-wrap. + throw e; + } catch (MalformedURLException e) { + throw new SDKException("invalid token endpoint URL: " + tokenEndpointURI, e); + } catch (IOException e) { + throw new SDKException("network error contacting token endpoint " + tokenEndpointURI, e); + } catch (JOSEException e) { + throw new SDKException("DPoP proof generation failed for token endpoint " + tokenEndpointURI, e); + } catch (com.nimbusds.oauth2.sdk.ParseException e) { + throw new SDKException("malformed token response from " + tokenEndpointURI, e); + } catch (RuntimeException e) { + throw new SDKException("unexpected error fetching token from " + tokenEndpointURI, e); } return this.token; } diff --git a/sdk/src/main/kotlin/io/opentdf/platform/sdk/AuthInterceptor.kt b/sdk/src/main/kotlin/io/opentdf/platform/sdk/AuthInterceptor.kt index a3babe2b..08d82865 100644 --- a/sdk/src/main/kotlin/io/opentdf/platform/sdk/AuthInterceptor.kt +++ b/sdk/src/main/kotlin/io/opentdf/platform/sdk/AuthInterceptor.kt @@ -5,15 +5,31 @@ import com.connectrpc.StreamFunction import com.connectrpc.UnaryFunction import com.connectrpc.http.UnaryHTTPRequest import com.connectrpc.http.clone +import com.nimbusds.jwt.SignedJWT +import org.slf4j.LoggerFactory +import java.net.URL + +internal class AuthInterceptor(private val ts: TokenSource) : Interceptor { + private val logger = LoggerFactory.getLogger(AuthInterceptor::class.java) + // The connect-kotlin Interceptor API exposes no per-call context to thread the + // request URL into responseFunction. ThreadLocal is the workaround, relying on + // connect-kotlin's contract that requestFunction and responseFunction for a single + // unary call run synchronously on the same thread. If that assumption ever breaks, + // nonces could be cached against the wrong origin. The okhttp-level + // dpopRetryInterceptor below avoids the issue by reading the URL straight from + // chain.request(). + private val requestUrl = ThreadLocal() -private class AuthInterceptor(private val ts: TokenSource) : Interceptor{ override fun streamFunction(): StreamFunction { return StreamFunction( requestFunction = { request -> val requestHeaders = mutableMapOf>() val authHeaders = ts.getAuthHeaders(request.url, "POST") requestHeaders["Authorization"] = listOf(authHeaders.authHeader) - requestHeaders["DPoP"] = listOf(authHeaders.dpopHeader) + authHeaders.dpopHeader?.let { requestHeaders["DPoP"] = listOf(it) } + + logger.debug("DPoP path=stream url={} method=POST authScheme={} {}", + request.url, authScheme(authHeaders.authHeader), dpopSummary(authHeaders.dpopHeader)) return@StreamFunction request.clone( url = request.url, @@ -31,24 +47,127 @@ private class AuthInterceptor(private val ts: TokenSource) : Interceptor{ override fun unaryFunction(): UnaryFunction { return UnaryFunction( requestFunction = { request -> - val requestHeaders = mutableMapOf>() - val authHeaders = ts.getAuthHeaders(request.url, request.httpMethod.name) - requestHeaders["Authorization"] = listOf(authHeaders.authHeader) - requestHeaders["DPoP"] = listOf(authHeaders.dpopHeader) + // Clear any value left behind by an earlier requestFunction that + // threw before its paired responseFunction could run. + requestUrl.remove() + requestUrl.set(request.url) + try { + val requestHeaders = mutableMapOf>() + val authHeaders = ts.getAuthHeaders(request.url, request.httpMethod.name) + requestHeaders["Authorization"] = listOf(authHeaders.authHeader) + authHeaders.dpopHeader?.let { requestHeaders["DPoP"] = listOf(it) } - return@UnaryFunction UnaryHTTPRequest( - url = request.url, - contentType = request.contentType, - headers = requestHeaders, - message = request.message, - timeout = request.timeout, - methodSpec = request.methodSpec, - httpMethod = request.httpMethod - ) + logger.debug("DPoP path=unary url={} method={} authScheme={} {}", + request.url, request.httpMethod.name, + authScheme(authHeaders.authHeader), dpopSummary(authHeaders.dpopHeader)) + + UnaryHTTPRequest( + url = request.url, + contentType = request.contentType, + headers = requestHeaders, + message = request.message, + timeout = request.timeout, + methodSpec = request.methodSpec, + httpMethod = request.httpMethod + ) + } catch (t: Throwable) { + // responseFunction won't run, so clear the slot ourselves to + // avoid stale state leaking to the next call on this thread. + requestUrl.remove() + throw t + } }, responseFunction = { resp -> + val url = requestUrl.get() + requestUrl.remove() + + // Cache any server-issued DPoP nonce for future requests to the same origin + val dpopNonce = resp.headers["dpop-nonce"]?.firstOrNull() + ?: resp.headers["DPoP-Nonce"]?.firstOrNull() + if (dpopNonce != null && url != null) { + ts.cacheNonce(url, dpopNonce) + } + logger.debug("DPoP path=unary-response url={} nonceCached={} status={}", + url, dpopNonce != null && url != null, resp.status) resp }, ) } -} \ No newline at end of file + + /** + * Returns an OkHttp interceptor that retries on RFC 9449 §9 DPoP nonce challenges + * from resource servers (KAS and the platform-services Connect client). + * A 401 is retried only when WWW-Authenticate carries scheme=DPoP and error=use_dpop_nonce; + * any other 401 (or any 401 with only a stray DPoP-Nonce header) is passed through unchanged. + * Rotated nonces are cached after every successful proceed so the next request picks them up. + */ + fun dpopRetryInterceptor(): okhttp3.Interceptor = okhttp3.Interceptor { chain -> + val url = chain.request().url.toUrl() + val outgoingMethod = chain.request().method + var response = chain.proceed(chain.request()) + + // RFC 9449 §9: cache any rotated nonce from the response, regardless of status. + cacheNonceIfPresent(url, response) + + logger.debug("DPoP path=okhttp url={} method={} status={} authScheme={} {}", + url, outgoingMethod, response.code, + authScheme(chain.request().header("Authorization")), + dpopSummary(chain.request().header("DPoP"))) + + if (response.code == 401 && isDpopNonceChallenge(response)) { + val dpopNonce = response.header("dpop-nonce") ?: response.header("DPoP-Nonce") + if (dpopNonce != null) { + response.close() + ts.cacheNonce(url, dpopNonce) + val authHeaders = ts.getAuthHeaders(url, chain.request().method) + val newRequestBuilder = chain.request().newBuilder() + .header("Authorization", authHeaders.authHeader) + authHeaders.dpopHeader?.let { newRequestBuilder.header("DPoP", it) } + val newRequest = newRequestBuilder.build() + logger.debug("DPoP path=okhttp-retry url={} method={} nonce={} authScheme={} {}", + url, chain.request().method, dpopNonce, + authScheme(authHeaders.authHeader), dpopSummary(authHeaders.dpopHeader)) + response = try { + chain.proceed(newRequest) + } catch (e: Exception) { + logger.debug("DPoP retry request to {} failed", url, e) + throw e + } + cacheNonceIfPresent(url, response) + logger.debug("DPoP path=okhttp-retry-response url={} status={}", url, response.code) + } + } + response + } + + private fun cacheNonceIfPresent(url: URL, response: okhttp3.Response) { + val nonce = response.header("dpop-nonce") ?: response.header("DPoP-Nonce") + if (nonce != null) { + ts.cacheNonce(url, nonce) + } + } + + private fun isDpopNonceChallenge(response: okhttp3.Response): Boolean { + return response.challenges().any { challenge -> + challenge.scheme.equals("DPoP", ignoreCase = true) && + challenge.authParams["error"].equals("use_dpop_nonce", ignoreCase = true) + } + } + + private fun authScheme(authHeader: String?): String { + if (authHeader == null) return "" + val idx = authHeader.indexOf(' ') + return if (idx > 0) authHeader.substring(0, idx) else "?" + } + + private fun dpopSummary(dpopProof: String?): String { + if (dpopProof == null) return "dpop=" + return try { + val claims = SignedJWT.parse(dpopProof).jwtClaimsSet + "dpop[htm=${claims.getStringClaim("htm")} htu=${claims.getStringClaim("htu")}" + + " jti=${claims.getStringClaim("jti")} nonce=${claims.getStringClaim("nonce")}]" + } catch (e: Exception) { + "dpop=" + } + } +} diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/DPoPRetryInterceptorTest.java b/sdk/src/test/java/io/opentdf/platform/sdk/DPoPRetryInterceptorTest.java new file mode 100644 index 00000000..8091bb3f --- /dev/null +++ b/sdk/src/test/java/io/opentdf/platform/sdk/DPoPRetryInterceptorTest.java @@ -0,0 +1,464 @@ +package io.opentdf.platform.sdk; + +import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.jwk.KeyUse; +import com.nimbusds.jose.jwk.RSAKey; +import com.nimbusds.jose.jwk.gen.RSAKeyGenerator; +import com.nimbusds.jwt.SignedJWT; +import com.nimbusds.oauth2.sdk.ClientCredentialsGrant; +import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic; +import com.nimbusds.oauth2.sdk.auth.Secret; +import com.nimbusds.oauth2.sdk.id.ClientID; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.Response; +import okhttp3.mockwebserver.Dispatcher; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; + +import static org.assertj.core.api.Assertions.assertThat; + +class DPoPRetryInterceptorTest { + + private static final String FAKE_TOKEN_RESPONSE = + "{\"access_token\":\"test-token\",\"token_type\":\"DPoP\",\"expires_in\":3600}"; + + private AuthInterceptor buildAuthInterceptor(MockWebServer tokenServer, RSAKey rsaKey) throws Exception { + return new AuthInterceptor(buildTokenSource(tokenServer, rsaKey)); + } + + private TokenSource buildTokenSource(MockWebServer tokenServer, RSAKey rsaKey) { + return new TokenSource( + new ClientSecretBasic(new ClientID("test-client"), new Secret("test-secret")), + rsaKey, + JWSAlgorithm.RS256, + tokenServer.url("/token").uri(), + new ClientCredentialsGrant(), + null + ); + } + + @Test + void retryOn401WithDPoPNonce() throws Exception { + RSAKey rsaKey = new RSAKeyGenerator(2048) + .keyUse(KeyUse.SIGNATURE) + .keyID(UUID.randomUUID().toString()) + .generate(); + try (MockWebServer tokenServer = new MockWebServer(); + MockWebServer kasServer = new MockWebServer()) { + // Queue multiple token responses (one for each getAuthHeaders call during retry) + for (int i = 0; i < 5; i++) { + tokenServer.enqueue(new MockResponse() + .setBody(FAKE_TOKEN_RESPONSE) + .setHeader("Content-Type", "application/json")); + } + tokenServer.start(); + + // First request returns 401 + DPoP-Nonce + DPoP nonce challenge; second returns 200 + kasServer.enqueue(new MockResponse() + .setResponseCode(401) + .addHeader("DPoP-Nonce", "server-issued-nonce") + .addHeader("WWW-Authenticate", "DPoP error=\"use_dpop_nonce\"")); + kasServer.enqueue(new MockResponse().setResponseCode(200)); + kasServer.start(); + + AuthInterceptor authInterceptor = buildAuthInterceptor(tokenServer, rsaKey); + OkHttpClient client = new OkHttpClient.Builder() + .addInterceptor(authInterceptor.dpopRetryInterceptor()) + .build(); + + Request request = new Request.Builder() + .url(kasServer.url("/kas/rewrap")) + .post(okhttp3.RequestBody.create(new byte[0])) + .build(); + Response response = client.newCall(request).execute(); + response.close(); + + assertThat(kasServer.getRequestCount()).isEqualTo(2); + assertThat(response.code()).isEqualTo(200); + + // Verify second request carries a DPoP proof with the nonce + kasServer.takeRequest(); // consume first request + RecordedRequest retryRequest = kasServer.takeRequest(); + String dpopHeader = retryRequest.getHeader("DPoP"); + assertThat(dpopHeader).isNotNull(); + + SignedJWT dpopJwt = SignedJWT.parse(dpopHeader); + String nonceClaim = dpopJwt.getJWTClaimsSet().getStringClaim("nonce"); + assertThat(nonceClaim).isEqualTo("server-issued-nonce"); + } + } + + @Test + void noRetryOn401WithoutDPoPNonce() throws Exception { + RSAKey rsaKey = new RSAKeyGenerator(2048) + .keyUse(KeyUse.SIGNATURE) + .keyID(UUID.randomUUID().toString()) + .generate(); + try (MockWebServer tokenServer = new MockWebServer(); + MockWebServer kasServer = new MockWebServer()) { + tokenServer.enqueue(new MockResponse() + .setBody(FAKE_TOKEN_RESPONSE) + .setHeader("Content-Type", "application/json")); + tokenServer.start(); + + kasServer.enqueue(new MockResponse().setResponseCode(401)); + kasServer.start(); + + AuthInterceptor authInterceptor = buildAuthInterceptor(tokenServer, rsaKey); + OkHttpClient client = new OkHttpClient.Builder() + .addInterceptor(authInterceptor.dpopRetryInterceptor()) + .build(); + + Request request = new Request.Builder() + .url(kasServer.url("/kas/rewrap")) + .post(okhttp3.RequestBody.create(new byte[0])) + .build(); + Response response = client.newCall(request).execute(); + response.close(); + + assertThat(kasServer.getRequestCount()).isEqualTo(1); + assertThat(response.code()).isEqualTo(401); + } + } + + @Test + void onlyRetriesOnceWhenSecondResponseAlsoChallengesWithNonce() throws Exception { + // Pins the single-retry guarantee: even if the retry response is also a + // 401 + DPoP-Nonce + use_dpop_nonce challenge, no further retry is attempted. + // This protects against an infinite-retry loop if an AS misbehaves or rotates + // nonces faster than the client can spend them. + RSAKey rsaKey = new RSAKeyGenerator(2048) + .keyUse(KeyUse.SIGNATURE) + .keyID(UUID.randomUUID().toString()) + .generate(); + try (MockWebServer tokenServer = new MockWebServer(); + MockWebServer kasServer = new MockWebServer()) { + tokenServer.enqueue(new MockResponse() + .setBody(FAKE_TOKEN_RESPONSE) + .setHeader("Content-Type", "application/json")); + tokenServer.start(); + + kasServer.enqueue(new MockResponse() + .setResponseCode(401) + .addHeader("DPoP-Nonce", "first-nonce") + .addHeader("WWW-Authenticate", "DPoP error=\"use_dpop_nonce\"")); + kasServer.enqueue(new MockResponse() + .setResponseCode(401) + .addHeader("DPoP-Nonce", "second-nonce") + .addHeader("WWW-Authenticate", "DPoP error=\"use_dpop_nonce\"")); + kasServer.start(); + + AuthInterceptor authInterceptor = buildAuthInterceptor(tokenServer, rsaKey); + OkHttpClient client = new OkHttpClient.Builder() + .addInterceptor(authInterceptor.dpopRetryInterceptor()) + .build(); + + Request request = new Request.Builder() + .url(kasServer.url("/kas/rewrap")) + .post(okhttp3.RequestBody.create(new byte[0])) + .build(); + Response response = client.newCall(request).execute(); + response.close(); + + assertThat(kasServer.getRequestCount()).isEqualTo(2); + assertThat(response.code()).isEqualTo(401); + } + } + + @Test + void concurrentRequestsAllRetrySuccessfully() throws Exception { + // Smoke test: drive 10 parallel requests through the retry interceptor, each of + // which sees a 401+nonce followed by a 200. All 10 must eventually return 200 + // and each retry must carry a DPoP-Nonce claim. Regressions in the cross-thread + // safety of the nonce cache or interceptor state should surface here. + final int parallelism = 10; + RSAKey rsaKey = new RSAKeyGenerator(2048) + .keyUse(KeyUse.SIGNATURE) + .keyID(UUID.randomUUID().toString()) + .generate(); + try (MockWebServer tokenServer = new MockWebServer(); + MockWebServer kasServer = new MockWebServer()) { + // One token response per request — the cache keeps us to one in practice, + // but enqueue enough that any per-thread re-fetch doesn't deadlock the test. + for (int i = 0; i < parallelism * 2; i++) { + tokenServer.enqueue(new MockResponse() + .setBody(FAKE_TOKEN_RESPONSE) + .setHeader("Content-Type", "application/json")); + } + tokenServer.start(); + + // A FIFO queue can't deliver alternating 401/200 reliably under concurrent + // load — by the time request N's retry arrives, request N+1's first attempt + // may have already consumed N's 200. Use a stateful dispatcher that decides + // based on whether the request already carries a nonce. + kasServer.setDispatcher(new Dispatcher() { + @Override + public MockResponse dispatch(RecordedRequest request) { + String dpop = request.getHeader("DPoP"); + boolean hasNonce = false; + if (dpop != null) { + try { + hasNonce = SignedJWT.parse(dpop) + .getJWTClaimsSet().getStringClaim("nonce") != null; + } catch (Exception ignored) { + } + } + if (hasNonce) { + return new MockResponse().setResponseCode(200); + } + return new MockResponse() + .setResponseCode(401) + .addHeader("DPoP-Nonce", "concurrent-nonce") + .addHeader("WWW-Authenticate", "DPoP error=\"use_dpop_nonce\""); + } + }); + kasServer.start(); + + AuthInterceptor authInterceptor = buildAuthInterceptor(tokenServer, rsaKey); + OkHttpClient client = new OkHttpClient.Builder() + .addInterceptor(authInterceptor.dpopRetryInterceptor()) + .build(); + + ExecutorService pool = Executors.newFixedThreadPool(parallelism); + try { + List> tasks = new ArrayList<>(); + for (int i = 0; i < parallelism; i++) { + tasks.add(() -> { + Request request = new Request.Builder() + .url(kasServer.url("/kas/rewrap")) + .post(okhttp3.RequestBody.create(new byte[0])) + .build(); + try (Response response = client.newCall(request).execute()) { + return response.code(); + } + }); + } + List> results = pool.invokeAll(tasks, 30, TimeUnit.SECONDS); + for (Future f : results) { + assertThat(f.get()).isEqualTo(200); + } + } finally { + pool.shutdownNow(); + } + + // Each request produced a 401 + a retry: 2 * parallelism total. + assertThat(kasServer.getRequestCount()).isEqualTo(parallelism * 2); + + // Every retry must carry a nonce — pin that the cross-thread URL/nonce + // bookkeeping never produced a retry without one. + int retriesWithNonce = 0; + int totalRetries = 0; + for (int i = 0; i < parallelism * 2; i++) { + RecordedRequest recorded = kasServer.takeRequest(); + String dpop = recorded.getHeader("DPoP"); + if (dpop == null) { + continue; + } + String nonce = SignedJWT.parse(dpop).getJWTClaimsSet().getStringClaim("nonce"); + if (nonce != null) { + retriesWithNonce++; + } + if (nonce != null) { + totalRetries++; + } + } + assertThat(totalRetries).isEqualTo(parallelism); + assertThat(retriesWithNonce).isEqualTo(parallelism); + } + } + + @Test + void noRetryOn401WithDPoPNonceButNoChallenge() throws Exception { + // A bare DPoP-Nonce header on a 401 (no WWW-Authenticate) must not trigger a retry — + // otherwise any rogue origin can poison the nonce cache and burn a token round-trip. + RSAKey rsaKey = new RSAKeyGenerator(2048) + .keyUse(KeyUse.SIGNATURE) + .keyID(UUID.randomUUID().toString()) + .generate(); + try (MockWebServer tokenServer = new MockWebServer(); + MockWebServer kasServer = new MockWebServer()) { + tokenServer.enqueue(new MockResponse() + .setBody(FAKE_TOKEN_RESPONSE) + .setHeader("Content-Type", "application/json")); + tokenServer.start(); + + kasServer.enqueue(new MockResponse() + .setResponseCode(401) + .addHeader("DPoP-Nonce", "spurious-nonce")); + kasServer.start(); + + AuthInterceptor authInterceptor = buildAuthInterceptor(tokenServer, rsaKey); + OkHttpClient client = new OkHttpClient.Builder() + .addInterceptor(authInterceptor.dpopRetryInterceptor()) + .build(); + + Response response = client.newCall(new Request.Builder() + .url(kasServer.url("/kas/rewrap")) + .post(okhttp3.RequestBody.create(new byte[0])) + .build()).execute(); + response.close(); + + assertThat(kasServer.getRequestCount()).isEqualTo(1); + assertThat(response.code()).isEqualTo(401); + } + } + + @Test + void noRetryOn401WithNonDpopChallenge() throws Exception { + // WWW-Authenticate: Basic must not trigger a DPoP retry even if DPoP-Nonce is present. + RSAKey rsaKey = new RSAKeyGenerator(2048) + .keyUse(KeyUse.SIGNATURE) + .keyID(UUID.randomUUID().toString()) + .generate(); + try (MockWebServer tokenServer = new MockWebServer(); + MockWebServer kasServer = new MockWebServer()) { + tokenServer.enqueue(new MockResponse() + .setBody(FAKE_TOKEN_RESPONSE) + .setHeader("Content-Type", "application/json")); + tokenServer.start(); + + kasServer.enqueue(new MockResponse() + .setResponseCode(401) + .addHeader("DPoP-Nonce", "spurious-nonce") + .addHeader("WWW-Authenticate", "Basic realm=\"x\"")); + kasServer.start(); + + AuthInterceptor authInterceptor = buildAuthInterceptor(tokenServer, rsaKey); + OkHttpClient client = new OkHttpClient.Builder() + .addInterceptor(authInterceptor.dpopRetryInterceptor()) + .build(); + + Response response = client.newCall(new Request.Builder() + .url(kasServer.url("/kas/rewrap")) + .post(okhttp3.RequestBody.create(new byte[0])) + .build()).execute(); + response.close(); + + assertThat(kasServer.getRequestCount()).isEqualTo(1); + assertThat(response.code()).isEqualTo(401); + } + } + + @Test + void noRetryOn401WithDpopErrorOtherThanUseDpopNonce() throws Exception { + // RFC 9449 §9 only signals retry on error=use_dpop_nonce. Other DPoP errors + // (invalid_token, insufficient_scope, etc.) must surface to the caller. + RSAKey rsaKey = new RSAKeyGenerator(2048) + .keyUse(KeyUse.SIGNATURE) + .keyID(UUID.randomUUID().toString()) + .generate(); + try (MockWebServer tokenServer = new MockWebServer(); + MockWebServer kasServer = new MockWebServer()) { + tokenServer.enqueue(new MockResponse() + .setBody(FAKE_TOKEN_RESPONSE) + .setHeader("Content-Type", "application/json")); + tokenServer.start(); + + kasServer.enqueue(new MockResponse() + .setResponseCode(401) + .addHeader("DPoP-Nonce", "fresh-nonce") + .addHeader("WWW-Authenticate", "DPoP error=\"invalid_token\"")); + kasServer.start(); + + AuthInterceptor authInterceptor = buildAuthInterceptor(tokenServer, rsaKey); + OkHttpClient client = new OkHttpClient.Builder() + .addInterceptor(authInterceptor.dpopRetryInterceptor()) + .build(); + + Response response = client.newCall(new Request.Builder() + .url(kasServer.url("/kas/rewrap")) + .post(okhttp3.RequestBody.create(new byte[0])) + .build()).execute(); + response.close(); + + assertThat(kasServer.getRequestCount()).isEqualTo(1); + assertThat(response.code()).isEqualTo(401); + } + } + + @Test + void rotatedNonceFromSuccessfulResponseIsCachedForNextRequest() throws Exception { + // RFC 9449 §9: any response (including 200) may rotate the nonce. The retry + // interceptor must pick that up so the *next* request picks it from the cache. + // Note: the retry interceptor itself does not stamp DPoP headers on the initial + // request — those come from the auth path that builds the request — so we + // verify cache population by querying the TokenSource directly afterward. + RSAKey rsaKey = new RSAKeyGenerator(2048) + .keyUse(KeyUse.SIGNATURE) + .keyID(UUID.randomUUID().toString()) + .generate(); + try (MockWebServer tokenServer = new MockWebServer(); + MockWebServer kasServer = new MockWebServer()) { + tokenServer.enqueue(new MockResponse() + .setBody(FAKE_TOKEN_RESPONSE) + .setHeader("Content-Type", "application/json")); + tokenServer.start(); + + kasServer.enqueue(new MockResponse() + .setResponseCode(200) + .addHeader("DPoP-Nonce", "rotated-nonce")); + kasServer.start(); + + TokenSource ts = buildTokenSource(tokenServer, rsaKey); + OkHttpClient client = new OkHttpClient.Builder() + .addInterceptor(new AuthInterceptor(ts).dpopRetryInterceptor()) + .build(); + + client.newCall(new Request.Builder() + .url(kasServer.url("/kas/rewrap")) + .post(okhttp3.RequestBody.create(new byte[0])) + .build()).execute().close(); + + TokenSource.AuthHeaders headers = ts.getAuthHeaders( + kasServer.url("/kas/rewrap").url(), "POST"); + String nonceClaim = SignedJWT.parse(headers.getDpopHeader()) + .getJWTClaimsSet().getStringClaim("nonce"); + assertThat(nonceClaim).isEqualTo("rotated-nonce"); + } + } + + @Test + void noRetryOnSuccessResponse() throws Exception { + RSAKey rsaKey = new RSAKeyGenerator(2048) + .keyUse(KeyUse.SIGNATURE) + .keyID(UUID.randomUUID().toString()) + .generate(); + try (MockWebServer tokenServer = new MockWebServer(); + MockWebServer kasServer = new MockWebServer()) { + tokenServer.enqueue(new MockResponse() + .setBody(FAKE_TOKEN_RESPONSE) + .setHeader("Content-Type", "application/json")); + tokenServer.start(); + + kasServer.enqueue(new MockResponse().setResponseCode(200)); + kasServer.start(); + + AuthInterceptor authInterceptor = buildAuthInterceptor(tokenServer, rsaKey); + OkHttpClient client = new OkHttpClient.Builder() + .addInterceptor(authInterceptor.dpopRetryInterceptor()) + .build(); + + Request request = new Request.Builder() + .url(kasServer.url("/kas/rewrap")) + .post(okhttp3.RequestBody.create(new byte[0])) + .build(); + Response response = client.newCall(request).execute(); + response.close(); + + assertThat(kasServer.getRequestCount()).isEqualTo(1); + assertThat(response.code()).isEqualTo(200); + } + } +} diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/SDKBuilderTest.java b/sdk/src/test/java/io/opentdf/platform/sdk/SDKBuilderTest.java index b24dcbd8..bc0f2659 100644 --- a/sdk/src/test/java/io/opentdf/platform/sdk/SDKBuilderTest.java +++ b/sdk/src/test/java/io/opentdf/platform/sdk/SDKBuilderTest.java @@ -198,6 +198,104 @@ public String alg() { } } + @Test + void ecDpopKeyAutoGeneratesRsaSrtSigner() throws Exception { + // When the caller supplies an EC DPoP key, the SDK must auto-generate a separate + // RSA-2048 key for SRT signing because DefaultSrtSigner uses RSASSASigner which + // rejects non-RSA keys. Without this separation, build() would throw inside + // DefaultSrtSigner's constructor. + try (MockWebServer oidcServer = startMockOidcServer()) { + String issuer = oidcServer.url("my_realm").toString(); + Server platformServices = startWellKnownGrpcServer(issuer); + try { + com.nimbusds.jose.jwk.ECKey ecDpopKey = new com.nimbusds.jose.jwk.gen.ECKeyGenerator(com.nimbusds.jose.jwk.Curve.P_256) + .keyUse(com.nimbusds.jose.jwk.KeyUse.SIGNATURE) + .keyID(java.util.UUID.randomUUID().toString()) + .generate(); + + var sdk = SDKBuilder.newBuilder() + .clientSecret("user", "password") + .platformEndpoint("http://localhost:" + platformServices.getPort()) + .useInsecurePlaintextConnection(true) + .protocol(ProtocolType.GRPC) + .dpopKey(ecDpopKey) + .build(); + + assertThat(sdk.getSrtSigner()).isPresent(); + assertThat(sdk.getSrtSigner().get().alg()).isEqualTo("RS256"); + // Sanity-check: the SRT signer can actually sign, which would fail if it was + // mistakenly handed the EC key (RSASSASigner constructor would have thrown). + byte[] signed = sdk.getSrtSigner().get().sign(new byte[]{1, 2, 3}); + assertThat(signed).isNotEmpty(); + } finally { + platformServices.shutdownNow(); + } + } + } + + @Test + void rsaDpopKeyReusesSameKeyForSrt() throws Exception { + // When the caller supplies an RSA DPoP key, SDKBuilder reuses it for SRT signing + // (no second RSA key is generated). This test pins that behavior so a regression + // that splits the keys (and burns a key-generation per build) is caught. + try (MockWebServer oidcServer = startMockOidcServer()) { + String issuer = oidcServer.url("my_realm").toString(); + Server platformServices = startWellKnownGrpcServer(issuer); + try { + com.nimbusds.jose.jwk.RSAKey rsaDpopKey = new com.nimbusds.jose.jwk.gen.RSAKeyGenerator(2048) + .keyUse(com.nimbusds.jose.jwk.KeyUse.SIGNATURE) + .keyID(java.util.UUID.randomUUID().toString()) + .generate(); + + var sdk = SDKBuilder.newBuilder() + .clientSecret("user", "password") + .platformEndpoint("http://localhost:" + platformServices.getPort()) + .useInsecurePlaintextConnection(true) + .protocol(ProtocolType.GRPC) + .dpopKey(rsaDpopKey) + .build(); + + assertThat(sdk.getSrtSigner()).isPresent(); + assertThat(sdk.getSrtSigner().get().alg()).isEqualTo("RS256"); + } finally { + platformServices.shutdownNow(); + } + } + } + + private MockWebServer startMockOidcServer() throws IOException { + MockWebServer httpServer = new MockWebServer(); + httpServer.start(); + String issuer = httpServer.url("my_realm").toString(); + String tokenEndpoint = httpServer.url("tokens").toString(); + String oidcConfig; + try (var in = SDKBuilderTest.class.getResourceAsStream("/oidc-config.json")) { + oidcConfig = new String(in.readAllBytes(), StandardCharsets.UTF_8) + .replace("", issuer) + .replace("", tokenEndpoint); + } + httpServer.enqueue(new MockResponse().setBody(oidcConfig).setHeader("Content-type", "application/json")); + return httpServer; + } + + private Server startWellKnownGrpcServer(String issuer) throws IOException { + WellKnownServiceGrpc.WellKnownServiceImplBase wellKnownService = new WellKnownServiceGrpc.WellKnownServiceImplBase() { + @Override + public void getWellKnownConfiguration(GetWellKnownConfigurationRequest request, + StreamObserver responseObserver) { + var val = Value.newBuilder().setStringValue(issuer).build(); + var config = Struct.newBuilder().putFields("platform_issuer", val).build(); + responseObserver.onNext(GetWellKnownConfigurationResponse.newBuilder().setConfiguration(config).build()); + responseObserver.onCompleted(); + } + }; + return ServerBuilder.forPort(getRandomPort()) + .directExecutor() + .addService(wellKnownService) + .build() + .start(); + } + void sdkServicesSetup(boolean useSSLPlatform, boolean useSSLIDP) throws Exception { HeldCertificate rootCertificate = new HeldCertificate.Builder() @@ -367,7 +465,7 @@ public ServerCall.Listener interceptCall(ServerCall responseObserver) { + responseObserver.onNext(GetWellKnownConfigurationResponse.getDefaultInstance()); + responseObserver.onCompleted(); + } + }; + + Server platformServices = ServerBuilder + .forPort(getRandomPort()) + .directExecutor() + .addService(wellKnownService) + .build(); + try { + platformServices.start(); + + com.nimbusds.jose.jwk.RSAKey dpopKey = new com.nimbusds.jose.jwk.gen.RSAKeyGenerator(2048) + .keyUse(com.nimbusds.jose.jwk.KeyUse.SIGNATURE) + .keyID(java.util.UUID.randomUUID().toString()) + .generate(); + + SDKBuilder builder = SDKBuilder.newBuilder() + .clientSecret("user", "password") + .platformEndpoint("http://localhost:" + platformServices.getPort()) + .useInsecurePlaintextConnection(true) + .protocol(ProtocolType.GRPC) + .dpopKey(dpopKey); + + SDKException ex = assertThrows(SDKException.class, builder::build); + assertThat(ex.getMessage()).contains("DPoP").contains("platform_issuer"); + } finally { + platformServices.shutdownNow(); + } + } + @Test void testProtocolConfiguration() { // Test protocol setter and getter functionality diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/TokenSourceTest.java b/sdk/src/test/java/io/opentdf/platform/sdk/TokenSourceTest.java new file mode 100644 index 00000000..243fd357 --- /dev/null +++ b/sdk/src/test/java/io/opentdf/platform/sdk/TokenSourceTest.java @@ -0,0 +1,514 @@ +package io.opentdf.platform.sdk; + +import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.jwk.Curve; +import com.nimbusds.jose.jwk.ECKey; +import com.nimbusds.jose.jwk.JWK; +import com.nimbusds.jose.jwk.KeyUse; +import com.nimbusds.jose.jwk.RSAKey; +import com.nimbusds.jose.jwk.gen.ECKeyGenerator; +import com.nimbusds.jose.jwk.gen.RSAKeyGenerator; +import com.nimbusds.jwt.SignedJWT; +import com.nimbusds.oauth2.sdk.ClientCredentialsGrant; +import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic; +import com.nimbusds.oauth2.sdk.auth.Secret; +import com.nimbusds.oauth2.sdk.id.ClientID; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.jupiter.api.Test; + +import java.net.URL; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class TokenSourceTest { + + private static final String FAKE_TOKEN_RESPONSE = + "{\"access_token\":\"test-access-token\",\"token_type\":\"DPoP\",\"expires_in\":3600}"; + + private static final String BEARER_TOKEN_RESPONSE = + "{\"access_token\":\"plain-bearer-token\",\"token_type\":\"Bearer\",\"expires_in\":3600}"; + + private TokenSource buildTokenSource(MockWebServer tokenServer, RSAKey rsaKey) throws Exception { + return new TokenSource( + new ClientSecretBasic(new ClientID("test-client"), new Secret("test-secret")), + rsaKey, + JWSAlgorithm.RS256, + tokenServer.url("/token").uri(), + new ClientCredentialsGrant(), + null + ); + } + + @Test + void cachedNonceIsIncludedInNextDPoPProof() throws Exception { + RSAKey rsaKey = new RSAKeyGenerator(2048) + .keyUse(KeyUse.SIGNATURE) + .keyID(UUID.randomUUID().toString()) + .generate(); + try (MockWebServer tokenServer = new MockWebServer()) { + // Token endpoint queues two responses: one for initial fetch, one in case of re-fetch + tokenServer.enqueue(new MockResponse() + .setBody(FAKE_TOKEN_RESPONSE) + .setHeader("Content-Type", "application/json")); + tokenServer.enqueue(new MockResponse() + .setBody(FAKE_TOKEN_RESPONSE) + .setHeader("Content-Type", "application/json")); + tokenServer.start(); + + TokenSource ts = buildTokenSource(tokenServer, rsaKey); + URL testUrl = new URL("https://kas.example.com/kas"); + + ts.cacheNonce(testUrl, "server-nonce-abc"); + + TokenSource.AuthHeaders headers = ts.getAuthHeaders(testUrl, "POST"); + SignedJWT dpopJwt = SignedJWT.parse(headers.getDpopHeader()); + String nonceClaim = dpopJwt.getJWTClaimsSet().getStringClaim("nonce"); + + assertThat(nonceClaim).isEqualTo("server-nonce-abc"); + } + } + + @Test + void explicitNonceOverridesCachedNonce() throws Exception { + RSAKey rsaKey = new RSAKeyGenerator(2048) + .keyUse(KeyUse.SIGNATURE) + .keyID(UUID.randomUUID().toString()) + .generate(); + try (MockWebServer tokenServer = new MockWebServer()) { + tokenServer.enqueue(new MockResponse() + .setBody(FAKE_TOKEN_RESPONSE) + .setHeader("Content-Type", "application/json")); + tokenServer.enqueue(new MockResponse() + .setBody(FAKE_TOKEN_RESPONSE) + .setHeader("Content-Type", "application/json")); + tokenServer.start(); + + TokenSource ts = buildTokenSource(tokenServer, rsaKey); + URL testUrl = new URL("https://kas.example.com/kas"); + + ts.cacheNonce(testUrl, "cached-nonce"); + + TokenSource.AuthHeaders headers = ts.getAuthHeaders(testUrl, "POST", "explicit-nonce"); + SignedJWT dpopJwt = SignedJWT.parse(headers.getDpopHeader()); + String nonceClaim = dpopJwt.getJWTClaimsSet().getStringClaim("nonce"); + + assertThat(nonceClaim).isEqualTo("explicit-nonce"); + } + } + + @Test + void noNonceClaimWhenNoCachedNonce() throws Exception { + RSAKey rsaKey = new RSAKeyGenerator(2048) + .keyUse(KeyUse.SIGNATURE) + .keyID(UUID.randomUUID().toString()) + .generate(); + try (MockWebServer tokenServer = new MockWebServer()) { + tokenServer.enqueue(new MockResponse() + .setBody(FAKE_TOKEN_RESPONSE) + .setHeader("Content-Type", "application/json")); + tokenServer.enqueue(new MockResponse() + .setBody(FAKE_TOKEN_RESPONSE) + .setHeader("Content-Type", "application/json")); + tokenServer.start(); + + TokenSource ts = buildTokenSource(tokenServer, rsaKey); + URL testUrl = new URL("https://kas.example.com/kas"); + + TokenSource.AuthHeaders headers = ts.getAuthHeaders(testUrl, "POST"); + SignedJWT dpopJwt = SignedJWT.parse(headers.getDpopHeader()); + String nonceClaim = dpopJwt.getJWTClaimsSet().getStringClaim("nonce"); + + assertThat(nonceClaim).isNull(); + } + } + + @Test + void htuClaimStripsQueryAndFragment() throws Exception { + // RFC 9449 §4.2: htu must omit query and fragment, and Nimbus rejects + // anything else with IllegalArgumentException. + RSAKey rsaKey = new RSAKeyGenerator(2048) + .keyUse(KeyUse.SIGNATURE) + .keyID(UUID.randomUUID().toString()) + .generate(); + try (MockWebServer tokenServer = new MockWebServer()) { + tokenServer.enqueue(new MockResponse() + .setBody(FAKE_TOKEN_RESPONSE) + .setHeader("Content-Type", "application/json")); + tokenServer.start(); + + TokenSource ts = buildTokenSource(tokenServer, rsaKey); + URL urlWithQuery = new URL("https://kas.example.com/kas?foo=bar&baz=qux#frag"); + + TokenSource.AuthHeaders headers = ts.getAuthHeaders(urlWithQuery, "POST"); + SignedJWT dpopJwt = SignedJWT.parse(headers.getDpopHeader()); + String htu = dpopJwt.getJWTClaimsSet().getStringClaim("htu"); + + assertThat(htu).isEqualTo("https://kas.example.com/kas"); + } + } + + @Test + void ecKeyGeneratesDPoPProof() throws Exception { + ECKey ecKey = new ECKeyGenerator(Curve.P_256) + .keyUse(KeyUse.SIGNATURE) + .keyID(UUID.randomUUID().toString()) + .generate(); + try (MockWebServer tokenServer = new MockWebServer()) { + tokenServer.enqueue(new MockResponse() + .setBody(FAKE_TOKEN_RESPONSE) + .setHeader("Content-Type", "application/json")); + tokenServer.start(); + + TokenSource ts = new TokenSource( + new ClientSecretBasic(new ClientID("test-client"), new Secret("test-secret")), + ecKey, + JWSAlgorithm.ES256, + tokenServer.url("/token").uri(), + new ClientCredentialsGrant(), + null + ); + URL testUrl = new URL("https://kas.example.com/kas"); + ts.cacheNonce(testUrl, "ec-nonce"); + + TokenSource.AuthHeaders headers = ts.getAuthHeaders(testUrl, "POST"); + assertThat(headers.getAuthHeader()).startsWith("DPoP "); + + SignedJWT dpopJwt = SignedJWT.parse(headers.getDpopHeader()); + assertThat(dpopJwt.getHeader().getAlgorithm()).isEqualTo(JWSAlgorithm.ES256); + assertThat(dpopJwt.getJWTClaimsSet().getStringClaim("nonce")).isEqualTo("ec-nonce"); + } + } + + @Test + void noncesAreIsolatedByOrigin() throws Exception { + RSAKey rsaKey = new RSAKeyGenerator(2048) + .keyUse(KeyUse.SIGNATURE) + .keyID(UUID.randomUUID().toString()) + .generate(); + try (MockWebServer tokenServer = new MockWebServer()) { + for (int i = 0; i < 4; i++) { + tokenServer.enqueue(new MockResponse() + .setBody(FAKE_TOKEN_RESPONSE) + .setHeader("Content-Type", "application/json")); + } + tokenServer.start(); + + TokenSource ts = buildTokenSource(tokenServer, rsaKey); + URL kasUrl = new URL("https://kas.example.com/kas"); + URL otherUrl = new URL("https://other.example.com/kas"); + + ts.cacheNonce(kasUrl, "kas-nonce"); + + TokenSource.AuthHeaders headersForKas = ts.getAuthHeaders(kasUrl, "POST"); + TokenSource.AuthHeaders headersForOther = ts.getAuthHeaders(otherUrl, "POST"); + + String kasNonce = SignedJWT.parse(headersForKas.getDpopHeader()) + .getJWTClaimsSet().getStringClaim("nonce"); + String otherNonce = SignedJWT.parse(headersForOther.getDpopHeader()) + .getJWTClaimsSet().getStringClaim("nonce"); + + assertThat(kasNonce).isEqualTo("kas-nonce"); + assertThat(otherNonce).isNull(); + } + } + + @Test + void noncesAreIsolatedByPort() throws Exception { + RSAKey rsaKey = new RSAKeyGenerator(2048) + .keyUse(KeyUse.SIGNATURE) + .keyID(UUID.randomUUID().toString()) + .generate(); + try (MockWebServer tokenServer = new MockWebServer()) { + for (int i = 0; i < 2; i++) { + tokenServer.enqueue(new MockResponse() + .setBody(FAKE_TOKEN_RESPONSE) + .setHeader("Content-Type", "application/json")); + } + tokenServer.start(); + + TokenSource ts = buildTokenSource(tokenServer, rsaKey); + URL port8080 = new URL("https://kas.example.com:8080/kas"); + URL port9090 = new URL("https://kas.example.com:9090/kas"); + + ts.cacheNonce(port8080, "nonce-8080"); + TokenSource.AuthHeaders headers8080 = ts.getAuthHeaders(port8080, "POST"); + TokenSource.AuthHeaders headers9090 = ts.getAuthHeaders(port9090, "POST"); + + assertThat(SignedJWT.parse(headers8080.getDpopHeader()) + .getJWTClaimsSet().getStringClaim("nonce")).isEqualTo("nonce-8080"); + assertThat(SignedJWT.parse(headers9090.getDpopHeader()) + .getJWTClaimsSet().getStringClaim("nonce")).isNull(); + } + } + + @Test + void noncesAreIsolatedByScheme() throws Exception { + RSAKey rsaKey = new RSAKeyGenerator(2048) + .keyUse(KeyUse.SIGNATURE) + .keyID(UUID.randomUUID().toString()) + .generate(); + try (MockWebServer tokenServer = new MockWebServer()) { + for (int i = 0; i < 2; i++) { + tokenServer.enqueue(new MockResponse() + .setBody(FAKE_TOKEN_RESPONSE) + .setHeader("Content-Type", "application/json")); + } + tokenServer.start(); + + TokenSource ts = buildTokenSource(tokenServer, rsaKey); + URL httpsUrl = new URL("https://kas.example.com/kas"); + URL httpUrl = new URL("http://kas.example.com/kas"); + + ts.cacheNonce(httpsUrl, "https-nonce"); + TokenSource.AuthHeaders headersHttps = ts.getAuthHeaders(httpsUrl, "POST"); + TokenSource.AuthHeaders headersHttp = ts.getAuthHeaders(httpUrl, "POST"); + + assertThat(SignedJWT.parse(headersHttps.getDpopHeader()) + .getJWTClaimsSet().getStringClaim("nonce")).isEqualTo("https-nonce"); + assertThat(SignedJWT.parse(headersHttp.getDpopHeader()) + .getJWTClaimsSet().getStringClaim("nonce")).isNull(); + } + } + + @Test + void noncesShareCacheForImplicitAndExplicitDefaultPort() throws Exception { + // Pins the getDefaultPort() normalization in TokenSource.getOrigin — + // https://host and https://host:443 must share a cache entry. + RSAKey rsaKey = new RSAKeyGenerator(2048) + .keyUse(KeyUse.SIGNATURE) + .keyID(UUID.randomUUID().toString()) + .generate(); + try (MockWebServer tokenServer = new MockWebServer()) { + tokenServer.enqueue(new MockResponse() + .setBody(FAKE_TOKEN_RESPONSE) + .setHeader("Content-Type", "application/json")); + tokenServer.start(); + + TokenSource ts = buildTokenSource(tokenServer, rsaKey); + URL implicitPort = new URL("https://kas.example.com/kas"); + URL explicitPort = new URL("https://kas.example.com:443/kas"); + + ts.cacheNonce(explicitPort, "shared-nonce"); + TokenSource.AuthHeaders headers = ts.getAuthHeaders(implicitPort, "POST"); + + assertThat(SignedJWT.parse(headers.getDpopHeader()) + .getJWTClaimsSet().getStringClaim("nonce")).isEqualTo("shared-nonce"); + } + } + + @Test + void emptyNonceIsNotCached() throws Exception { + RSAKey rsaKey = new RSAKeyGenerator(2048) + .keyUse(KeyUse.SIGNATURE) + .keyID(UUID.randomUUID().toString()) + .generate(); + try (MockWebServer tokenServer = new MockWebServer()) { + tokenServer.enqueue(new MockResponse() + .setBody(FAKE_TOKEN_RESPONSE) + .setHeader("Content-Type", "application/json")); + tokenServer.start(); + + TokenSource ts = buildTokenSource(tokenServer, rsaKey); + URL testUrl = new URL("https://kas.example.com/kas"); + + ts.cacheNonce(testUrl, ""); + ts.cacheNonce(testUrl, null); + + TokenSource.AuthHeaders headers = ts.getAuthHeaders(testUrl, "POST"); + String nonceClaim = SignedJWT.parse(headers.getDpopHeader()) + .getJWTClaimsSet().getStringClaim("nonce"); + + assertThat(nonceClaim).isNull(); + } + } + + @Test + void getToken_retriesWithNonceOnUseDpopNonce() throws Exception { + RSAKey rsaKey = new RSAKeyGenerator(2048) + .keyUse(KeyUse.SIGNATURE) + .keyID(UUID.randomUUID().toString()) + .generate(); + try (MockWebServer tokenServer = new MockWebServer()) { + // First: 401 use_dpop_nonce + tokenServer.enqueue(new MockResponse() + .setResponseCode(401) + .setHeader("Content-Type", "application/json") + .addHeader("DPoP-Nonce", "retry-nonce-abc") + .setBody("{\"error\":\"use_dpop_nonce\",\"error_description\":\"nonce required\"}")); + // Second: success + tokenServer.enqueue(new MockResponse() + .setResponseCode(200) + .setHeader("Content-Type", "application/json") + .setBody("{\"access_token\":\"real-token\",\"token_type\":\"DPoP\",\"expires_in\":3600}")); + tokenServer.start(); + + TokenSource ts = buildTokenSource(tokenServer, rsaKey); + URL resourceUrl = new URL("https://kas.example.com/kas"); + + TokenSource.AuthHeaders headers = ts.getAuthHeaders(resourceUrl, "POST"); + + assertThat(headers.getAuthHeader()).isEqualTo("DPoP real-token"); + assertThat(tokenServer.getRequestCount()).isEqualTo(2); + + RecordedRequest first = tokenServer.takeRequest(); + String firstNonce = SignedJWT.parse(first.getHeader("DPoP")) + .getJWTClaimsSet().getStringClaim("nonce"); + assertThat(firstNonce).isNull(); + + RecordedRequest second = tokenServer.takeRequest(); + String secondNonce = SignedJWT.parse(second.getHeader("DPoP")) + .getJWTClaimsSet().getStringClaim("nonce"); + assertThat(secondNonce).isEqualTo("retry-nonce-abc"); + } + } + + @Test + void getToken_throwsDescriptiveErrorWhenUseDpopNonceLacksHeader() throws Exception { + RSAKey rsaKey = new RSAKeyGenerator(2048) + .keyUse(KeyUse.SIGNATURE) + .keyID(UUID.randomUUID().toString()) + .generate(); + try (MockWebServer tokenServer = new MockWebServer()) { + // 401 use_dpop_nonce but NO DPoP-Nonce header + tokenServer.enqueue(new MockResponse() + .setResponseCode(401) + .setHeader("Content-Type", "application/json") + .setBody("{\"error\":\"use_dpop_nonce\",\"error_description\":\"nonce required\"}")); + tokenServer.start(); + + TokenSource ts = buildTokenSource(tokenServer, rsaKey); + + assertThatThrownBy(() -> ts.getAuthHeaders(new URL("https://kas.example.com/kas"), "POST")) + .isInstanceOf(SDKException.class) + .satisfies(e -> assertThat(e.getMessage() + (e.getCause() != null ? e.getCause().getMessage() : "")) + .contains("use_dpop_nonce")); + } + } + + @Test + void getToken_surfacesMalformedTokenResponseDistinctly() throws Exception { + // Pre-fix every token-fetch failure surfaced as "failed to get token" with the + // cause buried. After C3 the parse failure is attributed to the token endpoint. + RSAKey rsaKey = new RSAKeyGenerator(2048) + .keyUse(KeyUse.SIGNATURE) + .keyID(UUID.randomUUID().toString()) + .generate(); + try (MockWebServer tokenServer = new MockWebServer()) { + // 200 with non-JSON body trips ParseException inside nimbus. + tokenServer.enqueue(new MockResponse() + .setResponseCode(200) + .setHeader("Content-Type", "application/json") + .setBody("this is not json")); + tokenServer.start(); + + TokenSource ts = buildTokenSource(tokenServer, rsaKey); + + assertThatThrownBy(() -> ts.getAuthHeaders(new URL("https://kas.example.com/kas"), "POST")) + .isInstanceOf(SDKException.class) + .hasMessageContaining("malformed token response") + .hasMessageContaining(tokenServer.url("/token").toString()); + } + } + + @Test + void constructor_rejectsRsaKeyWithEcAlgorithm() throws Exception { + RSAKey rsaKey = new RSAKeyGenerator(2048).keyUse(KeyUse.SIGNATURE) + .keyID(UUID.randomUUID().toString()).generate(); + assertThatThrownBy(() -> new TokenSource( + new ClientSecretBasic(new ClientID("c"), new Secret("s")), + rsaKey, JWSAlgorithm.ES256, + new URL("https://idp.example.com/token").toURI(), + new ClientCredentialsGrant(), null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("RSA") + .hasMessageContaining("ES256"); + } + + @Test + void constructor_rejectsEcKeyWithMismatchedCurveAlgorithm() throws Exception { + ECKey ecKey = new ECKeyGenerator(Curve.P_256).keyUse(KeyUse.SIGNATURE) + .keyID(UUID.randomUUID().toString()).generate(); + assertThatThrownBy(() -> new TokenSource( + new ClientSecretBasic(new ClientID("c"), new Secret("s")), + ecKey, JWSAlgorithm.ES384, + new URL("https://idp.example.com/token").toURI(), + new ClientCredentialsGrant(), null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("P-256") + .hasMessageContaining("ES384"); + } + + @Test + void constructor_rejectsUnsupportedJwkType() throws Exception { + // OKP type — parsed from a static JWK to avoid pulling in the Tink dependency + // that OctetKeyPairGenerator needs at runtime. + JWK okp = JWK.parse("{\"kty\":\"OKP\",\"crv\":\"Ed25519\"," + + "\"x\":\"11qYAYKxCrfVS_7TyWQHOg7hcvPapiMlrwIaaPcHURo\"," + + "\"d\":\"nWGxne_9WmC6hEr0kuwsxERJxWl7MmkZcDusAxyuf2A\"}"); + assertThatThrownBy(() -> new TokenSource( + new ClientSecretBasic(new ClientID("c"), new Secret("s")), + okp, JWSAlgorithm.EdDSA, + new URL("https://idp.example.com/token").toURI(), + new ClientCredentialsGrant(), null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Unsupported JWK type"); + } + + @Test + void getToken_usesProactivelyCachedNonce() throws Exception { + RSAKey rsaKey = new RSAKeyGenerator(2048) + .keyUse(KeyUse.SIGNATURE) + .keyID(UUID.randomUUID().toString()) + .generate(); + try (MockWebServer tokenServer = new MockWebServer()) { + tokenServer.enqueue(new MockResponse() + .setResponseCode(200) + .setHeader("Content-Type", "application/json") + .setBody("{\"access_token\":\"proactive-token\",\"token_type\":\"DPoP\",\"expires_in\":3600}")); + tokenServer.start(); + + TokenSource ts = buildTokenSource(tokenServer, rsaKey); + // Pre-seed the cache for the token endpoint origin + ts.cacheNonce(tokenServer.url("/token").url(), "proactive-nonce"); + + ts.getAuthHeaders(new URL("https://kas.example.com/kas"), "POST"); + + assertThat(tokenServer.getRequestCount()).isEqualTo(1); + + RecordedRequest request = tokenServer.takeRequest(); + String nonceClaim = SignedJWT.parse(request.getHeader("DPoP")) + .getJWTClaimsSet().getStringClaim("nonce"); + assertThat(nonceClaim).isEqualTo("proactive-nonce"); + } + } + + @Test + void getAuthHeaders_downgradesToBearerWhenTokenEndpointReturnsBearer() throws Exception { + // Keycloak realms with DPoP disabled return token_type=Bearer even when the client + // sent a DPoP proof. The SDK must not emit "Authorization: DPoP " — + // that scheme is reserved for DPoP-bound tokens (RFC 9449 §7.1) and a DPoP-enforcing + // resource server will reject it. Downgrade to Bearer and omit the proof. + RSAKey rsaKey = new RSAKeyGenerator(2048) + .keyUse(KeyUse.SIGNATURE) + .keyID(UUID.randomUUID().toString()) + .generate(); + try (MockWebServer tokenServer = new MockWebServer()) { + tokenServer.enqueue(new MockResponse() + .setResponseCode(200) + .setHeader("Content-Type", "application/json") + .setBody(BEARER_TOKEN_RESPONSE)); + tokenServer.start(); + + TokenSource ts = buildTokenSource(tokenServer, rsaKey); + + TokenSource.AuthHeaders headers = ts.getAuthHeaders(new URL("https://kas.example.com/kas"), "POST"); + + assertThat(headers.getAuthHeader()).isEqualTo("Bearer plain-bearer-token"); + assertThat(headers.getDpopHeader()).isNull(); + } + } +}