From 9ead2749eb0739dc2e9e2e2afa9485eaa6b14063 Mon Sep 17 00:00:00 2001 From: meskill <8974488+meskill@users.noreply.github.com> Date: Tue, 26 May 2026 15:16:12 +0000 Subject: [PATCH 1/8] docs: issues with replication --- docs/issues/replication.md | 325 +++++++++++++++++++++++++++++++++++++ 1 file changed, 325 insertions(+) create mode 100644 docs/issues/replication.md diff --git a/docs/issues/replication.md b/docs/issues/replication.md new file mode 100644 index 000000000..c33673515 --- /dev/null +++ b/docs/issues/replication.md @@ -0,0 +1,325 @@ +# Replication Issues + +--- + +## 🚧 Issue 1 — Lag inflation when source databases share a PostgreSQL instance + +### Description + +The replication lag metric used to gate cutover is computed as: + +```sql +SELECT pg_current_wal_lsn() - confirmed_flush_lsn +FROM pg_replication_slots +WHERE slot_name = '...'; +``` + +`pg_current_wal_lsn()` returns the current write-ahead log position for the **entire PostgreSQL instance**, not for a specific database or publication. When multiple source shards are hosted on the same PostgreSQL instance (different databases, one slot per database), `pg_current_wal_lsn()` advances with every write to any database on that instance. + +A logical replication slot only decodes and delivers changes that belong to its publication. Changes from other databases are physically present in the WAL stream but are invisible to the slot's decoder — `confirmed_flush_lsn` only advances when the client acknowledges a decoded logical message (i.e., a `Commit` record from the publication). Once a slot has replayed all of its publication's data, `confirmed_flush_lsn` stagnates at the LSN of the last commit in that publication. It will never advance past WAL records from other databases, regardless of whether writes have stopped. + +This means the lag metric permanently overstates the remaining work. On a three-shard benchmark where all source databases are on one instance, the observed lag was ~3.5 GB per slot even after each slot had replayed all of its own publication's data. The lag never dropped below the cutover threshold, so `wait_for_replication()` looped indefinitely and cutover never fired. + +### Cause + +`pg_current_wal_lsn()` is instance-scoped. `confirmed_flush_lsn` is publication-scoped. Their difference is only meaningful when a single database accounts for all writes to the instance. + +### Code references + +| Symbol | File | +|---|---| +| `ReplicationSlot::replication_lag()` — the lag query | [`pgdog/src/backend/replication/logical/publisher/slot.rs`](../../pgdog/src/backend/replication/logical/publisher/slot.rs) | +| `ReplicationWaiter::wait_for_replication()` — the cutover gate | [`pgdog/src/backend/replication/logical/orchestrator.rs`](../../pgdog/src/backend/replication/logical/orchestrator.rs) | +| Keepalive handler / `flush_lsn` reply (`data_since_keepalive` flag) | [`pgdog/src/backend/replication/logical/publisher/publisher_impl.rs`](../../pgdog/src/backend/replication/logical/publisher/publisher_impl.rs) | +### Fix + +The PostgreSQL walsender sends a keepalive message after exhausting all decoded changes available for the publication. The keepalive carries `wal_end` — the server's current WAL write position. When a keepalive arrives and no xlog data was received since the previous keepalive, the slot has drained: there is nothing for the publication between `committed_lsn` and `wal_end`, so that gap consists entirely of other databases' WAL. + +In this state the client can safely reply with `flush_lsn = wal_end`. PostgreSQL sets `confirmed_flush_lsn` on the slot to `wal_end`, and the lag query returns ~0. + +Reporting `wal_end` is only valid when the slot is caught up. During active streaming — where the server is sending transactions and keepalives may arrive between commits — `flush_lsn` must remain at `committed_lsn`. Reporting `wal_end` prematurely would advance `confirmed_flush_lsn` past unapplied commits; if the connection dropped, the server would restart from `wal_end` and those commits would be lost. + +The implemented guard: `data_since_keepalive` flag. Set to `true` when any xlog data message arrives; cleared to `false` when a keepalive arrives. A keepalive is the catch-up signal only when this flag is `false` — meaning no data arrived between the last keepalive and this one. + +``` +keepalive received + data_since_keepalive = true -> reply flush_lsn = committed_lsn (active, mid-stream) + data_since_keepalive = false -> reply flush_lsn = wal_end (caught up) +``` + +### PostgreSQL references + +- [Streaming Replication Protocol](https://www.postgresql.org/docs/current/protocol-replication.html) — Standby Status Update message (`flush_lsn` field); Primary Keepalive message (`wal_end` field). +- [`pg_replication_slots`](https://www.postgresql.org/docs/current/view-pg-replication-slots.html) — `confirmed_flush_lsn` column: the last LSN confirmed received by the standby/subscriber. +- [Logical Replication](https://www.postgresql.org/docs/current/logical-replication.html) — publication filtering and how the walsender decodes only changes relevant to the subscriber's publication. + +--- + +## 🚧 Issue 2 — Stop signal only unblocked one task instead of all + +### Description + +When cutover initiates, `Waiter::stop()` must signal all N per-shard replication tasks to terminate. Each task blocks in a `select!` loop waiting on either incoming WAL data or a stop signal. All N tasks must receive the signal and break out of their loop before `Waiter::wait()` can return. + +The original implementation used `Arc` for the stop signal. `Notify::notify_one()` wakes exactly one waiting task. `Notify::notify_waiters()` wakes only tasks that are *currently parked* on the future at the moment of the call — any task that polls `notified()` after `notify_waiters()` returns will park again and never wake. With N tasks the result was: one task exited, the others remained blocked indefinitely. `Waiter::wait()` joined all task handles and hung. + +### Cause + +`Notify` does not persist state. A permit stored by `notify_one()` is consumed by the first task that calls `notified().await`; subsequent tasks see no permit and park. `notify_waiters()` is a snapshot operation: it only affects tasks already parked at the instant of the call. + +### Code references + +| Symbol | File | +|---|---| +| `Publisher::stop` / `Waiter::stop` — the stop channel | [`pgdog/src/backend/replication/logical/publisher/publisher_impl.rs`](../../pgdog/src/backend/replication/logical/publisher/publisher_impl.rs) | +| `Waiter::stop()` — sends the signal | same file | +| `stop_rx.changed()` arm in per-shard task `select!` | same file | +### Fix + +Replace `Arc` with `Arc>` / `watch::Receiver`. `watch::Sender::send(true)` persists the value in the channel. Every receiver — whether currently parked or polling later — resolves `changed()` as soon as it observes the new value. One `send(true)` call unblocks all N tasks regardless of scheduling order. + +```rust +// Publisher and Waiter both hold Arc>. +// Each task subscribes before spawning: +let mut stop_rx = self.stop.subscribe(); // watch::Receiver + +// Inside the task select!: +_ = stop_rx.changed() => { + slot.stop_replication().await?; + break; +} + +// Waiter::stop() -- one call, all tasks unblock: +pub fn stop(&self) { + let _ = self.stop.send(true); +} +``` + +### Tokio references + +- [`tokio::sync::Notify`](https://docs.rs/tokio/latest/tokio/sync/struct.Notify.html) — `notify_one()` stores at most one permit; `notify_waiters()` is not persistent. +- [`tokio::sync::watch`](https://docs.rs/tokio/latest/tokio/sync/watch/index.html) — `Sender::send()` updates the shared value; all receivers observe it on their next `changed()` poll. + +--- + +## 🚧 Issue 3 — Premature cutover when lag map is empty at startup + +### Description + +The orchestrator computes the current replication lag as the maximum value across all per-shard lag entries: + +```rust +lag.values().copied().max().unwrap_or(i64::MAX) as u64 +``` + +Per-shard lag values are written into `replication_lag: HashMap` by each task's `check_lag.tick()` interval. This interval fires once per second, but the first tick does not fire until one second after the task spawns. + +The orchestrator's `wait_for_replication()` loop runs immediately after `replicate()` returns. If it evaluates `replication_lag()` before any task's first tick, the map is empty. `Iterator::max()` on an empty iterator returns `None`. The original code used `unwrap_or_default()`, which returns `0`. The orchestrator saw lag = 0 bytes, concluded replication was already caught up, entered maintenance mode, and triggered cutover before any data had been replicated. + +### Cause + +Two independent races: +1. `unwrap_or_default()` on an empty map returned 0, which is indistinguishable from a legitimately zero lag. +2. No synchronization between task startup and the first lag measurement. + +### Code references + +| Symbol | File | +|---|---| +| `Publisher::replicate()` — pre-population loop for `replication_lag` map | [`pgdog/src/backend/replication/logical/publisher/publisher_impl.rs`](../../pgdog/src/backend/replication/logical/publisher/publisher_impl.rs) | +| `Publisher::replication_lag()` — reads the map | same file | +| `Orchestrator::replication_lag()` — takes the max across shards | [`pgdog/src/backend/replication/logical/orchestrator.rs`](../../pgdog/src/backend/replication/logical/orchestrator.rs) | +### Fix + +Pre-populate the map with `i64::MAX` for every shard index before spawning any tasks. Use `or_insert` so a real measurement written by a task is never overwritten by the sentinel. Change `unwrap_or_default()` to `unwrap_or(i64::MAX)` so an empty map (which should not occur after pre-population) is also treated as maximally lagged rather than zero. + +```rust +// Before spawning tasks: +let mut guard = self.replication_lag.lock(); +for number in 0..n_sources { + guard.entry(number).or_insert(i64::MAX); +} + +// Orchestrator reads: +lag.values().copied().max().unwrap_or(i64::MAX) as u64 +``` + +The sentinel is overwritten by the first real measurement from each task's `check_lag.tick()`. Until that point the orchestrator treats every uninitialized shard as having infinite lag and will not proceed. + +--- + +## 🚧 Issue 4 — Divergent code paths for the same operation + +### Description + +The resharding pipeline — schema sync, data sync, replication, cutover — can be initiated through four independent entry points: + +| Entry point | Phase coverage | +|---|---| +| `RESHARD` admin command | full flow: schema sync, data sync, replication, cutover | +| CLI `replicate-and-cutover` | same full flow | +| `REPLICATE` admin command | replication only; schema/data sync must already be done | +| `REPLICATE` + `CUTOVER` admin commands | replication started as background task; cutover triggered externally | + +Each path composes the same underlying primitives (`Orchestrator`, `ReplicationWaiter`, `Publisher`) but assembles them differently and makes different assumptions about which phases have already run and which signals will arrive. When one path is fixed, the fix is not applied to the others because there is no single place that owns the shared contract. + +The concrete instance that surfaced during debugging: the background task registered by `REPLICATE` exits without performing cutover when the slot drains naturally. It waits in: + +```rust +select! { + _ = abort_rx => { waiter.stop(); } + _ = cutover.notified() => { waiter.cutover().await; } + result = waiter.wait() => { /* log error only */ } +} +``` + +When the source slot drains (`CopyDone`), `waiter.wait()` returns and the task exits via the third arm. `waiter.cutover()` is never called: `wait_for_replication()`, `wait_for_cutover()`, `schema_sync_post_cutover()`, and `databases::cutover()` (which flips traffic) all go unexecuted. The destination is fully populated and replication has stopped, but pgdog still routes to the source. The direct paths (`RESHARD`, CLI) always call `cutover()` at the end — they do not have this gap. + +A secondary consequence of path divergence: `AsyncTasks::cutover()` still calls `notify_one()` on `Arc` (the primitive replaced in Issue 2). The fix in `publisher_impl.rs` was not propagated to the admin layer. This creates a race in the `REPLICATE` + `CUTOVER` path: if the slot drains at the same instant the operator sends `CUTOVER`, `select!` picks one arm non-deterministically and the notification may be silently discarded. + +### Cause + +Each entry point was built to satisfy a specific operational need without consolidating around a shared flow. There is no contract enforcing which phases run in which order, so paths accumulate independent deviations over time. + +### Code references + +| Symbol | File | +|---|---| +| `TaskType::Replication` `select!` / `waiter.wait()` arm (line ~150) | [`pgdog/src/backend/replication/logical/admin.rs`](../../pgdog/src/backend/replication/logical/admin.rs) | +| `AsyncTasks::cutover()` / `notify_one()` (line ~75) | same file | +| `Replicate::execute()` | [`pgdog/src/admin/replicate.rs`](../../pgdog/src/admin/replicate.rs) | +| `Reshard::execute()` | [`pgdog/src/admin/reshard.rs`](../../pgdog/src/admin/reshard.rs) | +| `Orchestrator::replicate_and_cutover()` — the canonical flow | [`pgdog/src/backend/replication/logical/orchestrator.rs`](../../pgdog/src/backend/replication/logical/orchestrator.rs) | +### Fix + +Two independent fixes are needed. + +#### Immediate fix — cutover on natural drain + +The `waiter.wait()` arm in the background task `select!` currently only logs errors; on `Ok(())` it exits silently without performing cutover. The slot has fully drained, the destination is populated, but pgdog still routes traffic to the source. The fix is to call `waiter.cutover()` on successful completion: + +```rust +result = waiter.wait() => { + match result { + Ok(()) => { + // Slot drained naturally — still perform cutover. + if let Err(err) = waiter.cutover().await { + error!(...); + } + } + Err(err) => error!(...), + } +} +``` + +#### Secondary fix — `notify_one()` race in `AsyncTasks::cutover()` + +`AsyncTasks::cutover()` still calls `notify_one()` on `Arc` (the primitive replaced in Issue 2 for the stop signal). The fix from Issue 2 was not propagated to the admin layer. This creates a race in the `REPLICATE` + `CUTOVER` path: if the slot drains at the same instant the operator sends `CUTOVER`, `select!` picks one arm non-deterministically and the notification may be silently discarded. + +Replace `Arc` with `watch::Sender` in `AsyncTasks` so that `cutover.send(true)` persists the value and any task polling `cutover.changed()` — whether parked or not at the moment of the call — sees it. + +#### Structural fix + +Make `Orchestrator::replicate_and_cutover()` the single canonical implementation of the full flow and have the background-task path call it rather than assembling phases independently. The background-task model's only responsibility should be *when* to trigger cutover (immediately, on external signal, on timeout) — not *how* to execute it. + +### References + +- [Tokio `select!` macro](https://docs.rs/tokio/latest/tokio/macro.select.html) — when multiple branches are ready simultaneously, one is chosen pseudo-randomly; no branch is guaranteed to execute. + +--- + +## 🚧 Issue 5 — `AbortSignal` is not an abort signal; it is a coordinator-gone detector + +### Description + +`AbortSignal` is used inside the parallel table-copy path to interrupt in-flight `COPY` loops when the sync coordinator exits. The name implies an active cancellation primitive, but the mechanism is entirely passive: it wraps an `UnboundedSender` and calls `tx.closed().await`, which resolves only when the corresponding `rx` (owned by `ParallelSyncManager::run()`) is dropped. + +There is no `abort()` method. Nothing sends a signal. The only way the future resolves is if the receiver end of the channel is dropped — which happens as a side effect of the manager returning, not as an intentional cancellation act. + +This creates four concrete problems. + +**1. `rx` is only dropped when `manager.run()` returns.** + +`ParallelSyncManager::run()` is called inside a `tokio::spawn`ed task. That task runs independently of its caller. Dropping or cancelling `Orchestrator::data_sync()` or `Publisher::data_sync()` mid-await does not cancel the spawned task, does not drop `rx`, and does not fire the abort signal in any worker. Workers keep running until `manager.run()` finishes on its own. + +**2. The only trigger for `rx` dropping mid-run is a worker error propagating through `?`.** + +Inside `run()`: + +```rust +while let Some(table) = rx.recv().await { + tables.push(table?); // ← error here short-circuits the loop +}; +``` + +When one worker sends `Err(...)`, the `?` unwinds `run()`, `rx` drops, and `tx.closed()` resolves in every remaining worker — all concurrent `COPY` loops abort simultaneously. This is the only operational abort path. There is no way to cancel a single table's copy without bringing down every other table in the same manager. + +**3. The `tx.is_closed()` guard fires after permit acquisition, not before.** + +```rust +let _permit = Arc::clone(&self.permit) + .acquire_owned() + .await + .map_err(|_| Error::ParallelConnection)?; + +if self.tx.is_closed() { // ← checked here + return Err(Error::DataSyncAborted); +} +``` + +Tasks that are queued behind the semaphore when the coordinator dies continue to wait for a permit. They do not observe that the coordinator is gone until after they acquire a permit. Under high concurrency this means every queued task wakes, acquires a permit, checks `is_closed()`, and immediately returns an error — burning a permit round-trip for each one. + +**4. The name is a lie told to the next reader.** + +A reader seeing `AbortSignal` at a call site infers that an active abort can be issued. The implementation has no such capability. The name suppresses the question "who calls abort?" and makes the passive channel-closed mechanism invisible. This is the most dangerous property: future code that tries to use `AbortSignal` to implement a real abort — a timeout, a graceful stop, a per-table cancel — will find no mechanism to do so and is likely to add a second, parallel cancellation path instead of understanding the existing one. + +### Cause + +The signal was built to satisfy a narrow requirement — stop in-flight copies when any copy fails — and named optimistically. The implementation accidentally works for that single case because error propagation through `?` drops `rx` as a side effect. The gap between the name and the mechanism was not visible until the wider behaviour of `rx` (its lifetime, its relationship to `tokio::spawn`, the absence of an explicit abort path) was examined. + +### Code references + +| Symbol | File | +---|---| +| `AbortSignal` — full definition | [`pgdog/src/backend/replication/logical/publisher/abort.rs`](../../pgdog/src/backend/replication/logical/publisher/abort.rs) | +| `ParallelSync::run_with_retry()` — constructs `AbortSignal` per attempt | [`pgdog/src/backend/replication/logical/publisher/parallel_sync.rs`](../../pgdog/src/backend/replication/logical/publisher/parallel_sync.rs) | +| `Table::data_sync()` — `select!` arm polling `abort.aborted()` | [`pgdog/src/backend/replication/logical/publisher/table.rs`](../../pgdog/src/backend/replication/logical/publisher/table.rs) | +| `ParallelSyncManager::run()` — owns `rx`; drops it on exit | [`pgdog/src/backend/replication/logical/publisher/parallel_sync.rs`](../../pgdog/src/backend/replication/logical/publisher/parallel_sync.rs) | + +### Fix + +Replace `AbortSignal` with a `CancellationToken` (from `tokio-util`) or a `watch::Sender` with an explicit `cancel()` method. The cancellation handle must be cloneable, sendable to workers before they start, and triggerable by an external caller — not just by the death of a channel receiver. + +```rust +// tokio-util approach: +use tokio_util::sync::CancellationToken; + +let cancel = CancellationToken::new(); + +// Pass a clone to each worker: +let worker_cancel = cancel.clone(); + +// Inside data_sync COPY loop: +select! { + _ = worker_cancel.cancelled() => { + return Err(Error::CopyAborted(self.table.clone())); + } + result = copy_sub.copy_data(data_row) => { ... } +} + +// To stop all workers at any time: +cancel.cancel(); +``` + +With this shape: +- The manager can cancel all workers explicitly without dying first. +- A timeout or external stop signal can call `cancel.cancel()` without needing to propagate an error through the channel. +- Per-table cancellation is possible by giving each worker its own child token: `cancel.child_token()`. +- The `tx.is_closed()` pre-flight check becomes `cancel.is_cancelled()`, which is honest about what it is testing. + +The `AbortSignal` type should be deleted. It carries no state that cannot be replaced by the token directly, and its existence perpetuates the misleading name. + +### References + +- [`tokio_util::sync::CancellationToken`](https://docs.rs/tokio-util/latest/tokio_util/sync/struct.CancellationToken.html) — cooperative cancellation token with child-token support and `cancelled().await`. +- [`tokio::sync::watch`](https://docs.rs/tokio/latest/tokio/sync/watch/index.html) — persistent value channel; used for the stop-signal fix in Issue 2. \ No newline at end of file From 0112e03a0b5802beb21fa0d812c6c46e108c9157 Mon Sep 17 00:00:00 2001 From: meskill <8974488+meskill@users.noreply.github.com> Date: Wed, 27 May 2026 12:10:50 +0000 Subject: [PATCH 2/8] feat(api): add new async_tasks management for background_tasks --- pgdog/Cargo.toml | 3 +- pgdog/src/admin/show_tasks.rs | 1 + pgdog/src/api/async_task.rs | 1068 +++++++++++++++++++++++++++++++++ pgdog/src/api/mod.rs | 8 + pgdog/src/api/resharding.rs | 42 ++ pgdog/src/lib.rs | 4 + 6 files changed, 1125 insertions(+), 1 deletion(-) create mode 100644 pgdog/src/api/async_task.rs create mode 100644 pgdog/src/api/mod.rs create mode 100644 pgdog/src/api/resharding.rs diff --git a/pgdog/Cargo.toml b/pgdog/Cargo.toml index e306afc29..1b2829a07 100644 --- a/pgdog/Cargo.toml +++ b/pgdog/Cargo.toml @@ -22,7 +22,7 @@ tracing-subscriber = { version = "0.3", features = ["env-filter", "json", "std"] tracing-throttle = "0.4" parking_lot = "0.12" thiserror = "2" -derive_more = { version = "2", features = ["display", "error"] } +derive_more = { version = "2", features = ["display", "error", "from"] } bytes = "1" clap = { version = "4", features = ["derive"] } serde = { version = "1", features = ["derive"] } @@ -97,6 +97,7 @@ tempfile = "3.23.0" stats_alloc = "0.1.10" brunch = "0.5" wiremock = "0.6" +tokio = { version = "1", features = ["full", "test-util"] } [[bench]] name = "comment_parser" diff --git a/pgdog/src/admin/show_tasks.rs b/pgdog/src/admin/show_tasks.rs index 3b56d6aaa..51565c812 100644 --- a/pgdog/src/admin/show_tasks.rs +++ b/pgdog/src/admin/show_tasks.rs @@ -22,6 +22,7 @@ impl Command for ShowTasks { async fn execute(&self) -> Result, Error> { let rd = RowDescription::new(&[ Field::bigint("id"), + // Field::bigint("parent_id"), Field::text("type"), Field::text("started_at"), Field::text("elapsed"), diff --git a/pgdog/src/api/async_task.rs b/pgdog/src/api/async_task.rs new file mode 100644 index 000000000..6123fee67 --- /dev/null +++ b/pgdog/src/api/async_task.rs @@ -0,0 +1,1068 @@ +use std::fmt::Debug; +use std::fmt::Display; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::task::{Context, Poll}; +use std::time::{Duration, SystemTime}; + +use dashmap::DashMap; +use parking_lot::RwLock; +use tokio::select; +use tokio::sync::oneshot::{self, Receiver}; +use tokio::time::timeout; +use tokio_util::sync::CancellationToken; + +#[derive(Copy, Clone, Debug, Display, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct AsyncTaskId(u64); +pub trait TaskInfoStatus: Display + Debug + Send + Sync + 'static { + fn task_name() -> &'static str; + + /// Grace period for cooperative shutdown after cancellation; + /// once it expires, the task is force-aborted. + fn cancel_timeout() -> Duration { + Duration::from_secs(5) + } +} + +#[derive(Display, Debug)] +pub struct AnonymousTask; + +impl TaskInfoStatus for AnonymousTask { + fn task_name() -> &'static str { + "anonymous" + } +} + +#[derive(Display, Debug, Clone)] +pub enum TaskStatus { + Started, + Pending(T), + Finished, + Cancelled, + Error(String), + Panic(String), +} + +/// Type-erased snapshot of a task's current state, +/// readable through the registry without knowing `T`. +#[derive(Debug, Clone)] +pub struct TaskState { + pub name: &'static str, + pub status: TaskStatus, + pub started_at: SystemTime, + pub updated_at: SystemTime, +} + +/// Why a task did not complete, delivered to the waiter +/// as the error half of its `Result`. +#[derive(Debug, Display, Error)] +pub enum TaskError { + /// The task itself returned an error. + #[display("task failed: {_0}")] + Failed(E), + #[display("task was cancelled")] + Cancelled, + #[display("task panicked: {_0}")] + Panicked(#[error(ignore)] String), + /// The task's result was never delivered: the watcher + /// died without reporting (e.g. runtime shutdown). + #[display("task result was never delivered")] + Abandoned, +} + +impl TaskStatus { + /// Terminal states are write-once; late writers + /// (e.g. ctx clones outliving the task) are ignored. + fn is_terminal(&self) -> bool { + matches!( + self, + Self::Finished | Self::Cancelled | Self::Error(_) | Self::Panic(_) + ) + } + + /// Snapshot for the registry: keep the variant, render `T`. + fn stringify(&self) -> TaskStatus + where + T: Display, + { + match self { + Self::Started => TaskStatus::Started, + Self::Pending(status) => TaskStatus::Pending(status.to_string()), + Self::Finished => TaskStatus::Finished, + Self::Cancelled => TaskStatus::Cancelled, + Self::Error(err) => TaskStatus::Error(err.clone()), + Self::Panic(msg) => TaskStatus::Panic(msg.clone()), + } + } +} + +impl TaskState { + /// Reached a terminal state more than `ttl` ago? + fn expired(&self, now: SystemTime, ttl: Duration) -> bool { + self.status.is_terminal() + && now + .duration_since(self.updated_at) + .is_ok_and(|age| age >= ttl) + } +} + +type SharedStatus = Arc>>; + +#[derive(Default)] +struct TasksMap { + map: DashMap>, + counter: AtomicU64, +} + +impl TasksMap { + fn insert_next(&self, value: Arc) -> AsyncTaskId { + let id = AsyncTaskId(self.counter.fetch_add(1, Ordering::Relaxed)); + + self.map.insert(id, value); + + id + } +} + +struct AsyncTaskState { + updated_at: SystemTime, + status: TaskStatus, +} + +impl AsyncTaskState { + fn new() -> Self { + Self { + updated_at: SystemTime::now(), + status: TaskStatus::Started, + } + } +} + +struct AsyncTask { + started_at: SystemTime, + cancellation_token: CancellationToken, + /// Set once the task asks for its cancellation token: only + /// then can it react to cancellation, so only then is the + /// cooperative-shutdown grace period worth waiting out. + cooperative: AtomicBool, + state: Arc>>, + subtasks: Arc, +} + +trait TaskMapEntry: Send + Sync + 'static { + fn cancel(&self); + fn state(&self) -> TaskState; + fn subtasks(&self) -> &TasksMap; +} + +impl TaskMapEntry for AsyncTask { + fn cancel(&self) { + self.cancellation_token.cancel(); + } + + fn state(&self) -> TaskState { + let state = self.state.read(); + + TaskState { + name: T::task_name(), + status: state.status.stringify(), + started_at: self.started_at, + updated_at: state.updated_at, + } + } + + fn subtasks(&self) -> &TasksMap { + &self.subtasks + } +} + +pub struct AsyncTaskContext { + task: Arc>, +} + +impl Clone for AsyncTaskContext { + fn clone(&self) -> Self { + Self { + task: self.task.clone(), + } + } +} + +/// Handle to a spawned task. Resolves, as a future, to the +/// task's result; also exposes the registry id of the task. +pub struct AsyncTaskWaiter { + id: AsyncTaskId, + waiter: Receiver>>, +} + +impl AsyncTaskWaiter { + pub fn id(&self) -> AsyncTaskId { + self.id + } +} + +impl Debug for AsyncTaskWaiter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AsyncTaskWaiter") + .field("id", &self.id) + .finish_non_exhaustive() + } +} + +impl Future for AsyncTaskWaiter { + type Output = Result>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.get_mut().waiter).poll(cx).map(|res| { + // A dropped sender means the watcher died without + // reporting; surface that as a task error instead of + // leaking the channel's RecvError. + res.unwrap_or_else(|_| Err(TaskError::Abandoned)) + }) + } +} + +/// Point-in-time view of a root task and its subtasks. +/// +/// All descendants register with the root task, so `subtasks` is a +/// flat list of every descendant (children and grandchildren alike), +/// ordered by id; their own `subtasks` are always empty. +#[derive(Debug, Clone)] +pub struct TaskSnapshot { + pub id: AsyncTaskId, + pub state: TaskState, + pub subtasks: Vec, +} + +impl TaskSnapshot { + fn new(id: AsyncTaskId, entry: &dyn TaskMapEntry) -> Self { + let mut subtasks: Vec<_> = entry + .subtasks() + .map + .iter() + .map(|sub| TaskSnapshot { + id: *sub.key(), + state: sub.value().state(), + subtasks: Vec::new(), + }) + .collect(); + subtasks.sort_unstable_by_key(|task| task.id); + + Self { + id, + state: entry.state(), + subtasks, + } + } +} + +/// Finished tasks stay visible in the registry for this long +/// before being pruned. +const TASK_RETENTION: Duration = Duration::from_secs(24 * 60 * 60); + +pub struct AsyncTasksStorage { + tasks: Arc, + retention: Duration, +} + +impl Default for AsyncTasksStorage { + fn default() -> Self { + Self::new(TASK_RETENTION) + } +} + +fn run_task( + current_task: Option<&Arc>>, + tasks: &TasksMap, + execute: impl FnOnce(AsyncTaskContext) -> F, +) -> AsyncTaskWaiter +where + T: TaskInfoStatus, + T1: TaskInfoStatus, + F: Future> + Send + 'static, + R: Send + 'static, + E: std::error::Error + Send + 'static, +{ + let state = Arc::new(RwLock::new(AsyncTaskState::new())); + + let task = AsyncTask { + started_at: SystemTime::now(), + cancellation_token: match current_task { + Some(current_task) => current_task.cancellation_token.child_token(), + None => CancellationToken::new(), + }, + cooperative: AtomicBool::new(false), + // Subtasks share the root task's registry: every descendant + // registers as a direct child of the root. + subtasks: match current_task { + Some(current_task) => current_task.subtasks.clone(), + None => Arc::new(TasksMap::default()), + }, + state: state.clone(), + }; + + let task = Arc::new(task); + // Make sure we insert task to map before it's actually started + let id = tasks.insert_next(task.clone()); + + let ctx = AsyncTaskContext { task: task.clone() }; + + let mut handle = tokio::spawn(execute(ctx.clone())); + let (sender, receiver) = oneshot::channel(); + + let cancellation_token = task.cancellation_token.clone(); + + tokio::spawn(async move { + let res = select! { + _ = cancellation_token.cancelled() => { + if ctx.task.cooperative.load(Ordering::Relaxed) { + match timeout(T1::cancel_timeout(), &mut handle).await { + Ok(res) => res, + Err(_) => { + handle.abort(); + handle.await + } + } + } else { + // The task never took its cancellation token, so it + // cannot react to it: abort right away instead of + // letting it run on through the grace period. + handle.abort(); + handle.await + } + } + res = &mut handle => { + res + } + }; + + match res { + Ok(Ok(res)) => { + ctx.set_inner_status(TaskStatus::Finished); + let _ = sender.send(Ok(res)); + } + Ok(Err(err)) => { + ctx.set_inner_status(TaskStatus::Error(err.to_string())); + let _ = sender.send(Err(TaskError::Failed(err))); + } + Err(err) if err.is_cancelled() => { + ctx.set_inner_status(TaskStatus::Cancelled); + let _ = sender.send(Err(TaskError::Cancelled)); + } + Err(err) => { + let panic = err.to_string(); + ctx.set_inner_status(TaskStatus::Panic(panic.clone())); + let _ = sender.send(Err(TaskError::Panicked(panic))); + } + } + }); + + AsyncTaskWaiter { + id, + waiter: receiver, + } +} + +impl AsyncTaskContext { + fn set_inner_status(&self, status: TaskStatus) { + let mut state = self.task.state.write(); + if state.status.is_terminal() { + return; + } + state.status = status; + state.updated_at = SystemTime::now(); + } + + pub fn set_status(&self, status: T) { + self.set_inner_status(TaskStatus::Pending(status)); + } + + /// Hand out this task's cancellation token. Taking the token + /// opts the task into cooperative shutdown: on `cancel_task` + /// it gets [`TaskInfoStatus::cancel_timeout`] to wind down + /// before being aborted. Tasks that never take it are aborted + /// immediately. Take it early. + pub fn cancellation_token(&self) -> CancellationToken { + self.task.cooperative.store(true, Ordering::Relaxed); + + self.task.cancellation_token.clone() + } + + pub fn run( + &self, + execute: impl FnOnce(AsyncTaskContext) -> F, + ) -> AsyncTaskWaiter + where + T1: TaskInfoStatus, + F: Future> + Send + 'static, + R: Send + 'static, + E: std::error::Error + Send + 'static, + { + run_task(Some(&self.task), &self.task.subtasks, execute) + } +} + +impl AsyncTasksStorage { + pub fn new(retention: Duration) -> Self { + Self { + tasks: Arc::default(), + retention, + } + } + + pub fn run( + &self, + execute: impl FnOnce(AsyncTaskContext) -> F, + ) -> AsyncTaskWaiter + where + T: TaskInfoStatus, + F: Future> + Send + 'static, + R: Send + 'static, + E: std::error::Error + Send + 'static, + { + self.prune(); + + run_task(Option::<&Arc>>::None, &self.tasks, execute) + } + + /// Request cancellation of a task. The task winds down + /// cooperatively (or is aborted after the grace period) and + /// stays in the registry with a terminal status until pruned. + /// Returns the state the task was in when cancellation was + /// requested, or `None` for an unknown id. + pub fn cancel_task(&self, id: AsyncTaskId) -> Option { + let entry = self.tasks.map.get(&id)?; + + entry.cancel(); + + Some(entry.state()) + } + + /// Drop every task that reached a terminal state more than + /// `retention` ago; running tasks are never dropped. + fn prune(&self) { + let now = SystemTime::now(); + + self.tasks.map.retain(|_, entry| { + entry + .subtasks() + .map + .retain(|_, sub| !sub.state().expired(now, self.retention)); + + !entry.state().expired(now, self.retention) + }); + } + + /// Snapshot every root task with its subtasks, ordered by id. + pub fn tasks(&self) -> Vec { + self.prune(); + + let mut tasks: Vec<_> = self + .tasks + .map + .iter() + .map(|entry| TaskSnapshot::new(*entry.key(), entry.value().as_ref())) + .collect(); + tasks.sort_unstable_by_key(|task| task.id); + tasks + } + + /// Snapshot a single root task by id. + pub fn task(&self, id: AsyncTaskId) -> Option { + self.prune(); + + self.tasks + .map + .get(&id) + .map(|entry| TaskSnapshot::new(id, entry.value().as_ref())) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use parking_lot::Mutex; + use std::convert::Infallible; + use std::sync::Arc; + use tokio::sync::Notify; + use tokio::task::yield_now; + use tokio::time::sleep; + use tokio::{join, test}; + + #[derive(Display, Debug)] + enum TestTask { + StepOne, + StepTwo, + } + + impl TaskInfoStatus for TestTask { + fn task_name() -> &'static str { + "test_task" + } + } + + macro_rules! mock_successful { + ($state_id:ident, $notify:ident) => {{ + let state = $state_id.clone(); + let notify = $notify.clone(); + + async move |_ctx: AsyncTaskContext| { + *state.lock() = "started"; + + notify.notified().await; + + *state.lock() = "finished"; + + Ok::<_, Infallible>(()) + } + }}; + } + + macro_rules! mock_failing { + ($state_id:ident, $notify:ident) => {{ + let state = $state_id.clone(); + let notify = $notify.clone(); + + async move |_ctx: AsyncTaskContext| { + *state.lock() = "started"; + + notify.notified().await; + + *state.lock() = "failed"; + + Err::<(), _>(std::io::Error::other("mock task failure")) + } + }}; + } + + /// Spawned tasks are only guaranteed to be polled after the + /// current task yields; a few rounds cover spawn chains. + async fn settle() { + for _ in 0..5 { + yield_now().await; + } + } + + #[test] + async fn test_single_execution() { + let notify = Arc::new(Notify::new()); + let state_a = Arc::new(Mutex::new("initial")); + let a = mock_successful!(state_a, notify); + + let async_storage = AsyncTasksStorage::default(); + + let task = async_storage.run(a); + + settle().await; + + assert_eq!(*state_a.lock(), "started"); + + notify.notify_one(); + + task.await.unwrap(); + + assert_eq!(*state_a.lock(), "finished"); + } + + #[test] + async fn test_multiple_execution() { + let notify = Arc::new(Notify::new()); + let state_a = Arc::new(Mutex::new("initial")); + let state_b = Arc::new(Mutex::new("initial")); + let state_c = Arc::new(Mutex::new("initial")); + let a = mock_successful!(state_a, notify); + let b = mock_successful!(state_b, notify); + let c = mock_successful!(state_c, notify); + + let async_storage = AsyncTasksStorage::default(); + + let task_a = async_storage.run(a); + let task_b = async_storage.run(b); + + let info = async_storage.task(task_a.id()).unwrap(); + + assert_eq!(async_storage.tasks().len(), 2); + + assert!(matches!(info.state.status, TaskStatus::Started)); + + let info = async_storage.task(task_b.id()).unwrap(); + + assert!(matches!(info.state.status, TaskStatus::Started)); + + settle().await; + + assert_eq!(*state_a.lock(), "started"); + assert_eq!(*state_b.lock(), "started"); + assert_eq!(*state_c.lock(), "initial"); + + notify.notify_one(); + notify.notify_one(); + + let task_c = async_storage.run(c); + + notify.notify_one(); + + task_c.await.unwrap(); + + assert_eq!(*state_a.lock(), "finished"); + assert_eq!(*state_b.lock(), "finished"); + assert_eq!(*state_c.lock(), "finished"); + } + + #[test] + async fn test_inner_execution() { + let notify = Arc::new(Notify::new()); + let state_a = Arc::new(Mutex::new("initial")); + let state_b = Arc::new(Mutex::new("initial")); + let state_c = Arc::new(Mutex::new("initial")); + let a = mock_successful!(state_a, notify); + let b = mock_successful!(state_b, notify); + let c = { + let state = state_c.clone(); + + async move |ctx: AsyncTaskContext| { + *state.lock() = "started"; + + let (a, b) = join!(ctx.run(a), ctx.run(b)); + a.unwrap(); + b.unwrap(); + + *state.lock() = "finished"; + + Ok::<_, Infallible>(()) + } + }; + + let async_storage = AsyncTasksStorage::default(); + + let task_c = async_storage.run(c); + + settle().await; + + assert_eq!(*state_a.lock(), "started"); + assert_eq!(*state_b.lock(), "started"); + assert_eq!(*state_c.lock(), "started"); + + notify.notify_waiters(); + + task_c.await.unwrap(); + + assert_eq!(*state_a.lock(), "finished"); + assert_eq!(*state_b.lock(), "finished"); + assert_eq!(*state_c.lock(), "finished"); + } + + #[test(start_paused = true)] + async fn test_single_cancel() { + let notify = Arc::new(Notify::new()); + let state_a = Arc::new(Mutex::new("initial")); + let a = mock_successful!(state_a, notify); + + let async_storage = AsyncTasksStorage::default(); + + let task = async_storage.run(a); + + settle().await; + + assert_eq!(*state_a.lock(), "started"); + + let task_id = task.id(); + async_storage.cancel_task(task_id); + + let res = task.await; + assert!(matches!(res, Err(TaskError::Cancelled))); + + assert_eq!(*state_a.lock(), "started"); + + // Cancelled tasks stay visible with a terminal status. + let snapshot = async_storage.task(task_id).unwrap(); + assert!(matches!(snapshot.state.status, TaskStatus::Cancelled)); + } + #[test(start_paused = true)] + async fn test_graceful_exit_during_cancel_grace() { + let state = Arc::new(Mutex::new("initial")); + let a = { + let state = state.clone(); + + async move |ctx: AsyncTaskContext| { + *state.lock() = "started"; + + ctx.cancellation_token().cancelled().await; + + // Cooperative wind-down, well within the 5s grace window. + sleep(Duration::from_secs(1)).await; + + *state.lock() = "graceful"; + + Ok::<_, Infallible>(42) + } + }; + + let async_storage = AsyncTasksStorage::default(); + + let task = async_storage.run(a); + + settle().await; + + assert_eq!(*state.lock(), "started"); + + async_storage.cancel_task(task.id()); + + // The task observed the token and finished gracefully: its + // result must be delivered, not discarded or lost to a + // watcher panic. + let res = task.await; + assert_eq!(res.unwrap(), 42); + + assert_eq!(*state.lock(), "graceful"); + } + + #[test(start_paused = true)] + async fn test_inner_cancel() { + let notify = Arc::new(Notify::new()); + let state_a = Arc::new(Mutex::new("initial")); + let state_b = Arc::new(Mutex::new("initial")); + let state_c = Arc::new(Mutex::new("initial")); + let a = mock_successful!(state_a, notify); + let b = mock_successful!(state_b, notify); + let c = { + let state = state_c.clone(); + let notify = notify.clone(); + + async move |ctx: AsyncTaskContext| { + *state.lock() = "started"; + + let _ = join!(ctx.run(a), ctx.run(b)); + + notify.notified().await; + + *state.lock() = "finished"; + + Ok::<_, Infallible>(()) + } + }; + + let async_storage = AsyncTasksStorage::default(); + + let task_c = async_storage.run(c); + + settle().await; + + assert_eq!(*state_a.lock(), "started"); + assert_eq!(*state_b.lock(), "started"); + assert_eq!(*state_c.lock(), "started"); + + async_storage.cancel_task(task_c.id()); + + let res = task_c.await; + assert!(matches!(res, Err(TaskError::Cancelled))); + + assert_eq!(*state_a.lock(), "started"); + assert_eq!(*state_b.lock(), "started"); + assert_eq!(*state_c.lock(), "started"); + } + + #[test] + async fn test_single_error() { + let notify = Arc::new(Notify::new()); + let state_a = Arc::new(Mutex::new("initial")); + let a = mock_failing!(state_a, notify); + + let async_storage = AsyncTasksStorage::default(); + + let task = async_storage.run(a); + + settle().await; + + assert_eq!(*state_a.lock(), "started"); + + notify.notify_one(); + + let task_id = task.id(); + let res = task.await; + assert!(matches!(res, Err(TaskError::Failed(_)))); + + assert_eq!(*state_a.lock(), "failed"); + + let info = async_storage.task(task_id).unwrap(); + assert!(matches!(info.state.status, TaskStatus::Error(_))); + } + + #[test] + async fn test_inner_error() { + let notify = Arc::new(Notify::new()); + let state_a = Arc::new(Mutex::new("initial")); + let state_c = Arc::new(Mutex::new("initial")); + let a = mock_failing!(state_a, notify); + let c = { + let state = state_c.clone(); + + async move |ctx: AsyncTaskContext| { + *state.lock() = "started"; + + let res = ctx.run(a).await; + assert!(matches!(res, Err(TaskError::Failed(_)))); + + *state.lock() = "inner_failed"; + + Ok::<_, Infallible>(()) + } + }; + + let async_storage = AsyncTasksStorage::default(); + + let task_c = async_storage.run(c); + + settle().await; + + assert_eq!(*state_a.lock(), "started"); + assert_eq!(*state_c.lock(), "started"); + + notify.notify_one(); + + task_c.await.unwrap(); + + assert_eq!(*state_a.lock(), "failed"); + assert_eq!(*state_c.lock(), "inner_failed"); + } + + #[test] + async fn test_panic() { + let notify = Arc::new(Notify::new()); + let state_a = Arc::new(Mutex::new("initial")); + let a = { + let state = state_a.clone(); + let notify = notify.clone(); + + async move |_ctx: AsyncTaskContext| { + *state.lock() = "started"; + + notify.notified().await; + + if *state.lock() == "started" { + panic!("panicking task"); + } + + Ok::<_, Infallible>(()) + } + }; + + let async_storage = AsyncTasksStorage::default(); + + let task = async_storage.run(a); + + settle().await; + + assert_eq!(*state_a.lock(), "started"); + + notify.notify_one(); + + let task_id = task.id(); + let res = task.await; + assert!(matches!(res, Err(TaskError::Panicked(_)))); + + assert_eq!(*state_a.lock(), "started"); + + let info = async_storage.task(task_id).unwrap(); + assert!(matches!(info.state.status, TaskStatus::Panic(_))); + } + + #[test] + async fn test_traverse_statuses() { + let sub_gate = Arc::new(Notify::new()); + let parent_gate = Arc::new(Notify::new()); + + let async_storage = AsyncTasksStorage::default(); + + let task = async_storage.run({ + let sub_gate = sub_gate.clone(); + let parent_gate = parent_gate.clone(); + + async move |ctx: AsyncTaskContext| { + ctx.set_status(TestTask::StepOne); + + let sub = ctx.run({ + let sub_gate = sub_gate.clone(); + + async move |ctx: AsyncTaskContext| { + // Grandchild: registers with the root task. + ctx.run(async move |_ctx: AsyncTaskContext| { + sub_gate.notified().await; + Ok::<_, Infallible>(()) + }) + .await + .unwrap(); + + Ok::<_, Infallible>(()) + } + }); + + sub.await.unwrap(); + + ctx.set_status(TestTask::StepTwo); + + parent_gate.notified().await; + + Ok::<_, Infallible>(()) + } + }); + + let id = task.id(); + + settle().await; + + // Top-level listing: one named task with a live status. + let tasks = async_storage.tasks(); + assert_eq!(tasks.len(), 1); + + let snapshot = &tasks[0]; + assert_eq!(snapshot.id, id); + assert_eq!(snapshot.state.name, "test_task"); + assert!( + matches!(&snapshot.state.status, TaskStatus::Pending(s) if s.as_str() == "StepOne") + ); + + // Both the subtask and its grandchild appear as direct + // children of the root, flat. + assert_eq!(snapshot.subtasks.len(), 2); + for sub in &snapshot.subtasks { + assert_eq!(sub.state.name, "anonymous"); + assert!(matches!(sub.state.status, TaskStatus::Started)); + assert!(sub.subtasks.is_empty()); + } + + sub_gate.notify_one(); + settle().await; + + // Subtask finished, parent moved to the next phase. + let snapshot = async_storage.task(id).unwrap(); + assert!( + matches!(&snapshot.state.status, TaskStatus::Pending(s) if s.as_str() == "StepTwo") + ); + for sub in &snapshot.subtasks { + assert!(matches!(sub.state.status, TaskStatus::Finished)); + } + + parent_gate.notify_one(); + + task.await.unwrap(); + + // Terminal status stays observable after completion. + let snapshot = async_storage.task(id).unwrap(); + assert!(matches!(snapshot.state.status, TaskStatus::Finished)); + } + + #[test] + async fn test_prune_expired_tasks() { + let notify_a = Arc::new(Notify::new()); + let notify_b = Arc::new(Notify::new()); + let state_a = Arc::new(Mutex::new("initial")); + let state_b = Arc::new(Mutex::new("initial")); + let a = mock_successful!(state_a, notify_a); + let b = mock_successful!(state_b, notify_b); + + // Zero retention: terminal tasks are pruned on next access. + let async_storage = AsyncTasksStorage::new(Duration::ZERO); + + let task_a = async_storage.run(a); + let task_b = async_storage.run(b); + let id_a = task_a.id(); + let id_b = task_b.id(); + + settle().await; + + // Both running: nothing to prune. + assert_eq!(async_storage.tasks().len(), 2); + + notify_a.notify_one(); + task_a.await.unwrap(); + + // The finished task expired; the running one survives. + let tasks = async_storage.tasks(); + assert_eq!(tasks.len(), 1); + assert_eq!(tasks[0].id, id_b); + assert!(async_storage.task(id_a).is_none()); + + notify_b.notify_one(); + task_b.await.unwrap(); + + assert!(async_storage.tasks().is_empty()); + } + + #[test(start_paused = true)] + async fn test_cancel_timeout_override() { + #[derive(Display, Debug)] + struct SlowExitTask; + + impl TaskInfoStatus for SlowExitTask { + fn task_name() -> &'static str { + "slow_exit" + } + + fn cancel_timeout() -> Duration { + Duration::from_secs(30) + } + } + + let notify = Arc::new(Notify::new()); + + let async_storage = AsyncTasksStorage::default(); + + let task = async_storage.run({ + let notify = notify.clone(); + + // Takes its cancellation token (opting into the grace + // period) but ignores it: only the forced abort after + // the grace period can stop it. + async move |ctx: AsyncTaskContext| { + let _token = ctx.cancellation_token(); + notify.notified().await; + Ok::<_, Infallible>(()) + } + }); + + settle().await; + + let started = tokio::time::Instant::now(); + + async_storage.cancel_task(task.id()); + + let res = task.await; + assert!(matches!(res, Err(TaskError::Cancelled))); + + // The paused clock advances exactly by the overridden + // grace period, not the default 5s. + assert_eq!(started.elapsed(), Duration::from_secs(30)); + } + + #[test(start_paused = true)] + async fn test_non_cooperative_cancel_aborts_immediately() { + let notify = Arc::new(Notify::new()); + + let async_storage = AsyncTasksStorage::default(); + + let task = async_storage.run({ + let notify = notify.clone(); + + // Never takes the cancellation token: no grace period. + async move |_ctx: AsyncTaskContext| { + notify.notified().await; + Ok::<_, Infallible>(()) + } + }); + + settle().await; + + let started = tokio::time::Instant::now(); + + async_storage.cancel_task(task.id()); + + let res = task.await; + assert!(matches!(res, Err(TaskError::Cancelled))); + + // Aborted right away: the paused clock did not advance. + assert_eq!(started.elapsed(), Duration::ZERO); + } +} diff --git a/pgdog/src/api/mod.rs b/pgdog/src/api/mod.rs new file mode 100644 index 000000000..d1ea61c98 --- /dev/null +++ b/pgdog/src/api/mod.rs @@ -0,0 +1,8 @@ +//! PgDog API handlers. +//! +//! The interfaces that calls the api: +//! - pgdog CLI +//! - admin db api + +mod async_task; +pub mod resharding; diff --git a/pgdog/src/api/resharding.rs b/pgdog/src/api/resharding.rs new file mode 100644 index 000000000..3fb1fb335 --- /dev/null +++ b/pgdog/src/api/resharding.rs @@ -0,0 +1,42 @@ +use crate::backend::replication::Error; +use crate::backend::replication::orchestrator::Orchestrator; + +pub struct Options { + /// Source database name. + from_database: String, + + /// Destination database. + to_database: String, + + /// Publication name. + publication: String, + + /// Name of the replication slot to create/use. + replication_slot: Option, + + /// Replicate or copy data over. + replicate_only: bool, + + /// Replicate or copy data over. + sync_only: bool, + + /// Don't perform pre-data schema sync. + skip_schema_sync: bool, +} + +pub fn reshard(options: Options) -> Result<(), Error> { + let Options { + from_database, + to_database, + publication, + replication_slot, + replicate_only, + sync_only, + skip_schema_sync, + } = options; + + let orchestrator = + Orchestrator::new(&from_database, &to_database, &publication, replication_slot); + + Ok(()) +} diff --git a/pgdog/src/lib.rs b/pgdog/src/lib.rs index ee9e757dd..c3f4aad45 100644 --- a/pgdog/src/lib.rs +++ b/pgdog/src/lib.rs @@ -2,7 +2,11 @@ #![allow(clippy::result_unit_err)] #![deny(clippy::print_stdout)] +#[macro_use] +extern crate derive_more; + pub mod admin; +pub mod api; pub mod auth; pub mod backend; pub mod cli; From 919831d01abea56138e76f3a4ca6227f6e1b632d Mon Sep 17 00:00:00 2001 From: meskill <8974488+meskill@users.noreply.github.com> Date: Thu, 18 Jun 2026 11:59:34 +0000 Subject: [PATCH 3/8] test(api): add tests for admin commands --- .config/nextest.toml | 1 + .../rust/tests/integration/admin/mod.rs | 103 +++++ .../tests/integration/admin/show_config.rs | 26 ++ .../tests/integration/admin/show_version.rs | 23 + .../rust/tests/integration/admin/tasks.rs | 393 ++++++++++++++++++ integration/rust/tests/integration/mod.rs | 1 + integration/users.toml | 7 + 7 files changed, 554 insertions(+) create mode 100644 integration/rust/tests/integration/admin/mod.rs create mode 100644 integration/rust/tests/integration/admin/show_config.rs create mode 100644 integration/rust/tests/integration/admin/show_version.rs create mode 100644 integration/rust/tests/integration/admin/tasks.rs diff --git a/.config/nextest.toml b/.config/nextest.toml index e1836202c..959c1bc19 100644 --- a/.config/nextest.toml +++ b/.config/nextest.toml @@ -5,3 +5,4 @@ default-filter = "not package(rust)" [profile.default] slow-timeout = "15s" +test-threads = 1 diff --git a/integration/rust/tests/integration/admin/mod.rs b/integration/rust/tests/integration/admin/mod.rs new file mode 100644 index 000000000..f29650406 --- /dev/null +++ b/integration/rust/tests/integration/admin/mod.rs @@ -0,0 +1,103 @@ +//! Integration tests asserting the output of admin commands. +//! +//! Each submodule connects to the live PgDog admin database (see +//! `rust::setup::admin_sqlx`) and verifies the shape and contents of a +//! command's output over the wire. +pub mod show_config; +pub mod show_version; +pub mod tasks; + +use sqlx::{Column, Executor, Pool, Postgres, Row, TypeInfo}; + +/// Wire layout expected from `SHOW TASKS`. +const SHOW_TASKS_LAYOUT: &[(&str, &str)] = &[ + ("id", "INT8"), + ("type", "TEXT"), + ("started_at", "TEXT"), + ("elapsed", "TEXT"), + ("elapsed_ms", "INT8"), +]; + +/// A single row returned by `SHOW TASKS` with all fields already parsed and +/// validated. Construction only succeeds through [`Tasks::fetch`], which +/// checks the wire layout and field invariants before handing rows out. +#[derive(Debug, Clone)] +pub struct Task { + pub id: i64, + pub kind: String, + pub started_at: String, + pub elapsed: String, + pub elapsed_ms: i64, +} + +/// Parsed result of a `SHOW TASKS` admin command. +/// +/// Call [`Tasks::fetch`] to issue the command, validate column layout and +/// every row's field invariants in one shot, and get back a typed collection +/// you can query with [`Tasks::find`] or iterate over [`Tasks::rows`]. +pub struct Tasks { + pub rows: Vec, +} + +impl Tasks { + /// Issue `SHOW TASKS` against `pool`, assert the wire layout, parse and + /// validate every row, and return the collection. + /// + /// Panics on any layout mismatch, unexpected wire type, or field that + /// violates an invariant (empty timestamp, negative elapsed_ms). + pub async fn fetch(pool: &Pool) -> Self { + let raw = pool.fetch_all("SHOW TASKS").await.unwrap(); + + // assert_layout requires at least one row; skip when empty (valid — no tasks running). + if !raw.is_empty() { + assert_layout(&raw, SHOW_TASKS_LAYOUT); + } + + let rows = raw + .iter() + .map(|row| { + let id: i64 = row.get("id"); + let started_at: String = row.get("started_at"); + let elapsed: String = row.get("elapsed"); + let elapsed_ms: i64 = row.get("elapsed_ms"); + + assert!(!started_at.is_empty(), "task {id}: started_at is empty"); + assert!(!elapsed.is_empty(), "task {id}: elapsed is empty"); + assert!(elapsed_ms >= 0, "task {id}: elapsed_ms is negative"); + + Task { + id, + kind: row.get("type"), + started_at, + elapsed, + elapsed_ms, + } + }) + .collect(); + + Self { rows } + } + + /// Return the task with the given id, if present. + pub fn find(&self, id: i64) -> Option<&Task> { + self.rows.iter().find(|t| t.id == id) + } + + pub fn is_empty(&self) -> bool { + self.rows.is_empty() + } +} + +/// Assert that `rows` is non-empty and that the first row's column layout +/// (name, wire type) matches `expected` exactly, in order. +/// +/// Used by submodule tests for commands other than `SHOW TASKS`. +pub fn assert_layout(rows: &[sqlx::postgres::PgRow], expected: &[(&str, &str)]) { + assert!(!rows.is_empty(), "expected at least one row"); + let actual: Vec<(&str, &str)> = rows[0] + .columns() + .iter() + .map(|col| (col.name(), col.type_info().name())) + .collect(); + assert_eq!(actual, expected, "column layout mismatch"); +} diff --git a/integration/rust/tests/integration/admin/show_config.rs b/integration/rust/tests/integration/admin/show_config.rs new file mode 100644 index 000000000..1d23de6be --- /dev/null +++ b/integration/rust/tests/integration/admin/show_config.rs @@ -0,0 +1,26 @@ +use std::collections::HashMap; + +use rust::setup::admin_sqlx; +use sqlx::{Executor, Row}; + +use super::assert_layout; + +/// `SHOW CONFIG` returns rows described by two TEXT columns, `name` and +/// `value`, one per configuration setting. +#[tokio::test] +async fn test_show_config_reports_settings() { + let admin = admin_sqlx().await; + let rows = admin.fetch_all("SHOW CONFIG").await.unwrap(); + + assert_layout(&rows, &[("name", "TEXT"), ("value", "TEXT")]); + + let settings: HashMap = rows + .iter() + .map(|row| (row.get("name"), row.get("value"))) + .collect(); + + assert_eq!(settings["host"], "0.0.0.0"); + assert_eq!(settings["port"], "6432"); + + admin.close().await; +} diff --git a/integration/rust/tests/integration/admin/show_version.rs b/integration/rust/tests/integration/admin/show_version.rs new file mode 100644 index 000000000..8afc82431 --- /dev/null +++ b/integration/rust/tests/integration/admin/show_version.rs @@ -0,0 +1,23 @@ +use rust::setup::admin_sqlx; +use sqlx::{Executor, Row}; + +use super::assert_layout; + +/// `SHOW VERSION` returns a single row described by one `version` TEXT column, +/// carrying the PgDog version banner. +#[tokio::test] +async fn test_show_version_reports_banner() { + let admin = admin_sqlx().await; + let rows = admin.fetch_all("SHOW VERSION").await.unwrap(); + + assert_eq!(rows.len(), 1, "SHOW VERSION should return exactly one row"); + assert_layout(&rows, &[("version", "TEXT")]); + + let version: &str = rows[0].get("version"); + assert!( + version.starts_with("PgDog v"), + "version should start with the PgDog banner, got: {version:?}" + ); + + admin.close().await; +} diff --git a/integration/rust/tests/integration/admin/tasks.rs b/integration/rust/tests/integration/admin/tasks.rs new file mode 100644 index 000000000..efde1da74 --- /dev/null +++ b/integration/rust/tests/integration/admin/tasks.rs @@ -0,0 +1,393 @@ +use std::time::Duration; + +use rust::setup::{admin_sqlx, connection_sqlx_direct, connection_sqlx_direct_db}; +use sqlx::{Executor, Pool, Postgres, Row}; +use tokio::time::sleep; + +use super::Tasks; + +// ─── Constants ────────────────────────────────────────────────────────────── + +/// Shared table created in the source `pgdog` database and propagated to the +/// destination shards by schema_sync and copy_data tests. Sequential +/// execution (`test-threads = 1`) means each test owns it exclusively. +const TEST_TABLE: &str = "_pgdog_test_task"; + +const STOP_TASK_PUB: &str = "pgdog_stop_task_test_pub"; +const STOP_TASK_SLOT: &str = "pgdog_stop_task_test_slot"; +const CUTOVER_PUB: &str = "pgdog_cutover_test_pub"; +const CUTOVER_SLOT: &str = "pgdog_cutover_test_slot"; +const SCHEMA_SYNC_PRE_PUB: &str = "pgdog_schema_sync_pre_test_pub"; +const SCHEMA_SYNC_POST_PUB: &str = "pgdog_schema_sync_post_test_pub"; +const COPY_DATA_PUB: &str = "pgdog_copy_data_test_pub"; + +/// WHERE predicate that matches every replication slot created by these tests. +/// +/// Used verbatim in three consecutive queries inside [`cleanup`]: +/// terminate active WAL senders → wait until inactive → drop. +const SLOT_FILTER: &str = " slot_name LIKE 'pgdog_stop_task_test_slot_%' \ + OR slot_name LIKE 'pgdog_cutover_test_slot_%' \ + OR slot_name LIKE '__pgdog_repl_%'"; + +// ─── Helpers ──────────────────────────────────────────────────────────────── + +/// Drop `table` and its orphaned row-type (left behind by interrupted DDL) +/// from `pool`. Both statements use `IF EXISTS` so the function is safe to +/// call when the objects do not exist. +async fn drop_table(pool: &Pool, table: &str) { + let _ = pool + .execute(format!("DROP TABLE IF EXISTS {table} CASCADE").as_str()) + .await; + let _ = pool + .execute(format!("DROP TYPE IF EXISTS {table} CASCADE").as_str()) + .await; +} + +/// Drop `table` from the source database (`direct`) and from every shard, +/// reusing the same [`drop_table`] call for each. +/// +/// `pgdog_sharded` uses shards `shard_0` and `shard_1` in the integration +/// setup. Connection failures for individual shards are ignored so that a +/// single bad shard does not block cleanup of the rest. +async fn drop_table_everywhere(table: &str, direct: &Pool) { + drop_table(direct, table).await; + for db in &["shard_0", "shard_1"] { + drop_table(&connection_sqlx_direct_db(db).await, table).await; + } +} + +/// Full cleanup for all task tests — idempotent and safe to call as both +/// pre-flight (evict prior-run leftovers) and post-flight (leave state clean). +/// +/// 1. Stop every live PgDog task via `STOP_TASK`. +/// 2. Wait for the task map to drain. +/// 3. Terminate WAL senders still holding any test slot. +/// 4. Wait until all those slots are inactive. +/// 5. Drop the now-inactive test slots. +/// 6. Drop all test publications (`IF EXISTS` — idempotent). +/// 7. Drop [`TEST_TABLE`] from the source database and from every shard. +async fn cleanup(admin: &Pool, direct: &Pool) { + // 1. Cooperative stop. + for task in &Tasks::fetch(admin).await.rows { + let _ = admin + .execute(format!("STOP_TASK {}", task.id).as_str()) + .await; + } + + // 2. Wait for the task map to drain. + for _ in 0..20 { + if Tasks::fetch(admin).await.is_empty() { + break; + } + sleep(Duration::from_millis(500)).await; + } + + // 3. Terminate WAL senders on any test slot. + let _ = direct + .execute( + format!( + "SELECT pg_terminate_backend(active_pid) \ + FROM pg_replication_slots \ + WHERE ({SLOT_FILTER}) AND active_pid IS NOT NULL" + ) + .as_str(), + ) + .await; + + // 4. Wait for those slots to deactivate. + for _ in 0..20 { + let any_active = direct + .fetch_optional(sqlx::query(&format!( + "SELECT bool_or(active) AS active FROM pg_replication_slots WHERE {SLOT_FILTER}" + ))) + .await + .ok() + .flatten() + .and_then(|row: sqlx::postgres::PgRow| row.get::, _>("active")) + .unwrap_or(false); + if !any_active { + break; + } + sleep(Duration::from_millis(500)).await; + } + + // 5. Drop inactive test slots. + let _ = direct + .execute( + format!( + "SELECT pg_drop_replication_slot(slot_name) \ + FROM pg_replication_slots \ + WHERE ({SLOT_FILTER}) AND NOT active" + ) + .as_str(), + ) + .await; + + // 6. Drop all test publications. + for pub_name in &[ + STOP_TASK_PUB, + CUTOVER_PUB, + SCHEMA_SYNC_PRE_PUB, + SCHEMA_SYNC_POST_PUB, + COPY_DATA_PUB, + ] { + let _ = direct + .execute(format!("DROP PUBLICATION IF EXISTS {pub_name}").as_str()) + .await; + } + + // 7. Drop shared test table from source and every shard. + drop_table_everywhere(TEST_TABLE, direct).await; +} + +/// Start `pgdog` → `pgdog_sharded` replication using a `FOR ALL TABLES` +/// publication. Waits until the task appears in `SHOW TASKS` with kind +/// `"replication"` and returns its id. +async fn start_replication( + pub_name: &str, + slot_name: &str, + admin: &Pool, + direct: &Pool, +) -> i64 { + admin.execute("RELOAD").await.unwrap(); + sleep(Duration::from_millis(500)).await; + + direct + .execute(format!("CREATE PUBLICATION {pub_name} FOR ALL TABLES").as_str()) + .await + .unwrap(); + + let row = admin + .fetch_one(format!("REPLICATE pgdog pgdog_sharded {pub_name} {slot_name}").as_str()) + .await + .unwrap(); + // REPLICATE returns task_id as TEXT on the wire. + let task_id: i64 = row.get::("task_id").parse().unwrap(); + + let mut appeared = false; + for _ in 0..20 { + if Tasks::fetch(admin) + .await + .find(task_id) + .is_some_and(|t| t.kind == "replication") + { + appeared = true; + break; + } + sleep(Duration::from_millis(500)).await; + } + assert!( + appeared, + "replication task {task_id} did not appear in SHOW TASKS within 10s" + ); + + task_id +} + +/// Poll until `task_id` is absent from `SHOW TASKS` (up to 30 s). +async fn wait_for_task_gone(admin: &Pool, task_id: i64) { + for _ in 0..60 { + if Tasks::fetch(admin).await.find(task_id).is_none() { + return; + } + sleep(Duration::from_millis(500)).await; + } + panic!("task {task_id} still present in SHOW TASKS after 30s"); +} + +// ─── Tests ────────────────────────────────────────────────────────────────── + +/// `STOP_TASK` on an id that does not exist returns `"task not found"` rather +/// than a connection error. +#[tokio::test] +async fn test_stop_nonexistent_task() { + let admin = admin_sqlx().await; + + let row = admin.fetch_one("STOP_TASK 999999999").await.unwrap(); + assert_eq!(row.get::("stop_task"), "task not found"); +} + +/// `CUTOVER` with no replication task running returns a server error; the +/// connection pool stays healthy afterward. +#[tokio::test] +async fn test_cutover_without_replication_task() { + let direct = connection_sqlx_direct().await; + let admin = admin_sqlx().await; + cleanup(&admin, &direct).await; + + let err = admin.fetch_one("CUTOVER").await.unwrap_err(); + assert!( + matches!(err, sqlx::Error::Database(_)), + "expected a database error, got: {err:?}" + ); + // Pool must still be usable. + admin.fetch_one("SHOW VERSION").await.unwrap(); +} + +/// A replication task can be cancelled via `STOP_TASK `, which returns +/// `"OK"` and removes the task from `SHOW TASKS`. +#[tokio::test] +async fn test_stop_task() { + let direct = connection_sqlx_direct().await; + let admin = admin_sqlx().await; + cleanup(&admin, &direct).await; + + let task_id = start_replication(STOP_TASK_PUB, STOP_TASK_SLOT, &admin, &direct).await; + + let row = admin + .fetch_one(format!("STOP_TASK {task_id}").as_str()) + .await + .unwrap(); + assert_eq!(row.get::("stop_task"), "OK"); + + wait_for_task_gone(&admin, task_id).await; + cleanup(&admin, &direct).await; +} + +/// A replication task can also be stopped via `CUTOVER`, which triggers a +/// final sync, returns `"OK"`, and removes the task from `SHOW TASKS`. +#[tokio::test] +async fn test_cutover() { + let direct = connection_sqlx_direct().await; + let admin = admin_sqlx().await; + cleanup(&admin, &direct).await; + + let task_id = start_replication(CUTOVER_PUB, CUTOVER_SLOT, &admin, &direct).await; + + let row = admin.fetch_one("CUTOVER").await.unwrap(); + assert_eq!(row.get::("cutover"), "OK"); + + wait_for_task_gone(&admin, task_id).await; + cleanup(&admin, &direct).await; +} + +/// `SCHEMA_SYNC pre` registers a `schema_sync` task synchronously before +/// returning the `task_id`, so the task is in `SHOW TASKS` immediately. +#[tokio::test] +async fn test_schema_sync_pre() { + let direct = connection_sqlx_direct().await; + let admin = admin_sqlx().await; + cleanup(&admin, &direct).await; + + direct + .execute(format!("CREATE TABLE {TEST_TABLE} (id SERIAL PRIMARY KEY, val TEXT)").as_str()) + .await + .unwrap(); + direct + .execute( + format!("CREATE PUBLICATION {SCHEMA_SYNC_PRE_PUB} FOR TABLE {TEST_TABLE}").as_str(), + ) + .await + .unwrap(); + + let row = admin + .fetch_one(format!("SCHEMA_SYNC pre pgdog pgdog_sharded {SCHEMA_SYNC_PRE_PUB}").as_str()) + .await + .unwrap(); + let task_id: i64 = row.get::("task_id").parse().unwrap(); + + // Task is registered before the command returns; verify kind if still running. + if let Some(task) = Tasks::fetch(&admin).await.find(task_id) { + assert_eq!(task.kind, "schema_sync"); + } + + let stop = admin + .fetch_one(format!("STOP_TASK {task_id}").as_str()) + .await + .unwrap(); + let status = stop.get::("stop_task"); + assert!( + status == "OK" || status == "task not found", + "unexpected STOP_TASK response: {status}" + ); + + cleanup(&admin, &direct).await; +} + +/// `SCHEMA_SYNC post` follows the same task lifecycle as `pre`. +#[tokio::test] +async fn test_schema_sync_post() { + let direct = connection_sqlx_direct().await; + let admin = admin_sqlx().await; + cleanup(&admin, &direct).await; + + direct + .execute(format!("CREATE TABLE {TEST_TABLE} (id SERIAL PRIMARY KEY, val TEXT)").as_str()) + .await + .unwrap(); + direct + .execute( + format!("CREATE PUBLICATION {SCHEMA_SYNC_POST_PUB} FOR TABLE {TEST_TABLE}").as_str(), + ) + .await + .unwrap(); + + let row = admin + .fetch_one(format!("SCHEMA_SYNC post pgdog pgdog_sharded {SCHEMA_SYNC_POST_PUB}").as_str()) + .await + .unwrap(); + let task_id: i64 = row.get::("task_id").parse().unwrap(); + + if let Some(task) = Tasks::fetch(&admin).await.find(task_id) { + assert_eq!(task.kind, "schema_sync"); + } + + let stop = admin + .fetch_one(format!("STOP_TASK {task_id}").as_str()) + .await + .unwrap(); + let status = stop.get::("stop_task"); + assert!( + status == "OK" || status == "task not found", + "unexpected STOP_TASK response: {status}" + ); + + cleanup(&admin, &direct).await; +} + +/// `COPY_DATA` returns `task_id TEXT` and `replication_slot TEXT`. A +/// `copy_data` task is registered synchronously; it internally spawns a +/// `replication` task when complete. +#[tokio::test] +async fn test_copy_data() { + let direct = connection_sqlx_direct().await; + let admin = admin_sqlx().await; + cleanup(&admin, &direct).await; + + direct + .execute(format!("CREATE TABLE {TEST_TABLE} (id SERIAL PRIMARY KEY, val TEXT)").as_str()) + .await + .unwrap(); + direct + .execute(format!("CREATE PUBLICATION {COPY_DATA_PUB} FOR TABLE {TEST_TABLE}").as_str()) + .await + .unwrap(); + + // Response: task_id TEXT, replication_slot TEXT. + let row = admin + .fetch_one(format!("COPY_DATA pgdog pgdog_sharded {COPY_DATA_PUB}").as_str()) + .await + .unwrap(); + let task_id: i64 = row.get::("task_id").parse().unwrap(); + let slot_name: String = row.get("replication_slot"); + assert!(!slot_name.is_empty(), "replication_slot must be non-empty"); + + // Verify kind while still running (may already be gone if fast). + if let Some(task) = Tasks::fetch(&admin).await.find(task_id) { + assert_eq!(task.kind, "copy_data"); + } + + // Abort early; STOP_TASK on a CopyData task emits a WARNING notice that + // sqlx ignores. "task not found" is valid if the task finished first. + let stop = admin + .fetch_one(format!("STOP_TASK {task_id}").as_str()) + .await + .unwrap(); + let status = stop.get::("stop_task"); + assert!( + status == "OK" || status == "task not found", + "unexpected STOP_TASK response: {status}" + ); + + cleanup(&admin, &direct).await; +} diff --git a/integration/rust/tests/integration/mod.rs b/integration/rust/tests/integration/mod.rs index ce9d2f971..60132d930 100644 --- a/integration/rust/tests/integration/mod.rs +++ b/integration/rust/tests/integration/mod.rs @@ -1,3 +1,4 @@ +pub mod admin; pub mod admin_termination; pub mod auth; pub mod auto_id; diff --git a/integration/users.toml b/integration/users.toml index 3a552e640..d430e022a 100644 --- a/integration/users.toml +++ b/integration/users.toml @@ -3,6 +3,13 @@ name = "pgdog" database = "pgdog" password = "pgdog" +[[users]] +name = "pgdog_migrator" +database = "pgdog" +password = "pgdog" +server_user = "pgdog" +schema_admin = true + [[users]] name = "pgdog_hashed" database = "pgdog" From 93e1e2b3c98d895b83800706b8fd124c2bcec26a Mon Sep 17 00:00:00 2001 From: meskill <8974488+meskill@users.noreply.github.com> Date: Thu, 11 Jun 2026 16:22:15 +0000 Subject: [PATCH 4/8] feat: add new api tasks --- pgdog/src/api/async_task.rs | 672 +++++++++++------- pgdog/src/api/copy_data.rs | 80 +++ pgdog/src/api/mod.rs | 63 +- pgdog/src/api/replication.rs | 99 +++ pgdog/src/api/resharding.rs | 118 ++- pgdog/src/api/schema_sync.rs | 83 +++ .../replication/logical/cutover_signal.rs | 43 ++ pgdog/src/backend/replication/logical/mod.rs | 1 + .../replication/logical/orchestrator.rs | 3 + 9 files changed, 877 insertions(+), 285 deletions(-) create mode 100644 pgdog/src/api/copy_data.rs create mode 100644 pgdog/src/api/replication.rs create mode 100644 pgdog/src/api/schema_sync.rs create mode 100644 pgdog/src/backend/replication/logical/cutover_signal.rs diff --git a/pgdog/src/api/async_task.rs b/pgdog/src/api/async_task.rs index 6123fee67..b80cf8c7e 100644 --- a/pgdog/src/api/async_task.rs +++ b/pgdog/src/api/async_task.rs @@ -16,29 +16,50 @@ use tokio_util::sync::CancellationToken; #[derive(Copy, Clone, Debug, Display, Hash, PartialEq, Eq, PartialOrd, Ord)] pub struct AsyncTaskId(u64); -pub trait TaskInfoStatus: Display + Debug + Send + Sync + 'static { - fn task_name() -> &'static str; + +impl From for AsyncTaskId { + fn from(id: u64) -> Self { + Self(id) + } +} + +/// Status type for tasks that report no intermediate progress. +/// +/// [`Infallible`](std::convert::Infallible) is uninhabited, so a task +/// with this status type can never call +/// [`set_status`](AsyncTaskContext::set_status). +pub type Empty = std::convert::Infallible; + +/// A composable background task: a value carrying its own arguments, +/// driven to completion by [`run`](Task::run). Launch it top-level with +/// [`AsyncTasksStorage::run`], or nested under a running task with +/// [`AsyncTaskContext::run`]. +pub trait Task: Display + Debug + Send + Sync + Sized + 'static { + /// Progress-status payload reported through + /// [`set_status`](AsyncTaskContext::set_status) while the task + /// runs; the [`Empty`] type reports no intermediate progress. + type Status: Display + Send + Sync + 'static; + /// Value the task resolves to on success. + type Output: Send + 'static; + /// Error the task may fail with. + type Error: std::error::Error + Send + 'static; /// Grace period for cooperative shutdown after cancellation; /// once it expires, the task is force-aborted. fn cancel_timeout() -> Duration { Duration::from_secs(5) } -} -#[derive(Display, Debug)] -pub struct AnonymousTask; - -impl TaskInfoStatus for AnonymousTask { - fn task_name() -> &'static str { - "anonymous" - } + fn run( + self, + ctx: AsyncTaskContext, + ) -> impl Future> + Send + 'static; } #[derive(Display, Debug, Clone)] -pub enum TaskStatus { +pub enum TaskStatus { Started, - Pending(T), + Pending(S), Finished, Cancelled, Error(String), @@ -49,7 +70,7 @@ pub enum TaskStatus { /// readable through the registry without knowing `T`. #[derive(Debug, Clone)] pub struct TaskState { - pub name: &'static str, + pub name: String, pub status: TaskStatus, pub started_at: SystemTime, pub updated_at: SystemTime, @@ -72,7 +93,7 @@ pub enum TaskError { Abandoned, } -impl TaskStatus { +impl TaskStatus { /// Terminal states are write-once; late writers /// (e.g. ctx clones outliving the task) are ignored. fn is_terminal(&self) -> bool { @@ -85,7 +106,7 @@ impl TaskStatus { /// Snapshot for the registry: keep the variant, render `T`. fn stringify(&self) -> TaskStatus where - T: Display, + S: Display, { match self { Self::Started => TaskStatus::Started, @@ -98,16 +119,6 @@ impl TaskStatus { } } -impl TaskState { - /// Reached a terminal state more than `ttl` ago? - fn expired(&self, now: SystemTime, ttl: Duration) -> bool { - self.status.is_terminal() - && now - .duration_since(self.updated_at) - .is_ok_and(|age| age >= ttl) - } -} - type SharedStatus = Arc>>; #[derive(Default)] @@ -126,12 +137,12 @@ impl TasksMap { } } -struct AsyncTaskState { +struct AsyncTaskState { updated_at: SystemTime, - status: TaskStatus, + status: TaskStatus, } -impl AsyncTaskState { +impl AsyncTaskState { fn new() -> Self { Self { updated_at: SystemTime::now(), @@ -140,14 +151,15 @@ impl AsyncTaskState { } } -struct AsyncTask { +struct AsyncTask { started_at: SystemTime, + name: String, cancellation_token: CancellationToken, /// Set once the task asks for its cancellation token: only /// then can it react to cancellation, so only then is the /// cooperative-shutdown grace period worth waiting out. cooperative: AtomicBool, - state: Arc>>, + state: Arc>>, subtasks: Arc, } @@ -155,9 +167,12 @@ trait TaskMapEntry: Send + Sync + 'static { fn cancel(&self); fn state(&self) -> TaskState; fn subtasks(&self) -> &TasksMap; + /// Cheap expiry check for pruning: terminal and older than `ttl`, + /// without building a full [`TaskState`]. + fn expired(&self, now: SystemTime, ttl: Duration) -> bool; } -impl TaskMapEntry for AsyncTask { +impl TaskMapEntry for AsyncTask { fn cancel(&self) { self.cancellation_token.cancel(); } @@ -166,7 +181,7 @@ impl TaskMapEntry for AsyncTask { let state = self.state.read(); TaskState { - name: T::task_name(), + name: self.name.clone(), status: state.status.stringify(), started_at: self.started_at, updated_at: state.updated_at, @@ -176,13 +191,21 @@ impl TaskMapEntry for AsyncTask { fn subtasks(&self) -> &TasksMap { &self.subtasks } + + fn expired(&self, now: SystemTime, ttl: Duration) -> bool { + let state = self.state.read(); + state.status.is_terminal() + && now + .duration_since(state.updated_at) + .is_ok_and(|age| age >= ttl) + } } -pub struct AsyncTaskContext { +pub struct AsyncTaskContext { task: Arc>, } -impl Clone for AsyncTaskContext { +impl Clone for AsyncTaskContext { fn clone(&self) -> Self { Self { task: self.task.clone(), @@ -273,52 +296,46 @@ impl Default for AsyncTasksStorage { } } -fn run_task( - current_task: Option<&Arc>>, - tasks: &TasksMap, - execute: impl FnOnce(AsyncTaskContext) -> F, -) -> AsyncTaskWaiter -where - T: TaskInfoStatus, - T1: TaskInfoStatus, - F: Future> + Send + 'static, - R: Send + 'static, - E: std::error::Error + Send + 'static, -{ +fn run_task( + parent_token: Option<&CancellationToken>, + register_into: &TasksMap, + subtasks: Arc, + task: T, +) -> AsyncTaskWaiter { let state = Arc::new(RwLock::new(AsyncTaskState::new())); - let task = AsyncTask { + let entry = AsyncTask { started_at: SystemTime::now(), - cancellation_token: match current_task { - Some(current_task) => current_task.cancellation_token.child_token(), + name: task.to_string(), + cancellation_token: match parent_token { + Some(token) => token.child_token(), None => CancellationToken::new(), }, cooperative: AtomicBool::new(false), - // Subtasks share the root task's registry: every descendant + // Descendants share the root task's registry: every descendant // registers as a direct child of the root. - subtasks: match current_task { - Some(current_task) => current_task.subtasks.clone(), - None => Arc::new(TasksMap::default()), - }, + subtasks, state: state.clone(), }; - let task = Arc::new(task); + let entry = Arc::new(entry); // Make sure we insert task to map before it's actually started - let id = tasks.insert_next(task.clone()); + let id = register_into.insert_next(entry.clone()); - let ctx = AsyncTaskContext { task: task.clone() }; + let ctx = AsyncTaskContext { + task: entry.clone(), + }; - let mut handle = tokio::spawn(execute(ctx.clone())); + let mut handle = tokio::spawn(task.run(ctx.clone())); let (sender, receiver) = oneshot::channel(); - let cancellation_token = task.cancellation_token.clone(); + let cancellation_token = entry.cancellation_token.clone(); tokio::spawn(async move { let res = select! { _ = cancellation_token.cancelled() => { if ctx.task.cooperative.load(Ordering::Relaxed) { - match timeout(T1::cancel_timeout(), &mut handle).await { + match timeout(T::cancel_timeout(), &mut handle).await { Ok(res) => res, Err(_) => { handle.abort(); @@ -365,8 +382,8 @@ where } } -impl AsyncTaskContext { - fn set_inner_status(&self, status: TaskStatus) { +impl AsyncTaskContext { + fn set_inner_status(&self, status: TaskStatus) { let mut state = self.task.state.write(); if state.status.is_terminal() { return; @@ -375,13 +392,13 @@ impl AsyncTaskContext { state.updated_at = SystemTime::now(); } - pub fn set_status(&self, status: T) { + pub fn set_status(&self, status: T::Status) { self.set_inner_status(TaskStatus::Pending(status)); } /// Hand out this task's cancellation token. Taking the token /// opts the task into cooperative shutdown: on `cancel_task` - /// it gets [`TaskInfoStatus::cancel_timeout`] to wind down + /// it gets [`Task::cancel_timeout`] to wind down /// before being aborted. Tasks that never take it are aborted /// immediately. Take it early. pub fn cancellation_token(&self) -> CancellationToken { @@ -390,17 +407,13 @@ impl AsyncTaskContext { self.task.cancellation_token.clone() } - pub fn run( - &self, - execute: impl FnOnce(AsyncTaskContext) -> F, - ) -> AsyncTaskWaiter - where - T1: TaskInfoStatus, - F: Future> + Send + 'static, - R: Send + 'static, - E: std::error::Error + Send + 'static, - { - run_task(Some(&self.task), &self.task.subtasks, execute) + pub fn run(&self, task: T1) -> AsyncTaskWaiter { + run_task( + Some(&self.task.cancellation_token), + &self.task.subtasks, + self.task.subtasks.clone(), + task, + ) } } @@ -412,19 +425,10 @@ impl AsyncTasksStorage { } } - pub fn run( - &self, - execute: impl FnOnce(AsyncTaskContext) -> F, - ) -> AsyncTaskWaiter - where - T: TaskInfoStatus, - F: Future> + Send + 'static, - R: Send + 'static, - E: std::error::Error + Send + 'static, - { + pub fn run(&self, task: T) -> AsyncTaskWaiter { self.prune(); - run_task(Option::<&Arc>>::None, &self.tasks, execute) + run_task(None, &self.tasks, Arc::new(TasksMap::default()), task) } /// Request cancellation of a task. The task winds down @@ -449,9 +453,9 @@ impl AsyncTasksStorage { entry .subtasks() .map - .retain(|_, sub| !sub.state().expired(now, self.retention)); + .retain(|_, sub| !sub.expired(now, self.retention)); - !entry.state().expired(now, self.retention) + !entry.expired(now, self.retention) }); } @@ -491,50 +495,320 @@ mod tests { use tokio::time::sleep; use tokio::{join, test}; + type State = Arc>; + #[derive(Display, Debug)] - enum TestTask { + enum TestTaskStatus { + #[display("StepOne")] StepOne, + #[display("StepTwo")] StepTwo, } - impl TaskInfoStatus for TestTask { - fn task_name() -> &'static str { - "test_task" - } + /// Sets "started", waits on `notify`, then "finished" and succeeds. + #[derive(Display, Debug)] + #[display("mock")] + struct MockSuccessful { + state: State, + notify: Arc, } - macro_rules! mock_successful { - ($state_id:ident, $notify:ident) => {{ - let state = $state_id.clone(); - let notify = $notify.clone(); + impl Task for MockSuccessful { + type Status = Empty; + type Output = (); + type Error = Infallible; + + async fn run(self, _ctx: AsyncTaskContext) -> Result<(), Infallible> { + *self.state.lock() = "started"; + self.notify.notified().await; + *self.state.lock() = "finished"; + Ok(()) + } + } - async move |_ctx: AsyncTaskContext| { - *state.lock() = "started"; + /// Sets "started", waits on `notify`, then "failed" and errors. + #[derive(Display, Debug)] + #[display("mock")] + struct MockFailing { + state: State, + notify: Arc, + } - notify.notified().await; + impl Task for MockFailing { + type Status = Empty; + type Output = (); + type Error = std::io::Error; - *state.lock() = "finished"; + async fn run(self, _ctx: AsyncTaskContext) -> Result<(), std::io::Error> { + *self.state.lock() = "started"; + self.notify.notified().await; + *self.state.lock() = "failed"; + Err(std::io::Error::other("mock task failure")) + } + } - Ok::<_, Infallible>(()) + macro_rules! mock_successful { + ($state:ident, $notify:ident) => { + MockSuccessful { + state: $state.clone(), + notify: $notify.clone(), } - }}; + }; } macro_rules! mock_failing { - ($state_id:ident, $notify:ident) => {{ - let state = $state_id.clone(); - let notify = $notify.clone(); + ($state:ident, $notify:ident) => { + MockFailing { + state: $state.clone(), + notify: $notify.clone(), + } + }; + } + + /// Waits on its gate, then succeeds. Never takes its cancellation + /// token, so it is aborted (not wound down) on cancellation. + #[derive(Display, Debug)] + #[display("anonymous")] + struct Gate { + gate: Arc, + } + + impl Task for Gate { + type Status = Empty; + type Output = (); + type Error = Infallible; - async move |_ctx: AsyncTaskContext| { - *state.lock() = "started"; + async fn run(self, _ctx: AsyncTaskContext) -> Result<(), Infallible> { + self.gate.notified().await; + Ok(()) + } + } + + /// Immediately succeeds. + #[derive(Display, Debug)] + #[display("noop")] + struct Noop; + + impl Task for Noop { + type Status = Empty; + type Output = (); + type Error = Infallible; + + async fn run(self, _ctx: AsyncTaskContext) -> Result<(), Infallible> { + Ok(()) + } + } + + /// Runs two children concurrently and joins them, then finishes. + #[derive(Display, Debug)] + #[display("inner")] + struct InnerJoin { + state: State, + a: MockSuccessful, + b: MockSuccessful, + } + + impl Task for InnerJoin { + type Status = Empty; + type Output = (); + type Error = Infallible; + + async fn run(self, ctx: AsyncTaskContext) -> Result<(), Infallible> { + *self.state.lock() = "started"; + + let (a, b) = join!(ctx.run(self.a), ctx.run(self.b)); + a.unwrap(); + b.unwrap(); + + *self.state.lock() = "finished"; + + Ok(()) + } + } + + /// Runs two children concurrently, then waits on `notify`. + #[derive(Display, Debug)] + #[display("inner")] + struct InnerCancel { + state: State, + notify: Arc, + a: MockSuccessful, + b: MockSuccessful, + } + + impl Task for InnerCancel { + type Status = Empty; + type Output = (); + type Error = Infallible; + + async fn run(self, ctx: AsyncTaskContext) -> Result<(), Infallible> { + *self.state.lock() = "started"; + + let _ = join!(ctx.run(self.a), ctx.run(self.b)); - notify.notified().await; + self.notify.notified().await; - *state.lock() = "failed"; + *self.state.lock() = "finished"; + + Ok(()) + } + } + + /// Runs a failing child and asserts it failed. + #[derive(Display, Debug)] + #[display("inner")] + struct InnerError { + state: State, + a: MockFailing, + } + + impl Task for InnerError { + type Status = Empty; + type Output = (); + type Error = Infallible; + + async fn run(self, ctx: AsyncTaskContext) -> Result<(), Infallible> { + *self.state.lock() = "started"; + + let res = ctx.run(self.a).await; + assert!(matches!(res, Err(TaskError::Failed(_)))); + + *self.state.lock() = "inner_failed"; + + Ok(()) + } + } + + /// Takes its cancellation token and winds down gracefully within + /// the grace window, then succeeds with a value. + #[derive(Display, Debug)] + #[display("graceful")] + struct Graceful { + state: State, + } + + impl Task for Graceful { + type Status = Empty; + type Output = i32; + type Error = Infallible; + + async fn run(self, ctx: AsyncTaskContext) -> Result { + *self.state.lock() = "started"; + + ctx.cancellation_token().cancelled().await; + + // Cooperative wind-down, well within the 5s grace window. + sleep(Duration::from_secs(1)).await; + + *self.state.lock() = "graceful"; + + Ok(42) + } + } - Err::<(), _>(std::io::Error::other("mock task failure")) + /// Panics after being notified, if still in the "started" state. + #[derive(Display, Debug)] + #[display("panicker")] + struct Panicker { + state: State, + notify: Arc, + } + + impl Task for Panicker { + type Status = Empty; + type Output = (); + type Error = Infallible; + + async fn run(self, _ctx: AsyncTaskContext) -> Result<(), Infallible> { + *self.state.lock() = "started"; + + self.notify.notified().await; + + if *self.state.lock() == "started" { + panic!("panicking task"); } - }}; + + Ok(()) + } + } + + /// Subtask that spawns a grandchild gate (registered with the root) + /// and waits for it. + #[derive(Display, Debug)] + #[display("anonymous")] + struct Sub { + sub_gate: Arc, + } + + impl Task for Sub { + type Status = Empty; + type Output = (); + type Error = Infallible; + + async fn run(self, ctx: AsyncTaskContext) -> Result<(), Infallible> { + ctx.run(Gate { + gate: self.sub_gate, + }) + .await + .unwrap(); + + Ok(()) + } + } + + /// Root task that reports `TestTaskStatus`, runs a subtask, then + /// advances its status and waits. + #[derive(Display, Debug)] + #[display("test_task")] + struct TraverseRoot { + sub_gate: Arc, + parent_gate: Arc, + } + + impl Task for TraverseRoot { + type Status = TestTaskStatus; + type Output = (); + type Error = Infallible; + + async fn run(self, ctx: AsyncTaskContext) -> Result<(), Infallible> { + ctx.set_status(TestTaskStatus::StepOne); + + let sub = ctx.run(Sub { + sub_gate: self.sub_gate, + }); + + sub.await.unwrap(); + + ctx.set_status(TestTaskStatus::StepTwo); + + self.parent_gate.notified().await; + + Ok(()) + } + } + + /// Takes its cancellation token (opting into the grace period) but + /// ignores it, with an overridden 30s grace period. + #[derive(Display, Debug)] + #[display("slow_exit")] + struct SlowExit { + notify: Arc, + } + + impl Task for SlowExit { + type Status = Empty; + type Output = (); + type Error = Infallible; + + fn cancel_timeout() -> Duration { + Duration::from_secs(30) + } + + async fn run(self, ctx: AsyncTaskContext) -> Result<(), Infallible> { + let _token = ctx.cancellation_token(); + self.notify.notified().await; + Ok(()) + } } /// Spawned tasks are only guaranteed to be polled after the @@ -617,22 +891,10 @@ mod tests { let state_a = Arc::new(Mutex::new("initial")); let state_b = Arc::new(Mutex::new("initial")); let state_c = Arc::new(Mutex::new("initial")); - let a = mock_successful!(state_a, notify); - let b = mock_successful!(state_b, notify); - let c = { - let state = state_c.clone(); - - async move |ctx: AsyncTaskContext| { - *state.lock() = "started"; - - let (a, b) = join!(ctx.run(a), ctx.run(b)); - a.unwrap(); - b.unwrap(); - - *state.lock() = "finished"; - - Ok::<_, Infallible>(()) - } + let c = InnerJoin { + state: state_c.clone(), + a: mock_successful!(state_a, notify), + b: mock_successful!(state_b, notify), }; let async_storage = AsyncTasksStorage::default(); @@ -680,24 +942,12 @@ mod tests { let snapshot = async_storage.task(task_id).unwrap(); assert!(matches!(snapshot.state.status, TaskStatus::Cancelled)); } + #[test(start_paused = true)] async fn test_graceful_exit_during_cancel_grace() { let state = Arc::new(Mutex::new("initial")); - let a = { - let state = state.clone(); - - async move |ctx: AsyncTaskContext| { - *state.lock() = "started"; - - ctx.cancellation_token().cancelled().await; - - // Cooperative wind-down, well within the 5s grace window. - sleep(Duration::from_secs(1)).await; - - *state.lock() = "graceful"; - - Ok::<_, Infallible>(42) - } + let a = Graceful { + state: state.clone(), }; let async_storage = AsyncTasksStorage::default(); @@ -725,23 +975,11 @@ mod tests { let state_a = Arc::new(Mutex::new("initial")); let state_b = Arc::new(Mutex::new("initial")); let state_c = Arc::new(Mutex::new("initial")); - let a = mock_successful!(state_a, notify); - let b = mock_successful!(state_b, notify); - let c = { - let state = state_c.clone(); - let notify = notify.clone(); - - async move |ctx: AsyncTaskContext| { - *state.lock() = "started"; - - let _ = join!(ctx.run(a), ctx.run(b)); - - notify.notified().await; - - *state.lock() = "finished"; - - Ok::<_, Infallible>(()) - } + let c = InnerCancel { + state: state_c.clone(), + notify: notify.clone(), + a: mock_successful!(state_a, notify), + b: mock_successful!(state_b, notify), }; let async_storage = AsyncTasksStorage::default(); @@ -795,20 +1033,9 @@ mod tests { let notify = Arc::new(Notify::new()); let state_a = Arc::new(Mutex::new("initial")); let state_c = Arc::new(Mutex::new("initial")); - let a = mock_failing!(state_a, notify); - let c = { - let state = state_c.clone(); - - async move |ctx: AsyncTaskContext| { - *state.lock() = "started"; - - let res = ctx.run(a).await; - assert!(matches!(res, Err(TaskError::Failed(_)))); - - *state.lock() = "inner_failed"; - - Ok::<_, Infallible>(()) - } + let c = InnerError { + state: state_c.clone(), + a: mock_failing!(state_a, notify), }; let async_storage = AsyncTasksStorage::default(); @@ -832,21 +1059,9 @@ mod tests { async fn test_panic() { let notify = Arc::new(Notify::new()); let state_a = Arc::new(Mutex::new("initial")); - let a = { - let state = state_a.clone(); - let notify = notify.clone(); - - async move |_ctx: AsyncTaskContext| { - *state.lock() = "started"; - - notify.notified().await; - - if *state.lock() == "started" { - panic!("panicking task"); - } - - Ok::<_, Infallible>(()) - } + let a = Panicker { + state: state_a.clone(), + notify: notify.clone(), }; let async_storage = AsyncTasksStorage::default(); @@ -876,37 +1091,9 @@ mod tests { let async_storage = AsyncTasksStorage::default(); - let task = async_storage.run({ - let sub_gate = sub_gate.clone(); - let parent_gate = parent_gate.clone(); - - async move |ctx: AsyncTaskContext| { - ctx.set_status(TestTask::StepOne); - - let sub = ctx.run({ - let sub_gate = sub_gate.clone(); - - async move |ctx: AsyncTaskContext| { - // Grandchild: registers with the root task. - ctx.run(async move |_ctx: AsyncTaskContext| { - sub_gate.notified().await; - Ok::<_, Infallible>(()) - }) - .await - .unwrap(); - - Ok::<_, Infallible>(()) - } - }); - - sub.await.unwrap(); - - ctx.set_status(TestTask::StepTwo); - - parent_gate.notified().await; - - Ok::<_, Infallible>(()) - } + let task = async_storage.run(TraverseRoot { + sub_gate: sub_gate.clone(), + parent_gate: parent_gate.clone(), }); let id = task.id(); @@ -993,34 +1180,12 @@ mod tests { #[test(start_paused = true)] async fn test_cancel_timeout_override() { - #[derive(Display, Debug)] - struct SlowExitTask; - - impl TaskInfoStatus for SlowExitTask { - fn task_name() -> &'static str { - "slow_exit" - } - - fn cancel_timeout() -> Duration { - Duration::from_secs(30) - } - } - let notify = Arc::new(Notify::new()); let async_storage = AsyncTasksStorage::default(); - let task = async_storage.run({ - let notify = notify.clone(); - - // Takes its cancellation token (opting into the grace - // period) but ignores it: only the forced abort after - // the grace period can stop it. - async move |ctx: AsyncTaskContext| { - let _token = ctx.cancellation_token(); - notify.notified().await; - Ok::<_, Infallible>(()) - } + let task = async_storage.run(SlowExit { + notify: notify.clone(), }); settle().await; @@ -1043,14 +1208,9 @@ mod tests { let async_storage = AsyncTasksStorage::default(); - let task = async_storage.run({ - let notify = notify.clone(); - - // Never takes the cancellation token: no grace period. - async move |_ctx: AsyncTaskContext| { - notify.notified().await; - Ok::<_, Infallible>(()) - } + // `Gate` never takes the cancellation token: no grace period. + let task = async_storage.run(Gate { + gate: notify.clone(), }); settle().await; @@ -1065,4 +1225,12 @@ mod tests { // Aborted right away: the paused clock did not advance. assert_eq!(started.elapsed(), Duration::ZERO); } + + #[test] + async fn global_storage_runs_and_lists() { + let waiter = crate::api::storage().run(Noop); + let id = waiter.id(); + waiter.await.unwrap(); + assert!(crate::api::storage().task(id).is_some()); + } } diff --git a/pgdog/src/api/copy_data.rs b/pgdog/src/api/copy_data.rs new file mode 100644 index 000000000..1a6a124e5 --- /dev/null +++ b/pgdog/src/api/copy_data.rs @@ -0,0 +1,80 @@ +//! Copy-data background task: schema sync + data sync, then a replication +//! task that catches up and (on `CUTOVER`) cuts over. + +use crate::api::async_task::AsyncTaskContext; +use crate::api::replication::ReplicationTask; +use crate::api::schema_sync::{SchemaSyncPhase, SchemaSyncTask}; +use crate::api::{MigrationError, Task}; +use crate::backend::replication::logical::orchestrator::Orchestrator; + +/// Copy data from a source database to a target: schema sync, data sync, +/// then replication catch-up and cutover. +#[derive(Display, Debug)] +#[display("copy_data")] +pub(crate) struct CopyDataTask { + pub orchestrator: Orchestrator, +} + +/// Stages of the copy-data flow, reported as the task's status. The +/// fine-grained schema-sync and replication stages live on the child tasks. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Display)] +pub(crate) enum CopyDataStatus { + /// Running the schema-sync child task. + #[display("syncing schema")] + SchemaSync, + /// Copying table data to the destination. + #[display("syncing data")] + SyncingData, + /// Running the replication child task. + #[display("replicating")] + Replication, +} + +impl Task for CopyDataTask { + type Status = CopyDataStatus; + type Output = (); + type Error = MigrationError; + + async fn run(self, ctx: AsyncTaskContext) -> Result<(), MigrationError> { + // Sync the schema as a child task so it reports its own stages. + ctx.set_status(CopyDataStatus::SchemaSync); + let mut orchestrator = ctx + .run(SchemaSyncTask { + orchestrator: self.orchestrator, + phase: SchemaSyncPhase::Pre, + }) + .await?; + + ctx.set_status(CopyDataStatus::SyncingData); + orchestrator.data_sync().await?; + + // data_sync can run for hours; pools may have reloaded. Re-fetch + // live cluster refs before starting replication. + orchestrator.refresh()?; + + // Replication runs as a child until cutover, reporting its own stages. + // Awaiting keeps copy_data non-terminal while it runs; its outcome is + // intentionally not propagated here. + ctx.set_status(CopyDataStatus::Replication); + let _ = ctx.run(ReplicationTask { orchestrator }).await; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn copy_data_status_renders_distinct_labels() { + let labels = [ + CopyDataStatus::SchemaSync.to_string(), + CopyDataStatus::SyncingData.to_string(), + CopyDataStatus::Replication.to_string(), + ]; + assert!(labels.iter().all(|label| !label.is_empty())); + let unique: std::collections::HashSet<_> = labels.iter().collect(); + assert_eq!(unique.len(), labels.len()); + } +} diff --git a/pgdog/src/api/mod.rs b/pgdog/src/api/mod.rs index d1ea61c98..4f77eb599 100644 --- a/pgdog/src/api/mod.rs +++ b/pgdog/src/api/mod.rs @@ -4,5 +4,66 @@ //! - pgdog CLI //! - admin db api -mod async_task; +use std::sync::LazyLock; + +use crate::backend::replication::logical::Error; +use async_task::{AsyncTaskWaiter, AsyncTasksStorage, TaskError}; + +pub mod async_task; +pub mod copy_data; +pub mod replication; pub mod resharding; +pub mod schema_sync; + +/// Process-global task registry shared by all `crate::api` task modules. +static TASKS: LazyLock = LazyLock::new(AsyncTasksStorage::default); + +/// Accessor for the process-global task registry. +pub(crate) fn storage() -> &'static AsyncTasksStorage { + &TASKS +} + +/// A composable background task: implement [`Task`] (see +/// [`async_task`]) to define one, then launch it as a top-level task +/// with [`start`] or nested under a running task through its +/// [`AsyncTaskContext`](async_task::AsyncTaskContext). +pub(crate) use async_task::Task; + +/// Launch `task` as a top-level task in the global registry. +pub(crate) fn start(task: T) -> AsyncTaskWaiter { + storage().run(task) +} + +/// Error returned by the API migration tasks: either an error from the +/// replication/orchestrator machinery, or a child task's [`TaskError`] +/// (failure, cancellation, panic, or abandonment) surfaced to its parent. +#[derive(Debug, Display, Error, From)] +pub(crate) enum MigrationError { + #[display("{_0}")] + Replication(Error), + #[display("{_0}")] + Task(TaskError), +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn migration_error_wraps_replication_and_task_errors() { + // A replication/orchestrator error converts directly. + let err = MigrationError::from(Error::NoSchema); + assert!(matches!(err, MigrationError::Replication(Error::NoSchema))); + + // A child task's failure is wrapped, preserving the inner error. + let err = MigrationError::from(TaskError::Failed(Error::NoSchema)); + assert!(matches!( + err, + MigrationError::Task(TaskError::Failed(Error::NoSchema)) + )); + + // Non-failure child outcomes are preserved too (not stringified). + let err = MigrationError::from(TaskError::::Cancelled); + assert!(matches!(err, MigrationError::Task(TaskError::Cancelled))); + } +} diff --git a/pgdog/src/api/replication.rs b/pgdog/src/api/replication.rs new file mode 100644 index 000000000..fcb3243ac --- /dev/null +++ b/pgdog/src/api/replication.rs @@ -0,0 +1,99 @@ +//! Logical-replication background task. +//! +//! Drives a `ReplicationWaiter` to completion: it stops on cancellation +//! (`STOP_TASK`), performs cutover on an external `cutover_signal::request()` +//! (`CUTOVER`), and otherwise finishes when the source slot drains (no cutover +//! on natural drain). Launch it top-level with [`super::start`], or as a child +//! by spawning it through a parent task's [`AsyncTaskContext`]. + +use std::time::Duration; + +use tokio::select; + +use crate::api::Task; +use crate::api::async_task::AsyncTaskContext; +use crate::backend::replication::logical::orchestrator::Orchestrator; +use crate::backend::replication::logical::{Error, cutover_signal}; + +/// Stages of logical replication, reported as the task's status. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Display)] +pub(crate) enum ReplicationStatus { + /// Streaming changes to catch the destination up. + #[display("replicating")] + Replicating, + /// Cutting traffic over to the destination. + #[display("cutting over")] + CuttingOver, + /// Winding down on a stop request. + #[display("stopping")] + Stopping, +} + +/// Replicate from a source database to a target, owning the orchestrator +/// that produces the replication waiter. +#[derive(Display, Debug)] +#[display("replication")] +pub(crate) struct ReplicationTask { + pub orchestrator: Orchestrator, +} + +impl Task for ReplicationTask { + type Status = ReplicationStatus; + type Output = (); + type Error = Error; + + /// A cutover, once started, must run to completion. `STOP_TASK` during + /// the waiting phase is handled by the `select!` arm below (graceful + /// `waiter.stop()`); a `STOP_TASK` during an in-flight `waiter.cutover()` + /// waits out this (effectively unbounded) grace period instead of + /// force-aborting mid-cutover. + fn cancel_timeout() -> Duration { + Duration::from_secs(24 * 60 * 60) + } + + async fn run(self, ctx: AsyncTaskContext) -> Result<(), Error> { + let mut waiter = self.orchestrator.replicate().await?; + let token = ctx.cancellation_token(); + + ctx.set_status(ReplicationStatus::Replicating); + + select! { + _ = token.cancelled() => { + ctx.set_status(ReplicationStatus::Stopping); + waiter.stop(); + } + _ = cutover_signal::requested() => { + ctx.set_status(ReplicationStatus::CuttingOver); + waiter.cutover().await?; + } + res = waiter.wait() => { + res?; + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn cancel_timeout_far_exceeds_default() { + // Far larger than the 5s default: cutover must not be force-aborted. + assert!(ReplicationTask::cancel_timeout() > Duration::from_secs(60)); + } + + #[test] + fn replication_status_renders_distinct_labels() { + let labels = [ + ReplicationStatus::Replicating.to_string(), + ReplicationStatus::CuttingOver.to_string(), + ReplicationStatus::Stopping.to_string(), + ]; + assert!(labels.iter().all(|label| !label.is_empty())); + let unique: std::collections::HashSet<_> = labels.iter().collect(); + assert_eq!(unique.len(), labels.len()); + } +} diff --git a/pgdog/src/api/resharding.rs b/pgdog/src/api/resharding.rs index 3fb1fb335..9ac5acc05 100644 --- a/pgdog/src/api/resharding.rs +++ b/pgdog/src/api/resharding.rs @@ -1,42 +1,96 @@ -use crate::backend::replication::Error; -use crate::backend::replication::orchestrator::Orchestrator; +//! Reshard background task: the full automatic schema-sync + data-sync + +//! replication + cutover flow. -pub struct Options { - /// Source database name. - from_database: String, +use crate::api::async_task::AsyncTaskContext; +use crate::api::replication::ReplicationTask; +use crate::api::schema_sync::{SchemaSyncPhase, SchemaSyncTask}; +use crate::api::{MigrationError, Task}; +use crate::backend::replication::logical::cutover_signal; +use crate::backend::replication::logical::orchestrator::Orchestrator; - /// Destination database. - to_database: String, +/// Run the complete replicate-and-cutover flow from a source database to a +/// target. +#[derive(Display, Debug)] +#[display("reshard")] +pub(crate) struct ReshardTask { + pub orchestrator: Orchestrator, +} + +/// Stages of the reshard flow, reported as the task's status. The +/// fine-grained schema-sync and replication stages live on the child tasks. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Display)] +pub(crate) enum ReshardStatus { + /// Running the pre-data schema-sync child task. + #[display("syncing schema")] + SchemaSync, + /// Copying table data to the destination. + #[display("syncing data")] + SyncingData, + /// Running the post-data schema-sync child task (indexes, constraints). + #[display("finalizing schema")] + FinalizingSchema, + /// Running the replication child task through cutover. + #[display("replicating")] + Replication, +} - /// Publication name. - publication: String, +impl Task for ReshardTask { + type Status = ReshardStatus; + type Output = (); + type Error = MigrationError; - /// Name of the replication slot to create/use. - replication_slot: Option, + async fn run(self, ctx: AsyncTaskContext) -> Result<(), MigrationError> { + // Sync the pre-data schema (tables) as a child task. + ctx.set_status(ReshardStatus::SchemaSync); + let orchestrator = ctx + .run(SchemaSyncTask { + orchestrator: self.orchestrator, + phase: SchemaSyncPhase::Pre, + }) + .await?; - /// Replicate or copy data over. - replicate_only: bool, + // Sync the data to destination. + ctx.set_status(ReshardStatus::SyncingData); + orchestrator.data_sync().await?; - /// Replicate or copy data over. - sync_only: bool, + // Create secondary indexes as a child task (schema already loaded). + ctx.set_status(ReshardStatus::FinalizingSchema); + let mut orchestrator = ctx + .run(SchemaSyncTask { + orchestrator, + phase: SchemaSyncPhase::Post, + }) + .await?; - /// Don't perform pre-data schema sync. - skip_schema_sync: bool, + // Refresh cluster references: data_sync can take hours and the pools + // may have been reloaded (e.g. by a client DDL) in the meantime. + orchestrator.refresh()?; + + // Reshard cuts over automatically: request it up front (the cutover + // signal is buffered), then run replication as a child that consumes + // the request and cuts over once it has caught up. + ctx.set_status(ReshardStatus::Replication); + cutover_signal::request(); + ctx.run(ReplicationTask { orchestrator }).await?; + + Ok(()) + } } -pub fn reshard(options: Options) -> Result<(), Error> { - let Options { - from_database, - to_database, - publication, - replication_slot, - replicate_only, - sync_only, - skip_schema_sync, - } = options; - - let orchestrator = - Orchestrator::new(&from_database, &to_database, &publication, replication_slot); - - Ok(()) +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn reshard_status_renders_distinct_labels() { + let labels = [ + ReshardStatus::SchemaSync.to_string(), + ReshardStatus::SyncingData.to_string(), + ReshardStatus::FinalizingSchema.to_string(), + ReshardStatus::Replication.to_string(), + ]; + assert!(labels.iter().all(|label| !label.is_empty())); + let unique: std::collections::HashSet<_> = labels.iter().collect(); + assert_eq!(unique.len(), labels.len()); + } } diff --git a/pgdog/src/api/schema_sync.rs b/pgdog/src/api/schema_sync.rs new file mode 100644 index 000000000..9266dba15 --- /dev/null +++ b/pgdog/src/api/schema_sync.rs @@ -0,0 +1,83 @@ +//! Schema-sync background task (pre-data or post-data). + +use crate::api::Task; +use crate::api::async_task::AsyncTaskContext; +use crate::backend::replication::logical::Error; +use crate::backend::replication::logical::orchestrator::Orchestrator; + +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub enum SchemaSyncPhase { + Pre, + Post, +} + +/// Stages of a schema sync, reported as the task's status. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Display)] +pub(crate) enum SchemaSyncStatus { + /// Dumping the schema from the source. + #[display("loading schema")] + LoadingSchema, + /// Restoring tables on the destination (pre-data). + #[display("syncing tables")] + SyncingTables, + /// Creating indexes and constraints on the destination (post-data). + #[display("creating indexes")] + CreatingIndexes, +} + +/// Sync the schema (pre- or post-data) from a source database to a target. +#[derive(Display, Debug)] +#[display("schema_sync")] +pub(crate) struct SchemaSyncTask { + pub orchestrator: Orchestrator, + pub phase: SchemaSyncPhase, +} + +impl Task for SchemaSyncTask { + type Status = SchemaSyncStatus; + type Output = Orchestrator; + type Error = Error; + + /// Returns the orchestrator with its schema loaded and synced so a parent + /// task can thread it into the next phase. The schema dump is skipped when + /// the orchestrator already carries one (e.g. a parent that runs `Pre` + /// then `Post` on the same orchestrator). + async fn run(self, ctx: AsyncTaskContext) -> Result { + let mut orchestrator = self.orchestrator; + + if orchestrator.schema().is_err() { + ctx.set_status(SchemaSyncStatus::LoadingSchema); + orchestrator.load_schema().await?; + } + + match self.phase { + SchemaSyncPhase::Pre => { + ctx.set_status(SchemaSyncStatus::SyncingTables); + orchestrator.schema_sync_pre(true).await?; + } + SchemaSyncPhase::Post => { + ctx.set_status(SchemaSyncStatus::CreatingIndexes); + orchestrator.schema_sync_post(true).await?; + } + } + + Ok(orchestrator) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn schema_sync_status_renders_distinct_labels() { + let labels = [ + SchemaSyncStatus::LoadingSchema.to_string(), + SchemaSyncStatus::SyncingTables.to_string(), + SchemaSyncStatus::CreatingIndexes.to_string(), + ]; + assert!(labels.iter().all(|label| !label.is_empty())); + let unique: std::collections::HashSet<_> = labels.iter().collect(); + assert_eq!(unique.len(), labels.len()); + } +} diff --git a/pgdog/src/backend/replication/logical/cutover_signal.rs b/pgdog/src/backend/replication/logical/cutover_signal.rs new file mode 100644 index 000000000..e81ea7767 --- /dev/null +++ b/pgdog/src/backend/replication/logical/cutover_signal.rs @@ -0,0 +1,43 @@ +//! Cutover signal for the logical replication task. +//! +//! Standalone channel between the `CUTOVER` admin command and the +//! running replication task: the command [`request`]s the cutover +//! from anywhere, the replication task [`requested`] waits for it. +//! +//! One request is buffered, so a request arriving before the task +//! starts waiting is not lost. There is at most one replication +//! task per process, matching the single buffered permit. + +use std::sync::LazyLock; + +use tokio::sync::Notify; + +static CUTOVER: LazyLock = LazyLock::new(Notify::new); + +/// Request a cutover from the running replication task. +pub fn request() { + CUTOVER.notify_one(); +} + +/// Wait until a cutover is requested. Only the replication task +/// waits on this. +pub async fn requested() { + CUTOVER.notified().await; +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use super::*; + + #[tokio::test] + async fn test_request_is_buffered() { + // Request lands before anyone waits: still delivered. + request(); + + tokio::time::timeout(Duration::from_secs(1), requested()) + .await + .unwrap(); + } +} diff --git a/pgdog/src/backend/replication/logical/mod.rs b/pgdog/src/backend/replication/logical/mod.rs index f8cbf5445..e10a23b0c 100644 --- a/pgdog/src/backend/replication/logical/mod.rs +++ b/pgdog/src/backend/replication/logical/mod.rs @@ -1,5 +1,6 @@ pub mod admin; pub mod copy_statement; +pub mod cutover_signal; pub mod ee; pub mod error; pub mod orchestrator; diff --git a/pgdog/src/backend/replication/logical/orchestrator.rs b/pgdog/src/backend/replication/logical/orchestrator.rs index 572100583..ff394cfcd 100644 --- a/pgdog/src/backend/replication/logical/orchestrator.rs +++ b/pgdog/src/backend/replication/logical/orchestrator.rs @@ -168,6 +168,9 @@ impl Orchestrator { } /// Perform the entire flow in one swoop. + #[deprecated(note = "phase orchestration now lives in the migration tasks (see \ + `crate::api::resharding::ReshardTask`); drive the individual \ + steps directly. Remove once the remaining callers migrate.")] pub(crate) async fn replicate_and_cutover(&mut self) -> Result<(), Error> { // Load the schema from source. self.load_schema().await?; From 7a45799c7bfbdc5b6f1b580023e9a01309f920cb Mon Sep 17 00:00:00 2001 From: meskill <8974488+meskill@users.noreply.github.com> Date: Mon, 22 Jun 2026 12:48:11 +0000 Subject: [PATCH 5/8] migrate to api tasks --- Cargo.lock | 26 ++ Cargo.toml | 1 + docs/ASYNC_TASKS.md | 172 +++++++++ docs/issues/replication.md | 162 +++++++- .../rust/tests/integration/admin/mod.rs | 22 ++ .../rust/tests/integration/admin/tasks.rs | 111 +++--- pgdog/Cargo.toml | 3 +- pgdog/src/admin/copy_data.rs | 22 +- pgdog/src/admin/cutover.rs | 21 +- pgdog/src/admin/parser.rs | 10 +- pgdog/src/admin/replicate.rs | 5 +- pgdog/src/admin/reshard.rs | 20 +- pgdog/src/admin/schema_sync.rs | 57 +-- pgdog/src/admin/show_tasks.rs | 52 ++- pgdog/src/admin/stop_task.rs | 23 +- pgdog/src/api/async_task.rs | 121 +++++- pgdog/src/api/copy_data.rs | 102 ++--- pgdog/src/api/mod.rs | 27 ++ pgdog/src/api/replication.rs | 238 +++++++++++- pgdog/src/api/resharding.rs | 172 ++++++--- pgdog/src/api/schema_sync.rs | 72 +++- .../src/backend/replication/logical/admin.rs | 351 ------------------ .../replication/logical/cutover_signal.rs | 43 --- pgdog/src/backend/replication/logical/mod.rs | 3 - .../replication/logical/orchestrator.rs | 97 ++--- .../logical/publisher/publisher_impl.rs | 56 ++- pgdog/src/backend/schema/sync/pg_dump.rs | 16 +- pgdog/src/cli.rs | 150 ++++---- 28 files changed, 1291 insertions(+), 864 deletions(-) create mode 100644 docs/ASYNC_TASKS.md delete mode 100644 pgdog/src/backend/replication/logical/admin.rs delete mode 100644 pgdog/src/backend/replication/logical/cutover_signal.rs diff --git a/Cargo.lock b/Cargo.lock index 648ca143b..7c9cc8083 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -850,6 +850,31 @@ dependencies = [ "hybrid-array", ] +[[package]] +name = "bon" +version = "3.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a602c73c7b0148ec6d12af6fd5cc7a46e2eacc8878271a999abac56eed12f561" +dependencies = [ + "bon-macros", + "rustversion", +] + +[[package]] +name = "bon-macros" +version = "3.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dee98b0db6a962de883bf5d20362dee4d7ca0d12fe39a7c6c73c844e1cd7c1f" +dependencies = [ + "darling 0.23.0", + "ident_case", + "prettyplease", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.118", +] + [[package]] name = "borsh" version = "1.6.1" @@ -3293,6 +3318,7 @@ dependencies = [ "azure_identity", "base64 0.22.1", "bit-vec 0.8.0", + "bon", "brunch", "bytes", "cc", diff --git a/Cargo.toml b/Cargo.toml index a6db5313b..dc553ed15 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ edition = "2024" [workspace.dependencies] pgdog-plugin = { path = "./pgdog-plugin", version = "0.3.0" } pgdog-config = { path = "./pgdog-config", version = "0.1.0" } +bon = "3.9" schemars = { version = "1.2.1", features = ["uuid1"] } serde_json = "1.0" indexmap = { version = "2.14", features = ["serde"] } diff --git a/docs/ASYNC_TASKS.md b/docs/ASYNC_TASKS.md new file mode 100644 index 000000000..52840525d --- /dev/null +++ b/docs/ASYNC_TASKS.md @@ -0,0 +1,172 @@ +# Async Task Framework — Architecture + +The `crate::api` module ([`pgdog/src/api/`](../pgdog/src/api/)) is the execution layer that sits +between PgDog's two user interfaces and its long-running operations. Any operation that may run for +seconds to hours runs here as a background *task*. + +The central principle is that **the task is the single source of truth for execution**. A user +interface only assembles options and starts the task; all behaviour, status transitions, and error +handling live inside it. Whether an operation is started through a SQL command on the admin +database or a terminal invocation of the CLI, the same task runs the same code. + +This document covers the framework itself — how tasks are started, tracked, composed, cancelled, +and observed. It deliberately does not enumerate the individual task implementations; each task's +behaviour is documented alongside its own code in [`pgdog/src/api/`](../pgdog/src/api/). + +--- + +## Architecture + +```mermaid +flowchart TD + subgraph Interfaces + ADMIN["admin database API\nSQL commands"] + CLI["pgdog CLI\nsubcommands"] + end + + subgraph API["crate::api — execution layer"] + REG["process-global registry\nAsyncTasksStorage"] + TASKS["task implementations\n(impl Task)"] + end + + WORK["underlying operation\n(engine / pipeline / I/O)"] + + ADMIN -->|start task, get id| REG + CLI -->|start task, await result| REG + REG -->|spawns and tracks| TASKS + TASKS -->|drives| WORK + ADMIN -->|SHOW TASKS / STOP_TASK| REG + CLI -->|Ctrl-C → cancel_task| REG +``` + +A task is any type that implements the `Task` trait +([`api/async_task.rs`](../pgdog/src/api/async_task.rs)): it defines its own status type, output, and +error, and provides an `async run`. The framework owns everything around that `run` — spawning, +registration, id assignment, status storage, cancellation, and retention. + +--- + +## The registry + +When a task is started via `crate::api::start()` ([`api/mod.rs`](../pgdog/src/api/mod.rs)), it is +spawned as an async future and immediately registered in `AsyncTasksStorage` +([`api/async_task.rs`](../pgdog/src/api/async_task.rs)) under a monotonically-increasing integer id. +The id is returned before any work begins, so the caller can track the operation while it runs in +the background. + +One registry serves the entire process. A task started by the CLI and a task started through the +admin SQL API are both registered in the same store and are equally visible to `SHOW TASKS` and +equally cancellable by `STOP_TASK`. A task spawned automatically by another task registers itself +the same way and is just as addressable from either interface. + +**Terminal tasks are retained for 24 hours** after they finish so their outcome can be inspected. +A running task is never pruned. `SHOW TASKS` filters terminal tasks back out and shows only what +is currently running; the retention period exists only so that an id returned by a command +continues to be addressable for a reasonable window after completion. + +--- + +## Task composition + +A task can spawn *child tasks* through its execution context. Children share the root task's +registration entry and appear as a flat subtask list on the root's snapshot in the registry. The +root's status describes which high-level phase is active; the child's status describes what is +happening within that phase. A composite task that sequences several phases therefore reports +fine-grained progress without any special support from the registry — each phase is just a child +task whose status bubbles up. + +`SHOW TASKS` surfaces both the root and its running children as separate rows, so a child appears +with its own type even though it was never started as a top-level command. `STOP_TASK` only +addresses root tasks by id; cancelling the root propagates to all its children through the +cancellation token hierarchy (see [Cancellation](#cancellation) below). + +--- + +## Status lifecycle + +Every task type defines its own set of progress stages. The registry stores a type-erased +snapshot of the current status on every write and exposes it through `SHOW TASKS`. A task begins +in `Started` the moment it is registered, transitions through `Pending(stage)` as it reports +progress, and ends in one of four terminal states: + +- **`Finished`** — completed successfully. +- **`Cancelled`** — stopped by `STOP_TASK`, Ctrl-C, or parent cancellation. +- **`Error`** — the task's code returned an error; the message is stored in the status. +- **`Panic`** — the task's future panicked; the message is stored. + +Terminal states are **write-once**: a context clone that outlives the task cannot overwrite a +recorded outcome. + +--- + +## Cancellation + +Every task holds a `CancellationToken`; a child's token is `parent.child_token()`, so cancelling a +task cancels its whole subtree. + +Cooperative vs. not is decided by one call. `ctx.cancellation_token()` sets `cooperative = true` as +a side effect and hands back the token: + +```rust +pub fn cancellation_token(&self) -> CancellationToken { + self.task.cooperative.store(true, Ordering::Relaxed); + self.task.cancellation_token.clone() +} +``` + +The watcher branches on that flag when the token fires: + +```rust +ctx.transition(TaskStatus::Cancelling); +if cooperative { + // grace period, then force-abort + match timeout(T::cancel_timeout(), &mut handle).await { + Ok(res) => res, + Err(_) => { handle.abort(); handle.await } + } +} else { + handle.abort(); handle.await // never took the token → abort now +} +``` + +A cooperative task typically `select!`s its work against the token and runs its own shutdown +before `run` returns, within `cancel_timeout()` (default 5s). A non-cooperative task never takes +the token and is aborted immediately — fine when dropping the future already tears the work down +(e.g. in-flight units in a `JoinSet`). + +`STOP_TASK` calls `cancel_task(id)`, which returns `None` for an unknown or already-terminal id +(so callers don't claim success or emit cleanup warnings for a finished task) and otherwise calls +`entry.cancel()`. + +## Composition + +`ctx.run(child)` spawns a subtask: + +```rust +pub fn run(&self, task: T1) -> AsyncTaskWaiter { + run_task(Some(&self.task), &self.task.subtasks, task) +} +``` + +All descendants register in the *root's* `subtasks` map, so `TaskSnapshot.subtasks` is a flat list +of every descendant ordered by id (their own `subtasks` are always empty). `SHOW TASKS` renders +the root and its running children as separate rows. `STOP_TASK` only addresses roots; the token +hierarchy propagates the cancel downward. + +## Retention + +`AsyncTasksStorage` prunes on every `run`/`tasks`/`task` call. `prune()` drops entries whose +terminal state is older than `retention` (`TASK_RETENTION = 24h`); running tasks are never dropped. +So an id stays addressable, with its final status and last `inner_status`, for 24h after it +finishes. + +## The two callers + +Both go through `AsyncTaskWaiter`; the difference is what they do with it. + +- **Admin** ([`pgdog/src/admin/`](../pgdog/src/admin/)): fire-and-forget. Take `.id()`, drop the + waiter, return the id to the client. The client polls `SHOW TASKS` and runs `STOP_TASK `. +- **CLI** ([`cli.rs`](../pgdog/src/cli.rs)): await the waiter in a loop. On Ctrl-C, call into the + registry to cancel, then keep awaiting until the task winds down before exiting. + +Same task, same options, same status transitions, same error path either way. diff --git a/docs/issues/replication.md b/docs/issues/replication.md index c33673515..1dac7989e 100644 --- a/docs/issues/replication.md +++ b/docs/issues/replication.md @@ -20,6 +20,18 @@ A logical replication slot only decodes and delivers changes that belong to its This means the lag metric permanently overstates the remaining work. On a three-shard benchmark where all source databases are on one instance, the observed lag was ~3.5 GB per slot even after each slot had replayed all of its own publication's data. The lag never dropped below the cutover threshold, so `wait_for_replication()` looped indefinitely and cutover never fired. +The destination is a second, often worse, source of the same inflation. When the +destination shards live on the **same PostgreSQL instance** as the source (a +single-instance dev/test setup, or any deployment that co-locates them), every +row pgdog copies or replicates *into* the destination is itself a write to that +instance and advances `pg_current_wal_lsn()`. So the very act of catching the +destination up pushes the instance WAL position forward, while the source slot's +`confirmed_flush_lsn` only tracks the source publication — the measured lag rises +as replication makes progress instead of falling. Every source slot on a shared +instance reads back a near-identical, ever-growing lag, because they all subtract +their publication-scoped `confirmed_flush_lsn` from the same instance-wide +`pg_current_wal_lsn()`. + ### Cause `pg_current_wal_lsn()` is instance-scoped. `confirmed_flush_lsn` is publication-scoped. Their difference is only meaningful when a single database accounts for all writes to the instance. @@ -55,7 +67,18 @@ keepalive received --- -## 🚧 Issue 2 — Stop signal only unblocked one task instead of all +## ✅ Issue 2 — Stop signal only unblocked one task instead of all (resolved) + +> **Resolved.** The `Arc` stop signal was replaced with a +> [`CancellationToken`](https://docs.rs/tokio-util/latest/tokio_util/sync/struct.CancellationToken.html) +> held by both `Publisher` and `Waiter` +> ([`publisher_impl.rs`](../../pgdog/src/backend/replication/logical/publisher/publisher_impl.rs)). +> `Waiter::stop()` calls `stop.cancel()`, which latches permanently and wakes +> every per-shard stream task — whether already parked on `stop.cancelled()` or +> polling later — so one call unblocks all N tasks. The token (not the +> `watch::Sender` proposed below) is the primitive that shipped; both +> satisfy the "persistent value, wakes every receiver" requirement. Each stream +> latches the signal with a `stopping` flag, then drains its slot to completion. ### Description @@ -148,7 +171,24 @@ The sentinel is overwritten by the first real measurement from each task's `chec --- -## 🚧 Issue 4 — Divergent code paths for the same operation +## ✅ Issue 4 — Divergent code paths for the same operation (resolved) + +> **Resolved.** The divergent paths were consolidated onto the `crate::api` +> background-task framework. [`ReshardTask`](../../pgdog/src/api/resharding.rs) +> is the single composite flow (pre-data schema sync → data copy → post-data +> schema sync → replication), with `auto_cutover` toggling the final cutover; +> [`CopyDataTask`](../../pgdog/src/api/copy_data.rs) is the bulk data-copy leaf +> it composes; and `REPLICATE`/`CUTOVER`/`STOP_TASK` all drive one +> [`ReplicationTask`](../../pgdog/src/api/replication.rs). The old +> `backend/replication/logical/admin.rs` (`AsyncTasks`, `TaskType`) and +> `Orchestrator::replicate_and_cutover()` referenced below no longer exist. +> +> The `notify_one()` race is gone: a cutover is delivered to the specific +> running task through a latching `CancellationToken` held in a per-task +> `CUTOVERS` map keyed by task id (`ReplicationTask::cutover`), so it cannot be +> lost or leak into a later task. Natural slot drain intentionally does *not* +> auto-cut-over in the operator flow (`REPLICATE`/`copy_data`); cutover is +> explicit (operator `CUTOVER`) or automatic only under reshard's `auto_cutover`. ### Description @@ -322,4 +362,120 @@ The `AbortSignal` type should be deleted. It carries no state that cannot be rep ### References - [`tokio_util::sync::CancellationToken`](https://docs.rs/tokio-util/latest/tokio_util/sync/struct.CancellationToken.html) — cooperative cancellation token with child-token support and `cancelled().await`. -- [`tokio::sync::watch`](https://docs.rs/tokio/latest/tokio/sync/watch/index.html) — persistent value channel; used for the stop-signal fix in Issue 2. \ No newline at end of file +- [`tokio::sync::watch`](https://docs.rs/tokio/latest/tokio/sync/watch/index.html) — persistent value channel; used for the stop-signal fix in Issue 2. + +--- + +## ✅ Issue 6 — `STOP_TASK` during cutover removed the task but left it running (resolved) + +> **Resolved.** In the `crate::api` framework `STOP_TASK` only *requests* +> cancellation through a `CancellationToken`; it never removes the registry +> entry. The entry is dropped only after the task future actually reaches a +> terminal state, so the registry always reflects real execution state. A +> cutover that has passed its point of no return is allowed to finish rather +> than being torn down mid-traffic-swap: `ReplicationTask` overrides +> `cancel_timeout()` with a 60 s grace +> ([`replication.rs`](../../pgdog/src/api/replication.rs)) — comfortably longer +> than the committed swap (WAL drain + schema-sync DDL + reverse-replication +> setup), so `STOP_TASK` lets it run to completion instead of aborting it after +> the 5 s default. A genuinely hung swap is still reaped once the grace expires. + +### Description (old implementation) + +The old `AsyncTasks` registry (deleted `backend/replication/logical/admin.rs`) +tracked each task in a `DashMap`, where: + +```rust +struct TaskInfo { + abort_tx: oneshot::Sender<()>, // dropped => abort_rx resolves + cutover: Arc, + task_kind: TaskKind, + started_at: SystemTime, +} +``` + +A replication task was driven by a detached `tokio::spawn`ed future selecting +over three arms: + +```rust +spawn(async move { + select! { + _ = abort_rx => { waiter.stop(); } // STOP_TASK + _ = cutover.notified() => { waiter.cutover().await; } // CUTOVER + result = waiter.wait() => { /* slot drained */ } + } + AsyncTasks::get().tasks.remove(&id); // self-remove on exit +}); +``` + +`STOP_TASK` called `AsyncTasks::remove(id)`, which did +`tasks.remove(&id)` — synchronously deleting the map entry (and dropping +`abort_tx`). Two things made the cutover-then-stop sequence broken: + +1. **Removal was immediate and unconditional.** The entry vanished from the map + the instant `remove` was called, and `remove` returned the `TaskKind` as if + it had succeeded. Nothing tied the entry's presence to the task having + actually stopped — the map said "gone" while the spawned future was still + running. There was no terminal state retained; `SHOW TASKS` simply stopped + listing it. + +2. **The abort arm could not win once cutover had started.** Dropping + `abort_tx` resolves `abort_rx`, but `select!` had already committed to the + `cutover.notified()` arm, so `waiter.cutover().await` ran to completion + regardless. Unlike the `SchemaSync`/`CopyData` variants — which held an + `abort_handle` and called `abort_handle.abort()` — the replication variant + had no hard-abort path at all; `STOP_TASK` could only ever request a + graceful `waiter.stop()`, and only if it won the race. + +So after `CUTOVER` then `STOP_TASK`, the registry reported the task removed and +stopped, while the detached future kept flipping traffic, running post-data +schema sync, and setting up reverse replication — mutating cluster state with +no visibility and no way to stop it. + +### How the new implementation works + +The registry entry's lifetime is tied to the task future through the supervisor +in `run_task` ([`async_task.rs`](../../pgdog/src/api/async_task.rs)), not to +admin-side bookkeeping: + +1. **`STOP_TASK` requests; it does not remove.** + `AsyncTasksStorage::cancel_task` only calls `entry.cancel()` on the task's + `CancellationToken` and leaves the entry in the map. The entry transitions + to a terminal status (`Cancelled`/`Finished`) only when the spawned future + returns; `prune()` then drops it after the 24h retention. A running task is + never removed, so `SHOW TASKS` keeps showing it as `Stopping`/`CuttingOver` + until it has genuinely finished — the map always reflects real state. + +2. **A started cutover runs to completion.** `ReplicationTask::run` + ([`replication.rs`](../../pgdog/src/api/replication.rs)) `select!`s the + cutover arm (`waiter.cutover().await`) against `token.cancelled()`. Once the + `CUTOVER` signal fires and that arm is chosen, `select!` is committed to + awaiting it; a later `STOP_TASK` cancels the token but no longer has an arm + to win. The supervisor sees the cancelled token and — because + `ReplicationTask` is cooperative (it took its token) — waits + `cancel_timeout()` before aborting. `ReplicationTask` overrides + `cancel_timeout()` to **60 s** (vs. the 5 s trait default in + [`async_task.rs`](../../pgdog/src/api/async_task.rs)), comfortably longer than + the committed swap, so the traffic switch finishes instead of being aborted + mid-flight. Only a swap that hangs past the grace is force-aborted. + +3. **The reverse order winds down without cutover.** If `STOP_TASK` lands + first, the `token.cancelled()` arm sets status `Stopping`, calls + `waiter.stop()` (graceful slot drain), and returns — no traffic switch. A + subsequent `CUTOVER` finds the `CutoverWaiter` already dropped from the + `CUTOVERS` map, so `ReplicationTask::cutover` returns `false` and the admin + command rejects with `NotReplication`. + +Net: `STOP_TASK` gracefully stops a task that has not yet cut over, and is a +deliberate no-op against an in-flight cutover — which completes within its 60 s +grace — and in both cases the task stays visible in the registry until it has +actually stopped. + +### Code references + +| Symbol | File | +|---|---| +| `AsyncTasksStorage::cancel_task` — cancels the token, keeps the entry | [`pgdog/src/api/async_task.rs`](../../pgdog/src/api/async_task.rs) | +| `AsyncTasksStorage::prune` — drops only terminal entries past retention | same file | +| `run_task` supervisor — cooperative grace via `cancel_timeout`; sets terminal status on completion | same file | +| `ReplicationTask::run` / `cancel_timeout` — cutover-vs-stop `select!` | [`pgdog/src/api/replication.rs`](../../pgdog/src/api/replication.rs) | diff --git a/integration/rust/tests/integration/admin/mod.rs b/integration/rust/tests/integration/admin/mod.rs index f29650406..605d8ece7 100644 --- a/integration/rust/tests/integration/admin/mod.rs +++ b/integration/rust/tests/integration/admin/mod.rs @@ -12,8 +12,12 @@ use sqlx::{Column, Executor, Pool, Postgres, Row, TypeInfo}; /// Wire layout expected from `SHOW TASKS`. const SHOW_TASKS_LAYOUT: &[(&str, &str)] = &[ ("id", "INT8"), + ("root_id", "INT8"), + ("scope", "TEXT"), ("type", "TEXT"), + ("status", "TEXT"), ("started_at", "TEXT"), + ("updated_at", "TEXT"), ("elapsed", "TEXT"), ("elapsed_ms", "INT8"), ]; @@ -24,8 +28,12 @@ const SHOW_TASKS_LAYOUT: &[(&str, &str)] = &[ #[derive(Debug, Clone)] pub struct Task { pub id: i64, + pub root_id: i64, + pub scope: String, pub kind: String, + pub status: String, pub started_at: String, + pub updated_at: String, pub elapsed: String, pub elapsed_ms: i64, } @@ -57,18 +65,32 @@ impl Tasks { .iter() .map(|row| { let id: i64 = row.get("id"); + let root_id: i64 = row.get("root_id"); + let scope: String = row.get("scope"); + let status: String = row.get("status"); let started_at: String = row.get("started_at"); + let updated_at: String = row.get("updated_at"); let elapsed: String = row.get("elapsed"); let elapsed_ms: i64 = row.get("elapsed_ms"); assert!(!started_at.is_empty(), "task {id}: started_at is empty"); + assert!(!updated_at.is_empty(), "task {id}: updated_at is empty"); assert!(!elapsed.is_empty(), "task {id}: elapsed is empty"); + assert!(!status.is_empty(), "task {id}: status is empty"); assert!(elapsed_ms >= 0, "task {id}: elapsed_ms is negative"); + assert!( + scope == "root" || scope == "subtask", + "task {id}: unexpected scope {scope:?}" + ); Task { id, + root_id, + scope, kind: row.get("type"), + status, started_at, + updated_at, elapsed, elapsed_ms, } diff --git a/integration/rust/tests/integration/admin/tasks.rs b/integration/rust/tests/integration/admin/tasks.rs index efde1da74..d9888d748 100644 --- a/integration/rust/tests/integration/admin/tasks.rs +++ b/integration/rust/tests/integration/admin/tasks.rs @@ -169,7 +169,7 @@ async fn start_replication( if Tasks::fetch(admin) .await .find(task_id) - .is_some_and(|t| t.kind == "replication") + .is_some_and(|t| t.kind == "replication pgdog -> pgdog_sharded") { appeared = true; break; @@ -195,6 +195,30 @@ async fn wait_for_task_gone(admin: &Pool, task_id: i64) { panic!("task {task_id} still present in SHOW TASKS after 30s"); } +/// Whether the relation `name` (table or index) exists on `db`, resolved +/// through the connection's search_path — these tests create objects in the +/// `pgdog` schema (the `$user` schema for role `pgdog`). +async fn relation_present(pool: &Pool, name: &str) -> bool { + pool.fetch_one(format!("SELECT to_regclass('{name}') IS NOT NULL AS present").as_str()) + .await + .unwrap() + .get::("present") +} + +/// Poll until relation `name` exists on both destination shards (up to 30 s), +/// proving a schema sync actually propagated it. Panics on timeout. +async fn wait_for_relation_on_shards(name: &str) { + let shard_0 = connection_sqlx_direct_db("shard_0").await; + let shard_1 = connection_sqlx_direct_db("shard_1").await; + for _ in 0..60 { + if relation_present(&shard_0, name).await && relation_present(&shard_1, name).await { + return; + } + sleep(Duration::from_millis(500)).await; + } + panic!("relation {name} did not propagate to all shards within 30s"); +} + // ─── Tests ────────────────────────────────────────────────────────────────── /// `STOP_TASK` on an id that does not exist returns `"task not found"` rather @@ -261,8 +285,9 @@ async fn test_cutover() { cleanup(&admin, &direct).await; } -/// `SCHEMA_SYNC pre` registers a `schema_sync` task synchronously before -/// returning the `task_id`, so the task is in `SHOW TASKS` immediately. +/// `SCHEMA_SYNC pre` syncs the table structure from the source to the +/// destination shards. Asserts the table actually appears on both shards — +/// not merely that the task registered. #[tokio::test] async fn test_schema_sync_pre() { let direct = connection_sqlx_direct().await; @@ -284,37 +309,34 @@ async fn test_schema_sync_pre() { .fetch_one(format!("SCHEMA_SYNC pre pgdog pgdog_sharded {SCHEMA_SYNC_PRE_PUB}").as_str()) .await .unwrap(); - let task_id: i64 = row.get::("task_id").parse().unwrap(); - - // Task is registered before the command returns; verify kind if still running. - if let Some(task) = Tasks::fetch(&admin).await.find(task_id) { - assert_eq!(task.kind, "schema_sync"); - } + // Response carries the task id as TEXT; ensure it parses. + let _task_id: i64 = row.get::("task_id").parse().unwrap(); - let stop = admin - .fetch_one(format!("STOP_TASK {task_id}").as_str()) - .await - .unwrap(); - let status = stop.get::("stop_task"); - assert!( - status == "OK" || status == "task not found", - "unexpected STOP_TASK response: {status}" - ); + // cleanup() dropped the table from the shards pre-flight, so its presence + // here proves the pre sync created it on both shards. + wait_for_relation_on_shards(TEST_TABLE).await; cleanup(&admin, &direct).await; } -/// `SCHEMA_SYNC post` follows the same task lifecycle as `pre`. +/// `SCHEMA_SYNC post` adds indexes/constraints to tables that already exist on +/// the destination. Syncs the table with `pre` first, then asserts `post` +/// propagates a secondary index — an effect `pre` does not produce. #[tokio::test] async fn test_schema_sync_post() { let direct = connection_sqlx_direct().await; let admin = admin_sqlx().await; cleanup(&admin, &direct).await; + let secondary_index = format!("{TEST_TABLE}_val_idx"); direct .execute(format!("CREATE TABLE {TEST_TABLE} (id SERIAL PRIMARY KEY, val TEXT)").as_str()) .await .unwrap(); + direct + .execute(format!("CREATE INDEX {secondary_index} ON {TEST_TABLE} (val)").as_str()) + .await + .unwrap(); direct .execute( format!("CREATE PUBLICATION {SCHEMA_SYNC_POST_PUB} FOR TABLE {TEST_TABLE}").as_str(), @@ -322,32 +344,29 @@ async fn test_schema_sync_post() { .await .unwrap(); - let row = admin - .fetch_one(format!("SCHEMA_SYNC post pgdog pgdog_sharded {SCHEMA_SYNC_POST_PUB}").as_str()) + // pre creates the table (and primary key) on the shards, but not the + // secondary index — that is post-data. + admin + .fetch_one(format!("SCHEMA_SYNC pre pgdog pgdog_sharded {SCHEMA_SYNC_POST_PUB}").as_str()) .await .unwrap(); - let task_id: i64 = row.get::("task_id").parse().unwrap(); - - if let Some(task) = Tasks::fetch(&admin).await.find(task_id) { - assert_eq!(task.kind, "schema_sync"); - } + wait_for_relation_on_shards(TEST_TABLE).await; - let stop = admin - .fetch_one(format!("STOP_TASK {task_id}").as_str()) + // post adds the secondary index on both destination shards. + let row = admin + .fetch_one(format!("SCHEMA_SYNC post pgdog pgdog_sharded {SCHEMA_SYNC_POST_PUB}").as_str()) .await .unwrap(); - let status = stop.get::("stop_task"); - assert!( - status == "OK" || status == "task not found", - "unexpected STOP_TASK response: {status}" - ); + let _task_id: i64 = row.get::("task_id").parse().unwrap(); + + wait_for_relation_on_shards(&secondary_index).await; cleanup(&admin, &direct).await; } -/// `COPY_DATA` returns `task_id TEXT` and `replication_slot TEXT`. A -/// `copy_data` task is registered synchronously; it internally spawns a -/// `replication` task when complete. +/// `COPY_DATA` syncs the schema, copies data, then starts replication. Asserts +/// the table is actually created on both destination shards (the schema phase), +/// then `cleanup` stops the long-running replication the task spawns. #[tokio::test] async fn test_copy_data() { let direct = connection_sqlx_direct().await; @@ -368,26 +387,12 @@ async fn test_copy_data() { .fetch_one(format!("COPY_DATA pgdog pgdog_sharded {COPY_DATA_PUB}").as_str()) .await .unwrap(); - let task_id: i64 = row.get::("task_id").parse().unwrap(); + let _task_id: i64 = row.get::("task_id").parse().unwrap(); let slot_name: String = row.get("replication_slot"); assert!(!slot_name.is_empty(), "replication_slot must be non-empty"); - // Verify kind while still running (may already be gone if fast). - if let Some(task) = Tasks::fetch(&admin).await.find(task_id) { - assert_eq!(task.kind, "copy_data"); - } - - // Abort early; STOP_TASK on a CopyData task emits a WARNING notice that - // sqlx ignores. "task not found" is valid if the task finished first. - let stop = admin - .fetch_one(format!("STOP_TASK {task_id}").as_str()) - .await - .unwrap(); - let status = stop.get::("stop_task"); - assert!( - status == "OK" || status == "task not found", - "unexpected STOP_TASK response: {status}" - ); + // copy_data's schema-sync phase must create the table on both shards. + wait_for_relation_on_shards(TEST_TABLE).await; cleanup(&admin, &direct).await; } diff --git a/pgdog/Cargo.toml b/pgdog/Cargo.toml index 1b2829a07..175b2caa7 100644 --- a/pgdog/Cargo.toml +++ b/pgdog/Cargo.toml @@ -15,6 +15,7 @@ tui = ["ratatui"] new_parser = ["pg_raw_parse"] [dependencies] +bon.workspace = true pin-project = "1" tokio = { version = "1", features = ["full"] } tracing = "0.1" @@ -22,7 +23,7 @@ tracing-subscriber = { version = "0.3", features = ["env-filter", "json", "std"] tracing-throttle = "0.4" parking_lot = "0.12" thiserror = "2" -derive_more = { version = "2", features = ["display", "error", "from"] } +derive_more = { version = "2", features = ["display", "error", "from", "from_str"] } bytes = "1" clap = { version = "4", features = ["derive"] } serde = { version = "1", features = ["derive"] } diff --git a/pgdog/src/admin/copy_data.rs b/pgdog/src/admin/copy_data.rs index 869064edc..3fa6da9dc 100644 --- a/pgdog/src/admin/copy_data.rs +++ b/pgdog/src/admin/copy_data.rs @@ -1,10 +1,9 @@ //! COPY_DATA command. -use tokio::spawn; use tracing::info; -use crate::backend::replication::AsyncTasks; -use crate::backend::replication::logical::admin::{Task, TaskType}; +use crate::api::resharding::ReshardTask; +use crate::api::start; use crate::backend::replication::orchestrator::Orchestrator; use super::prelude::*; @@ -54,7 +53,7 @@ impl Command for CopyData { self.from_database, self.to_database, self.publication ); - let mut orchestrator = Orchestrator::new( + let orchestrator = Orchestrator::new( &self.from_database, &self.to_database, &self.publication, @@ -63,20 +62,7 @@ impl Command for CopyData { let slot_name = orchestrator.replication_slot().to_owned(); - let task_id = Task::register(TaskType::CopyData(spawn(async move { - orchestrator.load_schema().await?; - orchestrator.schema_sync_pre(true).await?; - orchestrator.data_sync().await?; - // data_sync can run for hours; any pool reload during the copy marks self.source - // offline. Re-fetch live cluster refs from databases() before starting replication. - orchestrator.refresh()?; - - AsyncTasks::insert(TaskType::Replication(Box::new( - orchestrator.replicate().await?, - ))); - - Ok(()) - }))); + let task_id = start(ReshardTask::builder().orchestrator(orchestrator).build()).id(); let mut dr = DataRow::new(); dr.add(task_id.to_string()).add(slot_name); diff --git a/pgdog/src/admin/cutover.rs b/pgdog/src/admin/cutover.rs index c5f8f8435..b8042813b 100644 --- a/pgdog/src/admin/cutover.rs +++ b/pgdog/src/admin/cutover.rs @@ -1,8 +1,12 @@ -use crate::backend::replication::logical::admin::AsyncTasks; +use crate::api::async_task::AsyncTaskId; +use crate::api::replication::ReplicationTask; +use crate::backend::replication::logical::Error as ReplicationError; use super::prelude::*; -pub struct Cutover; +pub struct Cutover { + task_id: Option, +} #[async_trait] impl Command for Cutover { @@ -14,13 +18,22 @@ impl Command for Cutover { let parts: Vec<&str> = sql.split_whitespace().collect(); match parts[..] { - ["cutover"] => Ok(Cutover), + ["cutover"] => Ok(Cutover { task_id: None }), + ["cutover", id] => { + let task_id = id.parse().map_err(|_| Error::Syntax)?; + Ok(Cutover { + task_id: Some(task_id), + }) + } _ => Err(Error::Syntax), } } async fn execute(&self) -> Result, Error> { - AsyncTasks::cutover()?; + // With an id, cut over that task; without, the first running one. + if !ReplicationTask::cutover(self.task_id.map(AsyncTaskId::from)) { + return Err(ReplicationError::NotReplication.into()); + } let mut dr = DataRow::new(); dr.add("OK"); diff --git a/pgdog/src/admin/parser.rs b/pgdog/src/admin/parser.rs index 692ec874b..b878a405d 100644 --- a/pgdog/src/admin/parser.rs +++ b/pgdog/src/admin/parser.rs @@ -277,7 +277,13 @@ mod tests { #[test] fn parses_cutover_command() { - let result = Parser::parse("CUTOVER"); - assert!(matches!(result, Ok(ParseResult::Cutover(_)))); + assert!(matches!( + Parser::parse("CUTOVER"), + Ok(ParseResult::Cutover(_)) + )); + assert!(matches!( + Parser::parse("CUTOVER 1"), + Ok(ParseResult::Cutover(_)) + )); } } diff --git a/pgdog/src/admin/replicate.rs b/pgdog/src/admin/replicate.rs index 9586366f3..3c2d9820d 100644 --- a/pgdog/src/admin/replicate.rs +++ b/pgdog/src/admin/replicate.rs @@ -2,7 +2,8 @@ use tracing::info; -use crate::backend::replication::logical::admin::{Task, TaskType}; +use crate::api::replication::ReplicationTask; +use crate::api::start; use crate::backend::replication::orchestrator::Orchestrator; use super::prelude::*; @@ -60,7 +61,7 @@ impl Command for Replicate { )?; let waiter = orchestrator.replicate().await?; - let task_id = Task::register(TaskType::Replication(Box::new(waiter))); + let task_id = start(ReplicationTask::builder().waiter(waiter).build()).id(); let mut dr = DataRow::new(); dr.add(task_id.to_string()); diff --git a/pgdog/src/admin/reshard.rs b/pgdog/src/admin/reshard.rs index eb9bf2da4..695ac6e79 100644 --- a/pgdog/src/admin/reshard.rs +++ b/pgdog/src/admin/reshard.rs @@ -2,6 +2,8 @@ use tracing::info; +use crate::api::resharding::ReshardTask; +use crate::api::start; use crate::backend::replication::orchestrator::Orchestrator; use super::prelude::*; @@ -50,15 +52,27 @@ impl Command for Reshard { r#"resharding "{}" to "{}", publication="{}""#, self.from_database, self.to_database, self.publication ); - let mut orchestrator = Orchestrator::new( + let orchestrator = Orchestrator::new( &self.from_database, &self.to_database, &self.publication, self.replication_slot.clone(), )?; - orchestrator.replicate_and_cutover().await?; + let task_id = start( + ReshardTask::builder() + .orchestrator(orchestrator) + .auto_cutover(true) + .build(), + ) + .id(); - Ok(vec![]) + let mut dr = DataRow::new(); + dr.add(task_id.to_string()); + + Ok(vec![ + RowDescription::new(&[Field::text("task_id")]).message()?, + dr.message()?, + ]) } } diff --git a/pgdog/src/admin/schema_sync.rs b/pgdog/src/admin/schema_sync.rs index 2b4947756..8ed27944d 100644 --- a/pgdog/src/admin/schema_sync.rs +++ b/pgdog/src/admin/schema_sync.rs @@ -1,19 +1,13 @@ //! SCHEMA_SYNC command. -use tokio::spawn; use tracing::info; -use crate::backend::replication::logical::admin::{Task, TaskType}; +use crate::api::schema_sync::{SchemaSyncPhase, SchemaSyncTask}; +use crate::api::start; use crate::backend::replication::orchestrator::Orchestrator; use super::prelude::*; -#[derive(Clone, Copy, PartialEq, Eq)] -pub enum SchemaSyncPhase { - Pre, - Post, -} - pub struct SchemaSync { pub from_database: String, pub to_database: String, @@ -25,10 +19,7 @@ pub struct SchemaSync { #[async_trait] impl Command for SchemaSync { fn name(&self) -> String { - match self.phase { - SchemaSyncPhase::Pre => "SCHEMA_SYNC PRE".into(), - SchemaSyncPhase::Post => "SCHEMA_SYNC POST".into(), - } + format!("SCHEMA_SYNC {}", self.phase).to_uppercase() } fn parse(sql: &str) -> Result { @@ -46,7 +37,7 @@ impl Command for SchemaSync { to_database: to_database.to_owned(), publication: publication.to_owned(), replication_slot: None, - phase: parse_phase(phase)?, + phase: phase.parse().map_err(|_| Error::Syntax)?, }), [ "schema_sync", @@ -60,43 +51,33 @@ impl Command for SchemaSync { to_database: to_database.to_owned(), publication: publication.to_owned(), replication_slot: Some(replication_slot.to_owned()), - phase: parse_phase(phase)?, + phase: phase.parse().map_err(|_| Error::Syntax)?, }), _ => Err(Error::Syntax), } } async fn execute(&self) -> Result, Error> { - let phase_name = match self.phase { - SchemaSyncPhase::Pre => "pre", - SchemaSyncPhase::Post => "post", - }; - info!( r#"schema_sync {} "{}" to "{}", publication="{}""#, - phase_name, self.from_database, self.to_database, self.publication + self.phase, self.from_database, self.to_database, self.publication ); - let mut orchestrator = Orchestrator::new( + let orchestrator = Orchestrator::new( &self.from_database, &self.to_database, &self.publication, self.replication_slot.clone(), )?; - let phase = self.phase; - let handle = spawn(async move { - orchestrator.load_schema().await?; - - match phase { - SchemaSyncPhase::Pre => orchestrator.schema_sync_pre(true).await, - SchemaSyncPhase::Post => orchestrator.schema_sync_post(true).await, - }?; - - Ok(()) - }); - - let task_id = Task::register(TaskType::SchemaSync(handle)); + let task_id = start( + SchemaSyncTask::builder() + .orchestrator(orchestrator) + .phase(self.phase) + .ignore_errors(true) + .build(), + ) + .id(); let mut dr = DataRow::new(); dr.add(task_id.to_string()); @@ -107,11 +88,3 @@ impl Command for SchemaSync { ]) } } - -fn parse_phase(phase: &str) -> Result { - match phase { - "pre" => Ok(SchemaSyncPhase::Pre), - "post" => Ok(SchemaSyncPhase::Post), - _ => Err(Error::Syntax), - } -} diff --git a/pgdog/src/admin/show_tasks.rs b/pgdog/src/admin/show_tasks.rs index 51565c812..65a40c769 100644 --- a/pgdog/src/admin/show_tasks.rs +++ b/pgdog/src/admin/show_tasks.rs @@ -2,7 +2,7 @@ use std::time::SystemTime; use chrono::{DateTime, Local}; -use crate::backend::replication::logical::admin::AsyncTasks; +use crate::api::storage; use crate::util::{format_time, human_duration_display}; use super::prelude::*; @@ -22,29 +22,53 @@ impl Command for ShowTasks { async fn execute(&self) -> Result, Error> { let rd = RowDescription::new(&[ Field::bigint("id"), - // Field::bigint("parent_id"), + Field::bigint("root_id"), + Field::text("scope"), Field::text("type"), + Field::text("status"), Field::text("started_at"), + Field::text("updated_at"), Field::text("elapsed"), Field::bigint("elapsed_ms"), ]); let mut messages = vec![rd.message()?]; let now = SystemTime::now(); - for (id, task_kind, started_at) in AsyncTasks::get().iter() { - let elapsed = now.duration_since(started_at).unwrap_or_default(); - let elapsed_ms = elapsed.as_millis() as i64; - let elapsed_str = human_duration_display(elapsed); + for task in storage().tasks() { + // A root task plus its subtasks (e.g. the replication child of a + // copy_data/reshard task). Each row carries its own `id` and the + // `root_id` it belongs to — only root tasks are cancellable, so + // STOP_TASK targets `root_id`. Terminal tasks are retained for + // reporting but filtered out here. + let root_id = task.id; + let entries = std::iter::once((task.id, true, &task.state)) + .chain(task.subtasks.iter().map(|sub| (sub.id, false, &sub.state))); - let started_at_str = format_time(DateTime::::from(started_at)); + for (id, is_root, state) in entries { + if state.is_terminal() { + continue; + } - let mut row = DataRow::new(); - row.add(id as i64) - .add(task_kind.to_string().as_str()) - .add(started_at_str.as_str()) - .add(elapsed_str.as_str()) - .add(elapsed_ms); - messages.push(row.message()?); + let elapsed = now.duration_since(state.started_at).unwrap_or_default(); + let elapsed_ms = elapsed.as_millis() as i64; + let elapsed_str = human_duration_display(elapsed); + let started_at_str = format_time(DateTime::::from(state.started_at)); + let updated_at_str = format_time(DateTime::::from(state.updated_at)); + let status_str = state.status.to_string(); + let scope = if is_root { "root" } else { "subtask" }; + + let mut row = DataRow::new(); + row.add(id) + .add(root_id) + .add(scope) + .add(state.name.as_str()) + .add(status_str.as_str()) + .add(started_at_str.as_str()) + .add(updated_at_str.as_str()) + .add(elapsed_str.as_str()) + .add(elapsed_ms); + messages.push(row.message()?); + } } Ok(messages) diff --git a/pgdog/src/admin/stop_task.rs b/pgdog/src/admin/stop_task.rs index 70be0b857..c7f07cac0 100644 --- a/pgdog/src/admin/stop_task.rs +++ b/pgdog/src/admin/stop_task.rs @@ -1,5 +1,5 @@ -use crate::backend::replication::logical::admin::{AsyncTasks, TaskKind}; -use crate::net::messages::{ErrorResponse, NoticeResponse}; +use crate::api::async_task::AsyncTaskId; +use crate::api::storage; use super::prelude::*; @@ -26,23 +26,14 @@ impl Command for StopTask { } async fn execute(&self) -> Result, Error> { - let task_kind = AsyncTasks::remove(self.task_id); + let cancelled = storage().cancel_task(AsyncTaskId::from(self.task_id)); let mut messages = vec![]; - if task_kind == Some(TaskKind::CopyData) { - let notice = NoticeResponse::from(ErrorResponse { - severity: "WARNING".into(), - code: "01000".into(), - message: "replication slot was not dropped and requires manual cleanup".into(), - ..Default::default() - }); - messages.push(notice.message()?); - } - - let result = match task_kind { - Some(_) => "OK", - None => "task not found", + let result = if cancelled.is_some() { + "OK" + } else { + "task not found" }; let mut dr = DataRow::new(); diff --git a/pgdog/src/api/async_task.rs b/pgdog/src/api/async_task.rs index b80cf8c7e..45b2899dd 100644 --- a/pgdog/src/api/async_task.rs +++ b/pgdog/src/api/async_task.rs @@ -23,6 +23,12 @@ impl From for AsyncTaskId { } } +impl From for u64 { + fn from(id: AsyncTaskId) -> Self { + id.0 + } +} + /// Status type for tasks that report no intermediate progress. /// /// [`Infallible`](std::convert::Infallible) is uninhabited, so a task @@ -76,6 +82,14 @@ pub struct TaskState { pub updated_at: SystemTime, } +impl TaskState { + /// Whether the task reached a terminal state (finished, cancelled, + /// errored, or panicked) and is only retained for status reporting. + pub fn is_terminal(&self) -> bool { + self.status.is_terminal() + } +} + /// Why a task did not complete, delivered to the waiter /// as the error half of its `Result`. #[derive(Debug, Display, Error)] @@ -119,8 +133,6 @@ impl TaskStatus { } } -type SharedStatus = Arc>>; - #[derive(Default)] struct TasksMap { map: DashMap>, @@ -128,12 +140,12 @@ struct TasksMap { } impl TasksMap { - fn insert_next(&self, value: Arc) -> AsyncTaskId { - let id = AsyncTaskId(self.counter.fetch_add(1, Ordering::Relaxed)); + fn next_id(&self) -> AsyncTaskId { + AsyncTaskId(self.counter.fetch_add(1, Ordering::Relaxed)) + } + fn insert(&self, id: AsyncTaskId, value: Arc) { self.map.insert(id, value); - - id } } @@ -153,6 +165,10 @@ impl AsyncTaskState { struct AsyncTask { started_at: SystemTime, + /// Id of the root task this task belongs to — its own id when it is a + /// root, inherited from the parent otherwise. Stable for the task's + /// lifetime and used to address its cutover signal. + root_id: AsyncTaskId, name: String, cancellation_token: CancellationToken, /// Set once the task asks for its cancellation token: only @@ -300,13 +316,20 @@ fn run_task( parent_token: Option<&CancellationToken>, register_into: &TasksMap, subtasks: Arc, + root: Option, task: T, ) -> AsyncTaskWaiter { + // Allocate the id up front so a root task can record its own id as its + // root id; descendants inherit the root's id from their parent. + let id = register_into.next_id(); + let root_id = root.unwrap_or(id); + let state = Arc::new(RwLock::new(AsyncTaskState::new())); let entry = AsyncTask { started_at: SystemTime::now(), name: task.to_string(), + root_id, cancellation_token: match parent_token { Some(token) => token.child_token(), None => CancellationToken::new(), @@ -319,8 +342,8 @@ fn run_task( }; let entry = Arc::new(entry); - // Make sure we insert task to map before it's actually started - let id = register_into.insert_next(entry.clone()); + // Make sure we insert task to map before it's actually started. + register_into.insert(id, entry.clone()); let ctx = AsyncTaskContext { task: entry.clone(), @@ -412,9 +435,16 @@ impl AsyncTaskContext { Some(&self.task.cancellation_token), &self.task.subtasks, self.task.subtasks.clone(), + Some(self.task.root_id), task, ) } + + /// Id of the root task this task belongs to (its own id when it is a + /// root). Used to address the task's cutover signal. + pub fn root_id(&self) -> AsyncTaskId { + self.task.root_id + } } impl AsyncTasksStorage { @@ -428,20 +458,28 @@ impl AsyncTasksStorage { pub fn run(&self, task: T) -> AsyncTaskWaiter { self.prune(); - run_task(None, &self.tasks, Arc::new(TasksMap::default()), task) + run_task(None, &self.tasks, Arc::new(TasksMap::default()), None, task) } /// Request cancellation of a task. The task winds down /// cooperatively (or is aborted after the grace period) and /// stays in the registry with a terminal status until pruned. - /// Returns the state the task was in when cancellation was - /// requested, or `None` for an unknown id. + /// Returns the state the task was in when cancellation was requested, + /// or `None` for an unknown or already-terminal id. pub fn cancel_task(&self, id: AsyncTaskId) -> Option { let entry = self.tasks.map.get(&id)?; + let state = entry.state(); + + // Already terminal (finished/cancelled/errored) but still retained for + // status reporting: nothing to cancel. Report as not found so callers + // don't claim success or emit cleanup warnings for a finished task. + if state.is_terminal() { + return None; + } entry.cancel(); - Some(entry.state()) + Some(state) } /// Drop every task that reached a terminal state more than @@ -1233,4 +1271,63 @@ mod tests { waiter.await.unwrap(); assert!(crate::api::storage().task(id).is_some()); } + + #[test] + async fn test_cancel_finished_task_is_not_found() { + let notify = Arc::new(Notify::new()); + let state = Arc::new(Mutex::new("initial")); + let a = mock_successful!(state, notify); + + let async_storage = AsyncTasksStorage::default(); + let task = async_storage.run(a); + let id = task.id(); + + notify.notify_one(); + task.await.unwrap(); + + // The task is terminal (finished) but still retained for reporting: + // cancelling it now reports not-found, so STOP_TASK won't claim it + // stopped a completed task (nor emit a bogus cleanup warning). + assert!(async_storage.task(id).is_some()); + assert!(async_storage.cancel_task(id).is_none()); + + // An unknown id is not-found too. + assert!( + async_storage + .cancel_task(AsyncTaskId::from(99_u64)) + .is_none() + ); + } + + #[test] + async fn test_prune_expired_subtasks_under_running_root() { + let sub_gate = Arc::new(Notify::new()); + let parent_gate = Arc::new(Notify::new()); + + // Zero retention: terminal tasks are pruned on next access. + let async_storage = AsyncTasksStorage::new(Duration::ZERO); + + let task = async_storage.run(TraverseRoot { + sub_gate: sub_gate.clone(), + parent_gate: parent_gate.clone(), + }); + let root_id = task.id(); + + settle().await; + + // Root running with its two descendants registered. + assert_eq!(async_storage.task(root_id).unwrap().subtasks.len(), 2); + + // Let the subtasks finish; the root parks on `parent_gate`, still running. + sub_gate.notify_one(); + settle().await; + + // The finished subtasks are pruned while the still-running root survives. + let snapshot = async_storage.task(root_id).unwrap(); + assert!(matches!(snapshot.state.status, TaskStatus::Pending(_))); + assert!(snapshot.subtasks.is_empty()); + + parent_gate.notify_one(); + task.await.unwrap(); + } } diff --git a/pgdog/src/api/copy_data.rs b/pgdog/src/api/copy_data.rs index 1a6a124e5..9dbbe337f 100644 --- a/pgdog/src/api/copy_data.rs +++ b/pgdog/src/api/copy_data.rs @@ -1,80 +1,42 @@ -//! Copy-data background task: schema sync + data sync, then a replication -//! task that catches up and (on `CUTOVER`) cuts over. +//! Copy-data leaf task: bulk-copies table data from a source to a target. +//! +//! This task only copies data. The schema sync (pre-data tables, post-data +//! indexes) and replication around it are composed by +//! [`ReshardTask`](crate::api::resharding::ReshardTask). -use crate::api::async_task::AsyncTaskContext; -use crate::api::replication::ReplicationTask; -use crate::api::schema_sync::{SchemaSyncPhase, SchemaSyncTask}; -use crate::api::{MigrationError, Task}; +use tokio::select; + +use crate::api::Task; +use crate::api::async_task::{AsyncTaskContext, Empty}; +use crate::backend::replication::logical::Error; use crate::backend::replication::logical::orchestrator::Orchestrator; -/// Copy data from a source database to a target: schema sync, data sync, -/// then replication catch-up and cutover. -#[derive(Display, Debug)] -#[display("copy_data")] +/// Bulk-copy table data from a source database to a target, returning the +/// orchestrator so the composing task can thread it into the next phase. +#[derive(Display, Debug, bon::Builder)] +#[display("copy_data {orchestrator}")] pub(crate) struct CopyDataTask { pub orchestrator: Orchestrator, } -/// Stages of the copy-data flow, reported as the task's status. The -/// fine-grained schema-sync and replication stages live on the child tasks. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Display)] -pub(crate) enum CopyDataStatus { - /// Running the schema-sync child task. - #[display("syncing schema")] - SchemaSync, - /// Copying table data to the destination. - #[display("syncing data")] - SyncingData, - /// Running the replication child task. - #[display("replicating")] - Replication, -} - impl Task for CopyDataTask { - type Status = CopyDataStatus; - type Output = (); - type Error = MigrationError; - - async fn run(self, ctx: AsyncTaskContext) -> Result<(), MigrationError> { - // Sync the schema as a child task so it reports its own stages. - ctx.set_status(CopyDataStatus::SchemaSync); - let mut orchestrator = ctx - .run(SchemaSyncTask { - orchestrator: self.orchestrator, - phase: SchemaSyncPhase::Pre, - }) - .await?; - - ctx.set_status(CopyDataStatus::SyncingData); - orchestrator.data_sync().await?; - - // data_sync can run for hours; pools may have reloaded. Re-fetch - // live cluster refs before starting replication. - orchestrator.refresh()?; - - // Replication runs as a child until cutover, reporting its own stages. - // Awaiting keeps copy_data non-terminal while it runs; its outcome is - // intentionally not propagated here. - ctx.set_status(CopyDataStatus::Replication); - let _ = ctx.run(ReplicationTask { orchestrator }).await; - - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn copy_data_status_renders_distinct_labels() { - let labels = [ - CopyDataStatus::SchemaSync.to_string(), - CopyDataStatus::SyncingData.to_string(), - CopyDataStatus::Replication.to_string(), - ]; - assert!(labels.iter().all(|label| !label.is_empty())); - let unique: std::collections::HashSet<_> = labels.iter().collect(); - assert_eq!(unique.len(), labels.len()); + type Status = Empty; + type Output = Orchestrator; + type Error = Error; + + async fn run(self, ctx: AsyncTaskContext) -> Result { + let token = ctx.cancellation_token(); + let orchestrator = self.orchestrator; + + select! { + res = orchestrator.data_sync() => res?, + // Cancellation drops the `data_sync()` future, whose internal + // `JoinSet` aborts every in-flight shard copy; closing those + // connections releases the temporary data-sync slots. The + // composing task drops the persistent replication slots afterward. + _ = token.cancelled() => return Err(Error::DataSyncAborted), + } + + Ok(orchestrator) } } diff --git a/pgdog/src/api/mod.rs b/pgdog/src/api/mod.rs index 4f77eb599..81c1631ff 100644 --- a/pgdog/src/api/mod.rs +++ b/pgdog/src/api/mod.rs @@ -45,6 +45,23 @@ pub(crate) enum MigrationError { Task(TaskError), } +/// Flatten a nested migration task's outcome into a single [`MigrationError`], +/// so a composite task (e.g. `reshard`) can run another composite task (e.g. +/// `copy_data`, whose error is already a `MigrationError`) as a child and +/// `?`-propagate its result without double-wrapping. +impl From> for MigrationError { + fn from(err: TaskError) -> Self { + match err { + // The child's own error: surface it directly. + TaskError::Failed(inner) => inner, + // Non-failure outcomes carry no inner error; re-wrap them. + TaskError::Cancelled => MigrationError::Task(TaskError::Cancelled), + TaskError::Panicked(msg) => MigrationError::Task(TaskError::Panicked(msg)), + TaskError::Abandoned => MigrationError::Task(TaskError::Abandoned), + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -65,5 +82,15 @@ mod tests { // Non-failure child outcomes are preserved too (not stringified). let err = MigrationError::from(TaskError::::Cancelled); assert!(matches!(err, MigrationError::Task(TaskError::Cancelled))); + + // A nested migration task's failure is flattened, not double-wrapped. + let err = MigrationError::from(TaskError::Failed(MigrationError::Replication( + Error::NoSchema, + ))); + assert!(matches!(err, MigrationError::Replication(Error::NoSchema))); + + // A nested non-failure outcome is preserved as a task error. + let err = MigrationError::from(TaskError::::Cancelled); + assert!(matches!(err, MigrationError::Task(TaskError::Cancelled))); } } diff --git a/pgdog/src/api/replication.rs b/pgdog/src/api/replication.rs index fcb3243ac..3b32f1c8b 100644 --- a/pgdog/src/api/replication.rs +++ b/pgdog/src/api/replication.rs @@ -1,19 +1,26 @@ //! Logical-replication background task. //! -//! Drives a `ReplicationWaiter` to completion: it stops on cancellation -//! (`STOP_TASK`), performs cutover on an external `cutover_signal::request()` -//! (`CUTOVER`), and otherwise finishes when the source slot drains (no cutover -//! on natural drain). Launch it top-level with [`super::start`], or as a child -//! by spawning it through a parent task's [`AsyncTaskContext`]. +//! Drives a `ReplicationWaiter` to completion. Without `auto_cutover` +//! (standalone `REPLICATE`, `copy_data`) it stops on cancellation +//! (`STOP_TASK`), cuts over on an operator `CUTOVER` addressed to this task +//! (delivered through [`ReplicationTask::cutover`]), and otherwise finishes +//! when the source slot drains (no cutover on natural drain). With +//! `auto_cutover` set (reshard) it cuts over automatically once the +//! destination has caught up. Launch it top-level with [`super::start`], or +//! as a child by spawning it through a parent task's [`AsyncTaskContext`]. +use std::collections::HashMap; +use std::sync::LazyLock; use std::time::Duration; +use parking_lot::Mutex; use tokio::select; +use tokio_util::sync::CancellationToken; use crate::api::Task; -use crate::api::async_task::AsyncTaskContext; -use crate::backend::replication::logical::orchestrator::Orchestrator; -use crate::backend::replication::logical::{Error, cutover_signal}; +use crate::api::async_task::{AsyncTaskContext, AsyncTaskId}; +use crate::backend::replication::logical::Error; +use crate::backend::replication::logical::orchestrator::ReplicationWaiter; /// Stages of logical replication, reported as the task's status. #[derive(Debug, Clone, Copy, PartialEq, Eq, Display)] @@ -24,17 +31,73 @@ pub(crate) enum ReplicationStatus { /// Cutting traffic over to the destination. #[display("cutting over")] CuttingOver, + /// Cutting traffic back to the original after a prior cutover (rollback). + #[display("rolling back")] + RollingBack, /// Winding down on a stop request. #[display("stopping")] Stopping, } -/// Replicate from a source database to a target, owning the orchestrator -/// that produces the replication waiter. -#[derive(Display, Debug)] -#[display("replication")] +/// Direction of a replication task: the initial migration (`Forward`) or the +/// post-cutover reverse stream that backs a rollback (`Reverse`). A `CUTOVER` +/// on a `Reverse` task is therefore a rollback. Affects reported status only, +/// not control flow. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub(crate) enum Direction { + #[default] + Forward, + Reverse, +} + +/// Drive a [`ReplicationWaiter`] to completion. The caller creates the waiter +/// (via `Orchestrator::replicate`); this task owns only the waiter, not the +/// orchestrator. +#[derive(Display, Debug, bon::Builder)] +#[display("replication {waiter}")] pub(crate) struct ReplicationTask { - pub orchestrator: Orchestrator, + /// The running replication waiter this task drives to completion. + pub waiter: ReplicationWaiter, + /// Cut over automatically once the destination has caught up, instead + /// of waiting for an operator `CUTOVER`. Set by the reshard flow, + /// which drives its own cutover; standalone `REPLICATE` and + /// `copy_data` leave it `false` and wait for an external `CUTOVER`. + #[builder(default)] + pub auto_cutover: bool, + /// Replication direction. `Reverse` marks the post-cutover stream that + /// backs a rollback; it only affects reported status, not control flow. + #[builder(default)] + pub direction: Direction, +} + +/// Cutover tokens of the replication tasks currently awaiting an operator +/// `CUTOVER`, keyed by the root task id they belong to. A cutover token is +/// *separate* from the task's `STOP_TASK` cancellation token — signalling it +/// means "cut over", not "abandon". Registrations are dropped with the task, +/// so a cutover can never outlive its task and leak into a later one. +static CUTOVERS: LazyLock>> = + LazyLock::new(|| Mutex::new(HashMap::new())); + +/// Guard held by a running replication task: removes its cutover +/// registration on drop. Awaiting [`requested`](CutoverWaiter::requested) +/// resolves when an operator `CUTOVER` targets the task. +struct CutoverWaiter { + root_id: AsyncTaskId, + token: CancellationToken, +} + +impl CutoverWaiter { + /// Wait until a cutover is requested for this task. The token latches, so + /// a cutover that arrived earlier is delivered immediately. + async fn requested(&self) { + self.token.cancelled().await; + } +} + +impl Drop for CutoverWaiter { + fn drop(&mut self) { + CUTOVERS.lock().remove(&self.root_id); + } } impl Task for ReplicationTask { @@ -52,21 +115,32 @@ impl Task for ReplicationTask { } async fn run(self, ctx: AsyncTaskContext) -> Result<(), Error> { - let mut waiter = self.orchestrator.replicate().await?; let token = ctx.cancellation_token(); + // Operator flow (`REPLICATE`, `copy_data`) registers for an external + // `CUTOVER` addressed to this task; the reshard flow (`auto_cutover`) + // cuts over on its own and registers nothing. + let cutover = (!self.auto_cutover).then(|| Self::register_cutover(ctx.root_id())); + + let mut waiter = self.waiter; ctx.set_status(ReplicationStatus::Replicating); select! { + // STOP_TASK: wind down without cutting over. _ = token.cancelled() => { ctx.set_status(ReplicationStatus::Stopping); waiter.stop(); } - _ = cutover_signal::requested() => { - ctx.set_status(ReplicationStatus::CuttingOver); + // Operator CUTOVER, or immediately under `auto_cutover`: switch traffic. + _ = async { if let Some(cutover) = &cutover { cutover.requested().await } } => { + ctx.set_status(match self.direction { + Direction::Forward => ReplicationStatus::CuttingOver, + Direction::Reverse => ReplicationStatus::RollingBack, + }); waiter.cutover().await?; } - res = waiter.wait() => { + // Source slot drained without a cutover (operator flow only): done. + res = waiter.wait(), if cutover.is_some() => { res?; } } @@ -75,10 +149,49 @@ impl Task for ReplicationTask { } } +impl ReplicationTask { + /// Trigger a cutover on a running replication task, returning whether one + /// was there to receive it. `Some(root_id)` targets that task; `None` + /// targets the first (lowest-id) running replication task. The `CUTOVER` + /// admin command rejects with `NotReplication` when this is `false`. + pub(crate) fn cutover(target: Option) -> bool { + let tokens = CUTOVERS.lock(); + + let token = match target { + Some(id) => tokens.get(&id), + // No id: cut over the first (lowest-id) running task. + None => tokens.keys().min().copied().and_then(|id| tokens.get(&id)), + }; + + match token { + Some(token) => { + token.cancel(); + true + } + None => false, + } + } + + /// Register this task (by its `root_id`) to receive operator cutovers for + /// as long as the returned guard is held. + fn register_cutover(root_id: AsyncTaskId) -> CutoverWaiter { + let token = CancellationToken::new(); + CUTOVERS.lock().insert(root_id, token.clone()); + CutoverWaiter { root_id, token } + } +} + #[cfg(test)] mod tests { + use std::time::Duration; + use super::*; + // Serialize tests that touch the process-global `CUTOVERS` map so they + // never observe each other's registrations under a multi-threaded harness. + static CUTOVER_TEST_LOCK: std::sync::LazyLock> = + std::sync::LazyLock::new(|| tokio::sync::Mutex::new(())); + #[test] fn cancel_timeout_far_exceeds_default() { // Far larger than the 5s default: cutover must not be force-aborted. @@ -96,4 +209,95 @@ mod tests { let unique: std::collections::HashSet<_> = labels.iter().collect(); assert_eq!(unique.len(), labels.len()); } + + #[tokio::test] + async fn cutover_delivers_even_when_buffered() { + let _guard = CUTOVER_TEST_LOCK.lock().await; + // Cutover lands before the task awaits: still delivered (latches). + let waiter = ReplicationTask::register_cutover(AsyncTaskId::from(1)); + assert!( + ReplicationTask::cutover(Some(AsyncTaskId::from(1))), + "the named task must receive the cutover" + ); + + tokio::time::timeout(Duration::from_secs(1), waiter.requested()) + .await + .expect("buffered cutover was not delivered"); + } + + #[tokio::test] + async fn cutover_targets_only_the_named_task() { + let _guard = CUTOVER_TEST_LOCK.lock().await; + // A cutover for one id must never disturb a task registered under a + // different id — the whole point of keying by task id. + let waiter = ReplicationTask::register_cutover(AsyncTaskId::from(7)); + + assert!( + !ReplicationTask::cutover(Some(AsyncTaskId::from(8))), + "no task is registered under id 8" + ); + assert!( + tokio::time::timeout(Duration::from_millis(200), waiter.requested()) + .await + .is_err(), + "a cutover for a different id leaked to this task" + ); + + assert!(ReplicationTask::cutover(Some(AsyncTaskId::from(7)))); + tokio::time::timeout(Duration::from_secs(1), waiter.requested()) + .await + .expect("targeted cutover was not delivered"); + } + + #[tokio::test] + async fn cutover_without_id_targets_the_first_task() { + let _guard = CUTOVER_TEST_LOCK.lock().await; + // No id: the lowest-id (first) registered task is cut over, and only + // it. + let first = ReplicationTask::register_cutover(AsyncTaskId::from(3)); + let second = ReplicationTask::register_cutover(AsyncTaskId::from(9)); + + assert!( + ReplicationTask::cutover(None), + "the first registered task must be cut over" + ); + + tokio::time::timeout(Duration::from_secs(1), first.requested()) + .await + .expect("the first task was not cut over"); + assert!( + tokio::time::timeout(Duration::from_millis(200), second.requested()) + .await + .is_err(), + "cutover(None) disturbed a task other than the first" + ); + } + + #[tokio::test] + async fn cutover_does_not_leak_to_the_next_task() { + let _guard = CUTOVER_TEST_LOCK.lock().await; + // A cutover to a task that never consumes it must die with that task, + // never reaching the next one. Regression guard for the signal leak. + { + let first = ReplicationTask::register_cutover(AsyncTaskId::from(1)); + assert!(ReplicationTask::cutover(Some(AsyncTaskId::from(1)))); + drop(first); // ends without ever awaiting `requested()` + } + + let next = ReplicationTask::register_cutover(AsyncTaskId::from(2)); + assert!( + tokio::time::timeout(Duration::from_millis(200), next.requested()) + .await + .is_err(), + "stale cutover leaked into the next replication task" + ); + } + + #[tokio::test] + async fn cutover_with_no_task_is_rejected() { + let _guard = CUTOVER_TEST_LOCK.lock().await; + // Nothing registered: `CUTOVER` (with or without an id) is rejected. + assert!(!ReplicationTask::cutover(None)); + assert!(!ReplicationTask::cutover(Some(AsyncTaskId::from(404)))); + } } diff --git a/pgdog/src/api/resharding.rs b/pgdog/src/api/resharding.rs index 9ac5acc05..bb528e639 100644 --- a/pgdog/src/api/resharding.rs +++ b/pgdog/src/api/resharding.rs @@ -1,35 +1,59 @@ -//! Reshard background task: the full automatic schema-sync + data-sync + -//! replication + cutover flow. +//! Reshard / migration composer task. +//! +//! Composes the full migration from a source database to a target: pre-data +//! schema sync, the [`CopyDataTask`] bulk copy, post-data schema sync, then +//! replication. With `auto_cutover` it cuts over automatically once replication +//! has caught up (reshard); otherwise it waits for an operator `CUTOVER` +//! (`copy_data`). The admin `COPY_DATA`/`RESHARD` commands and the CLI +//! `data_sync` all run this task, differing only in their options. + +use std::time::Duration; + +use tracing::warn; use crate::api::async_task::AsyncTaskContext; +use crate::api::copy_data::CopyDataTask; use crate::api::replication::ReplicationTask; use crate::api::schema_sync::{SchemaSyncPhase, SchemaSyncTask}; use crate::api::{MigrationError, Task}; -use crate::backend::replication::logical::cutover_signal; use crate::backend::replication::logical::orchestrator::Orchestrator; -/// Run the complete replicate-and-cutover flow from a source database to a -/// target. -#[derive(Display, Debug)] -#[display("reshard")] +/// Run the full migration from a source database to a target: schema sync +/// (pre-data tables, then post-data indexes around the bulk copy), data copy, +/// then replication. With `auto_cutover` it also performs the cutover. +#[derive(Display, Debug, bon::Builder)] +#[display("reshard {orchestrator}")] pub(crate) struct ReshardTask { pub orchestrator: Orchestrator, + /// Skip the pre- and post-data schema sync. + #[builder(default)] + pub skip_schema_sync: bool, + /// Only replicate; skip the initial data copy. + #[builder(default)] + pub replicate_only: bool, + /// Only copy data; skip replication. + #[builder(default)] + pub sync_only: bool, + /// Cut over automatically once replication has caught up, instead of + /// waiting for an operator `CUTOVER`. Set by the reshard flow. + #[builder(default)] + pub auto_cutover: bool, } -/// Stages of the reshard flow, reported as the task's status. The -/// fine-grained schema-sync and replication stages live on the child tasks. +/// Stages of the migration, reported as the task's status. The fine-grained +/// schema-sync, copy, and replication stages live on the child tasks. #[derive(Debug, Clone, Copy, PartialEq, Eq, Display)] pub(crate) enum ReshardStatus { /// Running the pre-data schema-sync child task. #[display("syncing schema")] SchemaSync, - /// Copying table data to the destination. + /// Running the data-copy child task. #[display("syncing data")] SyncingData, /// Running the post-data schema-sync child task (indexes, constraints). #[display("finalizing schema")] FinalizingSchema, - /// Running the replication child task through cutover. + /// Running the replication child task. #[display("replicating")] Replication, } @@ -39,41 +63,99 @@ impl Task for ReshardTask { type Output = (); type Error = MigrationError; + /// Composes long-lived child tasks; match their generous grace so a + /// `STOP_TASK` lets them wind down (and clean up replication slots) before + /// this task returns. + fn cancel_timeout() -> Duration { + Duration::from_secs(24 * 60 * 60) + } + async fn run(self, ctx: AsyncTaskContext) -> Result<(), MigrationError> { - // Sync the pre-data schema (tables) as a child task. - ctx.set_status(ReshardStatus::SchemaSync); - let orchestrator = ctx - .run(SchemaSyncTask { - orchestrator: self.orchestrator, - phase: SchemaSyncPhase::Pre, - }) - .await?; - - // Sync the data to destination. - ctx.set_status(ReshardStatus::SyncingData); - orchestrator.data_sync().await?; - - // Create secondary indexes as a child task (schema already loaded). - ctx.set_status(ReshardStatus::FinalizingSchema); - let mut orchestrator = ctx - .run(SchemaSyncTask { - orchestrator, - phase: SchemaSyncPhase::Post, - }) - .await?; - - // Refresh cluster references: data_sync can take hours and the pools - // may have been reloaded (e.g. by a client DDL) in the meantime. - orchestrator.refresh()?; - - // Reshard cuts over automatically: request it up front (the cutover - // signal is buffered), then run replication as a child that consumes - // the request and cuts over once it has caught up. - ctx.set_status(ReshardStatus::Replication); - cutover_signal::request(); - ctx.run(ReplicationTask { orchestrator }).await?; - - Ok(()) + // Take the cancellation token so a `STOP_TASK` winds the children down + // cooperatively (they'd otherwise outlive this task). + let _token = ctx.cancellation_token(); + let mut orchestrator = self.orchestrator; + + // Pre-data schema sync, unless skipped. It runs before any replication + // slots exist, so it stays outside the cleanup guard below. + if !self.skip_schema_sync { + ctx.set_status(ReshardStatus::SchemaSync); + orchestrator = ctx + .run( + SchemaSyncTask::builder() + .orchestrator(orchestrator) + .phase(SchemaSyncPhase::Pre) + .ignore_errors(true) + .build(), + ) + .await?; + } + + // From the data copy onward the orchestrator may hold replication slots + // (created during data_sync, kept until replication takes them over). + // Awaiting this guard on every exit drops whatever the publisher still + // owns — a no-op once replication has claimed the slots — so a failed or + // aborted migration never leaves them lingering on the source. + let guard = orchestrator.publication_guard(); + let result: Result<(), MigrationError> = async { + // Copy the data, unless replicate-only. + if !self.replicate_only { + ctx.set_status(ReshardStatus::SyncingData); + orchestrator = ctx + .run(CopyDataTask::builder().orchestrator(orchestrator).build()) + .await?; + } + + // Post-data schema sync (secondary indexes, constraints): the + // second half of schema sync, after the bulk load. + if !self.skip_schema_sync { + ctx.set_status(ReshardStatus::FinalizingSchema); + orchestrator = ctx + .run( + SchemaSyncTask::builder() + .orchestrator(orchestrator) + .phase(SchemaSyncPhase::Post) + .ignore_errors(true) + .build(), + ) + .await?; + } + + // Replication, unless sync-only. + if !self.sync_only { + ctx.set_status(ReshardStatus::Replication); + + // data_sync / schema sync can run for hours; pools may have + // reloaded. Re-fetch live cluster refs before replicating. + orchestrator.refresh()?; + + // `auto_cutover` (reshard) cuts over on its own; otherwise the + // task runs until an operator `CUTOVER`/`STOP_TASK`. Both of + // those resolve to `Ok`, so awaiting surfaces only a genuine + // replication failure. + let waiter = orchestrator.replicate().await?; + ctx.run( + ReplicationTask::builder() + .waiter(waiter) + .auto_cutover(self.auto_cutover) + .build(), + ) + .await?; + } + + Ok(()) + } + .await; + + // Drop any replication slots the publisher still owns only when the + // migration failed or was aborted mid-copy. + if result.is_err() + && let Err(err) = guard.cleanup().await + { + warn!("failed to clean up replication slots after migration: {err}"); + } + + result } } diff --git a/pgdog/src/api/schema_sync.rs b/pgdog/src/api/schema_sync.rs index 9266dba15..54b20826b 100644 --- a/pgdog/src/api/schema_sync.rs +++ b/pgdog/src/api/schema_sync.rs @@ -1,14 +1,31 @@ -//! Schema-sync background task (pre-data or post-data). +//! Schema-sync background task (pre-data, post-data, or cutover). + +use std::ops::Deref; use crate::api::Task; use crate::api::async_task::AsyncTaskContext; use crate::backend::replication::logical::Error; use crate::backend::replication::logical::orchestrator::Orchestrator; +use crate::backend::schema::sync::pg_dump::SyncState; -#[derive(Clone, Copy, PartialEq, Eq, Debug)] +#[derive(Clone, Copy, PartialEq, Eq, Debug, Display, FromStr)] pub enum SchemaSyncPhase { + #[display("pre")] Pre, + #[display("post")] Post, + #[display("cutover")] + Cutover, +} + +impl From for SyncState { + fn from(phase: SchemaSyncPhase) -> Self { + match phase { + SchemaSyncPhase::Pre => SyncState::PreData, + SchemaSyncPhase::Post => SyncState::PostData, + SchemaSyncPhase::Cutover => SyncState::Cutover, + } + } } /// Stages of a schema sync, reported as the task's status. @@ -23,14 +40,21 @@ pub(crate) enum SchemaSyncStatus { /// Creating indexes and constraints on the destination (post-data). #[display("creating indexes")] CreatingIndexes, + /// Restoring cutover-time schema on the destination. + #[display("syncing cutover schema")] + Cutover, } -/// Sync the schema (pre- or post-data) from a source database to a target. -#[derive(Display, Debug)] -#[display("schema_sync")] +/// Sync the schema (pre-data, post-data, or cutover) from a source database to a target. +#[derive(Display, Debug, bon::Builder)] +#[display("schema_sync({phase}) {orchestrator}")] pub(crate) struct SchemaSyncTask { pub orchestrator: Orchestrator, pub phase: SchemaSyncPhase, + #[builder(default)] + pub ignore_errors: bool, + #[builder(default)] + pub dry_run: bool, } impl Task for SchemaSyncTask { @@ -42,6 +66,7 @@ impl Task for SchemaSyncTask { /// task can thread it into the next phase. The schema dump is skipped when /// the orchestrator already carries one (e.g. a parent that runs `Pre` /// then `Post` on the same orchestrator). + #[allow(clippy::print_stdout)] async fn run(self, ctx: AsyncTaskContext) -> Result { let mut orchestrator = self.orchestrator; @@ -50,14 +75,28 @@ impl Task for SchemaSyncTask { orchestrator.load_schema().await?; } + // Dry run prints the SQL this task would execute and stops short of + // touching the destination. The schema load above is required for it. + if self.dry_run { + let schema = orchestrator.schema()?; + for statement in schema.statements(self.phase.into())? { + println!("{}", statement.deref()); + } + return Ok(orchestrator); + } + match self.phase { SchemaSyncPhase::Pre => { ctx.set_status(SchemaSyncStatus::SyncingTables); - orchestrator.schema_sync_pre(true).await?; + orchestrator.schema_sync_pre(self.ignore_errors).await?; } SchemaSyncPhase::Post => { ctx.set_status(SchemaSyncStatus::CreatingIndexes); - orchestrator.schema_sync_post(true).await?; + orchestrator.schema_sync_post(self.ignore_errors).await?; + } + SchemaSyncPhase::Cutover => { + ctx.set_status(SchemaSyncStatus::Cutover); + orchestrator.schema_sync_cutover(self.ignore_errors).await?; } } @@ -75,9 +114,28 @@ mod tests { SchemaSyncStatus::LoadingSchema.to_string(), SchemaSyncStatus::SyncingTables.to_string(), SchemaSyncStatus::CreatingIndexes.to_string(), + SchemaSyncStatus::Cutover.to_string(), ]; assert!(labels.iter().all(|label| !label.is_empty())); let unique: std::collections::HashSet<_> = labels.iter().collect(); assert_eq!(unique.len(), labels.len()); } + + #[test] + fn schema_sync_phase_parses_and_displays() { + for (text, phase) in [ + ("pre", SchemaSyncPhase::Pre), + ("post", SchemaSyncPhase::Post), + ("cutover", SchemaSyncPhase::Cutover), + ] { + assert_eq!(text.parse::().unwrap(), phase); + assert_eq!(phase.to_string(), text); + } + // Parsing is case-insensitive; unknown phases are rejected. + assert_eq!( + "CUTOVER".parse::().unwrap(), + SchemaSyncPhase::Cutover + ); + assert!("bogus".parse::().is_err()); + } } diff --git a/pgdog/src/backend/replication/logical/admin.rs b/pgdog/src/backend/replication/logical/admin.rs deleted file mode 100644 index 13a5ce046..000000000 --- a/pgdog/src/backend/replication/logical/admin.rs +++ /dev/null @@ -1,351 +0,0 @@ -use std::{ - fmt, - sync::{ - Arc, - atomic::{AtomicU64, Ordering}, - }, - time::SystemTime, -}; - -use crate::backend::replication::orchestrator::ReplicationWaiter; - -use super::Error; -use dashmap::DashMap; -use once_cell::sync::Lazy; -use tokio::{ - select, spawn, - sync::{Notify, oneshot}, - task::JoinHandle, -}; -use tracing::error; - -static TASKS: Lazy = Lazy::new(AsyncTasks::default); - -pub struct Task; - -impl Task { - pub(crate) fn register(task: TaskType) -> u64 { - AsyncTasks::insert(task) - } -} - -pub enum TaskType { - SchemaSync(JoinHandle>), - CopyData(JoinHandle>), - Replication(Box), -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum TaskKind { - SchemaSync, - CopyData, - Replication, -} - -impl fmt::Display for TaskKind { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - TaskKind::SchemaSync => write!(f, "schema_sync"), - TaskKind::CopyData => write!(f, "copy_data"), - TaskKind::Replication => write!(f, "replication"), - } - } -} - -pub struct TaskInfo { - #[allow(dead_code)] - abort_tx: oneshot::Sender<()>, - cutover: Arc, - pub task_kind: TaskKind, - pub started_at: SystemTime, -} - -#[derive(Clone, Default)] -pub struct AsyncTasks { - tasks: Arc>, - counter: Arc, -} - -impl AsyncTasks { - pub fn get() -> Self { - TASKS.clone() - } - - /// Perform cutover. - pub fn cutover() -> Result<(), Error> { - let this = Self::get(); - let task = this - .tasks - .iter() - .find(|t| t.task_kind == TaskKind::Replication) - .ok_or(Error::NotReplication)?; - - task.cutover.notify_one(); - - Ok(()) - } - - pub fn insert(task: TaskType) -> u64 { - let this = Self::get(); - let id = this.counter.fetch_add(1, Ordering::SeqCst); - let (abort_tx, abort_rx) = oneshot::channel(); - - match task { - TaskType::SchemaSync(handle) => { - this.tasks.insert( - id, - TaskInfo { - abort_tx, - cutover: Arc::new(Notify::new()), - task_kind: TaskKind::SchemaSync, - started_at: SystemTime::now(), - }, - ); - let abort_handle = handle.abort_handle(); - spawn(async move { - select! { - _ = abort_rx => { - abort_handle.abort(); - } - result = handle => { - match result { - Ok(Ok(())) => {} - Ok(Err(err)) => error!("[task: {}] {}", id, err), - Err(err) => error!("[task: {}] {}", id, err), - } - } - } - AsyncTasks::get().tasks.remove(&id); - }); - } - - TaskType::CopyData(handle) => { - this.tasks.insert( - id, - TaskInfo { - abort_tx, - cutover: Arc::new(Notify::new()), - task_kind: TaskKind::CopyData, - started_at: SystemTime::now(), - }, - ); - let abort_handle = handle.abort_handle(); - spawn(async move { - select! { - _ = abort_rx => { - abort_handle.abort(); - } - result = handle => { - match result { - Ok(Ok(())) => {} - Ok(Err(err)) => error!("[task: {}] {}", id, err), - Err(err) => error!("[task: {}] {}", id, err), - } - } - } - AsyncTasks::get().tasks.remove(&id); - }); - } - - TaskType::Replication(mut waiter) => { - let cutover = Arc::new(Notify::new()); - - this.tasks.insert( - id, - TaskInfo { - abort_tx, - cutover: cutover.clone(), - task_kind: TaskKind::Replication, - started_at: SystemTime::now(), - }, - ); - - spawn(async move { - select! { - _ = abort_rx => { - waiter.stop(); - } - - _ = cutover.notified() => { - if let Err(err) = waiter.cutover().await { - error!("[task: {}] {}", id, err); - } - } - - result = waiter.wait() => { - if let Err(err) = result { - error!("[task: {}] {}", id, err); - } - } - } - - AsyncTasks::get().tasks.remove(&id); - }); - } - } - - id - } - - pub fn remove(id: u64) -> Option { - // Dropping the sender signals abort to the waiting task - Self::get() - .tasks - .remove(&id) - .map(|(_, info)| info.task_kind) - } - - pub fn iter(&self) -> impl Iterator + '_ { - self.tasks - .iter() - .map(|e| (*e.key(), e.value().task_kind, e.value().started_at)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::time::Duration; - use tokio::time::sleep; - - #[test] - fn test_task_kind_display() { - assert_eq!(TaskKind::SchemaSync.to_string(), "schema_sync"); - assert_eq!(TaskKind::CopyData.to_string(), "copy_data"); - assert_eq!(TaskKind::Replication.to_string(), "replication"); - } - - #[tokio::test] - async fn test_task_registration_and_removal() { - // Create a task that completes immediately - let handle = spawn(async { Ok::<(), Error>(()) }); - let id = Task::register(TaskType::SchemaSync(handle)); - - // Task should be visible briefly - // Give it a moment to register - sleep(Duration::from_millis(10)).await; - - // Try to remove it - it may already be gone if it completed - let result = AsyncTasks::remove(id); - // Either we removed it, or it already completed and removed itself - assert!(result.is_none() || result == Some(TaskKind::SchemaSync)); - } - - #[tokio::test] - async fn test_task_abort_via_remove() { - // Create a long-running task - let handle = spawn(async { - sleep(Duration::from_secs(60)).await; - Ok::<(), Error>(()) - }); - let id = Task::register(TaskType::CopyData(handle)); - - // Give it time to register - sleep(Duration::from_millis(10)).await; - - // Remove should abort the task - let result = AsyncTasks::remove(id); - assert_eq!(result, Some(TaskKind::CopyData)); - - // Task should be gone now - sleep(Duration::from_millis(50)).await; - let result = AsyncTasks::remove(id); - assert_eq!(result, None); - } - - #[tokio::test] - async fn test_task_iter() { - // Create multiple tasks - let handle1 = spawn(async { - sleep(Duration::from_secs(60)).await; - Ok::<(), Error>(()) - }); - let handle2 = spawn(async { - sleep(Duration::from_secs(60)).await; - Ok::<(), Error>(()) - }); - - let id1 = Task::register(TaskType::SchemaSync(handle1)); - let id2 = Task::register(TaskType::CopyData(handle2)); - - sleep(Duration::from_millis(10)).await; - - // Should see both tasks - let tasks: Vec<_> = AsyncTasks::get().iter().collect(); - let task_ids: Vec<_> = tasks.iter().map(|(id, _, _)| *id).collect(); - assert!(task_ids.contains(&id1)); - assert!(task_ids.contains(&id2)); - - // Verify task kinds - for (id, kind, _) in &tasks { - if *id == id1 { - assert_eq!(*kind, TaskKind::SchemaSync); - } else if *id == id2 { - assert_eq!(*kind, TaskKind::CopyData); - } - } - - // Cleanup - AsyncTasks::remove(id1); - AsyncTasks::remove(id2); - } - - #[tokio::test] - async fn test_task_auto_cleanup_on_completion() { - // Create a task that completes quickly - let handle = spawn(async { - sleep(Duration::from_millis(10)).await; - Ok::<(), Error>(()) - }); - let id = Task::register(TaskType::SchemaSync(handle)); - - // Wait for task to complete and cleanup - sleep(Duration::from_millis(100)).await; - - // Task should have removed itself - let result = AsyncTasks::remove(id); - assert_eq!(result, None); - } - - #[tokio::test] - async fn test_cutover_fails_without_replication_task() { - // Create a non-replication task - let handle = spawn(async { - sleep(Duration::from_secs(60)).await; - Ok::<(), Error>(()) - }); - let id = Task::register(TaskType::SchemaSync(handle)); - sleep(Duration::from_millis(10)).await; - - // Cutover should fail because there's no replication task - let result = AsyncTasks::cutover(); - assert!(matches!(result, Err(Error::NotReplication)), "{:?}", result); - - // Cleanup - AsyncTasks::remove(id); - } - - #[tokio::test] - async fn test_cutover_returns_not_found_when_no_replication_task() { - // Register several non-replication tasks - let mut task_ids = vec![]; - for _ in 0..5 { - let handle = spawn(async { - sleep(Duration::from_secs(60)).await; - Ok::<(), Error>(()) - }); - task_ids.push(Task::register(TaskType::SchemaSync(handle))); - } - - sleep(Duration::from_millis(10)).await; - - // With only non-replication tasks, cutover should return TaskNotFound - let result = AsyncTasks::cutover(); - assert!(matches!(result, Err(Error::NotReplication))); - - // Cleanup - for id in task_ids { - AsyncTasks::remove(id); - } - } -} diff --git a/pgdog/src/backend/replication/logical/cutover_signal.rs b/pgdog/src/backend/replication/logical/cutover_signal.rs deleted file mode 100644 index e81ea7767..000000000 --- a/pgdog/src/backend/replication/logical/cutover_signal.rs +++ /dev/null @@ -1,43 +0,0 @@ -//! Cutover signal for the logical replication task. -//! -//! Standalone channel between the `CUTOVER` admin command and the -//! running replication task: the command [`request`]s the cutover -//! from anywhere, the replication task [`requested`] waits for it. -//! -//! One request is buffered, so a request arriving before the task -//! starts waiting is not lost. There is at most one replication -//! task per process, matching the single buffered permit. - -use std::sync::LazyLock; - -use tokio::sync::Notify; - -static CUTOVER: LazyLock = LazyLock::new(Notify::new); - -/// Request a cutover from the running replication task. -pub fn request() { - CUTOVER.notify_one(); -} - -/// Wait until a cutover is requested. Only the replication task -/// waits on this. -pub async fn requested() { - CUTOVER.notified().await; -} - -#[cfg(test)] -mod tests { - use std::time::Duration; - - use super::*; - - #[tokio::test] - async fn test_request_is_buffered() { - // Request lands before anyone waits: still delivered. - request(); - - tokio::time::timeout(Duration::from_secs(1), requested()) - .await - .unwrap(); - } -} diff --git a/pgdog/src/backend/replication/logical/mod.rs b/pgdog/src/backend/replication/logical/mod.rs index e10a23b0c..3bc799b33 100644 --- a/pgdog/src/backend/replication/logical/mod.rs +++ b/pgdog/src/backend/replication/logical/mod.rs @@ -1,6 +1,4 @@ -pub mod admin; pub mod copy_statement; -pub mod cutover_signal; pub mod ee; pub mod error; pub mod orchestrator; @@ -8,7 +6,6 @@ pub mod publisher; pub mod status; pub mod subscriber; -pub use admin::*; pub use copy_statement::CopyStatement; pub use error::*; diff --git a/pgdog/src/backend/replication/logical/orchestrator.rs b/pgdog/src/backend/replication/logical/orchestrator.rs index ff394cfcd..7e98b98ae 100644 --- a/pgdog/src/backend/replication/logical/orchestrator.rs +++ b/pgdog/src/backend/replication/logical/orchestrator.rs @@ -28,6 +28,23 @@ pub(crate) struct Orchestrator { replication_slot: String, } +/// A handle to a publication's replication slots, decoupled from the rest of +/// the orchestrator. Awaiting [`PublicationGuard::cleanup`] drops every slot +/// the publisher still owns — a no-op once `replicate` has handed them off to +/// the streaming tasks. Take one before data sync and await it on every exit +/// so a failed or aborted migration never leaves slots lingering on the source +/// (holding back WAL). +pub(crate) struct PublicationGuard { + publisher: Arc>, +} + +impl PublicationGuard { + /// Drop any replication slots the publisher still owns. + pub(crate) async fn cleanup(self) -> Result<(), Error> { + self.publisher.lock().await.cleanup().await + } +} + impl Orchestrator { /// Create new orchestrator. pub(crate) fn new( @@ -145,6 +162,13 @@ impl Orchestrator { Ok(()) } + /// Take a [`PublicationGuard`] over this orchestrator's replication slots. + pub(crate) fn publication_guard(&self) -> PublicationGuard { + PublicationGuard { + publisher: self.publisher.clone(), + } + } + /// Replicate forever. /// /// Useful for CLI interface only, since this will never stop. @@ -162,38 +186,6 @@ impl Orchestrator { }) } - /// Request replication stop. - pub(crate) async fn request_stop(&self) { - self.publisher.lock().await.request_stop(); - } - - /// Perform the entire flow in one swoop. - #[deprecated(note = "phase orchestration now lives in the migration tasks (see \ - `crate::api::resharding::ReshardTask`); drive the individual \ - steps directly. Remove once the remaining callers migrate.")] - pub(crate) async fn replicate_and_cutover(&mut self) -> Result<(), Error> { - // Load the schema from source. - self.load_schema().await?; - - // Sync the schema to destination. - self.schema_sync_pre(true).await?; - - // Sync the data to destination. - self.data_sync().await?; - - // Create secondary indexes on destination. - self.schema_sync_post(true).await?; - - // Start replication to catch up and cutover once done. - // Refresh cluster references: data_sync can take hours and the pools - // may have been reloaded (e.g. by a client DDL) in the meantime. - self.refresh()?; - - self.replicate().await?.cutover().await?; - - Ok(()) - } - pub(crate) async fn schema_sync_post(&mut self, ignore_errors: bool) -> Result<(), Error> { let schema = self.schema.as_ref().ok_or(Error::NoSchema)?; @@ -224,16 +216,21 @@ impl Orchestrator { let lag = self.publisher.lock().await.replication_lag(); lag.values().copied().max().unwrap_or_default() as u64 } +} - pub(crate) async fn cleanup(&mut self) -> Result<(), Error> { - let mut guard = self.publisher.lock().await; - guard.cleanup().await?; - - Ok(()) +impl Display for Orchestrator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{} -> {}", + self.source.identifier().database, + self.destination.identifier().database + ) } } -#[derive(Debug)] +#[derive(Debug, Display)] +#[display("{orchestrator}")] pub struct ReplicationWaiter { orchestrator: Orchestrator, waiter: Waiter, @@ -318,7 +315,7 @@ impl ReplicationWaiter { maintenance_mode::start(None); // Cancel any running queries. - cancel_all(&self.orchestrator.source.identifier().database).await?; + ok_or_abort!(cancel_all(&self.orchestrator.source.identifier().database).await); break; // TODO: wait for clients to all stop. @@ -397,7 +394,7 @@ impl ReplicationWaiter { // In case replication breaks now. res = self.waiter.wait() => { - res?; + ok_or_abort!(res); } } @@ -467,16 +464,24 @@ impl ReplicationWaiter { // Fix any reverse replication blockers. ok_or_abort!(self.orchestrator.schema_sync_post_cutover(true).await); - // Create reverse replication in case we need to rollback. + // Create the reverse-replication slot synchronously, while traffic is + // still paused, so its consistent-point LSN captures every write made + // to the new primary after cutover. On failure, resume traffic and + // abort — the forward switch is already committed, so this surfaces as + // an error for the operator (rollback won't be available). let waiter = ok_or_abort!(self.orchestrator.replicate().await); - // Let it run in the background. - AsyncTasks::insert(TaskType::Replication(Box::new(waiter))); + // Drive the running waiter as a background api task so it stays visible + // in SHOW TASKS and can be cut over (rollback) or stopped. + crate::api::start( + crate::api::replication::ReplicationTask::builder() + .waiter(waiter) + .direction(crate::api::replication::Direction::Reverse) + .build(), + ); - // It's not safe to resume traffic. + // Slot is established and capturing — now safe to resume traffic. info!("[cutover] complete, resuming traffic"); - - // Point traffic to the other database and resume. maintenance_mode::stop(None); cutover_state(CutoverState::Complete); diff --git a/pgdog/src/backend/replication/logical/publisher/publisher_impl.rs b/pgdog/src/backend/replication/logical/publisher/publisher_impl.rs index dd7065696..f79c656ad 100644 --- a/pgdog/src/backend/replication/logical/publisher/publisher_impl.rs +++ b/pgdog/src/backend/replication/logical/publisher/publisher_impl.rs @@ -4,11 +4,11 @@ use std::time::Duration; use parking_lot::Mutex; use pgdog_config::QueryParserEngine; -use tokio::sync::Notify; -use tokio::task::JoinHandle; +use tokio::task::{JoinHandle, JoinSet}; use tokio::time::{Instant, sleep}; use tokio::try_join; use tokio::{select, spawn, time::interval}; +use tokio_util::sync::CancellationToken; use tracing::{debug, info, warn}; use super::super::{Error, ensure_validation, publisher::Table}; @@ -55,7 +55,7 @@ pub struct Publisher { /// Last transaction. last_transaction: Arc>>, /// Stop signal. - stop: Arc, + stop: CancellationToken, /// Slot name. slot_name: String, } @@ -72,7 +72,7 @@ impl Publisher { slots: HashMap::new(), query_parser_engine, replication_lag: Arc::new(Mutex::new(HashMap::new())), - stop: Arc::new(Notify::new()), + stop: CancellationToken::new(), last_transaction: Arc::new(Mutex::new(None)), slot_name, } @@ -234,10 +234,15 @@ impl Publisher { let max_attempts = dest.resharding_replication_retry_max_attempts(); let delay = dest.resharding_replication_retry_min_delay(); let mut attempt = 0usize; + // Latches on the first cancellation so the `cancelled()` arm fires + // once (it stays ready forever after `cancel()`); the drain below + // then runs to completion. + let mut stopping = false; loop { select! { - _ = stop.notified() => { + _ = stop.cancelled(), if !stopping => { slot.stop_replication().await?; + stopping = true; } // This is cancellation-safe. @@ -337,7 +342,7 @@ impl Publisher { /// Request the publisher to stop replication. pub fn request_stop(&self) { - self.stop.notify_one(); + self.stop.cancel(); } /// Get current replication lag. @@ -371,10 +376,14 @@ impl Publisher { ensure_validation!(validation_errors); // Create replication slots only after validation passes — a slot - // created before valid() would be orphaned on a NoIdentityColumns error. + // created before valid() would be orphaned on validation errors. self.create_slots(source).await?; - let mut handles = vec![]; + // A JoinSet aborts every in-flight shard copy when this future is + // dropped (e.g. the owning task is cancelled), so cancellation actually + // stops the copy instead of leaving detached syncs running in the + // background. + let mut set: JoinSet), Error>> = JoinSet::new(); for (number, shard) in source.shards().iter().enumerate() { let tables = self @@ -411,16 +420,16 @@ impl Publisher { }; let dest = dest.clone(); - handles.push(spawn(async move { + set.spawn(async move { let manager = ParallelSyncManager::new(tables, replicas, dest)?; let tables = manager.run().await?; - Ok::, Error>(tables) - })); + Ok::<(usize, Vec), Error>((number, tables)) + }); } - for (number, handle) in handles.into_iter().enumerate() { - let tables = handle.await??; + while let Some(joined) = set.join_next().await { + let (number, tables) = joined??; info!( "table sync for {} tables complete [{}, shard: {}]", @@ -436,13 +445,20 @@ impl Publisher { Ok(()) } - /// Cleanup after replication. + /// Drop the replication slots created during data sync. + /// + /// Idempotent: the slot map is taken out up front, so repeated calls — or a + /// call after replication already took the slots over — are no-ops. Every + /// slot is attempted even if one fails; the first error is returned. pub async fn cleanup(&mut self) -> Result<(), Error> { - for slot in self.slots.values_mut() { - slot.drop_slot().await?; + let mut error = None; + for (_, mut slot) in std::mem::take(&mut self.slots) { + if let Err(err) = slot.drop_slot().await { + error.get_or_insert(err); + } } - Ok(()) + error.map_or(Ok(()), Err) } } @@ -460,12 +476,12 @@ impl Publisher { #[derive(Debug)] pub struct Waiter { streams: Vec>>, - stop: Arc, + stop: CancellationToken, } impl Waiter { pub fn stop(&self) { - self.stop.notify_one(); + self.stop.cancel(); } pub async fn wait(&mut self) -> Result<(), Error> { @@ -482,7 +498,7 @@ impl Waiter { pub fn new_test() -> Self { Self { streams: vec![], - stop: Arc::new(Notify::new()), + stop: CancellationToken::new(), } } } diff --git a/pgdog/src/backend/schema/sync/pg_dump.rs b/pgdog/src/backend/schema/sync/pg_dump.rs index 7a15e20d2..b6d628f39 100644 --- a/pgdog/src/backend/schema/sync/pg_dump.rs +++ b/pgdog/src/backend/schema/sync/pg_dump.rs @@ -111,7 +111,7 @@ fn should_convert_to_bigint<'a>( is_integer_type(sval.as_str()) } -use tokio::{process::Command, spawn}; +use tokio::{process::Command, task::JoinSet}; #[derive(Debug, Clone)] pub struct PgDump { @@ -1041,7 +1041,11 @@ impl PgDumpOutput { }) .collect::>(), )); - let mut handles = vec![]; + // A JoinSet aborts every in-flight per-shard sync when it is dropped, + // so cancelling the task (dropping this future) actually stops the + // schema apply instead of leaving detached spawns running in the + // background against the destination. + let mut set: JoinSet> = JoinSet::new(); for (num, shard) in dest.shards().iter().enumerate() { let mut primary = shard.primary(&Request::default()).await?; @@ -1056,7 +1060,7 @@ impl PgDumpOutput { let trackers = trackers.clone(); let output = self.clone(); - handles.push(spawn(async move { + set.spawn(async move { let stmts = output.statements(state)?; let mut progress = Progress::new(stmts.len()); @@ -1107,11 +1111,11 @@ impl PgDumpOutput { } Ok::<(), Error>(()) - })); + }); } - for handle in handles { - handle.await??; + while let Some(joined) = set.join_next().await { + joined??; } Ok(()) diff --git a/pgdog/src/cli.rs b/pgdog/src/cli.rs index 9b59bbb6f..5570a55c1 100644 --- a/pgdog/src/cli.rs +++ b/pgdog/src/cli.rs @@ -1,18 +1,19 @@ -use std::ops::Deref; use std::path::PathBuf; -use std::time::Duration; use clap::{Parser, Subcommand}; use std::fs::read_to_string; use thiserror::Error; -use tokio::time::sleep; use tokio::{select, signal::ctrl_c}; use tracing::{info, warn}; +use crate::api::Task; +use crate::api::resharding::ReshardTask; +use crate::api::schema_sync::{SchemaSyncPhase, SchemaSyncTask}; +use crate::api::start; +use crate::api::storage; use crate::backend::databases::databases; use crate::backend::replication::orchestrator::Orchestrator; use crate::backend::schema::sync::config::ShardConfig; -use crate::backend::schema::sync::pg_dump::SyncState; use crate::config::{Config, Users}; use crate::frontend::router::cli::RouterCli; @@ -279,6 +280,28 @@ pub fn config_check( } } +/// Run an api task to completion in the foreground, cancelling it on Ctrl-C so +/// it can wind down (e.g. stop replication) instead of the process being +/// hard-killed. Returns the task output, or its error/cancellation outcome. +async fn run_to_completion(task: T) -> Result> +where + T::Error: std::error::Error + 'static, +{ + let mut waiter = start(task); + let id = waiter.id(); + + loop { + select! { + result = &mut waiter => return Ok(result?), + signal = ctrl_c() => { + signal?; + warn!("interrupt received, cancelling task {id}"); + storage().cancel_task(id); + } + } + } +} + /// FOR TESTING PURPOSES ONLY. pub async fn replicate_and_cutover(commands: Commands) -> Result<(), Box> { if let Commands::ReplicateAndCutover { @@ -288,22 +311,26 @@ pub async fn replicate_and_cutover(commands: Commands) -> Result<(), Box Result<(), Box> { - use crate::backend::replication::logical::Error; - if let Commands::DataSync { from_database, to_database, @@ -314,65 +341,23 @@ pub async fn data_sync(commands: Commands) -> Result<(), Box { - result?; - } - - _ = ctrl_c() => { - warn!("abort signal received, waiting 5 seconds and performing cleanup"); - sleep(Duration::from_secs(5)).await; - - orchestrator.cleanup().await?; - - return Err(Error::DataSyncAborted.into()); - } - } - } - - if !sync_only { - let mut waiter = orchestrator.replicate().await?; - - select! { - result = waiter.wait() => { - result?; - } - - _ = ctrl_c() => { - warn!("abort signal received"); - - orchestrator.request_stop().await; - - info!("waiting for replication to stop"); - - waiter.wait().await?; - orchestrator.cleanup().await?; - - return Err(Error::DataSyncAborted.into()); - } - } - } - } else { - return Ok(()); + let orchestrator = + Orchestrator::new(&from_database, &to_database, &publication, replication_slot)?; + + run_to_completion( + ReshardTask::builder() + .orchestrator(orchestrator) + .skip_schema_sync(skip_schema_sync) + .replicate_only(replicate_only) + .sync_only(sync_only) + .build(), + ) + .await?; } Ok(()) } -#[allow(clippy::print_stdout)] pub async fn schema_sync(commands: Commands) -> Result<(), Box> { if let Commands::SchemaSync { from_database, @@ -384,32 +369,25 @@ pub async fn schema_sync(commands: Commands) -> Result<(), Box Date: Tue, 23 Jun 2026 13:32:52 +0000 Subject: [PATCH 6/8] fixes & improvements --- .config/nextest.toml | 1 - docs/ASYNC_TASKS.md | 141 ++-- docs/RESHARDING.md | 22 +- docs/issues/replication.md | 77 ++- .../rust/tests/integration/admin/mod.rs | 57 +- .../tests/integration/admin/show_config.rs | 3 +- .../tests/integration/admin/show_version.rs | 3 +- .../rust/tests/integration/admin/tasks.rs | 602 +++++++++++------ pgdog-config/src/core.rs | 65 ++ pgdog/Cargo.toml | 2 +- pgdog/src/admin/copy_data.rs | 4 +- pgdog/src/admin/cutover.rs | 4 +- pgdog/src/admin/replicate.rs | 4 +- pgdog/src/admin/reshard.rs | 4 +- pgdog/src/admin/schema_sync.rs | 4 +- pgdog/src/admin/show_tasks.rs | 49 +- pgdog/src/admin/stop_task.rs | 6 +- pgdog/src/api/async_task.rs | 603 +++++++++--------- pgdog/src/api/mod.rs | 58 +- pgdog/src/api/replication.rs | 116 ++-- pgdog/src/api/resharding.rs | 31 +- pgdog/src/api/schema_sync.rs | 6 - .../replication/logical/orchestrator.rs | 46 +- pgdog/src/cli.rs | 8 +- 24 files changed, 1066 insertions(+), 850 deletions(-) diff --git a/.config/nextest.toml b/.config/nextest.toml index 959c1bc19..e1836202c 100644 --- a/.config/nextest.toml +++ b/.config/nextest.toml @@ -5,4 +5,3 @@ default-filter = "not package(rust)" [profile.default] slow-timeout = "15s" -test-threads = 1 diff --git a/docs/ASYNC_TASKS.md b/docs/ASYNC_TASKS.md index 52840525d..68b15f5a8 100644 --- a/docs/ASYNC_TASKS.md +++ b/docs/ASYNC_TASKS.md @@ -1,103 +1,90 @@ -# Async Task Framework — Architecture +# Async Task Framework -The `crate::api` module ([`pgdog/src/api/`](../pgdog/src/api/)) is the execution layer that sits -between PgDog's two user interfaces and its long-running operations. Any operation that may run for -seconds to hours runs here as a background *task*. +Long-running operations (resharding, copy, replication, schema sync) run as background *tasks* in +`crate::api` ([`pgdog/src/api/`](../pgdog/src/api/)). The admin SQL API and the `pgdog` CLI both +start the same task through the same registry; only how they consume the result differs. -The central principle is that **the task is the single source of truth for execution**. A user -interface only assembles options and starts the task; all behaviour, status transitions, and error -handling live inside it. Whether an operation is started through a SQL command on the admin -database or a terminal invocation of the CLI, the same task runs the same code. +## The `Task` trait -This document covers the framework itself — how tasks are started, tracked, composed, cancelled, -and observed. It deliberately does not enumerate the individual task implementations; each task's -behaviour is documented alongside its own code in [`pgdog/src/api/`](../pgdog/src/api/). +A task is any type implementing `Task` ([`api/async_task.rs`](../pgdog/src/api/async_task.rs)): ---- - -## Architecture - -```mermaid -flowchart TD - subgraph Interfaces - ADMIN["admin database API\nSQL commands"] - CLI["pgdog CLI\nsubcommands"] - end - - subgraph API["crate::api — execution layer"] - REG["process-global registry\nAsyncTasksStorage"] - TASKS["task implementations\n(impl Task)"] - end +```rust +pub trait Task: Display + Debug + Send + Sync + Sized + 'static { + type Status: Display + Send + Sync + 'static; // inner progress; Empty = none + type Output: Send + 'static; + type Error: std::error::Error + Send + 'static; - WORK["underlying operation\n(engine / pipeline / I/O)"] + fn cancel_timeout() -> Duration { Duration::from_secs(5) } - ADMIN -->|start task, get id| REG - CLI -->|start task, await result| REG - REG -->|spawns and tracks| TASKS - TASKS -->|drives| WORK - ADMIN -->|SHOW TASKS / STOP_TASK| REG - CLI -->|Ctrl-C → cancel_task| REG + fn run(self, ctx: AsyncTaskContext) + -> impl Future> + Send + 'static; +} ``` -A task is any type that implements the `Task` trait -([`api/async_task.rs`](../pgdog/src/api/async_task.rs)): it defines its own status type, output, and -error, and provides an `async run`. The framework owns everything around that `run` — spawning, -registration, id assignment, status storage, cancellation, and retention. +`run` is the whole task. Everything else — spawning, ids, status storage, cancellation, +retention — is handled by the framework around it. + +## Starting a task ---- +`crate::api::run_task(task)` ([`api/mod.rs`](../pgdog/src/api/mod.rs)) calls +`tasks_storage().run(task)`, which delegates to the private `run_task` in `async_task.rs`. That +function: -## The registry +1. allocates an id from the registry (`tasks.next_id()`); a root task's `root_id` is its own id, +2. builds the `AsyncTask` entry (cancellation token, state, tracing span) and inserts it into the + map *before* spawning, +3. spawns the task future: `tokio::spawn(task.run(ctx).instrument(span))`, +4. spawns a second watcher future that `select!`s the task handle against its cancellation token + and records the terminal status, +5. returns an `AsyncTaskWaiter { id, waiter }`. -When a task is started via `crate::api::start()` ([`api/mod.rs`](../pgdog/src/api/mod.rs)), it is -spawned as an async future and immediately registered in `AsyncTasksStorage` -([`api/async_task.rs`](../pgdog/src/api/async_task.rs)) under a monotonically-increasing integer id. -The id is returned before any work begins, so the caller can track the operation while it runs in -the background. +The id is known before `run` does any work, so the caller can address the task immediately. -One registry serves the entire process. A task started by the CLI and a task started through the -admin SQL API are both registered in the same store and are equally visible to `SHOW TASKS` and -equally cancellable by `STOP_TASK`. A task spawned automatically by another task registers itself -the same way and is just as addressable from either interface. +```rust +pub struct AsyncTaskWaiter { + id: AsyncTaskId, + waiter: Receiver>>, // oneshot +} +``` -**Terminal tasks are retained for 24 hours** after they finish so their outcome can be inspected. -A running task is never pruned. `SHOW TASKS` filters terminal tasks back out and shows only what -is currently running; the retention period exists only so that an id returned by a command -continues to be addressable for a reasonable window after completion. +`AsyncTaskWaiter` is a `Future`; awaiting it yields the task's `Result`. A dropped sender (watcher +gone) maps to `Err(TaskError::Abandoned)`. `.id()` returns the id without awaiting. ---- +The registry is process-global: -## Task composition +```rust +static TASKS: LazyLock = LazyLock::new(AsyncTasksStorage::default); +``` -A task can spawn *child tasks* through its execution context. Children share the root task's -registration entry and appear as a flat subtask list on the root's snapshot in the registry. The -root's status describes which high-level phase is active; the child's status describes what is -happening within that phase. A composite task that sequences several phases therefore reports -fine-grained progress without any special support from the registry — each phase is just a child -task whose status bubbles up. +So a CLI task and an admin task land in the same `AsyncTasksStorage`, both visible to `SHOW TASKS` +and cancellable by `STOP_TASK`. -`SHOW TASKS` surfaces both the root and its running children as separate rows, so a child appears -with its own type even though it was never started as a top-level command. `STOP_TASK` only -addresses root tasks by id; cancelling the root propagates to all its children through the -cancellation token hierarchy (see [Cancellation](#cancellation) below). +## Status ---- +Two separate axes. The lifecycle status is a fixed enum: -## Status lifecycle +```rust +pub enum TaskStatus { + Started, Running, Cancelling, // non-terminal + Finished, Cancelled, Error(String), Panic(String), // terminal +} +``` -Every task type defines its own set of progress stages. The registry stores a type-erased -snapshot of the current status on every write and exposes it through `SHOW TASKS`. A task begins -in `Started` the moment it is registered, transitions through `Pending(stage)` as it reports -progress, and ends in one of four terminal states: +The domain-specific progress is `Task::Status`, reported by the task itself via +`ctx.set_status(...)` and surfaced separately as `inner_status`. `set_status` flips the lifecycle +to `Running` (but won't regress out of `Cancelling`) and stores the rendered inner status. -- **`Finished`** — completed successfully. -- **`Cancelled`** — stopped by `STOP_TASK`, Ctrl-C, or parent cancellation. -- **`Error`** — the task's code returned an error; the message is stored in the status. -- **`Panic`** — the task's future panicked; the message is stored. +The registry stores both in a type-erased `TaskState` (`name`, `status`, `inner_status`, +`started_at`, `updated_at`), so `SHOW TASKS` reads it without knowing `T`. -Terminal states are **write-once**: a context clone that outlives the task cannot overwrite a -recorded outcome. +Transitions are write-once at the terminal boundary: `transition` and `set_status` both bail early +if `state.status.is_terminal()`, so a context clone that outlives the task can't clobber a +recorded outcome. The watcher sets the terminal status based on the `select!` arm that won: ---- +- task returned `Ok` → `Finished` (or `Cancelled` if the token was already cancelled), +- task returned `Err(e)` → `Error(e)`, waiter gets `TaskError::Failed(e)`, +- join handle cancelled → `Cancelled` / `TaskError::Cancelled`, +- join handle panicked → `Panic(msg)` / `TaskError::Panicked(msg)`. ## Cancellation diff --git a/docs/RESHARDING.md b/docs/RESHARDING.md index 259a5b49b..24d5149bc 100644 --- a/docs/RESHARDING.md +++ b/docs/RESHARDING.md @@ -167,8 +167,10 @@ traffic immediately via `maintenance_mode::stop()` and returns an error. Steps i 1. `publisher.request_stop()` + `waiter.wait()` — stops the replication stream; drains remaining WAL. 2. `schema_sync_cutover()` — applies `SyncState::Cutover` operations (e.g. drops sequences that won't be used in the sharded cluster). -3. `cutover(source_db, dest_db)` in [`pgdog/src/backend/databases.rs`](../pgdog/src/backend/databases.rs) — atomically swaps source and - destination in the in-memory routing table. +3. `cutover(source_db, dest_db)` in [`pgdog/src/backend/databases.rs`](../pgdog/src/backend/databases.rs) — + atomically swaps the two clusters' logical identity in the routing table (and config refs via + `Config::cutover`/`Users::cutover`); no data moves. Persisted to disk when + `cutover_save_config = true`. 4. `orchestrator.refresh()` — re-fetches both clusters from `databases()` so the orchestrator now treats the new cluster as source for reverse replication. 5. `schema_sync_post_cutover()` — applies `SyncState::PostCutover` (removes blockers that would @@ -214,19 +216,19 @@ returned by any task surfaces via `table?` and aborts the manager's loop. Remain completion or abort via their own `AbortSignal`, but their results are ignored once the channel is dropped. -### Temporary vs permanent replication slots +On a failed or aborted migration, `ReshardTask::run` ([`api/resharding.rs`](../pgdog/src/api/resharding.rs)) +obtains a guard via `Orchestrator::publication_guard()` and calls `PublicationGuard::cleanup()`, which +locks the publisher and has `Publisher::cleanup()` drop the permanent WAL slot via +`DROP_REPLICATION_SLOT "name" WAIT`. On success the slot is kept so reverse replication can roll +back. If the process crashes before this runs, the slot survives and keeps accumulating WAL on the +source — drop it manually before retrying. + +### Temporary replication slots Per-table slots created in [`Table::data_sync()`](../pgdog/src/backend/replication/logical/publisher/table.rs) are `TEMPORARY` — PostgreSQL drops them automatically when the replication connection closes, including on error or panic. A failed copy task leaves no orphaned per-table slot. -The `Publisher`'s named replication slot (the one used for the WAL streaming phase) is permanent. -[`Publisher::cleanup()`](../pgdog/src/backend/replication/logical/publisher/publisher_impl.rs) drops it by calling `slot.drop_slot()`, which issues -`DROP_REPLICATION_SLOT "name" WAIT` over the replication protocol connection. `cleanup()` is an -explicit method on `Orchestrator` — it is not called automatically inside `replicate_and_cutover()`. -If the orchestrator is dropped after Step 5 begins but before `cleanup()` is called (e.g. a -process crash), the permanent slot survives and continues accumulating WAL on the source. - ### The `ok_or_abort!` macro — guaranteed traffic resumption after cutover starts ```rust diff --git a/docs/issues/replication.md b/docs/issues/replication.md index 1dac7989e..6c3f46a5d 100644 --- a/docs/issues/replication.md +++ b/docs/issues/replication.md @@ -42,7 +42,7 @@ their publication-scoped `confirmed_flush_lsn` from the same instance-wide |---|---| | `ReplicationSlot::replication_lag()` — the lag query | [`pgdog/src/backend/replication/logical/publisher/slot.rs`](../../pgdog/src/backend/replication/logical/publisher/slot.rs) | | `ReplicationWaiter::wait_for_replication()` — the cutover gate | [`pgdog/src/backend/replication/logical/orchestrator.rs`](../../pgdog/src/backend/replication/logical/orchestrator.rs) | -| Keepalive handler / `flush_lsn` reply (`data_since_keepalive` flag) | [`pgdog/src/backend/replication/logical/publisher/publisher_impl.rs`](../../pgdog/src/backend/replication/logical/publisher/publisher_impl.rs) | +| Keepalive handler / `flush_lsn` reply (proposed `data_since_keepalive` flag) | [`pgdog/src/backend/replication/logical/publisher/publisher_impl.rs`](../../pgdog/src/backend/replication/logical/publisher/publisher_impl.rs) | ### Fix The PostgreSQL walsender sends a keepalive message after exhausting all decoded changes available for the publication. The keepalive carries `wal_end` — the server's current WAL write position. When a keepalive arrives and no xlog data was received since the previous keepalive, the slot has drained: there is nothing for the publication between `committed_lsn` and `wal_end`, so that gap consists entirely of other databases' WAL. @@ -51,7 +51,7 @@ In this state the client can safely reply with `flush_lsn = wal_end`. PostgreSQL Reporting `wal_end` is only valid when the slot is caught up. During active streaming — where the server is sending transactions and keepalives may arrive between commits — `flush_lsn` must remain at `committed_lsn`. Reporting `wal_end` prematurely would advance `confirmed_flush_lsn` past unapplied commits; if the connection dropped, the server would restart from `wal_end` and those commits would be lost. -The implemented guard: `data_since_keepalive` flag. Set to `true` when any xlog data message arrives; cleared to `false` when a keepalive arrives. A keepalive is the catch-up signal only when this flag is `false` — meaning no data arrived between the last keepalive and this one. +The proposed guard: a `data_since_keepalive` flag. Set to `true` when any xlog data message arrives; cleared to `false` when a keepalive arrives. A keepalive is the catch-up signal only when this flag is `false` — meaning no data arrived between the last keepalive and this one. ``` keepalive received @@ -225,8 +225,8 @@ Each entry point was built to satisfy a specific operational need without consol | Symbol | File | |---|---| -| `TaskType::Replication` `select!` / `waiter.wait()` arm (line ~150) | [`pgdog/src/backend/replication/logical/admin.rs`](../../pgdog/src/backend/replication/logical/admin.rs) | -| `AsyncTasks::cutover()` / `notify_one()` (line ~75) | same file | +| `TaskType::Replication` `select!` / `waiter.wait()` arm (historical, line ~150) | now `ReplicationTask::run` in [`pgdog/src/api/replication.rs`](../../pgdog/src/api/replication.rs) | +| `AsyncTasks::cutover()` / `notify_one()` (historical, line ~75) | now the cutover signal in [`pgdog/src/api/async_task.rs`](../../pgdog/src/api/async_task.rs) | | `Replicate::execute()` | [`pgdog/src/admin/replicate.rs`](../../pgdog/src/admin/replicate.rs) | | `Reshard::execute()` | [`pgdog/src/admin/reshard.rs`](../../pgdog/src/admin/reshard.rs) | | `Orchestrator::replicate_and_cutover()` — the canonical flow | [`pgdog/src/backend/replication/logical/orchestrator.rs`](../../pgdog/src/backend/replication/logical/orchestrator.rs) | @@ -479,3 +479,72 @@ actually stopped. | `AsyncTasksStorage::prune` — drops only terminal entries past retention | same file | | `run_task` supervisor — cooperative grace via `cancel_timeout`; sets terminal status on completion | same file | | `ReplicationTask::run` / `cancel_timeout` — cutover-vs-stop `select!` | [`pgdog/src/api/replication.rs`](../../pgdog/src/api/replication.rs) | + +--- + +## ✅ Issue 7 — Cutover swapped the cluster name but not `database_name`, repointing entries at a nonexistent database (resolved) + +> **Resolved.** `Config::cutover` now pins the effective `database_name` on the +> two clusters being swapped *before* exchanging their logical `name`, so the +> name swap only changes routing identity and never the physical database an +> entry connects to. + +### Description + +Traffic cutover swaps the *logical* `name` of the source and destination +clusters (`Config::cutover`, `pgdog-config/src/core.rs`) — clients keep +connecting to the same name, which now resolves to the other cluster's backends. + +The physical Postgres database an entry connects to, however, is resolved as +`database_name` **if set, otherwise the cluster `name`** (`Address::from`, +`pgdog/src/backend/pool/address.rs`). `Config::cutover` rewrote `name` but left +`database_name` untouched. So any database entry that relied on the default +(no explicit `database_name`, i.e. cluster name == physical DB name) was, after +the swap, silently repointed at a physical database named after the *other* +cluster — which usually does not exist. + +Two symptoms, one cause: + +- The pool monitor logs `FATAL: 3D000 database "" does not exist` and + retries forever against the dead config. +- The post-cutover schema restore (`schema_sync_post_cutover` → + `PgDumpOutput::restore`) checks out a connection from that pool; since it can + never connect, the checkout waits out `checkout_timeout` and the cutover task + fails with `schema: checkout timeout`, past the point of no return. + +Configs that set `database_name` explicitly on every entry (e.g. the +`integration/resharding` suite, where all entries use `database_name = "postgres"`) +were unaffected, which is why this stayed hidden — the swap is a no-op for the +physical target when `database_name` is already pinned. + +### Cause + +`database_name` defaults to the cluster `name` at connection time, and +`Config::cutover` changed `name` without first materializing that default. The +two fields are only equivalent until the name is swapped. + +### Code references + +| Symbol | File | +|---|---| +| `Config::cutover` — swaps `name`; now pins `database_name` first | [`pgdog-config/src/core.rs`](../../pgdog-config/src/core.rs) | +| `Address::from` — `database_name` falls back to cluster `name` | [`pgdog/src/backend/pool/address.rs`](../../pgdog/src/backend/pool/address.rs) | +| `databases::cutover` — applies the config swap and relaunches pools | [`pgdog/src/backend/databases.rs`](../../pgdog/src/backend/databases.rs) | +| `test_cutover_preserves_physical_database_name` — regression test | [`pgdog-config/src/core.rs`](../../pgdog-config/src/core.rs) | + +### Fix + +Before the name swap, set `database_name` to the current `name` on any entry of +the two clusters where it is unset: + +```rust +for db in self.databases.iter_mut() { + if (db.name == source || db.name == destination) && db.database_name.is_none() { + db.database_name = Some(db.name.clone()); + } +} +``` + +The swap then moves only the logical name; the physical connection target is +preserved. Covered by `test_cutover_preserves_physical_database_name` (asymmetric: +one cluster relies on the default, the other sets `database_name` explicitly). diff --git a/integration/rust/tests/integration/admin/mod.rs b/integration/rust/tests/integration/admin/mod.rs index 605d8ece7..8f2ee06b3 100644 --- a/integration/rust/tests/integration/admin/mod.rs +++ b/integration/rust/tests/integration/admin/mod.rs @@ -1,8 +1,6 @@ -//! Integration tests asserting the output of admin commands. +//! Integration tests asserting admin command output over the wire. //! -//! Each submodule connects to the live PgDog admin database (see -//! `rust::setup::admin_sqlx`) and verifies the shape and contents of a -//! command's output over the wire. +//! Each submodule connects to the live PgDog admin database (`rust::setup::admin_sqlx`). pub mod show_config; pub mod show_version; pub mod tasks; @@ -12,51 +10,42 @@ use sqlx::{Column, Executor, Pool, Postgres, Row, TypeInfo}; /// Wire layout expected from `SHOW TASKS`. const SHOW_TASKS_LAYOUT: &[(&str, &str)] = &[ ("id", "INT8"), - ("root_id", "INT8"), ("scope", "TEXT"), ("type", "TEXT"), ("status", "TEXT"), + ("inner_status", "TEXT"), ("started_at", "TEXT"), ("updated_at", "TEXT"), ("elapsed", "TEXT"), ("elapsed_ms", "INT8"), ]; -/// A single row returned by `SHOW TASKS` with all fields already parsed and -/// validated. Construction only succeeds through [`Tasks::fetch`], which -/// checks the wire layout and field invariants before handing rows out. +/// A parsed, validated `SHOW TASKS` row. Built only via [`Tasks::fetch`]. #[derive(Debug, Clone)] pub struct Task { - pub id: i64, - pub root_id: i64, + pub id: Option, pub scope: String, pub kind: String, pub status: String, + pub inner_status: String, pub started_at: String, pub updated_at: String, pub elapsed: String, pub elapsed_ms: i64, } -/// Parsed result of a `SHOW TASKS` admin command. -/// -/// Call [`Tasks::fetch`] to issue the command, validate column layout and -/// every row's field invariants in one shot, and get back a typed collection -/// you can query with [`Tasks::find`] or iterate over [`Tasks::rows`]. +/// Parsed result of `SHOW TASKS`; query with [`Tasks::find`] or [`Tasks::rows`]. pub struct Tasks { pub rows: Vec, } impl Tasks { - /// Issue `SHOW TASKS` against `pool`, assert the wire layout, parse and - /// validate every row, and return the collection. - /// - /// Panics on any layout mismatch, unexpected wire type, or field that - /// violates an invariant (empty timestamp, negative elapsed_ms). + /// Run `SHOW TASKS`, assert the wire layout and per-row invariants, and + /// return the parsed rows. Panics on any layout/type/invariant violation. pub async fn fetch(pool: &Pool) -> Self { let raw = pool.fetch_all("SHOW TASKS").await.unwrap(); - // assert_layout requires at least one row; skip when empty (valid — no tasks running). + // assert_layout needs a row; an empty result is valid (no tasks). if !raw.is_empty() { assert_layout(&raw, SHOW_TASKS_LAYOUT); } @@ -64,8 +53,7 @@ impl Tasks { let rows = raw .iter() .map(|row| { - let id: i64 = row.get("id"); - let root_id: i64 = row.get("root_id"); + let id: Option = row.get("id"); let scope: String = row.get("scope"); let status: String = row.get("status"); let started_at: String = row.get("started_at"); @@ -73,22 +61,22 @@ impl Tasks { let elapsed: String = row.get("elapsed"); let elapsed_ms: i64 = row.get("elapsed_ms"); - assert!(!started_at.is_empty(), "task {id}: started_at is empty"); - assert!(!updated_at.is_empty(), "task {id}: updated_at is empty"); - assert!(!elapsed.is_empty(), "task {id}: elapsed is empty"); - assert!(!status.is_empty(), "task {id}: status is empty"); - assert!(elapsed_ms >= 0, "task {id}: elapsed_ms is negative"); + assert!(!started_at.is_empty(), "task {id:?}: started_at is empty"); + assert!(!updated_at.is_empty(), "task {id:?}: updated_at is empty"); + assert!(!elapsed.is_empty(), "task {id:?}: elapsed is empty"); + assert!(!status.is_empty(), "task {id:?}: status is empty"); + assert!(elapsed_ms >= 0, "task {id:?}: elapsed_ms is negative"); assert!( scope == "root" || scope == "subtask", - "task {id}: unexpected scope {scope:?}" + "task {id:?}: unexpected scope {scope:?}" ); Task { id, - root_id, scope, kind: row.get("type"), status, + inner_status: row.get("inner_status"), started_at, updated_at, elapsed, @@ -100,9 +88,9 @@ impl Tasks { Self { rows } } - /// Return the task with the given id, if present. + /// Return the (root) task with the given id, if present. pub fn find(&self, id: i64) -> Option<&Task> { - self.rows.iter().find(|t| t.id == id) + self.rows.iter().find(|t| t.id == Some(id)) } pub fn is_empty(&self) -> bool { @@ -110,10 +98,7 @@ impl Tasks { } } -/// Assert that `rows` is non-empty and that the first row's column layout -/// (name, wire type) matches `expected` exactly, in order. -/// -/// Used by submodule tests for commands other than `SHOW TASKS`. +/// Assert the first row's column layout (name, wire type) matches `expected` exactly. pub fn assert_layout(rows: &[sqlx::postgres::PgRow], expected: &[(&str, &str)]) { assert!(!rows.is_empty(), "expected at least one row"); let actual: Vec<(&str, &str)> = rows[0] diff --git a/integration/rust/tests/integration/admin/show_config.rs b/integration/rust/tests/integration/admin/show_config.rs index 1d23de6be..902c8f6bc 100644 --- a/integration/rust/tests/integration/admin/show_config.rs +++ b/integration/rust/tests/integration/admin/show_config.rs @@ -5,8 +5,7 @@ use sqlx::{Executor, Row}; use super::assert_layout; -/// `SHOW CONFIG` returns rows described by two TEXT columns, `name` and -/// `value`, one per configuration setting. +/// `SHOW CONFIG` returns `name`/`value` TEXT columns, one row per setting. #[tokio::test] async fn test_show_config_reports_settings() { let admin = admin_sqlx().await; diff --git a/integration/rust/tests/integration/admin/show_version.rs b/integration/rust/tests/integration/admin/show_version.rs index 8afc82431..ad7f9f073 100644 --- a/integration/rust/tests/integration/admin/show_version.rs +++ b/integration/rust/tests/integration/admin/show_version.rs @@ -3,8 +3,7 @@ use sqlx::{Executor, Row}; use super::assert_layout; -/// `SHOW VERSION` returns a single row described by one `version` TEXT column, -/// carrying the PgDog version banner. +/// `SHOW VERSION` returns a single `version` TEXT row carrying the banner. #[tokio::test] async fn test_show_version_reports_banner() { let admin = admin_sqlx().await; diff --git a/integration/rust/tests/integration/admin/tasks.rs b/integration/rust/tests/integration/admin/tasks.rs index d9888d748..436e75592 100644 --- a/integration/rust/tests/integration/admin/tasks.rs +++ b/integration/rust/tests/integration/admin/tasks.rs @@ -2,38 +2,28 @@ use std::time::Duration; use rust::setup::{admin_sqlx, connection_sqlx_direct, connection_sqlx_direct_db}; use sqlx::{Executor, Pool, Postgres, Row}; -use tokio::time::sleep; +use tokio::time::{sleep, timeout}; use super::Tasks; // ─── Constants ────────────────────────────────────────────────────────────── -/// Shared table created in the source `pgdog` database and propagated to the -/// destination shards by schema_sync and copy_data tests. Sequential -/// execution (`test-threads = 1`) means each test owns it exclusively. +/// Source table propagated to the shards; tests run serially and own it exclusively. const TEST_TABLE: &str = "_pgdog_test_task"; -const STOP_TASK_PUB: &str = "pgdog_stop_task_test_pub"; -const STOP_TASK_SLOT: &str = "pgdog_stop_task_test_slot"; -const CUTOVER_PUB: &str = "pgdog_cutover_test_pub"; -const CUTOVER_SLOT: &str = "pgdog_cutover_test_slot"; -const SCHEMA_SYNC_PRE_PUB: &str = "pgdog_schema_sync_pre_test_pub"; -const SCHEMA_SYNC_POST_PUB: &str = "pgdog_schema_sync_post_test_pub"; -const COPY_DATA_PUB: &str = "pgdog_copy_data_test_pub"; - -/// WHERE predicate that matches every replication slot created by these tests. -/// -/// Used verbatim in three consecutive queries inside [`cleanup`]: -/// terminate active WAL senders → wait until inactive → drop. -const SLOT_FILTER: &str = " slot_name LIKE 'pgdog_stop_task_test_slot_%' \ - OR slot_name LIKE 'pgdog_cutover_test_slot_%' \ - OR slot_name LIKE '__pgdog_repl_%'"; +const TEST_PUB: &str = "pgdog_test_pub"; + +/// Matches every replication slot these tests create — all auto-named by +/// COPY_DATA / RESHARD / REPLICATE. Used by [`cleanup`]. +const SLOT_FILTER: &str = "slot_name LIKE '__pgdog_repl_%'"; + +/// Interval between status polls in the `wait_*` / `cleanup` loops below; each +/// loop is bounded by a `timeout(window, …)` rather than a fixed iteration count. +const POLL: Duration = Duration::from_millis(200); // ─── Helpers ──────────────────────────────────────────────────────────────── -/// Drop `table` and its orphaned row-type (left behind by interrupted DDL) -/// from `pool`. Both statements use `IF EXISTS` so the function is safe to -/// call when the objects do not exist. +/// Drop `table` and any orphaned row-type from `pool` (idempotent). async fn drop_table(pool: &Pool, table: &str) { let _ = pool .execute(format!("DROP TABLE IF EXISTS {table} CASCADE").as_str()) @@ -43,12 +33,7 @@ async fn drop_table(pool: &Pool, table: &str) { .await; } -/// Drop `table` from the source database (`direct`) and from every shard, -/// reusing the same [`drop_table`] call for each. -/// -/// `pgdog_sharded` uses shards `shard_0` and `shard_1` in the integration -/// setup. Connection failures for individual shards are ignored so that a -/// single bad shard does not block cleanup of the rest. +/// Drop `table` from the source and both shards; per-shard failures are ignored. async fn drop_table_everywhere(table: &str, direct: &Pool) { drop_table(direct, table).await; for db in &["shard_0", "shard_1"] { @@ -56,33 +41,49 @@ async fn drop_table_everywhere(table: &str, direct: &Pool) { } } -/// Full cleanup for all task tests — idempotent and safe to call as both -/// pre-flight (evict prior-run leftovers) and post-flight (leave state clean). -/// -/// 1. Stop every live PgDog task via `STOP_TASK`. -/// 2. Wait for the task map to drain. -/// 3. Terminate WAL senders still holding any test slot. -/// 4. Wait until all those slots are inactive. -/// 5. Drop the now-inactive test slots. -/// 6. Drop all test publications (`IF EXISTS` — idempotent). -/// 7. Drop [`TEST_TABLE`] from the source database and from every shard. +/// Idempotent cleanup, safe to run before and after each test. async fn cleanup(admin: &Pool, direct: &Pool) { - // 1. Cooperative stop. - for task in &Tasks::fetch(admin).await.rows { - let _ = admin - .execute(format!("STOP_TASK {}", task.id).as_str()) - .await; - } - - // 2. Wait for the task map to drain. - for _ in 0..20 { - if Tasks::fetch(admin).await.is_empty() { - break; + // Drain every task to a terminal state *before* touching anything else. An + // in-flight cutover holds backend pools and, past its point of no return, + // runs to completion (then spawns a reverse-replication task) — a `RELOAD` + // mid-cutover would shut those pools down and fail it with "pool is shut + // down". STOP_TASK every still-running task each pass (the reverse-replication + // task spawned during winddown is caught on a later pass) until all are terminal. + timeout(Duration::from_secs(60), async { + let is_terminal = |status: &str| { + matches!(status, "finished" | "cancelled") + || status.starts_with("failed") + || status.starts_with("panicked") + }; + loop { + let tasks = Tasks::fetch(admin).await; + if tasks.rows.iter().all(|t| is_terminal(t.status.as_str())) { + break; + } + // Only stop tasks that are still running; terminal ones are left as-is + // (re-stopping a reverse-replication task spawned during winddown is + // handled on the next pass once it appears as running). + for task in &tasks.rows { + if is_terminal(task.status.as_str()) { + continue; + } + if let Some(id) = task.id { + let _ = admin.execute(format!("STOP_TASK {id}").as_str()).await; + } + } + sleep(POLL).await; } - sleep(Duration::from_millis(500)).await; - } + }) + .await + .expect("tasks did not drain to a terminal state"); + + // No task is mid-cutover now, so it is safe to restore the topology: a + // completed auto_cutover swaps the db configs in memory (not persisted — + // `cutover_save_config` is off), so RELOAD reloads the pristine on-disk config. + let _ = admin.execute("RELOAD").await; + sleep(Duration::from_millis(500)).await; - // 3. Terminate WAL senders on any test slot. + // Terminate WAL senders on any test slot. let _ = direct .execute( format!( @@ -94,24 +95,28 @@ async fn cleanup(admin: &Pool, direct: &Pool) { ) .await; - // 4. Wait for those slots to deactivate. - for _ in 0..20 { - let any_active = direct - .fetch_optional(sqlx::query(&format!( - "SELECT bool_or(active) AS active FROM pg_replication_slots WHERE {SLOT_FILTER}" - ))) - .await - .ok() - .flatten() - .and_then(|row: sqlx::postgres::PgRow| row.get::, _>("active")) - .unwrap_or(false); - if !any_active { - break; + // Wait for those slots to deactivate. + timeout(Duration::from_secs(10), async { + loop { + let any_active = direct + .fetch_optional(sqlx::query(&format!( + "SELECT bool_or(active) AS active FROM pg_replication_slots WHERE {SLOT_FILTER}" + ))) + .await + .ok() + .flatten() + .and_then(|row: sqlx::postgres::PgRow| row.get::, _>("active")) + .unwrap_or(false); + if !any_active { + break; + } + sleep(POLL).await; } - sleep(Duration::from_millis(500)).await; - } + }) + .await + .expect("replication slots did not deactivate"); - // 5. Drop inactive test slots. + // Drop inactive test slots. let _ = direct .execute( format!( @@ -123,81 +128,101 @@ async fn cleanup(admin: &Pool, direct: &Pool) { ) .await; - // 6. Drop all test publications. - for pub_name in &[ - STOP_TASK_PUB, - CUTOVER_PUB, - SCHEMA_SYNC_PRE_PUB, - SCHEMA_SYNC_POST_PUB, - COPY_DATA_PUB, - ] { - let _ = direct - .execute(format!("DROP PUBLICATION IF EXISTS {pub_name}").as_str()) - .await; - } + // Drop the test publication. + let _ = direct + .execute(format!("DROP PUBLICATION IF EXISTS {TEST_PUB}").as_str()) + .await; - // 7. Drop shared test table from source and every shard. + // Drop the shared test table everywhere. drop_table_everywhere(TEST_TABLE, direct).await; } -/// Start `pgdog` → `pgdog_sharded` replication using a `FOR ALL TABLES` -/// publication. Waits until the task appears in `SHOW TASKS` with kind -/// `"replication"` and returns its id. -async fn start_replication( - pub_name: &str, - slot_name: &str, - admin: &Pool, - direct: &Pool, -) -> i64 { +/// Start `pgdog` -> `pgdog_sharded` replication; return the task id once it +/// appears in `SHOW TASKS`. +async fn start_replication(admin: &Pool, direct: &Pool) -> i64 { admin.execute("RELOAD").await.unwrap(); sleep(Duration::from_millis(500)).await; direct - .execute(format!("CREATE PUBLICATION {pub_name} FOR ALL TABLES").as_str()) + .execute(format!("CREATE PUBLICATION {TEST_PUB} FOR ALL TABLES").as_str()) .await .unwrap(); let row = admin - .fetch_one(format!("REPLICATE pgdog pgdog_sharded {pub_name} {slot_name}").as_str()) + .fetch_one(format!("REPLICATE pgdog pgdog_sharded {TEST_PUB}").as_str()) .await .unwrap(); - // REPLICATE returns task_id as TEXT on the wire. let task_id: i64 = row.get::("task_id").parse().unwrap(); - let mut appeared = false; - for _ in 0..20 { - if Tasks::fetch(admin) - .await - .find(task_id) - .is_some_and(|t| t.kind == "replication pgdog -> pgdog_sharded") - { - appeared = true; - break; + let appeared = timeout(Duration::from_secs(10), async { + loop { + if Tasks::fetch(admin) + .await + .find(task_id) + .is_some_and(|t| t.kind == "replication pgdog -> pgdog_sharded") + { + return; + } + sleep(POLL).await; } - sleep(Duration::from_millis(500)).await; - } + }) + .await; assert!( - appeared, - "replication task {task_id} did not appear in SHOW TASKS within 10s" + appeared.is_ok(), + "replication task {task_id} did not appear in SHOW TASKS in time" ); task_id } -/// Poll until `task_id` is absent from `SHOW TASKS` (up to 30 s). -async fn wait_for_task_gone(admin: &Pool, task_id: i64) { - for _ in 0..60 { - if Tasks::fetch(admin).await.find(task_id).is_none() { - return; +/// Poll until `task_id` reaches `status` in `SHOW TASKS` (up to 30 s). +async fn wait_for_task_status(admin: &Pool, task_id: i64, status: &str) { + let result = timeout(Duration::from_secs(30), async { + loop { + if let Some(task) = Tasks::fetch(admin).await.find(task_id) { + if task.status == status { + return; + } + // Fail fast on an unexpected error state — the status text carries + // the failure message — instead of waiting out the window. + if task.status.starts_with("failed") || task.status.starts_with("panicked") { + panic!( + "task {task_id} errored while waiting for {status:?}: {} (inner_status {:?})", + task.status, task.inner_status + ); + } + } + sleep(POLL).await; } - sleep(Duration::from_millis(500)).await; + }) + .await; + if result.is_err() { + panic!("task {task_id} did not reach status {status:?} in SHOW TASKS in time"); + } +} + +/// Current `SHOW TASKS` status line for `task_id`, for timeout diagnostics. +async fn task_status_line(admin: &Pool, task_id: i64) -> String { + match Tasks::fetch(admin).await.find(task_id) { + Some(t) => format!("status {:?}, inner_status {:?}", t.status, t.inner_status), + None => "task absent from SHOW TASKS".to_string(), } - panic!("task {task_id} still present in SHOW TASKS after 30s"); } -/// Whether the relation `name` (table or index) exists on `db`, resolved -/// through the connection's search_path — these tests create objects in the -/// `pgdog` schema (the `$user` schema for role `pgdog`). +/// Panic with the task's status (which carries the error message) if `task_id` +/// reached an error state. +async fn fail_if_task_errored(admin: &Pool, task_id: i64) { + if let Some(t) = Tasks::fetch(admin).await.find(task_id) + && (t.status.starts_with("failed") || t.status.starts_with("panicked")) + { + panic!( + "task {task_id} errored: {} (inner_status {:?})", + t.status, t.inner_status + ); + } +} + +/// Whether relation `name` exists on `db` (resolved via the connection's search_path). async fn relation_present(pool: &Pool, name: &str) -> bool { pool.fetch_one(format!("SELECT to_regclass('{name}') IS NOT NULL AS present").as_str()) .await @@ -205,24 +230,134 @@ async fn relation_present(pool: &Pool, name: &str) -> bool { .get::("present") } -/// Poll until relation `name` exists on both destination shards (up to 30 s), -/// proving a schema sync actually propagated it. Panics on timeout. -async fn wait_for_relation_on_shards(name: &str) { +/// Poll until relation `name` exists on both shards (up to 30 s), failing fast +/// with the task's status if `task_id` errors meanwhile. +async fn wait_for_relation_on_shards(admin: &Pool, task_id: i64, name: &str) { let shard_0 = connection_sqlx_direct_db("shard_0").await; let shard_1 = connection_sqlx_direct_db("shard_1").await; - for _ in 0..60 { - if relation_present(&shard_0, name).await && relation_present(&shard_1, name).await { - return; + let result = timeout(Duration::from_secs(30), async { + loop { + fail_if_task_errored(admin, task_id).await; + if relation_present(&shard_0, name).await && relation_present(&shard_1, name).await { + return; + } + sleep(POLL).await; + } + }) + .await; + if result.is_err() { + panic!( + "relation {name} did not propagate to all shards in time ({})", + task_status_line(admin, task_id).await + ); + } +} + +/// Rows of `table` on shard `db`, queried directly (bypassing pgdog); an absent +/// table counts 0. +async fn shard_row_count(db: &str, table: &str) -> i64 { + let pool = connection_sqlx_direct_db(db).await; + if !relation_present(&pool, table).await { + return 0; + } + pool.fetch_one(format!("SELECT COUNT(*)::bigint AS n FROM {table}").as_str()) + .await + .unwrap() + .get::("n") +} + +/// Poll until every destination shard holds exactly `expected` rows of `table` +/// (up to 30 s). `table` is omni (not in `sharded_tables`), so each shard ends +/// up with a full copy — the reference-table semantics also asserted by +/// `check_omni_each_shard` in the resharding suite. Fails fast on task error. +async fn wait_for_rows_each_shard( + admin: &Pool, + task_id: i64, + table: &str, + expected: i64, +) { + let result = timeout(Duration::from_secs(30), async { + loop { + fail_if_task_errored(admin, task_id).await; + if shard_row_count("shard_0", table).await == expected + && shard_row_count("shard_1", table).await == expected + { + return; + } + sleep(POLL).await; } - sleep(Duration::from_millis(500)).await; + }) + .await; + if result.is_err() { + panic!( + "table {table} did not reach {expected} rows on each shard in time \ + (shard_0={}, shard_1={}, {})", + shard_row_count("shard_0", table).await, + shard_row_count("shard_1", table).await, + task_status_line(admin, task_id).await + ); } - panic!("relation {name} did not propagate to all shards within 30s"); +} + +/// Poll `SHOW TASKS` until a row satisfies `pred` (up to 30 s). +async fn wait_for_task(admin: &Pool, desc: &str, pred: impl Fn(&super::Task) -> bool) { + let result = timeout(Duration::from_secs(30), async { + loop { + if Tasks::fetch(admin).await.rows.iter().any(&pred) { + return; + } + sleep(POLL).await; + } + }) + .await; + if result.is_err() { + panic!("no task matching {desc:?} appeared in SHOW TASKS in time"); + } +} + +/// Create the test table on the source. +async fn create_test_table(direct: &Pool) { + direct + .execute(format!("CREATE TABLE {TEST_TABLE} (id BIGSERIAL PRIMARY KEY, val TEXT)").as_str()) + .await + .unwrap(); +} + +/// Insert `n` rows into the test table. +async fn seed_rows(direct: &Pool, n: i64) { + direct + .execute( + format!( + "INSERT INTO {TEST_TABLE} (val) SELECT 'v' || g FROM generate_series(1, {n}) g" + ) + .as_str(), + ) + .await + .unwrap(); +} + +/// Create the test publication (`FOR TABLE TEST_TABLE`) on the source. +async fn create_publication(direct: &Pool) { + direct + .execute(format!("CREATE PUBLICATION {TEST_PUB} FOR TABLE {TEST_TABLE}").as_str()) + .await + .unwrap(); +} + +/// Run an admin command that returns a `task_id` column; parse and return it. +async fn run_task_command(admin: &Pool, command: &str) -> i64 { + admin + .fetch_one(command) + .await + .unwrap() + .get::("task_id") + .parse() + .unwrap() } // ─── Tests ────────────────────────────────────────────────────────────────── -/// `STOP_TASK` on an id that does not exist returns `"task not found"` rather -/// than a connection error. +/// `STOP_TASK` on an unknown id returns `"task not found"`. #[tokio::test] async fn test_stop_nonexistent_task() { let admin = admin_sqlx().await; @@ -231,8 +366,7 @@ async fn test_stop_nonexistent_task() { assert_eq!(row.get::("stop_task"), "task not found"); } -/// `CUTOVER` with no replication task running returns a server error; the -/// connection pool stays healthy afterward. +/// `CUTOVER` with no replication task errors but leaves the pool usable. #[tokio::test] async fn test_cutover_without_replication_task() { let direct = connection_sqlx_direct().await; @@ -244,19 +378,17 @@ async fn test_cutover_without_replication_task() { matches!(err, sqlx::Error::Database(_)), "expected a database error, got: {err:?}" ); - // Pool must still be usable. admin.fetch_one("SHOW VERSION").await.unwrap(); } -/// A replication task can be cancelled via `STOP_TASK `, which returns -/// `"OK"` and removes the task from `SHOW TASKS`. +/// `STOP_TASK ` cancels a replication task (returns `"OK"`, status -> `cancelled`). #[tokio::test] async fn test_stop_task() { let direct = connection_sqlx_direct().await; let admin = admin_sqlx().await; cleanup(&admin, &direct).await; - let task_id = start_replication(STOP_TASK_PUB, STOP_TASK_SLOT, &admin, &direct).await; + let task_id = start_replication(&admin, &direct).await; let row = admin .fetch_one(format!("STOP_TASK {task_id}").as_str()) @@ -264,64 +396,84 @@ async fn test_stop_task() { .unwrap(); assert_eq!(row.get::("stop_task"), "OK"); - wait_for_task_gone(&admin, task_id).await; + wait_for_task_status(&admin, task_id, "cancelled").await; cleanup(&admin, &direct).await; } -/// A replication task can also be stopped via `CUTOVER`, which triggers a -/// final sync, returns `"OK"`, and removes the task from `SHOW TASKS`. +/// Operator `CUTOVER` drives a running replication through the full traffic +/// switch to completion. The replication is established with `COPY_DATA` (which +/// loads the schema and leaves a task awaiting an operator cutover — a bare +/// `REPLICATE` never loads the schema, so its cutover would fail in +/// `schema_sync_cutover`). `CUTOVER` returns `"OK"` and the task reaches +/// `finished` once the swap, post-cutover schema sync, and reverse-replication +/// setup all complete. `cleanup` RELOADs to revert the resulting config swap. #[tokio::test] async fn test_cutover() { let direct = connection_sqlx_direct().await; let admin = admin_sqlx().await; cleanup(&admin, &direct).await; - let task_id = start_replication(CUTOVER_PUB, CUTOVER_SLOT, &admin, &direct).await; - - let row = admin.fetch_one("CUTOVER").await.unwrap(); - assert_eq!(row.get::("cutover"), "OK"); + create_test_table(&direct).await; + seed_rows(&direct, 20).await; + create_publication(&direct).await; + + let task_id = + run_task_command(&admin, &format!("COPY_DATA pgdog pgdog_sharded {TEST_PUB}")).await; + + // Wait until the task reaches its replication phase: schema is loaded and it + // is registered to receive an operator CUTOVER. + wait_for_task(&admin, "copy_data replicating", |t| { + t.id == Some(task_id) && t.inner_status == "replicating" + }) + .await; + + // Issue CUTOVER; retry briefly to cover the gap between the replication + // phase being reported and the task registering for operator cutover. + let cutover_ok = timeout(Duration::from_secs(10), async { + loop { + if let Ok(row) = admin.fetch_one("CUTOVER").await + && row.get::("cutover") == "OK" + { + return; + } + sleep(POLL).await; + } + }) + .await; + assert!( + cutover_ok.is_ok(), + "CUTOVER never returned OK ({})", + task_status_line(&admin, task_id).await + ); - wait_for_task_gone(&admin, task_id).await; + // The switch runs to completion: the task ends in `finished`. + wait_for_task_status(&admin, task_id, "finished").await; cleanup(&admin, &direct).await; } -/// `SCHEMA_SYNC pre` syncs the table structure from the source to the -/// destination shards. Asserts the table actually appears on both shards — -/// not merely that the task registered. +/// `SCHEMA_SYNC pre` creates the table on both shards. #[tokio::test] async fn test_schema_sync_pre() { let direct = connection_sqlx_direct().await; let admin = admin_sqlx().await; cleanup(&admin, &direct).await; - direct - .execute(format!("CREATE TABLE {TEST_TABLE} (id SERIAL PRIMARY KEY, val TEXT)").as_str()) - .await - .unwrap(); - direct - .execute( - format!("CREATE PUBLICATION {SCHEMA_SYNC_PRE_PUB} FOR TABLE {TEST_TABLE}").as_str(), - ) - .await - .unwrap(); + create_test_table(&direct).await; + create_publication(&direct).await; - let row = admin - .fetch_one(format!("SCHEMA_SYNC pre pgdog pgdog_sharded {SCHEMA_SYNC_PRE_PUB}").as_str()) - .await - .unwrap(); - // Response carries the task id as TEXT; ensure it parses. - let _task_id: i64 = row.get::("task_id").parse().unwrap(); + let task_id = run_task_command( + &admin, + &format!("SCHEMA_SYNC pre pgdog pgdog_sharded {TEST_PUB}"), + ) + .await; - // cleanup() dropped the table from the shards pre-flight, so its presence - // here proves the pre sync created it on both shards. - wait_for_relation_on_shards(TEST_TABLE).await; + // cleanup dropped the table pre-flight, so its presence proves pre created it. + wait_for_relation_on_shards(&admin, task_id, TEST_TABLE).await; cleanup(&admin, &direct).await; } -/// `SCHEMA_SYNC post` adds indexes/constraints to tables that already exist on -/// the destination. Syncs the table with `pre` first, then asserts `post` -/// propagates a secondary index — an effect `pre` does not produce. +/// `SCHEMA_SYNC post` adds the secondary index, which `pre` does not. #[tokio::test] async fn test_schema_sync_post() { let direct = connection_sqlx_direct().await; @@ -329,70 +481,110 @@ async fn test_schema_sync_post() { cleanup(&admin, &direct).await; let secondary_index = format!("{TEST_TABLE}_val_idx"); - direct - .execute(format!("CREATE TABLE {TEST_TABLE} (id SERIAL PRIMARY KEY, val TEXT)").as_str()) - .await - .unwrap(); + create_test_table(&direct).await; direct .execute(format!("CREATE INDEX {secondary_index} ON {TEST_TABLE} (val)").as_str()) .await .unwrap(); - direct - .execute( - format!("CREATE PUBLICATION {SCHEMA_SYNC_POST_PUB} FOR TABLE {TEST_TABLE}").as_str(), - ) - .await - .unwrap(); + create_publication(&direct).await; - // pre creates the table (and primary key) on the shards, but not the - // secondary index — that is post-data. - admin - .fetch_one(format!("SCHEMA_SYNC pre pgdog pgdog_sharded {SCHEMA_SYNC_POST_PUB}").as_str()) - .await - .unwrap(); - wait_for_relation_on_shards(TEST_TABLE).await; + // pre creates the table but not the secondary index (that's post-data). + let pre_task_id = run_task_command( + &admin, + &format!("SCHEMA_SYNC pre pgdog pgdog_sharded {TEST_PUB}"), + ) + .await; + wait_for_relation_on_shards(&admin, pre_task_id, TEST_TABLE).await; - // post adds the secondary index on both destination shards. - let row = admin - .fetch_one(format!("SCHEMA_SYNC post pgdog pgdog_sharded {SCHEMA_SYNC_POST_PUB}").as_str()) - .await - .unwrap(); - let _task_id: i64 = row.get::("task_id").parse().unwrap(); + // post adds the secondary index on both shards. + let task_id = run_task_command( + &admin, + &format!("SCHEMA_SYNC post pgdog pgdog_sharded {TEST_PUB}"), + ) + .await; - wait_for_relation_on_shards(&secondary_index).await; + wait_for_relation_on_shards(&admin, task_id, &secondary_index).await; cleanup(&admin, &direct).await; } -/// `COPY_DATA` syncs the schema, copies data, then starts replication. Asserts -/// the table is actually created on both destination shards (the schema phase), -/// then `cleanup` stops the long-running replication the task spawns. +/// `COPY_DATA` syncs schema, copies data, then replicates; assert the table and +/// its rows reach both shards. #[tokio::test] async fn test_copy_data() { let direct = connection_sqlx_direct().await; let admin = admin_sqlx().await; cleanup(&admin, &direct).await; - direct - .execute(format!("CREATE TABLE {TEST_TABLE} (id SERIAL PRIMARY KEY, val TEXT)").as_str()) - .await - .unwrap(); - direct - .execute(format!("CREATE PUBLICATION {COPY_DATA_PUB} FOR TABLE {TEST_TABLE}").as_str()) - .await - .unwrap(); + create_test_table(&direct).await; + seed_rows(&direct, 20).await; + create_publication(&direct).await; - // Response: task_id TEXT, replication_slot TEXT. let row = admin - .fetch_one(format!("COPY_DATA pgdog pgdog_sharded {COPY_DATA_PUB}").as_str()) + .fetch_one(format!("COPY_DATA pgdog pgdog_sharded {TEST_PUB}").as_str()) .await .unwrap(); - let _task_id: i64 = row.get::("task_id").parse().unwrap(); + let task_id: i64 = row.get::("task_id").parse().unwrap(); let slot_name: String = row.get("replication_slot"); assert!(!slot_name.is_empty(), "replication_slot must be non-empty"); - // copy_data's schema-sync phase must create the table on both shards. - wait_for_relation_on_shards(TEST_TABLE).await; + // Schema phase creates the table; the omni table is fully copied to every + // shard, so each holds all 20 rows. + wait_for_relation_on_shards(&admin, task_id, TEST_TABLE).await; + wait_for_rows_each_shard(&admin, task_id, TEST_TABLE, 20).await; + + cleanup(&admin, &direct).await; +} + +/// `RESHARD` (auto_cutover) is non-blocking — returns a task id immediately and +/// registers a `reshard` task — copies schema + data to the shards, and enters +/// its auto-cutover replication phase. +#[tokio::test] +async fn test_reshard() { + let direct = connection_sqlx_direct().await; + let admin = admin_sqlx().await; + cleanup(&admin, &direct).await; + + create_test_table(&direct).await; + seed_rows(&direct, 20).await; + create_publication(&direct).await; + + // Non-blocking: RESHARD returns a task id straight away. + let task_id = + run_task_command(&admin, &format!("RESHARD pgdog pgdog_sharded {TEST_PUB}")).await; + + // Registers immediately as a `reshard` task. + wait_for_task(&admin, "reshard task", |t| { + t.id == Some(task_id) && t.kind.starts_with("reshard ") + }) + .await; + + // Schema + data reach both shards (omni table fully copied to each -> 20 per shard). + wait_for_relation_on_shards(&admin, task_id, TEST_TABLE).await; + wait_for_rows_each_shard(&admin, task_id, TEST_TABLE, 20).await; + + // auto_cutover drives into the replication phase on its own. + wait_for_task(&admin, "reshard replicating", |t| { + t.id == Some(task_id) && t.inner_status == "replicating" + }) + .await; + + // Wind down and confirm a terminal state. + let _ = admin.execute(format!("STOP_TASK {task_id}").as_str()).await; + timeout(Duration::from_secs(30), async { + loop { + if Tasks::fetch(&admin) + .await + .find(task_id) + .is_some_and(|t| matches!(t.status.as_str(), "cancelled" | "finished")) + { + return; + } + sleep(POLL).await; + } + }) + .await + .expect("reshard task did not reach a terminal state"); cleanup(&admin, &direct).await; } diff --git a/pgdog-config/src/core.rs b/pgdog-config/src/core.rs index 2c089e9e8..8c2d45d52 100644 --- a/pgdog-config/src/core.rs +++ b/pgdog-config/src/core.rs @@ -637,6 +637,12 @@ impl Config { /// Swap database configs between `source` and `destination`. /// Uses tmp pattern: source -> tmp, destination -> source, tmp -> destination. pub fn cutover(&mut self, source: &str, destination: &str) { + // force setting the database name on cutover to make sure + // the proper database_name is present after the swap. + for db in self.databases.iter_mut() { + db.database_name = db.database_name.take().or(Some(db.name.clone())); + } + let tmp = format!("__tmp_{}__", random_string(12)); crate::swap_field!(self.databases.iter_mut(), name, source, destination, tmp); @@ -1103,6 +1109,57 @@ tables = ["my_table"] ); } + #[test] + fn test_cutover_preserves_physical_database_name() { + // `source_db` relies on the default (physical db == cluster name); + // `destination_db` sets an explicit `database_name`. After cutover the + // physical target of each entry must be unchanged — only the logical + // cluster name swaps. Regression test for entries connecting to a + // nonexistent database after a name swap. + let mut config = Config { + databases: vec![ + Database { + name: "source_db".to_string(), + host: "source-host".to_string(), + port: 5432, + database_name: None, + ..Default::default() + }, + Database { + name: "destination_db".to_string(), + host: "destination-host".to_string(), + port: 5433, + database_name: Some("real_dest".to_string()), + ..Default::default() + }, + ], + ..Default::default() + }; + + config.cutover("source_db", "destination_db"); + + // The entry now named `source_db` carries destination's config; its + // physical target must remain destination's database. + let source = config + .databases + .iter() + .find(|d| d.name == "source_db") + .unwrap(); + assert_eq!(source.host, "destination-host"); + assert_eq!(source.database_name.as_deref(), Some("real_dest")); + + // The entry now named `destination_db` carries source's config. Source + // relied on the default, so its physical target was pinned to the old + // name (`source_db`) — not silently changed to `destination_db`. + let destination = config + .databases + .iter() + .find(|d| d.name == "destination_db") + .unwrap(); + assert_eq!(destination.host, "source-host"); + assert_eq!(destination.database_name.as_deref(), Some("source_db")); + } + #[test] fn test_cutover_visual() { let before = r#" @@ -1183,6 +1240,7 @@ destination_db = "destination_db" let expected_after = r#" [[databases]] name = "destination_db" +database_name = "source_db" host = "source-host-0" port = 5432 role = "primary" @@ -1190,6 +1248,7 @@ shard = 0 [[databases]] name = "destination_db" +database_name = "source_db" host = "source-host-0-replica" port = 5432 role = "replica" @@ -1197,6 +1256,7 @@ shard = 0 [[databases]] name = "destination_db" +database_name = "source_db" host = "source-host-1" port = 5432 role = "primary" @@ -1204,6 +1264,7 @@ shard = 1 [[databases]] name = "destination_db" +database_name = "source_db" host = "source-host-1-replica" port = 5432 role = "replica" @@ -1211,6 +1272,7 @@ shard = 1 [[databases]] name = "source_db" +database_name = "destination_db" host = "destination-host-0" port = 5433 role = "primary" @@ -1218,6 +1280,7 @@ shard = 0 [[databases]] name = "source_db" +database_name = "destination_db" host = "destination-host-0-replica" port = 5433 role = "replica" @@ -1225,6 +1288,7 @@ shard = 0 [[databases]] name = "source_db" +database_name = "destination_db" host = "destination-host-1" port = 5433 role = "primary" @@ -1232,6 +1296,7 @@ shard = 1 [[databases]] name = "source_db" +database_name = "destination_db" host = "destination-host-1-replica" port = 5433 role = "replica" diff --git a/pgdog/Cargo.toml b/pgdog/Cargo.toml index 175b2caa7..523bc706d 100644 --- a/pgdog/Cargo.toml +++ b/pgdog/Cargo.toml @@ -23,7 +23,7 @@ tracing-subscriber = { version = "0.3", features = ["env-filter", "json", "std"] tracing-throttle = "0.4" parking_lot = "0.12" thiserror = "2" -derive_more = { version = "2", features = ["display", "error", "from", "from_str"] } +derive_more = { version = "2", features = ["display", "debug", "error", "from", "from_str"] } bytes = "1" clap = { version = "4", features = ["derive"] } serde = { version = "1", features = ["derive"] } diff --git a/pgdog/src/admin/copy_data.rs b/pgdog/src/admin/copy_data.rs index 3fa6da9dc..e2b122692 100644 --- a/pgdog/src/admin/copy_data.rs +++ b/pgdog/src/admin/copy_data.rs @@ -3,7 +3,7 @@ use tracing::info; use crate::api::resharding::ReshardTask; -use crate::api::start; +use crate::api::run_task; use crate::backend::replication::orchestrator::Orchestrator; use super::prelude::*; @@ -62,7 +62,7 @@ impl Command for CopyData { let slot_name = orchestrator.replication_slot().to_owned(); - let task_id = start(ReshardTask::builder().orchestrator(orchestrator).build()).id(); + let task_id = run_task(ReshardTask::builder().orchestrator(orchestrator).build()).id(); let mut dr = DataRow::new(); dr.add(task_id.to_string()).add(slot_name); diff --git a/pgdog/src/admin/cutover.rs b/pgdog/src/admin/cutover.rs index b8042813b..33c8d72b9 100644 --- a/pgdog/src/admin/cutover.rs +++ b/pgdog/src/admin/cutover.rs @@ -5,7 +5,7 @@ use crate::backend::replication::logical::Error as ReplicationError; use super::prelude::*; pub struct Cutover { - task_id: Option, + task_id: Option, } #[async_trait] @@ -31,7 +31,7 @@ impl Command for Cutover { async fn execute(&self) -> Result, Error> { // With an id, cut over that task; without, the first running one. - if !ReplicationTask::cutover(self.task_id.map(AsyncTaskId::from)) { + if !ReplicationTask::trigger_cutover(self.task_id) { return Err(ReplicationError::NotReplication.into()); } diff --git a/pgdog/src/admin/replicate.rs b/pgdog/src/admin/replicate.rs index 3c2d9820d..c72b263ed 100644 --- a/pgdog/src/admin/replicate.rs +++ b/pgdog/src/admin/replicate.rs @@ -3,7 +3,7 @@ use tracing::info; use crate::api::replication::ReplicationTask; -use crate::api::start; +use crate::api::run_task; use crate::backend::replication::orchestrator::Orchestrator; use super::prelude::*; @@ -61,7 +61,7 @@ impl Command for Replicate { )?; let waiter = orchestrator.replicate().await?; - let task_id = start(ReplicationTask::builder().waiter(waiter).build()).id(); + let task_id = run_task(ReplicationTask::builder().waiter(waiter).build()).id(); let mut dr = DataRow::new(); dr.add(task_id.to_string()); diff --git a/pgdog/src/admin/reshard.rs b/pgdog/src/admin/reshard.rs index 695ac6e79..b0b5ffe71 100644 --- a/pgdog/src/admin/reshard.rs +++ b/pgdog/src/admin/reshard.rs @@ -3,7 +3,7 @@ use tracing::info; use crate::api::resharding::ReshardTask; -use crate::api::start; +use crate::api::run_task; use crate::backend::replication::orchestrator::Orchestrator; use super::prelude::*; @@ -59,7 +59,7 @@ impl Command for Reshard { self.replication_slot.clone(), )?; - let task_id = start( + let task_id = run_task( ReshardTask::builder() .orchestrator(orchestrator) .auto_cutover(true) diff --git a/pgdog/src/admin/schema_sync.rs b/pgdog/src/admin/schema_sync.rs index 8ed27944d..2fe2b6c27 100644 --- a/pgdog/src/admin/schema_sync.rs +++ b/pgdog/src/admin/schema_sync.rs @@ -2,8 +2,8 @@ use tracing::info; +use crate::api::run_task; use crate::api::schema_sync::{SchemaSyncPhase, SchemaSyncTask}; -use crate::api::start; use crate::backend::replication::orchestrator::Orchestrator; use super::prelude::*; @@ -70,7 +70,7 @@ impl Command for SchemaSync { self.replication_slot.clone(), )?; - let task_id = start( + let task_id = run_task( SchemaSyncTask::builder() .orchestrator(orchestrator) .phase(self.phase) diff --git a/pgdog/src/admin/show_tasks.rs b/pgdog/src/admin/show_tasks.rs index 65a40c769..73e234f35 100644 --- a/pgdog/src/admin/show_tasks.rs +++ b/pgdog/src/admin/show_tasks.rs @@ -2,7 +2,8 @@ use std::time::SystemTime; use chrono::{DateTime, Local}; -use crate::api::storage; +use crate::api::tasks_storage; +use crate::net::data_row::Data; use crate::util::{format_time, human_duration_display}; use super::prelude::*; @@ -22,10 +23,10 @@ impl Command for ShowTasks { async fn execute(&self) -> Result, Error> { let rd = RowDescription::new(&[ Field::bigint("id"), - Field::bigint("root_id"), Field::text("scope"), Field::text("type"), Field::text("status"), + Field::text("inner_status"), Field::text("started_at"), Field::text("updated_at"), Field::text("elapsed"), @@ -34,35 +35,47 @@ impl Command for ShowTasks { let mut messages = vec![rd.message()?]; let now = SystemTime::now(); - for task in storage().tasks() { + // Most-recent task first (highest id). + for task in tasks_storage().tasks().into_iter().rev() { // A root task plus its subtasks (e.g. the replication child of a - // copy_data/reshard task). Each row carries its own `id` and the - // `root_id` it belongs to — only root tasks are cancellable, so - // STOP_TASK targets `root_id`. Terminal tasks are retained for - // reporting but filtered out here. + // copy_data/reshard task). Every row reports the root task's id as + // `id` — that is the cancellable handle (STOP_TASK targets it); the + // `scope` column distinguishes the root from its subtasks. Terminal + // tasks stay listed with their final status until pruned from the map. let root_id = task.id; - let entries = std::iter::once((task.id, true, &task.state)) - .chain(task.subtasks.iter().map(|sub| (sub.id, false, &sub.state))); + let entries = std::iter::once((true, &task.state)) + .chain(task.subtasks.iter().map(|sub| (false, &sub.state))); - for (id, is_root, state) in entries { - if state.is_terminal() { - continue; - } - - let elapsed = now.duration_since(state.started_at).unwrap_or_default(); + for (is_root, state) in entries { + // Terminal tasks are retained after completion; measure their + // elapsed to the final transition (`updated_at`), not `now`, + // so the duration reflects the actual run rather than ticking + // up until the task is pruned. + let end = if state.is_terminal() { + state.updated_at + } else { + now + }; + let elapsed = end.duration_since(state.started_at).unwrap_or_default(); let elapsed_ms = elapsed.as_millis() as i64; let elapsed_str = human_duration_display(elapsed); let started_at_str = format_time(DateTime::::from(state.started_at)); let updated_at_str = format_time(DateTime::::from(state.updated_at)); let status_str = state.status.to_string(); + let inner_str = state.inner_status.clone().unwrap_or_default(); let scope = if is_root { "root" } else { "subtask" }; let mut row = DataRow::new(); - row.add(id) - .add(root_id) - .add(scope) + // Subtasks share their root's id; only the root row carries it. + if is_root { + row.add(root_id); + } else { + row.add(Data::null()); + } + row.add(scope) .add(state.name.as_str()) .add(status_str.as_str()) + .add(inner_str.as_str()) .add(started_at_str.as_str()) .add(updated_at_str.as_str()) .add(elapsed_str.as_str()) diff --git a/pgdog/src/admin/stop_task.rs b/pgdog/src/admin/stop_task.rs index c7f07cac0..7e06211d7 100644 --- a/pgdog/src/admin/stop_task.rs +++ b/pgdog/src/admin/stop_task.rs @@ -1,10 +1,10 @@ use crate::api::async_task::AsyncTaskId; -use crate::api::storage; +use crate::api::tasks_storage; use super::prelude::*; pub struct StopTask { - task_id: u64, + task_id: AsyncTaskId, } #[async_trait] @@ -26,7 +26,7 @@ impl Command for StopTask { } async fn execute(&self) -> Result, Error> { - let cancelled = storage().cancel_task(AsyncTaskId::from(self.task_id)); + let cancelled = tasks_storage().cancel_task(self.task_id); let mut messages = vec![]; diff --git a/pgdog/src/api/async_task.rs b/pgdog/src/api/async_task.rs index 45b2899dd..b97e200b4 100644 --- a/pgdog/src/api/async_task.rs +++ b/pgdog/src/api/async_task.rs @@ -9,37 +9,34 @@ use std::time::{Duration, SystemTime}; use dashmap::DashMap; use parking_lot::RwLock; +use pgdog_postgres_types::ToDataRowColumn; use tokio::select; use tokio::sync::oneshot::{self, Receiver}; use tokio::time::timeout; use tokio_util::sync::CancellationToken; +use tracing::warn; -#[derive(Copy, Clone, Debug, Display, Hash, PartialEq, Eq, PartialOrd, Ord)] +/// Represent the ID of the async task. +#[derive(Copy, Clone, Debug, Display, FromStr, Hash, PartialEq, Eq, PartialOrd, Ord)] pub struct AsyncTaskId(u64); -impl From for AsyncTaskId { - fn from(id: u64) -> Self { - Self(id) +impl ToDataRowColumn for AsyncTaskId { + fn to_data_row_column(&self) -> pgdog_postgres_types::Data { + self.0.to_data_row_column() } } -impl From for u64 { - fn from(id: AsyncTaskId) -> Self { - id.0 +#[cfg(test)] +impl From for AsyncTaskId { + fn from(value: u64) -> Self { + AsyncTaskId(value) } } /// Status type for tasks that report no intermediate progress. -/// -/// [`Infallible`](std::convert::Infallible) is uninhabited, so a task -/// with this status type can never call -/// [`set_status`](AsyncTaskContext::set_status). pub type Empty = std::convert::Infallible; -/// A composable background task: a value carrying its own arguments, -/// driven to completion by [`run`](Task::run). Launch it top-level with -/// [`AsyncTasksStorage::run`], or nested under a running task with -/// [`AsyncTaskContext::run`]. +/// A composable async task. pub trait Task: Display + Debug + Send + Sync + Sized + 'static { /// Progress-status payload reported through /// [`set_status`](AsyncTaskContext::set_status) while the task @@ -56,19 +53,33 @@ pub trait Task: Display + Debug + Send + Sync + Sized + 'static { Duration::from_secs(5) } + /// Async task main execution logic fn run( self, ctx: AsyncTaskContext, ) -> impl Future> + Send + 'static; } -#[derive(Display, Debug, Clone)] -pub enum TaskStatus { +/// Predefined lifecycle status of a task — a fixed, enumerable set, +/// independent of the task's domain-specific progress (which is tracked +/// separately as the inner status). +#[derive(Display, Debug, Clone, PartialEq, Eq)] +pub enum TaskStatus { + #[display("started")] Started, - Pending(S), + #[display("running")] + Running, + #[display("finished")] Finished, + /// Cancellation has been requested; the task is winding down + /// cooperatively and has not yet reached a terminal state. + #[display("cancelling")] + Cancelling, + #[display("cancelled")] Cancelled, + #[display("failed: {_0}")] Error(String), + #[display("panicked: {_0}")] Panic(String), } @@ -77,7 +88,13 @@ pub enum TaskStatus { #[derive(Debug, Clone)] pub struct TaskState { pub name: String, - pub status: TaskStatus, + /// Predefined lifecycle status (carries the error/panic message + /// for its terminal failure variants). + pub status: TaskStatus, + /// Last inner progress reported by the task, rendered to a string. + /// Preserved across terminal transitions, so a failed or cancelled + /// task still shows its last known progress. + pub inner_status: Option, pub started_at: SystemTime, pub updated_at: SystemTime, } @@ -97,8 +114,10 @@ pub enum TaskError { /// The task itself returned an error. #[display("task failed: {_0}")] Failed(E), + /// The task was cancelled. #[display("task was cancelled")] Cancelled, + /// The task panicked. #[display("task panicked: {_0}")] Panicked(#[error(ignore)] String), /// The task's result was never delivered: the watcher @@ -107,32 +126,17 @@ pub enum TaskError { Abandoned, } -impl TaskStatus { - /// Terminal states are write-once; late writers - /// (e.g. ctx clones outliving the task) are ignored. +impl TaskStatus { + /// Whether the task reached a terminal state. fn is_terminal(&self) -> bool { matches!( self, Self::Finished | Self::Cancelled | Self::Error(_) | Self::Panic(_) ) } - - /// Snapshot for the registry: keep the variant, render `T`. - fn stringify(&self) -> TaskStatus - where - S: Display, - { - match self { - Self::Started => TaskStatus::Started, - Self::Pending(status) => TaskStatus::Pending(status.to_string()), - Self::Finished => TaskStatus::Finished, - Self::Cancelled => TaskStatus::Cancelled, - Self::Error(err) => TaskStatus::Error(err.clone()), - Self::Panic(msg) => TaskStatus::Panic(msg.clone()), - } - } } +/// Represent the storage of tasks based on it's id #[derive(Default)] struct TasksMap { map: DashMap>, @@ -149,9 +153,16 @@ impl TasksMap { } } +/// Mutable state of the async task that is updated +/// during the execution and status updates. struct AsyncTaskState { updated_at: SystemTime, - status: TaskStatus, + /// Predefined lifecycle status (carries the error/panic message for + /// its terminal failure variants). + status: TaskStatus, + /// Last inner progress reported by the task; kept across terminal + /// transitions so failed/cancelled tasks retain their last progress. + inner_status: Option, } impl AsyncTaskState { @@ -159,32 +170,36 @@ impl AsyncTaskState { Self { updated_at: SystemTime::now(), status: TaskStatus::Started, + inner_status: None, } } } +/// Represent the info about queued task struct AsyncTask { started_at: SystemTime, /// Id of the root task this task belongs to — its own id when it is a - /// root, inherited from the parent otherwise. Stable for the task's - /// lifetime and used to address its cutover signal. + /// root, inherited from the parent otherwise. root_id: AsyncTaskId, + /// Name of task based on [Task] Display implementation name: String, cancellation_token: CancellationToken, /// Set once the task asks for its cancellation token: only - /// then can it react to cancellation, so only then is the - /// cooperative-shutdown grace period worth waiting out. + /// then can it react to cancellation, so only then we'll + /// wait for the cancellation to finish. cooperative: AtomicBool, + /// Mutable state of the task state: Arc>>, + /// The reference to the root map of tasks subtasks: Arc, } +/// Wrapper trait for [AsyncTask] that is not tied to specific +/// type T that allows to store tasks of different types inside. trait TaskMapEntry: Send + Sync + 'static { fn cancel(&self); fn state(&self) -> TaskState; fn subtasks(&self) -> &TasksMap; - /// Cheap expiry check for pruning: terminal and older than `ttl`, - /// without building a full [`TaskState`]. fn expired(&self, now: SystemTime, ttl: Duration) -> bool; } @@ -198,7 +213,8 @@ impl TaskMapEntry for AsyncTask { TaskState { name: self.name.clone(), - status: state.status.stringify(), + status: state.status.clone(), + inner_status: state.inner_status.as_ref().map(ToString::to_string), started_at: self.started_at, updated_at: state.updated_at, } @@ -217,6 +233,7 @@ impl TaskMapEntry for AsyncTask { } } +/// Context that is passed to the [Task::run] pub struct AsyncTaskContext { task: Arc>, } @@ -231,8 +248,10 @@ impl Clone for AsyncTaskContext { /// Handle to a spawned task. Resolves, as a future, to the /// task's result; also exposes the registry id of the task. +#[derive(derive_more::Debug)] pub struct AsyncTaskWaiter { id: AsyncTaskId, + #[debug(ignore)] waiter: Receiver>>, } @@ -242,14 +261,6 @@ impl AsyncTaskWaiter { } } -impl Debug for AsyncTaskWaiter { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("AsyncTaskWaiter") - .field("id", &self.id) - .finish_non_exhaustive() - } -} - impl Future for AsyncTaskWaiter { type Output = Result>; @@ -301,6 +312,7 @@ impl TaskSnapshot { /// before being pruned. const TASK_RETENTION: Duration = Duration::from_secs(24 * 60 * 60); +/// The main storage for async tasks pub struct AsyncTasksStorage { tasks: Arc, retention: Duration, @@ -312,17 +324,15 @@ impl Default for AsyncTasksStorage { } } -fn run_task( - parent_token: Option<&CancellationToken>, - register_into: &TasksMap, - subtasks: Arc, - root: Option, +fn run_task( + parent_task: Option<&AsyncTask

>, + tasks: &TasksMap, task: T, ) -> AsyncTaskWaiter { // Allocate the id up front so a root task can record its own id as its // root id; descendants inherit the root's id from their parent. - let id = register_into.next_id(); - let root_id = root.unwrap_or(id); + let id = tasks.next_id(); + let root_id = parent_task.map(|p| p.root_id).unwrap_or(id); let state = Arc::new(RwLock::new(AsyncTaskState::new())); @@ -330,20 +340,24 @@ fn run_task( started_at: SystemTime::now(), name: task.to_string(), root_id, - cancellation_token: match parent_token { - Some(token) => token.child_token(), + cancellation_token: match parent_task { + Some(parent) => parent.cancellation_token.child_token(), None => CancellationToken::new(), }, cooperative: AtomicBool::new(false), // Descendants share the root task's registry: every descendant // registers as a direct child of the root. - subtasks, + subtasks: if let Some(parent) = parent_task { + parent.subtasks.clone() + } else { + Arc::new(TasksMap::default()) + }, state: state.clone(), }; let entry = Arc::new(entry); // Make sure we insert task to map before it's actually started. - register_into.insert(id, entry.clone()); + tasks.insert(id, entry.clone()); let ctx = AsyncTaskContext { task: entry.clone(), @@ -357,18 +371,19 @@ fn run_task( tokio::spawn(async move { let res = select! { _ = cancellation_token.cancelled() => { + ctx.transition(TaskStatus::Cancelling); if ctx.task.cooperative.load(Ordering::Relaxed) { match timeout(T::cancel_timeout(), &mut handle).await { Ok(res) => res, Err(_) => { + // The timeout fired: abort the task handle.abort(); handle.await } } } else { // The task never took its cancellation token, so it - // cannot react to it: abort right away instead of - // letting it run on through the grace period. + // cannot react to it: abort immediately. handle.abort(); handle.await } @@ -380,20 +395,25 @@ fn run_task( match res { Ok(Ok(res)) => { - ctx.set_inner_status(TaskStatus::Finished); + let status = if cancellation_token.is_cancelled() { + TaskStatus::Cancelled + } else { + TaskStatus::Finished + }; + ctx.transition(status); let _ = sender.send(Ok(res)); } Ok(Err(err)) => { - ctx.set_inner_status(TaskStatus::Error(err.to_string())); + ctx.transition(TaskStatus::Error(err.to_string())); let _ = sender.send(Err(TaskError::Failed(err))); } Err(err) if err.is_cancelled() => { - ctx.set_inner_status(TaskStatus::Cancelled); + ctx.transition(TaskStatus::Cancelled); let _ = sender.send(Err(TaskError::Cancelled)); } Err(err) => { let panic = err.to_string(); - ctx.set_inner_status(TaskStatus::Panic(panic.clone())); + ctx.transition(TaskStatus::Panic(panic.clone())); let _ = sender.send(Err(TaskError::Panicked(panic))); } } @@ -406,7 +426,10 @@ fn run_task( } impl AsyncTaskContext { - fn set_inner_status(&self, status: TaskStatus) { + /// Move the task to a new lifecycle `status` (terminal or the + /// non-terminal `Cancelling`), preserving the last inner progress. + /// No-op once the task is already terminal. + fn transition(&self, status: TaskStatus) { let mut state = self.task.state.write(); if state.status.is_terminal() { return; @@ -415,29 +438,32 @@ impl AsyncTaskContext { state.updated_at = SystemTime::now(); } + /// Update the inner progress status of the current task. pub fn set_status(&self, status: T::Status) { - self.set_inner_status(TaskStatus::Pending(status)); + let mut state = self.task.state.write(); + if state.status.is_terminal() { + return; + } + // Don't regress a cancellation-in-progress back to Running; the task + // may still report inner progress while it winds down. + if state.status != TaskStatus::Cancelling { + state.status = TaskStatus::Running; + } + state.inner_status = Some(status); + state.updated_at = SystemTime::now(); } - /// Hand out this task's cancellation token. Taking the token - /// opts the task into cooperative shutdown: on `cancel_task` - /// it gets [`Task::cancel_timeout`] to wind down - /// before being aborted. Tasks that never take it are aborted - /// immediately. Take it early. + /// Hand out this task's cancellation token. + /// Tasks that never take it are aborted immediately. pub fn cancellation_token(&self) -> CancellationToken { self.task.cooperative.store(true, Ordering::Relaxed); self.task.cancellation_token.clone() } + /// Run the new task as a subtask of the current one pub fn run(&self, task: T1) -> AsyncTaskWaiter { - run_task( - Some(&self.task.cancellation_token), - &self.task.subtasks, - self.task.subtasks.clone(), - Some(self.task.root_id), - task, - ) + run_task(Some(&self.task), &self.task.subtasks, task) } /// Id of the root task this task belongs to (its own id when it is a @@ -455,10 +481,11 @@ impl AsyncTasksStorage { } } + /// Schedule the new task as a root task for execution pub fn run(&self, task: T) -> AsyncTaskWaiter { self.prune(); - run_task(None, &self.tasks, Arc::new(TasksMap::default()), None, task) + run_task::(None, &self.tasks, task) } /// Request cancellation of a task. The task winds down @@ -474,6 +501,7 @@ impl AsyncTasksStorage { // status reporting: nothing to cancel. Report as not found so callers // don't claim success or emit cleanup warnings for a finished task. if state.is_terminal() { + warn!("Task: {id} is already in terminal state and cannot be cancelled"); return None; } @@ -527,11 +555,12 @@ mod tests { use super::*; use parking_lot::Mutex; use std::convert::Infallible; + use std::fmt::Debug; use std::sync::Arc; use tokio::sync::Notify; use tokio::task::yield_now; + use tokio::test; use tokio::time::sleep; - use tokio::{join, test}; type State = Arc>; @@ -543,36 +572,27 @@ mod tests { StepTwo, } - /// Sets "started", waits on `notify`, then "finished" and succeeds. - #[derive(Display, Debug)] - #[display("mock")] - struct MockSuccessful { - state: State, - notify: Arc, + /// What a [`Mock`] does once notified. + #[derive(Clone, Copy, Debug)] + enum Outcome { + /// Mark "finished" and succeed. + Succeed, + /// Mark "failed" and return an error. + Fail, + /// Panic, leaving the state at "started". + Panic, } - impl Task for MockSuccessful { - type Status = Empty; - type Output = (); - type Error = Infallible; - - async fn run(self, _ctx: AsyncTaskContext) -> Result<(), Infallible> { - *self.state.lock() = "started"; - self.notify.notified().await; - *self.state.lock() = "finished"; - Ok(()) - } - } - - /// Sets "started", waits on `notify`, then "failed" and errors. + /// Sets "started", waits on `notify`, then resolves per `outcome`. #[derive(Display, Debug)] #[display("mock")] - struct MockFailing { + struct Mock { state: State, notify: Arc, + outcome: Outcome, } - impl Task for MockFailing { + impl Task for Mock { type Status = Empty; type Output = (); type Error = std::io::Error; @@ -580,31 +600,52 @@ mod tests { async fn run(self, _ctx: AsyncTaskContext) -> Result<(), std::io::Error> { *self.state.lock() = "started"; self.notify.notified().await; - *self.state.lock() = "failed"; - Err(std::io::Error::other("mock task failure")) + match self.outcome { + Outcome::Succeed => { + *self.state.lock() = "finished"; + Ok(()) + } + Outcome::Fail => { + *self.state.lock() = "failed"; + Err(std::io::Error::other("mock task failure")) + } + Outcome::Panic => panic!("panicking task"), + } } } macro_rules! mock_successful { ($state:ident, $notify:ident) => { - MockSuccessful { + Mock { state: $state.clone(), notify: $notify.clone(), + outcome: Outcome::Succeed, } }; } macro_rules! mock_failing { ($state:ident, $notify:ident) => { - MockFailing { + Mock { + state: $state.clone(), + notify: $notify.clone(), + outcome: Outcome::Fail, + } + }; + } + + macro_rules! mock_panicking { + ($state:ident, $notify:ident) => { + Mock { state: $state.clone(), notify: $notify.clone(), + outcome: Outcome::Panic, } }; } /// Waits on its gate, then succeeds. Never takes its cancellation - /// token, so it is aborted (not wound down) on cancellation. + /// token, so it is aborted immediately on cancellation. #[derive(Display, Debug)] #[display("anonymous")] struct Gate { @@ -637,52 +678,26 @@ mod tests { } } - /// Runs two children concurrently and joins them, then finishes. + /// Runs a single child task of any type (spawned via `ctx.run`), + /// propagating its error if it fails, then waits on `notify` + /// before finishing. #[derive(Display, Debug)] #[display("inner")] - struct InnerJoin { - state: State, - a: MockSuccessful, - b: MockSuccessful, - } - - impl Task for InnerJoin { - type Status = Empty; - type Output = (); - type Error = Infallible; - - async fn run(self, ctx: AsyncTaskContext) -> Result<(), Infallible> { - *self.state.lock() = "started"; - - let (a, b) = join!(ctx.run(self.a), ctx.run(self.b)); - a.unwrap(); - b.unwrap(); - - *self.state.lock() = "finished"; - - Ok(()) - } - } - - /// Runs two children concurrently, then waits on `notify`. - #[derive(Display, Debug)] - #[display("inner")] - struct InnerCancel { + struct Inner { state: State, notify: Arc, - a: MockSuccessful, - b: MockSuccessful, + child: C, } - impl Task for InnerCancel { + impl Task for Inner { type Status = Empty; type Output = (); - type Error = Infallible; + type Error = TaskError; - async fn run(self, ctx: AsyncTaskContext) -> Result<(), Infallible> { + async fn run(self, ctx: AsyncTaskContext) -> Result<(), TaskError> { *self.state.lock() = "started"; - let _ = join!(ctx.run(self.a), ctx.run(self.b)); + ctx.run(self.child).await?; self.notify.notified().await; @@ -692,81 +707,43 @@ mod tests { } } - /// Runs a failing child and asserts it failed. - #[derive(Display, Debug)] - #[display("inner")] - struct InnerError { - state: State, - a: MockFailing, - } - - impl Task for InnerError { - type Status = Empty; - type Output = (); - type Error = Infallible; - - async fn run(self, ctx: AsyncTaskContext) -> Result<(), Infallible> { - *self.state.lock() = "started"; - - let res = ctx.run(self.a).await; - assert!(matches!(res, Err(TaskError::Failed(_)))); - - *self.state.lock() = "inner_failed"; - - Ok(()) - } - } - - /// Takes its cancellation token and winds down gracefully within - /// the grace window, then succeeds with a value. + /// Takes its cancellation token, then waits on `notify` or + /// cancellation. On cancellation it winds down for `wind_down` + /// before stopping, returning `true`: shorter than the 30s grace + /// window delivers a graceful `Ok(true)`, longer is force-aborted + /// when the grace period expires. #[derive(Display, Debug)] - #[display("graceful")] - struct Graceful { + #[display("cancellable")] + struct Cancellable { state: State, + notify: Arc, + wind_down: Duration, } - impl Task for Graceful { + impl Task for Cancellable { type Status = Empty; - type Output = i32; + type Output = bool; type Error = Infallible; - async fn run(self, ctx: AsyncTaskContext) -> Result { - *self.state.lock() = "started"; - - ctx.cancellation_token().cancelled().await; - - // Cooperative wind-down, well within the 5s grace window. - sleep(Duration::from_secs(1)).await; - - *self.state.lock() = "graceful"; - - Ok(42) + fn cancel_timeout() -> Duration { + Duration::from_secs(30) } - } - - /// Panics after being notified, if still in the "started" state. - #[derive(Display, Debug)] - #[display("panicker")] - struct Panicker { - state: State, - notify: Arc, - } - impl Task for Panicker { - type Status = Empty; - type Output = (); - type Error = Infallible; - - async fn run(self, _ctx: AsyncTaskContext) -> Result<(), Infallible> { + async fn run(self, ctx: AsyncTaskContext) -> Result { *self.state.lock() = "started"; - self.notify.notified().await; - - if *self.state.lock() == "started" { - panic!("panicking task"); + let token = ctx.cancellation_token(); + tokio::select! { + _ = self.notify.notified() => {} + _ = token.cancelled() => { + sleep(self.wind_down).await; + *self.state.lock() = "cancelled"; + return Ok(true); + } } - Ok(()) + *self.state.lock() = "finished"; + Ok(true) } } @@ -825,30 +802,6 @@ mod tests { } } - /// Takes its cancellation token (opting into the grace period) but - /// ignores it, with an overridden 30s grace period. - #[derive(Display, Debug)] - #[display("slow_exit")] - struct SlowExit { - notify: Arc, - } - - impl Task for SlowExit { - type Status = Empty; - type Output = (); - type Error = Infallible; - - fn cancel_timeout() -> Duration { - Duration::from_secs(30) - } - - async fn run(self, ctx: AsyncTaskContext) -> Result<(), Infallible> { - let _token = ctx.cancellation_token(); - self.notify.notified().await; - Ok(()) - } - } - /// Spawned tasks are only guaranteed to be polled after the /// current task yields; a few rounds cover spawn chains. async fn settle() { @@ -925,14 +878,14 @@ mod tests { #[test] async fn test_inner_execution() { + let child_notify = Arc::new(Notify::new()); let notify = Arc::new(Notify::new()); let state_a = Arc::new(Mutex::new("initial")); - let state_b = Arc::new(Mutex::new("initial")); let state_c = Arc::new(Mutex::new("initial")); - let c = InnerJoin { + let c = Inner { state: state_c.clone(), - a: mock_successful!(state_a, notify), - b: mock_successful!(state_b, notify), + notify: notify.clone(), + child: mock_successful!(state_a, child_notify), }; let async_storage = AsyncTasksStorage::default(); @@ -942,15 +895,20 @@ mod tests { settle().await; assert_eq!(*state_a.lock(), "started"); - assert_eq!(*state_b.lock(), "started"); assert_eq!(*state_c.lock(), "started"); - notify.notify_waiters(); + // Release the child; the root then parks on its gate. + child_notify.notify_waiters(); + settle().await; + + assert_eq!(*state_a.lock(), "finished"); + assert_eq!(*state_c.lock(), "started"); + + // Open the gate; the root finishes. + notify.notify_one(); task_c.await.unwrap(); - assert_eq!(*state_a.lock(), "finished"); - assert_eq!(*state_b.lock(), "finished"); assert_eq!(*state_c.lock(), "finished"); } @@ -983,41 +941,80 @@ mod tests { #[test(start_paused = true)] async fn test_graceful_exit_during_cancel_grace() { + let notify = Arc::new(Notify::new()); let state = Arc::new(Mutex::new("initial")); - let a = Graceful { + let a = Cancellable { state: state.clone(), + notify: notify.clone(), + wind_down: Duration::from_secs(1), }; let async_storage = AsyncTasksStorage::default(); let task = async_storage.run(a); + let task_id = task.id(); settle().await; assert_eq!(*state.lock(), "started"); - async_storage.cancel_task(task.id()); + async_storage.cancel_task(task_id); // The task observed the token and finished gracefully: its // result must be delivered, not discarded or lost to a // watcher panic. let res = task.await; - assert_eq!(res.unwrap(), 42); + assert!(res.unwrap()); + + assert_eq!(*state.lock(), "cancelled"); + + // Cancellation was requested, so the terminal status is `Cancelled` + // even though the task wound down cleanly to `Ok(42)`. + let snapshot = async_storage.task(task_id).unwrap(); + assert!(matches!(snapshot.state.status, TaskStatus::Cancelled)); + } + + #[test(start_paused = true)] + async fn test_cancelling_status_visible_during_wind_down() { + let notify = Arc::new(Notify::new()); + let state = Arc::new(Mutex::new("initial")); + let a = Cancellable { + state: state.clone(), + notify, + wind_down: Duration::from_secs(5), + }; + + let async_storage = AsyncTasksStorage::default(); + + let task = async_storage.run(a); + let task_id = task.id(); + + settle().await; + async_storage.cancel_task(task_id); + settle().await; + + // Cancellation requested; the task is still winding down (sleeping + // `wind_down`) and must report a non-terminal `Cancelling` status. + let snapshot = async_storage.task(task_id).unwrap(); + assert!(matches!(snapshot.state.status, TaskStatus::Cancelling)); + assert!(!snapshot.state.is_terminal()); - assert_eq!(*state.lock(), "graceful"); + // Once it finishes winding down, the status settles to terminal. + let res = task.await; + assert!(res.unwrap()); + let snapshot = async_storage.task(task_id).unwrap(); + assert!(matches!(snapshot.state.status, TaskStatus::Cancelled)); } #[test(start_paused = true)] async fn test_inner_cancel() { let notify = Arc::new(Notify::new()); let state_a = Arc::new(Mutex::new("initial")); - let state_b = Arc::new(Mutex::new("initial")); let state_c = Arc::new(Mutex::new("initial")); - let c = InnerCancel { + let c = Inner { state: state_c.clone(), notify: notify.clone(), - a: mock_successful!(state_a, notify), - b: mock_successful!(state_b, notify), + child: mock_successful!(state_a, notify), }; let async_storage = AsyncTasksStorage::default(); @@ -1027,7 +1024,6 @@ mod tests { settle().await; assert_eq!(*state_a.lock(), "started"); - assert_eq!(*state_b.lock(), "started"); assert_eq!(*state_c.lock(), "started"); async_storage.cancel_task(task_c.id()); @@ -1036,7 +1032,6 @@ mod tests { assert!(matches!(res, Err(TaskError::Cancelled))); assert_eq!(*state_a.lock(), "started"); - assert_eq!(*state_b.lock(), "started"); assert_eq!(*state_c.lock(), "started"); } @@ -1071,9 +1066,10 @@ mod tests { let notify = Arc::new(Notify::new()); let state_a = Arc::new(Mutex::new("initial")); let state_c = Arc::new(Mutex::new("initial")); - let c = InnerError { + let c = Inner { state: state_c.clone(), - a: mock_failing!(state_a, notify), + notify: notify.clone(), + child: mock_failing!(state_a, notify), }; let async_storage = AsyncTasksStorage::default(); @@ -1087,20 +1083,19 @@ mod tests { notify.notify_one(); - task_c.await.unwrap(); + // The child's failure propagates out of the root task. + let res = task_c.await; + assert!(matches!(res, Err(TaskError::Failed(_)))); assert_eq!(*state_a.lock(), "failed"); - assert_eq!(*state_c.lock(), "inner_failed"); + assert_eq!(*state_c.lock(), "started"); } #[test] async fn test_panic() { let notify = Arc::new(Notify::new()); let state_a = Arc::new(Mutex::new("initial")); - let a = Panicker { - state: state_a.clone(), - notify: notify.clone(), - }; + let a = mock_panicking!(state_a, notify); let async_storage = AsyncTasksStorage::default(); @@ -1146,7 +1141,8 @@ mod tests { assert_eq!(snapshot.id, id); assert_eq!(snapshot.state.name, "test_task"); assert!( - matches!(&snapshot.state.status, TaskStatus::Pending(s) if s.as_str() == "StepOne") + snapshot.state.status == TaskStatus::Running + && snapshot.state.inner_status.as_deref() == Some("StepOne") ); // Both the subtask and its grandchild appear as direct @@ -1164,7 +1160,8 @@ mod tests { // Subtask finished, parent moved to the next phase. let snapshot = async_storage.task(id).unwrap(); assert!( - matches!(&snapshot.state.status, TaskStatus::Pending(s) if s.as_str() == "StepTwo") + snapshot.state.status == TaskStatus::Running + && snapshot.state.inner_status.as_deref() == Some("StepTwo") ); for sub in &snapshot.subtasks { assert!(matches!(sub.state.status, TaskStatus::Finished)); @@ -1174,44 +1171,62 @@ mod tests { task.await.unwrap(); - // Terminal status stays observable after completion. + // Terminal status stays observable after completion, and the last + // inner progress is preserved across the terminal transition. let snapshot = async_storage.task(id).unwrap(); assert!(matches!(snapshot.state.status, TaskStatus::Finished)); + assert_eq!(snapshot.state.inner_status.as_deref(), Some("StepTwo")); } #[test] - async fn test_prune_expired_tasks() { + async fn test_prune_expired_tasks_and_subtasks() { let notify_a = Arc::new(Notify::new()); - let notify_b = Arc::new(Notify::new()); let state_a = Arc::new(Mutex::new("initial")); - let state_b = Arc::new(Mutex::new("initial")); let a = mock_successful!(state_a, notify_a); - let b = mock_successful!(state_b, notify_b); + + let sub_gate = Arc::new(Notify::new()); + let parent_gate = Arc::new(Notify::new()); // Zero retention: terminal tasks are pruned on next access. let async_storage = AsyncTasksStorage::new(Duration::ZERO); let task_a = async_storage.run(a); - let task_b = async_storage.run(b); let id_a = task_a.id(); - let id_b = task_b.id(); + + let root = async_storage.run(TraverseRoot { + sub_gate: sub_gate.clone(), + parent_gate: parent_gate.clone(), + }); + let root_id = root.id(); settle().await; - // Both running: nothing to prune. + // Both roots running; the root has its two descendants registered. assert_eq!(async_storage.tasks().len(), 2); + assert_eq!(async_storage.task(root_id).unwrap().subtasks.len(), 2); + // Finish the top-level task: it expired and is pruned on next + // access, while the still-running root survives. notify_a.notify_one(); task_a.await.unwrap(); - // The finished task expired; the running one survives. let tasks = async_storage.tasks(); assert_eq!(tasks.len(), 1); - assert_eq!(tasks[0].id, id_b); + assert_eq!(tasks[0].id, root_id); assert!(async_storage.task(id_a).is_none()); - notify_b.notify_one(); - task_b.await.unwrap(); + // Let the subtasks finish; the root parks on `parent_gate`, still running. + sub_gate.notify_one(); + settle().await; + + // The finished subtasks are pruned while the still-running root survives. + let snapshot = async_storage.task(root_id).unwrap(); + assert!(matches!(snapshot.state.status, TaskStatus::Running)); + assert!(snapshot.subtasks.is_empty()); + + // The root finishes and expires too: the registry empties. + parent_gate.notify_one(); + root.await.unwrap(); assert!(async_storage.tasks().is_empty()); } @@ -1222,8 +1237,10 @@ mod tests { let async_storage = AsyncTasksStorage::default(); - let task = async_storage.run(SlowExit { + let task = async_storage.run(Cancellable { + state: Arc::new(Mutex::new("initial")), notify: notify.clone(), + wind_down: Duration::from_secs(60), }); settle().await; @@ -1266,10 +1283,10 @@ mod tests { #[test] async fn global_storage_runs_and_lists() { - let waiter = crate::api::storage().run(Noop); + let waiter = crate::api::tasks_storage().run(Noop); let id = waiter.id(); waiter.await.unwrap(); - assert!(crate::api::storage().task(id).is_some()); + assert!(crate::api::tasks_storage().task(id).is_some()); } #[test] @@ -1298,36 +1315,4 @@ mod tests { .is_none() ); } - - #[test] - async fn test_prune_expired_subtasks_under_running_root() { - let sub_gate = Arc::new(Notify::new()); - let parent_gate = Arc::new(Notify::new()); - - // Zero retention: terminal tasks are pruned on next access. - let async_storage = AsyncTasksStorage::new(Duration::ZERO); - - let task = async_storage.run(TraverseRoot { - sub_gate: sub_gate.clone(), - parent_gate: parent_gate.clone(), - }); - let root_id = task.id(); - - settle().await; - - // Root running with its two descendants registered. - assert_eq!(async_storage.task(root_id).unwrap().subtasks.len(), 2); - - // Let the subtasks finish; the root parks on `parent_gate`, still running. - sub_gate.notify_one(); - settle().await; - - // The finished subtasks are pruned while the still-running root survives. - let snapshot = async_storage.task(root_id).unwrap(); - assert!(matches!(snapshot.state.status, TaskStatus::Pending(_))); - assert!(snapshot.subtasks.is_empty()); - - parent_gate.notify_one(); - task.await.unwrap(); - } } diff --git a/pgdog/src/api/mod.rs b/pgdog/src/api/mod.rs index 81c1631ff..7963896fe 100644 --- a/pgdog/src/api/mod.rs +++ b/pgdog/src/api/mod.rs @@ -19,7 +19,7 @@ pub mod schema_sync; static TASKS: LazyLock = LazyLock::new(AsyncTasksStorage::default); /// Accessor for the process-global task registry. -pub(crate) fn storage() -> &'static AsyncTasksStorage { +pub(crate) fn tasks_storage() -> &'static AsyncTasksStorage { &TASKS } @@ -30,8 +30,8 @@ pub(crate) fn storage() -> &'static AsyncTasksStorage { pub(crate) use async_task::Task; /// Launch `task` as a top-level task in the global registry. -pub(crate) fn start(task: T) -> AsyncTaskWaiter { - storage().run(task) +pub(crate) fn run_task(task: T) -> AsyncTaskWaiter { + tasks_storage().run(task) } /// Error returned by the API migration tasks: either an error from the @@ -39,58 +39,6 @@ pub(crate) fn start(task: T) -> AsyncTaskWaiter { /// (failure, cancellation, panic, or abandonment) surfaced to its parent. #[derive(Debug, Display, Error, From)] pub(crate) enum MigrationError { - #[display("{_0}")] Replication(Error), - #[display("{_0}")] Task(TaskError), } - -/// Flatten a nested migration task's outcome into a single [`MigrationError`], -/// so a composite task (e.g. `reshard`) can run another composite task (e.g. -/// `copy_data`, whose error is already a `MigrationError`) as a child and -/// `?`-propagate its result without double-wrapping. -impl From> for MigrationError { - fn from(err: TaskError) -> Self { - match err { - // The child's own error: surface it directly. - TaskError::Failed(inner) => inner, - // Non-failure outcomes carry no inner error; re-wrap them. - TaskError::Cancelled => MigrationError::Task(TaskError::Cancelled), - TaskError::Panicked(msg) => MigrationError::Task(TaskError::Panicked(msg)), - TaskError::Abandoned => MigrationError::Task(TaskError::Abandoned), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn migration_error_wraps_replication_and_task_errors() { - // A replication/orchestrator error converts directly. - let err = MigrationError::from(Error::NoSchema); - assert!(matches!(err, MigrationError::Replication(Error::NoSchema))); - - // A child task's failure is wrapped, preserving the inner error. - let err = MigrationError::from(TaskError::Failed(Error::NoSchema)); - assert!(matches!( - err, - MigrationError::Task(TaskError::Failed(Error::NoSchema)) - )); - - // Non-failure child outcomes are preserved too (not stringified). - let err = MigrationError::from(TaskError::::Cancelled); - assert!(matches!(err, MigrationError::Task(TaskError::Cancelled))); - - // A nested migration task's failure is flattened, not double-wrapped. - let err = MigrationError::from(TaskError::Failed(MigrationError::Replication( - Error::NoSchema, - ))); - assert!(matches!(err, MigrationError::Replication(Error::NoSchema))); - - // A nested non-failure outcome is preserved as a task error. - let err = MigrationError::from(TaskError::::Cancelled); - assert!(matches!(err, MigrationError::Task(TaskError::Cancelled))); - } -} diff --git a/pgdog/src/api/replication.rs b/pgdog/src/api/replication.rs index 3b32f1c8b..d786c11c2 100644 --- a/pgdog/src/api/replication.rs +++ b/pgdog/src/api/replication.rs @@ -6,12 +6,10 @@ //! (delivered through [`ReplicationTask::cutover`]), and otherwise finishes //! when the source slot drains (no cutover on natural drain). With //! `auto_cutover` set (reshard) it cuts over automatically once the -//! destination has caught up. Launch it top-level with [`super::start`], or -//! as a child by spawning it through a parent task's [`AsyncTaskContext`]. +//! destination has caught up. use std::collections::HashMap; use std::sync::LazyLock; -use std::time::Duration; use parking_lot::Mutex; use tokio::select; @@ -50,18 +48,17 @@ pub(crate) enum Direction { Reverse, } -/// Drive a [`ReplicationWaiter`] to completion. The caller creates the waiter -/// (via `Orchestrator::replicate`); this task owns only the waiter, not the -/// orchestrator. +/// Run the replication by driving a [`ReplicationWaiter`] to completion. #[derive(Display, Debug, bon::Builder)] -#[display("replication {waiter}")] +#[display("replication {waiter}{}", match direction { + Direction::Forward => "", + Direction::Reverse => " (reverse)", +})] pub(crate) struct ReplicationTask { /// The running replication waiter this task drives to completion. pub waiter: ReplicationWaiter, /// Cut over automatically once the destination has caught up, instead - /// of waiting for an operator `CUTOVER`. Set by the reshard flow, - /// which drives its own cutover; standalone `REPLICATE` and - /// `copy_data` leave it `false` and wait for an external `CUTOVER`. + /// of waiting for an operator `CUTOVER`. #[builder(default)] pub auto_cutover: bool, /// Replication direction. `Reverse` marks the post-cutover stream that @@ -73,13 +70,12 @@ pub(crate) struct ReplicationTask { /// Cutover tokens of the replication tasks currently awaiting an operator /// `CUTOVER`, keyed by the root task id they belong to. A cutover token is /// *separate* from the task's `STOP_TASK` cancellation token — signalling it -/// means "cut over", not "abandon". Registrations are dropped with the task, -/// so a cutover can never outlive its task and leak into a later one. +/// means "cut over", not "abandon". static CUTOVERS: LazyLock>> = LazyLock::new(|| Mutex::new(HashMap::new())); /// Guard held by a running replication task: removes its cutover -/// registration on drop. Awaiting [`requested`](CutoverWaiter::requested) +/// registration on drop. Awaiting [CutoverWaiter::requested] /// resolves when an operator `CUTOVER` targets the task. struct CutoverWaiter { root_id: AsyncTaskId, @@ -105,42 +101,26 @@ impl Task for ReplicationTask { type Output = (); type Error = Error; - /// A cutover, once started, must run to completion. `STOP_TASK` during - /// the waiting phase is handled by the `select!` arm below (graceful - /// `waiter.stop()`); a `STOP_TASK` during an in-flight `waiter.cutover()` - /// waits out this (effectively unbounded) grace period instead of - /// force-aborting mid-cutover. - fn cancel_timeout() -> Duration { - Duration::from_secs(24 * 60 * 60) - } - - async fn run(self, ctx: AsyncTaskContext) -> Result<(), Error> { + async fn run(mut self, ctx: AsyncTaskContext) -> Result<(), Error> { let token = ctx.cancellation_token(); - // Operator flow (`REPLICATE`, `copy_data`) registers for an external - // `CUTOVER` addressed to this task; the reshard flow (`auto_cutover`) - // cuts over on its own and registers nothing. - let cutover = (!self.auto_cutover).then(|| Self::register_cutover(ctx.root_id())); - - let mut waiter = self.waiter; ctx.set_status(ReplicationStatus::Replicating); + if self.auto_cutover { + return self.perform_cutover(&ctx, &token).await; + } + + let cutover = Self::register_cutover(ctx.root_id()); + select! { - // STOP_TASK: wind down without cutting over. _ = token.cancelled() => { ctx.set_status(ReplicationStatus::Stopping); - waiter.stop(); + self.waiter.stop(); } - // Operator CUTOVER, or immediately under `auto_cutover`: switch traffic. - _ = async { if let Some(cutover) = &cutover { cutover.requested().await } } => { - ctx.set_status(match self.direction { - Direction::Forward => ReplicationStatus::CuttingOver, - Direction::Reverse => ReplicationStatus::RollingBack, - }); - waiter.cutover().await?; + _ = cutover.requested() => { + self.perform_cutover(&ctx, &token).await?; } - // Source slot drained without a cutover (operator flow only): done. - res = waiter.wait(), if cutover.is_some() => { + res = self.waiter.wait() => { res?; } } @@ -150,17 +130,27 @@ impl Task for ReplicationTask { } impl ReplicationTask { - /// Trigger a cutover on a running replication task, returning whether one - /// was there to receive it. `Some(root_id)` targets that task; `None` - /// targets the first (lowest-id) running replication task. The `CUTOVER` - /// admin command rejects with `NotReplication` when this is `false`. - pub(crate) fn cutover(target: Option) -> bool { + /// Perform the actual cutover for running replication. + async fn perform_cutover( + &mut self, + ctx: &AsyncTaskContext, + token: &CancellationToken, + ) -> Result<(), Error> { + ctx.set_status(match self.direction { + Direction::Forward => ReplicationStatus::CuttingOver, + Direction::Reverse => ReplicationStatus::RollingBack, + }); + self.waiter.cutover(token).await + } + + /// Trigger a cutover on a running replication task. + pub(crate) fn trigger_cutover(target: Option) -> bool { let tokens = CUTOVERS.lock(); let token = match target { Some(id) => tokens.get(&id), // No id: cut over the first (lowest-id) running task. - None => tokens.keys().min().copied().and_then(|id| tokens.get(&id)), + None => tokens.keys().min().and_then(|id| tokens.get(id)), }; match token { @@ -192,31 +182,13 @@ mod tests { static CUTOVER_TEST_LOCK: std::sync::LazyLock> = std::sync::LazyLock::new(|| tokio::sync::Mutex::new(())); - #[test] - fn cancel_timeout_far_exceeds_default() { - // Far larger than the 5s default: cutover must not be force-aborted. - assert!(ReplicationTask::cancel_timeout() > Duration::from_secs(60)); - } - - #[test] - fn replication_status_renders_distinct_labels() { - let labels = [ - ReplicationStatus::Replicating.to_string(), - ReplicationStatus::CuttingOver.to_string(), - ReplicationStatus::Stopping.to_string(), - ]; - assert!(labels.iter().all(|label| !label.is_empty())); - let unique: std::collections::HashSet<_> = labels.iter().collect(); - assert_eq!(unique.len(), labels.len()); - } - #[tokio::test] async fn cutover_delivers_even_when_buffered() { let _guard = CUTOVER_TEST_LOCK.lock().await; // Cutover lands before the task awaits: still delivered (latches). let waiter = ReplicationTask::register_cutover(AsyncTaskId::from(1)); assert!( - ReplicationTask::cutover(Some(AsyncTaskId::from(1))), + ReplicationTask::trigger_cutover(Some(AsyncTaskId::from(1))), "the named task must receive the cutover" ); @@ -233,7 +205,7 @@ mod tests { let waiter = ReplicationTask::register_cutover(AsyncTaskId::from(7)); assert!( - !ReplicationTask::cutover(Some(AsyncTaskId::from(8))), + !ReplicationTask::trigger_cutover(Some(AsyncTaskId::from(8))), "no task is registered under id 8" ); assert!( @@ -243,7 +215,7 @@ mod tests { "a cutover for a different id leaked to this task" ); - assert!(ReplicationTask::cutover(Some(AsyncTaskId::from(7)))); + assert!(ReplicationTask::trigger_cutover(Some(AsyncTaskId::from(7)))); tokio::time::timeout(Duration::from_secs(1), waiter.requested()) .await .expect("targeted cutover was not delivered"); @@ -258,7 +230,7 @@ mod tests { let second = ReplicationTask::register_cutover(AsyncTaskId::from(9)); assert!( - ReplicationTask::cutover(None), + ReplicationTask::trigger_cutover(None), "the first registered task must be cut over" ); @@ -280,7 +252,7 @@ mod tests { // never reaching the next one. Regression guard for the signal leak. { let first = ReplicationTask::register_cutover(AsyncTaskId::from(1)); - assert!(ReplicationTask::cutover(Some(AsyncTaskId::from(1)))); + assert!(ReplicationTask::trigger_cutover(Some(AsyncTaskId::from(1)))); drop(first); // ends without ever awaiting `requested()` } @@ -297,7 +269,9 @@ mod tests { async fn cutover_with_no_task_is_rejected() { let _guard = CUTOVER_TEST_LOCK.lock().await; // Nothing registered: `CUTOVER` (with or without an id) is rejected. - assert!(!ReplicationTask::cutover(None)); - assert!(!ReplicationTask::cutover(Some(AsyncTaskId::from(404)))); + assert!(!ReplicationTask::trigger_cutover(None)); + assert!(!ReplicationTask::trigger_cutover(Some(AsyncTaskId::from( + 404 + )))); } } diff --git a/pgdog/src/api/resharding.rs b/pgdog/src/api/resharding.rs index bb528e639..c861f33e6 100644 --- a/pgdog/src/api/resharding.rs +++ b/pgdog/src/api/resharding.rs @@ -63,11 +63,8 @@ impl Task for ReshardTask { type Output = (); type Error = MigrationError; - /// Composes long-lived child tasks; match their generous grace so a - /// `STOP_TASK` lets them wind down (and clean up replication slots) before - /// this task returns. fn cancel_timeout() -> Duration { - Duration::from_secs(24 * 60 * 60) + Duration::from_secs(60) } async fn run(self, ctx: AsyncTaskContext) -> Result<(), MigrationError> { @@ -95,7 +92,7 @@ impl Task for ReshardTask { // (created during data_sync, kept until replication takes them over). // Awaiting this guard on every exit drops whatever the publisher still // owns — a no-op once replication has claimed the slots — so a failed or - // aborted migration never leaves them lingering on the source. + // aborted migration doesn't leave them lingering on the source. let guard = orchestrator.publication_guard(); let result: Result<(), MigrationError> = async { // Copy the data, unless replicate-only. @@ -110,6 +107,12 @@ impl Task for ReshardTask { // second half of schema sync, after the bulk load. if !self.skip_schema_sync { ctx.set_status(ReshardStatus::FinalizingSchema); + + // The bulk copy above can run for hours; pools may have reloaded + // meanwhile, leaving our cluster refs stale. Re-fetch them before + // touching the destination. + orchestrator.refresh()?; + orchestrator = ctx .run( SchemaSyncTask::builder() @@ -158,21 +161,3 @@ impl Task for ReshardTask { result } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn reshard_status_renders_distinct_labels() { - let labels = [ - ReshardStatus::SchemaSync.to_string(), - ReshardStatus::SyncingData.to_string(), - ReshardStatus::FinalizingSchema.to_string(), - ReshardStatus::Replication.to_string(), - ]; - assert!(labels.iter().all(|label| !label.is_empty())); - let unique: std::collections::HashSet<_> = labels.iter().collect(); - assert_eq!(unique.len(), labels.len()); - } -} diff --git a/pgdog/src/api/schema_sync.rs b/pgdog/src/api/schema_sync.rs index 54b20826b..7d9cbf231 100644 --- a/pgdog/src/api/schema_sync.rs +++ b/pgdog/src/api/schema_sync.rs @@ -62,10 +62,6 @@ impl Task for SchemaSyncTask { type Output = Orchestrator; type Error = Error; - /// Returns the orchestrator with its schema loaded and synced so a parent - /// task can thread it into the next phase. The schema dump is skipped when - /// the orchestrator already carries one (e.g. a parent that runs `Pre` - /// then `Post` on the same orchestrator). #[allow(clippy::print_stdout)] async fn run(self, ctx: AsyncTaskContext) -> Result { let mut orchestrator = self.orchestrator; @@ -75,8 +71,6 @@ impl Task for SchemaSyncTask { orchestrator.load_schema().await?; } - // Dry run prints the SQL this task would execute and stops short of - // touching the destination. The schema load above is required for it. if self.dry_run { let schema = orchestrator.schema()?; for statement in schema.statements(self.phase.into())? { diff --git a/pgdog/src/backend/replication/logical/orchestrator.rs b/pgdog/src/backend/replication/logical/orchestrator.rs index 7e98b98ae..b6610875c 100644 --- a/pgdog/src/backend/replication/logical/orchestrator.rs +++ b/pgdog/src/backend/replication/logical/orchestrator.rs @@ -14,7 +14,8 @@ use tokio::{ sync::Mutex, time::{Instant, interval}, }; -use tracing::{info, warn}; +use tokio_util::sync::CancellationToken; +use tracing::{error, info, warn}; use super::*; @@ -31,9 +32,7 @@ pub(crate) struct Orchestrator { /// A handle to a publication's replication slots, decoupled from the rest of /// the orchestrator. Awaiting [`PublicationGuard::cleanup`] drops every slot /// the publisher still owns — a no-op once `replicate` has handed them off to -/// the streaming tasks. Take one before data sync and await it on every exit -/// so a failed or aborted migration never leaves slots lingering on the source -/// (holding back WAL). +/// the streaming tasks. pub(crate) struct PublicationGuard { publisher: Arc>, } @@ -435,9 +434,31 @@ impl ReplicationWaiter { } /// Perform traffic cutover between source and destination. - pub(crate) async fn cutover(&mut self) -> Result<(), Error> { - self.wait_for_replication().await?; - self.wait_for_cutover().await?; + /// + /// The pre-switch wait (`wait_for_replication`, `wait_for_cutover`) is + /// cancellable: a `STOP_TASK` (via `cancel`) there resumes traffic, stops + /// the replication streams, and returns without moving any traffic. Past + /// the point of no return the switch always runs to completion — cancelling + /// it would leave traffic split between source and destination. + pub(crate) async fn cutover(&mut self, cancel: &CancellationToken) -> Result<(), Error> { + select! { + // Nothing has moved yet (`wait_for_replication` only pauses traffic + // at its very end). Resume traffic (no-op if never paused) and wind + // the streams down, so the aborted cutover leaves nothing running. + _ = cancel.cancelled() => { + maintenance_mode::stop(None); + self.waiter.stop(); + warn!("[cutover] stop requested before the traffic switch, aborting cutover"); + cutover_state(CutoverState::Abort { + error: "stopped before cutover".into(), + }); + return Ok(()); + } + res = async { + self.wait_for_replication().await?; + self.wait_for_cutover().await + } => { res?; } + } // We're going, point of no return. self.orchestrator.publisher.lock().await.request_stop(); @@ -464,16 +485,12 @@ impl ReplicationWaiter { // Fix any reverse replication blockers. ok_or_abort!(self.orchestrator.schema_sync_post_cutover(true).await); - // Create the reverse-replication slot synchronously, while traffic is - // still paused, so its consistent-point LSN captures every write made - // to the new primary after cutover. On failure, resume traffic and - // abort — the forward switch is already committed, so this surfaces as - // an error for the operator (rollback won't be available). + // Create reverse replication in case we need to rollback. let waiter = ok_or_abort!(self.orchestrator.replicate().await); // Drive the running waiter as a background api task so it stays visible // in SHOW TASKS and can be cut over (rollback) or stopped. - crate::api::start( + crate::api::run_task( crate::api::replication::ReplicationTask::builder() .waiter(waiter) .direction(crate::api::replication::Direction::Reverse) @@ -482,6 +499,8 @@ impl ReplicationWaiter { // Slot is established and capturing — now safe to resume traffic. info!("[cutover] complete, resuming traffic"); + + // Point traffic to the other database and resume. maintenance_mode::stop(None); cutover_state(CutoverState::Complete); @@ -495,6 +514,7 @@ macro_rules! ok_or_abort { match $expr { Ok(res) => res, Err(err) => { + error!("Orchestrator failed: {err}"); maintenance_mode::stop(None); cutover_state(CutoverState::Abort { error: err.to_string(), diff --git a/pgdog/src/cli.rs b/pgdog/src/cli.rs index 5570a55c1..6667a8550 100644 --- a/pgdog/src/cli.rs +++ b/pgdog/src/cli.rs @@ -8,9 +8,9 @@ use tracing::{info, warn}; use crate::api::Task; use crate::api::resharding::ReshardTask; +use crate::api::run_task; use crate::api::schema_sync::{SchemaSyncPhase, SchemaSyncTask}; -use crate::api::start; -use crate::api::storage; +use crate::api::tasks_storage; use crate::backend::databases::databases; use crate::backend::replication::orchestrator::Orchestrator; use crate::backend::schema::sync::config::ShardConfig; @@ -287,7 +287,7 @@ async fn run_to_completion(task: T) -> Result { signal?; warn!("interrupt received, cancelling task {id}"); - storage().cancel_task(id); + tasks_storage().cancel_task(id); } } } From 50fef1bd38bc0fea42eee7b4a8b68e6f2f4d94c4 Mon Sep 17 00:00:00 2001 From: meskill <8974488+meskill@users.noreply.github.com> Date: Thu, 25 Jun 2026 18:39:59 +0000 Subject: [PATCH 7/8] add logging --- pgdog/src/api/async_task.rs | 35 ++++++++++++++++++++++++++++++++--- pgdog/src/api/replication.rs | 5 +++++ 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/pgdog/src/api/async_task.rs b/pgdog/src/api/async_task.rs index b97e200b4..5f638e780 100644 --- a/pgdog/src/api/async_task.rs +++ b/pgdog/src/api/async_task.rs @@ -14,7 +14,7 @@ use tokio::select; use tokio::sync::oneshot::{self, Receiver}; use tokio::time::timeout; use tokio_util::sync::CancellationToken; -use tracing::warn; +use tracing::{Instrument, Span, error, info, info_span, warn}; /// Represent the ID of the async task. #[derive(Copy, Clone, Debug, Display, FromStr, Hash, PartialEq, Eq, PartialOrd, Ord)] @@ -192,6 +192,8 @@ struct AsyncTask { state: Arc>>, /// The reference to the root map of tasks subtasks: Arc, + /// The tracing span associated with the task + tracing_span: Span, } /// Wrapper trait for [AsyncTask] that is not tied to specific @@ -335,10 +337,23 @@ fn run_task( let root_id = parent_task.map(|p| p.root_id).unwrap_or(id); let state = Arc::new(RwLock::new(AsyncTaskState::new())); + let name = task.to_string(); + + let span = if let Some(parent_task) = parent_task { + let parent_span = &parent_task.tracing_span; + let parent_name = &parent_task.name; + info!( + "Starting new subtask '{name}' (id: {id}) for parent task '{parent_name}' (root_id: {root_id})" + ); + info_span!(parent: parent_span, "task", %id) + } else { + info!("Starting new task '{name}' (id: {id})"); + info_span!(parent: None, "task", %id) + }; let entry = AsyncTask { started_at: SystemTime::now(), - name: task.to_string(), + name, root_id, cancellation_token: match parent_task { Some(parent) => parent.cancellation_token.child_token(), @@ -353,6 +368,7 @@ fn run_task( Arc::new(TasksMap::default()) }, state: state.clone(), + tracing_span: span.clone(), }; let entry = Arc::new(entry); @@ -363,7 +379,7 @@ fn run_task( task: entry.clone(), }; - let mut handle = tokio::spawn(task.run(ctx.clone())); + let mut handle = tokio::spawn(task.run(ctx.clone()).instrument(span)); let (sender, receiver) = oneshot::channel(); let cancellation_token = entry.cancellation_token.clone(); @@ -404,6 +420,7 @@ fn run_task( let _ = sender.send(Ok(res)); } Ok(Err(err)) => { + error!("task failed: {err}"); ctx.transition(TaskStatus::Error(err.to_string())); let _ = sender.send(Err(TaskError::Failed(err))); } @@ -413,6 +430,7 @@ fn run_task( } Err(err) => { let panic = err.to_string(); + error!("task panicked: {panic}"); ctx.transition(TaskStatus::Panic(panic.clone())); let _ = sender.send(Err(TaskError::Panicked(panic))); } @@ -430,25 +448,36 @@ impl AsyncTaskContext { /// non-terminal `Cancelling`), preserving the last inner progress. /// No-op once the task is already terminal. fn transition(&self, status: TaskStatus) { + let _enter = self.task.tracing_span.enter(); + let mut state = self.task.state.write(); if state.status.is_terminal() { return; } + + info!("state transition to: {status}"); + state.status = status; state.updated_at = SystemTime::now(); } /// Update the inner progress status of the current task. pub fn set_status(&self, status: T::Status) { + let _enter = self.task.tracing_span.enter(); + let mut state = self.task.state.write(); if state.status.is_terminal() { return; } + + info!("inner state transition to: {status}"); + // Don't regress a cancellation-in-progress back to Running; the task // may still report inner progress while it winds down. if state.status != TaskStatus::Cancelling { state.status = TaskStatus::Running; } + state.inner_status = Some(status); state.updated_at = SystemTime::now(); } diff --git a/pgdog/src/api/replication.rs b/pgdog/src/api/replication.rs index d786c11c2..3389483e9 100644 --- a/pgdog/src/api/replication.rs +++ b/pgdog/src/api/replication.rs @@ -10,6 +10,7 @@ use std::collections::HashMap; use std::sync::LazyLock; +use std::time::Duration; use parking_lot::Mutex; use tokio::select; @@ -101,6 +102,10 @@ impl Task for ReplicationTask { type Output = (); type Error = Error; + fn cancel_timeout() -> Duration { + Duration::from_secs(60) + } + async fn run(mut self, ctx: AsyncTaskContext) -> Result<(), Error> { let token = ctx.cancellation_token(); From 0c2c93ecb056493a5f797ea8178ba1ba2633afc8 Mon Sep 17 00:00:00 2001 From: meskill <8974488+meskill@users.noreply.github.com> Date: Fri, 26 Jun 2026 15:31:32 +0000 Subject: [PATCH 8/8] extend json logs with spans --- pgdog/src/lib.rs | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/pgdog/src/lib.rs b/pgdog/src/lib.rs index c3f4aad45..f1fc6377d 100644 --- a/pgdog/src/lib.rs +++ b/pgdog/src/lib.rs @@ -183,35 +183,27 @@ fn init_logger(general: Option<&General>) { .map(|general| general.log_format) .unwrap_or_default(); + let format = fmt::layer() + .with_ansi(std::io::stderr().is_terminal()) + .with_writer(std::io::stderr) + .with_file(false); + #[cfg(not(debug_assertions))] + let format = format.with_target(false); + match log_format { LogFormat::Text => { - let format = fmt::layer() - .with_ansi(std::io::stderr().is_terminal()) - .with_writer(std::io::stderr) - .with_file(false); - #[cfg(not(debug_assertions))] - let format = format.with_target(false); let format = format.with_filter(throttle); - let _ = tracing_subscriber::registry() .with(format) .with(filter) .try_init(); } LogFormat::Json | LogFormat::JsonFlattened => { - let format = fmt::layer() - .json() - .with_ansi(false) - .with_writer(std::io::stderr) - .with_file(false) - .with_current_span(false) - .with_span_list(false); + let format = format.json().with_current_span(false); let format = match log_format { LogFormat::JsonFlattened => format.flatten_event(true), _ => format, }; - #[cfg(not(debug_assertions))] - let format = format.with_target(false); let format = format.with_filter(throttle); let _ = tracing_subscriber::registry()