diff --git a/Cargo.lock b/Cargo.lock index 6aef6cfe..3ed7115e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1254,7 +1254,7 @@ dependencies = [ [[package]] name = "pgsqlite" -version = "0.0.18" +version = "0.0.21" dependencies = [ "anyhow", "arbitrary", diff --git a/Cargo.toml b/Cargo.toml index 5b68a7cf..f9f18c12 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgsqlite" -version = "0.0.20" +version = "0.0.21" edition = "2024" [features] diff --git a/docs/superpowers/plans/2026-03-27-pgadmin4-set-compat.md b/docs/superpowers/plans/2026-03-27-pgadmin4-set-compat.md new file mode 100644 index 00000000..a0a0eb97 --- /dev/null +++ b/docs/superpowers/plans/2026-03-27-pgadmin4-set-compat.md @@ -0,0 +1,376 @@ +# pgAdmin4 SET Command Compatibility Implementation Plan + +> **For agentic workers:** REQUIRED: Use superpowers:subagent-driven-development (if subagents available) or superpowers:executing-plans to implement this plan. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Fix SET command parsing and add `set_config()`/`pg_show_all_settings()` support so pgAdmin4 can connect to pgsqlite. + +**Architecture:** Fix the SET regex to allow `=` without surrounding spaces. Add query preprocessing in the executor to rewrite `pg_show_all_settings()` → `pg_settings` and handle `set_config()` with a synthetic wire response. Align `server_version` across all 4 reporting locations. + +**Tech Stack:** Rust, regex, PostgreSQL wire protocol + +**Spec:** `docs/superpowers/specs/2026-03-27-pgadmin4-set-compat-design.md` + +--- + +## File Structure + +| File | Role | +|------|------| +| `src/query/set_handler.rs` | SET/SHOW command handler — regex fix + version alignment | +| `src/query/executor.rs` | Query executor — add `set_config()` handler + `pg_show_all_settings()` rewrite | +| `src/session/state.rs` | Session state — version alignment | +| `src/functions/system_functions.rs` | SQLite-registered functions — version alignment | + +--- + +### Task 1: Fix SET_PARAMETER_PATTERN regex + +**Files:** +- Modify: `src/query/set_handler.rs:15-17` (regex definition) +- Modify: `src/query/set_handler.rs:192-228` (tests) + +- [ ] **Step 1: Add failing tests for SET without spaces** + +Add these tests to the existing `mod tests` block in `src/query/set_handler.rs` (starts at line 192): + +```rust +#[test] +fn test_set_parameter_pattern_equals_no_spaces() { + // Issue #71: pgAdmin4 sends SET DateStyle=ISO + assert!(SET_PARAMETER_PATTERN.is_match("SET DateStyle=ISO")); + assert!(SET_PARAMETER_PATTERN.is_match("SET client_min_messages=notice")); + assert!(SET_PARAMETER_PATTERN.is_match("SET client_encoding='utf-8'")); +} + +#[test] +fn test_set_parameter_pattern_equals_with_spaces() { + assert!(SET_PARAMETER_PATTERN.is_match("SET DateStyle = ISO")); + assert!(SET_PARAMETER_PATTERN.is_match("SET client_encoding = 'UTF8'")); +} + +#[test] +fn test_set_parameter_pattern_to_keyword() { + assert!(SET_PARAMETER_PATTERN.is_match("SET search_path TO public")); + assert!(SET_PARAMETER_PATTERN.is_match("SET client_encoding TO 'UTF8'")); +} + +#[test] +fn test_set_parameter_pattern_captures() { + let caps = SET_PARAMETER_PATTERN.captures("SET DateStyle=ISO").unwrap(); + assert_eq!(&caps[1], "DateStyle"); + assert_eq!(&caps[2], "ISO"); + + let caps = SET_PARAMETER_PATTERN.captures("SET client_encoding = 'UTF8'").unwrap(); + assert_eq!(&caps[1], "client_encoding"); + assert_eq!(&caps[2], "'UTF8'"); + + let caps = SET_PARAMETER_PATTERN.captures("SET search_path TO public").unwrap(); + assert_eq!(&caps[1], "search_path"); + assert_eq!(&caps[2], "public"); +} +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `cargo test --lib set_handler::tests -- --nocapture` +Expected: `test_set_parameter_pattern_equals_no_spaces` FAILS + +- [ ] **Step 3: Fix the regex** + +In `src/query/set_handler.rs`, change line 16 from: + +```rust +Regex::new(r"(?i)^\s*SET\s+(\w+)\s+(?:TO|=)\s+(.+)$").unwrap() +``` + +to: + +```rust +Regex::new(r"(?i)^\s*SET\s+(\w+)(?:\s*=\s*|\s+TO\s+)(.+)$").unwrap() +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `cargo test --lib set_handler::tests -- --nocapture` +Expected: All tests PASS + +- [ ] **Step 5: Commit** + +```bash +git add src/query/set_handler.rs +git commit -m "fix: allow SET without spaces around equals sign (#71)" +``` + +--- + +### Task 2: Add pg_show_all_settings() rewrite and set_config() handler + +**Files:** +- Modify: `src/query/executor.rs` — add static regexes, `preprocess_query` function, `set_config()` handler in `execute_single_statement`, and tests in existing `mod tests` block (line 2727) + +- [ ] **Step 1: Write failing tests** + +Add to the existing `mod tests` block in `src/query/executor.rs` (starts at line 2727): + +```rust +#[test] +fn test_pg_show_all_settings_rewrite() { + let query = "SELECT set_config('bytea_output','hex',false) FROM pg_show_all_settings() WHERE name = 'bytea_output'"; + let rewritten = preprocess_query(query); + assert!(rewritten.contains("pg_settings")); + assert!(!rewritten.contains("pg_show_all_settings()")); +} + +#[test] +fn test_pg_show_all_settings_case_insensitive() { + let query = "SELECT * FROM PG_SHOW_ALL_SETTINGS() WHERE name = 'timezone'"; + let rewritten = preprocess_query(query); + assert!(rewritten.contains("pg_settings")); +} + +#[test] +fn test_no_rewrite_when_not_present() { + let query = "SELECT * FROM pg_settings WHERE name = 'timezone'"; + let rewritten = preprocess_query(query); + assert_eq!(rewritten, query); +} + +#[test] +fn test_set_config_detection() { + let query = "SELECT set_config('bytea_output','hex',false) FROM pg_settings WHERE name = 'bytea_output'"; + assert!(SET_CONFIG_PATTERN.is_match(query)); +} + +#[test] +fn test_set_config_captures() { + let query = "SELECT set_config('bytea_output','hex',false)"; + let caps = SET_CONFIG_PATTERN.captures(query).unwrap(); + assert_eq!(&caps[1], "bytea_output"); + assert_eq!(&caps[2], "hex"); + assert_eq!(&caps[3], "false"); +} + +#[test] +fn test_set_config_empty_value() { + let query = "SELECT set_config('application_name','',false)"; + let caps = SET_CONFIG_PATTERN.captures(query).unwrap(); + assert_eq!(&caps[1], "application_name"); + assert_eq!(&caps[2], ""); + assert_eq!(&caps[3], "false"); +} + +#[test] +fn test_set_config_with_spaces() { + let query = "SELECT set_config( 'timezone' , 'UTC' , true )"; + let caps = SET_CONFIG_PATTERN.captures(query).unwrap(); + assert_eq!(&caps[1], "timezone"); + assert_eq!(&caps[2], "UTC"); + assert_eq!(&caps[3], "true"); +} + +#[test] +fn test_pgadmin4_full_query_preprocessing() { + // The exact pgAdmin4 query after semicolon splitting yields this statement: + let query = "SELECT set_config('bytea_output','hex',false) FROM pg_show_all_settings() WHERE name = 'bytea_output'"; + + // Step 1: preprocess rewrites pg_show_all_settings() → pg_settings + let rewritten = preprocess_query(query); + assert_eq!( + rewritten, + "SELECT set_config('bytea_output','hex',false) FROM pg_settings WHERE name = 'bytea_output'" + ); + + // Step 2: set_config pattern matches on the rewritten query + let caps = SET_CONFIG_PATTERN.captures(&rewritten).unwrap(); + assert_eq!(&caps[1], "bytea_output"); + assert_eq!(&caps[2], "hex"); + assert_eq!(&caps[3], "false"); +} +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `cargo test --lib executor::tests -- --nocapture 2>&1 | grep -E "FAIL|not found|error"` +Expected: FAIL — `preprocess_query` and `SET_CONFIG_PATTERN` do not exist + +- [ ] **Step 3: Implement the static regexes and preprocess_query function** + +Add to `src/query/executor.rs`, after the existing static declarations (after line 23, before the struct definitions): + +```rust +static PG_SHOW_ALL_SETTINGS_PATTERN: Lazy = Lazy::new(|| { + Regex::new(r"(?i)pg_show_all_settings\(\s*\)").unwrap() +}); + +static SET_CONFIG_PATTERN: Lazy = Lazy::new(|| { + Regex::new(r"(?i)set_config\(\s*'([^']+)'\s*,\s*'([^']*)'\s*,\s*(true|false)\s*\)").unwrap() +}); + +fn preprocess_query(query: &str) -> String { + if PG_SHOW_ALL_SETTINGS_PATTERN.is_match(query) { + PG_SHOW_ALL_SETTINGS_PATTERN.replace_all(query, "pg_settings").to_string() + } else { + query.to_string() + } +} +``` + +- [ ] **Step 4: Run regex and preprocessing tests** + +Run: `cargo test --lib executor::tests -- --nocapture 2>&1 | grep -E "test_pg_show|test_set_config|test_pgadmin"` +Expected: All matching tests PASS + +- [ ] **Step 5: Wire preprocessing and set_config handler into execute_single_statement** + +In `execute_single_statement`, after the failed-transaction check (after line 263's closing `}`) and before the ultra-fast path comment (line 264), add: + +```rust +// Preprocess query: rewrite pg_show_all_settings() → pg_settings +let query = preprocess_query(query); +let query: &str = query.as_str(); + +// Handle set_config() function calls +if let Some(caps) = SET_CONFIG_PATTERN.captures(query) { + let param_name = caps[1].to_string(); + let param_value = caps[2].to_string(); + // is_local (caps[3]) is ignored — pgsqlite doesn't support transaction-scoped settings + + debug!("Handling set_config('{}', '{}', ...)", param_name, param_value); + + // Set the parameter in the session + let mut params = session.parameters.write().await; + params.insert(param_name.to_uppercase(), param_value.clone()); + drop(params); + + // Send synthetic response: RowDescription + DataRow + CommandComplete + let field = FieldDescription { + name: "set_config".to_string(), + table_oid: 0, + column_id: 1, + type_oid: PgType::Text.to_oid(), + type_size: -1, + type_modifier: -1, + format: 0, + }; + framed.send(BackendMessage::RowDescription(vec![field])).await + .map_err(PgSqliteError::Io)?; + + let row = vec![Some(param_value.as_bytes().to_vec())]; + framed.send(BackendMessage::DataRow(row)).await + .map_err(PgSqliteError::Io)?; + + framed.send(BackendMessage::CommandComplete { + tag: "SELECT 1".to_string() + }).await.map_err(PgSqliteError::Io)?; + + return Ok(()); +} +``` + +Note on parameter key casing: The `to_uppercase()` call is consistent with the existing SET handler (`set_handler.rs:85`). Session parameters initialized in `state.rs` use mixed case (`"DateStyle"`, `"server_version"`), while SET/SHOW operations uppercase. This is a pre-existing inconsistency — not introduced by this change. + +- [ ] **Step 6: Run full test suite** + +Run: `cargo test` +Expected: All PASS + +Run: `cargo clippy` +Expected: No new warnings + +- [ ] **Step 7: Commit** + +```bash +git add src/query/executor.rs +git commit -m "feat: add pg_show_all_settings() rewrite and set_config() support (#71)" +``` + +--- + +### Task 3: Align server_version across all locations + +**Files:** +- Modify: `src/query/set_handler.rs:110-111` +- Modify: `src/session/state.rs:58` +- Modify: `src/functions/system_functions.rs:16` + +- [ ] **Step 1: Update set_handler.rs** + +In `src/query/set_handler.rs`, change lines 110-111 from: + +```rust +"SERVER_VERSION" => "15.0".to_string(), +"SERVER_VERSION_NUM" => "150000".to_string(), +``` + +to: + +```rust +"SERVER_VERSION" => "16.0".to_string(), +"SERVER_VERSION_NUM" => "160000".to_string(), +``` + +- [ ] **Step 2: Update state.rs** + +In `src/session/state.rs`, change line 58 from: + +```rust +parameters.insert("server_version".to_string(), "14.0 (SQLite wrapper)".to_string()); +``` + +to: + +```rust +parameters.insert("server_version".to_string(), "16.0".to_string()); +``` + +- [ ] **Step 3: Update system_functions.rs** + +In `src/functions/system_functions.rs`, change line 16 from: + +```rust +Ok(format!("PostgreSQL 15.0 (pgsqlite {}) on x86_64-pc-linux-gnu, compiled by rustc, 64-bit", +``` + +to: + +```rust +Ok(format!("PostgreSQL 16.0 (pgsqlite {}) on x86_64-pc-linux-gnu, compiled by rustc, 64-bit", +``` + +- [ ] **Step 4: Run full test suite** + +Run: `cargo test` +Expected: All PASS + +Run: `cargo clippy` +Expected: No new warnings + +- [ ] **Step 5: Commit** + +```bash +git add src/query/set_handler.rs src/session/state.rs src/functions/system_functions.rs +git commit -m "fix: align server_version to 16.0 across all reporting locations" +``` + +--- + +### Task 4: Final verification + +- [ ] **Step 1: Run pre-commit checklist** + +```bash +cargo check && cargo clippy && cargo build && cargo test +``` + +Expected: All pass with no errors or warnings. + +- [ ] **Step 2: Verify all pgAdmin4 compound query statements are handled** + +The compound query from issue #71 is split by `;` in the executor. After our changes, each statement should succeed: + +1. `SET DateStyle=ISO` — matches fixed regex (Task 1) ✓ +2. `SET client_min_messages=notice` — matches fixed regex (Task 1) ✓ +3. `SELECT set_config('bytea_output','hex',false) FROM pg_show_all_settings() WHERE name = 'bytea_output'` — `pg_show_all_settings()` rewritten to `pg_settings`, then `set_config()` handler intercepts and returns synthetic response (Task 2) ✓ +4. `SET client_encoding='utf-8'` — matches fixed regex (Task 1) ✓ + +The `test_pgadmin4_full_query_preprocessing` test in Task 2 validates the combined rewrite + intercept flow for statement 3. diff --git a/docs/superpowers/plans/2026-03-27-start-transaction.md b/docs/superpowers/plans/2026-03-27-start-transaction.md new file mode 100644 index 00000000..b13062e5 --- /dev/null +++ b/docs/superpowers/plans/2026-03-27-start-transaction.md @@ -0,0 +1,262 @@ +# START TRANSACTION / END TRANSACTION Implementation Plan + +> **For agentic workers:** REQUIRED: Use superpowers:subagent-driven-development (if subagents available) or superpowers:executing-plans to implement this plan. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Recognize `START TRANSACTION` as `BEGIN` and `END`/`END TRANSACTION` as `COMMIT` so postgres_fdw can connect. + +**Architecture:** Add `START` and `END` keyword detection in `QueryTypeDetector` (both fast path and fallback) and in the extended query protocol's transaction dispatch in `extended.rs`. No new translation or preprocessing needed — existing transaction handlers already call session methods without passing raw SQL to SQLite. + +**Tech Stack:** Rust + +**Spec:** `docs/superpowers/specs/2026-03-27-start-transaction-design.md` + +--- + +## File Structure + +| File | Role | +|------|------| +| `src/query/query_type_detection.rs` | Query type classifier — add `START` → `Begin`, `END` → `Commit` | +| `src/query/extended.rs` | Extended query protocol — add `START`/`END` to transaction dispatch and handler | + +--- + +### Task 1: Add START TRANSACTION detection in QueryTypeDetector + +**Files:** +- Modify: `src/query/query_type_detection.rs:37-43` (fast path 5-byte block) +- Modify: `src/query/query_type_detection.rs:85-86` (fallback path) +- Modify: `src/query/query_type_detection.rs:197-243` (tests) + +- [ ] **Step 1: Add failing tests** + +Add to the existing `mod tests` block in `src/query/query_type_detection.rs`, inside `test_query_type_detection`: + +```rust +// START TRANSACTION (issue #70) +assert_eq!(QueryTypeDetector::detect_query_type("START TRANSACTION"), QueryType::Begin); +assert_eq!(QueryTypeDetector::detect_query_type("start transaction"), QueryType::Begin); +assert_eq!(QueryTypeDetector::detect_query_type("Start Transaction"), QueryType::Begin); +assert_eq!(QueryTypeDetector::detect_query_type("START TRANSACTION ISOLATION LEVEL REPEATABLE READ"), QueryType::Begin); +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: `cargo test --lib query_type_detection::tests -- --nocapture` +Expected: FAIL — `START TRANSACTION` returns `QueryType::Other` + +- [ ] **Step 3: Add START to fast path** + +In `src/query/query_type_detection.rs`, add `START` to the existing `if bytes.len() >= 5` block at line 37-43: + +```rust +if bytes.len() >= 5 { + match &bytes[0..5] { + b"ALTER" | b"alter" | b"Alter" => return QueryType::Alter, + b"BEGIN" | b"begin" | b"Begin" => return QueryType::Begin, + b"START" | b"start" | b"Start" => return QueryType::Begin, + _ => {} + } +} +``` + +- [ ] **Step 4: Add START to fallback path** + +After the existing `BEGIN` check at line 85-86, add: + +```rust +} else if trimmed.len() >= 5 && trimmed[..5].eq_ignore_ascii_case("START") { + QueryType::Begin +``` + +- [ ] **Step 5: Run tests to verify they pass** + +Run: `cargo test --lib query_type_detection::tests -- --nocapture` +Expected: All PASS + +- [ ] **Step 6: Commit** + +```bash +git add src/query/query_type_detection.rs +git commit -m "feat: recognize START TRANSACTION as BEGIN (#70)" +``` + +--- + +### Task 2: Add END / END TRANSACTION detection in QueryTypeDetector + +**Files:** +- Modify: `src/query/query_type_detection.rs:37-65` (fast path — new 3-byte block) +- Modify: `src/query/query_type_detection.rs:85-90` (fallback path) +- Modify: `src/query/query_type_detection.rs:197-243` (tests) + +- [ ] **Step 1: Add failing tests** + +Add to `test_query_type_detection`: + +```rust +// END / END TRANSACTION +assert_eq!(QueryTypeDetector::detect_query_type("END"), QueryType::Commit); +assert_eq!(QueryTypeDetector::detect_query_type("end"), QueryType::Commit); +assert_eq!(QueryTypeDetector::detect_query_type("END TRANSACTION"), QueryType::Commit); +assert_eq!(QueryTypeDetector::detect_query_type("end transaction"), QueryType::Commit); +// Word boundary guard — END must not match identifiers starting with "END" +assert_ne!(QueryTypeDetector::detect_query_type("ENDLESS"), QueryType::Commit); +assert_ne!(QueryTypeDetector::detect_query_type("ENDTABLE"), QueryType::Commit); +``` + +Also add a new test for the `is_transaction` helper: + +```rust +#[test] +fn test_is_transaction_with_synonyms() { + assert!(QueryTypeDetector::is_transaction("BEGIN")); + assert!(QueryTypeDetector::is_transaction("START TRANSACTION")); + assert!(QueryTypeDetector::is_transaction("COMMIT")); + assert!(QueryTypeDetector::is_transaction("END")); + assert!(QueryTypeDetector::is_transaction("END TRANSACTION")); + assert!(QueryTypeDetector::is_transaction("ROLLBACK")); + assert!(!QueryTypeDetector::is_transaction("SELECT 1")); +} +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: `cargo test --lib query_type_detection::tests -- --nocapture` +Expected: FAIL — `END` returns `QueryType::Other` + +- [ ] **Step 3: Add END to fast path with word boundary guard** + +Add a new block after the existing `if bytes.len() >= 5` block (after line 43), before `if bytes.len() >= 6`: + +```rust +if bytes.len() >= 3 { + let first3 = &bytes[0..3]; + if (first3 == b"END" || first3 == b"end" || first3 == b"End") + && (bytes.len() == 3 || bytes[3].is_ascii_whitespace()) + { + return QueryType::Commit; + } +} +``` + +- [ ] **Step 4: Add END to fallback path with word boundary guard** + +After the `START` check added in Task 1, add: + +```rust +} else if trimmed.len() >= 3 && trimmed[..3].eq_ignore_ascii_case("END") + && (trimmed.len() == 3 || trimmed.as_bytes()[3].is_ascii_whitespace()) { + QueryType::Commit +``` + +- [ ] **Step 5: Run tests to verify they pass** + +Run: `cargo test --lib query_type_detection::tests -- --nocapture` +Expected: All PASS + +- [ ] **Step 6: Commit** + +```bash +git add src/query/query_type_detection.rs +git commit -m "feat: recognize END / END TRANSACTION as COMMIT (#70)" +``` + +--- + +### Task 3: Add START/END to extended query protocol + +**Files:** +- Modify: `src/query/extended.rs:1865-1868` (dispatch block) +- Modify: `src/query/extended.rs:5635-5647` (execute_transaction handler) + +- [ ] **Step 1: Update dispatch block** + +In `src/query/extended.rs`, change lines 1865-1868 from: + +```rust +} else if query_starts_with_ignore_case(&final_query, "BEGIN") + || query_starts_with_ignore_case(&final_query, "COMMIT") + || query_starts_with_ignore_case(&final_query, "ROLLBACK") { + Self::execute_transaction(framed, db, session, &final_query).await?; +``` + +to: + +```rust +} else if query_starts_with_ignore_case(&final_query, "BEGIN") + || query_starts_with_ignore_case(&final_query, "START") + || query_starts_with_ignore_case(&final_query, "COMMIT") + || query_starts_with_ignore_case(&final_query, "END") + || query_starts_with_ignore_case(&final_query, "ROLLBACK") { + Self::execute_transaction(framed, db, session, &final_query).await?; +``` + +- [ ] **Step 2: Update execute_transaction handler** + +In `src/query/extended.rs`, change lines 5635-5643 from: + +```rust +if query_starts_with_ignore_case(query, "BEGIN") { + db.begin_with_session(&session.id).await?; + framed.send(BackendMessage::CommandComplete { tag: "BEGIN".to_string() }).await + .map_err(PgSqliteError::Io)?; +} else if query_starts_with_ignore_case(query, "COMMIT") { + db.commit_with_session(&session.id).await?; + framed.send(BackendMessage::CommandComplete { tag: "COMMIT".to_string() }).await + .map_err(PgSqliteError::Io)?; +``` + +to: + +```rust +if query_starts_with_ignore_case(query, "BEGIN") + || query_starts_with_ignore_case(query, "START") { + db.begin_with_session(&session.id).await?; + framed.send(BackendMessage::CommandComplete { tag: "BEGIN".to_string() }).await + .map_err(PgSqliteError::Io)?; +} else if query_starts_with_ignore_case(query, "COMMIT") + || query_starts_with_ignore_case(query, "END") { + db.commit_with_session(&session.id).await?; + framed.send(BackendMessage::CommandComplete { tag: "COMMIT".to_string() }).await + .map_err(PgSqliteError::Io)?; +``` + +- [ ] **Step 3: Run full test suite** + +Run: `cargo test` +Expected: All PASS + +Run: `cargo clippy` +Expected: No new warnings + +- [ ] **Step 4: Commit** + +```bash +git add src/query/extended.rs +git commit -m "feat: add START/END transaction support to extended protocol (#70)" +``` + +--- + +### Task 4: Final verification + +- [ ] **Step 1: Run pre-commit checklist** + +```bash +cargo check && cargo clippy && cargo build && cargo test +``` + +Expected: All pass with no errors. + +- [ ] **Step 2: Verify the postgres_fdw query is handled** + +The query from issue #70: +```sql +START TRANSACTION ISOLATION LEVEL REPEATABLE READ +``` + +After our changes: +1. `QueryTypeDetector::detect_query_type("START TRANSACTION ISOLATION LEVEL REPEATABLE READ")` → `QueryType::Begin` (Task 1) +2. Routes to `execute_transaction` which calls `db.begin_with_session()` — isolation level clause never reaches SQLite +3. Extended protocol also handles `START` via dispatch block (Task 3) diff --git a/docs/superpowers/plans/2026-03-27-wire-type-oid-fix.md b/docs/superpowers/plans/2026-03-27-wire-type-oid-fix.md new file mode 100644 index 00000000..7094499f --- /dev/null +++ b/docs/superpowers/plans/2026-03-27-wire-type-oid-fix.md @@ -0,0 +1,480 @@ +# Wire Protocol Type OID Fix Implementation Plan + +> **For agentic workers:** REQUIRED: Use superpowers:subagent-driven-development (if subagents available) or superpowers:executing-plans to implement this plan. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Fix wire protocol type OIDs so pre-existing SQLite tables return correct OIDs (int4, float8, bool, etc.) instead of always returning OID 25 (TEXT). + +**Architecture:** Add a `sqlite_type_to_pg_type_name` helper and a `get_column_types_from_pragma` DbHandler method that runs `PRAGMA table_info` to get SQLite declared types. Use this as a fallback in the schema_types population loops in executor.rs and extended.rs when `__pgsqlite_schema` has no entry. + +**Tech Stack:** Rust + +**Spec:** `docs/superpowers/specs/2026-03-27-wire-type-oid-fix-design.md` + +--- + +## File Structure + +| File | Role | +|------|------| +| `src/types/sqlite_type_info.rs` | Add `sqlite_type_to_pg_type_name` — maps SQLite declared types to PG type name strings | +| `src/session/db_handler.rs` | Add `get_column_types_from_pragma` — fetches PRAGMA table_info and returns column→PG type name map | +| `src/query/executor.rs` | Add PRAGMA fallback in schema_types population loop | +| `src/query/extended.rs` | Add PRAGMA fallback in 3 locations (schema_types x2, inferred_types x1) | + +--- + +### Task 1: Add `sqlite_type_to_pg_type_name` helper + +**Files:** +- Modify: `src/types/sqlite_type_info.rs:62-116` + +- [ ] **Step 1: Add failing tests** + +Add a `#[cfg(test)]` module at the end of `src/types/sqlite_type_info.rs`: + +```rust +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sqlite_type_to_pg_type_name() { + assert_eq!(sqlite_type_to_pg_type_name("INTEGER"), "integer"); + assert_eq!(sqlite_type_to_pg_type_name("integer"), "integer"); + assert_eq!(sqlite_type_to_pg_type_name("INT"), "integer"); + assert_eq!(sqlite_type_to_pg_type_name("BIGINT"), "bigint"); + assert_eq!(sqlite_type_to_pg_type_name("INT8"), "bigint"); + assert_eq!(sqlite_type_to_pg_type_name("SMALLINT"), "smallint"); + assert_eq!(sqlite_type_to_pg_type_name("INT2"), "smallint"); + assert_eq!(sqlite_type_to_pg_type_name("REAL"), "double precision"); + assert_eq!(sqlite_type_to_pg_type_name("FLOAT"), "double precision"); + assert_eq!(sqlite_type_to_pg_type_name("DOUBLE PRECISION"), "double precision"); + assert_eq!(sqlite_type_to_pg_type_name("BOOLEAN"), "boolean"); + assert_eq!(sqlite_type_to_pg_type_name("BOOL"), "boolean"); + assert_eq!(sqlite_type_to_pg_type_name("TEXT"), "text"); + assert_eq!(sqlite_type_to_pg_type_name("VARCHAR(50)"), "text"); + assert_eq!(sqlite_type_to_pg_type_name("BLOB"), "bytea"); + assert_eq!(sqlite_type_to_pg_type_name("DATE"), "date"); + assert_eq!(sqlite_type_to_pg_type_name("TIMESTAMP"), "timestamp"); + assert_eq!(sqlite_type_to_pg_type_name("NUMERIC"), "numeric"); + assert_eq!(sqlite_type_to_pg_type_name("DECIMAL"), "numeric"); + assert_eq!(sqlite_type_to_pg_type_name("UUID"), "uuid"); + assert_eq!(sqlite_type_to_pg_type_name("JSON"), "json"); + assert_eq!(sqlite_type_to_pg_type_name("SOMETHING_UNKNOWN"), "text"); + } +} +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: `cargo test --lib sqlite_type_info::tests -- --nocapture` +Expected: FAIL — `sqlite_type_to_pg_type_name` not found + +- [ ] **Step 3: Implement `sqlite_type_to_pg_type_name`** + +Add after `sqlite_type_to_pg_oid` (after line 116) in `src/types/sqlite_type_info.rs`: + +```rust +/// Convert SQLite type declaration to PostgreSQL type name string +pub fn sqlite_type_to_pg_type_name(sqlite_type: &str) -> &'static str { + let type_upper = sqlite_type.to_uppercase(); + + if type_upper.contains("BLOB") { + return "bytea"; + } + + if type_upper.contains("REAL") || type_upper.contains("FLOAT") || type_upper.contains("DOUBLE") { + return "double precision"; + } + + if type_upper.contains("INT") { + if type_upper.contains("INT2") || type_upper.contains("SMALLINT") { + return "smallint"; + } else if type_upper.contains("INT8") || type_upper.contains("BIGINT") { + return "bigint"; + } else { + return "integer"; + } + } + + if type_upper.contains("BOOL") { + return "boolean"; + } + + if type_upper.contains("DATE") && !type_upper.contains("TIME") { + return "date"; + } + + if type_upper.contains("TIME") && !type_upper.contains("STAMP") { + return "time"; + } + + if type_upper.contains("TIMESTAMP") { + return "timestamp"; + } + + if type_upper.contains("NUMERIC") || type_upper.contains("DECIMAL") { + return "numeric"; + } + + if type_upper.contains("UUID") { + return "uuid"; + } + + if type_upper.contains("JSON") { + return "json"; + } + + "text" +} +``` + +- [ ] **Step 4: Run test to verify it passes** + +Run: `cargo test --lib sqlite_type_info::tests -- --nocapture` +Expected: All PASS + +- [ ] **Step 5: Commit** + +```bash +git add src/types/sqlite_type_info.rs +git commit -m "feat: add sqlite_type_to_pg_type_name helper (#68)" +``` + +--- + +### Task 2: Add `get_column_types_from_pragma` to DbHandler + +**Files:** +- Modify: `src/session/db_handler.rs:2272` (after `get_schema_type_with_session`) + +- [ ] **Step 1: Add the method** + +Add after the closing `}` of `get_schema_type_with_session` (after line 2272) in `src/session/db_handler.rs`: + +```rust + /// Get column types from PRAGMA table_info for tables without __pgsqlite_schema metadata. + /// Returns a HashMap mapping column names to PostgreSQL type name strings. + pub async fn get_column_types_from_pragma( + &self, + session_id: &Uuid, + table_name: &str, + ) -> Result, PgSqliteError> { + let table_name = table_name.to_string(); + self.with_session_connection(session_id, move |conn| { + let mut result = std::collections::HashMap::new(); + let query = format!("PRAGMA table_info(\"{}\")", table_name); + let mut stmt = conn.prepare(&query)?; + let mut rows = stmt.query_map([], |row| { + let col_name: String = row.get(1)?; + let col_type: String = row.get(2)?; + Ok((col_name, col_type)) + })?; + while let Some(Ok((col_name, col_type))) = rows.next() { + let pg_type_name = crate::types::sqlite_type_info::sqlite_type_to_pg_type_name(&col_type); + result.insert(col_name, pg_type_name.to_string()); + } + Ok(result) + }).await + } +``` + +- [ ] **Step 2: Run build to verify it compiles** + +Run: `cargo check` +Expected: No errors (warnings OK) + +- [ ] **Step 3: Commit** + +```bash +git add src/session/db_handler.rs +git commit -m "feat: add get_column_types_from_pragma to DbHandler (#68)" +``` + +--- + +### Task 3: Add PRAGMA fallback in executor.rs + +**Files:** +- Modify: `src/query/executor.rs:1088-1155` (schema_types population loop) + +- [ ] **Step 1: Add PRAGMA cache and fallback** + +In `src/query/executor.rs`, at line 1088, change the `schema_types` population block from: + +```rust + if let Some(ref table) = table_name { + debug!("Type inference: Found table name '{}', looking up schema for {} columns", table, response.columns.len()); + + // Extract column mappings from query if possible + let column_mappings = extract_column_mappings_from_query(query, table); + + // Fetch types for actual columns + for col_name in &response.columns { + // Try direct lookup first + if let Ok(Some(pg_type)) = db.get_schema_type_with_session(&session.id, table, col_name).await { +``` + +to: + +```rust + if let Some(ref table) = table_name { + debug!("Type inference: Found table name '{}', looking up schema for {} columns", table, response.columns.len()); + + // Pre-fetch PRAGMA table_info as fallback for tables without __pgsqlite_schema + let pragma_types = db.get_column_types_from_pragma(&session.id, table).await.unwrap_or_default(); + + // Extract column mappings from query if possible + let column_mappings = extract_column_mappings_from_query(query, table); + + // Fetch types for actual columns + for col_name in &response.columns { + // Try direct lookup first + if let Ok(Some(pg_type)) = db.get_schema_type_with_session(&session.id, table, col_name).await { +``` + +Then add a PRAGMA fallback at the end of the `for col_name` loop body — after line 1163 (close of the `else` block) and before line 1164 (close of the `for` loop). This placement is at the `for` loop level so it catches columns that went through ANY branch (direct lookup, column_mappings, alias resolution) and still weren't found: + +```rust + // PRAGMA table_info fallback for pre-existing SQLite tables + if !schema_types.contains_key(col_name) { + if let Some(pg_type) = pragma_types.get(col_name) { + debug!("Type inference: Found PRAGMA type for '{}.{}' -> {}", table, col_name, pg_type); + schema_types.insert(col_name.clone(), pg_type.clone()); + } + } +``` + +This goes between lines 1163 and 1164 (just before the `for` loop's closing `}`). + +- [ ] **Step 2: Run tests** + +Run: `cargo test` +Expected: All pass + +- [ ] **Step 3: Commit** + +```bash +git add src/query/executor.rs +git commit -m "feat: add PRAGMA type fallback in simple query protocol (#68)" +``` + +--- + +### Task 4: Add PRAGMA fallback in extended.rs schema_types (2 locations) + +**Files:** +- Modify: `src/query/extended.rs:563-640` (Parse-time schema_types) +- Modify: `src/query/extended.rs:4487-4511` (Execute-time schema_types) + +- [ ] **Step 1: Add PRAGMA fallback to Parse-time schema_types (line 563)** + +In `src/query/extended.rs`, at line 564, change: + +```rust + if let Some(ref table) = table_name { + info!("PARSE: Fetching schema types for table '{}'", table); + // For aliased columns, try to find the source column + for col_name in &response.columns { +``` + +to: + +```rust + if let Some(ref table) = table_name { + info!("PARSE: Fetching schema types for table '{}'", table); + let pragma_types = db.get_column_types_from_pragma(&session.id, table).await.unwrap_or_default(); + // For aliased columns, try to find the source column + for col_name in &response.columns { +``` + +Then add a PRAGMA fallback at the end of the `for col_name` loop body — after line 648 (close of the `else` block) and before line 649 (close of the `for` loop). This placement is at the `for` loop level: + +```rust + // PRAGMA table_info fallback + if !schema_types.contains_key(col_name) { + if let Some(pg_type) = pragma_types.get(col_name) { + info!("PARSE: Found PRAGMA type for '{}.{}' -> {}", table, col_name, pg_type); + schema_types.insert(col_name.clone(), pg_type.clone()); + } + } +``` + +This goes between lines 648 and 649 (just before the `for` loop's closing `}`). + +- [ ] **Step 2: Add PRAGMA fallback to Execute-time schema_types (line 4487)** + +In `src/query/extended.rs`, at line 4488, change: + +```rust + if let Some(ref table) = table_name { + for col_name in &response.columns { +``` + +to: + +```rust + if let Some(ref table) = table_name { + let pragma_types = db.get_column_types_from_pragma(&session.id, table).await.unwrap_or_default(); + for col_name in &response.columns { +``` + +Then add a PRAGMA fallback at the end of the `for col_name` loop body — before line 4511 (the `for` loop's closing `}`): + +```rust + // PRAGMA table_info fallback + if !schema_types.contains_key(col_name) { + if let Some(pg_type) = pragma_types.get(col_name) { + info!("Found PRAGMA type for '{}.{}' -> {}", table, col_name, pg_type); + schema_types.insert(col_name.clone(), pg_type.clone()); + } + } +``` + +- [ ] **Step 3: Run tests** + +Run: `cargo test` +Expected: All pass + +- [ ] **Step 4: Commit** + +```bash +git add src/query/extended.rs +git commit -m "feat: add PRAGMA type fallback in extended protocol schema_types (#68)" +``` + +--- + +### Task 5: Add PRAGMA fallback in extended.rs inferred_types + +**Files:** +- Modify: `src/query/extended.rs:836-903` (Parse-time inferred_types fallback) + +- [ ] **Step 1: Add PRAGMA lookup at inferred_types fallback points** + +The `inferred_types` construction (lines 654-912) has 4 TEXT fallback points where a table name is known but `get_schema_type_with_session` returned `None`. At each of these, insert a PRAGMA lookup before defaulting to TEXT. + +First, add a `pragma_types` cache at the start of the `inferred_types` loop. Find line 654: + +```rust + let mut inferred_types = Vec::new(); +``` + +Add after it: + +```rust + let pragma_types_for_inferred = if let Some(ref table) = table_name { + db.get_column_types_from_pragma(&session.id, table).await.unwrap_or_default() + } else { + std::collections::HashMap::new() + }; +``` + +Then change the 4 fallback points. At lines 848-856 (after `extract_source_table_column_for_alias` lookup), change: + +```rust + Ok(None) => { + info!("Column '{}': no schema type found for '{}.{}', defaulting to text", + col_name, source_table, source_col); + inferred_types.push(PgType::Text.to_oid()); + } + Err(_) => { + // Schema lookup error, defaulting to text + inferred_types.push(PgType::Text.to_oid()); + } +``` + +to: + +```rust + Ok(None) => { + if let Some(pg_type) = pragma_types_for_inferred.get(col_name) { + let type_oid = crate::types::SchemaTypeMapper::pg_type_string_to_oid(pg_type); + info!("Column '{}': resolved type from PRAGMA -> {} (OID {})", col_name, pg_type, type_oid); + inferred_types.push(type_oid); + } else { + info!("Column '{}': no schema type found for '{}.{}', defaulting to text", + col_name, source_table, source_col); + inferred_types.push(PgType::Text.to_oid()); + } + } + Err(_) => { + if let Some(pg_type) = pragma_types_for_inferred.get(col_name) { + let type_oid = crate::types::SchemaTypeMapper::pg_type_string_to_oid(pg_type); + inferred_types.push(type_oid); + } else { + inferred_types.push(PgType::Text.to_oid()); + } + } +``` + +At lines 896-904 (after `extract_table_name_from_select` + `get_schema_type_with_session`), change: + +```rust + Ok(None) => { + info!("Column '{}': no schema type found for '{}.{}', defaulting to text", + col_name, table_name, col_name); + inferred_types.push(PgType::Text.to_oid()); + } + Err(_) => { + // Schema lookup error, defaulting to text + inferred_types.push(PgType::Text.to_oid()); + } +``` + +to: + +```rust + Ok(None) => { + if let Some(pg_type) = pragma_types_for_inferred.get(col_name) { + let type_oid = crate::types::SchemaTypeMapper::pg_type_string_to_oid(pg_type); + info!("Column '{}': resolved type from PRAGMA -> {} (OID {})", col_name, pg_type, type_oid); + inferred_types.push(type_oid); + } else { + info!("Column '{}': no schema type found for '{}.{}', defaulting to text", + col_name, table_name, col_name); + inferred_types.push(PgType::Text.to_oid()); + } + } + Err(_) => { + if let Some(pg_type) = pragma_types_for_inferred.get(col_name) { + let type_oid = crate::types::SchemaTypeMapper::pg_type_string_to_oid(pg_type); + inferred_types.push(type_oid); + } else { + inferred_types.push(PgType::Text.to_oid()); + } + } +``` + +- [ ] **Step 2: Run tests** + +Run: `cargo test` +Expected: All pass + +- [ ] **Step 3: Commit** + +```bash +git add src/query/extended.rs +git commit -m "feat: add PRAGMA type fallback in extended protocol inferred_types (#68)" +``` + +--- + +### Task 6: Final verification + +- [ ] **Step 1: Run pre-commit checklist** + +```bash +cargo check && cargo clippy && cargo build && cargo test +``` + +Expected: All pass with no errors. + +- [ ] **Step 2: Verify the fix conceptually** + +After our changes, for a pre-existing SQLite table with `CREATE TABLE test (id INTEGER, name TEXT, score REAL, active BOOLEAN)`: + +1. `get_column_types_from_pragma` returns `{"id": "integer", "name": "text", "score": "double precision", "active": "boolean"}` +2. In executor.rs, `schema_types` gets populated from PRAGMA when `__pgsqlite_schema` has no entry +3. `pg_type_string_to_oid("integer")` → 23, `pg_type_string_to_oid("double precision")` → 701, `pg_type_string_to_oid("boolean")` → 16 +4. FieldDescription sends correct OIDs on the wire diff --git a/docs/superpowers/plans/2026-05-08-column-name-sanitization.md b/docs/superpowers/plans/2026-05-08-column-name-sanitization.md new file mode 100644 index 00000000..2f61986e --- /dev/null +++ b/docs/superpowers/plans/2026-05-08-column-name-sanitization.md @@ -0,0 +1,377 @@ +# Column Name Sanitization Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Strip parenthesized suffixes from SQLite result column names to match PostgreSQL behavior (e.g., `version()` → `version`). + +**Architecture:** Add a `sanitize_column_name()` function that strips everything from the first `(` onward when present. Apply it at every `stmt.column_name(i)` collection site so that `DbResponse.columns` and all downstream code (FieldDescription, type lookups) receive PostgreSQL-compatible column names. + +**Tech Stack:** Rust, existing pgsqlite codebase patterns. + +--- + +### Task 1: Create the `sanitize_column_name` module + +**Files:** +- Create: `src/query/column_sanitizer.rs` +- Modify: `src/query/mod.rs` + +- [ ] **Step 1: Create the sanitizer module** + +```rust +// src/query/column_sanitizer.rs + +/// Strip parenthesized arguments from column names to match PostgreSQL behavior. +/// +/// PostgreSQL returns just the function name as the column name for function calls: +/// SELECT version() → column "version" +/// SELECT count(*) → column "count" +/// SELECT max(id) → column "max" +/// +/// SQLite returns the full expression: +/// SELECT version() → column "version()" +/// SELECT count(*) → column "count(*)" +/// SELECT max(id) → column "max(id)" +/// +/// This function normalizes SQLite's behavior to match PostgreSQL. +pub fn sanitize_column_name(name: &str) -> &str { + if let Some(pos) = name.find('(') { + &name[..pos] + } else { + name + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_basic_function_call() { + assert_eq!(sanitize_column_name("version()"), "version"); + } + + #[test] + fn test_function_with_star_arg() { + assert_eq!(sanitize_column_name("count(*)"), "count"); + } + + #[test] + fn test_function_with_column_arg() { + assert_eq!(sanitize_column_name("max(id)"), "max"); + } + + #[test] + fn test_nested_function_call() { + assert_eq!(sanitize_column_name("COALESCE(max(id), 0)"), "COALESCE"); + } + + #[test] + fn test_no_parens() { + assert_eq!(sanitize_column_name("current_timestamp"), "current_timestamp"); + } + + #[test] + fn test_regular_column() { + assert_eq!(sanitize_column_name("my_column"), "my_column"); + } + + #[test] + fn test_schema_qualified() { + assert_eq!(sanitize_column_name("pg_catalog.version()"), "pg_catalog.version"); + } + + #[test] + fn test_empty_string() { + assert_eq!(sanitize_column_name(""), ""); + } + + #[test] + fn test_just_parens() { + assert_eq!(sanitize_column_name("()"), ""); + } +} +``` + +- [ ] **Step 2: Register the module in mod.rs** + +Add to `src/query/mod.rs`: + +```rust +pub mod column_sanitizer; +``` + +- [ ] **Step 3: Run tests to verify module compiles and tests pass** + +Run: `cargo test --lib column_sanitizer` +Expected: All 9 tests pass. + +- [ ] **Step 4: Commit** + +```bash +git add src/query/column_sanitizer.rs src/query/mod.rs +git commit -m "feat: add column name sanitizer to strip function parentheses" +``` + +--- + +### Task 2: Apply sanitization at all `stmt.column_name(i)` sites in db_handler.rs + +**Files:** +- Modify: `src/session/db_handler.rs` + +There are 9 sites in db_handler.rs where `stmt.column_name(i)` is called. Each follows the pattern `columns.push(stmt.column_name(i)?.to_string())` or `column_names.push(stmt.column_name(i).unwrap_or("").to_string())`. Change each to wrap with `sanitize_column_name()`. + +- [ ] **Step 1: Add import at top of db_handler.rs** + +Add to the imports section of `src/session/db_handler.rs`: + +```rust +use crate::query::column_sanitizer::sanitize_column_name; +``` + +- [ ] **Step 2: Apply at line ~529 — SELECT path** + +Change: +```rust +columns.push(stmt.column_name(i)?.to_string()); +``` +To: +```rust +columns.push(sanitize_column_name(stmt.column_name(i)?).to_string()); +``` + +- [ ] **Step 3: Apply at line ~664 — another SELECT path** + +Same change as Step 2. + +- [ ] **Step 4: Apply at line ~777 — catalog query path** + +Same change. + +- [ ] **Step 5: Apply at line ~895 — another path** + +Same change. + +- [ ] **Step 6: Apply at line ~1242 — another path** + +Same change. + +- [ ] **Step 7: Apply at line ~1453 — another path** + +Same change. + +- [ ] **Step 8: Apply at line ~1722 — another path** + +Same change. + +- [ ] **Step 9: Apply at line ~2326 — type inference path (uses unwrap_or)** + +Change: +```rust +column_names.push(stmt.column_name(i).unwrap_or("").to_string()); +``` +To: +```rust +column_names.push(sanitize_column_name(stmt.column_name(i).unwrap_or("")).to_string()); +``` + +- [ ] **Step 10: Apply at line ~2424 — type inference path (uses unwrap_or)** + +Same change as Step 9. + +- [ ] **Step 11: Run cargo check** + +Run: `cargo check` +Expected: No errors. + +- [ ] **Step 12: Commit** + +```bash +git add src/session/db_handler.rs +git commit -m "feat: apply column name sanitization in db_handler" +``` + +--- + +### Task 3: Apply sanitization in fast_path.rs + +**Files:** +- Modify: `src/query/fast_path.rs` + +3 sites, all following `columns.push(stmt.column_name(i)?.to_string())`. + +- [ ] **Step 1: Add import** + +```rust +use crate::query::column_sanitizer::sanitize_column_name; +``` + +- [ ] **Step 2: Apply at line ~377** + +Change: +```rust +columns.push(stmt.column_name(i)?.to_string()); +``` +To: +```rust +columns.push(sanitize_column_name(stmt.column_name(i)?).to_string()); +``` + +- [ ] **Step 3: Apply at line ~619** + +Same change. + +- [ ] **Step 4: Apply at line ~737** + +Same change. + +- [ ] **Step 5: Run cargo check** + +Run: `cargo check` +Expected: No errors. + +- [ ] **Step 6: Commit** + +```bash +git add src/query/fast_path.rs +git commit -m "feat: apply column name sanitization in fast_path" +``` + +--- + +### Task 4: Apply sanitization in remaining files + +**Files:** +- Modify: `src/catalog/query_interceptor.rs` +- Modify: `src/cache/statement_pool.rs` +- Modify: `src/optimization/read_only_optimizer.rs` + +- [ ] **Step 1: Apply in query_interceptor.rs (2 sites: ~274, ~415)** + +Add import: +```rust +use crate::query::column_sanitizer::sanitize_column_name; +``` + +Change both sites from: +```rust +columns.push(stmt.column_name(i)?.to_string()); +``` +To: +```rust +columns.push(sanitize_column_name(stmt.column_name(i)?).to_string()); +``` + +- [ ] **Step 2: Apply in statement_pool.rs (1 site: ~197)** + +Add import: +```rust +use crate::query::column_sanitizer::sanitize_column_name; +``` + +Change: +```rust +column_names.push(stmt.column_name(i)?.to_string()); +``` +To: +```rust +column_names.push(sanitize_column_name(stmt.column_name(i)?).to_string()); +``` + +- [ ] **Step 3: Apply in read_only_optimizer.rs (1 site: ~262)** + +Add import: +```rust +use crate::query::column_sanitizer::sanitize_column_name; +``` + +Change: +```rust +columns.push(stmt.column_name(i)?.to_string()); +``` +To: +```rust +columns.push(sanitize_column_name(stmt.column_name(i)?).to_string()); +``` + +- [ ] **Step 4: Run cargo check** + +Run: `cargo check` +Expected: No errors. + +- [ ] **Step 5: Commit** + +```bash +git add src/catalog/query_interceptor.rs src/cache/statement_pool.rs src/optimization/read_only_optimizer.rs +git commit -m "feat: apply column name sanitization in remaining files" +``` + +--- + +### Task 5: Check extended query path for column name handling + +**Files:** +- Modify: `src/query/extended.rs` (if needed) + +The extended query path may get column names from cached `StatementMetadata` (from statement_pool.rs, which we already fixed) or from `DbResponse.columns`. Read the relevant code to determine if any additional sanitization is needed. + +- [ ] **Step 1: Inspect extended.rs for column name sources** + +Search for how column names enter the extended path. They come from: +1. `DbResponse.columns` — already sanitized via Task 2-4 +2. `StatementMetadata.column_names` — already sanitized via statement_pool.rs in Task 4 +3. Hardcoded column names in catalog queries — these don't need sanitization (they're already correct PostgreSQL names) + +No additional changes expected in extended.rs. + +- [ ] **Step 2: Run cargo check** + +Run: `cargo check` +Expected: No errors. + +--- + +### Task 6: Run full test suite and verify + +- [ ] **Step 1: Run cargo test** + +Run: `cargo test` +Expected: All tests pass. + +- [ ] **Step 2: Run cargo clippy** + +Run: `cargo clippy` +Expected: No new warnings. + +- [ ] **Step 3: Build** + +Run: `cargo build` +Expected: Successful build. + +- [ ] **Step 4: Manual integration test — verify version() column name** + +Start pgsqlite, connect with psql, and run: + +```sql +SELECT version(); +``` + +Verify the column name in the response is `version` (not `version()`). + +- [ ] **Step 5: Manual integration test — verify other function calls** + +```sql +SELECT count(*) FROM sqlite_master; +SELECT 1, current_timestamp; +SELECT max(1); +``` + +Verify column names are `count`, `current_timestamp`, `max` respectively. + +- [ ] **Step 6: Final commit if any fixes needed** + +```bash +git add -A +git commit -m "fix: any fixes from integration testing" +``` \ No newline at end of file diff --git a/docs/superpowers/specs/2026-03-27-pgadmin4-set-compat-design.md b/docs/superpowers/specs/2026-03-27-pgadmin4-set-compat-design.md new file mode 100644 index 00000000..6958bad1 --- /dev/null +++ b/docs/superpowers/specs/2026-03-27-pgadmin4-set-compat-design.md @@ -0,0 +1,116 @@ +# pgAdmin4 SET Command Compatibility + +**Issue**: [#71](https://github.com/erans/pgsqlite/issues/71) - SET without spaces around equal sign +**Date**: 2026-03-27 +**Scope**: Fix SET parsing + add `set_config()` and `pg_show_all_settings()` for pgAdmin4 compatibility + +## Problem + +pgAdmin4 sends this compound query on connection: + +```sql +SET DateStyle=ISO; SET client_min_messages=notice; SELECT set_config('bytea_output','hex',false) FROM pg_show_all_settings() WHERE name = 'bytea_output'; SET client_encoding='utf-8'; +``` + +Three things fail: +1. `SET DateStyle=ISO` — regex requires spaces around `=` +2. `pg_show_all_settings()` — function not recognized +3. `set_config(...)` — function not implemented + +## Design + +### 1. SET Regex Fix + +**File**: `src/query/set_handler.rs` + +Current regex (`SET_PARAMETER_PATTERN`): +``` +(?i)^\s*SET\s+(\w+)\s+(?:TO|=)\s+(.+)$ +``` + +Fixed regex: +``` +(?i)^\s*SET\s+(\w+)(?:\s*=\s*|\s+TO\s+)(.+)$ +``` + +- `=` allows optional whitespace on both sides (`\s*=\s*`) +- `TO` still requires whitespace on both sides (`\s+TO\s+`) to prevent ambiguity with parameter names + +Covers all PostgreSQL-valid forms: `SET x=y`, `SET x = y`, `SET x TO y`. + +### 2. pg_show_all_settings() Rewrite + +**Location**: Query preprocessing in `src/query/executor.rs`, before SQL parsing + +`pg_show_all_settings()` returns the same data as `pg_settings`. Rewrite the function call to the table name via case-insensitive text replacement: + +``` +pg_show_all_settings() → pg_settings +``` + +After rewriting, the query routes through the existing `PgSettingsHandler`. + +**Limitation**: This is a plain text substitution. It would incorrectly rewrite the function name inside string literals (e.g., `'pg_show_all_settings()'`). In practice this never occurs in real client queries. + +### 3. set_config() Function + +**Location**: Handled as a special case in `src/query/executor.rs`, at the `execute_single_statement` level (before normal query routing). + +Detection regex: +``` +(?i)set_config\(\s*'([^']+)'\s*,\s*'([^']*)'\s*,\s*(true|false)\s*\) +``` + +Note: The value capture uses `[^']*` (star, not plus) to allow empty string values like `set_config('application_name', '', false)`. + +Processing: +1. Detect that the query is a SELECT containing `set_config(...)` +2. Extract setting name, new value, and is_local flag from regex captures +3. Set the parameter in the session state (same as SET handler) +4. Send a synthetic response directly on the wire: + - `RowDescription`: one column named `set_config`, type Text + - `DataRow`: one row containing the new value + - `CommandComplete`: tag `"SELECT 1"` +5. Return early — do **not** pass through to `PgSettingsHandler` or normal query execution + +This avoids the problem of `PgSettingsHandler` not knowing how to handle string literal projections. The `set_config()` handler owns the full response. + +**`is_local` flag**: Treated identically to `false` (session-level SET). pgsqlite does not support transaction-scoped settings. This is a known limitation — the parameter is always set at session scope regardless of the flag value. + +**Error handling**: If `set_config()` arguments can't be parsed, the query falls through to normal execution. + +### 4. server_version Consistency (Drive-by Fix) + +Four locations report different PostgreSQL version numbers: + +| Location | Current Value | Action | +|----------|--------------|--------| +| `src/catalog/pg_settings.rs:102` | `16.0` | Keep (canonical) | +| `src/query/set_handler.rs:110` | `15.0` | Update to `16.0` | +| `src/session/state.rs:58` | `14.0 (SQLite wrapper)` | Update to `16.0` | +| `src/functions/system_functions.rs:16` | `15.0` | Update to `16.0` | + +Align all to `16.0` so pgAdmin4 sees a consistent version across all code paths. + +## Files Changed + +| File | Change | +|------|--------| +| `src/query/set_handler.rs` | Fix `SET_PARAMETER_PATTERN` regex; update `server_version` to `16.0` | +| `src/query/executor.rs` | Add `pg_show_all_settings()` rewrite and `set_config()` handler in preprocessing | +| `src/session/state.rs` | Update `server_version` to `16.0` | +| `src/functions/system_functions.rs` | Update version string to `16.0` | + +## Testing + +- Unit tests for SET regex: `SET x=y`, `SET x = y`, `SET x TO y` variants +- Unit test for `pg_show_all_settings()` rewrite +- Unit test for `set_config()` detection, parameter extraction, and empty-value handling +- Integration test with the exact pgAdmin4 compound query + +## Known Limitations + +- `set_config()` with `is_local=true` behaves as session-level (no transaction-scoped settings) +- `set_config()` detection is regex-based, not AST-based — handles the standard `SELECT set_config(...)` pattern but not deeply nested usage +- `set_config()` regex does not handle escaped single quotes in values (e.g., `'my''app'`); unlikely in practice for settings values +- `pg_show_all_settings()` rewrite is plain text substitution, not AST-aware; would incorrectly rewrite occurrences inside string literals diff --git a/docs/superpowers/specs/2026-03-27-start-transaction-design.md b/docs/superpowers/specs/2026-03-27-start-transaction-design.md new file mode 100644 index 00000000..7c804474 --- /dev/null +++ b/docs/superpowers/specs/2026-03-27-start-transaction-design.md @@ -0,0 +1,110 @@ +# START TRANSACTION / END TRANSACTION Support + +**Issue**: [#70](https://github.com/erans/pgsqlite/issues/70) - Support for START TRANSACTION +**Date**: 2026-03-27 +**Scope**: Recognize PostgreSQL transaction synonyms: `START TRANSACTION` → `BEGIN`, `END` → `COMMIT` + +## Problem + +PostgreSQL's `START TRANSACTION` is a synonym for `BEGIN`, and `END` / `END TRANSACTION` is a synonym for `COMMIT`. The `postgres_fdw` extension sends: + +```sql +START TRANSACTION ISOLATION LEVEL REPEATABLE READ +``` + +pgsqlite's `QueryTypeDetector` doesn't recognize `START` as a transaction keyword, so the raw SQL is passed to SQLite, which fails because `START` isn't a SQLite keyword. The same issue exists for `END TRANSACTION`. + +Additionally, the extended query protocol path in `extended.rs` has its own transaction dispatch and `execute_transaction` that only check for `BEGIN`, `COMMIT`, and `ROLLBACK` — not `START` or `END`. + +## Design + +### 1. Detection changes in QueryTypeDetector + +**File**: `src/query/query_type_detection.rs` + +**Fast path** — add `START` to the existing `if bytes.len() >= 5` block (lines 37-43) which already matches `ALTER` and `BEGIN`: +```rust +b"START" | b"start" | b"Start" => return QueryType::Begin, +``` + +For `END`, add a new 3-byte check with word boundary guard: +```rust +if bytes.len() >= 3 { + let first3 = &bytes[0..3]; + if (first3 == b"END" || first3 == b"end" || first3 == b"End") + && (bytes.len() == 3 || bytes[3].is_ascii_whitespace()) { + return QueryType::Commit; + } +} +``` + +**Fallback path** — add alongside existing checks: +```rust +} else if trimmed.len() >= 5 && trimmed[..5].eq_ignore_ascii_case("START") { + QueryType::Begin +} else if trimmed.len() >= 3 && trimmed[..3].eq_ignore_ascii_case("END") + && (trimmed.len() == 3 || trimmed.as_bytes()[3].is_ascii_whitespace()) { + QueryType::Commit +``` + +### 2. Extended query protocol dispatch + +**File**: `src/query/extended.rs` + +**Dispatch block** (~line 1865): Add `START` and `END` to the transaction dispatch: +```rust +} else if query_starts_with_ignore_case(&final_query, "BEGIN") + || query_starts_with_ignore_case(&final_query, "START") + || query_starts_with_ignore_case(&final_query, "COMMIT") + || query_starts_with_ignore_case(&final_query, "END") + || query_starts_with_ignore_case(&final_query, "ROLLBACK") { + Self::execute_transaction(framed, db, session, &final_query).await?; +``` + +**execute_transaction** (~line 5635): Add `START` and `END` handling: +```rust +if query_starts_with_ignore_case(query, "BEGIN") + || query_starts_with_ignore_case(query, "START") { + db.begin_with_session(&session.id).await?; + framed.send(BackendMessage::CommandComplete { tag: "BEGIN".to_string() }).await + .map_err(PgSqliteError::Io)?; +} else if query_starts_with_ignore_case(query, "COMMIT") + || query_starts_with_ignore_case(query, "END") { + db.commit_with_session(&session.id).await?; + framed.send(BackendMessage::CommandComplete { tag: "COMMIT".to_string() }).await + .map_err(PgSqliteError::Io)?; +} else if query_starts_with_ignore_case(query, "ROLLBACK") { +``` + +Note: `query_starts_with_ignore_case` checks case-insensitively and only matches at word boundaries, so `END` won't match `ENDTABLE`. + +### Why this is sufficient + +Once detected as `QueryType::Begin` or `QueryType::Commit`, the existing `execute_transaction` handlers call `db.begin_with_session()` or `db.commit_with_session()` — they never pass the raw SQL to SQLite. The `ISOLATION LEVEL REPEATABLE READ` clause is naturally ignored because SQLite only supports one isolation level (serializable). + +The ultra-simple query fast path (`simple_query_detector.rs`) only matches DML/SELECT queries, so `START TRANSACTION` and `END` are unaffected — they always fall through to the normal detection path. + +Note: `query_router.rs` already handles `START` and `END` classification (lines 170-173), so no changes are needed there. + +## Files Changed + +| File | Change | +|------|--------| +| `src/query/query_type_detection.rs` | Add `START` → `Begin` and `END` → `Commit` in both detection paths + tests | +| `src/query/extended.rs` | Add `START` and `END` to transaction dispatch and `execute_transaction` | + +## Testing + +- `START TRANSACTION` → `QueryType::Begin` +- `start transaction` → `QueryType::Begin` (case insensitive) +- `START TRANSACTION ISOLATION LEVEL REPEATABLE READ` → `QueryType::Begin` +- `END` → `QueryType::Commit` +- `END TRANSACTION` → `QueryType::Commit` +- `end transaction` → `QueryType::Commit` (case insensitive) +- Verify `END` does NOT match identifiers starting with "END" (word boundary guard) +- Verify `START TRANSACTION` in a failed transaction returns the "commands ignored" error + +## Known Limitations + +- SQLite only supports serializable isolation, so `ISOLATION LEVEL REPEATABLE READ` (or any other level) is silently accepted but not enforced at that level +- `SET TRANSACTION ISOLATION LEVEL` is not handled (separate command, not yet reported as needed) diff --git a/docs/superpowers/specs/2026-05-08-column-name-sanitization-design.md b/docs/superpowers/specs/2026-05-08-column-name-sanitization-design.md new file mode 100644 index 00000000..6a0f42f1 --- /dev/null +++ b/docs/superpowers/specs/2026-05-08-column-name-sanitization-design.md @@ -0,0 +1,103 @@ +# Column Name Sanitization: Strip Function Parentheses from Result Column Names + +## Problem + +When PostgreSQL clients execute `SELECT version()`, PostgreSQL returns a result column named `version`. SQLite returns `version()` (or `function_name(arg1, arg2)` for functions with arguments). This breaks clients like pgAdmin4 that expect column names to match PostgreSQL behavior — they look for a `version` column and find `version()` instead. + +The existing `FunctionParenthesesTranslator` only strips `()` from `current_user()` and `session_user()` in query text before execution. It does not address column names in result sets, and it does not handle general function call column names like `version()`, `count(*)`, `max(id)`, etc. + +## Solution + +Add a `sanitize_column_name()` function that strips parenthesized suffixes from column names. Apply it at the boundary where SQLite column names enter the pgsqlite protocol layer — specifically when building `FieldDescription` objects or `DbResponse.columns`. + +## Design + +### New Function + +A small utility function in `src/query/column_sanitizer.rs` (new file): + +```rust +/// Strip parenthesized arguments from column names to match PostgreSQL behavior. +/// PostgreSQL: SELECT version() -> column name "version" +/// SQLite: SELECT version() -> column name "version()" +/// This function normalizes SQLite's behavior to match PostgreSQL. +pub fn sanitize_column_name(name: &str) -> &str { + if let Some(pos) = name.find('(') { + &name[..pos] + } else { + name + } +} +``` + +### Where to Apply + +Sanitize column names **after** internal type lookups (datetime, boolean, enum, schema lookups) but **before** they reach the protocol layer. This means: + +1. **`DbResponse.columns`** — When building the response from `stmt.column_name(i)`, sanitize the names. Since type lookups in the executor also use `response.columns`, we need to ensure the sanitization is applied consistently. + + **Strategy:** Sanitize at the point where `DbResponse.columns` is constructed (the `stmt.column_name(i)` collection sites), BUT keep raw names accessible for type lookups. The cleanest approach: sanitize in `DbResponse.columns` and update type-lookup logic to use sanitized names too (since the schema cache keys use real column names, not function-call expressions). + +2. **Direct `stmt.column_name(i)` calls** in fast_path, extended, and other executor paths — sanitize uniformly. + +### Application Sites + +**`src/session/db_handler.rs`:** +- Line ~529: `columns.push(stmt.column_name(i)?.to_string())` — SELECT path +- Line ~664: same pattern — another SELECT path +- Line ~777: same — catalog query path +- Line ~895: same +- Line ~1242: same +- Line ~1453: same +- Line ~1722: same +- Line ~2324-2326: column collection for type inference + +**`src/query/fast_path.rs`:** +- Line 377: `columns.push(stmt.column_name(i)?.to_string())` +- Line 619: same +- Line 737: same + +**`src/query/extended.rs`:** +- Lines creating `column_names` from `stmt.column_name(i)` — sanitize similarly + +**`src/query/executor.rs`:** +- Type-lookup sites that use column names from `response.columns` — these need to work with sanitized names since that's what flows through the system. + +### Type Lookup Compatibility + +Internal type lookups (datetime, boolean, enum, schema) use column names to query `__pgsqlite_schema`. Since real table columns never contain `(`, sanitized names won't break these lookups. Function-call column names like `version()` or `max(id)` aren't in `__pgsqlite_schema`, so stripping parens is safe — the lookup simply returns `None` as it does today. + +For the executor's `column_mappings` (alias → real column), these are built from `AS` aliases in the query. If the alias is a function call like `max(id)`, stripping parens gives `max`, which also won't match schema lookups — same behavior as today (fallback to text type). + +### Edge Cases + +| SQLite Column Name | After Sanitization | PostgreSQL Behavior | Match? | +|---|---|---|---| +| `version()` | `version` | `version` | Yes | +| `count(*)` | `count` | `count` | Yes | +| `max(id)` | `max` | `max` | Yes | +| `current_timestamp` | `current_timestamp` | `current_timestamp` | Yes | +| `pg_catalog.version()` | `pg_catalog.version` | N/A (qualified) | Acceptable | +| `my_alias` | `my_alias` | `my_alias` | Yes | +| `1` | `1` | `?column?` | Acceptable | + +### What About the Existing FunctionParenthesesTranslator? + +The `FunctionParenthesesTranslator` strips `current_user()` and `session_user()` from query text. This is still needed because SQLite doesn't recognize `current_user` as a function call — it needs the parentheses removed from the query itself. But it's incomplete: + +1. It doesn't handle `version()` in the query (unnecessary since SQLite supports it) +2. It doesn't affect result column names (the actual bug) + +The column name sanitizer is orthogonal: it fixes result column names regardless of what the query text looks like. No changes to `FunctionParenthesesTranslator` are needed for this fix. + +## Testing + +Add unit tests in `column_sanitizer.rs` covering: +- Basic function call: `version()` → `version` +- Function with args: `count(*)` → `count`, `max(id)` → `max` +- No parens: `current_timestamp` → `current_timestamp` +- Nested parens: `COALESCE(max(id), 0)` → `COALESCE` +- Already clean: `my_column` → `my_column` +- Schema-qualified: `pg_catalog.version()` → `pg_catalog.version` + +Integration test: connect via psql/psycopg, execute `SELECT version()`, verify column name is `version`. \ No newline at end of file diff --git a/src/cache/statement_pool.rs b/src/cache/statement_pool.rs index 95af55f1..35d91727 100644 --- a/src/cache/statement_pool.rs +++ b/src/cache/statement_pool.rs @@ -3,6 +3,7 @@ use std::sync::Mutex; use rusqlite::{Connection, Statement, Params}; use once_cell::sync::Lazy; use crate::config::CONFIG; +use crate::query::column_sanitizer::sanitize_column_name; /// A pool of prepared SQLite statements for reuse /// This avoids the overhead of preparing the same statement multiple times @@ -194,7 +195,7 @@ impl StatementPool { let mut column_types = Vec::new(); for i in 0..column_count { - column_names.push(stmt.column_name(i)?.to_string()); + column_names.push(sanitize_column_name(stmt.column_name(i)?).to_string()); // We can't easily get PostgreSQL types here, so we'll leave them as None // They can be filled in later by the caller if needed column_types.push(None); diff --git a/src/catalog/query_interceptor.rs b/src/catalog/query_interceptor.rs index e1f90546..96189268 100644 --- a/src/catalog/query_interceptor.rs +++ b/src/catalog/query_interceptor.rs @@ -3,6 +3,7 @@ use uuid::Uuid; use crate::session::SessionState; use crate::PgSqliteError; use crate::translator::{RegexTranslator, SchemaPrefixTranslator}; +use crate::query::column_sanitizer::sanitize_column_name; use sqlparser::ast::{Statement, TableFactor, Select, SetExpr, SelectItem, Expr, FunctionArg, FunctionArgExpr}; use sqlparser::dialect::PostgreSqlDialect; use sqlparser::parser::Parser; @@ -271,7 +272,7 @@ impl CatalogInterceptor { let column_count = stmt.column_count(); let mut columns = Vec::new(); for i in 0..column_count { - columns.push(stmt.column_name(i)?.to_string()); + columns.push(sanitize_column_name(stmt.column_name(i)?).to_string()); } let rows_result: rusqlite::Result>>>> = stmt.query_map([], |row| { @@ -412,7 +413,7 @@ impl CatalogInterceptor { let column_count = stmt.column_count(); let mut columns = Vec::new(); for i in 0..column_count { - columns.push(stmt.column_name(i)?.to_string()); + columns.push(sanitize_column_name(stmt.column_name(i)?).to_string()); } let rows_result: rusqlite::Result>>>> = stmt.query_map([], |row| { diff --git a/src/optimization/read_only_optimizer.rs b/src/optimization/read_only_optimizer.rs index 12f71558..de689239 100644 --- a/src/optimization/read_only_optimizer.rs +++ b/src/optimization/read_only_optimizer.rs @@ -5,6 +5,7 @@ use rusqlite::Connection; use crate::cache::SchemaCache; use crate::session::db_handler::DbResponse; use crate::query::QueryComplexity; +use crate::query::column_sanitizer::sanitize_column_name; use tracing::{debug, info}; /// Direct read-only access optimizer for SELECT queries @@ -259,7 +260,7 @@ impl ReadOnlyOptimizer { // Get column names let mut columns = Vec::new(); for i in 0..column_count { - columns.push(stmt.column_name(i)?.to_string()); + columns.push(sanitize_column_name(stmt.column_name(i)?).to_string()); } // Get column types from schema cache diff --git a/src/query/column_sanitizer.rs b/src/query/column_sanitizer.rs new file mode 100644 index 00000000..3f3253d0 --- /dev/null +++ b/src/query/column_sanitizer.rs @@ -0,0 +1,66 @@ +pub fn sanitize_column_name(name: &str) -> &str { + match name.find('(') { + Some(pos) => &name[..pos], + None => name, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_version_function() { + assert_eq!(sanitize_column_name("version()"), "version"); + } + + #[test] + fn test_count_star() { + assert_eq!(sanitize_column_name("count(*)"), "count"); + } + + #[test] + fn test_max_with_arg() { + assert_eq!(sanitize_column_name("max(id)"), "max"); + } + + #[test] + fn test_coalesce_nested() { + assert_eq!(sanitize_column_name("COALESCE(max(id), 0)"), "COALESCE"); + } + + #[test] + fn test_current_timestamp() { + assert_eq!(sanitize_column_name("current_timestamp"), "current_timestamp"); + } + + #[test] + fn test_plain_column() { + assert_eq!(sanitize_column_name("my_column"), "my_column"); + } + + #[test] + fn test_empty_string() { + assert_eq!(sanitize_column_name(""), ""); + } + + #[test] + fn test_just_parentheses() { + assert_eq!(sanitize_column_name("()"), ""); + } + + #[test] + fn test_open_paren_only() { + assert_eq!(sanitize_column_name("("), ""); + } + + #[test] + fn test_multiple_parens() { + assert_eq!(sanitize_column_name("func(arg1, arg2)"), "func"); + } + + #[test] + fn test_nested_parens() { + assert_eq!(sanitize_column_name("outer(inner())"), "outer"); + } +} \ No newline at end of file diff --git a/src/query/fast_path.rs b/src/query/fast_path.rs index e50558b3..dacebd35 100644 --- a/src/query/fast_path.rs +++ b/src/query/fast_path.rs @@ -3,6 +3,7 @@ use regex::Regex; use once_cell::sync::Lazy; use crate::cache::SchemaCache; use crate::session::db_handler::DbResponse; +use crate::query::column_sanitizer::sanitize_column_name; use std::collections::HashMap; use std::sync::Mutex; @@ -374,7 +375,7 @@ pub fn query_fast_path( // Get column names let mut columns = Vec::new(); for i in 0..column_count { - columns.push(stmt.column_name(i)?.to_string()); + columns.push(sanitize_column_name(stmt.column_name(i)?).to_string()); } // Check for boolean columns in the schema using cache @@ -616,7 +617,7 @@ fn execute_fast_select_with_params( // Get column names let mut columns = Vec::new(); for i in 0..column_count { - columns.push(stmt.column_name(i)?.to_string()); + columns.push(sanitize_column_name(stmt.column_name(i)?).to_string()); } // Check for boolean columns in the schema using cache @@ -734,7 +735,7 @@ fn execute_fast_select( // Get column names let mut columns = Vec::new(); for i in 0..column_count { - columns.push(stmt.column_name(i)?.to_string()); + columns.push(sanitize_column_name(stmt.column_name(i)?).to_string()); } // Check for boolean columns in the schema using cache diff --git a/src/query/mod.rs b/src/query/mod.rs index 6e7f0c80..7badf655 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -15,6 +15,7 @@ pub mod unified_processor; pub mod pattern_optimizer; pub mod query_handler; pub mod join_type_inference; +pub mod column_sanitizer; pub use executor::QueryExecutor; pub use query_handler::{QueryHandler, QueryHandlerImpl}; diff --git a/src/session/db_handler.rs b/src/session/db_handler.rs index 9d2ff85b..2b4b816e 100644 --- a/src/session/db_handler.rs +++ b/src/session/db_handler.rs @@ -14,6 +14,7 @@ use crate::session::ConnectionManager; use crate::ddl::CommentDdlHandler; use crate::PgSqliteError; use crate::security::{events, SqlInjectionDetector}; +use crate::query::column_sanitizer::sanitize_column_name; use tracing::{debug, info, error, warn}; /// Security limits for query validation @@ -526,7 +527,7 @@ impl DbHandler { let column_count = stmt.column_count(); let mut columns = Vec::with_capacity(column_count); for i in 0..column_count { - columns.push(stmt.column_name(i)?.to_string()); + columns.push(sanitize_column_name(stmt.column_name(i)?).to_string()); } let rows: Result, _> = stmt.query_map(rusqlite::params_from_iter(values.iter()), |row| { @@ -661,7 +662,7 @@ impl DbHandler { let column_count = stmt.column_count(); let mut columns = Vec::new(); for i in 0..column_count { - columns.push(stmt.column_name(i)?.to_string()); + columns.push(sanitize_column_name(stmt.column_name(i)?).to_string()); } let rows_result: rusqlite::Result>>>> = stmt.query_map([], |row| { @@ -774,7 +775,7 @@ impl DbHandler { let column_count = stmt.column_count(); let mut columns = Vec::with_capacity(column_count); for i in 0..column_count { - columns.push(stmt.column_name(i)?.to_string()); + columns.push(sanitize_column_name(stmt.column_name(i)?).to_string()); } let rows: Result, _> = stmt.query_map([], |row| { @@ -892,7 +893,7 @@ impl DbHandler { let column_count = stmt.column_count(); let mut columns = Vec::with_capacity(column_count); for i in 0..column_count { - columns.push(stmt.column_name(i)?.to_string()); + columns.push(sanitize_column_name(stmt.column_name(i)?).to_string()); } let rows: Result, _> = stmt.query_map([], |row| { @@ -1239,7 +1240,7 @@ impl DbHandler { let column_count = stmt.column_count(); let mut columns = Vec::with_capacity(column_count); for i in 0..column_count { - columns.push(stmt.column_name(i)?.to_string()); + columns.push(sanitize_column_name(stmt.column_name(i)?).to_string()); } let rows: Result, _> = stmt.query_map([], |row| { @@ -1450,7 +1451,7 @@ impl DbHandler { let column_count = stmt.column_count(); let mut columns = Vec::with_capacity(column_count); for i in 0..column_count { - columns.push(stmt.column_name(i)?.to_string()); + columns.push(sanitize_column_name(stmt.column_name(i)?).to_string()); } let rows: Result, _> = stmt.query_map([], |row| { @@ -1719,7 +1720,7 @@ impl DbHandler { let column_count = stmt.column_count(); let mut columns = Vec::with_capacity(column_count); for i in 0..column_count { - columns.push(stmt.column_name(i)?.to_string()); + columns.push(sanitize_column_name(stmt.column_name(i)?).to_string()); } let rows: Result, _> = stmt.query_map([], |row| { @@ -2323,7 +2324,7 @@ impl DbHandler { let column_count = stmt.column_count(); let mut column_names = Vec::with_capacity(column_count); for i in 0..column_count { - column_names.push(stmt.column_name(i).unwrap_or("").to_string()); + column_names.push(sanitize_column_name(stmt.column_name(i).unwrap_or("")).to_string()); } // Build datetime column info for conversion @@ -2421,7 +2422,7 @@ impl DbHandler { let column_count = stmt.column_count(); let mut column_names = Vec::with_capacity(column_count); for i in 0..column_count { - column_names.push(stmt.column_name(i).unwrap_or("").to_string()); + column_names.push(sanitize_column_name(stmt.column_name(i).unwrap_or("")).to_string()); } // Build datetime column info for conversion diff --git a/src/types/schema_type_mapper.rs b/src/types/schema_type_mapper.rs index 5aafafe1..7c030e78 100644 --- a/src/types/schema_type_mapper.rs +++ b/src/types/schema_type_mapper.rs @@ -336,9 +336,42 @@ impl SchemaTypeMapper { ) -> Option { let upper = function_name.to_uppercase(); - // Handle aliased columns - if it's just a simple name, skip function detection - // This prevents false positives for columns named "year_col", "hour_trunc", etc. + // Handle bare function names (without parentheses). + // Column names may arrive without parentheses after sanitization strips them, + // e.g. "json_extract" instead of "json_extract(data, '$[0]')". + // Match known function names by exact equality to avoid false positives + // on columns named like "year_col" or "hour_trunc". if !function_name.contains('(') && !function_name.contains(' ') { + match upper.as_str() { + "COUNT" => return Some(PgType::Int8.to_oid()), + "SUM" | "AVG" => return Some(PgType::Numeric.to_oid()), + "MAX" | "MIN" => { + // Need schema lookup for proper type — fall through to query-based detection below + } + "JSON_ARRAY_LENGTH" => return Some(PgType::Int4.to_oid()), + "JSON_GROUP_ARRAY" | "JSON_ARRAY" | "JSON_OBJECT" | "JSON_EXTRACT" + | "JSON_AGG" | "JSON_OBJECT_AGG" | "JSONB_AGG" | "JSONB_OBJECT_AGG" | "ROW_TO_JSON" + | "JSON_EXTRACT_PATH" | "JSON_EXTRACT_PATH_TEXT" => return Some(PgType::Text.to_oid()), + "ARRAY_LENGTH" | "ARRAY_UPPER" | "ARRAY_LOWER" | "ARRAY_NDIMS" | "ARRAY_POSITION" => return Some(PgType::Int4.to_oid()), + "ARRAY_APPEND" | "ARRAY_PREPEND" | "ARRAY_CAT" | "ARRAY_REMOVE" + | "ARRAY_REPLACE" | "ARRAY_SLICE" | "STRING_TO_ARRAY" + | "ARRAY_POSITIONS" | "ARRAY_TO_STRING" | "UNNEST" + | "ARRAY_AGG" => return Some(PgType::Text.to_oid()), + "ARRAY_CONTAINS" | "ARRAY_CONTAINED" | "ARRAY_OVERLAP" => return Some(PgType::Bool.to_oid()), + "NOW" => return Some(PgType::Timestamptz.to_oid()), + "CURRENT_TIMESTAMP" => return Some(PgType::Timestamptz.to_oid()), + "CURRENT_DATE" => return Some(PgType::Text.to_oid()), + "CURRENT_TIME" => return Some(PgType::Time.to_oid()), + "EXTRACT" => return Some(PgType::Float8.to_oid()), + "DATE_TRUNC" | "TO_TIMESTAMP" => return Some(PgType::Timestamp.to_oid()), + "MAKE_DATE" => return Some(PgType::Date.to_oid()), + "MAKE_TIME" => return Some(PgType::Time.to_oid()), + "AGE" => return Some(PgType::Interval.to_oid()), + "EPOCH" => return Some(PgType::Timestamp.to_oid()), + "DECIMAL_ADD" | "DECIMAL_SUB" | "DECIMAL_MUL" | "DECIMAL_DIV" | "DECIMAL_FROM_TEXT" => return Some(PgType::Numeric.to_oid()), + _ => {} + } + // If we have the query, try to find what function produces this alias if let Some(q) = query { // Look for patterns like "sum(...) AS function_name" or "avg(...) AS function_name" diff --git a/tests/test_batch_7b.py b/tests/test_batch_7b.py new file mode 100644 index 00000000..1c2ddead --- /dev/null +++ b/tests/test_batch_7b.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +"""Quick batch test for 7B model""" +import time +import requests +import asyncio +import aiohttp + +API_URL = "http://localhost:8000/v1/chat/completions" + +async def send_request(session, idx): + payload = { + "model": "qwen2.5-7b-instruct", + "messages": [{"role": "user", "content": "What is Rust?"}], + "max_tokens": 50 + } + start = time.time() + async with session.post(API_URL, json=payload) as resp: + result = await resp.json() + elapsed = time.time() - start + tokens = result["usage"]["completion_tokens"] + return elapsed, tokens + +async def test_concurrent(): + print("Testing 3 concurrent requests...") + start = time.time() + + async with aiohttp.ClientSession() as session: + tasks = [send_request(session, i) for i in range(3)] + results = await asyncio.gather(*tasks) + + total_time = time.time() - start + + print(f"\nResults:") + for i, (elapsed, tokens) in enumerate(results): + print(f" Request {i+1}: {elapsed:.2f}s ({tokens} tokens)") + + total_tokens = sum(r[1] for r in results) + aggregate_throughput = total_tokens / total_time + + print(f"\nTotal time: {total_time:.2f}s") + print(f"Total tokens: {total_tokens}") + print(f"Aggregate throughput: {aggregate_throughput:.1f} tokens/sec") + +if __name__ == "__main__": + asyncio.run(test_concurrent())