From b4c9dccb741754348f14742a33412cfcb2b7e93b Mon Sep 17 00:00:00 2001 From: OmarAlJarrah Date: Wed, 17 Jun 2026 06:30:32 +0300 Subject: [PATCH] feat: add async retry and bearer-auth pipeline steps MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add the AsyncHttpStep counterparts for the RETRY and AUTH pillar stages so async calls get the same retry and authentication behaviour as the synchronous pipeline. DefaultAsyncRetryStep mirrors DefaultRetryStep's policy exactly — it reuses HttpRetryOptions, the shared BackoffCalculator, RetryAfterParser, the same Retry-After header set, and the same idempotency-aware re-sendability gating. Backoff delays are scheduled on a ScheduledExecutorService via Futures.delay rather than blocking a thread, and the retry loop is driven by an iterative trampoline (no per-attempt thenCompose recursion), so a long zero-delay retry sequence stays stack-safe. AsyncBearerTokenAuthStep stamps Authorization: Bearer via a new non-blocking BearerTokenProvider.fetchAsync seam. A token that is still valid but inside the refresh margin is returned and stamped immediately while a refresh runs off-thread; concurrent requests that observe an expiring or missing token share a single in-flight fetch (single-flight) so they don't stampede the token endpoint. The HTTPS guard, cross-origin credential suppression, and 401-challenge token eviction match the synchronous BearerTokenAuthStep. A ManualScheduler test fixture drives the scheduled delays deterministically so the retry tests run without real sleeps. --- sdk-core/api/sdk-core.api | 41 ++ .../sdk/core/http/auth/BearerTokenProvider.kt | 43 ++ .../core/http/pipeline/steps/AsyncAuthStep.kt | 150 +++++++ .../steps/AsyncBearerTokenAuthStep.kt | 276 ++++++++++++ .../http/pipeline/steps/AsyncRetryStep.kt | 26 ++ .../pipeline/steps/DefaultAsyncRetryStep.kt | 398 ++++++++++++++++++ .../steps/AsyncBearerTokenAuthStepTest.kt | 329 +++++++++++++++ .../steps/DefaultAsyncRetryStepTest.kt | 383 +++++++++++++++++ .../sdk/core/testing/ManualScheduler.kt | 192 +++++++++ 9 files changed, 1838 insertions(+) create mode 100644 sdk-core/src/main/kotlin/org/dexpace/sdk/core/http/pipeline/steps/AsyncAuthStep.kt create mode 100644 sdk-core/src/main/kotlin/org/dexpace/sdk/core/http/pipeline/steps/AsyncBearerTokenAuthStep.kt create mode 100644 sdk-core/src/main/kotlin/org/dexpace/sdk/core/http/pipeline/steps/AsyncRetryStep.kt create mode 100644 sdk-core/src/main/kotlin/org/dexpace/sdk/core/http/pipeline/steps/DefaultAsyncRetryStep.kt create mode 100644 sdk-core/src/test/kotlin/org/dexpace/sdk/core/http/pipeline/steps/AsyncBearerTokenAuthStepTest.kt create mode 100644 sdk-core/src/test/kotlin/org/dexpace/sdk/core/http/pipeline/steps/DefaultAsyncRetryStepTest.kt create mode 100644 sdk-core/src/testFixtures/kotlin/org/dexpace/sdk/core/testing/ManualScheduler.kt diff --git a/sdk-core/api/sdk-core.api b/sdk-core/api/sdk-core.api index 7ebcc312..023f2296 100644 --- a/sdk-core/api/sdk-core.api +++ b/sdk-core/api/sdk-core.api @@ -138,10 +138,14 @@ public final class org/dexpace/sdk/core/http/auth/BearerToken : org/dexpace/sdk/ public abstract interface class org/dexpace/sdk/core/http/auth/BearerTokenProvider { public fun fetch (Ljava/util/List;)Lorg/dexpace/sdk/core/http/auth/BearerToken; public abstract fun fetch (Ljava/util/List;Ljava/util/Map;)Lorg/dexpace/sdk/core/http/auth/BearerToken; + public fun fetchAsync (Ljava/util/List;)Ljava/util/concurrent/CompletableFuture; + public fun fetchAsync (Ljava/util/List;Ljava/util/Map;)Ljava/util/concurrent/CompletableFuture; } public final class org/dexpace/sdk/core/http/auth/BearerTokenProvider$DefaultImpls { public static fun fetch (Lorg/dexpace/sdk/core/http/auth/BearerTokenProvider;Ljava/util/List;)Lorg/dexpace/sdk/core/http/auth/BearerToken; + public static fun fetchAsync (Lorg/dexpace/sdk/core/http/auth/BearerTokenProvider;Ljava/util/List;)Ljava/util/concurrent/CompletableFuture; + public static fun fetchAsync (Lorg/dexpace/sdk/core/http/auth/BearerTokenProvider;Ljava/util/List;Ljava/util/Map;)Ljava/util/concurrent/CompletableFuture; } public abstract interface class org/dexpace/sdk/core/http/auth/ChallengeHandler { @@ -747,6 +751,30 @@ public final class org/dexpace/sdk/core/http/pipeline/Stage : java/lang/Enum { public static fun values ()[Lorg/dexpace/sdk/core/http/pipeline/Stage; } +public abstract class org/dexpace/sdk/core/http/pipeline/steps/AsyncAuthStep : org/dexpace/sdk/core/http/pipeline/AsyncHttpStep { + public fun ()V + protected abstract fun authorizeRequestAsync (Lorg/dexpace/sdk/core/http/request/Request;)Ljava/util/concurrent/CompletableFuture; + protected fun authorizeRequestOnChallengeAsync (Lorg/dexpace/sdk/core/http/request/Request;Lorg/dexpace/sdk/core/http/response/Response;)Ljava/util/concurrent/CompletableFuture; + public final fun getStage ()Lorg/dexpace/sdk/core/http/pipeline/Stage; + public final fun processAsync (Lorg/dexpace/sdk/core/http/request/Request;Lorg/dexpace/sdk/core/http/pipeline/AsyncPipelineNext;)Ljava/util/concurrent/CompletableFuture; +} + +public class org/dexpace/sdk/core/http/pipeline/steps/AsyncBearerTokenAuthStep : org/dexpace/sdk/core/http/pipeline/steps/AsyncAuthStep { + public fun (Lorg/dexpace/sdk/core/http/auth/BearerTokenProvider;Ljava/util/List;)V + public fun (Lorg/dexpace/sdk/core/http/auth/BearerTokenProvider;Ljava/util/List;Ljava/time/Duration;)V + public fun (Lorg/dexpace/sdk/core/http/auth/BearerTokenProvider;Ljava/util/List;Ljava/time/Duration;Lorg/dexpace/sdk/core/util/Clock;)V + public fun (Lorg/dexpace/sdk/core/http/auth/BearerTokenProvider;Ljava/util/List;Ljava/time/Duration;Lorg/dexpace/sdk/core/util/Clock;Lorg/dexpace/sdk/core/instrumentation/ClientLogger;)V + public synthetic fun (Lorg/dexpace/sdk/core/http/auth/BearerTokenProvider;Ljava/util/List;Ljava/time/Duration;Lorg/dexpace/sdk/core/util/Clock;Lorg/dexpace/sdk/core/instrumentation/ClientLogger;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + protected fun authorizeRequestAsync (Lorg/dexpace/sdk/core/http/request/Request;)Ljava/util/concurrent/CompletableFuture; + protected fun authorizeRequestOnChallengeAsync (Lorg/dexpace/sdk/core/http/request/Request;Lorg/dexpace/sdk/core/http/response/Response;)Ljava/util/concurrent/CompletableFuture; + protected fun bearerHeaderValue (Ljava/lang/String;)Ljava/lang/String; +} + +public abstract class org/dexpace/sdk/core/http/pipeline/steps/AsyncRetryStep : org/dexpace/sdk/core/http/pipeline/AsyncHttpStep { + public fun ()V + public final fun getStage ()Lorg/dexpace/sdk/core/http/pipeline/Stage; +} + public abstract class org/dexpace/sdk/core/http/pipeline/steps/AuthStep : org/dexpace/sdk/core/http/pipeline/HttpStep { public fun ()V protected abstract fun authorizeRequest (Lorg/dexpace/sdk/core/http/request/Request;)Lorg/dexpace/sdk/core/http/request/Request; @@ -775,6 +803,19 @@ public final class org/dexpace/sdk/core/http/pipeline/steps/DefaultAsyncInstrume public fun processAsync (Lorg/dexpace/sdk/core/http/request/Request;Lorg/dexpace/sdk/core/http/pipeline/AsyncPipelineNext;)Ljava/util/concurrent/CompletableFuture; } +public class org/dexpace/sdk/core/http/pipeline/steps/DefaultAsyncRetryStep : org/dexpace/sdk/core/http/pipeline/steps/AsyncRetryStep { + public static final field Companion Lorg/dexpace/sdk/core/http/pipeline/steps/DefaultAsyncRetryStep$Companion; + public fun (Ljava/util/concurrent/ScheduledExecutorService;)V + public fun (Ljava/util/concurrent/ScheduledExecutorService;Lorg/dexpace/sdk/core/http/pipeline/steps/HttpRetryOptions;)V + public fun (Ljava/util/concurrent/ScheduledExecutorService;Lorg/dexpace/sdk/core/http/pipeline/steps/HttpRetryOptions;Lorg/dexpace/sdk/core/util/Clock;)V + public fun (Ljava/util/concurrent/ScheduledExecutorService;Lorg/dexpace/sdk/core/http/pipeline/steps/HttpRetryOptions;Lorg/dexpace/sdk/core/util/Clock;Lorg/dexpace/sdk/core/instrumentation/ClientLogger;)V + public synthetic fun (Ljava/util/concurrent/ScheduledExecutorService;Lorg/dexpace/sdk/core/http/pipeline/steps/HttpRetryOptions;Lorg/dexpace/sdk/core/util/Clock;Lorg/dexpace/sdk/core/instrumentation/ClientLogger;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun processAsync (Lorg/dexpace/sdk/core/http/request/Request;Lorg/dexpace/sdk/core/http/pipeline/AsyncPipelineNext;)Ljava/util/concurrent/CompletableFuture; +} + +public final class org/dexpace/sdk/core/http/pipeline/steps/DefaultAsyncRetryStep$Companion { +} + public final class org/dexpace/sdk/core/http/pipeline/steps/DefaultInstrumentationStep : org/dexpace/sdk/core/http/pipeline/steps/InstrumentationStep { public fun ()V public fun (Lorg/dexpace/sdk/core/http/pipeline/steps/HttpInstrumentationOptions;)V diff --git a/sdk-core/src/main/kotlin/org/dexpace/sdk/core/http/auth/BearerTokenProvider.kt b/sdk-core/src/main/kotlin/org/dexpace/sdk/core/http/auth/BearerTokenProvider.kt index 2ba71503..830d518c 100644 --- a/sdk-core/src/main/kotlin/org/dexpace/sdk/core/http/auth/BearerTokenProvider.kt +++ b/sdk-core/src/main/kotlin/org/dexpace/sdk/core/http/auth/BearerTokenProvider.kt @@ -7,6 +7,8 @@ package org.dexpace.sdk.core.http.auth +import java.util.concurrent.CompletableFuture + /** * Source of fresh [BearerToken]s for [org.dexpace.sdk.core.http.pipeline.steps.BearerTokenAuthStep]. * @@ -44,4 +46,45 @@ public fun interface BearerTokenProvider { /** Convenience for callers without extra params; forwards to [fetch] with an empty map. */ public fun fetch(scopes: List): BearerToken = fetch(scopes, emptyMap()) + + /** + * Asynchronous counterpart of [fetch], used by + * [org.dexpace.sdk.core.http.pipeline.steps.AsyncBearerTokenAuthStep] so a token refresh + * never blocks the request-dispatching thread. + * + * The default implementation invokes the blocking [fetch] **on the calling thread** and + * wraps the outcome into an already-completed [CompletableFuture] (completing exceptionally + * if [fetch] throws). That default is correct but not non-blocking: a provider that talks to + * a remote token endpoint should override this method to dispatch the fetch off-thread — + * e.g. submit to an [java.util.concurrent.Executor], or call an async OAuth client — so the + * returned future completes without parking the caller. + * + * Per-cloud providers (GCP / Azure / Kubernetes workload identity) and OAuth + * token-exchange flows belong in adapter modules, not in `sdk-core`; this seam is what they + * override. + * + * @param scopes OAuth scopes to request; service-specific. + * @param params extra parameters to pass through to the token endpoint. + * @return a future that completes with a fresh [BearerToken], or completes exceptionally + * with whatever [fetch] threw. + */ + public fun fetchAsync( + scopes: List, + params: Map, + ): CompletableFuture = + try { + CompletableFuture.completedFuture(fetch(scopes, params)) + } catch (t: Throwable) { + // A provider's blocking fetch may throw any Throwable. Surface it through the + // future rather than synchronously so async callers observe a uniform error model. + // Error subclasses (OOM, StackOverflow) are intentionally NOT special-cased here: + // the default just mirrors fetch()'s outcome into the future, and an Error in a + // user-supplied lambda is still that lambda's failure, not a JVM-fatal one for us. + val failed = CompletableFuture() + failed.completeExceptionally(t) + failed + } + + /** Convenience for callers without extra params; forwards to [fetchAsync] with an empty map. */ + public fun fetchAsync(scopes: List): CompletableFuture = fetchAsync(scopes, emptyMap()) } diff --git a/sdk-core/src/main/kotlin/org/dexpace/sdk/core/http/pipeline/steps/AsyncAuthStep.kt b/sdk-core/src/main/kotlin/org/dexpace/sdk/core/http/pipeline/steps/AsyncAuthStep.kt new file mode 100644 index 00000000..4534e4e5 --- /dev/null +++ b/sdk-core/src/main/kotlin/org/dexpace/sdk/core/http/pipeline/steps/AsyncAuthStep.kt @@ -0,0 +1,150 @@ +/* + * Copyright (c) 2026 dexpace and Omar Aljarrah + * + * Licensed under the MIT License. See LICENSE in the project root. + * SPDX-License-Identifier: MIT + */ + +package org.dexpace.sdk.core.http.pipeline.steps + +import org.dexpace.sdk.core.http.common.HttpHeaderName +import org.dexpace.sdk.core.http.pipeline.AsyncHttpStep +import org.dexpace.sdk.core.http.pipeline.AsyncPipelineNext +import org.dexpace.sdk.core.http.pipeline.Stage +import org.dexpace.sdk.core.http.request.Request +import org.dexpace.sdk.core.http.response.Response +import org.dexpace.sdk.core.util.Futures +import java.util.concurrent.CompletableFuture + +/** + * Async pillar step at [Stage.AUTH] — the [AsyncHttpStep] counterpart of [AuthStep]. Stamps + * credentials onto outgoing requests via an async [authorizeRequestAsync] (so a token fetch / + * refresh never blocks the dispatching thread) and exposes the same 401 + `WWW-Authenticate` + * challenge hook. + * + * The stamping and challenge semantics mirror [AuthStep] exactly: + * + * - **HTTPS-only.** On the path that attaches a credential, [processAsync] rejects non-HTTPS + * schemes before any token fetch. The guard is skipped on the marker-suppressed cross-origin + * re-issue path, where no credential is attached. + * - **Cross-origin redirects.** A re-issue marked by the redirect step (see + * [CrossOriginRedirectMarker]) is forwarded credential-free; the marker is stripped before + * the request reaches the wire. + * - **Challenge retry.** On a 401 carrying `WWW-Authenticate`, [authorizeRequestOnChallengeAsync] + * is consulted; a non-null replacement is driven through the chain exactly once (no further + * challenge handling). The default returns a future of `null` (no retry). + * + * Unlike the synchronous [AuthStep] the credential-attaching guard checks and the downstream + * dispatches are composed on [CompletableFuture]s so the whole flow stays non-blocking. + * + * ## Thread-safety + * + * The stage is locked at the type level via `final override`. Concrete subclasses must be safe + * for concurrent invocation — see [AsyncBearerTokenAuthStep] (single-flight token refresh). + */ +public abstract class AsyncAuthStep : AsyncHttpStep { + final override val stage: Stage = Stage.AUTH + + final override fun processAsync( + request: Request, + next: AsyncPipelineNext, + ): CompletableFuture { + val authorizedFuture: CompletableFuture = + if (CrossOriginRedirectMarker.isMarked(request)) { + // Cross-origin redirect re-issue: strip the marker, attach no credential. + CompletableFuture.completedFuture( + request.newBuilder() + .headers(CrossOriginRedirectMarker.strip(request.headers)) + .build(), + ) + } else { + val scheme = request.url.protocol + if (!"https".equals(scheme, ignoreCase = true)) { + Futures.failed( + IllegalStateException( + "${this::class.simpleName} requires HTTPS to prevent credential leak " + + "(URL scheme: $scheme)", + ), + ) + } else { + authorizeRequestAsync(request) + } + } + + return authorizedFuture.thenCompose { authorized -> + next.copy().processAsync(authorized).thenCompose { response -> + handleChallenge(authorized, response, next) + } + } + } + + /** + * After the first downstream attempt, applies the 401 + `WWW-Authenticate` challenge hook. + * Returns the response unchanged unless [authorizeRequestOnChallengeAsync] yields a non-null + * replacement, in which case the original 401 is closed and the replacement is driven once. + */ + private fun handleChallenge( + authorized: Request, + response: Response, + next: AsyncPipelineNext, + ): CompletableFuture { + if (response.status.code != SC_UNAUTHORIZED) return CompletableFuture.completedFuture(response) + response.headers.get(HttpHeaderName.WWW_AUTHENTICATE) + ?: return CompletableFuture.completedFuture(response) + + val challengeFuture: CompletableFuture = + try { + authorizeRequestOnChallengeAsync(authorized, response) + } catch (t: Throwable) { + // A sync throw from the hook (caller-bug case) must still close the 401 body. + response.close() + return Futures.failed(t) + } + + return challengeFuture.handle { retryRequest, hookError -> + HookOutcome(retryRequest, hookError) + }.thenCompose { outcome -> + val hookError = outcome.error + if (hookError != null) { + response.close() + return@thenCompose Futures.failed(Futures.unwrap(hookError)) + } + val retryRequest = outcome.request ?: return@thenCompose CompletableFuture.completedFuture(response) + response.close() + next.copy().processAsync(retryRequest) + } + } + + /** Carrier so the challenge future's outcome (value or error) survives [CompletableFuture.handle]. */ + private class HookOutcome(val request: Request?, val error: Throwable?) + + /** + * Returns a future of [request] with the credential's auth header attached. Subclasses + * implement the concrete async stamping (e.g. fetch-or-refresh a bearer token off-thread, + * then stamp `Authorization: Bearer `). + * + * Called once per request before the downstream chain is invoked. + */ + protected abstract fun authorizeRequestAsync(request: Request): CompletableFuture + + /** + * Hook invoked on a 401 response that carries a `WWW-Authenticate` header. The default + * returns a future of `null` — surface the 401 with no retry. + * + * Subclasses override to refresh tokens or step up auth. A non-null [Request] in the + * returned future triggers a single retry through the downstream chain; the original 401 is + * closed first. + * + * @param request the request already stamped with the credential that produced the 401. + * @param response the 401 response; its body is still open at this point. + */ + protected open fun authorizeRequestOnChallengeAsync( + request: Request, + response: Response, + ): CompletableFuture = CompletableFuture.completedFuture(null) + + private companion object { + // HTTP 401 — the only status code AsyncAuthStep responds to. + private const val SC_UNAUTHORIZED = 401 + } +} diff --git a/sdk-core/src/main/kotlin/org/dexpace/sdk/core/http/pipeline/steps/AsyncBearerTokenAuthStep.kt b/sdk-core/src/main/kotlin/org/dexpace/sdk/core/http/pipeline/steps/AsyncBearerTokenAuthStep.kt new file mode 100644 index 00000000..d877e177 --- /dev/null +++ b/sdk-core/src/main/kotlin/org/dexpace/sdk/core/http/pipeline/steps/AsyncBearerTokenAuthStep.kt @@ -0,0 +1,276 @@ +/* + * Copyright (c) 2026 dexpace and Omar Aljarrah + * + * Licensed under the MIT License. See LICENSE in the project root. + * SPDX-License-Identifier: MIT + */ + +package org.dexpace.sdk.core.http.pipeline.steps + +import org.dexpace.sdk.core.http.auth.AuthChallengeParser +import org.dexpace.sdk.core.http.auth.BearerToken +import org.dexpace.sdk.core.http.auth.BearerTokenProvider +import org.dexpace.sdk.core.http.common.HttpHeaderName +import org.dexpace.sdk.core.http.request.Request +import org.dexpace.sdk.core.http.response.Response +import org.dexpace.sdk.core.instrumentation.ClientLogger +import org.dexpace.sdk.core.util.Clock +import java.time.Duration +import java.util.concurrent.CompletableFuture +import java.util.concurrent.locks.ReentrantLock +import kotlin.concurrent.withLock + +/** + * [AsyncAuthStep] that stamps `Authorization: Bearer ` on outgoing requests, fetching the + * token from [provider] via [BearerTokenProvider.fetchAsync] so a refresh never blocks the + * dispatching thread. The async counterpart of [BearerTokenAuthStep]. + * + * ## Background refresh + * + * The cached token is held in a `@Volatile` field; the hot-path read is wait-free. The expiry + * decision has three zones, keyed off [refreshMargin]: + * + * - **Fresh** — `now + margin < expiresAt`: stamp the cached token, no refresh. + * - **Expiring** — `expiresAt - margin <= now < expiresAt`: the token is still *valid*, so it + * is returned **immediately** and stamped, while a refresh is kicked off **off-thread** in the + * background. The in-flight request never waits on the token endpoint; the next request picks + * up the refreshed token once it lands. This is the behaviour issue #32 asks for: a + * valid-but-expiring token returned without blocking. + * - **Expired / missing** — `now >= expiresAt`, or no cached token: there is no usable + * credential, so the request **must** await a fresh fetch before it can be stamped. + * + * ## Single-flight + * + * Concurrent requests that all observe an expiring or missing token share **one** in-flight + * [provider] call rather than each hitting the token endpoint (no stampede). The in-flight future + * is published under a [ReentrantLock]; whoever wins the lock starts the fetch, everyone else + * joins the same future. On completion the in-flight slot is cleared so a later refresh starts a + * new fetch. A failed fetch is **not** cached — the next request retries. + * + * ## Eviction on 401 + * + * Like [BearerTokenAuthStep], a 401 + `WWW-Authenticate: Bearer` evicts the rejected token and + * re-stamps the single [AsyncAuthStep] retry with a freshly fetched one. Eviction is scoped to + * the exact token that produced the 401; a token a concurrent request already refreshed is left + * in place. A 401 with no bearer challenge (or none at all), or a request that reached the hook + * credential-free (cross-origin suppression), surfaces unchanged. + * + * ## Errors from [provider] + * + * - Future completes exceptionally → propagated; not cached, so a later request retries. + * - Future completes with `null` → surfaced as [IllegalStateException]. + * - Future completes with an already-expired token → surfaced as [IllegalStateException]. + * + * A background (expiring-zone) refresh that fails or returns an unusable token does **not** fail + * the in-flight request — the still-valid cached token was already stamped — it only logs and + * leaves the cache for the next request to refresh. + * + * ## Open for subclassing + * + * Override [bearerHeaderValue] to change the header format (and keep eviction matching), or + * [authorizeRequestOnChallengeAsync] to customise challenge handling. + */ +public open class AsyncBearerTokenAuthStep + @JvmOverloads + constructor( + private val provider: BearerTokenProvider, + private val scopes: List, + private val refreshMargin: Duration = Duration.ofSeconds(DEFAULT_REFRESH_MARGIN_SECONDS), + private val clock: Clock = Clock.SYSTEM, + private val logger: ClientLogger = ClientLogger(AsyncBearerTokenAuthStep::class), + ) : AsyncAuthStep() { + private val lock = ReentrantLock() + + @Volatile + private var cachedToken: BearerToken? = null + + // The single shared in-flight fetch, or null when none is running. Published / cleared + // under [lock] so concurrent expiring/missing requests coalesce onto one provider call. + @Volatile + private var inFlight: CompletableFuture? = null + + override fun authorizeRequestAsync(request: Request): CompletableFuture = + currentToken().thenApply { token -> stamp(request, token) } + + override fun authorizeRequestOnChallengeAsync( + request: Request, + response: Response, + ): CompletableFuture { + // No credential on the rejected request → stamping was suppressed (cross-origin + // redirect). Surface the 401 unchanged. + val rejectedHeader = + request.headers.get(HttpHeaderName.AUTHORIZATION) + ?: return CompletableFuture.completedFuture(null) + // A token refresh can only satisfy a Bearer challenge. + if (!offersBearerChallenge(response)) return CompletableFuture.completedFuture(null) + evictRejectedToken(rejectedHeader) + // forceFresh: the rejected token was just evicted; await a genuinely fresh fetch + // before re-stamping so the retry never carries the same rejected credential. + return forceFreshToken().thenApply { token -> stamp(request, token) as Request? } + } + + private fun stamp( + request: Request, + token: BearerToken, + ): Request = + request.newBuilder() + .setHeader(HttpHeaderName.AUTHORIZATION.caseSensitiveName, bearerHeaderValue(token.token)) + .build() + + /** + * Resolves the token to stamp for the current request, applying the three-zone expiry + * policy (fresh / expiring / expired). Never blocks: the returned future completes + * immediately when a usable cached token exists (even if a background refresh is also + * kicked off), and otherwise completes when the single-flight fetch lands. + */ + private fun currentToken(): CompletableFuture { + val now = clock.now() + val cached = cachedToken + if (cached != null && !cached.isExpiredAt(now)) { + // Token is still valid. If it is inside the refresh margin, return it now and + // refresh in the background; otherwise just return it. + if (cached.isExpiredAt(now, refreshMargin)) { + startBackgroundRefresh() + } + return CompletableFuture.completedFuture(cached) + } + // Missing or hard-expired: must await a fresh fetch. + return forceFreshToken() + } + + /** + * Kicks off a single-flight refresh whose result the caller does NOT await. Used on the + * expiring-but-valid path: the still-valid cached token is what the in-flight request + * stamps; this just warms the cache for the next request. + */ + private fun startBackgroundRefresh() { + // Reuse the single-flight machinery; ignore the returned future (fire-and-forget). + // exceptionally/handle keeps an unhandled background failure from surfacing as an + // uncaught CompletableFuture completion. + sharedFetch().whenComplete { _, error -> + if (error != null) { + logger.atWarning() + .event("http.auth.background_refresh_failed") + .field("error.type", error::class.java.simpleName ?: "Throwable") + .cause(error) + .log() + } + } + } + + /** + * Returns a future that completes with a usable token, awaiting the single-flight fetch. + * Used when there is no usable cached token (missing / hard-expired) and on the + * post-eviction challenge path. + */ + private fun forceFreshToken(): CompletableFuture = sharedFetch() + + /** + * Single-flight fetch coordinator. The first caller to find no in-flight fetch starts one + * (publishing it under [lock]); concurrent callers join the same future. The in-flight + * slot is cleared on completion so a subsequent refresh starts fresh. A re-check of the + * cache inside the lock means a token another thread just refreshed short-circuits the + * fetch. + */ + private fun sharedFetch(): CompletableFuture { + lock.withLock { + // Re-read inside the lock: another thread may have just refreshed. + val now = clock.now() + cachedToken?.takeIf { !it.isExpiredAt(now, refreshMargin) } + ?.let { return CompletableFuture.completedFuture(it) } + inFlight?.let { return it } + val fetch = launchFetch() + inFlight = fetch + // Attach cache bookkeeping AFTER publishing, so the clear-on-complete callback + // compares against the exact future stored in `inFlight`. Attaching here (rather + // than inside launchFetch) sidesteps the self-reference an inline-completed future + // would otherwise need. + fetch.whenComplete { token, error -> + lock.withLock { + if (inFlight === fetch) inFlight = null + if (error == null && token != null) cachedToken = token + } + } + return fetch + } + } + + /** + * Starts the provider fetch and applies validation (null token, already-expired token) so + * a misbehaving provider surfaces as [IllegalStateException] to the awaiting caller. Cache + * population and in-flight clearing are attached by [sharedFetch] against the published + * future. + */ + private fun launchFetch(): CompletableFuture { + val raw: CompletableFuture = + try { + fetchAsyncSafe() + } catch (t: Throwable) { + // A provider whose fetchAsync throws synchronously (caller-bug) — normalise. + val failed = CompletableFuture() + failed.completeExceptionally(t) + failed + } + return raw.thenApply { token -> validateFresh(token) } + } + + @Suppress( + "UNCHECKED_CAST", + "RedundantNullableReturnType", + ) + private fun fetchAsyncSafe(): CompletableFuture { + val future: CompletableFuture? = + provider.fetchAsync(scopes) as CompletableFuture? + return future ?: error("BearerTokenProvider.fetchAsync returned null") + } + + /** + * Validates a freshly fetched token: rejects a `null` (Kotlin intrinsics usually catch + * this earlier for Kotlin SAMs, but a platform-disabled-intrinsics context may not) and a + * token already expired at fetch time (no margin applied — a provider minting an + * effectively-expired token is misbehaving). + */ + @Suppress("UNCHECKED_CAST", "RedundantNullableReturnType") + private fun validateFresh(token: BearerToken): BearerToken { + val nonNull: BearerToken = (token as BearerToken?) ?: error("BearerTokenProvider returned null") + check(!nonNull.isExpiredAt(clock.now())) { + "BearerTokenProvider returned an already-expired token" + } + return nonNull + } + + /** + * Returns `true` when [response]'s `WWW-Authenticate` header advertises a `Bearer` + * challenge. + */ + private fun offersBearerChallenge(response: Response): Boolean { + val header = response.headers.get(HttpHeaderName.WWW_AUTHENTICATE) ?: return false + return AuthChallengeParser.parse(header).any { it.scheme == BEARER_SCHEME } + } + + /** + * Clears [cachedToken] iff it is still the token whose stamped header is [rejectedHeader]. + * Guarded by the same [lock] as the fetch path so the read-compare-clear is atomic against + * a concurrent refresh. + */ + private fun evictRejectedToken(rejectedHeader: String) { + lock.withLock { + val current = cachedToken ?: return + if (bearerHeaderValue(current.token) == rejectedHeader) { + cachedToken = null + } + } + } + + /** + * The `Authorization` header value for [token]. Single source of truth shared by the + * stamping path and [evictRejectedToken]. A subclass that emits a different header format + * must override this too, or eviction stops matching. + */ + protected open fun bearerHeaderValue(token: String): String = "Bearer $token" + + private companion object { + private const val DEFAULT_REFRESH_MARGIN_SECONDS = 30L + private const val BEARER_SCHEME = "bearer" + } + } diff --git a/sdk-core/src/main/kotlin/org/dexpace/sdk/core/http/pipeline/steps/AsyncRetryStep.kt b/sdk-core/src/main/kotlin/org/dexpace/sdk/core/http/pipeline/steps/AsyncRetryStep.kt new file mode 100644 index 00000000..ad2ec8e1 --- /dev/null +++ b/sdk-core/src/main/kotlin/org/dexpace/sdk/core/http/pipeline/steps/AsyncRetryStep.kt @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2026 dexpace and Omar Aljarrah + * + * Licensed under the MIT License. See LICENSE in the project root. + * SPDX-License-Identifier: MIT + */ + +package org.dexpace.sdk.core.http.pipeline.steps + +import org.dexpace.sdk.core.http.pipeline.AsyncHttpStep +import org.dexpace.sdk.core.http.pipeline.Stage + +/** + * Async pillar step at [Stage.RETRY] — the [AsyncHttpStep] counterpart of [RetryStep]. Drives + * an async retry loop with the same classification policy, backoff schedule, and + * `Retry-After` pacing as the synchronous stack, but schedules its delays on a + * [java.util.concurrent.ScheduledExecutorService] instead of blocking a thread. + * + * The base is `abstract` because the stage is locked to [Stage.RETRY] at the type level — + * users implementing custom async-retry behaviour override + * [AsyncHttpStep.processAsync] but inherit the pillar slot. The shipped concrete + * implementation is [DefaultAsyncRetryStep]. + */ +public abstract class AsyncRetryStep : AsyncHttpStep { + final override val stage: Stage = Stage.RETRY +} diff --git a/sdk-core/src/main/kotlin/org/dexpace/sdk/core/http/pipeline/steps/DefaultAsyncRetryStep.kt b/sdk-core/src/main/kotlin/org/dexpace/sdk/core/http/pipeline/steps/DefaultAsyncRetryStep.kt new file mode 100644 index 00000000..ebeece9f --- /dev/null +++ b/sdk-core/src/main/kotlin/org/dexpace/sdk/core/http/pipeline/steps/DefaultAsyncRetryStep.kt @@ -0,0 +1,398 @@ +/* + * Copyright (c) 2026 dexpace and Omar Aljarrah + * + * Licensed under the MIT License. See LICENSE in the project root. + * SPDX-License-Identifier: MIT + */ + +package org.dexpace.sdk.core.http.pipeline.steps + +import org.dexpace.sdk.core.http.common.HttpHeaderName +import org.dexpace.sdk.core.http.pipeline.AsyncPipelineNext +import org.dexpace.sdk.core.http.request.Method +import org.dexpace.sdk.core.http.request.Request +import org.dexpace.sdk.core.http.response.Response +import org.dexpace.sdk.core.instrumentation.ClientLogger +import org.dexpace.sdk.core.pipeline.step.retry.BackoffCalculator +import org.dexpace.sdk.core.pipeline.step.retry.RetryAfterParser +import org.dexpace.sdk.core.pipeline.step.retry.RetrySettings +import org.dexpace.sdk.core.util.Clock +import org.dexpace.sdk.core.util.Futures +import java.io.IOException +import java.time.Duration +import java.util.concurrent.CompletableFuture +import java.util.concurrent.ScheduledExecutorService + +/** + * Default [AsyncRetryStep] — the async mirror of [DefaultRetryStep]. Re-invokes the downstream + * async chain on classified failures with exponential / fixed backoff, server-supplied + * [HttpHeaderName.RETRY_AFTER] pacing, and idempotency-aware re-sendability gating, sharing the + * exact same policy as the synchronous stack: the same [HttpRetryOptions], the same + * [BackoffCalculator], the same [RetryAfterParser], the same `Retry-After` header set, and the + * same idempotent-method allow-list. + * + * ## Non-blocking delays + * + * Backoff delays are scheduled on a [ScheduledExecutorService] via [Futures.delay] — no + * `Thread.sleep`, no `Timer`. While a delay is pending the dispatching thread is free; the loop + * re-arms when the scheduled future fires. + * + * ## Loop shape — stack-safe, no `thenCompose` recursion + * + * The retry loop is driven iteratively by [drive], not by chaining `thenCompose` per attempt. + * Each downstream attempt registers a single [CompletableFuture.whenComplete] callback. When the + * outcome warrants another attempt, the callback hands control back to [drive] through a + * trampoline ([RetryDriver.continuation]) instead of calling [drive] recursively, so a retry + * sequence of length N never builds an N-deep stack frame chain or an N-deep future + * continuation graph. This mirrors the iterative `while` loop of [DefaultRetryStep] while staying + * fully async. + * + * ## Re-sendability gating + * + * Identical to [DefaultRetryStep]: + * - **No body** — retried only when the method is idempotent ([IDEMPOTENT_METHODS]); a bare + * non-idempotent `POST` is not retried even though there is nothing to re-send. + * - **Has a body** — retried only when [org.dexpace.sdk.core.http.request.RequestBody.isReplayable]. + * + * When the request is not re-sendable the loop runs exactly one attempt and completes with the + * response (or the failure) as-is. + * + * ## Failure handling + * + * Downstream failures surface as exceptionally-completed futures; the loop unwraps the + * [java.util.concurrent.CompletionException] wrapper via [Futures.unwrap] before classifying. + * Only [Exception] subclasses are classified — an [Error] (OOM, StackOverflow) completes the + * call exceptionally without retry. On terminal failure every prior attempt's exception is + * attached to the surfaced exception via [Throwable.addSuppressed]. + * + * ## Thread-safety + * + * Stateless after construction (the per-call [RetryDriver] holds all mutable loop state). The + * immutable [options] / [clock] / [scheduler] and the [ClientLogger] are shared across + * concurrent calls. + */ +public open class DefaultAsyncRetryStep + @JvmOverloads + constructor( + private val scheduler: ScheduledExecutorService, + options: HttpRetryOptions = HttpRetryOptions(), + private val clock: Clock = Clock.SYSTEM, + internal val logger: ClientLogger = ClientLogger(DefaultAsyncRetryStep::class), + ) : AsyncRetryStep() { + /** Effective options. `maxRetries < 0` is clamped to [DefaultRetryStep.DEFAULT_MAX_RETRIES]. */ + private val options: HttpRetryOptions = clampOptions(options) + + /** + * The [options]' exponential parameters as a [RetrySettings] view so the shared + * [BackoffCalculator] computes this stack's schedule — built once, exactly as + * [DefaultRetryStep.backoffSettings]. `totalTimeout = ZERO` disables the deadline cap. + * Building it eagerly validates the delay magnitudes at construction. + */ + private val backoffSettings: RetrySettings = + RetrySettings.builder() + .initialDelay(this.options.baseDelay) + .maxDelay(this.options.maxDelay) + .delayMultiplier(RetrySettings.DEFAULT_DELAY_MULTIPLIER) + .jitter(RetrySettings.DEFAULT_JITTER) + .totalTimeout(Duration.ZERO) + .build() + + override fun processAsync( + request: Request, + next: AsyncPipelineNext, + ): CompletableFuture { + val result = CompletableFuture() + val driver = RetryDriver(next, isRetrySafe(request), result) + driver.drive() + return result + } + + /** + * Per-call mutable loop state plus the trampolining driver. One instance per + * [processAsync] call; never shared across calls or threads (each call's continuations + * run sequentially on the scheduler / completing thread, never concurrently with + * themselves). + */ + private inner class RetryDriver( + private val next: AsyncPipelineNext, + private val retrySafe: Boolean, + private val result: CompletableFuture, + ) { + private var tryCount = 0 + private val sequenceStartNanos = clock.monotonic() + + // Lazily allocated on first failure so the success path never pays for the list. + private var suppressed: MutableList? = null + + // Trampoline state. `pumping` is true while the synchronous pump loop in drive() is + // active; `rearm` records that another attempt should run. A re-arm that happens + // while the pump is active (an inline / zero-delay retry) just sets `rearm` and lets + // the loop pick it up — it never recurses into a new drive() frame. A re-arm that + // happens after the pump has exited (the common async case, fired from a scheduler + // or downstream-completion thread) starts a fresh, shallow pump. Both paths run + // sequentially per call — drive() is only ever entered by one thread at a time + // because each attempt's continuation fires exactly once and the previous pump has + // returned before the async callback runs. + private var pumping: Boolean = false + private var rearm: Boolean = false + + /** + * Entry point and trampoline. Marks that an attempt should run ([rearm]); if a pump + * loop is already active it returns immediately (the active loop will pick up the + * re-arm), otherwise it runs the loop. The loop keeps starting attempts as long as + * inline completions keep setting [rearm], so a burst of zero-delay retries unwinds + * iteratively instead of recursing. + */ + fun drive() { + rearm = true + if (pumping) return + pumping = true + try { + while (rearm) { + rearm = false + startAttempt() + } + } finally { + pumping = false + } + } + + /** + * Launches one downstream attempt and registers its completion handler. When the + * attempt (and its retry decision + zero-length delay) complete inline, the handler + * runs synchronously and re-enters [drive] — which, because [pumping] is still true, + * merely sets [rearm] for the active loop. When the attempt completes later, the + * handler runs on the completing thread and starts a fresh pump. + */ + private fun startAttempt() { + val attempt: CompletableFuture = + try { + next.copy().processAsync() + } catch (e: Exception) { + // The async chain contract permits sync exceptions only for caller-bug + // cases; normalise to a failed future so classification is uniform. + Futures.failed(e) + } + attempt.whenComplete { response, error -> handleOutcome(response, error) } + } + + /** Routes a completed attempt to the success or failure handler. */ + private fun handleOutcome( + response: Response?, + error: Throwable?, + ) { + if (error == null) { + onSuccess(response!!) + } else { + onFailure(error) + } + } + + private fun onSuccess(response: Response) { + val retry = + retrySafe && + tryCount < options.maxRetries && + shouldRetryResponse(response) + if (!retry) { + result.complete(response) + return + } + val delay = computeResponseDelay(response, tryCount) + logRetry(tryCount, delay, response.status.code, cause = null) + closeQuietly(response) + tryCount++ + scheduleNext(delay) + } + + private fun onFailure(rawError: Throwable) { + val error = Futures.unwrap(rawError) + // Errors (OOM, StackOverflow, …) are unrecoverable — never retry, never log. + if (error is Error) { + result.completeExceptionally(error) + return + } + val exception = error as Exception + val retry = + retrySafe && + tryCount < options.maxRetries && + shouldRetryException(exception) + if (!retry) { + suppressed?.forEach(exception::addSuppressed) + result.completeExceptionally(exception) + return + } + val accumulator = suppressed ?: ArrayList().also { suppressed = it } + val delay = computeExceptionDelay(exception, tryCount) + logRetry(tryCount, delay, statusCode = -1, cause = exception) + // Record the current failure BEFORE scheduling so it is attached to any later + // terminal exception's suppressed list rather than being silently dropped. + accumulator.add(exception) + tryCount++ + scheduleNext(delay) + } + + /** + * Schedules the next attempt after [delay]. [Futures.delay] returns an + * already-complete future for a zero delay, so the [CompletableFuture.whenComplete] + * callback runs inline and re-enters [drive] while the pump loop is still active — a + * synchronous re-arm with no extra stack frame. For a positive delay the scheduled + * future fires later on the scheduler thread and starts a fresh pump. + */ + private fun scheduleNext(delay: Duration) { + val safeDelay = if (delay.isNegative) Duration.ZERO else delay + Futures.delay(scheduler, safeDelay).whenComplete { _, scheduleError -> + if (scheduleError != null) { + // The scheduler rejected or failed the delay task — surface it with any + // accumulated prior failures attached. + suppressed?.forEach(scheduleError::addSuppressed) + result.completeExceptionally(scheduleError) + } else { + drive() + } + } + } + + // --------------- Classification --------------- + + private fun shouldRetryResponse(response: Response): Boolean { + val condition = HttpRetryCondition(response, null, tryCount, (suppressed ?: emptyList())) + return invokeShouldRetry(options.shouldRetryCondition, condition) + } + + private fun shouldRetryException(exception: Exception): Boolean { + val condition = HttpRetryCondition(null, exception, tryCount, (suppressed ?: emptyList())) + return invokeShouldRetry(options.shouldRetryException, condition) + } + + // --------------- Delay computation --------------- + + private fun computeResponseDelay( + response: Response, + tryCount: Int, + ): Duration { + val condition = HttpRetryCondition(response, null, tryCount, (suppressed ?: emptyList())) + invokeDelayFromCondition(condition)?.let { return it } + retryAfterFromHeaders(response)?.let { return it } + return backoffOrFixed(tryCount) + } + + private fun computeExceptionDelay( + exception: Exception, + tryCount: Int, + ): Duration { + val condition = HttpRetryCondition(null, exception, tryCount, (suppressed ?: emptyList())) + invokeDelayFromCondition(condition)?.let { return it } + return backoffOrFixed(tryCount) + } + + // --------------- Logging --------------- + + private fun logRetry( + tryCount: Int, + delay: Duration, + statusCode: Int, + cause: Throwable?, + ) { + val event = + logger.atInfo() + .event("http.retry") + .field("http.retry.try_count", tryCount.toLong()) + .field("http.retry.delay_ms", delay.toMillis()) + .field("retry.total_elapsed_ms", (clock.monotonic() - sequenceStartNanos) / NANOS_PER_MILLI) + if (statusCode > 0) { + event.field("http.response.status_code", statusCode.toLong()) + } + if (cause != null) { + event.field("error.type", cause::class.java.simpleName ?: "Throwable") + .field("retry.cause_class", cause::class.simpleName ?: "Throwable") + .cause(cause) + } + event.log() + } + } + + // --------------- Shared helpers (stateless across calls) --------------- + + private fun isRetrySafe(request: Request): Boolean { + val body = request.body ?: return request.method in IDEMPOTENT_METHODS + return body.isReplayable() + } + + private fun invokeShouldRetry( + predicate: HttpRetryConditionPredicate, + condition: HttpRetryCondition, + ): Boolean = + try { + predicate.shouldRetry(condition) + } catch (t: Throwable) { + @Suppress("InstanceOfCheckForException") + if (t is Error) throw t + throw IllegalStateException("shouldRetry predicate threw", t) + } + + private fun invokeDelayFromCondition(condition: HttpRetryCondition): Duration? = + try { + options.delayFromCondition.delayFor(condition) + } catch (t: Throwable) { + @Suppress("InstanceOfCheckForException") + if (t is Error) throw t + logger.atWarning() + .event("http.retry.delay_override_failed") + .field("error.type", t::class.java.simpleName ?: "Throwable") + .cause(t) + .log() + null + } + + private fun backoffOrFixed(tryCount: Int): Duration = + options.fixedDelay ?: BackoffCalculator.computeDelay(tryCount + 1, backoffSettings) + + private fun retryAfterFromHeaders(response: Response): Duration? { + val now = clock.now() + for (name in options.retryAfterHeaders) { + val raw = response.headers.get(name) ?: continue + RetryAfterParser.parseHeaderValue(name, raw, now)?.let { return it } + } + return null + } + + private fun closeQuietly(response: Response) { + try { + response.close() + } catch (closeErr: IOException) { + logger.atVerbose() + .event("http.retry.close_failed") + .field("error.type", closeErr::class.java.simpleName ?: "IOException") + .log() + } + } + + private fun clampOptions(opts: HttpRetryOptions): HttpRetryOptions { + if (opts.maxRetries >= 0) return opts + logger.atVerbose() + .event("http.retry.maxRetries_clamped") + .field("http.retry.max_retries.requested", opts.maxRetries.toLong()) + .field("http.retry.max_retries.applied", DefaultRetryStep.DEFAULT_MAX_RETRIES.toLong()) + .log() + return HttpRetryOptions( + maxRetries = DefaultRetryStep.DEFAULT_MAX_RETRIES, + baseDelay = opts.baseDelay, + maxDelay = opts.maxDelay, + fixedDelay = opts.fixedDelay, + retryAfterHeaders = opts.retryAfterHeaders, + shouldRetryCondition = opts.shouldRetryCondition, + shouldRetryException = opts.shouldRetryException, + delayFromCondition = opts.delayFromCondition, + ) + } + + public companion object { + // Nanoseconds in one millisecond — converts monotonic deltas to ms for log events. + private const val NANOS_PER_MILLI = 1_000_000L + + // Methods safe to re-send regardless of body replayability (idempotent per RFC 9110). + // Mirrors DefaultRetryStep.IDEMPOTENT_METHODS / RetrySettings.DEFAULT_RETRYABLE_METHODS. + private val IDEMPOTENT_METHODS: Set = + setOf(Method.GET, Method.HEAD, Method.OPTIONS, Method.PUT, Method.DELETE) + } + } diff --git a/sdk-core/src/test/kotlin/org/dexpace/sdk/core/http/pipeline/steps/AsyncBearerTokenAuthStepTest.kt b/sdk-core/src/test/kotlin/org/dexpace/sdk/core/http/pipeline/steps/AsyncBearerTokenAuthStepTest.kt new file mode 100644 index 00000000..d3ad74bd --- /dev/null +++ b/sdk-core/src/test/kotlin/org/dexpace/sdk/core/http/pipeline/steps/AsyncBearerTokenAuthStepTest.kt @@ -0,0 +1,329 @@ +/* + * Copyright (c) 2026 dexpace and Omar Aljarrah + * + * Licensed under the MIT License. See LICENSE in the project root. + * SPDX-License-Identifier: MIT + */ + +package org.dexpace.sdk.core.http.pipeline.steps + +import org.dexpace.sdk.core.client.AsyncHttpClient +import org.dexpace.sdk.core.http.auth.BearerToken +import org.dexpace.sdk.core.http.auth.BearerTokenProvider +import org.dexpace.sdk.core.http.common.Headers +import org.dexpace.sdk.core.http.common.HttpHeaderName +import org.dexpace.sdk.core.http.common.Protocol +import org.dexpace.sdk.core.http.pipeline.AsyncHttpPipelineBuilder +import org.dexpace.sdk.core.http.request.Method +import org.dexpace.sdk.core.http.request.Request +import org.dexpace.sdk.core.http.response.Response +import org.dexpace.sdk.core.http.response.Status +import org.dexpace.sdk.core.io.Io +import org.dexpace.sdk.core.testing.FixedClock +import org.dexpace.sdk.core.util.Futures +import org.dexpace.sdk.io.OkioIoProvider +import java.time.Duration +import java.time.Instant +import java.util.concurrent.CompletableFuture +import java.util.concurrent.atomic.AtomicInteger +import kotlin.test.BeforeTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFails +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +class AsyncBearerTokenAuthStepTest { + private val now: Instant = Instant.parse("2026-01-01T12:00:00Z") + private val clock = FixedClock(now) + + @BeforeTest + fun setUp() { + Io.installProvider(OkioIoProvider) + } + + // ----------------- Basic stamping ----------------- + + @Test + fun `stamps a fresh token without blocking`() { + val provider = CountingProvider { BearerToken("tok", now.plusSeconds(3600)) } + val client = RecordingClient(200) + val future = pipeline(provider, client).sendAsync(getRequest()) + val response = future.join() + assertEquals(200, response.status.code) + assertEquals("Bearer tok", client.lastAuth) + assertEquals(1, provider.fetchCount) + } + + @Test + fun `caches the token across requests - one fetch for two sends`() { + val provider = CountingProvider { BearerToken("tok", now.plusSeconds(3600)) } + val client = RecordingClient(200) + val p = pipeline(provider, client) + p.sendAsync(getRequest()).join() + p.sendAsync(getRequest()).join() + assertEquals(1, provider.fetchCount) + } + + // ----------------- Background refresh of an expiring token ----------------- + + @Test + fun `valid-but-expiring token is returned immediately and refreshed in the background`() { + // Token expires in 10s; refresh margin is 30s → inside the margin but still valid. + val deferred = CompletableFuture() + // Seed the cache with the expiring token by handing it back on the first (awaited) fetch. + val seeded = CompletableFuture.completedFuture(BearerToken("old", now.plusSeconds(10))) + val seedingProvider = + object : BearerTokenProvider { + val fetches = AtomicInteger(0) + val gate = AtomicInteger(0) + + override fun fetch( + scopes: List, + params: Map, + ): BearerToken = error("blocking fetch must not be called") + + override fun fetchAsync( + scopes: List, + params: Map, + ): CompletableFuture { + fetches.incrementAndGet() + return if (gate.getAndIncrement() == 0) seeded else deferred + } + } + val client = RecordingClient(200) + val p = pipeline(seedingProvider, client) + + // First request: no cache → awaits the seeding fetch → stamps "old". + p.sendAsync(getRequest()).join() + assertEquals("Bearer old", client.lastAuth) + + // Second request: cached "old" is valid (expires in 10s) but inside the 30s margin → + // it is returned IMMEDIATELY and stamped; the refresh is kicked off in the background and + // is still pending (deferred not completed). + val secondFuture = p.sendAsync(getRequest()) + val second = secondFuture.join() + assertEquals(200, second.status.code) + assertEquals("Bearer old", client.lastAuth, "expiring-but-valid token must be stamped without waiting") + assertEquals(2, seedingProvider.fetches.get(), "a background refresh must have been started") + assertFalse(deferred.isDone) + + // Complete the background refresh; the NEXT request now sees the new token. + deferred.complete(BearerToken("new", now.plusSeconds(3600))) + p.sendAsync(getRequest()).join() + assertEquals("Bearer new", client.lastAuth) + } + + // ----------------- Single-flight ----------------- + + @Test + fun `concurrent expired-token requests share one fetch`() { + // No cached token → both requests must await; only ONE provider fetch should happen. + val deferred = CompletableFuture() + val fetches = AtomicInteger(0) + val provider = + object : BearerTokenProvider { + override fun fetch( + scopes: List, + params: Map, + ): BearerToken = error("blocking fetch must not be called") + + override fun fetchAsync( + scopes: List, + params: Map, + ): CompletableFuture { + fetches.incrementAndGet() + return deferred + } + } + val client = RecordingClient(200) + val p = pipeline(provider, client) + + val f1 = p.sendAsync(getRequest()) + val f2 = p.sendAsync(getRequest()) + // Both are parked on the single shared fetch. + assertFalse(f1.isDone) + assertFalse(f2.isDone) + assertEquals(1, fetches.get(), "concurrent requests must coalesce onto one fetch") + + deferred.complete(BearerToken("tok", now.plusSeconds(3600))) + assertEquals(200, f1.join().status.code) + assertEquals(200, f2.join().status.code) + assertEquals(1, fetches.get()) + } + + // ----------------- Provider errors ----------------- + + @Test + fun `provider failure propagates and is not cached`() { + val fetches = AtomicInteger(0) + val provider = + object : BearerTokenProvider { + override fun fetch( + scopes: List, + params: Map, + ): BearerToken = error("blocking fetch must not be called") + + override fun fetchAsync( + scopes: List, + params: Map, + ): CompletableFuture { + fetches.incrementAndGet() + return Futures.failed(RuntimeException("token endpoint down")) + } + } + val client = RecordingClient(200) + val p = pipeline(provider, client) + + val thrown = assertFails { p.sendAsync(getRequest()).join() } + assertTrue(Futures.unwrap(thrown) is RuntimeException) + // Not cached: a second request retries the fetch. + assertFails { p.sendAsync(getRequest()).join() } + assertEquals(2, fetches.get()) + } + + @Test + fun `provider returning an already-expired token surfaces IllegalStateException`() { + val provider = CountingProvider { BearerToken("stale", now.minusSeconds(1)) } + val client = RecordingClient(200) + val thrown = assertFails { pipeline(provider, client).sendAsync(getRequest()).join() } + assertTrue(Futures.unwrap(thrown) is IllegalStateException) + } + + // ----------------- HTTPS guard ----------------- + + @Test + fun `non-HTTPS request is rejected before any fetch`() { + val provider = CountingProvider { BearerToken("tok", now.plusSeconds(3600)) } + val client = RecordingClient(200) + val request = Request.builder().method(Method.GET).url("http://api.example.com/x").build() + val thrown = assertFails { pipeline(provider, client).sendAsync(request).join() } + assertTrue(Futures.unwrap(thrown) is IllegalStateException) + assertEquals(0, provider.fetchCount, "no token fetch on the rejected plaintext path") + } + + // ----------------- 401 challenge eviction ----------------- + + @Test + fun `401 bearer challenge evicts the token and retries with a fresh one`() { + var fetch = 0 + val provider = + object : BearerTokenProvider { + val fetches = AtomicInteger(0) + + override fun fetch( + scopes: List, + params: Map, + ): BearerToken = error("blocking fetch must not be called") + + override fun fetchAsync( + scopes: List, + params: Map, + ): CompletableFuture { + fetches.incrementAndGet() + fetch++ + val value = if (fetch == 1) "first" else "second" + return CompletableFuture.completedFuture(BearerToken(value, now.plusSeconds(3600))) + } + } + var call = 0 + val seenAuth = mutableListOf() + val client = + AsyncHttpClient { request -> + call++ + seenAuth.add(request.headers.get(HttpHeaderName.AUTHORIZATION)) + val code = if (call == 1) 401 else 200 + val headers = + if (call == 1) { + Headers.Builder().add(HttpHeaderName.WWW_AUTHENTICATE.caseSensitiveName, "Bearer").build() + } else { + Headers.Builder().build() + } + CompletableFuture.completedFuture( + Response.builder() + .request(request) + .protocol(Protocol.HTTP_1_1) + .status(Status.fromCode(code)) + .headers(headers) + .build(), + ) + } + val response = pipeline(provider, client).sendAsync(getRequest()).join() + assertEquals(200, response.status.code) + assertEquals(listOf("Bearer first", "Bearer second"), seenAuth) + assertEquals(2, provider.fetches.get()) + } + + @Test + fun `401 without a bearer challenge surfaces unchanged`() { + val provider = CountingProvider { BearerToken("tok", now.plusSeconds(3600)) } + var call = 0 + val client = + AsyncHttpClient { request -> + call++ + CompletableFuture.completedFuture( + Response.builder() + .request(request) + .protocol(Protocol.HTTP_1_1) + .status(Status.fromCode(401)) + .headers( + Headers.Builder() + .add(HttpHeaderName.WWW_AUTHENTICATE.caseSensitiveName, "Basic realm=x") + .build(), + ) + .build(), + ) + } + val response = pipeline(provider, client).sendAsync(getRequest()).join() + assertEquals(401, response.status.code) + assertEquals(1, call, "a Basic challenge must not trigger a bearer re-fetch + retry") + } + + // ----------------- Helpers ----------------- + + private fun pipeline( + provider: BearerTokenProvider, + client: AsyncHttpClient, + ) = AsyncHttpPipelineBuilder(client) + .append(AsyncBearerTokenAuthStep(provider, listOf("scope"), Duration.ofSeconds(30), clock)) + .build() + + private fun getRequest(): Request = Request.builder().method(Method.GET).url("https://api.example.com/x").build() + + /** Provider that counts fetchAsync calls and returns [supply]'s token, completed. */ + private class CountingProvider(private val supply: () -> BearerToken) : BearerTokenProvider { + private val fetches = AtomicInteger(0) + + val fetchCount: Int get() = fetches.get() + + override fun fetch( + scopes: List, + params: Map, + ): BearerToken = error("blocking fetch must not be called") + + override fun fetchAsync( + scopes: List, + params: Map, + ): CompletableFuture { + fetches.incrementAndGet() + return CompletableFuture.completedFuture(supply()) + } + } + + /** Async client returning a constant status and recording the Authorization header. */ + private class RecordingClient(private val code: Int) : AsyncHttpClient { + @Volatile + var lastAuth: String? = null + + override fun executeAsync(request: Request): CompletableFuture { + lastAuth = request.headers.get(HttpHeaderName.AUTHORIZATION) + return CompletableFuture.completedFuture( + Response.builder() + .request(request) + .protocol(Protocol.HTTP_1_1) + .status(Status.fromCode(code)) + .build(), + ) + } + } +} diff --git a/sdk-core/src/test/kotlin/org/dexpace/sdk/core/http/pipeline/steps/DefaultAsyncRetryStepTest.kt b/sdk-core/src/test/kotlin/org/dexpace/sdk/core/http/pipeline/steps/DefaultAsyncRetryStepTest.kt new file mode 100644 index 00000000..11088fb5 --- /dev/null +++ b/sdk-core/src/test/kotlin/org/dexpace/sdk/core/http/pipeline/steps/DefaultAsyncRetryStepTest.kt @@ -0,0 +1,383 @@ +/* + * Copyright (c) 2026 dexpace and Omar Aljarrah + * + * Licensed under the MIT License. See LICENSE in the project root. + * SPDX-License-Identifier: MIT + */ + +package org.dexpace.sdk.core.http.pipeline.steps + +import org.dexpace.sdk.core.client.AsyncHttpClient +import org.dexpace.sdk.core.http.common.Headers +import org.dexpace.sdk.core.http.common.MediaType +import org.dexpace.sdk.core.http.common.Protocol +import org.dexpace.sdk.core.http.pipeline.AsyncHttpPipelineBuilder +import org.dexpace.sdk.core.http.pipeline.AsyncPipelineNext +import org.dexpace.sdk.core.http.pipeline.Stage +import org.dexpace.sdk.core.http.request.Method +import org.dexpace.sdk.core.http.request.Request +import org.dexpace.sdk.core.http.request.RequestBody +import org.dexpace.sdk.core.http.response.Response +import org.dexpace.sdk.core.http.response.ResponseBody +import org.dexpace.sdk.core.http.response.Status +import org.dexpace.sdk.core.io.BufferedSink +import org.dexpace.sdk.core.io.BufferedSource +import org.dexpace.sdk.core.testing.ManualScheduler +import org.dexpace.sdk.core.util.Clock +import org.dexpace.sdk.core.util.Futures +import java.io.IOException +import java.time.Duration +import java.time.Instant +import java.util.concurrent.CompletableFuture +import java.util.concurrent.atomic.AtomicInteger +import kotlin.test.AfterTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFails +import kotlin.test.assertFalse +import kotlin.test.assertTrue +import kotlin.test.fail + +class DefaultAsyncRetryStepTest { + private val scheduler = ManualScheduler() + + @AfterTest + fun tearDown() { + scheduler.close() + } + + // ----------------- Type-level invariants ----------------- + + @Test + fun `stage is RETRY and final`() { + val step = DefaultAsyncRetryStep(scheduler) + assertEquals(Stage.RETRY, step.stage) + val custom = + object : AsyncRetryStep() { + override fun processAsync( + request: Request, + next: AsyncPipelineNext, + ): CompletableFuture = next.processAsync() + } + assertEquals(Stage.RETRY, custom.stage) + } + + // ----------------- maxRetries semantics ----------------- + + @Test + fun `maxRetries = 0 performs exactly one attempt`() { + val client = QueueClient().enqueue(503) + val future = pipeline(client, HttpRetryOptions(maxRetries = 0)).sendAsync(getRequest()) + scheduler.runAll() + assertEquals(503, future.join().status.code) + assertEquals(1, client.callCount) + } + + @Test + fun `retries a 503 until a 200 within the budget`() { + val client = QueueClient().enqueue(503).enqueue(503).enqueue(200) + val future = + pipeline(client, HttpRetryOptions.fixed(maxRetries = 3, delay = Duration.ofMillis(50))) + .sendAsync(getRequest()) + // Each backoff schedules a task on the manual scheduler; drain them to advance the loop. + scheduler.runAll() + assertEquals(200, future.join().status.code) + assertEquals(3, client.callCount) + } + + @Test + fun `exhausts retries and returns the last retryable response`() { + val client = QueueClient().enqueue(503).enqueue(503).enqueue(503) + val future = + pipeline(client, HttpRetryOptions.fixed(maxRetries = 2, delay = Duration.ofMillis(10))) + .sendAsync(getRequest()) + scheduler.runAll() + assertEquals(503, future.join().status.code) + // initial + 2 retries. + assertEquals(3, client.callCount) + } + + @Test + fun `non-retryable status is returned without retry`() { + val client = QueueClient().enqueue(404) + val future = pipeline(client, HttpRetryOptions()).sendAsync(getRequest()) + scheduler.runAll() + assertEquals(404, future.join().status.code) + assertEquals(1, client.callCount) + } + + // ----------------- Exception retry ----------------- + + @Test + fun `retries a retryable IOException then succeeds`() { + val client = FailNTimesClient(failures = 2, exception = IOException("boom")) + val future = + pipeline(client, HttpRetryOptions.fixed(maxRetries = 3, delay = Duration.ofMillis(5))) + .sendAsync(getRequest()) + scheduler.runAll() + assertEquals(200, future.join().status.code) + assertEquals(3, client.callCount) + } + + @Test + fun `terminal exception carries prior attempts as suppressed`() { + val client = FailNTimesClient(failures = 5, exception = IOException("boom")) + val future = + pipeline(client, HttpRetryOptions.fixed(maxRetries = 2, delay = Duration.ofMillis(5))) + .sendAsync(getRequest()) + scheduler.runAll() + val thrown = assertFails { future.join() } + val cause = Futures.unwrap(thrown) + assertTrue(cause is IOException) + // 2 prior failures attached as suppressed (initial + first retry), terminal is the 3rd. + assertEquals(2, cause.suppressed.size) + assertEquals(3, client.callCount) + } + + @Test + fun `non-retryable exception is surfaced immediately`() { + val client = FailNTimesClient(failures = 5, exception = IllegalArgumentException("nope")) + val future = + pipeline(client, HttpRetryOptions.fixed(maxRetries = 3, delay = Duration.ZERO)) + .sendAsync(getRequest()) + scheduler.runAll() + val thrown = assertFails { future.join() } + assertTrue(Futures.unwrap(thrown) is IllegalArgumentException) + assertEquals(1, client.callCount) + } + + // ----------------- Idempotency awareness ----------------- + + @Test + fun `bare POST without body is not retried`() { + val client = QueueClient().enqueue(503) + val request = + Request.builder().method(Method.POST).url("https://api.example.com/x").build() + val future = + pipeline(client, HttpRetryOptions.fixed(maxRetries = 3, delay = Duration.ZERO)) + .sendAsync(request) + scheduler.runAll() + assertEquals(503, future.join().status.code) + assertEquals(1, client.callCount) + } + + @Test + fun `POST with a replayable body is retried`() { + val client = QueueClient().enqueue(503).enqueue(200) + val request = + Request.builder() + .method(Method.POST) + .url("https://api.example.com/x") + .body(RequestBody.create("payload".toByteArray())) + .build() + val future = + pipeline(client, HttpRetryOptions.fixed(maxRetries = 3, delay = Duration.ZERO)) + .sendAsync(request) + scheduler.runAll() + assertEquals(200, future.join().status.code) + assertEquals(2, client.callCount) + } + + @Test + fun `POST with a non-replayable body is not retried`() { + val client = QueueClient().enqueue(503) + val request = + Request.builder() + .method(Method.POST) + .url("https://api.example.com/x") + .body(NonReplayableBody()) + .build() + val future = + pipeline(client, HttpRetryOptions.fixed(maxRetries = 3, delay = Duration.ZERO)) + .sendAsync(request) + scheduler.runAll() + assertEquals(503, future.join().status.code) + assertEquals(1, client.callCount) + } + + // ----------------- Retry-After honoring ----------------- + + @Test + fun `Retry-After seconds header is honored as the delay`() { + val client = + QueueClient() + .enqueue(503, Headers.Builder().add("Retry-After", "2").build()) + .enqueue(200) + val future = + pipeline(client, HttpRetryOptions(maxRetries = 3)) + .sendAsync(getRequest()) + // The first retry should be scheduled for 2 seconds. + scheduler.runAll() + assertEquals(200, future.join().status.code) + val scheduled = scheduler.recordedDelays + assertTrue(scheduled.any { it == Duration.ofSeconds(2) }, "expected a 2s scheduled delay, got $scheduled") + } + + // ----------------- Body close before retry ----------------- + + @Test + fun `prior retryable response body is closed before retrying`() { + val closes = AtomicInteger(0) + var n = 0 + val client = + AsyncHttpClient { request -> + n++ + val code = if (n == 1) 503 else 200 + CompletableFuture.completedFuture( + Response.builder() + .request(request) + .protocol(Protocol.HTTP_1_1) + .status(Status.fromCode(code)) + .body(CountingCloseBody(closes)) + .build(), + ) + } + val future = + pipeline(client, HttpRetryOptions.fixed(maxRetries = 3, delay = Duration.ZERO)) + .sendAsync(getRequest()) + scheduler.runAll() + future.join().close() + assertTrue(closes.get() >= 1, "the 503 body should have been closed before retry") + } + + // ----------------- Stack safety ----------------- + + @Test + fun `a long zero-delay retry sequence does not overflow the stack`() { + // 5000 zero-delay retries. An implementation that recursed per attempt (thenCompose + // chains or self-recursive drive) would blow the stack; the iterative trampoline must not. + val attempts = 5000 + val client = AlwaysFailClient(IOException("io")) + val future = + pipeline(client, HttpRetryOptions.fixed(maxRetries = attempts, delay = Duration.ZERO)) + .sendAsync(getRequest()) + scheduler.runAll() + val thrown = assertFails { future.join() } + assertTrue(Futures.unwrap(thrown) is IOException) + assertEquals(attempts + 1, client.callCount) + } + + // ----------------- Cancellation / no real sleep ----------------- + + @Test + fun `delay is scheduled not slept - pending future before the scheduler runs`() { + val client = QueueClient().enqueue(503).enqueue(200) + val future = + pipeline(client, HttpRetryOptions.fixed(maxRetries = 3, delay = Duration.ofSeconds(30))) + .sendAsync(getRequest()) + // The first attempt already ran (503), but the retry is parked on the scheduler — the + // future is NOT complete and no thread is blocked. + assertFalse(future.isDone) + assertEquals(1, client.callCount) + scheduler.runAll() + assertEquals(200, future.join().status.code) + } + + // ----------------- Helpers ----------------- + + private fun pipeline( + client: AsyncHttpClient, + options: HttpRetryOptions, + ) = AsyncHttpPipelineBuilder(client) + .append(DefaultAsyncRetryStep(scheduler, options, fixedClock())) + .build() + + private fun getRequest(): Request = Request.builder().method(Method.GET).url("https://api.example.com/x").build() + + private fun fixedClock(): Clock = + object : Clock { + override fun now(): Instant = Instant.EPOCH + + override fun monotonic(): Long = 0L + + override fun sleep(duration: Duration) = fail("async retry must not call Clock.sleep") + } + + /** Async client returning a FIFO queue of canned responses; throws if the queue is empty. */ + private class QueueClient : AsyncHttpClient { + private val queue = ArrayDeque>() + private val calls = AtomicInteger(0) + + val callCount: Int get() = calls.get() + + fun enqueue( + code: Int, + headers: Headers = Headers.Builder().build(), + ): QueueClient = apply { queue.addLast(code to headers) } + + override fun executeAsync(request: Request): CompletableFuture { + calls.incrementAndGet() + val (code, headers) = + queue.removeFirstOrNull() ?: error("QueueClient: no response enqueued") + return CompletableFuture.completedFuture( + Response.builder() + .request(request) + .protocol(Protocol.HTTP_1_1) + .status(Status.fromCode(code)) + .headers(headers) + .build(), + ) + } + } + + /** Fails the first [failures] attempts with [exception], then returns 200. */ + private class FailNTimesClient( + private val failures: Int, + private val exception: Exception, + ) : AsyncHttpClient { + private val calls = AtomicInteger(0) + + val callCount: Int get() = calls.get() + + override fun executeAsync(request: Request): CompletableFuture { + val n = calls.incrementAndGet() + return if (n <= failures) { + Futures.failed(exception) + } else { + CompletableFuture.completedFuture( + Response.builder() + .request(request) + .protocol(Protocol.HTTP_1_1) + .status(Status.OK) + .build(), + ) + } + } + } + + /** Always fails with [exception]. */ + private class AlwaysFailClient(private val exception: Exception) : AsyncHttpClient { + private val calls = AtomicInteger(0) + + val callCount: Int get() = calls.get() + + override fun executeAsync(request: Request): CompletableFuture { + calls.incrementAndGet() + return Futures.failed(exception) + } + } + + private class NonReplayableBody : RequestBody() { + override fun mediaType(): MediaType? = MediaType.parse("text/plain") + + override fun contentLength(): Long = 5 + + override fun isReplayable(): Boolean = false + + override fun writeTo(sink: BufferedSink) { + sink.write("hello".toByteArray(Charsets.UTF_8)) + } + } + + private class CountingCloseBody(private val closes: AtomicInteger) : ResponseBody() { + override fun mediaType(): MediaType? = null + + override fun contentLength(): Long = 0 + + override fun source(): BufferedSource = fail("body should not be read") + + override fun close() { + closes.incrementAndGet() + } + } +} diff --git a/sdk-core/src/testFixtures/kotlin/org/dexpace/sdk/core/testing/ManualScheduler.kt b/sdk-core/src/testFixtures/kotlin/org/dexpace/sdk/core/testing/ManualScheduler.kt new file mode 100644 index 00000000..b03bccd5 --- /dev/null +++ b/sdk-core/src/testFixtures/kotlin/org/dexpace/sdk/core/testing/ManualScheduler.kt @@ -0,0 +1,192 @@ +/* + * Copyright (c) 2026 dexpace and Omar Aljarrah + * + * Licensed under the MIT License. See LICENSE in the project root. + * SPDX-License-Identifier: MIT + */ + +package org.dexpace.sdk.core.testing + +import java.time.Duration +import java.util.concurrent.Callable +import java.util.concurrent.Delayed +import java.util.concurrent.ScheduledExecutorService +import java.util.concurrent.ScheduledFuture +import java.util.concurrent.TimeUnit + +/** + * Deterministic [ScheduledExecutorService] for testing async, time-dependent pipeline steps + * (async retry backoff) without real sleeps or background threads. + * + * Only the `schedule(Runnable, delay, unit)` overload used by + * [org.dexpace.sdk.core.util.Futures.delay] is implemented; every other method throws + * [UnsupportedOperationException]. Scheduled tasks are NOT run automatically — the test drives + * them explicitly via [runAll] (or [runNext]), so the test thread controls exactly when each + * delayed continuation fires. Each scheduled delay is recorded in [recordedDelays] so a test can + * assert on the requested backoff schedule (e.g. that a `Retry-After: 2` produced a 2-second + * delay) without observing wall-clock time. + * + * Not thread-safe: tests run on a single thread. + */ +class ManualScheduler : ScheduledExecutorService { + private val pending: ArrayDeque = ArrayDeque() + + /** Every delay requested via [schedule], in submission order. Read-only snapshot semantics. */ + val recordedDelays: List get() = pending.map { it.delay } + ran + + private val ran: MutableList = mutableListOf() + private var closed = false + + /** Number of tasks still queued and not yet run. */ + val pendingCount: Int get() = pending.size + + /** + * Runs queued tasks until the queue is empty, including tasks that earlier tasks schedule + * while running (the async retry loop re-arms by scheduling a new delay). Cancelled tasks are + * skipped. Bounded so a misbehaving infinite re-schedule fails loudly instead of hanging. + */ + fun runAll() { + var guard = 0 + while (pending.isNotEmpty()) { + check(guard++ < MAX_DRAIN_ITERATIONS) { + "ManualScheduler.runAll exceeded $MAX_DRAIN_ITERATIONS iterations — likely an " + + "unbounded re-schedule loop" + } + runNext() + } + } + + /** Runs the next queued task (FIFO). No-op if the queue is empty. */ + fun runNext() { + val task = pending.removeFirstOrNull() ?: return + ran.add(task.delay) + if (!task.cancelled) task.command.run() + } + + override fun schedule( + command: Runnable, + delay: Long, + unit: TimeUnit, + ): ScheduledFuture<*> { + check(!closed) { "ManualScheduler is closed" } + val task = ScheduledTask(command, Duration.ofNanos(unit.toNanos(delay))) + pending.addLast(task) + return task + } + + /** Marks the scheduler closed and drops any queued tasks. Not an override on Java 8's + * [ScheduledExecutorService] (which gained `close()` only in Java 19); a plain helper. */ + fun close() { + closed = true + pending.clear() + } + + override fun shutdown() { + close() + } + + override fun shutdownNow(): List { + val drained = pending.map { it.command } + close() + return drained + } + + override fun isShutdown(): Boolean = closed + + override fun isTerminated(): Boolean = closed + + override fun awaitTermination( + timeout: Long, + unit: TimeUnit, + ): Boolean = closed + + // -- Unused ScheduledExecutorService surface ------------------------------------------------ + + override fun schedule( + callable: Callable, + delay: Long, + unit: TimeUnit, + ): ScheduledFuture = unsupported() + + override fun scheduleAtFixedRate( + command: Runnable, + initialDelay: Long, + period: Long, + unit: TimeUnit, + ): ScheduledFuture<*> = unsupported() + + override fun scheduleWithFixedDelay( + command: Runnable, + initialDelay: Long, + delay: Long, + unit: TimeUnit, + ): ScheduledFuture<*> = unsupported() + + override fun execute(command: Runnable) { + command.run() + } + + override fun submit(task: Callable): java.util.concurrent.Future = unsupported() + + override fun submit( + task: Runnable, + result: T, + ): java.util.concurrent.Future = unsupported() + + override fun submit(task: Runnable): java.util.concurrent.Future<*> = unsupported() + + override fun invokeAll( + tasks: MutableCollection>, + ): MutableList> = unsupported() + + override fun invokeAll( + tasks: MutableCollection>, + timeout: Long, + unit: TimeUnit, + ): MutableList> = unsupported() + + override fun invokeAny(tasks: MutableCollection>): T = unsupported() + + override fun invokeAny( + tasks: MutableCollection>, + timeout: Long, + unit: TimeUnit, + ): T = unsupported() + + private fun unsupported(): Nothing = + throw UnsupportedOperationException("ManualScheduler only supports schedule(Runnable, delay, unit)") + + /** A queued task. [ScheduledFuture] is implemented minimally — only cancellation matters. */ + private class ScheduledTask( + val command: Runnable, + val delay: Duration, + ) : ScheduledFuture { + var cancelled: Boolean = false + private set + + override fun cancel(mayInterruptIfRunning: Boolean): Boolean { + cancelled = true + return true + } + + override fun isCancelled(): Boolean = cancelled + + override fun isDone(): Boolean = cancelled + + override fun get(): Any? = null + + override fun get( + timeout: Long, + unit: TimeUnit, + ): Any? = null + + override fun getDelay(unit: TimeUnit): Long = unit.convert(delay.toNanos(), TimeUnit.NANOSECONDS) + + override fun compareTo(other: Delayed): Int = + getDelay(TimeUnit.NANOSECONDS).compareTo(other.getDelay(TimeUnit.NANOSECONDS)) + } + + private companion object { + private const val MAX_DRAIN_ITERATIONS = 100_000 + } +}