diff --git a/Cargo.lock b/Cargo.lock index 648ca143b..7c9cc8083 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -850,6 +850,31 @@ dependencies = [ "hybrid-array", ] +[[package]] +name = "bon" +version = "3.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a602c73c7b0148ec6d12af6fd5cc7a46e2eacc8878271a999abac56eed12f561" +dependencies = [ + "bon-macros", + "rustversion", +] + +[[package]] +name = "bon-macros" +version = "3.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dee98b0db6a962de883bf5d20362dee4d7ca0d12fe39a7c6c73c844e1cd7c1f" +dependencies = [ + "darling 0.23.0", + "ident_case", + "prettyplease", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.118", +] + [[package]] name = "borsh" version = "1.6.1" @@ -3293,6 +3318,7 @@ dependencies = [ "azure_identity", "base64 0.22.1", "bit-vec 0.8.0", + "bon", "brunch", "bytes", "cc", diff --git a/Cargo.toml b/Cargo.toml index a6db5313b..dc553ed15 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ edition = "2024" [workspace.dependencies] pgdog-plugin = { path = "./pgdog-plugin", version = "0.3.0" } pgdog-config = { path = "./pgdog-config", version = "0.1.0" } +bon = "3.9" schemars = { version = "1.2.1", features = ["uuid1"] } serde_json = "1.0" indexmap = { version = "2.14", features = ["serde"] } diff --git a/docs/ASYNC_TASKS.md b/docs/ASYNC_TASKS.md new file mode 100644 index 000000000..68b15f5a8 --- /dev/null +++ b/docs/ASYNC_TASKS.md @@ -0,0 +1,159 @@ +# Async Task Framework + +Long-running operations (resharding, copy, replication, schema sync) run as background *tasks* in +`crate::api` ([`pgdog/src/api/`](../pgdog/src/api/)). The admin SQL API and the `pgdog` CLI both +start the same task through the same registry; only how they consume the result differs. + +## The `Task` trait + +A task is any type implementing `Task` ([`api/async_task.rs`](../pgdog/src/api/async_task.rs)): + +```rust +pub trait Task: Display + Debug + Send + Sync + Sized + 'static { + type Status: Display + Send + Sync + 'static; // inner progress; Empty = none + type Output: Send + 'static; + type Error: std::error::Error + Send + 'static; + + fn cancel_timeout() -> Duration { Duration::from_secs(5) } + + fn run(self, ctx: AsyncTaskContext) + -> impl Future> + Send + 'static; +} +``` + +`run` is the whole task. Everything else — spawning, ids, status storage, cancellation, +retention — is handled by the framework around it. + +## Starting a task + +`crate::api::run_task(task)` ([`api/mod.rs`](../pgdog/src/api/mod.rs)) calls +`tasks_storage().run(task)`, which delegates to the private `run_task` in `async_task.rs`. That +function: + +1. allocates an id from the registry (`tasks.next_id()`); a root task's `root_id` is its own id, +2. builds the `AsyncTask` entry (cancellation token, state, tracing span) and inserts it into the + map *before* spawning, +3. spawns the task future: `tokio::spawn(task.run(ctx).instrument(span))`, +4. spawns a second watcher future that `select!`s the task handle against its cancellation token + and records the terminal status, +5. returns an `AsyncTaskWaiter { id, waiter }`. + +The id is known before `run` does any work, so the caller can address the task immediately. + +```rust +pub struct AsyncTaskWaiter { + id: AsyncTaskId, + waiter: Receiver>>, // oneshot +} +``` + +`AsyncTaskWaiter` is a `Future`; awaiting it yields the task's `Result`. A dropped sender (watcher +gone) maps to `Err(TaskError::Abandoned)`. `.id()` returns the id without awaiting. + +The registry is process-global: + +```rust +static TASKS: LazyLock = LazyLock::new(AsyncTasksStorage::default); +``` + +So a CLI task and an admin task land in the same `AsyncTasksStorage`, both visible to `SHOW TASKS` +and cancellable by `STOP_TASK`. + +## Status + +Two separate axes. The lifecycle status is a fixed enum: + +```rust +pub enum TaskStatus { + Started, Running, Cancelling, // non-terminal + Finished, Cancelled, Error(String), Panic(String), // terminal +} +``` + +The domain-specific progress is `Task::Status`, reported by the task itself via +`ctx.set_status(...)` and surfaced separately as `inner_status`. `set_status` flips the lifecycle +to `Running` (but won't regress out of `Cancelling`) and stores the rendered inner status. + +The registry stores both in a type-erased `TaskState` (`name`, `status`, `inner_status`, +`started_at`, `updated_at`), so `SHOW TASKS` reads it without knowing `T`. + +Transitions are write-once at the terminal boundary: `transition` and `set_status` both bail early +if `state.status.is_terminal()`, so a context clone that outlives the task can't clobber a +recorded outcome. The watcher sets the terminal status based on the `select!` arm that won: + +- task returned `Ok` → `Finished` (or `Cancelled` if the token was already cancelled), +- task returned `Err(e)` → `Error(e)`, waiter gets `TaskError::Failed(e)`, +- join handle cancelled → `Cancelled` / `TaskError::Cancelled`, +- join handle panicked → `Panic(msg)` / `TaskError::Panicked(msg)`. + +## Cancellation + +Every task holds a `CancellationToken`; a child's token is `parent.child_token()`, so cancelling a +task cancels its whole subtree. + +Cooperative vs. not is decided by one call. `ctx.cancellation_token()` sets `cooperative = true` as +a side effect and hands back the token: + +```rust +pub fn cancellation_token(&self) -> CancellationToken { + self.task.cooperative.store(true, Ordering::Relaxed); + self.task.cancellation_token.clone() +} +``` + +The watcher branches on that flag when the token fires: + +```rust +ctx.transition(TaskStatus::Cancelling); +if cooperative { + // grace period, then force-abort + match timeout(T::cancel_timeout(), &mut handle).await { + Ok(res) => res, + Err(_) => { handle.abort(); handle.await } + } +} else { + handle.abort(); handle.await // never took the token → abort now +} +``` + +A cooperative task typically `select!`s its work against the token and runs its own shutdown +before `run` returns, within `cancel_timeout()` (default 5s). A non-cooperative task never takes +the token and is aborted immediately — fine when dropping the future already tears the work down +(e.g. in-flight units in a `JoinSet`). + +`STOP_TASK` calls `cancel_task(id)`, which returns `None` for an unknown or already-terminal id +(so callers don't claim success or emit cleanup warnings for a finished task) and otherwise calls +`entry.cancel()`. + +## Composition + +`ctx.run(child)` spawns a subtask: + +```rust +pub fn run(&self, task: T1) -> AsyncTaskWaiter { + run_task(Some(&self.task), &self.task.subtasks, task) +} +``` + +All descendants register in the *root's* `subtasks` map, so `TaskSnapshot.subtasks` is a flat list +of every descendant ordered by id (their own `subtasks` are always empty). `SHOW TASKS` renders +the root and its running children as separate rows. `STOP_TASK` only addresses roots; the token +hierarchy propagates the cancel downward. + +## Retention + +`AsyncTasksStorage` prunes on every `run`/`tasks`/`task` call. `prune()` drops entries whose +terminal state is older than `retention` (`TASK_RETENTION = 24h`); running tasks are never dropped. +So an id stays addressable, with its final status and last `inner_status`, for 24h after it +finishes. + +## The two callers + +Both go through `AsyncTaskWaiter`; the difference is what they do with it. + +- **Admin** ([`pgdog/src/admin/`](../pgdog/src/admin/)): fire-and-forget. Take `.id()`, drop the + waiter, return the id to the client. The client polls `SHOW TASKS` and runs `STOP_TASK `. +- **CLI** ([`cli.rs`](../pgdog/src/cli.rs)): await the waiter in a loop. On Ctrl-C, call into the + registry to cancel, then keep awaiting until the task winds down before exiting. + +Same task, same options, same status transitions, same error path either way. diff --git a/docs/RESHARDING.md b/docs/RESHARDING.md index 259a5b49b..24d5149bc 100644 --- a/docs/RESHARDING.md +++ b/docs/RESHARDING.md @@ -167,8 +167,10 @@ traffic immediately via `maintenance_mode::stop()` and returns an error. Steps i 1. `publisher.request_stop()` + `waiter.wait()` — stops the replication stream; drains remaining WAL. 2. `schema_sync_cutover()` — applies `SyncState::Cutover` operations (e.g. drops sequences that won't be used in the sharded cluster). -3. `cutover(source_db, dest_db)` in [`pgdog/src/backend/databases.rs`](../pgdog/src/backend/databases.rs) — atomically swaps source and - destination in the in-memory routing table. +3. `cutover(source_db, dest_db)` in [`pgdog/src/backend/databases.rs`](../pgdog/src/backend/databases.rs) — + atomically swaps the two clusters' logical identity in the routing table (and config refs via + `Config::cutover`/`Users::cutover`); no data moves. Persisted to disk when + `cutover_save_config = true`. 4. `orchestrator.refresh()` — re-fetches both clusters from `databases()` so the orchestrator now treats the new cluster as source for reverse replication. 5. `schema_sync_post_cutover()` — applies `SyncState::PostCutover` (removes blockers that would @@ -214,19 +216,19 @@ returned by any task surfaces via `table?` and aborts the manager's loop. Remain completion or abort via their own `AbortSignal`, but their results are ignored once the channel is dropped. -### Temporary vs permanent replication slots +On a failed or aborted migration, `ReshardTask::run` ([`api/resharding.rs`](../pgdog/src/api/resharding.rs)) +obtains a guard via `Orchestrator::publication_guard()` and calls `PublicationGuard::cleanup()`, which +locks the publisher and has `Publisher::cleanup()` drop the permanent WAL slot via +`DROP_REPLICATION_SLOT "name" WAIT`. On success the slot is kept so reverse replication can roll +back. If the process crashes before this runs, the slot survives and keeps accumulating WAL on the +source — drop it manually before retrying. + +### Temporary replication slots Per-table slots created in [`Table::data_sync()`](../pgdog/src/backend/replication/logical/publisher/table.rs) are `TEMPORARY` — PostgreSQL drops them automatically when the replication connection closes, including on error or panic. A failed copy task leaves no orphaned per-table slot. -The `Publisher`'s named replication slot (the one used for the WAL streaming phase) is permanent. -[`Publisher::cleanup()`](../pgdog/src/backend/replication/logical/publisher/publisher_impl.rs) drops it by calling `slot.drop_slot()`, which issues -`DROP_REPLICATION_SLOT "name" WAIT` over the replication protocol connection. `cleanup()` is an -explicit method on `Orchestrator` — it is not called automatically inside `replicate_and_cutover()`. -If the orchestrator is dropped after Step 5 begins but before `cleanup()` is called (e.g. a -process crash), the permanent slot survives and continues accumulating WAL on the source. - ### The `ok_or_abort!` macro — guaranteed traffic resumption after cutover starts ```rust diff --git a/docs/issues/replication.md b/docs/issues/replication.md new file mode 100644 index 000000000..6c3f46a5d --- /dev/null +++ b/docs/issues/replication.md @@ -0,0 +1,550 @@ +# Replication Issues + +--- + +## 🚧 Issue 1 — Lag inflation when source databases share a PostgreSQL instance + +### Description + +The replication lag metric used to gate cutover is computed as: + +```sql +SELECT pg_current_wal_lsn() - confirmed_flush_lsn +FROM pg_replication_slots +WHERE slot_name = '...'; +``` + +`pg_current_wal_lsn()` returns the current write-ahead log position for the **entire PostgreSQL instance**, not for a specific database or publication. When multiple source shards are hosted on the same PostgreSQL instance (different databases, one slot per database), `pg_current_wal_lsn()` advances with every write to any database on that instance. + +A logical replication slot only decodes and delivers changes that belong to its publication. Changes from other databases are physically present in the WAL stream but are invisible to the slot's decoder — `confirmed_flush_lsn` only advances when the client acknowledges a decoded logical message (i.e., a `Commit` record from the publication). Once a slot has replayed all of its publication's data, `confirmed_flush_lsn` stagnates at the LSN of the last commit in that publication. It will never advance past WAL records from other databases, regardless of whether writes have stopped. + +This means the lag metric permanently overstates the remaining work. On a three-shard benchmark where all source databases are on one instance, the observed lag was ~3.5 GB per slot even after each slot had replayed all of its own publication's data. The lag never dropped below the cutover threshold, so `wait_for_replication()` looped indefinitely and cutover never fired. + +The destination is a second, often worse, source of the same inflation. When the +destination shards live on the **same PostgreSQL instance** as the source (a +single-instance dev/test setup, or any deployment that co-locates them), every +row pgdog copies or replicates *into* the destination is itself a write to that +instance and advances `pg_current_wal_lsn()`. So the very act of catching the +destination up pushes the instance WAL position forward, while the source slot's +`confirmed_flush_lsn` only tracks the source publication — the measured lag rises +as replication makes progress instead of falling. Every source slot on a shared +instance reads back a near-identical, ever-growing lag, because they all subtract +their publication-scoped `confirmed_flush_lsn` from the same instance-wide +`pg_current_wal_lsn()`. + +### Cause + +`pg_current_wal_lsn()` is instance-scoped. `confirmed_flush_lsn` is publication-scoped. Their difference is only meaningful when a single database accounts for all writes to the instance. + +### Code references + +| Symbol | File | +|---|---| +| `ReplicationSlot::replication_lag()` — the lag query | [`pgdog/src/backend/replication/logical/publisher/slot.rs`](../../pgdog/src/backend/replication/logical/publisher/slot.rs) | +| `ReplicationWaiter::wait_for_replication()` — the cutover gate | [`pgdog/src/backend/replication/logical/orchestrator.rs`](../../pgdog/src/backend/replication/logical/orchestrator.rs) | +| Keepalive handler / `flush_lsn` reply (proposed `data_since_keepalive` flag) | [`pgdog/src/backend/replication/logical/publisher/publisher_impl.rs`](../../pgdog/src/backend/replication/logical/publisher/publisher_impl.rs) | +### Fix + +The PostgreSQL walsender sends a keepalive message after exhausting all decoded changes available for the publication. The keepalive carries `wal_end` — the server's current WAL write position. When a keepalive arrives and no xlog data was received since the previous keepalive, the slot has drained: there is nothing for the publication between `committed_lsn` and `wal_end`, so that gap consists entirely of other databases' WAL. + +In this state the client can safely reply with `flush_lsn = wal_end`. PostgreSQL sets `confirmed_flush_lsn` on the slot to `wal_end`, and the lag query returns ~0. + +Reporting `wal_end` is only valid when the slot is caught up. During active streaming — where the server is sending transactions and keepalives may arrive between commits — `flush_lsn` must remain at `committed_lsn`. Reporting `wal_end` prematurely would advance `confirmed_flush_lsn` past unapplied commits; if the connection dropped, the server would restart from `wal_end` and those commits would be lost. + +The proposed guard: a `data_since_keepalive` flag. Set to `true` when any xlog data message arrives; cleared to `false` when a keepalive arrives. A keepalive is the catch-up signal only when this flag is `false` — meaning no data arrived between the last keepalive and this one. + +``` +keepalive received + data_since_keepalive = true -> reply flush_lsn = committed_lsn (active, mid-stream) + data_since_keepalive = false -> reply flush_lsn = wal_end (caught up) +``` + +### PostgreSQL references + +- [Streaming Replication Protocol](https://www.postgresql.org/docs/current/protocol-replication.html) — Standby Status Update message (`flush_lsn` field); Primary Keepalive message (`wal_end` field). +- [`pg_replication_slots`](https://www.postgresql.org/docs/current/view-pg-replication-slots.html) — `confirmed_flush_lsn` column: the last LSN confirmed received by the standby/subscriber. +- [Logical Replication](https://www.postgresql.org/docs/current/logical-replication.html) — publication filtering and how the walsender decodes only changes relevant to the subscriber's publication. + +--- + +## ✅ Issue 2 — Stop signal only unblocked one task instead of all (resolved) + +> **Resolved.** The `Arc` stop signal was replaced with a +> [`CancellationToken`](https://docs.rs/tokio-util/latest/tokio_util/sync/struct.CancellationToken.html) +> held by both `Publisher` and `Waiter` +> ([`publisher_impl.rs`](../../pgdog/src/backend/replication/logical/publisher/publisher_impl.rs)). +> `Waiter::stop()` calls `stop.cancel()`, which latches permanently and wakes +> every per-shard stream task — whether already parked on `stop.cancelled()` or +> polling later — so one call unblocks all N tasks. The token (not the +> `watch::Sender` proposed below) is the primitive that shipped; both +> satisfy the "persistent value, wakes every receiver" requirement. Each stream +> latches the signal with a `stopping` flag, then drains its slot to completion. + +### Description + +When cutover initiates, `Waiter::stop()` must signal all N per-shard replication tasks to terminate. Each task blocks in a `select!` loop waiting on either incoming WAL data or a stop signal. All N tasks must receive the signal and break out of their loop before `Waiter::wait()` can return. + +The original implementation used `Arc` for the stop signal. `Notify::notify_one()` wakes exactly one waiting task. `Notify::notify_waiters()` wakes only tasks that are *currently parked* on the future at the moment of the call — any task that polls `notified()` after `notify_waiters()` returns will park again and never wake. With N tasks the result was: one task exited, the others remained blocked indefinitely. `Waiter::wait()` joined all task handles and hung. + +### Cause + +`Notify` does not persist state. A permit stored by `notify_one()` is consumed by the first task that calls `notified().await`; subsequent tasks see no permit and park. `notify_waiters()` is a snapshot operation: it only affects tasks already parked at the instant of the call. + +### Code references + +| Symbol | File | +|---|---| +| `Publisher::stop` / `Waiter::stop` — the stop channel | [`pgdog/src/backend/replication/logical/publisher/publisher_impl.rs`](../../pgdog/src/backend/replication/logical/publisher/publisher_impl.rs) | +| `Waiter::stop()` — sends the signal | same file | +| `stop_rx.changed()` arm in per-shard task `select!` | same file | +### Fix + +Replace `Arc` with `Arc>` / `watch::Receiver`. `watch::Sender::send(true)` persists the value in the channel. Every receiver — whether currently parked or polling later — resolves `changed()` as soon as it observes the new value. One `send(true)` call unblocks all N tasks regardless of scheduling order. + +```rust +// Publisher and Waiter both hold Arc>. +// Each task subscribes before spawning: +let mut stop_rx = self.stop.subscribe(); // watch::Receiver + +// Inside the task select!: +_ = stop_rx.changed() => { + slot.stop_replication().await?; + break; +} + +// Waiter::stop() -- one call, all tasks unblock: +pub fn stop(&self) { + let _ = self.stop.send(true); +} +``` + +### Tokio references + +- [`tokio::sync::Notify`](https://docs.rs/tokio/latest/tokio/sync/struct.Notify.html) — `notify_one()` stores at most one permit; `notify_waiters()` is not persistent. +- [`tokio::sync::watch`](https://docs.rs/tokio/latest/tokio/sync/watch/index.html) — `Sender::send()` updates the shared value; all receivers observe it on their next `changed()` poll. + +--- + +## 🚧 Issue 3 — Premature cutover when lag map is empty at startup + +### Description + +The orchestrator computes the current replication lag as the maximum value across all per-shard lag entries: + +```rust +lag.values().copied().max().unwrap_or(i64::MAX) as u64 +``` + +Per-shard lag values are written into `replication_lag: HashMap` by each task's `check_lag.tick()` interval. This interval fires once per second, but the first tick does not fire until one second after the task spawns. + +The orchestrator's `wait_for_replication()` loop runs immediately after `replicate()` returns. If it evaluates `replication_lag()` before any task's first tick, the map is empty. `Iterator::max()` on an empty iterator returns `None`. The original code used `unwrap_or_default()`, which returns `0`. The orchestrator saw lag = 0 bytes, concluded replication was already caught up, entered maintenance mode, and triggered cutover before any data had been replicated. + +### Cause + +Two independent races: +1. `unwrap_or_default()` on an empty map returned 0, which is indistinguishable from a legitimately zero lag. +2. No synchronization between task startup and the first lag measurement. + +### Code references + +| Symbol | File | +|---|---| +| `Publisher::replicate()` — pre-population loop for `replication_lag` map | [`pgdog/src/backend/replication/logical/publisher/publisher_impl.rs`](../../pgdog/src/backend/replication/logical/publisher/publisher_impl.rs) | +| `Publisher::replication_lag()` — reads the map | same file | +| `Orchestrator::replication_lag()` — takes the max across shards | [`pgdog/src/backend/replication/logical/orchestrator.rs`](../../pgdog/src/backend/replication/logical/orchestrator.rs) | +### Fix + +Pre-populate the map with `i64::MAX` for every shard index before spawning any tasks. Use `or_insert` so a real measurement written by a task is never overwritten by the sentinel. Change `unwrap_or_default()` to `unwrap_or(i64::MAX)` so an empty map (which should not occur after pre-population) is also treated as maximally lagged rather than zero. + +```rust +// Before spawning tasks: +let mut guard = self.replication_lag.lock(); +for number in 0..n_sources { + guard.entry(number).or_insert(i64::MAX); +} + +// Orchestrator reads: +lag.values().copied().max().unwrap_or(i64::MAX) as u64 +``` + +The sentinel is overwritten by the first real measurement from each task's `check_lag.tick()`. Until that point the orchestrator treats every uninitialized shard as having infinite lag and will not proceed. + +--- + +## ✅ Issue 4 — Divergent code paths for the same operation (resolved) + +> **Resolved.** The divergent paths were consolidated onto the `crate::api` +> background-task framework. [`ReshardTask`](../../pgdog/src/api/resharding.rs) +> is the single composite flow (pre-data schema sync → data copy → post-data +> schema sync → replication), with `auto_cutover` toggling the final cutover; +> [`CopyDataTask`](../../pgdog/src/api/copy_data.rs) is the bulk data-copy leaf +> it composes; and `REPLICATE`/`CUTOVER`/`STOP_TASK` all drive one +> [`ReplicationTask`](../../pgdog/src/api/replication.rs). The old +> `backend/replication/logical/admin.rs` (`AsyncTasks`, `TaskType`) and +> `Orchestrator::replicate_and_cutover()` referenced below no longer exist. +> +> The `notify_one()` race is gone: a cutover is delivered to the specific +> running task through a latching `CancellationToken` held in a per-task +> `CUTOVERS` map keyed by task id (`ReplicationTask::cutover`), so it cannot be +> lost or leak into a later task. Natural slot drain intentionally does *not* +> auto-cut-over in the operator flow (`REPLICATE`/`copy_data`); cutover is +> explicit (operator `CUTOVER`) or automatic only under reshard's `auto_cutover`. + +### Description + +The resharding pipeline — schema sync, data sync, replication, cutover — can be initiated through four independent entry points: + +| Entry point | Phase coverage | +|---|---| +| `RESHARD` admin command | full flow: schema sync, data sync, replication, cutover | +| CLI `replicate-and-cutover` | same full flow | +| `REPLICATE` admin command | replication only; schema/data sync must already be done | +| `REPLICATE` + `CUTOVER` admin commands | replication started as background task; cutover triggered externally | + +Each path composes the same underlying primitives (`Orchestrator`, `ReplicationWaiter`, `Publisher`) but assembles them differently and makes different assumptions about which phases have already run and which signals will arrive. When one path is fixed, the fix is not applied to the others because there is no single place that owns the shared contract. + +The concrete instance that surfaced during debugging: the background task registered by `REPLICATE` exits without performing cutover when the slot drains naturally. It waits in: + +```rust +select! { + _ = abort_rx => { waiter.stop(); } + _ = cutover.notified() => { waiter.cutover().await; } + result = waiter.wait() => { /* log error only */ } +} +``` + +When the source slot drains (`CopyDone`), `waiter.wait()` returns and the task exits via the third arm. `waiter.cutover()` is never called: `wait_for_replication()`, `wait_for_cutover()`, `schema_sync_post_cutover()`, and `databases::cutover()` (which flips traffic) all go unexecuted. The destination is fully populated and replication has stopped, but pgdog still routes to the source. The direct paths (`RESHARD`, CLI) always call `cutover()` at the end — they do not have this gap. + +A secondary consequence of path divergence: `AsyncTasks::cutover()` still calls `notify_one()` on `Arc` (the primitive replaced in Issue 2). The fix in `publisher_impl.rs` was not propagated to the admin layer. This creates a race in the `REPLICATE` + `CUTOVER` path: if the slot drains at the same instant the operator sends `CUTOVER`, `select!` picks one arm non-deterministically and the notification may be silently discarded. + +### Cause + +Each entry point was built to satisfy a specific operational need without consolidating around a shared flow. There is no contract enforcing which phases run in which order, so paths accumulate independent deviations over time. + +### Code references + +| Symbol | File | +|---|---| +| `TaskType::Replication` `select!` / `waiter.wait()` arm (historical, line ~150) | now `ReplicationTask::run` in [`pgdog/src/api/replication.rs`](../../pgdog/src/api/replication.rs) | +| `AsyncTasks::cutover()` / `notify_one()` (historical, line ~75) | now the cutover signal in [`pgdog/src/api/async_task.rs`](../../pgdog/src/api/async_task.rs) | +| `Replicate::execute()` | [`pgdog/src/admin/replicate.rs`](../../pgdog/src/admin/replicate.rs) | +| `Reshard::execute()` | [`pgdog/src/admin/reshard.rs`](../../pgdog/src/admin/reshard.rs) | +| `Orchestrator::replicate_and_cutover()` — the canonical flow | [`pgdog/src/backend/replication/logical/orchestrator.rs`](../../pgdog/src/backend/replication/logical/orchestrator.rs) | +### Fix + +Two independent fixes are needed. + +#### Immediate fix — cutover on natural drain + +The `waiter.wait()` arm in the background task `select!` currently only logs errors; on `Ok(())` it exits silently without performing cutover. The slot has fully drained, the destination is populated, but pgdog still routes traffic to the source. The fix is to call `waiter.cutover()` on successful completion: + +```rust +result = waiter.wait() => { + match result { + Ok(()) => { + // Slot drained naturally — still perform cutover. + if let Err(err) = waiter.cutover().await { + error!(...); + } + } + Err(err) => error!(...), + } +} +``` + +#### Secondary fix — `notify_one()` race in `AsyncTasks::cutover()` + +`AsyncTasks::cutover()` still calls `notify_one()` on `Arc` (the primitive replaced in Issue 2 for the stop signal). The fix from Issue 2 was not propagated to the admin layer. This creates a race in the `REPLICATE` + `CUTOVER` path: if the slot drains at the same instant the operator sends `CUTOVER`, `select!` picks one arm non-deterministically and the notification may be silently discarded. + +Replace `Arc` with `watch::Sender` in `AsyncTasks` so that `cutover.send(true)` persists the value and any task polling `cutover.changed()` — whether parked or not at the moment of the call — sees it. + +#### Structural fix + +Make `Orchestrator::replicate_and_cutover()` the single canonical implementation of the full flow and have the background-task path call it rather than assembling phases independently. The background-task model's only responsibility should be *when* to trigger cutover (immediately, on external signal, on timeout) — not *how* to execute it. + +### References + +- [Tokio `select!` macro](https://docs.rs/tokio/latest/tokio/macro.select.html) — when multiple branches are ready simultaneously, one is chosen pseudo-randomly; no branch is guaranteed to execute. + +--- + +## 🚧 Issue 5 — `AbortSignal` is not an abort signal; it is a coordinator-gone detector + +### Description + +`AbortSignal` is used inside the parallel table-copy path to interrupt in-flight `COPY` loops when the sync coordinator exits. The name implies an active cancellation primitive, but the mechanism is entirely passive: it wraps an `UnboundedSender` and calls `tx.closed().await`, which resolves only when the corresponding `rx` (owned by `ParallelSyncManager::run()`) is dropped. + +There is no `abort()` method. Nothing sends a signal. The only way the future resolves is if the receiver end of the channel is dropped — which happens as a side effect of the manager returning, not as an intentional cancellation act. + +This creates four concrete problems. + +**1. `rx` is only dropped when `manager.run()` returns.** + +`ParallelSyncManager::run()` is called inside a `tokio::spawn`ed task. That task runs independently of its caller. Dropping or cancelling `Orchestrator::data_sync()` or `Publisher::data_sync()` mid-await does not cancel the spawned task, does not drop `rx`, and does not fire the abort signal in any worker. Workers keep running until `manager.run()` finishes on its own. + +**2. The only trigger for `rx` dropping mid-run is a worker error propagating through `?`.** + +Inside `run()`: + +```rust +while let Some(table) = rx.recv().await { + tables.push(table?); // ← error here short-circuits the loop +}; +``` + +When one worker sends `Err(...)`, the `?` unwinds `run()`, `rx` drops, and `tx.closed()` resolves in every remaining worker — all concurrent `COPY` loops abort simultaneously. This is the only operational abort path. There is no way to cancel a single table's copy without bringing down every other table in the same manager. + +**3. The `tx.is_closed()` guard fires after permit acquisition, not before.** + +```rust +let _permit = Arc::clone(&self.permit) + .acquire_owned() + .await + .map_err(|_| Error::ParallelConnection)?; + +if self.tx.is_closed() { // ← checked here + return Err(Error::DataSyncAborted); +} +``` + +Tasks that are queued behind the semaphore when the coordinator dies continue to wait for a permit. They do not observe that the coordinator is gone until after they acquire a permit. Under high concurrency this means every queued task wakes, acquires a permit, checks `is_closed()`, and immediately returns an error — burning a permit round-trip for each one. + +**4. The name is a lie told to the next reader.** + +A reader seeing `AbortSignal` at a call site infers that an active abort can be issued. The implementation has no such capability. The name suppresses the question "who calls abort?" and makes the passive channel-closed mechanism invisible. This is the most dangerous property: future code that tries to use `AbortSignal` to implement a real abort — a timeout, a graceful stop, a per-table cancel — will find no mechanism to do so and is likely to add a second, parallel cancellation path instead of understanding the existing one. + +### Cause + +The signal was built to satisfy a narrow requirement — stop in-flight copies when any copy fails — and named optimistically. The implementation accidentally works for that single case because error propagation through `?` drops `rx` as a side effect. The gap between the name and the mechanism was not visible until the wider behaviour of `rx` (its lifetime, its relationship to `tokio::spawn`, the absence of an explicit abort path) was examined. + +### Code references + +| Symbol | File | +---|---| +| `AbortSignal` — full definition | [`pgdog/src/backend/replication/logical/publisher/abort.rs`](../../pgdog/src/backend/replication/logical/publisher/abort.rs) | +| `ParallelSync::run_with_retry()` — constructs `AbortSignal` per attempt | [`pgdog/src/backend/replication/logical/publisher/parallel_sync.rs`](../../pgdog/src/backend/replication/logical/publisher/parallel_sync.rs) | +| `Table::data_sync()` — `select!` arm polling `abort.aborted()` | [`pgdog/src/backend/replication/logical/publisher/table.rs`](../../pgdog/src/backend/replication/logical/publisher/table.rs) | +| `ParallelSyncManager::run()` — owns `rx`; drops it on exit | [`pgdog/src/backend/replication/logical/publisher/parallel_sync.rs`](../../pgdog/src/backend/replication/logical/publisher/parallel_sync.rs) | + +### Fix + +Replace `AbortSignal` with a `CancellationToken` (from `tokio-util`) or a `watch::Sender` with an explicit `cancel()` method. The cancellation handle must be cloneable, sendable to workers before they start, and triggerable by an external caller — not just by the death of a channel receiver. + +```rust +// tokio-util approach: +use tokio_util::sync::CancellationToken; + +let cancel = CancellationToken::new(); + +// Pass a clone to each worker: +let worker_cancel = cancel.clone(); + +// Inside data_sync COPY loop: +select! { + _ = worker_cancel.cancelled() => { + return Err(Error::CopyAborted(self.table.clone())); + } + result = copy_sub.copy_data(data_row) => { ... } +} + +// To stop all workers at any time: +cancel.cancel(); +``` + +With this shape: +- The manager can cancel all workers explicitly without dying first. +- A timeout or external stop signal can call `cancel.cancel()` without needing to propagate an error through the channel. +- Per-table cancellation is possible by giving each worker its own child token: `cancel.child_token()`. +- The `tx.is_closed()` pre-flight check becomes `cancel.is_cancelled()`, which is honest about what it is testing. + +The `AbortSignal` type should be deleted. It carries no state that cannot be replaced by the token directly, and its existence perpetuates the misleading name. + +### References + +- [`tokio_util::sync::CancellationToken`](https://docs.rs/tokio-util/latest/tokio_util/sync/struct.CancellationToken.html) — cooperative cancellation token with child-token support and `cancelled().await`. +- [`tokio::sync::watch`](https://docs.rs/tokio/latest/tokio/sync/watch/index.html) — persistent value channel; used for the stop-signal fix in Issue 2. + +--- + +## ✅ Issue 6 — `STOP_TASK` during cutover removed the task but left it running (resolved) + +> **Resolved.** In the `crate::api` framework `STOP_TASK` only *requests* +> cancellation through a `CancellationToken`; it never removes the registry +> entry. The entry is dropped only after the task future actually reaches a +> terminal state, so the registry always reflects real execution state. A +> cutover that has passed its point of no return is allowed to finish rather +> than being torn down mid-traffic-swap: `ReplicationTask` overrides +> `cancel_timeout()` with a 60 s grace +> ([`replication.rs`](../../pgdog/src/api/replication.rs)) — comfortably longer +> than the committed swap (WAL drain + schema-sync DDL + reverse-replication +> setup), so `STOP_TASK` lets it run to completion instead of aborting it after +> the 5 s default. A genuinely hung swap is still reaped once the grace expires. + +### Description (old implementation) + +The old `AsyncTasks` registry (deleted `backend/replication/logical/admin.rs`) +tracked each task in a `DashMap`, where: + +```rust +struct TaskInfo { + abort_tx: oneshot::Sender<()>, // dropped => abort_rx resolves + cutover: Arc, + task_kind: TaskKind, + started_at: SystemTime, +} +``` + +A replication task was driven by a detached `tokio::spawn`ed future selecting +over three arms: + +```rust +spawn(async move { + select! { + _ = abort_rx => { waiter.stop(); } // STOP_TASK + _ = cutover.notified() => { waiter.cutover().await; } // CUTOVER + result = waiter.wait() => { /* slot drained */ } + } + AsyncTasks::get().tasks.remove(&id); // self-remove on exit +}); +``` + +`STOP_TASK` called `AsyncTasks::remove(id)`, which did +`tasks.remove(&id)` — synchronously deleting the map entry (and dropping +`abort_tx`). Two things made the cutover-then-stop sequence broken: + +1. **Removal was immediate and unconditional.** The entry vanished from the map + the instant `remove` was called, and `remove` returned the `TaskKind` as if + it had succeeded. Nothing tied the entry's presence to the task having + actually stopped — the map said "gone" while the spawned future was still + running. There was no terminal state retained; `SHOW TASKS` simply stopped + listing it. + +2. **The abort arm could not win once cutover had started.** Dropping + `abort_tx` resolves `abort_rx`, but `select!` had already committed to the + `cutover.notified()` arm, so `waiter.cutover().await` ran to completion + regardless. Unlike the `SchemaSync`/`CopyData` variants — which held an + `abort_handle` and called `abort_handle.abort()` — the replication variant + had no hard-abort path at all; `STOP_TASK` could only ever request a + graceful `waiter.stop()`, and only if it won the race. + +So after `CUTOVER` then `STOP_TASK`, the registry reported the task removed and +stopped, while the detached future kept flipping traffic, running post-data +schema sync, and setting up reverse replication — mutating cluster state with +no visibility and no way to stop it. + +### How the new implementation works + +The registry entry's lifetime is tied to the task future through the supervisor +in `run_task` ([`async_task.rs`](../../pgdog/src/api/async_task.rs)), not to +admin-side bookkeeping: + +1. **`STOP_TASK` requests; it does not remove.** + `AsyncTasksStorage::cancel_task` only calls `entry.cancel()` on the task's + `CancellationToken` and leaves the entry in the map. The entry transitions + to a terminal status (`Cancelled`/`Finished`) only when the spawned future + returns; `prune()` then drops it after the 24h retention. A running task is + never removed, so `SHOW TASKS` keeps showing it as `Stopping`/`CuttingOver` + until it has genuinely finished — the map always reflects real state. + +2. **A started cutover runs to completion.** `ReplicationTask::run` + ([`replication.rs`](../../pgdog/src/api/replication.rs)) `select!`s the + cutover arm (`waiter.cutover().await`) against `token.cancelled()`. Once the + `CUTOVER` signal fires and that arm is chosen, `select!` is committed to + awaiting it; a later `STOP_TASK` cancels the token but no longer has an arm + to win. The supervisor sees the cancelled token and — because + `ReplicationTask` is cooperative (it took its token) — waits + `cancel_timeout()` before aborting. `ReplicationTask` overrides + `cancel_timeout()` to **60 s** (vs. the 5 s trait default in + [`async_task.rs`](../../pgdog/src/api/async_task.rs)), comfortably longer than + the committed swap, so the traffic switch finishes instead of being aborted + mid-flight. Only a swap that hangs past the grace is force-aborted. + +3. **The reverse order winds down without cutover.** If `STOP_TASK` lands + first, the `token.cancelled()` arm sets status `Stopping`, calls + `waiter.stop()` (graceful slot drain), and returns — no traffic switch. A + subsequent `CUTOVER` finds the `CutoverWaiter` already dropped from the + `CUTOVERS` map, so `ReplicationTask::cutover` returns `false` and the admin + command rejects with `NotReplication`. + +Net: `STOP_TASK` gracefully stops a task that has not yet cut over, and is a +deliberate no-op against an in-flight cutover — which completes within its 60 s +grace — and in both cases the task stays visible in the registry until it has +actually stopped. + +### Code references + +| Symbol | File | +|---|---| +| `AsyncTasksStorage::cancel_task` — cancels the token, keeps the entry | [`pgdog/src/api/async_task.rs`](../../pgdog/src/api/async_task.rs) | +| `AsyncTasksStorage::prune` — drops only terminal entries past retention | same file | +| `run_task` supervisor — cooperative grace via `cancel_timeout`; sets terminal status on completion | same file | +| `ReplicationTask::run` / `cancel_timeout` — cutover-vs-stop `select!` | [`pgdog/src/api/replication.rs`](../../pgdog/src/api/replication.rs) | + +--- + +## ✅ Issue 7 — Cutover swapped the cluster name but not `database_name`, repointing entries at a nonexistent database (resolved) + +> **Resolved.** `Config::cutover` now pins the effective `database_name` on the +> two clusters being swapped *before* exchanging their logical `name`, so the +> name swap only changes routing identity and never the physical database an +> entry connects to. + +### Description + +Traffic cutover swaps the *logical* `name` of the source and destination +clusters (`Config::cutover`, `pgdog-config/src/core.rs`) — clients keep +connecting to the same name, which now resolves to the other cluster's backends. + +The physical Postgres database an entry connects to, however, is resolved as +`database_name` **if set, otherwise the cluster `name`** (`Address::from`, +`pgdog/src/backend/pool/address.rs`). `Config::cutover` rewrote `name` but left +`database_name` untouched. So any database entry that relied on the default +(no explicit `database_name`, i.e. cluster name == physical DB name) was, after +the swap, silently repointed at a physical database named after the *other* +cluster — which usually does not exist. + +Two symptoms, one cause: + +- The pool monitor logs `FATAL: 3D000 database "" does not exist` and + retries forever against the dead config. +- The post-cutover schema restore (`schema_sync_post_cutover` → + `PgDumpOutput::restore`) checks out a connection from that pool; since it can + never connect, the checkout waits out `checkout_timeout` and the cutover task + fails with `schema: checkout timeout`, past the point of no return. + +Configs that set `database_name` explicitly on every entry (e.g. the +`integration/resharding` suite, where all entries use `database_name = "postgres"`) +were unaffected, which is why this stayed hidden — the swap is a no-op for the +physical target when `database_name` is already pinned. + +### Cause + +`database_name` defaults to the cluster `name` at connection time, and +`Config::cutover` changed `name` without first materializing that default. The +two fields are only equivalent until the name is swapped. + +### Code references + +| Symbol | File | +|---|---| +| `Config::cutover` — swaps `name`; now pins `database_name` first | [`pgdog-config/src/core.rs`](../../pgdog-config/src/core.rs) | +| `Address::from` — `database_name` falls back to cluster `name` | [`pgdog/src/backend/pool/address.rs`](../../pgdog/src/backend/pool/address.rs) | +| `databases::cutover` — applies the config swap and relaunches pools | [`pgdog/src/backend/databases.rs`](../../pgdog/src/backend/databases.rs) | +| `test_cutover_preserves_physical_database_name` — regression test | [`pgdog-config/src/core.rs`](../../pgdog-config/src/core.rs) | + +### Fix + +Before the name swap, set `database_name` to the current `name` on any entry of +the two clusters where it is unset: + +```rust +for db in self.databases.iter_mut() { + if (db.name == source || db.name == destination) && db.database_name.is_none() { + db.database_name = Some(db.name.clone()); + } +} +``` + +The swap then moves only the logical name; the physical connection target is +preserved. Covered by `test_cutover_preserves_physical_database_name` (asymmetric: +one cluster relies on the default, the other sets `database_name` explicitly). diff --git a/integration/rust/tests/integration/admin/mod.rs b/integration/rust/tests/integration/admin/mod.rs new file mode 100644 index 000000000..8f2ee06b3 --- /dev/null +++ b/integration/rust/tests/integration/admin/mod.rs @@ -0,0 +1,110 @@ +//! Integration tests asserting admin command output over the wire. +//! +//! Each submodule connects to the live PgDog admin database (`rust::setup::admin_sqlx`). +pub mod show_config; +pub mod show_version; +pub mod tasks; + +use sqlx::{Column, Executor, Pool, Postgres, Row, TypeInfo}; + +/// Wire layout expected from `SHOW TASKS`. +const SHOW_TASKS_LAYOUT: &[(&str, &str)] = &[ + ("id", "INT8"), + ("scope", "TEXT"), + ("type", "TEXT"), + ("status", "TEXT"), + ("inner_status", "TEXT"), + ("started_at", "TEXT"), + ("updated_at", "TEXT"), + ("elapsed", "TEXT"), + ("elapsed_ms", "INT8"), +]; + +/// A parsed, validated `SHOW TASKS` row. Built only via [`Tasks::fetch`]. +#[derive(Debug, Clone)] +pub struct Task { + pub id: Option, + pub scope: String, + pub kind: String, + pub status: String, + pub inner_status: String, + pub started_at: String, + pub updated_at: String, + pub elapsed: String, + pub elapsed_ms: i64, +} + +/// Parsed result of `SHOW TASKS`; query with [`Tasks::find`] or [`Tasks::rows`]. +pub struct Tasks { + pub rows: Vec, +} + +impl Tasks { + /// Run `SHOW TASKS`, assert the wire layout and per-row invariants, and + /// return the parsed rows. Panics on any layout/type/invariant violation. + pub async fn fetch(pool: &Pool) -> Self { + let raw = pool.fetch_all("SHOW TASKS").await.unwrap(); + + // assert_layout needs a row; an empty result is valid (no tasks). + if !raw.is_empty() { + assert_layout(&raw, SHOW_TASKS_LAYOUT); + } + + let rows = raw + .iter() + .map(|row| { + let id: Option = row.get("id"); + let scope: String = row.get("scope"); + let status: String = row.get("status"); + let started_at: String = row.get("started_at"); + let updated_at: String = row.get("updated_at"); + let elapsed: String = row.get("elapsed"); + let elapsed_ms: i64 = row.get("elapsed_ms"); + + assert!(!started_at.is_empty(), "task {id:?}: started_at is empty"); + assert!(!updated_at.is_empty(), "task {id:?}: updated_at is empty"); + assert!(!elapsed.is_empty(), "task {id:?}: elapsed is empty"); + assert!(!status.is_empty(), "task {id:?}: status is empty"); + assert!(elapsed_ms >= 0, "task {id:?}: elapsed_ms is negative"); + assert!( + scope == "root" || scope == "subtask", + "task {id:?}: unexpected scope {scope:?}" + ); + + Task { + id, + scope, + kind: row.get("type"), + status, + inner_status: row.get("inner_status"), + started_at, + updated_at, + elapsed, + elapsed_ms, + } + }) + .collect(); + + Self { rows } + } + + /// Return the (root) task with the given id, if present. + pub fn find(&self, id: i64) -> Option<&Task> { + self.rows.iter().find(|t| t.id == Some(id)) + } + + pub fn is_empty(&self) -> bool { + self.rows.is_empty() + } +} + +/// Assert the first row's column layout (name, wire type) matches `expected` exactly. +pub fn assert_layout(rows: &[sqlx::postgres::PgRow], expected: &[(&str, &str)]) { + assert!(!rows.is_empty(), "expected at least one row"); + let actual: Vec<(&str, &str)> = rows[0] + .columns() + .iter() + .map(|col| (col.name(), col.type_info().name())) + .collect(); + assert_eq!(actual, expected, "column layout mismatch"); +} diff --git a/integration/rust/tests/integration/admin/show_config.rs b/integration/rust/tests/integration/admin/show_config.rs new file mode 100644 index 000000000..902c8f6bc --- /dev/null +++ b/integration/rust/tests/integration/admin/show_config.rs @@ -0,0 +1,25 @@ +use std::collections::HashMap; + +use rust::setup::admin_sqlx; +use sqlx::{Executor, Row}; + +use super::assert_layout; + +/// `SHOW CONFIG` returns `name`/`value` TEXT columns, one row per setting. +#[tokio::test] +async fn test_show_config_reports_settings() { + let admin = admin_sqlx().await; + let rows = admin.fetch_all("SHOW CONFIG").await.unwrap(); + + assert_layout(&rows, &[("name", "TEXT"), ("value", "TEXT")]); + + let settings: HashMap = rows + .iter() + .map(|row| (row.get("name"), row.get("value"))) + .collect(); + + assert_eq!(settings["host"], "0.0.0.0"); + assert_eq!(settings["port"], "6432"); + + admin.close().await; +} diff --git a/integration/rust/tests/integration/admin/show_version.rs b/integration/rust/tests/integration/admin/show_version.rs new file mode 100644 index 000000000..ad7f9f073 --- /dev/null +++ b/integration/rust/tests/integration/admin/show_version.rs @@ -0,0 +1,22 @@ +use rust::setup::admin_sqlx; +use sqlx::{Executor, Row}; + +use super::assert_layout; + +/// `SHOW VERSION` returns a single `version` TEXT row carrying the banner. +#[tokio::test] +async fn test_show_version_reports_banner() { + let admin = admin_sqlx().await; + let rows = admin.fetch_all("SHOW VERSION").await.unwrap(); + + assert_eq!(rows.len(), 1, "SHOW VERSION should return exactly one row"); + assert_layout(&rows, &[("version", "TEXT")]); + + let version: &str = rows[0].get("version"); + assert!( + version.starts_with("PgDog v"), + "version should start with the PgDog banner, got: {version:?}" + ); + + admin.close().await; +} diff --git a/integration/rust/tests/integration/admin/tasks.rs b/integration/rust/tests/integration/admin/tasks.rs new file mode 100644 index 000000000..436e75592 --- /dev/null +++ b/integration/rust/tests/integration/admin/tasks.rs @@ -0,0 +1,590 @@ +use std::time::Duration; + +use rust::setup::{admin_sqlx, connection_sqlx_direct, connection_sqlx_direct_db}; +use sqlx::{Executor, Pool, Postgres, Row}; +use tokio::time::{sleep, timeout}; + +use super::Tasks; + +// ─── Constants ────────────────────────────────────────────────────────────── + +/// Source table propagated to the shards; tests run serially and own it exclusively. +const TEST_TABLE: &str = "_pgdog_test_task"; + +const TEST_PUB: &str = "pgdog_test_pub"; + +/// Matches every replication slot these tests create — all auto-named by +/// COPY_DATA / RESHARD / REPLICATE. Used by [`cleanup`]. +const SLOT_FILTER: &str = "slot_name LIKE '__pgdog_repl_%'"; + +/// Interval between status polls in the `wait_*` / `cleanup` loops below; each +/// loop is bounded by a `timeout(window, …)` rather than a fixed iteration count. +const POLL: Duration = Duration::from_millis(200); + +// ─── Helpers ──────────────────────────────────────────────────────────────── + +/// Drop `table` and any orphaned row-type from `pool` (idempotent). +async fn drop_table(pool: &Pool, table: &str) { + let _ = pool + .execute(format!("DROP TABLE IF EXISTS {table} CASCADE").as_str()) + .await; + let _ = pool + .execute(format!("DROP TYPE IF EXISTS {table} CASCADE").as_str()) + .await; +} + +/// Drop `table` from the source and both shards; per-shard failures are ignored. +async fn drop_table_everywhere(table: &str, direct: &Pool) { + drop_table(direct, table).await; + for db in &["shard_0", "shard_1"] { + drop_table(&connection_sqlx_direct_db(db).await, table).await; + } +} + +/// Idempotent cleanup, safe to run before and after each test. +async fn cleanup(admin: &Pool, direct: &Pool) { + // Drain every task to a terminal state *before* touching anything else. An + // in-flight cutover holds backend pools and, past its point of no return, + // runs to completion (then spawns a reverse-replication task) — a `RELOAD` + // mid-cutover would shut those pools down and fail it with "pool is shut + // down". STOP_TASK every still-running task each pass (the reverse-replication + // task spawned during winddown is caught on a later pass) until all are terminal. + timeout(Duration::from_secs(60), async { + let is_terminal = |status: &str| { + matches!(status, "finished" | "cancelled") + || status.starts_with("failed") + || status.starts_with("panicked") + }; + loop { + let tasks = Tasks::fetch(admin).await; + if tasks.rows.iter().all(|t| is_terminal(t.status.as_str())) { + break; + } + // Only stop tasks that are still running; terminal ones are left as-is + // (re-stopping a reverse-replication task spawned during winddown is + // handled on the next pass once it appears as running). + for task in &tasks.rows { + if is_terminal(task.status.as_str()) { + continue; + } + if let Some(id) = task.id { + let _ = admin.execute(format!("STOP_TASK {id}").as_str()).await; + } + } + sleep(POLL).await; + } + }) + .await + .expect("tasks did not drain to a terminal state"); + + // No task is mid-cutover now, so it is safe to restore the topology: a + // completed auto_cutover swaps the db configs in memory (not persisted — + // `cutover_save_config` is off), so RELOAD reloads the pristine on-disk config. + let _ = admin.execute("RELOAD").await; + sleep(Duration::from_millis(500)).await; + + // Terminate WAL senders on any test slot. + let _ = direct + .execute( + format!( + "SELECT pg_terminate_backend(active_pid) \ + FROM pg_replication_slots \ + WHERE ({SLOT_FILTER}) AND active_pid IS NOT NULL" + ) + .as_str(), + ) + .await; + + // Wait for those slots to deactivate. + timeout(Duration::from_secs(10), async { + loop { + let any_active = direct + .fetch_optional(sqlx::query(&format!( + "SELECT bool_or(active) AS active FROM pg_replication_slots WHERE {SLOT_FILTER}" + ))) + .await + .ok() + .flatten() + .and_then(|row: sqlx::postgres::PgRow| row.get::, _>("active")) + .unwrap_or(false); + if !any_active { + break; + } + sleep(POLL).await; + } + }) + .await + .expect("replication slots did not deactivate"); + + // Drop inactive test slots. + let _ = direct + .execute( + format!( + "SELECT pg_drop_replication_slot(slot_name) \ + FROM pg_replication_slots \ + WHERE ({SLOT_FILTER}) AND NOT active" + ) + .as_str(), + ) + .await; + + // Drop the test publication. + let _ = direct + .execute(format!("DROP PUBLICATION IF EXISTS {TEST_PUB}").as_str()) + .await; + + // Drop the shared test table everywhere. + drop_table_everywhere(TEST_TABLE, direct).await; +} + +/// Start `pgdog` -> `pgdog_sharded` replication; return the task id once it +/// appears in `SHOW TASKS`. +async fn start_replication(admin: &Pool, direct: &Pool) -> i64 { + admin.execute("RELOAD").await.unwrap(); + sleep(Duration::from_millis(500)).await; + + direct + .execute(format!("CREATE PUBLICATION {TEST_PUB} FOR ALL TABLES").as_str()) + .await + .unwrap(); + + let row = admin + .fetch_one(format!("REPLICATE pgdog pgdog_sharded {TEST_PUB}").as_str()) + .await + .unwrap(); + let task_id: i64 = row.get::("task_id").parse().unwrap(); + + let appeared = timeout(Duration::from_secs(10), async { + loop { + if Tasks::fetch(admin) + .await + .find(task_id) + .is_some_and(|t| t.kind == "replication pgdog -> pgdog_sharded") + { + return; + } + sleep(POLL).await; + } + }) + .await; + assert!( + appeared.is_ok(), + "replication task {task_id} did not appear in SHOW TASKS in time" + ); + + task_id +} + +/// Poll until `task_id` reaches `status` in `SHOW TASKS` (up to 30 s). +async fn wait_for_task_status(admin: &Pool, task_id: i64, status: &str) { + let result = timeout(Duration::from_secs(30), async { + loop { + if let Some(task) = Tasks::fetch(admin).await.find(task_id) { + if task.status == status { + return; + } + // Fail fast on an unexpected error state — the status text carries + // the failure message — instead of waiting out the window. + if task.status.starts_with("failed") || task.status.starts_with("panicked") { + panic!( + "task {task_id} errored while waiting for {status:?}: {} (inner_status {:?})", + task.status, task.inner_status + ); + } + } + sleep(POLL).await; + } + }) + .await; + if result.is_err() { + panic!("task {task_id} did not reach status {status:?} in SHOW TASKS in time"); + } +} + +/// Current `SHOW TASKS` status line for `task_id`, for timeout diagnostics. +async fn task_status_line(admin: &Pool, task_id: i64) -> String { + match Tasks::fetch(admin).await.find(task_id) { + Some(t) => format!("status {:?}, inner_status {:?}", t.status, t.inner_status), + None => "task absent from SHOW TASKS".to_string(), + } +} + +/// Panic with the task's status (which carries the error message) if `task_id` +/// reached an error state. +async fn fail_if_task_errored(admin: &Pool, task_id: i64) { + if let Some(t) = Tasks::fetch(admin).await.find(task_id) + && (t.status.starts_with("failed") || t.status.starts_with("panicked")) + { + panic!( + "task {task_id} errored: {} (inner_status {:?})", + t.status, t.inner_status + ); + } +} + +/// Whether relation `name` exists on `db` (resolved via the connection's search_path). +async fn relation_present(pool: &Pool, name: &str) -> bool { + pool.fetch_one(format!("SELECT to_regclass('{name}') IS NOT NULL AS present").as_str()) + .await + .unwrap() + .get::("present") +} + +/// Poll until relation `name` exists on both shards (up to 30 s), failing fast +/// with the task's status if `task_id` errors meanwhile. +async fn wait_for_relation_on_shards(admin: &Pool, task_id: i64, name: &str) { + let shard_0 = connection_sqlx_direct_db("shard_0").await; + let shard_1 = connection_sqlx_direct_db("shard_1").await; + let result = timeout(Duration::from_secs(30), async { + loop { + fail_if_task_errored(admin, task_id).await; + if relation_present(&shard_0, name).await && relation_present(&shard_1, name).await { + return; + } + sleep(POLL).await; + } + }) + .await; + if result.is_err() { + panic!( + "relation {name} did not propagate to all shards in time ({})", + task_status_line(admin, task_id).await + ); + } +} + +/// Rows of `table` on shard `db`, queried directly (bypassing pgdog); an absent +/// table counts 0. +async fn shard_row_count(db: &str, table: &str) -> i64 { + let pool = connection_sqlx_direct_db(db).await; + if !relation_present(&pool, table).await { + return 0; + } + pool.fetch_one(format!("SELECT COUNT(*)::bigint AS n FROM {table}").as_str()) + .await + .unwrap() + .get::("n") +} + +/// Poll until every destination shard holds exactly `expected` rows of `table` +/// (up to 30 s). `table` is omni (not in `sharded_tables`), so each shard ends +/// up with a full copy — the reference-table semantics also asserted by +/// `check_omni_each_shard` in the resharding suite. Fails fast on task error. +async fn wait_for_rows_each_shard( + admin: &Pool, + task_id: i64, + table: &str, + expected: i64, +) { + let result = timeout(Duration::from_secs(30), async { + loop { + fail_if_task_errored(admin, task_id).await; + if shard_row_count("shard_0", table).await == expected + && shard_row_count("shard_1", table).await == expected + { + return; + } + sleep(POLL).await; + } + }) + .await; + if result.is_err() { + panic!( + "table {table} did not reach {expected} rows on each shard in time \ + (shard_0={}, shard_1={}, {})", + shard_row_count("shard_0", table).await, + shard_row_count("shard_1", table).await, + task_status_line(admin, task_id).await + ); + } +} + +/// Poll `SHOW TASKS` until a row satisfies `pred` (up to 30 s). +async fn wait_for_task(admin: &Pool, desc: &str, pred: impl Fn(&super::Task) -> bool) { + let result = timeout(Duration::from_secs(30), async { + loop { + if Tasks::fetch(admin).await.rows.iter().any(&pred) { + return; + } + sleep(POLL).await; + } + }) + .await; + if result.is_err() { + panic!("no task matching {desc:?} appeared in SHOW TASKS in time"); + } +} + +/// Create the test table on the source. +async fn create_test_table(direct: &Pool) { + direct + .execute(format!("CREATE TABLE {TEST_TABLE} (id BIGSERIAL PRIMARY KEY, val TEXT)").as_str()) + .await + .unwrap(); +} + +/// Insert `n` rows into the test table. +async fn seed_rows(direct: &Pool, n: i64) { + direct + .execute( + format!( + "INSERT INTO {TEST_TABLE} (val) SELECT 'v' || g FROM generate_series(1, {n}) g" + ) + .as_str(), + ) + .await + .unwrap(); +} + +/// Create the test publication (`FOR TABLE TEST_TABLE`) on the source. +async fn create_publication(direct: &Pool) { + direct + .execute(format!("CREATE PUBLICATION {TEST_PUB} FOR TABLE {TEST_TABLE}").as_str()) + .await + .unwrap(); +} + +/// Run an admin command that returns a `task_id` column; parse and return it. +async fn run_task_command(admin: &Pool, command: &str) -> i64 { + admin + .fetch_one(command) + .await + .unwrap() + .get::("task_id") + .parse() + .unwrap() +} + +// ─── Tests ────────────────────────────────────────────────────────────────── + +/// `STOP_TASK` on an unknown id returns `"task not found"`. +#[tokio::test] +async fn test_stop_nonexistent_task() { + let admin = admin_sqlx().await; + + let row = admin.fetch_one("STOP_TASK 999999999").await.unwrap(); + assert_eq!(row.get::("stop_task"), "task not found"); +} + +/// `CUTOVER` with no replication task errors but leaves the pool usable. +#[tokio::test] +async fn test_cutover_without_replication_task() { + let direct = connection_sqlx_direct().await; + let admin = admin_sqlx().await; + cleanup(&admin, &direct).await; + + let err = admin.fetch_one("CUTOVER").await.unwrap_err(); + assert!( + matches!(err, sqlx::Error::Database(_)), + "expected a database error, got: {err:?}" + ); + admin.fetch_one("SHOW VERSION").await.unwrap(); +} + +/// `STOP_TASK ` cancels a replication task (returns `"OK"`, status -> `cancelled`). +#[tokio::test] +async fn test_stop_task() { + let direct = connection_sqlx_direct().await; + let admin = admin_sqlx().await; + cleanup(&admin, &direct).await; + + let task_id = start_replication(&admin, &direct).await; + + let row = admin + .fetch_one(format!("STOP_TASK {task_id}").as_str()) + .await + .unwrap(); + assert_eq!(row.get::("stop_task"), "OK"); + + wait_for_task_status(&admin, task_id, "cancelled").await; + cleanup(&admin, &direct).await; +} + +/// Operator `CUTOVER` drives a running replication through the full traffic +/// switch to completion. The replication is established with `COPY_DATA` (which +/// loads the schema and leaves a task awaiting an operator cutover — a bare +/// `REPLICATE` never loads the schema, so its cutover would fail in +/// `schema_sync_cutover`). `CUTOVER` returns `"OK"` and the task reaches +/// `finished` once the swap, post-cutover schema sync, and reverse-replication +/// setup all complete. `cleanup` RELOADs to revert the resulting config swap. +#[tokio::test] +async fn test_cutover() { + let direct = connection_sqlx_direct().await; + let admin = admin_sqlx().await; + cleanup(&admin, &direct).await; + + create_test_table(&direct).await; + seed_rows(&direct, 20).await; + create_publication(&direct).await; + + let task_id = + run_task_command(&admin, &format!("COPY_DATA pgdog pgdog_sharded {TEST_PUB}")).await; + + // Wait until the task reaches its replication phase: schema is loaded and it + // is registered to receive an operator CUTOVER. + wait_for_task(&admin, "copy_data replicating", |t| { + t.id == Some(task_id) && t.inner_status == "replicating" + }) + .await; + + // Issue CUTOVER; retry briefly to cover the gap between the replication + // phase being reported and the task registering for operator cutover. + let cutover_ok = timeout(Duration::from_secs(10), async { + loop { + if let Ok(row) = admin.fetch_one("CUTOVER").await + && row.get::("cutover") == "OK" + { + return; + } + sleep(POLL).await; + } + }) + .await; + assert!( + cutover_ok.is_ok(), + "CUTOVER never returned OK ({})", + task_status_line(&admin, task_id).await + ); + + // The switch runs to completion: the task ends in `finished`. + wait_for_task_status(&admin, task_id, "finished").await; + cleanup(&admin, &direct).await; +} + +/// `SCHEMA_SYNC pre` creates the table on both shards. +#[tokio::test] +async fn test_schema_sync_pre() { + let direct = connection_sqlx_direct().await; + let admin = admin_sqlx().await; + cleanup(&admin, &direct).await; + + create_test_table(&direct).await; + create_publication(&direct).await; + + let task_id = run_task_command( + &admin, + &format!("SCHEMA_SYNC pre pgdog pgdog_sharded {TEST_PUB}"), + ) + .await; + + // cleanup dropped the table pre-flight, so its presence proves pre created it. + wait_for_relation_on_shards(&admin, task_id, TEST_TABLE).await; + + cleanup(&admin, &direct).await; +} + +/// `SCHEMA_SYNC post` adds the secondary index, which `pre` does not. +#[tokio::test] +async fn test_schema_sync_post() { + let direct = connection_sqlx_direct().await; + let admin = admin_sqlx().await; + cleanup(&admin, &direct).await; + + let secondary_index = format!("{TEST_TABLE}_val_idx"); + create_test_table(&direct).await; + direct + .execute(format!("CREATE INDEX {secondary_index} ON {TEST_TABLE} (val)").as_str()) + .await + .unwrap(); + create_publication(&direct).await; + + // pre creates the table but not the secondary index (that's post-data). + let pre_task_id = run_task_command( + &admin, + &format!("SCHEMA_SYNC pre pgdog pgdog_sharded {TEST_PUB}"), + ) + .await; + wait_for_relation_on_shards(&admin, pre_task_id, TEST_TABLE).await; + + // post adds the secondary index on both shards. + let task_id = run_task_command( + &admin, + &format!("SCHEMA_SYNC post pgdog pgdog_sharded {TEST_PUB}"), + ) + .await; + + wait_for_relation_on_shards(&admin, task_id, &secondary_index).await; + + cleanup(&admin, &direct).await; +} + +/// `COPY_DATA` syncs schema, copies data, then replicates; assert the table and +/// its rows reach both shards. +#[tokio::test] +async fn test_copy_data() { + let direct = connection_sqlx_direct().await; + let admin = admin_sqlx().await; + cleanup(&admin, &direct).await; + + create_test_table(&direct).await; + seed_rows(&direct, 20).await; + create_publication(&direct).await; + + let row = admin + .fetch_one(format!("COPY_DATA pgdog pgdog_sharded {TEST_PUB}").as_str()) + .await + .unwrap(); + let task_id: i64 = row.get::("task_id").parse().unwrap(); + let slot_name: String = row.get("replication_slot"); + assert!(!slot_name.is_empty(), "replication_slot must be non-empty"); + + // Schema phase creates the table; the omni table is fully copied to every + // shard, so each holds all 20 rows. + wait_for_relation_on_shards(&admin, task_id, TEST_TABLE).await; + wait_for_rows_each_shard(&admin, task_id, TEST_TABLE, 20).await; + + cleanup(&admin, &direct).await; +} + +/// `RESHARD` (auto_cutover) is non-blocking — returns a task id immediately and +/// registers a `reshard` task — copies schema + data to the shards, and enters +/// its auto-cutover replication phase. +#[tokio::test] +async fn test_reshard() { + let direct = connection_sqlx_direct().await; + let admin = admin_sqlx().await; + cleanup(&admin, &direct).await; + + create_test_table(&direct).await; + seed_rows(&direct, 20).await; + create_publication(&direct).await; + + // Non-blocking: RESHARD returns a task id straight away. + let task_id = + run_task_command(&admin, &format!("RESHARD pgdog pgdog_sharded {TEST_PUB}")).await; + + // Registers immediately as a `reshard` task. + wait_for_task(&admin, "reshard task", |t| { + t.id == Some(task_id) && t.kind.starts_with("reshard ") + }) + .await; + + // Schema + data reach both shards (omni table fully copied to each -> 20 per shard). + wait_for_relation_on_shards(&admin, task_id, TEST_TABLE).await; + wait_for_rows_each_shard(&admin, task_id, TEST_TABLE, 20).await; + + // auto_cutover drives into the replication phase on its own. + wait_for_task(&admin, "reshard replicating", |t| { + t.id == Some(task_id) && t.inner_status == "replicating" + }) + .await; + + // Wind down and confirm a terminal state. + let _ = admin.execute(format!("STOP_TASK {task_id}").as_str()).await; + timeout(Duration::from_secs(30), async { + loop { + if Tasks::fetch(&admin) + .await + .find(task_id) + .is_some_and(|t| matches!(t.status.as_str(), "cancelled" | "finished")) + { + return; + } + sleep(POLL).await; + } + }) + .await + .expect("reshard task did not reach a terminal state"); + + cleanup(&admin, &direct).await; +} diff --git a/integration/rust/tests/integration/mod.rs b/integration/rust/tests/integration/mod.rs index ce9d2f971..60132d930 100644 --- a/integration/rust/tests/integration/mod.rs +++ b/integration/rust/tests/integration/mod.rs @@ -1,3 +1,4 @@ +pub mod admin; pub mod admin_termination; pub mod auth; pub mod auto_id; diff --git a/integration/users.toml b/integration/users.toml index 3a552e640..d430e022a 100644 --- a/integration/users.toml +++ b/integration/users.toml @@ -3,6 +3,13 @@ name = "pgdog" database = "pgdog" password = "pgdog" +[[users]] +name = "pgdog_migrator" +database = "pgdog" +password = "pgdog" +server_user = "pgdog" +schema_admin = true + [[users]] name = "pgdog_hashed" database = "pgdog" diff --git a/pgdog-config/src/core.rs b/pgdog-config/src/core.rs index 2c089e9e8..8c2d45d52 100644 --- a/pgdog-config/src/core.rs +++ b/pgdog-config/src/core.rs @@ -637,6 +637,12 @@ impl Config { /// Swap database configs between `source` and `destination`. /// Uses tmp pattern: source -> tmp, destination -> source, tmp -> destination. pub fn cutover(&mut self, source: &str, destination: &str) { + // force setting the database name on cutover to make sure + // the proper database_name is present after the swap. + for db in self.databases.iter_mut() { + db.database_name = db.database_name.take().or(Some(db.name.clone())); + } + let tmp = format!("__tmp_{}__", random_string(12)); crate::swap_field!(self.databases.iter_mut(), name, source, destination, tmp); @@ -1103,6 +1109,57 @@ tables = ["my_table"] ); } + #[test] + fn test_cutover_preserves_physical_database_name() { + // `source_db` relies on the default (physical db == cluster name); + // `destination_db` sets an explicit `database_name`. After cutover the + // physical target of each entry must be unchanged — only the logical + // cluster name swaps. Regression test for entries connecting to a + // nonexistent database after a name swap. + let mut config = Config { + databases: vec![ + Database { + name: "source_db".to_string(), + host: "source-host".to_string(), + port: 5432, + database_name: None, + ..Default::default() + }, + Database { + name: "destination_db".to_string(), + host: "destination-host".to_string(), + port: 5433, + database_name: Some("real_dest".to_string()), + ..Default::default() + }, + ], + ..Default::default() + }; + + config.cutover("source_db", "destination_db"); + + // The entry now named `source_db` carries destination's config; its + // physical target must remain destination's database. + let source = config + .databases + .iter() + .find(|d| d.name == "source_db") + .unwrap(); + assert_eq!(source.host, "destination-host"); + assert_eq!(source.database_name.as_deref(), Some("real_dest")); + + // The entry now named `destination_db` carries source's config. Source + // relied on the default, so its physical target was pinned to the old + // name (`source_db`) — not silently changed to `destination_db`. + let destination = config + .databases + .iter() + .find(|d| d.name == "destination_db") + .unwrap(); + assert_eq!(destination.host, "source-host"); + assert_eq!(destination.database_name.as_deref(), Some("source_db")); + } + #[test] fn test_cutover_visual() { let before = r#" @@ -1183,6 +1240,7 @@ destination_db = "destination_db" let expected_after = r#" [[databases]] name = "destination_db" +database_name = "source_db" host = "source-host-0" port = 5432 role = "primary" @@ -1190,6 +1248,7 @@ shard = 0 [[databases]] name = "destination_db" +database_name = "source_db" host = "source-host-0-replica" port = 5432 role = "replica" @@ -1197,6 +1256,7 @@ shard = 0 [[databases]] name = "destination_db" +database_name = "source_db" host = "source-host-1" port = 5432 role = "primary" @@ -1204,6 +1264,7 @@ shard = 1 [[databases]] name = "destination_db" +database_name = "source_db" host = "source-host-1-replica" port = 5432 role = "replica" @@ -1211,6 +1272,7 @@ shard = 1 [[databases]] name = "source_db" +database_name = "destination_db" host = "destination-host-0" port = 5433 role = "primary" @@ -1218,6 +1280,7 @@ shard = 0 [[databases]] name = "source_db" +database_name = "destination_db" host = "destination-host-0-replica" port = 5433 role = "replica" @@ -1225,6 +1288,7 @@ shard = 0 [[databases]] name = "source_db" +database_name = "destination_db" host = "destination-host-1" port = 5433 role = "primary" @@ -1232,6 +1296,7 @@ shard = 1 [[databases]] name = "source_db" +database_name = "destination_db" host = "destination-host-1-replica" port = 5433 role = "replica" diff --git a/pgdog/Cargo.toml b/pgdog/Cargo.toml index e306afc29..523bc706d 100644 --- a/pgdog/Cargo.toml +++ b/pgdog/Cargo.toml @@ -15,6 +15,7 @@ tui = ["ratatui"] new_parser = ["pg_raw_parse"] [dependencies] +bon.workspace = true pin-project = "1" tokio = { version = "1", features = ["full"] } tracing = "0.1" @@ -22,7 +23,7 @@ tracing-subscriber = { version = "0.3", features = ["env-filter", "json", "std"] tracing-throttle = "0.4" parking_lot = "0.12" thiserror = "2" -derive_more = { version = "2", features = ["display", "error"] } +derive_more = { version = "2", features = ["display", "debug", "error", "from", "from_str"] } bytes = "1" clap = { version = "4", features = ["derive"] } serde = { version = "1", features = ["derive"] } @@ -97,6 +98,7 @@ tempfile = "3.23.0" stats_alloc = "0.1.10" brunch = "0.5" wiremock = "0.6" +tokio = { version = "1", features = ["full", "test-util"] } [[bench]] name = "comment_parser" diff --git a/pgdog/src/admin/copy_data.rs b/pgdog/src/admin/copy_data.rs index 869064edc..e2b122692 100644 --- a/pgdog/src/admin/copy_data.rs +++ b/pgdog/src/admin/copy_data.rs @@ -1,10 +1,9 @@ //! COPY_DATA command. -use tokio::spawn; use tracing::info; -use crate::backend::replication::AsyncTasks; -use crate::backend::replication::logical::admin::{Task, TaskType}; +use crate::api::resharding::ReshardTask; +use crate::api::run_task; use crate::backend::replication::orchestrator::Orchestrator; use super::prelude::*; @@ -54,7 +53,7 @@ impl Command for CopyData { self.from_database, self.to_database, self.publication ); - let mut orchestrator = Orchestrator::new( + let orchestrator = Orchestrator::new( &self.from_database, &self.to_database, &self.publication, @@ -63,20 +62,7 @@ impl Command for CopyData { let slot_name = orchestrator.replication_slot().to_owned(); - let task_id = Task::register(TaskType::CopyData(spawn(async move { - orchestrator.load_schema().await?; - orchestrator.schema_sync_pre(true).await?; - orchestrator.data_sync().await?; - // data_sync can run for hours; any pool reload during the copy marks self.source - // offline. Re-fetch live cluster refs from databases() before starting replication. - orchestrator.refresh()?; - - AsyncTasks::insert(TaskType::Replication(Box::new( - orchestrator.replicate().await?, - ))); - - Ok(()) - }))); + let task_id = run_task(ReshardTask::builder().orchestrator(orchestrator).build()).id(); let mut dr = DataRow::new(); dr.add(task_id.to_string()).add(slot_name); diff --git a/pgdog/src/admin/cutover.rs b/pgdog/src/admin/cutover.rs index c5f8f8435..33c8d72b9 100644 --- a/pgdog/src/admin/cutover.rs +++ b/pgdog/src/admin/cutover.rs @@ -1,8 +1,12 @@ -use crate::backend::replication::logical::admin::AsyncTasks; +use crate::api::async_task::AsyncTaskId; +use crate::api::replication::ReplicationTask; +use crate::backend::replication::logical::Error as ReplicationError; use super::prelude::*; -pub struct Cutover; +pub struct Cutover { + task_id: Option, +} #[async_trait] impl Command for Cutover { @@ -14,13 +18,22 @@ impl Command for Cutover { let parts: Vec<&str> = sql.split_whitespace().collect(); match parts[..] { - ["cutover"] => Ok(Cutover), + ["cutover"] => Ok(Cutover { task_id: None }), + ["cutover", id] => { + let task_id = id.parse().map_err(|_| Error::Syntax)?; + Ok(Cutover { + task_id: Some(task_id), + }) + } _ => Err(Error::Syntax), } } async fn execute(&self) -> Result, Error> { - AsyncTasks::cutover()?; + // With an id, cut over that task; without, the first running one. + if !ReplicationTask::trigger_cutover(self.task_id) { + return Err(ReplicationError::NotReplication.into()); + } let mut dr = DataRow::new(); dr.add("OK"); diff --git a/pgdog/src/admin/parser.rs b/pgdog/src/admin/parser.rs index 692ec874b..b878a405d 100644 --- a/pgdog/src/admin/parser.rs +++ b/pgdog/src/admin/parser.rs @@ -277,7 +277,13 @@ mod tests { #[test] fn parses_cutover_command() { - let result = Parser::parse("CUTOVER"); - assert!(matches!(result, Ok(ParseResult::Cutover(_)))); + assert!(matches!( + Parser::parse("CUTOVER"), + Ok(ParseResult::Cutover(_)) + )); + assert!(matches!( + Parser::parse("CUTOVER 1"), + Ok(ParseResult::Cutover(_)) + )); } } diff --git a/pgdog/src/admin/replicate.rs b/pgdog/src/admin/replicate.rs index 9586366f3..c72b263ed 100644 --- a/pgdog/src/admin/replicate.rs +++ b/pgdog/src/admin/replicate.rs @@ -2,7 +2,8 @@ use tracing::info; -use crate::backend::replication::logical::admin::{Task, TaskType}; +use crate::api::replication::ReplicationTask; +use crate::api::run_task; use crate::backend::replication::orchestrator::Orchestrator; use super::prelude::*; @@ -60,7 +61,7 @@ impl Command for Replicate { )?; let waiter = orchestrator.replicate().await?; - let task_id = Task::register(TaskType::Replication(Box::new(waiter))); + let task_id = run_task(ReplicationTask::builder().waiter(waiter).build()).id(); let mut dr = DataRow::new(); dr.add(task_id.to_string()); diff --git a/pgdog/src/admin/reshard.rs b/pgdog/src/admin/reshard.rs index eb9bf2da4..b0b5ffe71 100644 --- a/pgdog/src/admin/reshard.rs +++ b/pgdog/src/admin/reshard.rs @@ -2,6 +2,8 @@ use tracing::info; +use crate::api::resharding::ReshardTask; +use crate::api::run_task; use crate::backend::replication::orchestrator::Orchestrator; use super::prelude::*; @@ -50,15 +52,27 @@ impl Command for Reshard { r#"resharding "{}" to "{}", publication="{}""#, self.from_database, self.to_database, self.publication ); - let mut orchestrator = Orchestrator::new( + let orchestrator = Orchestrator::new( &self.from_database, &self.to_database, &self.publication, self.replication_slot.clone(), )?; - orchestrator.replicate_and_cutover().await?; + let task_id = run_task( + ReshardTask::builder() + .orchestrator(orchestrator) + .auto_cutover(true) + .build(), + ) + .id(); - Ok(vec![]) + let mut dr = DataRow::new(); + dr.add(task_id.to_string()); + + Ok(vec![ + RowDescription::new(&[Field::text("task_id")]).message()?, + dr.message()?, + ]) } } diff --git a/pgdog/src/admin/schema_sync.rs b/pgdog/src/admin/schema_sync.rs index 2b4947756..2fe2b6c27 100644 --- a/pgdog/src/admin/schema_sync.rs +++ b/pgdog/src/admin/schema_sync.rs @@ -1,19 +1,13 @@ //! SCHEMA_SYNC command. -use tokio::spawn; use tracing::info; -use crate::backend::replication::logical::admin::{Task, TaskType}; +use crate::api::run_task; +use crate::api::schema_sync::{SchemaSyncPhase, SchemaSyncTask}; use crate::backend::replication::orchestrator::Orchestrator; use super::prelude::*; -#[derive(Clone, Copy, PartialEq, Eq)] -pub enum SchemaSyncPhase { - Pre, - Post, -} - pub struct SchemaSync { pub from_database: String, pub to_database: String, @@ -25,10 +19,7 @@ pub struct SchemaSync { #[async_trait] impl Command for SchemaSync { fn name(&self) -> String { - match self.phase { - SchemaSyncPhase::Pre => "SCHEMA_SYNC PRE".into(), - SchemaSyncPhase::Post => "SCHEMA_SYNC POST".into(), - } + format!("SCHEMA_SYNC {}", self.phase).to_uppercase() } fn parse(sql: &str) -> Result { @@ -46,7 +37,7 @@ impl Command for SchemaSync { to_database: to_database.to_owned(), publication: publication.to_owned(), replication_slot: None, - phase: parse_phase(phase)?, + phase: phase.parse().map_err(|_| Error::Syntax)?, }), [ "schema_sync", @@ -60,43 +51,33 @@ impl Command for SchemaSync { to_database: to_database.to_owned(), publication: publication.to_owned(), replication_slot: Some(replication_slot.to_owned()), - phase: parse_phase(phase)?, + phase: phase.parse().map_err(|_| Error::Syntax)?, }), _ => Err(Error::Syntax), } } async fn execute(&self) -> Result, Error> { - let phase_name = match self.phase { - SchemaSyncPhase::Pre => "pre", - SchemaSyncPhase::Post => "post", - }; - info!( r#"schema_sync {} "{}" to "{}", publication="{}""#, - phase_name, self.from_database, self.to_database, self.publication + self.phase, self.from_database, self.to_database, self.publication ); - let mut orchestrator = Orchestrator::new( + let orchestrator = Orchestrator::new( &self.from_database, &self.to_database, &self.publication, self.replication_slot.clone(), )?; - let phase = self.phase; - let handle = spawn(async move { - orchestrator.load_schema().await?; - - match phase { - SchemaSyncPhase::Pre => orchestrator.schema_sync_pre(true).await, - SchemaSyncPhase::Post => orchestrator.schema_sync_post(true).await, - }?; - - Ok(()) - }); - - let task_id = Task::register(TaskType::SchemaSync(handle)); + let task_id = run_task( + SchemaSyncTask::builder() + .orchestrator(orchestrator) + .phase(self.phase) + .ignore_errors(true) + .build(), + ) + .id(); let mut dr = DataRow::new(); dr.add(task_id.to_string()); @@ -107,11 +88,3 @@ impl Command for SchemaSync { ]) } } - -fn parse_phase(phase: &str) -> Result { - match phase { - "pre" => Ok(SchemaSyncPhase::Pre), - "post" => Ok(SchemaSyncPhase::Post), - _ => Err(Error::Syntax), - } -} diff --git a/pgdog/src/admin/show_tasks.rs b/pgdog/src/admin/show_tasks.rs index 3b56d6aaa..73e234f35 100644 --- a/pgdog/src/admin/show_tasks.rs +++ b/pgdog/src/admin/show_tasks.rs @@ -2,7 +2,8 @@ use std::time::SystemTime; use chrono::{DateTime, Local}; -use crate::backend::replication::logical::admin::AsyncTasks; +use crate::api::tasks_storage; +use crate::net::data_row::Data; use crate::util::{format_time, human_duration_display}; use super::prelude::*; @@ -22,28 +23,65 @@ impl Command for ShowTasks { async fn execute(&self) -> Result, Error> { let rd = RowDescription::new(&[ Field::bigint("id"), + Field::text("scope"), Field::text("type"), + Field::text("status"), + Field::text("inner_status"), Field::text("started_at"), + Field::text("updated_at"), Field::text("elapsed"), Field::bigint("elapsed_ms"), ]); let mut messages = vec![rd.message()?]; let now = SystemTime::now(); - for (id, task_kind, started_at) in AsyncTasks::get().iter() { - let elapsed = now.duration_since(started_at).unwrap_or_default(); - let elapsed_ms = elapsed.as_millis() as i64; - let elapsed_str = human_duration_display(elapsed); + // Most-recent task first (highest id). + for task in tasks_storage().tasks().into_iter().rev() { + // A root task plus its subtasks (e.g. the replication child of a + // copy_data/reshard task). Every row reports the root task's id as + // `id` — that is the cancellable handle (STOP_TASK targets it); the + // `scope` column distinguishes the root from its subtasks. Terminal + // tasks stay listed with their final status until pruned from the map. + let root_id = task.id; + let entries = std::iter::once((true, &task.state)) + .chain(task.subtasks.iter().map(|sub| (false, &sub.state))); - let started_at_str = format_time(DateTime::::from(started_at)); + for (is_root, state) in entries { + // Terminal tasks are retained after completion; measure their + // elapsed to the final transition (`updated_at`), not `now`, + // so the duration reflects the actual run rather than ticking + // up until the task is pruned. + let end = if state.is_terminal() { + state.updated_at + } else { + now + }; + let elapsed = end.duration_since(state.started_at).unwrap_or_default(); + let elapsed_ms = elapsed.as_millis() as i64; + let elapsed_str = human_duration_display(elapsed); + let started_at_str = format_time(DateTime::::from(state.started_at)); + let updated_at_str = format_time(DateTime::::from(state.updated_at)); + let status_str = state.status.to_string(); + let inner_str = state.inner_status.clone().unwrap_or_default(); + let scope = if is_root { "root" } else { "subtask" }; - let mut row = DataRow::new(); - row.add(id as i64) - .add(task_kind.to_string().as_str()) - .add(started_at_str.as_str()) - .add(elapsed_str.as_str()) - .add(elapsed_ms); - messages.push(row.message()?); + let mut row = DataRow::new(); + // Subtasks share their root's id; only the root row carries it. + if is_root { + row.add(root_id); + } else { + row.add(Data::null()); + } + row.add(scope) + .add(state.name.as_str()) + .add(status_str.as_str()) + .add(inner_str.as_str()) + .add(started_at_str.as_str()) + .add(updated_at_str.as_str()) + .add(elapsed_str.as_str()) + .add(elapsed_ms); + messages.push(row.message()?); + } } Ok(messages) diff --git a/pgdog/src/admin/stop_task.rs b/pgdog/src/admin/stop_task.rs index 70be0b857..7e06211d7 100644 --- a/pgdog/src/admin/stop_task.rs +++ b/pgdog/src/admin/stop_task.rs @@ -1,10 +1,10 @@ -use crate::backend::replication::logical::admin::{AsyncTasks, TaskKind}; -use crate::net::messages::{ErrorResponse, NoticeResponse}; +use crate::api::async_task::AsyncTaskId; +use crate::api::tasks_storage; use super::prelude::*; pub struct StopTask { - task_id: u64, + task_id: AsyncTaskId, } #[async_trait] @@ -26,23 +26,14 @@ impl Command for StopTask { } async fn execute(&self) -> Result, Error> { - let task_kind = AsyncTasks::remove(self.task_id); + let cancelled = tasks_storage().cancel_task(self.task_id); let mut messages = vec![]; - if task_kind == Some(TaskKind::CopyData) { - let notice = NoticeResponse::from(ErrorResponse { - severity: "WARNING".into(), - code: "01000".into(), - message: "replication slot was not dropped and requires manual cleanup".into(), - ..Default::default() - }); - messages.push(notice.message()?); - } - - let result = match task_kind { - Some(_) => "OK", - None => "task not found", + let result = if cancelled.is_some() { + "OK" + } else { + "task not found" }; let mut dr = DataRow::new(); diff --git a/pgdog/src/api/async_task.rs b/pgdog/src/api/async_task.rs new file mode 100644 index 000000000..5f638e780 --- /dev/null +++ b/pgdog/src/api/async_task.rs @@ -0,0 +1,1347 @@ +use std::fmt::Debug; +use std::fmt::Display; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::task::{Context, Poll}; +use std::time::{Duration, SystemTime}; + +use dashmap::DashMap; +use parking_lot::RwLock; +use pgdog_postgres_types::ToDataRowColumn; +use tokio::select; +use tokio::sync::oneshot::{self, Receiver}; +use tokio::time::timeout; +use tokio_util::sync::CancellationToken; +use tracing::{Instrument, Span, error, info, info_span, warn}; + +/// Represent the ID of the async task. +#[derive(Copy, Clone, Debug, Display, FromStr, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct AsyncTaskId(u64); + +impl ToDataRowColumn for AsyncTaskId { + fn to_data_row_column(&self) -> pgdog_postgres_types::Data { + self.0.to_data_row_column() + } +} + +#[cfg(test)] +impl From for AsyncTaskId { + fn from(value: u64) -> Self { + AsyncTaskId(value) + } +} + +/// Status type for tasks that report no intermediate progress. +pub type Empty = std::convert::Infallible; + +/// A composable async task. +pub trait Task: Display + Debug + Send + Sync + Sized + 'static { + /// Progress-status payload reported through + /// [`set_status`](AsyncTaskContext::set_status) while the task + /// runs; the [`Empty`] type reports no intermediate progress. + type Status: Display + Send + Sync + 'static; + /// Value the task resolves to on success. + type Output: Send + 'static; + /// Error the task may fail with. + type Error: std::error::Error + Send + 'static; + + /// Grace period for cooperative shutdown after cancellation; + /// once it expires, the task is force-aborted. + fn cancel_timeout() -> Duration { + Duration::from_secs(5) + } + + /// Async task main execution logic + fn run( + self, + ctx: AsyncTaskContext, + ) -> impl Future> + Send + 'static; +} + +/// Predefined lifecycle status of a task — a fixed, enumerable set, +/// independent of the task's domain-specific progress (which is tracked +/// separately as the inner status). +#[derive(Display, Debug, Clone, PartialEq, Eq)] +pub enum TaskStatus { + #[display("started")] + Started, + #[display("running")] + Running, + #[display("finished")] + Finished, + /// Cancellation has been requested; the task is winding down + /// cooperatively and has not yet reached a terminal state. + #[display("cancelling")] + Cancelling, + #[display("cancelled")] + Cancelled, + #[display("failed: {_0}")] + Error(String), + #[display("panicked: {_0}")] + Panic(String), +} + +/// Type-erased snapshot of a task's current state, +/// readable through the registry without knowing `T`. +#[derive(Debug, Clone)] +pub struct TaskState { + pub name: String, + /// Predefined lifecycle status (carries the error/panic message + /// for its terminal failure variants). + pub status: TaskStatus, + /// Last inner progress reported by the task, rendered to a string. + /// Preserved across terminal transitions, so a failed or cancelled + /// task still shows its last known progress. + pub inner_status: Option, + pub started_at: SystemTime, + pub updated_at: SystemTime, +} + +impl TaskState { + /// Whether the task reached a terminal state (finished, cancelled, + /// errored, or panicked) and is only retained for status reporting. + pub fn is_terminal(&self) -> bool { + self.status.is_terminal() + } +} + +/// Why a task did not complete, delivered to the waiter +/// as the error half of its `Result`. +#[derive(Debug, Display, Error)] +pub enum TaskError { + /// The task itself returned an error. + #[display("task failed: {_0}")] + Failed(E), + /// The task was cancelled. + #[display("task was cancelled")] + Cancelled, + /// The task panicked. + #[display("task panicked: {_0}")] + Panicked(#[error(ignore)] String), + /// The task's result was never delivered: the watcher + /// died without reporting (e.g. runtime shutdown). + #[display("task result was never delivered")] + Abandoned, +} + +impl TaskStatus { + /// Whether the task reached a terminal state. + fn is_terminal(&self) -> bool { + matches!( + self, + Self::Finished | Self::Cancelled | Self::Error(_) | Self::Panic(_) + ) + } +} + +/// Represent the storage of tasks based on it's id +#[derive(Default)] +struct TasksMap { + map: DashMap>, + counter: AtomicU64, +} + +impl TasksMap { + fn next_id(&self) -> AsyncTaskId { + AsyncTaskId(self.counter.fetch_add(1, Ordering::Relaxed)) + } + + fn insert(&self, id: AsyncTaskId, value: Arc) { + self.map.insert(id, value); + } +} + +/// Mutable state of the async task that is updated +/// during the execution and status updates. +struct AsyncTaskState { + updated_at: SystemTime, + /// Predefined lifecycle status (carries the error/panic message for + /// its terminal failure variants). + status: TaskStatus, + /// Last inner progress reported by the task; kept across terminal + /// transitions so failed/cancelled tasks retain their last progress. + inner_status: Option, +} + +impl AsyncTaskState { + fn new() -> Self { + Self { + updated_at: SystemTime::now(), + status: TaskStatus::Started, + inner_status: None, + } + } +} + +/// Represent the info about queued task +struct AsyncTask { + started_at: SystemTime, + /// Id of the root task this task belongs to — its own id when it is a + /// root, inherited from the parent otherwise. + root_id: AsyncTaskId, + /// Name of task based on [Task] Display implementation + name: String, + cancellation_token: CancellationToken, + /// Set once the task asks for its cancellation token: only + /// then can it react to cancellation, so only then we'll + /// wait for the cancellation to finish. + cooperative: AtomicBool, + /// Mutable state of the task + state: Arc>>, + /// The reference to the root map of tasks + subtasks: Arc, + /// The tracing span associated with the task + tracing_span: Span, +} + +/// Wrapper trait for [AsyncTask] that is not tied to specific +/// type T that allows to store tasks of different types inside. +trait TaskMapEntry: Send + Sync + 'static { + fn cancel(&self); + fn state(&self) -> TaskState; + fn subtasks(&self) -> &TasksMap; + fn expired(&self, now: SystemTime, ttl: Duration) -> bool; +} + +impl TaskMapEntry for AsyncTask { + fn cancel(&self) { + self.cancellation_token.cancel(); + } + + fn state(&self) -> TaskState { + let state = self.state.read(); + + TaskState { + name: self.name.clone(), + status: state.status.clone(), + inner_status: state.inner_status.as_ref().map(ToString::to_string), + started_at: self.started_at, + updated_at: state.updated_at, + } + } + + fn subtasks(&self) -> &TasksMap { + &self.subtasks + } + + fn expired(&self, now: SystemTime, ttl: Duration) -> bool { + let state = self.state.read(); + state.status.is_terminal() + && now + .duration_since(state.updated_at) + .is_ok_and(|age| age >= ttl) + } +} + +/// Context that is passed to the [Task::run] +pub struct AsyncTaskContext { + task: Arc>, +} + +impl Clone for AsyncTaskContext { + fn clone(&self) -> Self { + Self { + task: self.task.clone(), + } + } +} + +/// Handle to a spawned task. Resolves, as a future, to the +/// task's result; also exposes the registry id of the task. +#[derive(derive_more::Debug)] +pub struct AsyncTaskWaiter { + id: AsyncTaskId, + #[debug(ignore)] + waiter: Receiver>>, +} + +impl AsyncTaskWaiter { + pub fn id(&self) -> AsyncTaskId { + self.id + } +} + +impl Future for AsyncTaskWaiter { + type Output = Result>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.get_mut().waiter).poll(cx).map(|res| { + // A dropped sender means the watcher died without + // reporting; surface that as a task error instead of + // leaking the channel's RecvError. + res.unwrap_or_else(|_| Err(TaskError::Abandoned)) + }) + } +} + +/// Point-in-time view of a root task and its subtasks. +/// +/// All descendants register with the root task, so `subtasks` is a +/// flat list of every descendant (children and grandchildren alike), +/// ordered by id; their own `subtasks` are always empty. +#[derive(Debug, Clone)] +pub struct TaskSnapshot { + pub id: AsyncTaskId, + pub state: TaskState, + pub subtasks: Vec, +} + +impl TaskSnapshot { + fn new(id: AsyncTaskId, entry: &dyn TaskMapEntry) -> Self { + let mut subtasks: Vec<_> = entry + .subtasks() + .map + .iter() + .map(|sub| TaskSnapshot { + id: *sub.key(), + state: sub.value().state(), + subtasks: Vec::new(), + }) + .collect(); + subtasks.sort_unstable_by_key(|task| task.id); + + Self { + id, + state: entry.state(), + subtasks, + } + } +} + +/// Finished tasks stay visible in the registry for this long +/// before being pruned. +const TASK_RETENTION: Duration = Duration::from_secs(24 * 60 * 60); + +/// The main storage for async tasks +pub struct AsyncTasksStorage { + tasks: Arc, + retention: Duration, +} + +impl Default for AsyncTasksStorage { + fn default() -> Self { + Self::new(TASK_RETENTION) + } +} + +fn run_task( + parent_task: Option<&AsyncTask

>, + tasks: &TasksMap, + task: T, +) -> AsyncTaskWaiter { + // Allocate the id up front so a root task can record its own id as its + // root id; descendants inherit the root's id from their parent. + let id = tasks.next_id(); + let root_id = parent_task.map(|p| p.root_id).unwrap_or(id); + + let state = Arc::new(RwLock::new(AsyncTaskState::new())); + let name = task.to_string(); + + let span = if let Some(parent_task) = parent_task { + let parent_span = &parent_task.tracing_span; + let parent_name = &parent_task.name; + info!( + "Starting new subtask '{name}' (id: {id}) for parent task '{parent_name}' (root_id: {root_id})" + ); + info_span!(parent: parent_span, "task", %id) + } else { + info!("Starting new task '{name}' (id: {id})"); + info_span!(parent: None, "task", %id) + }; + + let entry = AsyncTask { + started_at: SystemTime::now(), + name, + root_id, + cancellation_token: match parent_task { + Some(parent) => parent.cancellation_token.child_token(), + None => CancellationToken::new(), + }, + cooperative: AtomicBool::new(false), + // Descendants share the root task's registry: every descendant + // registers as a direct child of the root. + subtasks: if let Some(parent) = parent_task { + parent.subtasks.clone() + } else { + Arc::new(TasksMap::default()) + }, + state: state.clone(), + tracing_span: span.clone(), + }; + + let entry = Arc::new(entry); + // Make sure we insert task to map before it's actually started. + tasks.insert(id, entry.clone()); + + let ctx = AsyncTaskContext { + task: entry.clone(), + }; + + let mut handle = tokio::spawn(task.run(ctx.clone()).instrument(span)); + let (sender, receiver) = oneshot::channel(); + + let cancellation_token = entry.cancellation_token.clone(); + + tokio::spawn(async move { + let res = select! { + _ = cancellation_token.cancelled() => { + ctx.transition(TaskStatus::Cancelling); + if ctx.task.cooperative.load(Ordering::Relaxed) { + match timeout(T::cancel_timeout(), &mut handle).await { + Ok(res) => res, + Err(_) => { + // The timeout fired: abort the task + handle.abort(); + handle.await + } + } + } else { + // The task never took its cancellation token, so it + // cannot react to it: abort immediately. + handle.abort(); + handle.await + } + } + res = &mut handle => { + res + } + }; + + match res { + Ok(Ok(res)) => { + let status = if cancellation_token.is_cancelled() { + TaskStatus::Cancelled + } else { + TaskStatus::Finished + }; + ctx.transition(status); + let _ = sender.send(Ok(res)); + } + Ok(Err(err)) => { + error!("task failed: {err}"); + ctx.transition(TaskStatus::Error(err.to_string())); + let _ = sender.send(Err(TaskError::Failed(err))); + } + Err(err) if err.is_cancelled() => { + ctx.transition(TaskStatus::Cancelled); + let _ = sender.send(Err(TaskError::Cancelled)); + } + Err(err) => { + let panic = err.to_string(); + error!("task panicked: {panic}"); + ctx.transition(TaskStatus::Panic(panic.clone())); + let _ = sender.send(Err(TaskError::Panicked(panic))); + } + } + }); + + AsyncTaskWaiter { + id, + waiter: receiver, + } +} + +impl AsyncTaskContext { + /// Move the task to a new lifecycle `status` (terminal or the + /// non-terminal `Cancelling`), preserving the last inner progress. + /// No-op once the task is already terminal. + fn transition(&self, status: TaskStatus) { + let _enter = self.task.tracing_span.enter(); + + let mut state = self.task.state.write(); + if state.status.is_terminal() { + return; + } + + info!("state transition to: {status}"); + + state.status = status; + state.updated_at = SystemTime::now(); + } + + /// Update the inner progress status of the current task. + pub fn set_status(&self, status: T::Status) { + let _enter = self.task.tracing_span.enter(); + + let mut state = self.task.state.write(); + if state.status.is_terminal() { + return; + } + + info!("inner state transition to: {status}"); + + // Don't regress a cancellation-in-progress back to Running; the task + // may still report inner progress while it winds down. + if state.status != TaskStatus::Cancelling { + state.status = TaskStatus::Running; + } + + state.inner_status = Some(status); + state.updated_at = SystemTime::now(); + } + + /// Hand out this task's cancellation token. + /// Tasks that never take it are aborted immediately. + pub fn cancellation_token(&self) -> CancellationToken { + self.task.cooperative.store(true, Ordering::Relaxed); + + self.task.cancellation_token.clone() + } + + /// Run the new task as a subtask of the current one + pub fn run(&self, task: T1) -> AsyncTaskWaiter { + run_task(Some(&self.task), &self.task.subtasks, task) + } + + /// Id of the root task this task belongs to (its own id when it is a + /// root). Used to address the task's cutover signal. + pub fn root_id(&self) -> AsyncTaskId { + self.task.root_id + } +} + +impl AsyncTasksStorage { + pub fn new(retention: Duration) -> Self { + Self { + tasks: Arc::default(), + retention, + } + } + + /// Schedule the new task as a root task for execution + pub fn run(&self, task: T) -> AsyncTaskWaiter { + self.prune(); + + run_task::(None, &self.tasks, task) + } + + /// Request cancellation of a task. The task winds down + /// cooperatively (or is aborted after the grace period) and + /// stays in the registry with a terminal status until pruned. + /// Returns the state the task was in when cancellation was requested, + /// or `None` for an unknown or already-terminal id. + pub fn cancel_task(&self, id: AsyncTaskId) -> Option { + let entry = self.tasks.map.get(&id)?; + let state = entry.state(); + + // Already terminal (finished/cancelled/errored) but still retained for + // status reporting: nothing to cancel. Report as not found so callers + // don't claim success or emit cleanup warnings for a finished task. + if state.is_terminal() { + warn!("Task: {id} is already in terminal state and cannot be cancelled"); + return None; + } + + entry.cancel(); + + Some(state) + } + + /// Drop every task that reached a terminal state more than + /// `retention` ago; running tasks are never dropped. + fn prune(&self) { + let now = SystemTime::now(); + + self.tasks.map.retain(|_, entry| { + entry + .subtasks() + .map + .retain(|_, sub| !sub.expired(now, self.retention)); + + !entry.expired(now, self.retention) + }); + } + + /// Snapshot every root task with its subtasks, ordered by id. + pub fn tasks(&self) -> Vec { + self.prune(); + + let mut tasks: Vec<_> = self + .tasks + .map + .iter() + .map(|entry| TaskSnapshot::new(*entry.key(), entry.value().as_ref())) + .collect(); + tasks.sort_unstable_by_key(|task| task.id); + tasks + } + + /// Snapshot a single root task by id. + pub fn task(&self, id: AsyncTaskId) -> Option { + self.prune(); + + self.tasks + .map + .get(&id) + .map(|entry| TaskSnapshot::new(id, entry.value().as_ref())) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use parking_lot::Mutex; + use std::convert::Infallible; + use std::fmt::Debug; + use std::sync::Arc; + use tokio::sync::Notify; + use tokio::task::yield_now; + use tokio::test; + use tokio::time::sleep; + + type State = Arc>; + + #[derive(Display, Debug)] + enum TestTaskStatus { + #[display("StepOne")] + StepOne, + #[display("StepTwo")] + StepTwo, + } + + /// What a [`Mock`] does once notified. + #[derive(Clone, Copy, Debug)] + enum Outcome { + /// Mark "finished" and succeed. + Succeed, + /// Mark "failed" and return an error. + Fail, + /// Panic, leaving the state at "started". + Panic, + } + + /// Sets "started", waits on `notify`, then resolves per `outcome`. + #[derive(Display, Debug)] + #[display("mock")] + struct Mock { + state: State, + notify: Arc, + outcome: Outcome, + } + + impl Task for Mock { + type Status = Empty; + type Output = (); + type Error = std::io::Error; + + async fn run(self, _ctx: AsyncTaskContext) -> Result<(), std::io::Error> { + *self.state.lock() = "started"; + self.notify.notified().await; + match self.outcome { + Outcome::Succeed => { + *self.state.lock() = "finished"; + Ok(()) + } + Outcome::Fail => { + *self.state.lock() = "failed"; + Err(std::io::Error::other("mock task failure")) + } + Outcome::Panic => panic!("panicking task"), + } + } + } + + macro_rules! mock_successful { + ($state:ident, $notify:ident) => { + Mock { + state: $state.clone(), + notify: $notify.clone(), + outcome: Outcome::Succeed, + } + }; + } + + macro_rules! mock_failing { + ($state:ident, $notify:ident) => { + Mock { + state: $state.clone(), + notify: $notify.clone(), + outcome: Outcome::Fail, + } + }; + } + + macro_rules! mock_panicking { + ($state:ident, $notify:ident) => { + Mock { + state: $state.clone(), + notify: $notify.clone(), + outcome: Outcome::Panic, + } + }; + } + + /// Waits on its gate, then succeeds. Never takes its cancellation + /// token, so it is aborted immediately on cancellation. + #[derive(Display, Debug)] + #[display("anonymous")] + struct Gate { + gate: Arc, + } + + impl Task for Gate { + type Status = Empty; + type Output = (); + type Error = Infallible; + + async fn run(self, _ctx: AsyncTaskContext) -> Result<(), Infallible> { + self.gate.notified().await; + Ok(()) + } + } + + /// Immediately succeeds. + #[derive(Display, Debug)] + #[display("noop")] + struct Noop; + + impl Task for Noop { + type Status = Empty; + type Output = (); + type Error = Infallible; + + async fn run(self, _ctx: AsyncTaskContext) -> Result<(), Infallible> { + Ok(()) + } + } + + /// Runs a single child task of any type (spawned via `ctx.run`), + /// propagating its error if it fails, then waits on `notify` + /// before finishing. + #[derive(Display, Debug)] + #[display("inner")] + struct Inner { + state: State, + notify: Arc, + child: C, + } + + impl Task for Inner { + type Status = Empty; + type Output = (); + type Error = TaskError; + + async fn run(self, ctx: AsyncTaskContext) -> Result<(), TaskError> { + *self.state.lock() = "started"; + + ctx.run(self.child).await?; + + self.notify.notified().await; + + *self.state.lock() = "finished"; + + Ok(()) + } + } + + /// Takes its cancellation token, then waits on `notify` or + /// cancellation. On cancellation it winds down for `wind_down` + /// before stopping, returning `true`: shorter than the 30s grace + /// window delivers a graceful `Ok(true)`, longer is force-aborted + /// when the grace period expires. + #[derive(Display, Debug)] + #[display("cancellable")] + struct Cancellable { + state: State, + notify: Arc, + wind_down: Duration, + } + + impl Task for Cancellable { + type Status = Empty; + type Output = bool; + type Error = Infallible; + + fn cancel_timeout() -> Duration { + Duration::from_secs(30) + } + + async fn run(self, ctx: AsyncTaskContext) -> Result { + *self.state.lock() = "started"; + + let token = ctx.cancellation_token(); + tokio::select! { + _ = self.notify.notified() => {} + _ = token.cancelled() => { + sleep(self.wind_down).await; + *self.state.lock() = "cancelled"; + return Ok(true); + } + } + + *self.state.lock() = "finished"; + Ok(true) + } + } + + /// Subtask that spawns a grandchild gate (registered with the root) + /// and waits for it. + #[derive(Display, Debug)] + #[display("anonymous")] + struct Sub { + sub_gate: Arc, + } + + impl Task for Sub { + type Status = Empty; + type Output = (); + type Error = Infallible; + + async fn run(self, ctx: AsyncTaskContext) -> Result<(), Infallible> { + ctx.run(Gate { + gate: self.sub_gate, + }) + .await + .unwrap(); + + Ok(()) + } + } + + /// Root task that reports `TestTaskStatus`, runs a subtask, then + /// advances its status and waits. + #[derive(Display, Debug)] + #[display("test_task")] + struct TraverseRoot { + sub_gate: Arc, + parent_gate: Arc, + } + + impl Task for TraverseRoot { + type Status = TestTaskStatus; + type Output = (); + type Error = Infallible; + + async fn run(self, ctx: AsyncTaskContext) -> Result<(), Infallible> { + ctx.set_status(TestTaskStatus::StepOne); + + let sub = ctx.run(Sub { + sub_gate: self.sub_gate, + }); + + sub.await.unwrap(); + + ctx.set_status(TestTaskStatus::StepTwo); + + self.parent_gate.notified().await; + + Ok(()) + } + } + + /// Spawned tasks are only guaranteed to be polled after the + /// current task yields; a few rounds cover spawn chains. + async fn settle() { + for _ in 0..5 { + yield_now().await; + } + } + + #[test] + async fn test_single_execution() { + let notify = Arc::new(Notify::new()); + let state_a = Arc::new(Mutex::new("initial")); + let a = mock_successful!(state_a, notify); + + let async_storage = AsyncTasksStorage::default(); + + let task = async_storage.run(a); + + settle().await; + + assert_eq!(*state_a.lock(), "started"); + + notify.notify_one(); + + task.await.unwrap(); + + assert_eq!(*state_a.lock(), "finished"); + } + + #[test] + async fn test_multiple_execution() { + let notify = Arc::new(Notify::new()); + let state_a = Arc::new(Mutex::new("initial")); + let state_b = Arc::new(Mutex::new("initial")); + let state_c = Arc::new(Mutex::new("initial")); + let a = mock_successful!(state_a, notify); + let b = mock_successful!(state_b, notify); + let c = mock_successful!(state_c, notify); + + let async_storage = AsyncTasksStorage::default(); + + let task_a = async_storage.run(a); + let task_b = async_storage.run(b); + + let info = async_storage.task(task_a.id()).unwrap(); + + assert_eq!(async_storage.tasks().len(), 2); + + assert!(matches!(info.state.status, TaskStatus::Started)); + + let info = async_storage.task(task_b.id()).unwrap(); + + assert!(matches!(info.state.status, TaskStatus::Started)); + + settle().await; + + assert_eq!(*state_a.lock(), "started"); + assert_eq!(*state_b.lock(), "started"); + assert_eq!(*state_c.lock(), "initial"); + + notify.notify_one(); + notify.notify_one(); + + let task_c = async_storage.run(c); + + notify.notify_one(); + + task_c.await.unwrap(); + + assert_eq!(*state_a.lock(), "finished"); + assert_eq!(*state_b.lock(), "finished"); + assert_eq!(*state_c.lock(), "finished"); + } + + #[test] + async fn test_inner_execution() { + let child_notify = Arc::new(Notify::new()); + let notify = Arc::new(Notify::new()); + let state_a = Arc::new(Mutex::new("initial")); + let state_c = Arc::new(Mutex::new("initial")); + let c = Inner { + state: state_c.clone(), + notify: notify.clone(), + child: mock_successful!(state_a, child_notify), + }; + + let async_storage = AsyncTasksStorage::default(); + + let task_c = async_storage.run(c); + + settle().await; + + assert_eq!(*state_a.lock(), "started"); + assert_eq!(*state_c.lock(), "started"); + + // Release the child; the root then parks on its gate. + child_notify.notify_waiters(); + settle().await; + + assert_eq!(*state_a.lock(), "finished"); + assert_eq!(*state_c.lock(), "started"); + + // Open the gate; the root finishes. + notify.notify_one(); + + task_c.await.unwrap(); + + assert_eq!(*state_c.lock(), "finished"); + } + + #[test(start_paused = true)] + async fn test_single_cancel() { + let notify = Arc::new(Notify::new()); + let state_a = Arc::new(Mutex::new("initial")); + let a = mock_successful!(state_a, notify); + + let async_storage = AsyncTasksStorage::default(); + + let task = async_storage.run(a); + + settle().await; + + assert_eq!(*state_a.lock(), "started"); + + let task_id = task.id(); + async_storage.cancel_task(task_id); + + let res = task.await; + assert!(matches!(res, Err(TaskError::Cancelled))); + + assert_eq!(*state_a.lock(), "started"); + + // Cancelled tasks stay visible with a terminal status. + let snapshot = async_storage.task(task_id).unwrap(); + assert!(matches!(snapshot.state.status, TaskStatus::Cancelled)); + } + + #[test(start_paused = true)] + async fn test_graceful_exit_during_cancel_grace() { + let notify = Arc::new(Notify::new()); + let state = Arc::new(Mutex::new("initial")); + let a = Cancellable { + state: state.clone(), + notify: notify.clone(), + wind_down: Duration::from_secs(1), + }; + + let async_storage = AsyncTasksStorage::default(); + + let task = async_storage.run(a); + let task_id = task.id(); + + settle().await; + + assert_eq!(*state.lock(), "started"); + + async_storage.cancel_task(task_id); + + // The task observed the token and finished gracefully: its + // result must be delivered, not discarded or lost to a + // watcher panic. + let res = task.await; + assert!(res.unwrap()); + + assert_eq!(*state.lock(), "cancelled"); + + // Cancellation was requested, so the terminal status is `Cancelled` + // even though the task wound down cleanly to `Ok(42)`. + let snapshot = async_storage.task(task_id).unwrap(); + assert!(matches!(snapshot.state.status, TaskStatus::Cancelled)); + } + + #[test(start_paused = true)] + async fn test_cancelling_status_visible_during_wind_down() { + let notify = Arc::new(Notify::new()); + let state = Arc::new(Mutex::new("initial")); + let a = Cancellable { + state: state.clone(), + notify, + wind_down: Duration::from_secs(5), + }; + + let async_storage = AsyncTasksStorage::default(); + + let task = async_storage.run(a); + let task_id = task.id(); + + settle().await; + async_storage.cancel_task(task_id); + settle().await; + + // Cancellation requested; the task is still winding down (sleeping + // `wind_down`) and must report a non-terminal `Cancelling` status. + let snapshot = async_storage.task(task_id).unwrap(); + assert!(matches!(snapshot.state.status, TaskStatus::Cancelling)); + assert!(!snapshot.state.is_terminal()); + + // Once it finishes winding down, the status settles to terminal. + let res = task.await; + assert!(res.unwrap()); + let snapshot = async_storage.task(task_id).unwrap(); + assert!(matches!(snapshot.state.status, TaskStatus::Cancelled)); + } + + #[test(start_paused = true)] + async fn test_inner_cancel() { + let notify = Arc::new(Notify::new()); + let state_a = Arc::new(Mutex::new("initial")); + let state_c = Arc::new(Mutex::new("initial")); + let c = Inner { + state: state_c.clone(), + notify: notify.clone(), + child: mock_successful!(state_a, notify), + }; + + let async_storage = AsyncTasksStorage::default(); + + let task_c = async_storage.run(c); + + settle().await; + + assert_eq!(*state_a.lock(), "started"); + assert_eq!(*state_c.lock(), "started"); + + async_storage.cancel_task(task_c.id()); + + let res = task_c.await; + assert!(matches!(res, Err(TaskError::Cancelled))); + + assert_eq!(*state_a.lock(), "started"); + assert_eq!(*state_c.lock(), "started"); + } + + #[test] + async fn test_single_error() { + let notify = Arc::new(Notify::new()); + let state_a = Arc::new(Mutex::new("initial")); + let a = mock_failing!(state_a, notify); + + let async_storage = AsyncTasksStorage::default(); + + let task = async_storage.run(a); + + settle().await; + + assert_eq!(*state_a.lock(), "started"); + + notify.notify_one(); + + let task_id = task.id(); + let res = task.await; + assert!(matches!(res, Err(TaskError::Failed(_)))); + + assert_eq!(*state_a.lock(), "failed"); + + let info = async_storage.task(task_id).unwrap(); + assert!(matches!(info.state.status, TaskStatus::Error(_))); + } + + #[test] + async fn test_inner_error() { + let notify = Arc::new(Notify::new()); + let state_a = Arc::new(Mutex::new("initial")); + let state_c = Arc::new(Mutex::new("initial")); + let c = Inner { + state: state_c.clone(), + notify: notify.clone(), + child: mock_failing!(state_a, notify), + }; + + let async_storage = AsyncTasksStorage::default(); + + let task_c = async_storage.run(c); + + settle().await; + + assert_eq!(*state_a.lock(), "started"); + assert_eq!(*state_c.lock(), "started"); + + notify.notify_one(); + + // The child's failure propagates out of the root task. + let res = task_c.await; + assert!(matches!(res, Err(TaskError::Failed(_)))); + + assert_eq!(*state_a.lock(), "failed"); + assert_eq!(*state_c.lock(), "started"); + } + + #[test] + async fn test_panic() { + let notify = Arc::new(Notify::new()); + let state_a = Arc::new(Mutex::new("initial")); + let a = mock_panicking!(state_a, notify); + + let async_storage = AsyncTasksStorage::default(); + + let task = async_storage.run(a); + + settle().await; + + assert_eq!(*state_a.lock(), "started"); + + notify.notify_one(); + + let task_id = task.id(); + let res = task.await; + assert!(matches!(res, Err(TaskError::Panicked(_)))); + + assert_eq!(*state_a.lock(), "started"); + + let info = async_storage.task(task_id).unwrap(); + assert!(matches!(info.state.status, TaskStatus::Panic(_))); + } + + #[test] + async fn test_traverse_statuses() { + let sub_gate = Arc::new(Notify::new()); + let parent_gate = Arc::new(Notify::new()); + + let async_storage = AsyncTasksStorage::default(); + + let task = async_storage.run(TraverseRoot { + sub_gate: sub_gate.clone(), + parent_gate: parent_gate.clone(), + }); + + let id = task.id(); + + settle().await; + + // Top-level listing: one named task with a live status. + let tasks = async_storage.tasks(); + assert_eq!(tasks.len(), 1); + + let snapshot = &tasks[0]; + assert_eq!(snapshot.id, id); + assert_eq!(snapshot.state.name, "test_task"); + assert!( + snapshot.state.status == TaskStatus::Running + && snapshot.state.inner_status.as_deref() == Some("StepOne") + ); + + // Both the subtask and its grandchild appear as direct + // children of the root, flat. + assert_eq!(snapshot.subtasks.len(), 2); + for sub in &snapshot.subtasks { + assert_eq!(sub.state.name, "anonymous"); + assert!(matches!(sub.state.status, TaskStatus::Started)); + assert!(sub.subtasks.is_empty()); + } + + sub_gate.notify_one(); + settle().await; + + // Subtask finished, parent moved to the next phase. + let snapshot = async_storage.task(id).unwrap(); + assert!( + snapshot.state.status == TaskStatus::Running + && snapshot.state.inner_status.as_deref() == Some("StepTwo") + ); + for sub in &snapshot.subtasks { + assert!(matches!(sub.state.status, TaskStatus::Finished)); + } + + parent_gate.notify_one(); + + task.await.unwrap(); + + // Terminal status stays observable after completion, and the last + // inner progress is preserved across the terminal transition. + let snapshot = async_storage.task(id).unwrap(); + assert!(matches!(snapshot.state.status, TaskStatus::Finished)); + assert_eq!(snapshot.state.inner_status.as_deref(), Some("StepTwo")); + } + + #[test] + async fn test_prune_expired_tasks_and_subtasks() { + let notify_a = Arc::new(Notify::new()); + let state_a = Arc::new(Mutex::new("initial")); + let a = mock_successful!(state_a, notify_a); + + let sub_gate = Arc::new(Notify::new()); + let parent_gate = Arc::new(Notify::new()); + + // Zero retention: terminal tasks are pruned on next access. + let async_storage = AsyncTasksStorage::new(Duration::ZERO); + + let task_a = async_storage.run(a); + let id_a = task_a.id(); + + let root = async_storage.run(TraverseRoot { + sub_gate: sub_gate.clone(), + parent_gate: parent_gate.clone(), + }); + let root_id = root.id(); + + settle().await; + + // Both roots running; the root has its two descendants registered. + assert_eq!(async_storage.tasks().len(), 2); + assert_eq!(async_storage.task(root_id).unwrap().subtasks.len(), 2); + + // Finish the top-level task: it expired and is pruned on next + // access, while the still-running root survives. + notify_a.notify_one(); + task_a.await.unwrap(); + + let tasks = async_storage.tasks(); + assert_eq!(tasks.len(), 1); + assert_eq!(tasks[0].id, root_id); + assert!(async_storage.task(id_a).is_none()); + + // Let the subtasks finish; the root parks on `parent_gate`, still running. + sub_gate.notify_one(); + settle().await; + + // The finished subtasks are pruned while the still-running root survives. + let snapshot = async_storage.task(root_id).unwrap(); + assert!(matches!(snapshot.state.status, TaskStatus::Running)); + assert!(snapshot.subtasks.is_empty()); + + // The root finishes and expires too: the registry empties. + parent_gate.notify_one(); + root.await.unwrap(); + + assert!(async_storage.tasks().is_empty()); + } + + #[test(start_paused = true)] + async fn test_cancel_timeout_override() { + let notify = Arc::new(Notify::new()); + + let async_storage = AsyncTasksStorage::default(); + + let task = async_storage.run(Cancellable { + state: Arc::new(Mutex::new("initial")), + notify: notify.clone(), + wind_down: Duration::from_secs(60), + }); + + settle().await; + + let started = tokio::time::Instant::now(); + + async_storage.cancel_task(task.id()); + + let res = task.await; + assert!(matches!(res, Err(TaskError::Cancelled))); + + // The paused clock advances exactly by the overridden + // grace period, not the default 5s. + assert_eq!(started.elapsed(), Duration::from_secs(30)); + } + + #[test(start_paused = true)] + async fn test_non_cooperative_cancel_aborts_immediately() { + let notify = Arc::new(Notify::new()); + + let async_storage = AsyncTasksStorage::default(); + + // `Gate` never takes the cancellation token: no grace period. + let task = async_storage.run(Gate { + gate: notify.clone(), + }); + + settle().await; + + let started = tokio::time::Instant::now(); + + async_storage.cancel_task(task.id()); + + let res = task.await; + assert!(matches!(res, Err(TaskError::Cancelled))); + + // Aborted right away: the paused clock did not advance. + assert_eq!(started.elapsed(), Duration::ZERO); + } + + #[test] + async fn global_storage_runs_and_lists() { + let waiter = crate::api::tasks_storage().run(Noop); + let id = waiter.id(); + waiter.await.unwrap(); + assert!(crate::api::tasks_storage().task(id).is_some()); + } + + #[test] + async fn test_cancel_finished_task_is_not_found() { + let notify = Arc::new(Notify::new()); + let state = Arc::new(Mutex::new("initial")); + let a = mock_successful!(state, notify); + + let async_storage = AsyncTasksStorage::default(); + let task = async_storage.run(a); + let id = task.id(); + + notify.notify_one(); + task.await.unwrap(); + + // The task is terminal (finished) but still retained for reporting: + // cancelling it now reports not-found, so STOP_TASK won't claim it + // stopped a completed task (nor emit a bogus cleanup warning). + assert!(async_storage.task(id).is_some()); + assert!(async_storage.cancel_task(id).is_none()); + + // An unknown id is not-found too. + assert!( + async_storage + .cancel_task(AsyncTaskId::from(99_u64)) + .is_none() + ); + } +} diff --git a/pgdog/src/api/copy_data.rs b/pgdog/src/api/copy_data.rs new file mode 100644 index 000000000..9dbbe337f --- /dev/null +++ b/pgdog/src/api/copy_data.rs @@ -0,0 +1,42 @@ +//! Copy-data leaf task: bulk-copies table data from a source to a target. +//! +//! This task only copies data. The schema sync (pre-data tables, post-data +//! indexes) and replication around it are composed by +//! [`ReshardTask`](crate::api::resharding::ReshardTask). + +use tokio::select; + +use crate::api::Task; +use crate::api::async_task::{AsyncTaskContext, Empty}; +use crate::backend::replication::logical::Error; +use crate::backend::replication::logical::orchestrator::Orchestrator; + +/// Bulk-copy table data from a source database to a target, returning the +/// orchestrator so the composing task can thread it into the next phase. +#[derive(Display, Debug, bon::Builder)] +#[display("copy_data {orchestrator}")] +pub(crate) struct CopyDataTask { + pub orchestrator: Orchestrator, +} + +impl Task for CopyDataTask { + type Status = Empty; + type Output = Orchestrator; + type Error = Error; + + async fn run(self, ctx: AsyncTaskContext) -> Result { + let token = ctx.cancellation_token(); + let orchestrator = self.orchestrator; + + select! { + res = orchestrator.data_sync() => res?, + // Cancellation drops the `data_sync()` future, whose internal + // `JoinSet` aborts every in-flight shard copy; closing those + // connections releases the temporary data-sync slots. The + // composing task drops the persistent replication slots afterward. + _ = token.cancelled() => return Err(Error::DataSyncAborted), + } + + Ok(orchestrator) + } +} diff --git a/pgdog/src/api/mod.rs b/pgdog/src/api/mod.rs new file mode 100644 index 000000000..7963896fe --- /dev/null +++ b/pgdog/src/api/mod.rs @@ -0,0 +1,44 @@ +//! PgDog API handlers. +//! +//! The interfaces that calls the api: +//! - pgdog CLI +//! - admin db api + +use std::sync::LazyLock; + +use crate::backend::replication::logical::Error; +use async_task::{AsyncTaskWaiter, AsyncTasksStorage, TaskError}; + +pub mod async_task; +pub mod copy_data; +pub mod replication; +pub mod resharding; +pub mod schema_sync; + +/// Process-global task registry shared by all `crate::api` task modules. +static TASKS: LazyLock = LazyLock::new(AsyncTasksStorage::default); + +/// Accessor for the process-global task registry. +pub(crate) fn tasks_storage() -> &'static AsyncTasksStorage { + &TASKS +} + +/// A composable background task: implement [`Task`] (see +/// [`async_task`]) to define one, then launch it as a top-level task +/// with [`start`] or nested under a running task through its +/// [`AsyncTaskContext`](async_task::AsyncTaskContext). +pub(crate) use async_task::Task; + +/// Launch `task` as a top-level task in the global registry. +pub(crate) fn run_task(task: T) -> AsyncTaskWaiter { + tasks_storage().run(task) +} + +/// Error returned by the API migration tasks: either an error from the +/// replication/orchestrator machinery, or a child task's [`TaskError`] +/// (failure, cancellation, panic, or abandonment) surfaced to its parent. +#[derive(Debug, Display, Error, From)] +pub(crate) enum MigrationError { + Replication(Error), + Task(TaskError), +} diff --git a/pgdog/src/api/replication.rs b/pgdog/src/api/replication.rs new file mode 100644 index 000000000..3389483e9 --- /dev/null +++ b/pgdog/src/api/replication.rs @@ -0,0 +1,282 @@ +//! Logical-replication background task. +//! +//! Drives a `ReplicationWaiter` to completion. Without `auto_cutover` +//! (standalone `REPLICATE`, `copy_data`) it stops on cancellation +//! (`STOP_TASK`), cuts over on an operator `CUTOVER` addressed to this task +//! (delivered through [`ReplicationTask::cutover`]), and otherwise finishes +//! when the source slot drains (no cutover on natural drain). With +//! `auto_cutover` set (reshard) it cuts over automatically once the +//! destination has caught up. + +use std::collections::HashMap; +use std::sync::LazyLock; +use std::time::Duration; + +use parking_lot::Mutex; +use tokio::select; +use tokio_util::sync::CancellationToken; + +use crate::api::Task; +use crate::api::async_task::{AsyncTaskContext, AsyncTaskId}; +use crate::backend::replication::logical::Error; +use crate::backend::replication::logical::orchestrator::ReplicationWaiter; + +/// Stages of logical replication, reported as the task's status. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Display)] +pub(crate) enum ReplicationStatus { + /// Streaming changes to catch the destination up. + #[display("replicating")] + Replicating, + /// Cutting traffic over to the destination. + #[display("cutting over")] + CuttingOver, + /// Cutting traffic back to the original after a prior cutover (rollback). + #[display("rolling back")] + RollingBack, + /// Winding down on a stop request. + #[display("stopping")] + Stopping, +} + +/// Direction of a replication task: the initial migration (`Forward`) or the +/// post-cutover reverse stream that backs a rollback (`Reverse`). A `CUTOVER` +/// on a `Reverse` task is therefore a rollback. Affects reported status only, +/// not control flow. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub(crate) enum Direction { + #[default] + Forward, + Reverse, +} + +/// Run the replication by driving a [`ReplicationWaiter`] to completion. +#[derive(Display, Debug, bon::Builder)] +#[display("replication {waiter}{}", match direction { + Direction::Forward => "", + Direction::Reverse => " (reverse)", +})] +pub(crate) struct ReplicationTask { + /// The running replication waiter this task drives to completion. + pub waiter: ReplicationWaiter, + /// Cut over automatically once the destination has caught up, instead + /// of waiting for an operator `CUTOVER`. + #[builder(default)] + pub auto_cutover: bool, + /// Replication direction. `Reverse` marks the post-cutover stream that + /// backs a rollback; it only affects reported status, not control flow. + #[builder(default)] + pub direction: Direction, +} + +/// Cutover tokens of the replication tasks currently awaiting an operator +/// `CUTOVER`, keyed by the root task id they belong to. A cutover token is +/// *separate* from the task's `STOP_TASK` cancellation token — signalling it +/// means "cut over", not "abandon". +static CUTOVERS: LazyLock>> = + LazyLock::new(|| Mutex::new(HashMap::new())); + +/// Guard held by a running replication task: removes its cutover +/// registration on drop. Awaiting [CutoverWaiter::requested] +/// resolves when an operator `CUTOVER` targets the task. +struct CutoverWaiter { + root_id: AsyncTaskId, + token: CancellationToken, +} + +impl CutoverWaiter { + /// Wait until a cutover is requested for this task. The token latches, so + /// a cutover that arrived earlier is delivered immediately. + async fn requested(&self) { + self.token.cancelled().await; + } +} + +impl Drop for CutoverWaiter { + fn drop(&mut self) { + CUTOVERS.lock().remove(&self.root_id); + } +} + +impl Task for ReplicationTask { + type Status = ReplicationStatus; + type Output = (); + type Error = Error; + + fn cancel_timeout() -> Duration { + Duration::from_secs(60) + } + + async fn run(mut self, ctx: AsyncTaskContext) -> Result<(), Error> { + let token = ctx.cancellation_token(); + + ctx.set_status(ReplicationStatus::Replicating); + + if self.auto_cutover { + return self.perform_cutover(&ctx, &token).await; + } + + let cutover = Self::register_cutover(ctx.root_id()); + + select! { + _ = token.cancelled() => { + ctx.set_status(ReplicationStatus::Stopping); + self.waiter.stop(); + } + _ = cutover.requested() => { + self.perform_cutover(&ctx, &token).await?; + } + res = self.waiter.wait() => { + res?; + } + } + + Ok(()) + } +} + +impl ReplicationTask { + /// Perform the actual cutover for running replication. + async fn perform_cutover( + &mut self, + ctx: &AsyncTaskContext, + token: &CancellationToken, + ) -> Result<(), Error> { + ctx.set_status(match self.direction { + Direction::Forward => ReplicationStatus::CuttingOver, + Direction::Reverse => ReplicationStatus::RollingBack, + }); + self.waiter.cutover(token).await + } + + /// Trigger a cutover on a running replication task. + pub(crate) fn trigger_cutover(target: Option) -> bool { + let tokens = CUTOVERS.lock(); + + let token = match target { + Some(id) => tokens.get(&id), + // No id: cut over the first (lowest-id) running task. + None => tokens.keys().min().and_then(|id| tokens.get(id)), + }; + + match token { + Some(token) => { + token.cancel(); + true + } + None => false, + } + } + + /// Register this task (by its `root_id`) to receive operator cutovers for + /// as long as the returned guard is held. + fn register_cutover(root_id: AsyncTaskId) -> CutoverWaiter { + let token = CancellationToken::new(); + CUTOVERS.lock().insert(root_id, token.clone()); + CutoverWaiter { root_id, token } + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use super::*; + + // Serialize tests that touch the process-global `CUTOVERS` map so they + // never observe each other's registrations under a multi-threaded harness. + static CUTOVER_TEST_LOCK: std::sync::LazyLock> = + std::sync::LazyLock::new(|| tokio::sync::Mutex::new(())); + + #[tokio::test] + async fn cutover_delivers_even_when_buffered() { + let _guard = CUTOVER_TEST_LOCK.lock().await; + // Cutover lands before the task awaits: still delivered (latches). + let waiter = ReplicationTask::register_cutover(AsyncTaskId::from(1)); + assert!( + ReplicationTask::trigger_cutover(Some(AsyncTaskId::from(1))), + "the named task must receive the cutover" + ); + + tokio::time::timeout(Duration::from_secs(1), waiter.requested()) + .await + .expect("buffered cutover was not delivered"); + } + + #[tokio::test] + async fn cutover_targets_only_the_named_task() { + let _guard = CUTOVER_TEST_LOCK.lock().await; + // A cutover for one id must never disturb a task registered under a + // different id — the whole point of keying by task id. + let waiter = ReplicationTask::register_cutover(AsyncTaskId::from(7)); + + assert!( + !ReplicationTask::trigger_cutover(Some(AsyncTaskId::from(8))), + "no task is registered under id 8" + ); + assert!( + tokio::time::timeout(Duration::from_millis(200), waiter.requested()) + .await + .is_err(), + "a cutover for a different id leaked to this task" + ); + + assert!(ReplicationTask::trigger_cutover(Some(AsyncTaskId::from(7)))); + tokio::time::timeout(Duration::from_secs(1), waiter.requested()) + .await + .expect("targeted cutover was not delivered"); + } + + #[tokio::test] + async fn cutover_without_id_targets_the_first_task() { + let _guard = CUTOVER_TEST_LOCK.lock().await; + // No id: the lowest-id (first) registered task is cut over, and only + // it. + let first = ReplicationTask::register_cutover(AsyncTaskId::from(3)); + let second = ReplicationTask::register_cutover(AsyncTaskId::from(9)); + + assert!( + ReplicationTask::trigger_cutover(None), + "the first registered task must be cut over" + ); + + tokio::time::timeout(Duration::from_secs(1), first.requested()) + .await + .expect("the first task was not cut over"); + assert!( + tokio::time::timeout(Duration::from_millis(200), second.requested()) + .await + .is_err(), + "cutover(None) disturbed a task other than the first" + ); + } + + #[tokio::test] + async fn cutover_does_not_leak_to_the_next_task() { + let _guard = CUTOVER_TEST_LOCK.lock().await; + // A cutover to a task that never consumes it must die with that task, + // never reaching the next one. Regression guard for the signal leak. + { + let first = ReplicationTask::register_cutover(AsyncTaskId::from(1)); + assert!(ReplicationTask::trigger_cutover(Some(AsyncTaskId::from(1)))); + drop(first); // ends without ever awaiting `requested()` + } + + let next = ReplicationTask::register_cutover(AsyncTaskId::from(2)); + assert!( + tokio::time::timeout(Duration::from_millis(200), next.requested()) + .await + .is_err(), + "stale cutover leaked into the next replication task" + ); + } + + #[tokio::test] + async fn cutover_with_no_task_is_rejected() { + let _guard = CUTOVER_TEST_LOCK.lock().await; + // Nothing registered: `CUTOVER` (with or without an id) is rejected. + assert!(!ReplicationTask::trigger_cutover(None)); + assert!(!ReplicationTask::trigger_cutover(Some(AsyncTaskId::from( + 404 + )))); + } +} diff --git a/pgdog/src/api/resharding.rs b/pgdog/src/api/resharding.rs new file mode 100644 index 000000000..c861f33e6 --- /dev/null +++ b/pgdog/src/api/resharding.rs @@ -0,0 +1,163 @@ +//! Reshard / migration composer task. +//! +//! Composes the full migration from a source database to a target: pre-data +//! schema sync, the [`CopyDataTask`] bulk copy, post-data schema sync, then +//! replication. With `auto_cutover` it cuts over automatically once replication +//! has caught up (reshard); otherwise it waits for an operator `CUTOVER` +//! (`copy_data`). The admin `COPY_DATA`/`RESHARD` commands and the CLI +//! `data_sync` all run this task, differing only in their options. + +use std::time::Duration; + +use tracing::warn; + +use crate::api::async_task::AsyncTaskContext; +use crate::api::copy_data::CopyDataTask; +use crate::api::replication::ReplicationTask; +use crate::api::schema_sync::{SchemaSyncPhase, SchemaSyncTask}; +use crate::api::{MigrationError, Task}; +use crate::backend::replication::logical::orchestrator::Orchestrator; + +/// Run the full migration from a source database to a target: schema sync +/// (pre-data tables, then post-data indexes around the bulk copy), data copy, +/// then replication. With `auto_cutover` it also performs the cutover. +#[derive(Display, Debug, bon::Builder)] +#[display("reshard {orchestrator}")] +pub(crate) struct ReshardTask { + pub orchestrator: Orchestrator, + /// Skip the pre- and post-data schema sync. + #[builder(default)] + pub skip_schema_sync: bool, + /// Only replicate; skip the initial data copy. + #[builder(default)] + pub replicate_only: bool, + /// Only copy data; skip replication. + #[builder(default)] + pub sync_only: bool, + /// Cut over automatically once replication has caught up, instead of + /// waiting for an operator `CUTOVER`. Set by the reshard flow. + #[builder(default)] + pub auto_cutover: bool, +} + +/// Stages of the migration, reported as the task's status. The fine-grained +/// schema-sync, copy, and replication stages live on the child tasks. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Display)] +pub(crate) enum ReshardStatus { + /// Running the pre-data schema-sync child task. + #[display("syncing schema")] + SchemaSync, + /// Running the data-copy child task. + #[display("syncing data")] + SyncingData, + /// Running the post-data schema-sync child task (indexes, constraints). + #[display("finalizing schema")] + FinalizingSchema, + /// Running the replication child task. + #[display("replicating")] + Replication, +} + +impl Task for ReshardTask { + type Status = ReshardStatus; + type Output = (); + type Error = MigrationError; + + fn cancel_timeout() -> Duration { + Duration::from_secs(60) + } + + async fn run(self, ctx: AsyncTaskContext) -> Result<(), MigrationError> { + // Take the cancellation token so a `STOP_TASK` winds the children down + // cooperatively (they'd otherwise outlive this task). + let _token = ctx.cancellation_token(); + let mut orchestrator = self.orchestrator; + + // Pre-data schema sync, unless skipped. It runs before any replication + // slots exist, so it stays outside the cleanup guard below. + if !self.skip_schema_sync { + ctx.set_status(ReshardStatus::SchemaSync); + orchestrator = ctx + .run( + SchemaSyncTask::builder() + .orchestrator(orchestrator) + .phase(SchemaSyncPhase::Pre) + .ignore_errors(true) + .build(), + ) + .await?; + } + + // From the data copy onward the orchestrator may hold replication slots + // (created during data_sync, kept until replication takes them over). + // Awaiting this guard on every exit drops whatever the publisher still + // owns — a no-op once replication has claimed the slots — so a failed or + // aborted migration doesn't leave them lingering on the source. + let guard = orchestrator.publication_guard(); + let result: Result<(), MigrationError> = async { + // Copy the data, unless replicate-only. + if !self.replicate_only { + ctx.set_status(ReshardStatus::SyncingData); + orchestrator = ctx + .run(CopyDataTask::builder().orchestrator(orchestrator).build()) + .await?; + } + + // Post-data schema sync (secondary indexes, constraints): the + // second half of schema sync, after the bulk load. + if !self.skip_schema_sync { + ctx.set_status(ReshardStatus::FinalizingSchema); + + // The bulk copy above can run for hours; pools may have reloaded + // meanwhile, leaving our cluster refs stale. Re-fetch them before + // touching the destination. + orchestrator.refresh()?; + + orchestrator = ctx + .run( + SchemaSyncTask::builder() + .orchestrator(orchestrator) + .phase(SchemaSyncPhase::Post) + .ignore_errors(true) + .build(), + ) + .await?; + } + + // Replication, unless sync-only. + if !self.sync_only { + ctx.set_status(ReshardStatus::Replication); + + // data_sync / schema sync can run for hours; pools may have + // reloaded. Re-fetch live cluster refs before replicating. + orchestrator.refresh()?; + + // `auto_cutover` (reshard) cuts over on its own; otherwise the + // task runs until an operator `CUTOVER`/`STOP_TASK`. Both of + // those resolve to `Ok`, so awaiting surfaces only a genuine + // replication failure. + let waiter = orchestrator.replicate().await?; + ctx.run( + ReplicationTask::builder() + .waiter(waiter) + .auto_cutover(self.auto_cutover) + .build(), + ) + .await?; + } + + Ok(()) + } + .await; + + // Drop any replication slots the publisher still owns only when the + // migration failed or was aborted mid-copy. + if result.is_err() + && let Err(err) = guard.cleanup().await + { + warn!("failed to clean up replication slots after migration: {err}"); + } + + result + } +} diff --git a/pgdog/src/api/schema_sync.rs b/pgdog/src/api/schema_sync.rs new file mode 100644 index 000000000..7d9cbf231 --- /dev/null +++ b/pgdog/src/api/schema_sync.rs @@ -0,0 +1,135 @@ +//! Schema-sync background task (pre-data, post-data, or cutover). + +use std::ops::Deref; + +use crate::api::Task; +use crate::api::async_task::AsyncTaskContext; +use crate::backend::replication::logical::Error; +use crate::backend::replication::logical::orchestrator::Orchestrator; +use crate::backend::schema::sync::pg_dump::SyncState; + +#[derive(Clone, Copy, PartialEq, Eq, Debug, Display, FromStr)] +pub enum SchemaSyncPhase { + #[display("pre")] + Pre, + #[display("post")] + Post, + #[display("cutover")] + Cutover, +} + +impl From for SyncState { + fn from(phase: SchemaSyncPhase) -> Self { + match phase { + SchemaSyncPhase::Pre => SyncState::PreData, + SchemaSyncPhase::Post => SyncState::PostData, + SchemaSyncPhase::Cutover => SyncState::Cutover, + } + } +} + +/// Stages of a schema sync, reported as the task's status. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Display)] +pub(crate) enum SchemaSyncStatus { + /// Dumping the schema from the source. + #[display("loading schema")] + LoadingSchema, + /// Restoring tables on the destination (pre-data). + #[display("syncing tables")] + SyncingTables, + /// Creating indexes and constraints on the destination (post-data). + #[display("creating indexes")] + CreatingIndexes, + /// Restoring cutover-time schema on the destination. + #[display("syncing cutover schema")] + Cutover, +} + +/// Sync the schema (pre-data, post-data, or cutover) from a source database to a target. +#[derive(Display, Debug, bon::Builder)] +#[display("schema_sync({phase}) {orchestrator}")] +pub(crate) struct SchemaSyncTask { + pub orchestrator: Orchestrator, + pub phase: SchemaSyncPhase, + #[builder(default)] + pub ignore_errors: bool, + #[builder(default)] + pub dry_run: bool, +} + +impl Task for SchemaSyncTask { + type Status = SchemaSyncStatus; + type Output = Orchestrator; + type Error = Error; + + #[allow(clippy::print_stdout)] + async fn run(self, ctx: AsyncTaskContext) -> Result { + let mut orchestrator = self.orchestrator; + + if orchestrator.schema().is_err() { + ctx.set_status(SchemaSyncStatus::LoadingSchema); + orchestrator.load_schema().await?; + } + + if self.dry_run { + let schema = orchestrator.schema()?; + for statement in schema.statements(self.phase.into())? { + println!("{}", statement.deref()); + } + return Ok(orchestrator); + } + + match self.phase { + SchemaSyncPhase::Pre => { + ctx.set_status(SchemaSyncStatus::SyncingTables); + orchestrator.schema_sync_pre(self.ignore_errors).await?; + } + SchemaSyncPhase::Post => { + ctx.set_status(SchemaSyncStatus::CreatingIndexes); + orchestrator.schema_sync_post(self.ignore_errors).await?; + } + SchemaSyncPhase::Cutover => { + ctx.set_status(SchemaSyncStatus::Cutover); + orchestrator.schema_sync_cutover(self.ignore_errors).await?; + } + } + + Ok(orchestrator) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn schema_sync_status_renders_distinct_labels() { + let labels = [ + SchemaSyncStatus::LoadingSchema.to_string(), + SchemaSyncStatus::SyncingTables.to_string(), + SchemaSyncStatus::CreatingIndexes.to_string(), + SchemaSyncStatus::Cutover.to_string(), + ]; + assert!(labels.iter().all(|label| !label.is_empty())); + let unique: std::collections::HashSet<_> = labels.iter().collect(); + assert_eq!(unique.len(), labels.len()); + } + + #[test] + fn schema_sync_phase_parses_and_displays() { + for (text, phase) in [ + ("pre", SchemaSyncPhase::Pre), + ("post", SchemaSyncPhase::Post), + ("cutover", SchemaSyncPhase::Cutover), + ] { + assert_eq!(text.parse::().unwrap(), phase); + assert_eq!(phase.to_string(), text); + } + // Parsing is case-insensitive; unknown phases are rejected. + assert_eq!( + "CUTOVER".parse::().unwrap(), + SchemaSyncPhase::Cutover + ); + assert!("bogus".parse::().is_err()); + } +} diff --git a/pgdog/src/backend/replication/logical/admin.rs b/pgdog/src/backend/replication/logical/admin.rs deleted file mode 100644 index 13a5ce046..000000000 --- a/pgdog/src/backend/replication/logical/admin.rs +++ /dev/null @@ -1,351 +0,0 @@ -use std::{ - fmt, - sync::{ - Arc, - atomic::{AtomicU64, Ordering}, - }, - time::SystemTime, -}; - -use crate::backend::replication::orchestrator::ReplicationWaiter; - -use super::Error; -use dashmap::DashMap; -use once_cell::sync::Lazy; -use tokio::{ - select, spawn, - sync::{Notify, oneshot}, - task::JoinHandle, -}; -use tracing::error; - -static TASKS: Lazy = Lazy::new(AsyncTasks::default); - -pub struct Task; - -impl Task { - pub(crate) fn register(task: TaskType) -> u64 { - AsyncTasks::insert(task) - } -} - -pub enum TaskType { - SchemaSync(JoinHandle>), - CopyData(JoinHandle>), - Replication(Box), -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum TaskKind { - SchemaSync, - CopyData, - Replication, -} - -impl fmt::Display for TaskKind { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - TaskKind::SchemaSync => write!(f, "schema_sync"), - TaskKind::CopyData => write!(f, "copy_data"), - TaskKind::Replication => write!(f, "replication"), - } - } -} - -pub struct TaskInfo { - #[allow(dead_code)] - abort_tx: oneshot::Sender<()>, - cutover: Arc, - pub task_kind: TaskKind, - pub started_at: SystemTime, -} - -#[derive(Clone, Default)] -pub struct AsyncTasks { - tasks: Arc>, - counter: Arc, -} - -impl AsyncTasks { - pub fn get() -> Self { - TASKS.clone() - } - - /// Perform cutover. - pub fn cutover() -> Result<(), Error> { - let this = Self::get(); - let task = this - .tasks - .iter() - .find(|t| t.task_kind == TaskKind::Replication) - .ok_or(Error::NotReplication)?; - - task.cutover.notify_one(); - - Ok(()) - } - - pub fn insert(task: TaskType) -> u64 { - let this = Self::get(); - let id = this.counter.fetch_add(1, Ordering::SeqCst); - let (abort_tx, abort_rx) = oneshot::channel(); - - match task { - TaskType::SchemaSync(handle) => { - this.tasks.insert( - id, - TaskInfo { - abort_tx, - cutover: Arc::new(Notify::new()), - task_kind: TaskKind::SchemaSync, - started_at: SystemTime::now(), - }, - ); - let abort_handle = handle.abort_handle(); - spawn(async move { - select! { - _ = abort_rx => { - abort_handle.abort(); - } - result = handle => { - match result { - Ok(Ok(())) => {} - Ok(Err(err)) => error!("[task: {}] {}", id, err), - Err(err) => error!("[task: {}] {}", id, err), - } - } - } - AsyncTasks::get().tasks.remove(&id); - }); - } - - TaskType::CopyData(handle) => { - this.tasks.insert( - id, - TaskInfo { - abort_tx, - cutover: Arc::new(Notify::new()), - task_kind: TaskKind::CopyData, - started_at: SystemTime::now(), - }, - ); - let abort_handle = handle.abort_handle(); - spawn(async move { - select! { - _ = abort_rx => { - abort_handle.abort(); - } - result = handle => { - match result { - Ok(Ok(())) => {} - Ok(Err(err)) => error!("[task: {}] {}", id, err), - Err(err) => error!("[task: {}] {}", id, err), - } - } - } - AsyncTasks::get().tasks.remove(&id); - }); - } - - TaskType::Replication(mut waiter) => { - let cutover = Arc::new(Notify::new()); - - this.tasks.insert( - id, - TaskInfo { - abort_tx, - cutover: cutover.clone(), - task_kind: TaskKind::Replication, - started_at: SystemTime::now(), - }, - ); - - spawn(async move { - select! { - _ = abort_rx => { - waiter.stop(); - } - - _ = cutover.notified() => { - if let Err(err) = waiter.cutover().await { - error!("[task: {}] {}", id, err); - } - } - - result = waiter.wait() => { - if let Err(err) = result { - error!("[task: {}] {}", id, err); - } - } - } - - AsyncTasks::get().tasks.remove(&id); - }); - } - } - - id - } - - pub fn remove(id: u64) -> Option { - // Dropping the sender signals abort to the waiting task - Self::get() - .tasks - .remove(&id) - .map(|(_, info)| info.task_kind) - } - - pub fn iter(&self) -> impl Iterator + '_ { - self.tasks - .iter() - .map(|e| (*e.key(), e.value().task_kind, e.value().started_at)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::time::Duration; - use tokio::time::sleep; - - #[test] - fn test_task_kind_display() { - assert_eq!(TaskKind::SchemaSync.to_string(), "schema_sync"); - assert_eq!(TaskKind::CopyData.to_string(), "copy_data"); - assert_eq!(TaskKind::Replication.to_string(), "replication"); - } - - #[tokio::test] - async fn test_task_registration_and_removal() { - // Create a task that completes immediately - let handle = spawn(async { Ok::<(), Error>(()) }); - let id = Task::register(TaskType::SchemaSync(handle)); - - // Task should be visible briefly - // Give it a moment to register - sleep(Duration::from_millis(10)).await; - - // Try to remove it - it may already be gone if it completed - let result = AsyncTasks::remove(id); - // Either we removed it, or it already completed and removed itself - assert!(result.is_none() || result == Some(TaskKind::SchemaSync)); - } - - #[tokio::test] - async fn test_task_abort_via_remove() { - // Create a long-running task - let handle = spawn(async { - sleep(Duration::from_secs(60)).await; - Ok::<(), Error>(()) - }); - let id = Task::register(TaskType::CopyData(handle)); - - // Give it time to register - sleep(Duration::from_millis(10)).await; - - // Remove should abort the task - let result = AsyncTasks::remove(id); - assert_eq!(result, Some(TaskKind::CopyData)); - - // Task should be gone now - sleep(Duration::from_millis(50)).await; - let result = AsyncTasks::remove(id); - assert_eq!(result, None); - } - - #[tokio::test] - async fn test_task_iter() { - // Create multiple tasks - let handle1 = spawn(async { - sleep(Duration::from_secs(60)).await; - Ok::<(), Error>(()) - }); - let handle2 = spawn(async { - sleep(Duration::from_secs(60)).await; - Ok::<(), Error>(()) - }); - - let id1 = Task::register(TaskType::SchemaSync(handle1)); - let id2 = Task::register(TaskType::CopyData(handle2)); - - sleep(Duration::from_millis(10)).await; - - // Should see both tasks - let tasks: Vec<_> = AsyncTasks::get().iter().collect(); - let task_ids: Vec<_> = tasks.iter().map(|(id, _, _)| *id).collect(); - assert!(task_ids.contains(&id1)); - assert!(task_ids.contains(&id2)); - - // Verify task kinds - for (id, kind, _) in &tasks { - if *id == id1 { - assert_eq!(*kind, TaskKind::SchemaSync); - } else if *id == id2 { - assert_eq!(*kind, TaskKind::CopyData); - } - } - - // Cleanup - AsyncTasks::remove(id1); - AsyncTasks::remove(id2); - } - - #[tokio::test] - async fn test_task_auto_cleanup_on_completion() { - // Create a task that completes quickly - let handle = spawn(async { - sleep(Duration::from_millis(10)).await; - Ok::<(), Error>(()) - }); - let id = Task::register(TaskType::SchemaSync(handle)); - - // Wait for task to complete and cleanup - sleep(Duration::from_millis(100)).await; - - // Task should have removed itself - let result = AsyncTasks::remove(id); - assert_eq!(result, None); - } - - #[tokio::test] - async fn test_cutover_fails_without_replication_task() { - // Create a non-replication task - let handle = spawn(async { - sleep(Duration::from_secs(60)).await; - Ok::<(), Error>(()) - }); - let id = Task::register(TaskType::SchemaSync(handle)); - sleep(Duration::from_millis(10)).await; - - // Cutover should fail because there's no replication task - let result = AsyncTasks::cutover(); - assert!(matches!(result, Err(Error::NotReplication)), "{:?}", result); - - // Cleanup - AsyncTasks::remove(id); - } - - #[tokio::test] - async fn test_cutover_returns_not_found_when_no_replication_task() { - // Register several non-replication tasks - let mut task_ids = vec![]; - for _ in 0..5 { - let handle = spawn(async { - sleep(Duration::from_secs(60)).await; - Ok::<(), Error>(()) - }); - task_ids.push(Task::register(TaskType::SchemaSync(handle))); - } - - sleep(Duration::from_millis(10)).await; - - // With only non-replication tasks, cutover should return TaskNotFound - let result = AsyncTasks::cutover(); - assert!(matches!(result, Err(Error::NotReplication))); - - // Cleanup - for id in task_ids { - AsyncTasks::remove(id); - } - } -} diff --git a/pgdog/src/backend/replication/logical/mod.rs b/pgdog/src/backend/replication/logical/mod.rs index f8cbf5445..3bc799b33 100644 --- a/pgdog/src/backend/replication/logical/mod.rs +++ b/pgdog/src/backend/replication/logical/mod.rs @@ -1,4 +1,3 @@ -pub mod admin; pub mod copy_statement; pub mod ee; pub mod error; @@ -7,7 +6,6 @@ pub mod publisher; pub mod status; pub mod subscriber; -pub use admin::*; pub use copy_statement::CopyStatement; pub use error::*; diff --git a/pgdog/src/backend/replication/logical/orchestrator.rs b/pgdog/src/backend/replication/logical/orchestrator.rs index 572100583..b6610875c 100644 --- a/pgdog/src/backend/replication/logical/orchestrator.rs +++ b/pgdog/src/backend/replication/logical/orchestrator.rs @@ -14,7 +14,8 @@ use tokio::{ sync::Mutex, time::{Instant, interval}, }; -use tracing::{info, warn}; +use tokio_util::sync::CancellationToken; +use tracing::{error, info, warn}; use super::*; @@ -28,6 +29,21 @@ pub(crate) struct Orchestrator { replication_slot: String, } +/// A handle to a publication's replication slots, decoupled from the rest of +/// the orchestrator. Awaiting [`PublicationGuard::cleanup`] drops every slot +/// the publisher still owns — a no-op once `replicate` has handed them off to +/// the streaming tasks. +pub(crate) struct PublicationGuard { + publisher: Arc>, +} + +impl PublicationGuard { + /// Drop any replication slots the publisher still owns. + pub(crate) async fn cleanup(self) -> Result<(), Error> { + self.publisher.lock().await.cleanup().await + } +} + impl Orchestrator { /// Create new orchestrator. pub(crate) fn new( @@ -145,6 +161,13 @@ impl Orchestrator { Ok(()) } + /// Take a [`PublicationGuard`] over this orchestrator's replication slots. + pub(crate) fn publication_guard(&self) -> PublicationGuard { + PublicationGuard { + publisher: self.publisher.clone(), + } + } + /// Replicate forever. /// /// Useful for CLI interface only, since this will never stop. @@ -162,35 +185,6 @@ impl Orchestrator { }) } - /// Request replication stop. - pub(crate) async fn request_stop(&self) { - self.publisher.lock().await.request_stop(); - } - - /// Perform the entire flow in one swoop. - pub(crate) async fn replicate_and_cutover(&mut self) -> Result<(), Error> { - // Load the schema from source. - self.load_schema().await?; - - // Sync the schema to destination. - self.schema_sync_pre(true).await?; - - // Sync the data to destination. - self.data_sync().await?; - - // Create secondary indexes on destination. - self.schema_sync_post(true).await?; - - // Start replication to catch up and cutover once done. - // Refresh cluster references: data_sync can take hours and the pools - // may have been reloaded (e.g. by a client DDL) in the meantime. - self.refresh()?; - - self.replicate().await?.cutover().await?; - - Ok(()) - } - pub(crate) async fn schema_sync_post(&mut self, ignore_errors: bool) -> Result<(), Error> { let schema = self.schema.as_ref().ok_or(Error::NoSchema)?; @@ -221,16 +215,21 @@ impl Orchestrator { let lag = self.publisher.lock().await.replication_lag(); lag.values().copied().max().unwrap_or_default() as u64 } +} - pub(crate) async fn cleanup(&mut self) -> Result<(), Error> { - let mut guard = self.publisher.lock().await; - guard.cleanup().await?; - - Ok(()) +impl Display for Orchestrator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{} -> {}", + self.source.identifier().database, + self.destination.identifier().database + ) } } -#[derive(Debug)] +#[derive(Debug, Display)] +#[display("{orchestrator}")] pub struct ReplicationWaiter { orchestrator: Orchestrator, waiter: Waiter, @@ -315,7 +314,7 @@ impl ReplicationWaiter { maintenance_mode::start(None); // Cancel any running queries. - cancel_all(&self.orchestrator.source.identifier().database).await?; + ok_or_abort!(cancel_all(&self.orchestrator.source.identifier().database).await); break; // TODO: wait for clients to all stop. @@ -394,7 +393,7 @@ impl ReplicationWaiter { // In case replication breaks now. res = self.waiter.wait() => { - res?; + ok_or_abort!(res); } } @@ -435,9 +434,31 @@ impl ReplicationWaiter { } /// Perform traffic cutover between source and destination. - pub(crate) async fn cutover(&mut self) -> Result<(), Error> { - self.wait_for_replication().await?; - self.wait_for_cutover().await?; + /// + /// The pre-switch wait (`wait_for_replication`, `wait_for_cutover`) is + /// cancellable: a `STOP_TASK` (via `cancel`) there resumes traffic, stops + /// the replication streams, and returns without moving any traffic. Past + /// the point of no return the switch always runs to completion — cancelling + /// it would leave traffic split between source and destination. + pub(crate) async fn cutover(&mut self, cancel: &CancellationToken) -> Result<(), Error> { + select! { + // Nothing has moved yet (`wait_for_replication` only pauses traffic + // at its very end). Resume traffic (no-op if never paused) and wind + // the streams down, so the aborted cutover leaves nothing running. + _ = cancel.cancelled() => { + maintenance_mode::stop(None); + self.waiter.stop(); + warn!("[cutover] stop requested before the traffic switch, aborting cutover"); + cutover_state(CutoverState::Abort { + error: "stopped before cutover".into(), + }); + return Ok(()); + } + res = async { + self.wait_for_replication().await?; + self.wait_for_cutover().await + } => { res?; } + } // We're going, point of no return. self.orchestrator.publisher.lock().await.request_stop(); @@ -467,10 +488,16 @@ impl ReplicationWaiter { // Create reverse replication in case we need to rollback. let waiter = ok_or_abort!(self.orchestrator.replicate().await); - // Let it run in the background. - AsyncTasks::insert(TaskType::Replication(Box::new(waiter))); + // Drive the running waiter as a background api task so it stays visible + // in SHOW TASKS and can be cut over (rollback) or stopped. + crate::api::run_task( + crate::api::replication::ReplicationTask::builder() + .waiter(waiter) + .direction(crate::api::replication::Direction::Reverse) + .build(), + ); - // It's not safe to resume traffic. + // Slot is established and capturing — now safe to resume traffic. info!("[cutover] complete, resuming traffic"); // Point traffic to the other database and resume. @@ -487,6 +514,7 @@ macro_rules! ok_or_abort { match $expr { Ok(res) => res, Err(err) => { + error!("Orchestrator failed: {err}"); maintenance_mode::stop(None); cutover_state(CutoverState::Abort { error: err.to_string(), diff --git a/pgdog/src/backend/replication/logical/publisher/publisher_impl.rs b/pgdog/src/backend/replication/logical/publisher/publisher_impl.rs index dd7065696..f79c656ad 100644 --- a/pgdog/src/backend/replication/logical/publisher/publisher_impl.rs +++ b/pgdog/src/backend/replication/logical/publisher/publisher_impl.rs @@ -4,11 +4,11 @@ use std::time::Duration; use parking_lot::Mutex; use pgdog_config::QueryParserEngine; -use tokio::sync::Notify; -use tokio::task::JoinHandle; +use tokio::task::{JoinHandle, JoinSet}; use tokio::time::{Instant, sleep}; use tokio::try_join; use tokio::{select, spawn, time::interval}; +use tokio_util::sync::CancellationToken; use tracing::{debug, info, warn}; use super::super::{Error, ensure_validation, publisher::Table}; @@ -55,7 +55,7 @@ pub struct Publisher { /// Last transaction. last_transaction: Arc>>, /// Stop signal. - stop: Arc, + stop: CancellationToken, /// Slot name. slot_name: String, } @@ -72,7 +72,7 @@ impl Publisher { slots: HashMap::new(), query_parser_engine, replication_lag: Arc::new(Mutex::new(HashMap::new())), - stop: Arc::new(Notify::new()), + stop: CancellationToken::new(), last_transaction: Arc::new(Mutex::new(None)), slot_name, } @@ -234,10 +234,15 @@ impl Publisher { let max_attempts = dest.resharding_replication_retry_max_attempts(); let delay = dest.resharding_replication_retry_min_delay(); let mut attempt = 0usize; + // Latches on the first cancellation so the `cancelled()` arm fires + // once (it stays ready forever after `cancel()`); the drain below + // then runs to completion. + let mut stopping = false; loop { select! { - _ = stop.notified() => { + _ = stop.cancelled(), if !stopping => { slot.stop_replication().await?; + stopping = true; } // This is cancellation-safe. @@ -337,7 +342,7 @@ impl Publisher { /// Request the publisher to stop replication. pub fn request_stop(&self) { - self.stop.notify_one(); + self.stop.cancel(); } /// Get current replication lag. @@ -371,10 +376,14 @@ impl Publisher { ensure_validation!(validation_errors); // Create replication slots only after validation passes — a slot - // created before valid() would be orphaned on a NoIdentityColumns error. + // created before valid() would be orphaned on validation errors. self.create_slots(source).await?; - let mut handles = vec![]; + // A JoinSet aborts every in-flight shard copy when this future is + // dropped (e.g. the owning task is cancelled), so cancellation actually + // stops the copy instead of leaving detached syncs running in the + // background. + let mut set: JoinSet), Error>> = JoinSet::new(); for (number, shard) in source.shards().iter().enumerate() { let tables = self @@ -411,16 +420,16 @@ impl Publisher { }; let dest = dest.clone(); - handles.push(spawn(async move { + set.spawn(async move { let manager = ParallelSyncManager::new(tables, replicas, dest)?; let tables = manager.run().await?; - Ok::, Error>(tables) - })); + Ok::<(usize, Vec), Error>((number, tables)) + }); } - for (number, handle) in handles.into_iter().enumerate() { - let tables = handle.await??; + while let Some(joined) = set.join_next().await { + let (number, tables) = joined??; info!( "table sync for {} tables complete [{}, shard: {}]", @@ -436,13 +445,20 @@ impl Publisher { Ok(()) } - /// Cleanup after replication. + /// Drop the replication slots created during data sync. + /// + /// Idempotent: the slot map is taken out up front, so repeated calls — or a + /// call after replication already took the slots over — are no-ops. Every + /// slot is attempted even if one fails; the first error is returned. pub async fn cleanup(&mut self) -> Result<(), Error> { - for slot in self.slots.values_mut() { - slot.drop_slot().await?; + let mut error = None; + for (_, mut slot) in std::mem::take(&mut self.slots) { + if let Err(err) = slot.drop_slot().await { + error.get_or_insert(err); + } } - Ok(()) + error.map_or(Ok(()), Err) } } @@ -460,12 +476,12 @@ impl Publisher { #[derive(Debug)] pub struct Waiter { streams: Vec>>, - stop: Arc, + stop: CancellationToken, } impl Waiter { pub fn stop(&self) { - self.stop.notify_one(); + self.stop.cancel(); } pub async fn wait(&mut self) -> Result<(), Error> { @@ -482,7 +498,7 @@ impl Waiter { pub fn new_test() -> Self { Self { streams: vec![], - stop: Arc::new(Notify::new()), + stop: CancellationToken::new(), } } } diff --git a/pgdog/src/backend/schema/sync/pg_dump.rs b/pgdog/src/backend/schema/sync/pg_dump.rs index 7a15e20d2..b6d628f39 100644 --- a/pgdog/src/backend/schema/sync/pg_dump.rs +++ b/pgdog/src/backend/schema/sync/pg_dump.rs @@ -111,7 +111,7 @@ fn should_convert_to_bigint<'a>( is_integer_type(sval.as_str()) } -use tokio::{process::Command, spawn}; +use tokio::{process::Command, task::JoinSet}; #[derive(Debug, Clone)] pub struct PgDump { @@ -1041,7 +1041,11 @@ impl PgDumpOutput { }) .collect::>(), )); - let mut handles = vec![]; + // A JoinSet aborts every in-flight per-shard sync when it is dropped, + // so cancelling the task (dropping this future) actually stops the + // schema apply instead of leaving detached spawns running in the + // background against the destination. + let mut set: JoinSet> = JoinSet::new(); for (num, shard) in dest.shards().iter().enumerate() { let mut primary = shard.primary(&Request::default()).await?; @@ -1056,7 +1060,7 @@ impl PgDumpOutput { let trackers = trackers.clone(); let output = self.clone(); - handles.push(spawn(async move { + set.spawn(async move { let stmts = output.statements(state)?; let mut progress = Progress::new(stmts.len()); @@ -1107,11 +1111,11 @@ impl PgDumpOutput { } Ok::<(), Error>(()) - })); + }); } - for handle in handles { - handle.await??; + while let Some(joined) = set.join_next().await { + joined??; } Ok(()) diff --git a/pgdog/src/cli.rs b/pgdog/src/cli.rs index 9b59bbb6f..6667a8550 100644 --- a/pgdog/src/cli.rs +++ b/pgdog/src/cli.rs @@ -1,18 +1,19 @@ -use std::ops::Deref; use std::path::PathBuf; -use std::time::Duration; use clap::{Parser, Subcommand}; use std::fs::read_to_string; use thiserror::Error; -use tokio::time::sleep; use tokio::{select, signal::ctrl_c}; use tracing::{info, warn}; +use crate::api::Task; +use crate::api::resharding::ReshardTask; +use crate::api::run_task; +use crate::api::schema_sync::{SchemaSyncPhase, SchemaSyncTask}; +use crate::api::tasks_storage; use crate::backend::databases::databases; use crate::backend::replication::orchestrator::Orchestrator; use crate::backend::schema::sync::config::ShardConfig; -use crate::backend::schema::sync::pg_dump::SyncState; use crate::config::{Config, Users}; use crate::frontend::router::cli::RouterCli; @@ -279,6 +280,28 @@ pub fn config_check( } } +/// Run an api task to completion in the foreground, cancelling it on Ctrl-C so +/// it can wind down (e.g. stop replication) instead of the process being +/// hard-killed. Returns the task output, or its error/cancellation outcome. +async fn run_to_completion(task: T) -> Result> +where + T::Error: std::error::Error + 'static, +{ + let mut waiter = run_task(task); + let id = waiter.id(); + + loop { + select! { + result = &mut waiter => return Ok(result?), + signal = ctrl_c() => { + signal?; + warn!("interrupt received, cancelling task {id}"); + tasks_storage().cancel_task(id); + } + } + } +} + /// FOR TESTING PURPOSES ONLY. pub async fn replicate_and_cutover(commands: Commands) -> Result<(), Box> { if let Commands::ReplicateAndCutover { @@ -288,22 +311,26 @@ pub async fn replicate_and_cutover(commands: Commands) -> Result<(), Box Result<(), Box> { - use crate::backend::replication::logical::Error; - if let Commands::DataSync { from_database, to_database, @@ -314,65 +341,23 @@ pub async fn data_sync(commands: Commands) -> Result<(), Box { - result?; - } - - _ = ctrl_c() => { - warn!("abort signal received, waiting 5 seconds and performing cleanup"); - sleep(Duration::from_secs(5)).await; - - orchestrator.cleanup().await?; - - return Err(Error::DataSyncAborted.into()); - } - } - } - - if !sync_only { - let mut waiter = orchestrator.replicate().await?; - - select! { - result = waiter.wait() => { - result?; - } - - _ = ctrl_c() => { - warn!("abort signal received"); - - orchestrator.request_stop().await; - - info!("waiting for replication to stop"); - - waiter.wait().await?; - orchestrator.cleanup().await?; - - return Err(Error::DataSyncAborted.into()); - } - } - } - } else { - return Ok(()); + let orchestrator = + Orchestrator::new(&from_database, &to_database, &publication, replication_slot)?; + + run_to_completion( + ReshardTask::builder() + .orchestrator(orchestrator) + .skip_schema_sync(skip_schema_sync) + .replicate_only(replicate_only) + .sync_only(sync_only) + .build(), + ) + .await?; } Ok(()) } -#[allow(clippy::print_stdout)] pub async fn schema_sync(commands: Commands) -> Result<(), Box> { if let Commands::SchemaSync { from_database, @@ -384,32 +369,25 @@ pub async fn schema_sync(commands: Commands) -> Result<(), Box) { .map(|general| general.log_format) .unwrap_or_default(); + let format = fmt::layer() + .with_ansi(std::io::stderr().is_terminal()) + .with_writer(std::io::stderr) + .with_file(false); + #[cfg(not(debug_assertions))] + let format = format.with_target(false); + match log_format { LogFormat::Text => { - let format = fmt::layer() - .with_ansi(std::io::stderr().is_terminal()) - .with_writer(std::io::stderr) - .with_file(false); - #[cfg(not(debug_assertions))] - let format = format.with_target(false); let format = format.with_filter(throttle); - let _ = tracing_subscriber::registry() .with(format) .with(filter) .try_init(); } LogFormat::Json | LogFormat::JsonFlattened => { - let format = fmt::layer() - .json() - .with_ansi(false) - .with_writer(std::io::stderr) - .with_file(false) - .with_current_span(false) - .with_span_list(false); + let format = format.json().with_current_span(false); let format = match log_format { LogFormat::JsonFlattened => format.flatten_event(true), _ => format, }; - #[cfg(not(debug_assertions))] - let format = format.with_target(false); let format = format.with_filter(throttle); let _ = tracing_subscriber::registry()