From 220641005898cc648befe8d8ee18e27f6eff4b64 Mon Sep 17 00:00:00 2001 From: George Stagg Date: Fri, 1 May 2026 14:10:51 +0100 Subject: [PATCH 1/8] Refactor libodbc bindings --- .github/workflows/build.yaml | 3 - .github/workflows/release-jupyter.yml | 23 +- .github/workflows/release-packages.yml | 56 +- Cargo.lock | 37 +- Cargo.toml | 2 +- ggsql-cli/Cargo.toml | 2 - ggsql-jupyter/src/connection.rs | 194 ++-- ggsql-jupyter/src/data_explorer.rs | 26 +- ggsql-jupyter/src/lib.rs | 2 - ggsql-jupyter/src/main.rs | 2 - ggsql-jupyter/src/util.rs | 25 - src/Cargo.toml | 4 +- src/reader/mod.rs | 135 ++- src/reader/odbc.rs | 848 -------------- src/reader/odbc/ffi.rs | 383 +++++++ src/reader/odbc/mod.rs | 1464 ++++++++++++++++++++++++ src/reader/odbc/snowflake.rs | 248 ++++ src/reader/odbc/wrapper.rs | 674 +++++++++++ src/reader/snowflake.rs | 35 - src/reader/sqlite.rs | 83 +- 20 files changed, 3043 insertions(+), 1203 deletions(-) delete mode 100644 ggsql-jupyter/src/util.rs delete mode 100644 src/reader/odbc.rs create mode 100644 src/reader/odbc/ffi.rs create mode 100644 src/reader/odbc/mod.rs create mode 100644 src/reader/odbc/snowflake.rs create mode 100644 src/reader/odbc/wrapper.rs delete mode 100644 src/reader/snowflake.rs diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 7d1aef25..0937dc62 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -32,9 +32,6 @@ jobs: - name: Install LLVM run: sudo apt-get install -y llvm - - name: Install ODBC - run: sudo apt-get install -y unixodbc-dev - - name: Install Rust uses: dtolnay/rust-toolchain@stable diff --git a/.github/workflows/release-jupyter.yml b/.github/workflows/release-jupyter.yml index e07ad615..6ba9c92f 100644 --- a/.github/workflows/release-jupyter.yml +++ b/.github/workflows/release-jupyter.yml @@ -55,21 +55,10 @@ jobs: uses: PyO3/maturin-action@v1 with: target: ${{ matrix.target }} - args: --release --out dist --auditwheel=repair + args: --release --out dist --auditwheel=skip working-directory: ggsql-jupyter manylinux: 2_28 docker-options: -e GGSQL_SKIP_GENERATE=1 - before-script-linux: | - if command -v yum >/dev/null 2>&1; then - yum install -y unixODBC-devel - elif command -v apt-get >/dev/null 2>&1; then - apt-get update && apt-get install -y unixodbc-dev - fi - - - name: Fix wheel library load paths - uses: ./.github/workflows/actions/fix-wheel-libs - with: - dist-dir: ggsql-jupyter/dist - uses: actions/upload-artifact@v4 with: @@ -96,21 +85,13 @@ jobs: name: tree-sitter-generated path: tree-sitter-ggsql/src/ - - name: Install ODBC - run: brew install unixodbc - - name: Build wheels uses: PyO3/maturin-action@v1 with: target: ${{ matrix.target }} - args: --release --out dist --auditwheel=repair + args: --release --out dist --auditwheel=skip working-directory: ggsql-jupyter - - name: Fix wheel library load paths - uses: ./.github/workflows/actions/fix-wheel-libs - with: - dist-dir: ggsql-jupyter/dist - - uses: actions/upload-artifact@v4 with: name: jupyter-wheels-macos-${{ matrix.target }} diff --git a/.github/workflows/release-packages.yml b/.github/workflows/release-packages.yml index 694343bd..7a22cc19 100644 --- a/.github/workflows/release-packages.yml +++ b/.github/workflows/release-packages.yml @@ -102,12 +102,6 @@ jobs: - name: Install tree-sitter-cli run: npm install -g tree-sitter-cli - - name: Install ODBC - run: brew install unixodbc - - - name: Install dylibbundler - run: brew install dylibbundler - - name: Install Rust toolchain uses: dtolnay/rust-toolchain@stable with: @@ -148,19 +142,10 @@ jobs: - name: Build ggsql binary (x86_64) run: cargo build --release --bin ggsql --bin ggsql-jupyter - - name: Bundle dynamic library dependencies - run: | - dylibbundler -cd -of -b -x target/release/ggsql -d ./libs/ -p "@executable_path/../lib/ggsql$VERSION/" - dylibbundler -cd -of -b -x target/release/ggsql-jupyter -d ./libs/ -p "@executable_path/../lib/ggsql$VERSION/" - - - name: Sign binaries and dylibs (Developer ID Application) + - name: Sign binaries (Developer ID Application) env: SIGN_ID: "Developer ID Application: ${{ secrets.GWS_APPLE_SIGN_IDENTITY }}" run: | - # Sign bundled dylibs first (inside-out), replacing dylibbundler's ad-hoc sigs - find ./libs -type f \( -name "*.dylib" -o -name "*.so" \) -print0 | \ - xargs -0 -I{} codesign --force --options runtime --timestamp --sign "$SIGN_ID" "{}" - # Then sign the executables with hardened runtime + entitlements codesign --force --options runtime --timestamp \ --entitlements entitlements.plist \ --sign "$SIGN_ID" target/release/ggsql @@ -177,14 +162,13 @@ jobs: APPLE_API_ISSUER: ${{ secrets.GWS_APPLE_API_ISSUER }} run: | PKG_NAME="ggsql_${VERSION}_x86_64.pkg" - mkdir -p pkg-payload/usr/local/bin "pkg-payload/usr/local/lib/ggsql$VERSION" + mkdir -p pkg-payload/usr/local/bin cp target/release/ggsql pkg-payload/usr/local/bin/ cp target/release/ggsql-jupyter pkg-payload/usr/local/bin/ - cp -R ./libs/. "pkg-payload/usr/local/lib/ggsql$VERSION/" mkdir -p pkg-scripts cat > pkg-scripts/postinstall < pkg-scripts/postinstall < (usize, String, String) { + let catalogs = reader.list_catalogs().unwrap_or_default(); + if catalogs.is_empty() { + let schemas = reader.list_schemas("").unwrap_or_default(); + if schemas.is_empty() { + (2, String::new(), String::new()) + } else { + (1, String::new(), String::new()) + } + } else { + (0, String::new(), String::new()) + } +} + +/// List objects at the given path depth, skipping empty hierarchy levels. pub fn list_objects(reader: &dyn Reader, path: &[String]) -> Result, String> { - match path.len() { + let (offset, default_catalog, default_schema) = depth_offset(reader); + let effective = path.len() + offset; + + match effective { 0 => list_catalogs(reader), - 1 => list_schemas(reader, &path[0]), - 2 => list_tables(reader, &path[0], &path[1]), + 1 => { + let catalog = if offset >= 1 { &default_catalog } else { &path[0] }; + list_schemas(reader, catalog) + } + 2 => { + let (catalog, schema) = match offset { + 2 => (&default_catalog, &default_schema), + 1 => (&default_catalog, &path[0]), + _ => (&path[0], &path[1]), + }; + list_tables(reader, catalog, schema) + } _ => Ok(vec![]), } } /// List fields (columns) for the object at the given path. -/// -/// - `[catalog, schema, table]` → list columns pub fn list_fields(reader: &dyn Reader, path: &[String]) -> Result, String> { - if path.len() == 3 { - list_columns(reader, &path[0], &path[1], &path[2]) - } else { - Ok(vec![]) + let (offset, default_catalog, default_schema) = depth_offset(reader); + let effective = path.len() + offset; + + if effective != 3 { + return Ok(vec![]); } + + let (catalog, schema, table) = match offset { + 2 => (default_catalog.as_str(), default_schema.as_str(), path[0].as_str()), + 1 => (default_catalog.as_str(), path[0].as_str(), path[1].as_str()), + _ => (path[0].as_str(), path[1].as_str(), path[2].as_str()), + }; + + list_columns(reader, catalog, schema, table) } /// Whether the path points to an object that contains data (table or view). @@ -59,47 +92,31 @@ pub fn contains_data(path: &[Value]) -> bool { } fn list_catalogs(reader: &dyn Reader) -> Result, String> { - let sql = reader.dialect().sql_list_catalogs(); - let df = reader - .execute_sql(&sql) + let catalogs = reader + .list_catalogs() .map_err(|e| format!("Failed to list catalogs: {}", e))?; - let col = find_column(&df, &["catalog_name", "name"]) - .map_err(|e| format!("Missing catalog_name/name column: {}", e))?; - - let mut catalogs = Vec::new(); - for i in 0..df.height() { - let name = ggsql::array_util::value_to_string(col, i) - .trim_matches('"') - .to_string(); - catalogs.push(ObjectSchema { + Ok(catalogs + .into_iter() + .map(|name| ObjectSchema { name, kind: "catalog".to_string(), - }); - } - Ok(catalogs) + }) + .collect()) } fn list_schemas(reader: &dyn Reader, catalog: &str) -> Result, String> { - let sql = reader.dialect().sql_list_schemas(catalog); - let df = reader - .execute_sql(&sql) + let schemas = reader + .list_schemas(catalog) .map_err(|e| format!("Failed to list schemas: {}", e))?; - let col = find_column(&df, &["schema_name", "name"]) - .map_err(|e| format!("Missing schema_name/name column: {}", e))?; - - let mut schemas = Vec::new(); - for i in 0..df.height() { - let name = ggsql::array_util::value_to_string(col, i) - .trim_matches('"') - .to_string(); - schemas.push(ObjectSchema { + Ok(schemas + .into_iter() + .map(|name| ObjectSchema { name, kind: "schema".to_string(), - }); - } - Ok(schemas) + }) + .collect()) } fn list_tables( @@ -107,40 +124,27 @@ fn list_tables( catalog: &str, schema: &str, ) -> Result, String> { - let sql = reader.dialect().sql_list_tables(catalog, schema); - let df = reader - .execute_sql(&sql) + let tables = reader + .list_tables(catalog, schema) .map_err(|e| format!("Failed to list tables: {}", e))?; - let name_col = find_column(&df, &["table_name", "name"]) - .map_err(|e| format!("Missing table_name/name column: {}", e))?; - let type_col = find_column(&df, &["table_type", "kind"]) - .map_err(|e| format!("Missing table_type/kind column: {}", e))?; - - let mut objects = Vec::new(); - for i in 0..df.height() { - let name = ggsql::array_util::value_to_string(name_col, i) - .trim_matches('"') - .to_string(); - let table_type = ggsql::array_util::value_to_string(type_col, i) - .trim_matches('"') - .to_uppercase(); - let kind = if table_type.contains("VIEW") { - "view" - } else if table_type == "TABLE" - || table_type == "BASE TABLE" - || table_type.contains("TABLE") - { - "table" - } else { - continue; // Skip non-table/view objects (stages, procedures, etc.) - }; - objects.push(ObjectSchema { - name, - kind: kind.to_string(), - }); - } - Ok(objects) + Ok(tables + .into_iter() + .filter_map(|t| { + let upper = t.table_type.to_uppercase(); + let kind = if upper.contains("VIEW") { + "view" + } else if upper == "TABLE" || upper == "BASE TABLE" || upper.contains("TABLE") { + "table" + } else { + return None; + }; + Some(ObjectSchema { + name: t.name, + kind: kind.to_string(), + }) + }) + .collect()) } fn list_columns( @@ -149,27 +153,17 @@ fn list_columns( schema: &str, table: &str, ) -> Result, String> { - let sql = reader.dialect().sql_list_columns(catalog, schema, table); - let df = reader - .execute_sql(&sql) + let columns = reader + .list_columns(catalog, schema, table) .map_err(|e| format!("Failed to list columns: {}", e))?; - let name_col = find_column(&df, &["column_name"]) - .map_err(|e| format!("Missing column_name column: {}", e))?; - let type_col = - find_column(&df, &["data_type"]).map_err(|e| format!("Missing data_type column: {}", e))?; - - let mut fields = Vec::new(); - for i in 0..df.height() { - let name = ggsql::array_util::value_to_string(name_col, i) - .trim_matches('"') - .to_string(); - let dtype = ggsql::array_util::value_to_string(type_col, i) - .trim_matches('"') - .to_string(); - fields.push(FieldSchema { name, dtype }); - } - Ok(fields) + Ok(columns + .into_iter() + .map(|c| FieldSchema { + name: c.name, + dtype: c.data_type, + }) + .collect()) } #[cfg(test)] diff --git a/ggsql-jupyter/src/data_explorer.rs b/ggsql-jupyter/src/data_explorer.rs index 94f7e2fe..8e1795de 100644 --- a/ggsql-jupyter/src/data_explorer.rs +++ b/ggsql-jupyter/src/data_explorer.rs @@ -3,7 +3,6 @@ //! Implements the `positron.dataExplorer` comm protocol, providing SQL-backed //! paginated data access. -use crate::util::find_column; use ggsql::reader::Reader; use serde_json::{json, Value}; @@ -95,29 +94,16 @@ impl DataExplorerState { }) .unwrap_or(0); - // Get column metadata from information_schema - let columns_sql = reader.dialect().sql_list_columns(catalog, schema, table); - let columns_df = reader - .execute_sql(&columns_sql) + let column_infos = reader + .list_columns(catalog, schema, table) .map_err(|e| format!("Failed to list columns: {}", e))?; - let name_col = find_column(&columns_df, &["column_name"]) - .map_err(|e| format!("Missing column_name: {}", e))?; - let type_col = find_column(&columns_df, &["data_type"]) - .map_err(|e| format!("Missing data_type: {}", e))?; - let mut columns = Vec::new(); - for i in 0..columns_df.height() { - let name = ggsql::array_util::value_to_string(name_col, i) - .trim_matches('"') - .to_string(); - let raw_type = ggsql::array_util::value_to_string(type_col, i) - .trim_matches('"') - .to_string(); - let type_display = sql_type_to_display(&raw_type).to_string(); - let type_name = clean_type_name(&raw_type); + for col in &column_infos { + let type_display = sql_type_to_display(&col.data_type).to_string(); + let type_name = clean_type_name(&col.data_type); columns.push(ColumnInfo { - name, + name: col.name.clone(), type_name, type_display, }); diff --git a/ggsql-jupyter/src/lib.rs b/ggsql-jupyter/src/lib.rs index e7c56850..c0e9c42f 100644 --- a/ggsql-jupyter/src/lib.rs +++ b/ggsql-jupyter/src/lib.rs @@ -7,8 +7,6 @@ pub mod data_explorer; pub mod display; pub mod executor; pub mod message; -pub mod util; - // Re-export commonly used types pub use display::format_display_data; pub use executor::{ExecutionResult, QueryExecutor}; diff --git a/ggsql-jupyter/src/main.rs b/ggsql-jupyter/src/main.rs index 73ba93d8..fccb8906 100644 --- a/ggsql-jupyter/src/main.rs +++ b/ggsql-jupyter/src/main.rs @@ -8,8 +8,6 @@ mod display; mod executor; mod kernel; mod message; -mod util; - use anyhow::{Context, Result}; use clap::Parser; use message::ConnectionInfo; diff --git a/ggsql-jupyter/src/util.rs b/ggsql-jupyter/src/util.rs deleted file mode 100644 index b722677b..00000000 --- a/ggsql-jupyter/src/util.rs +++ /dev/null @@ -1,25 +0,0 @@ -use arrow::array::ArrayRef; -use ggsql::DataFrame; - -/// Find a DataFrame column by name, trying multiple names and falling back to -/// case-insensitive matching. This handles ODBC drivers that return uppercase -/// column names (e.g. `TABLE_NAME` instead of `table_name`). -pub fn find_column<'a>(df: &'a DataFrame, names: &[&str]) -> Result<&'a ArrayRef, String> { - // Try exact match first - for name in names { - if let Ok(col) = df.column(name) { - return Ok(col); - } - } - // Fall back to case-insensitive match - let col_names = df.get_column_names(); - for name in names { - let lower = name.to_lowercase(); - for cn in &col_names { - if cn.to_lowercase() == lower { - return df.column(cn).map_err(|e| e.to_string()); - } - } - } - Err(format!("Missing column (tried: {:?})", names)) -} diff --git a/src/Cargo.toml b/src/Cargo.toml index c2b9cb6a..91452030 100644 --- a/src/Cargo.toml +++ b/src/Cargo.toml @@ -27,8 +27,8 @@ arrow = { workspace = true } # Readers duckdb = { workspace = true, optional = true } rusqlite = { workspace = true, optional = true } -odbc-api = { workspace = true, optional = true } toml_edit = { workspace = true, optional = true } +libloading = { workspace = true, optional = true } parquet = { workspace = true, optional = true } bytes = { workspace = true } @@ -57,7 +57,7 @@ default = ["duckdb", "sqlite", "vegalite", "parquet", "builtin-data", "odbc"] duckdb = ["dep:duckdb"] parquet = ["dep:parquet"] sqlite = ["dep:rusqlite"] -odbc = ["dep:odbc-api", "dep:toml_edit"] +odbc = ["dep:toml_edit", "dep:libloading"] vegalite = [] builtin-data = [] all-readers = ["duckdb", "sqlite", "odbc"] diff --git a/src/reader/mod.rs b/src/reader/mod.rs index 6646ada1..fbbedd97 100644 --- a/src/reader/mod.rs +++ b/src/reader/mod.rs @@ -94,48 +94,6 @@ pub trait SqlDialect { } } - // ========================================================================= - // Schema introspection queries (for Connections pane) - // ========================================================================= - - /// SQL to list catalog names. Returns rows with column `catalog_name`. - fn sql_list_catalogs(&self) -> String { - "SELECT DISTINCT catalog_name FROM information_schema.schemata ORDER BY catalog_name".into() - } - - /// SQL to list schema names within a catalog. Returns rows with column `schema_name`. - fn sql_list_schemas(&self, catalog: &str) -> String { - format!( - "SELECT DISTINCT schema_name FROM information_schema.schemata \ - WHERE catalog_name = '{}' ORDER BY schema_name", - catalog.replace('\'', "''") - ) - } - - /// SQL to list tables/views within a catalog and schema. - /// Returns rows with columns `table_name` and `table_type`. - fn sql_list_tables(&self, catalog: &str, schema: &str) -> String { - format!( - "SELECT DISTINCT table_name, table_type FROM information_schema.tables \ - WHERE table_catalog = '{}' AND table_schema = '{}' ORDER BY table_name", - catalog.replace('\'', "''"), - schema.replace('\'', "''") - ) - } - - /// SQL to list columns in a table. - /// Returns rows with columns `column_name` and `data_type`. - fn sql_list_columns(&self, catalog: &str, schema: &str, table: &str) -> String { - format!( - "SELECT column_name, data_type FROM information_schema.columns \ - WHERE table_catalog = '{}' AND table_schema = '{}' AND table_name = '{}' \ - ORDER BY ordinal_position", - catalog.replace('\'', "''"), - schema.replace('\'', "''"), - table.replace('\'', "''") - ) - } - /// Scalar MAX across any number of SQL expressions. fn sql_greatest(&self, exprs: &[&str]) -> String { let mut result = exprs[0].to_string(); @@ -303,9 +261,6 @@ pub mod sqlite; #[cfg(feature = "odbc")] pub mod odbc; -#[cfg(feature = "odbc")] -pub mod snowflake; - pub mod connection; pub mod data; mod spec; @@ -510,6 +465,96 @@ pub trait Reader { fn dialect(&self) -> &dyn SqlDialect { &AnsiDialect } + + // ========================================================================= + // Schema introspection + // ========================================================================= + + fn list_catalogs(&self) -> Result> { + let df = self.execute_sql( + "SELECT DISTINCT catalog_name FROM information_schema.schemata ORDER BY catalog_name", + )?; + let col = df.column("catalog_name")?; + let mut results = Vec::with_capacity(df.height()); + for i in 0..df.height() { + if !col.is_null(i) { + results.push(crate::array_util::value_to_string(col, i)); + } + } + Ok(results) + } + + fn list_schemas(&self, catalog: &str) -> Result> { + let df = self.execute_sql(&format!( + "SELECT DISTINCT schema_name FROM information_schema.schemata \ + WHERE catalog_name = '{}' ORDER BY schema_name", + catalog.replace('\'', "''") + ))?; + let col = df.column("schema_name")?; + let mut results = Vec::with_capacity(df.height()); + for i in 0..df.height() { + if !col.is_null(i) { + results.push(crate::array_util::value_to_string(col, i)); + } + } + Ok(results) + } + + fn list_tables(&self, catalog: &str, schema: &str) -> Result> { + let df = self.execute_sql(&format!( + "SELECT DISTINCT table_name, table_type FROM information_schema.tables \ + WHERE table_catalog = '{}' AND table_schema = '{}' ORDER BY table_name", + catalog.replace('\'', "''"), + schema.replace('\'', "''") + ))?; + let name_col = df.column("table_name")?; + let type_col = df.column("table_type")?; + let mut results = Vec::with_capacity(df.height()); + for i in 0..df.height() { + if !name_col.is_null(i) { + results.push(TableInfo { + name: crate::array_util::value_to_string(name_col, i), + table_type: crate::array_util::value_to_string(type_col, i), + }); + } + } + Ok(results) + } + + fn list_columns(&self, catalog: &str, schema: &str, table: &str) -> Result> { + let df = self.execute_sql(&format!( + "SELECT column_name, data_type FROM information_schema.columns \ + WHERE table_catalog = '{}' AND table_schema = '{}' AND table_name = '{}' \ + ORDER BY ordinal_position", + catalog.replace('\'', "''"), + schema.replace('\'', "''"), + table.replace('\'', "''") + ))?; + let name_col = df.column("column_name")?; + let type_col = df.column("data_type")?; + let mut results = Vec::with_capacity(df.height()); + for i in 0..df.height() { + if !name_col.is_null(i) { + results.push(ColumnInfo { + name: crate::array_util::value_to_string(name_col, i), + data_type: crate::array_util::value_to_string(type_col, i), + }); + } + } + Ok(results) + } +} + +/// A table or view in the schema. +pub struct TableInfo { + pub name: String, + pub table_type: String, +} + +/// A column in a table. +pub struct ColumnInfo { + pub name: String, + pub data_type: String, } /// Execute a ggsql query using any reader diff --git a/src/reader/odbc.rs b/src/reader/odbc.rs deleted file mode 100644 index 427467e4..00000000 --- a/src/reader/odbc.rs +++ /dev/null @@ -1,848 +0,0 @@ -//! Generic ODBC data source implementation -//! -//! Provides a reader for any ODBC-compatible database (Snowflake, PostgreSQL, -//! SQL Server, etc.) using the `odbc-api` crate. - -use crate::reader::Reader; -use crate::{naming, DataFrame, GgsqlError, Result}; -use arrow::array::*; -use arrow::datatypes::DataType; -use odbc_api::sys::{Date as OdbcDate, Time as OdbcTime, Timestamp as OdbcTimestamp}; -use odbc_api::{ - buffers::{AnyBuffer, AnySlice, BufferDesc, ColumnarBuffer}, - ConnectionOptions, Cursor, DataType as OdbcDataType, Environment, -}; -use std::cell::RefCell; -use std::collections::HashSet; -use std::sync::{Arc, OnceLock}; - -/// Global ODBC environment (must be a singleton per process). -fn odbc_env() -> &'static Environment { - static ENV: OnceLock = OnceLock::new(); - ENV.get_or_init(|| Environment::new().expect("Failed to create ODBC environment")) -} - -/// Detect the backend SQL dialect from an ODBC connection string. -/// -/// Returns a dialect matching the detected backend (e.g. Snowflake, SQLite, -/// DuckDB, or ANSI for generic/unknown backends). -fn detect_dialect(conn_str: &str) -> Box { - let lower = conn_str.to_lowercase(); - if lower.contains("driver=snowflake") { - Box::new(super::snowflake::SnowflakeDialect) - } else if lower.contains("driver=sqlite") || lower.contains("driver={sqlite") { - #[cfg(feature = "sqlite")] - { - Box::new(super::sqlite::SqliteDialect) - } - #[cfg(not(feature = "sqlite"))] - { - Box::new(super::AnsiDialect) - } - } else if lower.contains("driver=duckdb") || lower.contains("driver={duckdb") { - #[cfg(feature = "duckdb")] - { - Box::new(super::duckdb::DuckDbDialect) - } - #[cfg(not(feature = "duckdb"))] - { - Box::new(super::AnsiDialect) - } - } else { - Box::new(super::AnsiDialect) - } -} - -/// Generic ODBC reader implementing the `Reader` trait. -pub struct OdbcReader { - connection: odbc_api::Connection<'static>, - dialect: Box, - registered_tables: RefCell>, -} - -// Safety: odbc_api::Connection is Send when we ensure single-threaded access. -// The Reader trait requires &self (immutable) for execute_sql, and ODBC -// connections are safe to use from one thread at a time. -unsafe impl Send for OdbcReader {} - -impl OdbcReader { - /// Create a new ODBC reader from a `odbc://` connection URI. - /// - /// The URI format is `odbc://` followed by the raw ODBC connection string. - pub fn from_connection_string(uri: &str) -> Result { - let conn_str = uri - .strip_prefix("odbc://") - .ok_or_else(|| GgsqlError::ReaderError("ODBC URI must start with odbc://".into()))?; - - let mut conn_str = conn_str.to_string(); - - // Snowflake ConnectionName resolution from connections.toml - if is_snowflake(&conn_str) { - if let Some(resolved) = resolve_connection_name(&conn_str) { - conn_str = resolved; - } - } - - // Snowflake Workbench credential detection - if is_snowflake(&conn_str) && !has_token(&conn_str) { - if let Some(token) = detect_workbench_token() { - conn_str = inject_snowflake_token(&conn_str, &token); - } - } - - // Detect backend dialect from connection string - let dialect = detect_dialect(&conn_str); - - let env = odbc_env(); - let connection = env - .connect_with_connection_string(&conn_str, ConnectionOptions::default()) - .map_err(|e| GgsqlError::ReaderError(format!("ODBC connection failed: {}", e)))?; - - Ok(Self { - connection, - dialect, - registered_tables: RefCell::new(HashSet::new()), - }) - } -} - -impl Reader for OdbcReader { - fn execute_sql(&self, sql: &str) -> Result { - // Execute the query (3rd arg = query timeout, None = no timeout) - let cursor = self - .connection - .execute(sql, (), None) - .map_err(|e| GgsqlError::ReaderError(format!("ODBC execute failed: {}", e)))?; - - let Some(cursor) = cursor else { - // DDL or non-query statement — return empty DataFrame - return Ok(DataFrame::empty()); - }; - - cursor_to_dataframe(cursor) - } - - fn register(&self, name: &str, df: DataFrame, replace: bool) -> Result<()> { - super::validate_table_name(name)?; - - if replace { - let drop_sql = format!("DROP TABLE IF EXISTS {}", naming::quote_ident(name)); - // Ignore errors from DROP — table may not exist - let _ = self.connection.execute(&drop_sql, (), None); - } - - // Build CREATE TEMP TABLE with typed columns - let schema = df.schema(); - let col_defs: Vec = schema - .fields() - .iter() - .map(|field| { - format!( - "{} {}", - naming::quote_ident(field.name()), - arrow_dtype_to_sql(field.data_type()) - ) - }) - .collect(); - let create_sql = format!( - "CREATE TEMPORARY TABLE {} ({})", - naming::quote_ident(name), - col_defs.join(", ") - ); - self.connection - .execute(&create_sql, (), None) - .map_err(|e| { - GgsqlError::ReaderError(format!("Failed to create temp table '{}': {}", name, e)) - })?; - - // Insert data using ODBC bulk text inserter - let num_rows = df.height(); - if num_rows > 0 { - let num_cols = df.width(); - let placeholders: Vec<&str> = vec!["?"; num_cols]; - let insert_sql = format!( - "INSERT INTO {} VALUES ({})", - naming::quote_ident(name), - placeholders.join(", ") - ); - - // Convert all columns to string representation for text insertion - let columns = df.get_columns(); - let string_columns: Vec>> = columns - .iter() - .map(|col| { - (0..num_rows) - .map(|row| { - if col.is_null(row) { - None - } else { - Some(crate::array_util::value_to_string(col, row)) - } - }) - .collect() - }) - .collect(); - - // Determine max string length per column for buffer allocation - let max_str_lens: Vec = string_columns - .iter() - .map(|col| { - col.iter() - .filter_map(|v| v.as_ref().map(|s| s.len())) - .max() - .unwrap_or(1) - .max(1) // minimum buffer size of 1 - }) - .collect(); - - const BATCH_SIZE: usize = 1024; - let prepared = self.connection.prepare(&insert_sql).map_err(|e| { - GgsqlError::ReaderError(format!("Failed to prepare INSERT for '{}': {}", name, e)) - })?; - - let batch_capacity = num_rows.min(BATCH_SIZE); - let mut inserter = prepared - .into_text_inserter(batch_capacity, max_str_lens) - .map_err(|e| { - GgsqlError::ReaderError(format!( - "Failed to create bulk inserter for '{}': {}", - name, e - )) - })?; - - let mut rows_in_batch = 0; - for row_idx in 0..num_rows { - let row_values: Vec> = string_columns - .iter() - .map(|col| col[row_idx].as_ref().map(|s| s.as_bytes())) - .collect(); - - inserter.append(row_values.into_iter()).map_err(|e| { - GgsqlError::ReaderError(format!( - "Failed to append row {} to '{}': {}", - row_idx, name, e - )) - })?; - rows_in_batch += 1; - - if rows_in_batch >= BATCH_SIZE { - inserter.execute().map_err(|e| { - GgsqlError::ReaderError(format!( - "Failed to execute batch insert into '{}': {}", - name, e - )) - })?; - inserter.clear(); - rows_in_batch = 0; - } - } - - // Execute final partial batch - if rows_in_batch > 0 { - inserter.execute().map_err(|e| { - GgsqlError::ReaderError(format!( - "Failed to execute final batch insert into '{}': {}", - name, e - )) - })?; - } - } - - self.registered_tables.borrow_mut().insert(name.to_string()); - Ok(()) - } - - fn unregister(&self, name: &str) -> Result<()> { - if !self.registered_tables.borrow().contains(name) { - return Err(GgsqlError::ReaderError(format!( - "Table '{}' was not registered via this reader", - name - ))); - } - - let sql = format!("DROP TABLE IF EXISTS {}", naming::quote_ident(name)); - self.connection.execute(&sql, (), None).map_err(|e| { - GgsqlError::ReaderError(format!("Failed to unregister table '{}': {}", name, e)) - })?; - - self.registered_tables.borrow_mut().remove(name); - Ok(()) - } - - fn execute(&self, query: &str) -> Result { - super::execute_with_reader(self, query) - } - - fn dialect(&self) -> &dyn super::SqlDialect { - &*self.dialect - } -} - -/// Map an Arrow data type to a SQL column type string. -fn arrow_dtype_to_sql(dtype: &DataType) -> &'static str { - match dtype { - DataType::Boolean => "BOOLEAN", - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => "BIGINT", - DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => "BIGINT", - DataType::Float32 | DataType::Float64 => "DOUBLE PRECISION", - DataType::Date32 => "DATE", - DataType::Timestamp(_, _) => "TIMESTAMP", - DataType::Time64(_) => "TIME", - _ => "TEXT", - } -} - -/// Column builder that accumulates typed values across batches. -enum ColumnBuilder { - Int8(Vec>), - Int16(Vec>), - Int32(Vec>), - Int64(Vec>), - Float32(Vec>), - Float64(Vec>), - Boolean(Vec>), - Date(Vec>), - Time(Vec>), - Timestamp(Vec>), - Text(Vec>), -} - -impl ColumnBuilder { - fn from_odbc_type(data_type: &OdbcDataType) -> Self { - match data_type { - OdbcDataType::TinyInt => Self::Int8(Vec::new()), - OdbcDataType::SmallInt => Self::Int16(Vec::new()), - OdbcDataType::Integer => Self::Int32(Vec::new()), - OdbcDataType::BigInt => Self::Int64(Vec::new()), - OdbcDataType::Real | OdbcDataType::Float { precision: 0..=24 } => { - Self::Float32(Vec::new()) - } - OdbcDataType::Double | OdbcDataType::Float { .. } => Self::Float64(Vec::new()), - OdbcDataType::Numeric { - scale: 0, - precision, - } - | OdbcDataType::Decimal { - scale: 0, - precision, - } => { - if *precision < 10 { - Self::Int32(Vec::new()) - } else if *precision < 19 { - Self::Int64(Vec::new()) - } else { - Self::Float64(Vec::new()) - } - } - OdbcDataType::Numeric { .. } | OdbcDataType::Decimal { .. } => { - Self::Float64(Vec::new()) - } - OdbcDataType::Bit => Self::Boolean(Vec::new()), - OdbcDataType::Date => Self::Date(Vec::new()), - OdbcDataType::Time { .. } => Self::Time(Vec::new()), - OdbcDataType::Timestamp { .. } => Self::Timestamp(Vec::new()), - _ => Self::Text(Vec::new()), - } - } - - fn append_from_slice(&mut self, slice: AnySlice<'_>) -> std::result::Result<(), String> { - match (self, slice) { - (Self::Int8(v), AnySlice::NullableI8(s)) => { - v.extend(s.map(|opt| opt.copied())); - } - (Self::Int16(v), AnySlice::NullableI16(s)) => { - v.extend(s.map(|opt| opt.copied())); - } - (Self::Int32(v), AnySlice::NullableI32(s)) => { - v.extend(s.map(|opt| opt.copied())); - } - (Self::Int64(v), AnySlice::NullableI64(s)) => { - v.extend(s.map(|opt| opt.copied())); - } - (Self::Float32(v), AnySlice::NullableF32(s)) => { - v.extend(s.map(|opt| opt.copied())); - } - (Self::Float64(v), AnySlice::NullableF64(s)) => { - v.extend(s.map(|opt| opt.copied())); - } - (Self::Boolean(v), AnySlice::NullableBit(s)) => { - v.extend(s.map(|opt| opt.map(|b| b.as_bool()))); - } - (Self::Date(v), AnySlice::NullableDate(s)) => { - v.extend(s.map(|opt| opt.and_then(odbc_date_to_days))); - } - (Self::Time(v), AnySlice::NullableTime(s)) => { - v.extend(s.map(|opt| opt.map(odbc_time_to_nanos))); - } - (Self::Timestamp(v), AnySlice::NullableTimestamp(s)) => { - v.extend(s.map(|opt| opt.and_then(odbc_timestamp_to_micros))); - } - (Self::Text(v), AnySlice::Text(view)) => { - v.extend(view.iter().map(|opt| { - opt.and_then(|bytes| std::str::from_utf8(bytes).ok().map(|s| s.to_string())) - })); - } - (Self::Text(v), AnySlice::WText(view)) => { - v.extend( - view.iter() - .map(|opt| opt.map(|chars| String::from_utf16_lossy(chars.into()))), - ); - } - // Decimal/Numeric with scale > 0 bound as text → parse to f64 - (Self::Float64(v), AnySlice::Text(view)) => { - v.extend(view.iter().map(|opt| { - opt.and_then(|bytes| { - std::str::from_utf8(bytes) - .ok() - .and_then(|s| s.parse::().ok()) - }) - })); - } - // Decimal with scale=0 bound as i32/i64 text fallback - (Self::Int32(v), AnySlice::Text(view)) => { - v.extend(view.iter().map(|opt| { - opt.and_then(|bytes| { - std::str::from_utf8(bytes) - .ok() - .and_then(|s| s.parse::().ok()) - }) - })); - } - (Self::Int64(v), AnySlice::Text(view)) => { - v.extend(view.iter().map(|opt| { - opt.and_then(|bytes| { - std::str::from_utf8(bytes) - .ok() - .and_then(|s| s.parse::().ok()) - }) - })); - } - (builder, _slice) => { - let builder_type = match builder { - Self::Int8(_) => "Int8", - Self::Int16(_) => "Int16", - Self::Int32(_) => "Int32", - Self::Int64(_) => "Int64", - Self::Float32(_) => "Float32", - Self::Float64(_) => "Float64", - Self::Boolean(_) => "Boolean", - Self::Date(_) => "Date", - Self::Time(_) => "Time", - Self::Timestamp(_) => "Timestamp", - Self::Text(_) => "Text", - }; - return Err(format!( - "ODBC type mismatch: expected {builder_type} buffer but driver returned a different type" - )); - } - } - Ok(()) - } - - fn into_named_array(self, name: &str) -> (String, ArrayRef) { - let array: ArrayRef = match self { - Self::Int8(v) => Arc::new(Int8Array::from(v)), - Self::Int16(v) => Arc::new(Int16Array::from(v)), - Self::Int32(v) => Arc::new(Int32Array::from(v)), - Self::Int64(v) => Arc::new(Int64Array::from(v)), - Self::Float32(v) => Arc::new(Float32Array::from(v)), - Self::Float64(v) => Arc::new(Float64Array::from(v)), - Self::Boolean(v) => Arc::new(BooleanArray::from(v)), - Self::Date(v) => Arc::new(Date32Array::from(v)), - Self::Time(v) => Arc::new(Time64NanosecondArray::from(v)), - Self::Timestamp(v) => Arc::new(TimestampMicrosecondArray::from(v)), - Self::Text(v) => { - let refs: Vec> = v.iter().map(|s| s.as_deref()).collect(); - Arc::new(StringArray::from(refs)) - } - }; - (name.to_string(), array) - } -} - -fn odbc_date_to_days(d: &OdbcDate) -> Option { - chrono::NaiveDate::from_ymd_opt(d.year as i32, d.month as u32, d.day as u32).map(|date| { - let epoch = chrono::NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); - (date - epoch).num_days() as i32 - }) -} - -fn odbc_time_to_nanos(t: &OdbcTime) -> i64 { - let h = t.hour as i64; - let m = t.minute as i64; - let s = t.second as i64; - (h * 3600 + m * 60 + s) * 1_000_000_000 -} - -fn odbc_timestamp_to_micros(ts: &OdbcTimestamp) -> Option { - chrono::NaiveDate::from_ymd_opt(ts.year as i32, ts.month as u32, ts.day as u32) - .and_then(|date| { - date.and_hms_nano_opt( - ts.hour as u32, - ts.minute as u32, - ts.second as u32, - ts.fraction, - ) - }) - .map(|dt| dt.and_utc().timestamp_micros()) -} - -/// Convert an ODBC cursor to a DataFrame using typed buffers. -fn cursor_to_dataframe(mut cursor: impl Cursor) -> Result { - let col_count = cursor - .num_result_cols() - .map_err(|e| GgsqlError::ReaderError(format!("Failed to get column count: {}", e)))? - as usize; - - if col_count == 0 { - return Ok(DataFrame::empty()); - } - - // Collect column names and types, build buffer descriptors - let mut col_names = Vec::with_capacity(col_count); - let mut col_types = Vec::with_capacity(col_count); - let mut descs = Vec::with_capacity(col_count); - - let text_fallback = BufferDesc::Text { max_str_len: 65536 }; - - for i in 1..=col_count as u16 { - let name = cursor.col_name(i).map_err(|e| { - GgsqlError::ReaderError(format!("Failed to get column {} name: {}", i, e)) - })?; - let data_type = cursor.col_data_type(i).map_err(|e| { - GgsqlError::ReaderError(format!("Failed to get column {} type: {}", i, e)) - })?; - - let desc = BufferDesc::from_data_type(data_type, true).unwrap_or(text_fallback); - - col_names.push(name); - col_types.push(data_type); - descs.push(desc); - } - - // Create typed columnar buffer and column builders - let batch_size = 1000; - let mut builders: Vec = col_types - .iter() - .map(ColumnBuilder::from_odbc_type) - .collect(); - - let mut buffer = ColumnarBuffer::::from_descs(batch_size, descs); - - let mut block_cursor = cursor - .bind_buffer(&mut buffer) - .map_err(|e| GgsqlError::ReaderError(format!("Failed to bind buffer: {}", e)))?; - - while let Some(batch) = block_cursor - .fetch() - .map_err(|e| GgsqlError::ReaderError(format!("Failed to fetch batch: {}", e)))? - { - for (col_idx, builder) in builders.iter_mut().enumerate() { - let slice = batch.column(col_idx); - builder.append_from_slice(slice).map_err(|e| { - GgsqlError::ReaderError(format!("Column '{}': {}", col_names[col_idx], e)) - })?; - } - } - - // Convert builders to named arrays - let named_arrays: Vec<(String, ArrayRef)> = col_names - .iter() - .zip(builders) - .map(|(name, builder)| builder.into_named_array(name)) - .collect(); - - DataFrame::new(named_arrays) -} - -// ============================================================================ -// Snowflake Workbench credential detection -// ============================================================================ - -fn is_snowflake(conn_str: &str) -> bool { - conn_str.to_lowercase().contains("driver=snowflake") -} - -fn has_token(conn_str: &str) -> bool { - conn_str.to_lowercase().contains("token=") -} - -fn home_dir() -> Option { - #[cfg(target_os = "windows")] - { - std::env::var("USERPROFILE") - .ok() - .map(std::path::PathBuf::from) - } - #[cfg(not(target_os = "windows"))] - { - std::env::var("HOME").ok().map(std::path::PathBuf::from) - } -} - -/// Find the Snowflake connections.toml file, checking standard locations. -fn find_snowflake_connections_toml() -> Option { - use std::path::PathBuf; - - // 1. $SNOWFLAKE_HOME/connections.toml - if let Ok(snowflake_home) = std::env::var("SNOWFLAKE_HOME") { - let p = PathBuf::from(&snowflake_home).join("connections.toml"); - if p.exists() { - return Some(p); - } - } - - // 2. ~/.snowflake/connections.toml - if let Some(home) = home_dir() { - let p = home.join(".snowflake").join("connections.toml"); - if p.exists() { - return Some(p); - } - } - - // 3. Platform-specific paths - if let Some(home) = home_dir() { - #[cfg(target_os = "macos")] - { - let p = home.join("Library/Application Support/snowflake/connections.toml"); - if p.exists() { - return Some(p); - } - } - - #[cfg(target_os = "linux")] - { - let xdg = std::env::var("XDG_CONFIG_HOME") - .map(PathBuf::from) - .unwrap_or_else(|_| home.join(".config")); - let p = xdg.join("snowflake").join("connections.toml"); - if p.exists() { - return Some(p); - } - } - - #[cfg(target_os = "windows")] - { - let p = home.join("AppData/Local/snowflake/connections.toml"); - if p.exists() { - return Some(p); - } - } - } - - None -} - -/// Resolve a `ConnectionName=` parameter in a Snowflake ODBC connection -/// string by reading the named entry from `~/.snowflake/connections.toml` and -/// building a full ODBC connection string from it. -fn resolve_connection_name(conn_str: &str) -> Option { - // Extract ConnectionName value (case-insensitive) - let lower = conn_str.to_lowercase(); - let cn_key = "connectionname="; - let cn_start = lower.find(cn_key)?; - let value_start = cn_start + cn_key.len(); - - let rest = &conn_str[value_start..]; - let value_end = rest.find(';').unwrap_or(rest.len()); - let connection_name = rest[..value_end].trim(); - - if connection_name.is_empty() { - return None; - } - - // Read and parse connections.toml - let toml_path = find_snowflake_connections_toml()?; - let content = std::fs::read_to_string(&toml_path).ok()?; - let doc = content.parse::().ok()?; - - let entry = doc.get(connection_name)?; - if !entry.is_table() && !entry.is_inline_table() { - return None; - } - - // Build ODBC connection string from TOML entry fields - let get_str = |key: &str| -> Option { entry.get(key)?.as_str().map(|s| s.to_string()) }; - - let account = get_str("account")?; - let mut parts = vec![ - "Driver=Snowflake".to_string(), - format!("Server={}.snowflakecomputing.com", account), - ]; - - if let Some(user) = get_str("user") { - parts.push(format!("UID={}", user)); - } - if let Some(password) = get_str("password") { - parts.push(format!("PWD={}", password)); - } - if let Some(authenticator) = get_str("authenticator") { - parts.push(format!("Authenticator={}", authenticator)); - } - if let Some(token) = get_str("token") { - parts.push(format!("Token={}", token)); - } - if let Some(warehouse) = get_str("warehouse") { - parts.push(format!("Warehouse={}", warehouse)); - } - if let Some(database) = get_str("database") { - parts.push(format!("Database={}", database)); - } - if let Some(schema) = get_str("schema") { - parts.push(format!("Schema={}", schema)); - } - if let Some(role) = get_str("role") { - parts.push(format!("Role={}", role)); - } - - Some(parts.join(";")) -} - -/// Detect Posit Workbench Snowflake OAuth token. -/// -/// Checks `SNOWFLAKE_HOME` for a Workbench-managed `connections.toml` file -/// containing OAuth credentials. -fn detect_workbench_token() -> Option { - let snowflake_home = std::env::var("SNOWFLAKE_HOME").ok()?; - - // Only use Workbench credentials if the path indicates Workbench management - if !snowflake_home.contains("posit-workbench") { - return None; - } - - let toml_path = std::path::Path::new(&snowflake_home).join("connections.toml"); - let content = std::fs::read_to_string(&toml_path).ok()?; - - let doc = content.parse::().ok()?; - let token = doc.get("workbench")?.get("token")?.as_str()?.to_string(); - - if token.is_empty() { - None - } else { - Some(token) - } -} - -/// Inject OAuth token into a Snowflake ODBC connection string. -fn inject_snowflake_token(conn_str: &str, token: &str) -> String { - // Append authenticator and token parameters - let mut result = conn_str.trim_end_matches(';').to_string(); - result.push_str(";Authenticator=oauth;Token="); - result.push_str(token); - result -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_is_snowflake() { - assert!(is_snowflake( - "Driver=Snowflake;Server=foo.snowflakecomputing.com" - )); - assert!(!is_snowflake("Driver={PostgreSQL};Server=localhost")); - } - - #[test] - fn test_has_token() { - assert!(has_token("Driver=Snowflake;Token=abc123")); - assert!(!has_token("Driver=Snowflake;Server=foo")); - } - - #[test] - fn test_detect_dialect() { - // Snowflake uses SHOW commands - let dialect = detect_dialect("Driver=Snowflake;Server=foo"); - assert!(dialect.sql_list_catalogs().contains("SHOW")); - - // PostgreSQL uses information_schema (ANSI default) - let dialect = detect_dialect("Driver={PostgreSQL};Server=localhost"); - assert!(dialect.sql_list_catalogs().contains("information_schema")); - - // Generic uses information_schema (ANSI default) - let dialect = detect_dialect("Driver=SomeOther;Server=localhost"); - assert!(dialect.sql_list_catalogs().contains("information_schema")); - } - - #[test] - fn test_inject_snowflake_token() { - let result = inject_snowflake_token( - "Driver=Snowflake;Server=foo.snowflakecomputing.com", - "mytoken", - ); - assert!(result.contains("Authenticator=oauth")); - assert!(result.contains("Token=mytoken")); - } - - #[test] - fn test_resolve_connection_name_with_toml() { - use std::io::Write; - - // Create a temp dir with a connections.toml - let dir = tempfile::tempdir().unwrap(); - let toml_path = dir.path().join("connections.toml"); - let mut f = std::fs::File::create(&toml_path).unwrap(); - writeln!( - f, - r#" -default_connection_name = "myconn" - -[myconn] -account = "myaccount" -user = "myuser" -password = "mypass" -warehouse = "mywh" -database = "mydb" -schema = "public" -role = "myrole" - -[other] -account = "otheraccount" -"# - ) - .unwrap(); - - // Point SNOWFLAKE_HOME at our temp dir - std::env::set_var("SNOWFLAKE_HOME", dir.path()); - - let result = resolve_connection_name("Driver=Snowflake;ConnectionName=myconn"); - assert!(result.is_some()); - let conn = result.unwrap(); - assert!(conn.contains("Driver=Snowflake")); - assert!(conn.contains("Server=myaccount.snowflakecomputing.com")); - assert!(conn.contains("UID=myuser")); - assert!(conn.contains("PWD=mypass")); - assert!(conn.contains("Warehouse=mywh")); - assert!(conn.contains("Database=mydb")); - assert!(conn.contains("Schema=public")); - assert!(conn.contains("Role=myrole")); - - // Test with a connection that has fewer fields - let result2 = resolve_connection_name("Driver=Snowflake;ConnectionName=other"); - assert!(result2.is_some()); - let conn2 = result2.unwrap(); - assert!(conn2.contains("Server=otheraccount.snowflakecomputing.com")); - assert!(!conn2.contains("UID=")); - - // Test with non-existent connection name - let result3 = resolve_connection_name("Driver=Snowflake;ConnectionName=nonexistent"); - assert!(result3.is_none()); - - // No ConnectionName param → None - let result4 = resolve_connection_name("Driver=Snowflake;Server=foo"); - assert!(result4.is_none()); - - // Clean up env - std::env::remove_var("SNOWFLAKE_HOME"); - } - - #[test] - fn test_arrow_dtype_to_sql() { - assert_eq!(arrow_dtype_to_sql(&DataType::Int64), "BIGINT"); - assert_eq!(arrow_dtype_to_sql(&DataType::Float64), "DOUBLE PRECISION"); - assert_eq!(arrow_dtype_to_sql(&DataType::Boolean), "BOOLEAN"); - assert_eq!(arrow_dtype_to_sql(&DataType::Date32), "DATE"); - assert_eq!(arrow_dtype_to_sql(&DataType::Utf8), "TEXT"); - } -} diff --git a/src/reader/odbc/ffi.rs b/src/reader/odbc/ffi.rs new file mode 100644 index 00000000..ab65b675 --- /dev/null +++ b/src/reader/odbc/ffi.rs @@ -0,0 +1,383 @@ +//! Raw ODBC FFI types, constants, and runtime-loaded function pointers. +//! +//! Loads `libodbc` via `libloading` at runtime instead of linking at compile +//! time. Only the subset of ODBC functions used by our reader is included. + +use std::sync::OnceLock; + +// ============================================================================ +// Primitive type aliases (must match ODBC spec sizes) +// ============================================================================ + +pub type SqlChar = u8; +pub type SqlSmallInt = i16; +pub type SqlUSmallInt = u16; +pub type SqlInteger = i32; +pub type SqlUInteger = u32; +pub type SqlLen = isize; // pointer-sized: critical for indicator buffers +pub type SqlULen = usize; // pointer-sized: critical for column size +pub type SqlReturn = i16; +pub type SqlHandle = *mut std::ffi::c_void; +pub type SqlHEnv = SqlHandle; +pub type SqlHDbc = SqlHandle; +pub type SqlHStmt = SqlHandle; +pub type SqlHWnd = SqlHandle; +pub type SqlPointer = *mut std::ffi::c_void; + +// ============================================================================ +// Return codes +// ============================================================================ + +pub const SQL_SUCCESS: SqlReturn = 0; +pub const SQL_SUCCESS_WITH_INFO: SqlReturn = 1; +pub const SQL_ERROR: SqlReturn = -1; +pub const SQL_INVALID_HANDLE: SqlReturn = -2; +pub const SQL_NO_DATA: SqlReturn = 100; + +#[inline] +pub fn succeeded(rc: SqlReturn) -> bool { + rc == SQL_SUCCESS || rc == SQL_SUCCESS_WITH_INFO +} + +// ============================================================================ +// Handle types (for SQLAllocHandle / SQLFreeHandle) +// ============================================================================ + +pub const SQL_HANDLE_ENV: SqlSmallInt = 1; +pub const SQL_HANDLE_DBC: SqlSmallInt = 2; +pub const SQL_HANDLE_STMT: SqlSmallInt = 3; + +pub const SQL_NULL_HANDLE: SqlHandle = std::ptr::null_mut(); + +// ============================================================================ +// Environment attributes +// ============================================================================ + +pub const SQL_ATTR_ODBC_VERSION: SqlInteger = 200; +pub const SQL_OV_ODBC3: SqlInteger = 3; + +// ============================================================================ +// Connection attributes +// ============================================================================ + +pub const SQL_ATTR_AUTOCOMMIT: SqlInteger = 102; +pub const SQL_AUTOCOMMIT_ON: SqlUInteger = 1; + +// ============================================================================ +// Statement attributes +// ============================================================================ + +pub const SQL_ATTR_ROW_BIND_TYPE: SqlInteger = 5; +pub const SQL_ATTR_ROW_ARRAY_SIZE: SqlInteger = 27; +pub const SQL_ATTR_ROWS_FETCHED_PTR: SqlInteger = 26; +pub const SQL_BIND_BY_COLUMN: SqlULen = 0; + +// ============================================================================ +// SQLFreeStmt options +// ============================================================================ + +pub const SQL_CLOSE: SqlUSmallInt = 0; +pub const SQL_UNBIND: SqlUSmallInt = 2; +pub const SQL_RESET_PARAMS: SqlUSmallInt = 3; + +// ============================================================================ +// SQLDriverConnect options +// ============================================================================ + +pub const SQL_DRIVER_NOPROMPT: SqlUSmallInt = 0; + +// ============================================================================ +// SQLEndTran completion types +// ============================================================================ + +pub const SQL_COMMIT: SqlSmallInt = 0; +pub const SQL_ROLLBACK: SqlSmallInt = 1; + +// ============================================================================ +// SQL data types (returned by SQLDescribeCol) +// ============================================================================ + +pub const SQL_UNKNOWN_TYPE: SqlSmallInt = 0; +pub const SQL_CHAR: SqlSmallInt = 1; +pub const SQL_NUMERIC: SqlSmallInt = 2; +pub const SQL_DECIMAL: SqlSmallInt = 3; +pub const SQL_INTEGER: SqlSmallInt = 4; +pub const SQL_SMALLINT: SqlSmallInt = 5; +pub const SQL_FLOAT: SqlSmallInt = 6; +pub const SQL_REAL: SqlSmallInt = 7; +pub const SQL_DOUBLE: SqlSmallInt = 8; +pub const SQL_VARCHAR: SqlSmallInt = 12; +pub const SQL_TYPE_DATE: SqlSmallInt = 91; +pub const SQL_TYPE_TIME: SqlSmallInt = 92; +pub const SQL_TYPE_TIMESTAMP: SqlSmallInt = 93; +pub const SQL_BIGINT: SqlSmallInt = -5; +pub const SQL_TINYINT: SqlSmallInt = -6; +pub const SQL_BIT: SqlSmallInt = -7; +pub const SQL_WCHAR: SqlSmallInt = -8; +pub const SQL_WVARCHAR: SqlSmallInt = -9; +pub const SQL_LONGVARCHAR: SqlSmallInt = -1; +pub const SQL_WLONGVARCHAR: SqlSmallInt = -10; + +// ============================================================================ +// C data types (for SQLBindCol / SQLBindParameter / SQLGetData) +// ============================================================================ + +pub const SQL_C_CHAR: SqlSmallInt = 1; +pub const SQL_C_SLONG: SqlSmallInt = -16; +pub const SQL_C_SSHORT: SqlSmallInt = -15; +pub const SQL_C_STINYINT: SqlSmallInt = -26; +pub const SQL_C_SBIGINT: SqlSmallInt = -25; +pub const SQL_C_FLOAT: SqlSmallInt = 7; +pub const SQL_C_DOUBLE: SqlSmallInt = 8; +pub const SQL_C_BIT: SqlSmallInt = -7; +pub const SQL_C_TYPE_DATE: SqlSmallInt = 91; +pub const SQL_C_TYPE_TIME: SqlSmallInt = 92; +pub const SQL_C_TYPE_TIMESTAMP: SqlSmallInt = 93; + +// ============================================================================ +// Indicator / length constants +// ============================================================================ + +pub const SQL_NULL_DATA: SqlLen = -1; +pub const SQL_NTS: SqlLen = -3; + +// ============================================================================ +// SQLGetInfo info types +// ============================================================================ + +pub const SQL_DBMS_NAME: SqlUSmallInt = 17; + +// ============================================================================ +// SQLBindParameter input/output type +// ============================================================================ + +pub const SQL_PARAM_INPUT: SqlSmallInt = 1; + +// ============================================================================ +// Nullable constants (returned by SQLDescribeCol) +// ============================================================================ + +pub const SQL_NO_NULLS: SqlSmallInt = 0; +pub const SQL_NULLABLE: SqlSmallInt = 1; + +// ============================================================================ +// Date/time structs (matching ODBC C structs) +// ============================================================================ + +#[repr(C)] +#[derive(Debug, Clone, Copy, Default)] +pub struct SqlDateStruct { + pub year: SqlSmallInt, + pub month: SqlUSmallInt, + pub day: SqlUSmallInt, +} + +#[repr(C)] +#[derive(Debug, Clone, Copy, Default)] +pub struct SqlTimeStruct { + pub hour: SqlUSmallInt, + pub minute: SqlUSmallInt, + pub second: SqlUSmallInt, +} + +#[repr(C)] +#[derive(Debug, Clone, Copy, Default)] +pub struct SqlTimestampStruct { + pub year: SqlSmallInt, + pub month: SqlUSmallInt, + pub day: SqlUSmallInt, + pub hour: SqlUSmallInt, + pub minute: SqlUSmallInt, + pub second: SqlUSmallInt, + pub fraction: u32, +} + +// ============================================================================ +// Runtime-loaded function table +// ============================================================================ + +macro_rules! odbc_functions { + ($( + $(#[$meta:meta])* + fn $name:ident( $($arg:ident : $argty:ty),* $(,)? ) -> SqlReturn; + )*) => { + #[allow(non_snake_case)] + pub(crate) struct OdbcFunctions { + $( pub $name: unsafe extern "system" fn( $($argty),* ) -> SqlReturn, )* + } + + unsafe fn load_functions(lib: &libloading::Library) -> Result { + unsafe { + Ok(OdbcFunctions { + $( $name: *lib.get(concat!(stringify!($name), "\0").as_bytes()) + .map_err(|e| format!("{}: {}", stringify!($name), e))?, )* + }) + } + } + }; +} + +odbc_functions! { + fn SQLAllocHandle(handle_type: SqlSmallInt, input_handle: SqlHandle, output_handle: *mut SqlHandle) -> SqlReturn; + fn SQLFreeHandle(handle_type: SqlSmallInt, handle: SqlHandle) -> SqlReturn; + fn SQLSetEnvAttr(env: SqlHEnv, attribute: SqlInteger, value: SqlPointer, string_length: SqlInteger) -> SqlReturn; + fn SQLDriverConnect( + dbc: SqlHDbc, hwnd: SqlHWnd, + in_conn_str: *const SqlChar, in_len: SqlSmallInt, + out_conn_str: *mut SqlChar, out_max: SqlSmallInt, out_len: *mut SqlSmallInt, + driver_completion: SqlUSmallInt + ) -> SqlReturn; + fn SQLDisconnect(dbc: SqlHDbc) -> SqlReturn; + fn SQLExecDirect(stmt: SqlHStmt, text: *const SqlChar, text_length: SqlInteger) -> SqlReturn; + fn SQLPrepare(stmt: SqlHStmt, text: *const SqlChar, text_length: SqlInteger) -> SqlReturn; + fn SQLExecute(stmt: SqlHStmt) -> SqlReturn; + fn SQLNumResultCols(stmt: SqlHStmt, column_count: *mut SqlSmallInt) -> SqlReturn; + fn SQLDescribeCol( + stmt: SqlHStmt, col_number: SqlUSmallInt, + col_name: *mut SqlChar, buf_len: SqlSmallInt, name_len: *mut SqlSmallInt, + data_type: *mut SqlSmallInt, col_size: *mut SqlULen, + decimal_digits: *mut SqlSmallInt, nullable: *mut SqlSmallInt + ) -> SqlReturn; + fn SQLBindCol( + stmt: SqlHStmt, col_number: SqlUSmallInt, target_type: SqlSmallInt, + target_value: SqlPointer, buffer_length: SqlLen, indicator: *mut SqlLen + ) -> SqlReturn; + fn SQLFetch(stmt: SqlHStmt) -> SqlReturn; + fn SQLBindParameter( + stmt: SqlHStmt, param_number: SqlUSmallInt, input_output_type: SqlSmallInt, + value_type: SqlSmallInt, parameter_type: SqlSmallInt, + column_size: SqlULen, decimal_digits: SqlSmallInt, + parameter_value: SqlPointer, buffer_length: SqlLen, + str_len_or_ind: *mut SqlLen + ) -> SqlReturn; + fn SQLGetDiagRec( + handle_type: SqlSmallInt, handle: SqlHandle, rec_number: SqlSmallInt, + sql_state: *mut SqlChar, native_error: *mut SqlInteger, + message_text: *mut SqlChar, buffer_length: SqlSmallInt, + text_length: *mut SqlSmallInt + ) -> SqlReturn; + fn SQLFreeStmt(stmt: SqlHStmt, option: SqlUSmallInt) -> SqlReturn; + fn SQLRowCount(stmt: SqlHStmt, row_count: *mut SqlLen) -> SqlReturn; + fn SQLSetStmtAttr(stmt: SqlHStmt, attribute: SqlInteger, value: SqlPointer, string_length: SqlInteger) -> SqlReturn; + fn SQLGetInfo( + dbc: SqlHDbc, info_type: SqlUSmallInt, + info_value: SqlPointer, buffer_length: SqlSmallInt, + string_length: *mut SqlSmallInt + ) -> SqlReturn; + fn SQLTables( + stmt: SqlHStmt, + catalog: *const SqlChar, catalog_len: SqlSmallInt, + schema: *const SqlChar, schema_len: SqlSmallInt, + table: *const SqlChar, table_len: SqlSmallInt, + table_type: *const SqlChar, table_type_len: SqlSmallInt + ) -> SqlReturn; + fn SQLColumns( + stmt: SqlHStmt, + catalog: *const SqlChar, catalog_len: SqlSmallInt, + schema: *const SqlChar, schema_len: SqlSmallInt, + table: *const SqlChar, table_len: SqlSmallInt, + column: *const SqlChar, column_len: SqlSmallInt + ) -> SqlReturn; + fn SQLSetConnectAttr(dbc: SqlHDbc, attribute: SqlInteger, value: SqlPointer, string_length: SqlInteger) -> SqlReturn; + fn SQLEndTran(handle_type: SqlSmallInt, handle: SqlHandle, completion_type: SqlSmallInt) -> SqlReturn; +} + +struct OdbcLibrary { + _lib: libloading::Library, + pub fns: OdbcFunctions, +} + +// Safety: libloading::Library is Send+Sync. Function pointers are valid for +// the library's lifetime, and ODBC thread safety is the caller's responsibility. +unsafe impl Send for OdbcLibrary {} +unsafe impl Sync for OdbcLibrary {} + +static ODBC: OnceLock> = OnceLock::new(); + +fn load_odbc() -> Result { + // Check env var first + if let Ok(path) = std::env::var("GGSQL_ODBC_LIBRARY") { + match unsafe { libloading::Library::new(&path) } { + Ok(lib) => { + let fns = unsafe { load_functions(&lib)? }; + return Ok(OdbcLibrary { _lib: lib, fns }); + } + Err(e) => { + return Err(format!( + "GGSQL_ODBC_LIBRARY={} could not be loaded: {}", + path, e + )); + } + } + } + + let mut names: Vec = Vec::new(); + + #[cfg(target_os = "linux")] + { + names.push("libodbc.so.2".into()); + names.push("libodbc.so".into()); + } + #[cfg(target_os = "macos")] + { + names.push("libodbc.2.dylib".into()); + names.push("libodbc.dylib".into()); + // Homebrew on Apple Silicon + names.push("/opt/homebrew/lib/libodbc.2.dylib".into()); + // Homebrew on Intel + names.push("/usr/local/lib/libodbc.2.dylib".into()); + } + #[cfg(target_os = "windows")] + { + names.push("odbc32.dll".into()); + } + #[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))] + { + names.push("libodbc.so.2".into()); + names.push("libodbc.so".into()); + } + + let mut errors = Vec::new(); + for name in &names { + match unsafe { libloading::Library::new(name) } { + Ok(lib) => { + let fns = unsafe { load_functions(&lib)? }; + return Ok(OdbcLibrary { _lib: lib, fns }); + } + Err(e) => errors.push(format!(" {}: {}", name, e)), + } + } + + Err(format!( + "ODBC driver manager not found. Install unixODBC for your platform:\n\ + \n\ + \x20 macOS: brew install unixodbc\n\ + \x20 Debian: sudo apt install unixodbc\n\ + \x20 Fedora: sudo dnf install unixODBC\n\ + \x20 RHEL: sudo yum install unixODBC\n\ + \n\ + Or set GGSQL_ODBC_LIBRARY to the path of your ODBC driver manager.\n\ + \n\ + Tried:\n{}", + errors.join("\n") + )) +} + +/// Pre-load the ODBC driver manager. Returns Ok(()) if available. +pub fn try_load() -> Result<(), String> { + match ODBC.get_or_init(load_odbc) { + Ok(_) => Ok(()), + Err(e) => Err(e.clone()), + } +} + +/// Get the loaded ODBC function table. Panics if not loaded. +pub(crate) fn fns() -> &'static OdbcFunctions { + match ODBC.get_or_init(load_odbc) { + Ok(lib) => &lib.fns, + Err(e) => panic!( + "ODBC function called but driver manager is not available: {}", + e + ), + } +} diff --git a/src/reader/odbc/mod.rs b/src/reader/odbc/mod.rs new file mode 100644 index 00000000..b7d3a74e --- /dev/null +++ b/src/reader/odbc/mod.rs @@ -0,0 +1,1464 @@ +//! Generic ODBC data source implementation +//! +//! Provides a reader for any ODBC-compatible database using runtime-loaded +//! ODBC bindings via `libloading`. The ODBC driver manager (`libodbc`) is +//! loaded on first use — the binary runs fine without it until an ODBC +//! connection is requested. + +#[allow(dead_code)] +pub(crate) mod ffi; +mod snowflake; +#[allow(dead_code)] +mod wrapper; + +use crate::reader::Reader; +use crate::{naming, DataFrame, GgsqlError, Result}; +use arrow::array::*; +use arrow::datatypes::DataType; +use ffi::*; +use std::cell::RefCell; +use std::collections::HashSet; +use std::sync::Arc; +use wrapper::{Connection, Statement}; + +/// Detect the backend SQL dialect from the DBMS name and connection string. +fn detect_dialect(dbms_name: Option<&str>, conn_str: &str) -> Box { + if let Some(name) = dbms_name { + let lower = name.to_lowercase(); + if lower.contains("snowflake") { + return Box::new(super::AnsiDialect); + } + #[cfg(feature = "sqlite")] + if lower.contains("sqlite") { + return Box::new(super::sqlite::SqliteDialect); + } + #[cfg(feature = "duckdb")] + if lower.contains("duckdb") { + return Box::new(super::duckdb::DuckDbDialect); + } + } + + // Fall back to connection string matching + let lower = conn_str.to_lowercase(); + if lower.contains("driver=snowflake") { + Box::new(super::AnsiDialect) + } else if lower.contains("driver=sqlite") || lower.contains("driver={sqlite") { + #[cfg(feature = "sqlite")] + { + Box::new(super::sqlite::SqliteDialect) + } + #[cfg(not(feature = "sqlite"))] + { + Box::new(super::AnsiDialect) + } + } else if lower.contains("driver=duckdb") || lower.contains("driver={duckdb") { + #[cfg(feature = "duckdb")] + { + Box::new(super::duckdb::DuckDbDialect) + } + #[cfg(not(feature = "duckdb"))] + { + Box::new(super::AnsiDialect) + } + } else { + Box::new(super::AnsiDialect) + } +} + +/// Generic ODBC reader implementing the `Reader` trait. +pub struct OdbcReader { + connection: Connection, + dialect: Box, + registered_tables: RefCell>, +} + +// Safety: ODBC connections are safe to use from one thread at a time. +// The Reader trait requires &self (immutable) for execute_sql. +unsafe impl Send for OdbcReader {} + +impl OdbcReader { + /// Create a new ODBC reader from a `odbc://` connection URI. + pub fn from_connection_string(uri: &str) -> Result { + ffi::try_load() + .map_err(|e| GgsqlError::ReaderError(format!("ODBC is not available: {}", e)))?; + + let conn_str = uri + .strip_prefix("odbc://") + .ok_or_else(|| GgsqlError::ReaderError("ODBC URI must start with odbc://".into()))?; + + let mut conn_str = conn_str.to_string(); + + if snowflake::is_snowflake(&conn_str) { + if let Some(resolved) = snowflake::resolve_connection_name(&conn_str) { + conn_str = resolved; + } + } + + if snowflake::is_snowflake(&conn_str) && !snowflake::has_token(&conn_str) { + if let Some(token) = snowflake::detect_workbench_token() { + conn_str = snowflake::inject_snowflake_token(&conn_str, &token); + } + } + + let env = wrapper::odbc_env()?; + let connection = Connection::connect(env, &conn_str)?; + + let dbms_name = connection.dbms_name(); + let dialect = detect_dialect(dbms_name.as_deref(), &conn_str); + + Ok(Self { + connection, + dialect, + registered_tables: RefCell::new(HashSet::new()), + }) + } +} + +impl Reader for OdbcReader { + fn execute_sql(&self, sql: &str) -> Result { + let cursor = self.connection.execute(sql)?; + + let Some(cursor) = cursor else { + return Ok(DataFrame::empty()); + }; + + cursor_to_dataframe(cursor) + } + + fn register(&self, name: &str, df: DataFrame, replace: bool) -> Result<()> { + super::validate_table_name(name)?; + + if replace { + let drop_sql = format!("DROP TABLE IF EXISTS {}", naming::quote_ident(name)); + let _ = self.connection.execute(&drop_sql); + } + + let schema = df.schema(); + let col_defs: Vec = schema + .fields() + .iter() + .map(|field| { + format!( + "{} {}", + naming::quote_ident(field.name()), + arrow_dtype_to_sql(field.data_type()) + ) + }) + .collect(); + let create_sql = format!( + "CREATE TEMPORARY TABLE {} ({})", + naming::quote_ident(name), + col_defs.join(", ") + ); + self.connection.execute(&create_sql).map_err(|e| { + GgsqlError::ReaderError(format!("Failed to create temp table '{}': {}", name, e)) + })?; + + let num_rows = df.height(); + if num_rows > 0 { + let num_cols = df.width(); + let placeholders: Vec<&str> = vec!["?"; num_cols]; + let insert_sql = format!( + "INSERT INTO {} VALUES ({})", + naming::quote_ident(name), + placeholders.join(", ") + ); + + let columns = df.get_columns(); + let string_columns: Vec>> = columns + .iter() + .map(|col| { + (0..num_rows) + .map(|row| { + if col.is_null(row) { + None + } else { + Some(crate::array_util::value_to_string(col, row)) + } + }) + .collect() + }) + .collect(); + + let prepared = self.connection.prepare(&insert_sql).map_err(|e| { + GgsqlError::ReaderError(format!("Failed to prepare INSERT for '{}': {}", name, e)) + })?; + + for row_idx in 0..num_rows { + let row_values: Vec> = string_columns + .iter() + .map(|col| col[row_idx].as_ref().map(|s| s.as_bytes())) + .collect(); + + let mut indicators: Vec = row_values + .iter() + .map(|v| match v { + Some(bytes) => bytes.len() as SqlLen, + None => SQL_NULL_DATA, + }) + .collect(); + + for (col_idx, value) in row_values.iter().enumerate() { + let (ptr, len) = match value { + Some(bytes) => (bytes.as_ptr(), bytes.len() as SqlLen), + None => (std::ptr::null(), 0), + }; + unsafe { + prepared.bind_text_parameter( + (col_idx + 1) as u16, + ptr, + len, + &mut indicators[col_idx], + )?; + } + } + + prepared.execute().map_err(|e| { + GgsqlError::ReaderError(format!( + "Failed to insert row {} into '{}': {}", + row_idx, name, e + )) + })?; + + prepared.reset_params()?; + } + } + + self.registered_tables.borrow_mut().insert(name.to_string()); + Ok(()) + } + + fn unregister(&self, name: &str) -> Result<()> { + if !self.registered_tables.borrow().contains(name) { + return Err(GgsqlError::ReaderError(format!( + "Table '{}' was not registered via this reader", + name + ))); + } + + let sql = format!("DROP TABLE IF EXISTS {}", naming::quote_ident(name)); + self.connection.execute(&sql).map_err(|e| { + GgsqlError::ReaderError(format!("Failed to unregister table '{}': {}", name, e)) + })?; + + self.registered_tables.borrow_mut().remove(name); + Ok(()) + } + + fn execute(&self, query: &str) -> Result { + super::execute_with_reader(self, query) + } + + fn dialect(&self) -> &dyn super::SqlDialect { + &*self.dialect + } + + fn list_catalogs(&self) -> Result> { + // ODBC spec: CatalogName="%", SchemaName="", TableName="" + let stmt = wrapper::sql_tables(&self.connection, Some("%"), Some(""), Some(""), None)?; + let df = cursor_to_dataframe(stmt)?; + let mut catalogs = extract_string_column_ci(&df, "TABLE_CAT")?; + catalogs.sort(); + catalogs.dedup(); + Ok(catalogs) + } + + fn list_schemas(&self, _catalog: &str) -> Result> { + // ODBC spec: CatalogName="", SchemaName="%", TableName="" + let stmt = wrapper::sql_tables(&self.connection, Some(""), Some("%"), Some(""), None)?; + let df = cursor_to_dataframe(stmt)?; + let mut schemas = extract_string_column_ci(&df, "TABLE_SCHEM")?; + schemas.sort(); + schemas.dedup(); + Ok(schemas) + } + + fn list_tables(&self, catalog: &str, schema: &str) -> Result> { + let cat = if catalog.is_empty() { + None + } else { + Some(catalog) + }; + let sch = if schema.is_empty() { + None + } else { + Some(schema) + }; + let stmt = wrapper::sql_tables(&self.connection, cat, sch, Some("%"), Some("TABLE,VIEW"))?; + let df = cursor_to_dataframe(stmt)?; + extract_table_infos_ci(&df) + } + + fn list_columns( + &self, + catalog: &str, + schema: &str, + table: &str, + ) -> Result> { + let cat = if catalog.is_empty() { + None + } else { + Some(catalog) + }; + let sch = if schema.is_empty() { + None + } else { + Some(schema) + }; + let stmt = wrapper::sql_columns(&self.connection, cat, sch, Some(table), None)?; + let df = cursor_to_dataframe(stmt)?; + extract_column_infos_ci(&df) + } +} + +/// Find a column in a DataFrame by name (case-insensitive). +fn find_column_ci<'a>(df: &'a DataFrame, name: &str) -> Option<&'a ArrayRef> { + let lower = name.to_lowercase(); + let schema = df.schema(); + for (i, field) in schema.fields().iter().enumerate() { + if field.name().to_lowercase() == lower { + return Some(&df.get_columns()[i]); + } + } + None +} + +fn extract_string_column_ci(df: &DataFrame, col_name: &str) -> Result> { + let col = find_column_ci(df, col_name).ok_or_else(|| { + GgsqlError::ReaderError(format!("Column '{}' not found in ODBC result", col_name)) + })?; + let mut results = Vec::with_capacity(df.height()); + for i in 0..df.height() { + if !col.is_null(i) { + results.push(crate::array_util::value_to_string(col, i)); + } + } + Ok(results) +} + +fn extract_table_infos_ci(df: &DataFrame) -> Result> { + let name_col = find_column_ci(df, "TABLE_NAME").ok_or_else(|| { + GgsqlError::ReaderError("Column 'TABLE_NAME' not found in ODBC result".into()) + })?; + let type_col = find_column_ci(df, "TABLE_TYPE").ok_or_else(|| { + GgsqlError::ReaderError("Column 'TABLE_TYPE' not found in ODBC result".into()) + })?; + let mut results = Vec::with_capacity(df.height()); + for i in 0..df.height() { + if !name_col.is_null(i) { + results.push(super::TableInfo { + name: crate::array_util::value_to_string(name_col, i), + table_type: crate::array_util::value_to_string(type_col, i), + }); + } + } + Ok(results) +} + +fn extract_column_infos_ci(df: &DataFrame) -> Result> { + let name_col = find_column_ci(df, "COLUMN_NAME").ok_or_else(|| { + GgsqlError::ReaderError("Column 'COLUMN_NAME' not found in ODBC result".into()) + })?; + let type_col = find_column_ci(df, "TYPE_NAME").ok_or_else(|| { + GgsqlError::ReaderError("Column 'TYPE_NAME' not found in ODBC result".into()) + })?; + let mut results = Vec::with_capacity(df.height()); + for i in 0..df.height() { + if !name_col.is_null(i) { + results.push(super::ColumnInfo { + name: crate::array_util::value_to_string(name_col, i), + data_type: crate::array_util::value_to_string(type_col, i), + }); + } + } + Ok(results) +} + +// ============================================================================ +// SQL type mapping +// ============================================================================ + +fn arrow_dtype_to_sql(dtype: &DataType) -> &'static str { + match dtype { + DataType::Boolean => "BOOLEAN", + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => "BIGINT", + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => "BIGINT", + DataType::Float32 | DataType::Float64 => "DOUBLE PRECISION", + DataType::Date32 => "DATE", + DataType::Timestamp(_, _) => "TIMESTAMP", + DataType::Time64(_) => "TIME", + _ => "TEXT", + } +} + +// ============================================================================ +// Column builder (accumulates typed values across batches) +// ============================================================================ + +enum ColumnBuilder { + Int8(Vec>), + Int16(Vec>), + Int32(Vec>), + Int64(Vec>), + Float32(Vec>), + Float64(Vec>), + Boolean(Vec>), + Date(Vec>), + Time(Vec>), + Timestamp(Vec>), + Text(Vec>), +} + +impl ColumnBuilder { + fn from_sql_type( + sql_type: SqlSmallInt, + col_size: SqlULen, + decimal_digits: SqlSmallInt, + ) -> Self { + match sql_type { + SQL_TINYINT => Self::Int8(Vec::new()), + SQL_SMALLINT => Self::Int16(Vec::new()), + SQL_INTEGER => Self::Int32(Vec::new()), + SQL_BIGINT => Self::Int64(Vec::new()), + SQL_REAL => Self::Float32(Vec::new()), + SQL_DOUBLE | SQL_FLOAT => Self::Float64(Vec::new()), + SQL_NUMERIC | SQL_DECIMAL => { + if decimal_digits == 0 { + if col_size < 10 { + Self::Int32(Vec::new()) + } else if col_size < 19 { + Self::Int64(Vec::new()) + } else { + Self::Float64(Vec::new()) + } + } else { + Self::Float64(Vec::new()) + } + } + SQL_BIT => Self::Boolean(Vec::new()), + SQL_TYPE_DATE => Self::Date(Vec::new()), + SQL_TYPE_TIME => Self::Time(Vec::new()), + SQL_TYPE_TIMESTAMP => Self::Timestamp(Vec::new()), + _ => Self::Text(Vec::new()), + } + } + + fn c_type(&self) -> SqlSmallInt { + match self { + Self::Int8(_) => SQL_C_STINYINT, + Self::Int16(_) => SQL_C_SSHORT, + Self::Int32(_) => SQL_C_SLONG, + Self::Int64(_) => SQL_C_SBIGINT, + Self::Float32(_) => SQL_C_FLOAT, + Self::Float64(_) => SQL_C_DOUBLE, + Self::Boolean(_) => SQL_C_BIT, + Self::Date(_) => SQL_C_TYPE_DATE, + Self::Time(_) => SQL_C_TYPE_TIME, + Self::Timestamp(_) => SQL_C_TYPE_TIMESTAMP, + Self::Text(_) => SQL_C_CHAR, + } + } + + fn element_size(&self) -> usize { + match self { + Self::Int8(_) => std::mem::size_of::(), + Self::Int16(_) => std::mem::size_of::(), + Self::Int32(_) => std::mem::size_of::(), + Self::Int64(_) => std::mem::size_of::(), + Self::Float32(_) => std::mem::size_of::(), + Self::Float64(_) => std::mem::size_of::(), + Self::Boolean(_) => 1, + Self::Date(_) => std::mem::size_of::(), + Self::Time(_) => std::mem::size_of::(), + Self::Timestamp(_) => std::mem::size_of::(), + Self::Text(_) => 0, // text uses a separate buffer + } + } + + fn into_named_array(self, name: &str) -> (String, ArrayRef) { + let array: ArrayRef = match self { + Self::Int8(v) => Arc::new(Int8Array::from(v)), + Self::Int16(v) => Arc::new(Int16Array::from(v)), + Self::Int32(v) => Arc::new(Int32Array::from(v)), + Self::Int64(v) => Arc::new(Int64Array::from(v)), + Self::Float32(v) => Arc::new(Float32Array::from(v)), + Self::Float64(v) => Arc::new(Float64Array::from(v)), + Self::Boolean(v) => Arc::new(BooleanArray::from(v)), + Self::Date(v) => Arc::new(Date32Array::from(v)), + Self::Time(v) => Arc::new(Time64NanosecondArray::from(v)), + Self::Timestamp(v) => Arc::new(TimestampMicrosecondArray::from(v)), + Self::Text(v) => { + let refs: Vec> = v.iter().map(|s| s.as_deref()).collect(); + Arc::new(StringArray::from(refs)) + } + }; + (name.to_string(), array) + } +} + +// ============================================================================ +// Date/time conversion helpers +// ============================================================================ + +fn odbc_date_to_days(d: &SqlDateStruct) -> Option { + chrono::NaiveDate::from_ymd_opt(d.year as i32, d.month as u32, d.day as u32).map(|date| { + let epoch = chrono::NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + (date - epoch).num_days() as i32 + }) +} + +fn odbc_time_to_nanos(t: &SqlTimeStruct) -> i64 { + let h = t.hour as i64; + let m = t.minute as i64; + let s = t.second as i64; + (h * 3600 + m * 60 + s) * 1_000_000_000 +} + +fn odbc_timestamp_to_micros(ts: &SqlTimestampStruct) -> Option { + chrono::NaiveDate::from_ymd_opt(ts.year as i32, ts.month as u32, ts.day as u32) + .and_then(|date| { + date.and_hms_nano_opt( + ts.hour as u32, + ts.minute as u32, + ts.second as u32, + ts.fraction, + ) + }) + .map(|dt| dt.and_utc().timestamp_micros()) +} + +// ============================================================================ +// Cursor → DataFrame conversion +// ============================================================================ + +const BATCH_SIZE: usize = 1000; +const DEFAULT_TEXT_BUF_SIZE: usize = 65536; + +struct ColumnBuffer { + data: Vec, + indicators: Vec, + text_buf_size: usize, +} + +fn cursor_to_dataframe(stmt: Statement) -> Result { + let col_count = stmt.num_result_cols()?; + if col_count == 0 { + return Ok(DataFrame::empty()); + } + + // Describe all columns + let mut col_names = Vec::with_capacity(col_count); + let mut builders = Vec::with_capacity(col_count); + + for i in 1..=col_count as u16 { + let (name, data_type, col_size, decimal_digits, _nullable) = stmt.describe_col(i)?; + col_names.push(name); + builders.push(ColumnBuilder::from_sql_type( + data_type, + col_size, + decimal_digits, + )); + } + + // Set up batch fetching + stmt.setup_batch_fetch(BATCH_SIZE)?; + let mut rows_fetched: SqlULen = 0; + unsafe { stmt.set_rows_fetched_ptr(&mut rows_fetched)? }; + + // Allocate and bind buffers + let mut buffers: Vec = builders + .iter() + .enumerate() + .map(|(i, builder)| { + let (elem_size, text_buf_size) = if matches!(builder, ColumnBuilder::Text(_)) { + (DEFAULT_TEXT_BUF_SIZE + 1, DEFAULT_TEXT_BUF_SIZE) // +1 for null terminator + } else { + (builder.element_size(), 0) + }; + + let data = vec![0u8; elem_size * BATCH_SIZE]; + let indicators = vec![0isize; BATCH_SIZE]; + + let col_num = (i + 1) as u16; + let c_type = builder.c_type(); + // We'll bind after creating the buffer + let _ = (col_num, c_type); + + ColumnBuffer { + data, + indicators, + text_buf_size, + } + }) + .collect(); + + // Bind columns + for (i, (builder, buf)) in builders.iter().zip(buffers.iter_mut()).enumerate() { + let col_num = (i + 1) as u16; + let c_type = builder.c_type(); + let elem_size = if matches!(builder, ColumnBuilder::Text(_)) { + (buf.text_buf_size + 1) as SqlLen + } else { + builder.element_size() as SqlLen + }; + + stmt.bind_col( + col_num, + c_type, + buf.data.as_mut_ptr() as SqlPointer, + elem_size, + buf.indicators.as_mut_ptr(), + )?; + } + + // Fetch loop + loop { + rows_fetched = 0; + let rc = stmt.fetch_raw(); + + match rc { + SQL_NO_DATA => break, + SQL_SUCCESS | SQL_SUCCESS_WITH_INFO => {} + _ => { + return Err(GgsqlError::ReaderError("Failed to fetch batch".to_string())); + } + } + + let n = rows_fetched as usize; + if n == 0 { + break; + } + + // Extract data from buffers into builders + for (col_idx, (builder, buf)) in builders.iter_mut().zip(buffers.iter()).enumerate() { + extract_batch(builder, buf, n, col_idx)?; + } + } + + // Convert builders to named arrays + let named_arrays: Vec<(String, ArrayRef)> = col_names + .iter() + .zip(builders) + .map(|(name, builder)| builder.into_named_array(name)) + .collect(); + + DataFrame::new(named_arrays) +} + +fn extract_batch( + builder: &mut ColumnBuilder, + buf: &ColumnBuffer, + num_rows: usize, + _col_idx: usize, +) -> Result<()> { + for row in 0..num_rows { + let indicator = buf.indicators[row]; + let is_null = indicator == SQL_NULL_DATA; + + match builder { + ColumnBuilder::Int8(v) => { + if is_null { + v.push(None); + } else { + let val = buf.data[row * std::mem::size_of::()] as i8; + v.push(Some(val)); + } + } + ColumnBuilder::Int16(v) => { + if is_null { + v.push(None); + } else { + let offset = row * std::mem::size_of::(); + let val = i16::from_ne_bytes(buf.data[offset..offset + 2].try_into().unwrap()); + v.push(Some(val)); + } + } + ColumnBuilder::Int32(v) => { + if is_null { + v.push(None); + } else { + let offset = row * std::mem::size_of::(); + let val = i32::from_ne_bytes(buf.data[offset..offset + 4].try_into().unwrap()); + v.push(Some(val)); + } + } + ColumnBuilder::Int64(v) => { + if is_null { + v.push(None); + } else { + let offset = row * std::mem::size_of::(); + let val = i64::from_ne_bytes(buf.data[offset..offset + 8].try_into().unwrap()); + v.push(Some(val)); + } + } + ColumnBuilder::Float32(v) => { + if is_null { + v.push(None); + } else { + let offset = row * std::mem::size_of::(); + let val = f32::from_ne_bytes(buf.data[offset..offset + 4].try_into().unwrap()); + v.push(Some(val)); + } + } + ColumnBuilder::Float64(v) => { + if is_null { + v.push(None); + } else { + let offset = row * std::mem::size_of::(); + let val = f64::from_ne_bytes(buf.data[offset..offset + 8].try_into().unwrap()); + v.push(Some(val)); + } + } + ColumnBuilder::Boolean(v) => { + if is_null { + v.push(None); + } else { + v.push(Some(buf.data[row] != 0)); + } + } + ColumnBuilder::Date(v) => { + if is_null { + v.push(None); + } else { + let size = std::mem::size_of::(); + let offset = row * size; + let d: SqlDateStruct = + unsafe { std::ptr::read_unaligned(buf.data[offset..].as_ptr() as *const _) }; + v.push(odbc_date_to_days(&d)); + } + } + ColumnBuilder::Time(v) => { + if is_null { + v.push(None); + } else { + let size = std::mem::size_of::(); + let offset = row * size; + let t: SqlTimeStruct = + unsafe { std::ptr::read_unaligned(buf.data[offset..].as_ptr() as *const _) }; + v.push(Some(odbc_time_to_nanos(&t))); + } + } + ColumnBuilder::Timestamp(v) => { + if is_null { + v.push(None); + } else { + let size = std::mem::size_of::(); + let offset = row * size; + let ts: SqlTimestampStruct = + unsafe { std::ptr::read_unaligned(buf.data[offset..].as_ptr() as *const _) }; + v.push(odbc_timestamp_to_micros(&ts)); + } + } + ColumnBuilder::Text(v) => { + if is_null { + v.push(None); + } else { + let elem_size = buf.text_buf_size + 1; + let offset = row * elem_size; + let len = indicator as usize; + let actual_len = len.min(buf.text_buf_size); + let bytes = &buf.data[offset..offset + actual_len]; + let s = String::from_utf8_lossy(bytes).into_owned(); + v.push(Some(s)); + } + } + } + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_detect_dialect_from_dbms_name() { + let d = detect_dialect(Some("Snowflake"), "anything"); + assert!(!d.sql_greatest(&["a", "b"]).is_empty()); + + let d = detect_dialect(None, "Driver=Snowflake;Server=foo"); + assert!(!d.sql_greatest(&["a", "b"]).is_empty()); + + let d = detect_dialect(None, "Driver=SomeOther;Server=localhost"); + assert!(!d.sql_greatest(&["a", "b"]).is_empty()); + } + + #[test] + fn test_arrow_dtype_to_sql() { + assert_eq!(arrow_dtype_to_sql(&DataType::Int64), "BIGINT"); + assert_eq!(arrow_dtype_to_sql(&DataType::Float64), "DOUBLE PRECISION"); + assert_eq!(arrow_dtype_to_sql(&DataType::Boolean), "BOOLEAN"); + assert_eq!(arrow_dtype_to_sql(&DataType::Date32), "DATE"); + assert_eq!(arrow_dtype_to_sql(&DataType::Utf8), "TEXT"); + } + + #[test] + fn test_column_builder_from_sql_type() { + assert!(matches!( + ColumnBuilder::from_sql_type(SQL_TINYINT, 0, 0), + ColumnBuilder::Int8(_) + )); + assert!(matches!( + ColumnBuilder::from_sql_type(SQL_SMALLINT, 0, 0), + ColumnBuilder::Int16(_) + )); + assert!(matches!( + ColumnBuilder::from_sql_type(SQL_INTEGER, 0, 0), + ColumnBuilder::Int32(_) + )); + assert!(matches!( + ColumnBuilder::from_sql_type(SQL_BIGINT, 0, 0), + ColumnBuilder::Int64(_) + )); + assert!(matches!( + ColumnBuilder::from_sql_type(SQL_REAL, 0, 0), + ColumnBuilder::Float32(_) + )); + assert!(matches!( + ColumnBuilder::from_sql_type(SQL_DOUBLE, 0, 0), + ColumnBuilder::Float64(_) + )); + assert!(matches!( + ColumnBuilder::from_sql_type(SQL_FLOAT, 0, 0), + ColumnBuilder::Float64(_) + )); + assert!(matches!( + ColumnBuilder::from_sql_type(SQL_BIT, 0, 0), + ColumnBuilder::Boolean(_) + )); + assert!(matches!( + ColumnBuilder::from_sql_type(SQL_TYPE_DATE, 0, 0), + ColumnBuilder::Date(_) + )); + assert!(matches!( + ColumnBuilder::from_sql_type(SQL_TYPE_TIME, 0, 0), + ColumnBuilder::Time(_) + )); + assert!(matches!( + ColumnBuilder::from_sql_type(SQL_TYPE_TIMESTAMP, 0, 0), + ColumnBuilder::Timestamp(_) + )); + assert!(matches!( + ColumnBuilder::from_sql_type(SQL_VARCHAR, 0, 0), + ColumnBuilder::Text(_) + )); + assert!(matches!( + ColumnBuilder::from_sql_type(SQL_WVARCHAR, 0, 0), + ColumnBuilder::Text(_) + )); + assert!(matches!( + ColumnBuilder::from_sql_type(SQL_LONGVARCHAR, 0, 0), + ColumnBuilder::Text(_) + )); + // Decimal with scale=0 maps to integer types based on precision + assert!(matches!( + ColumnBuilder::from_sql_type(SQL_NUMERIC, 5, 0), + ColumnBuilder::Int32(_) + )); + assert!(matches!( + ColumnBuilder::from_sql_type(SQL_NUMERIC, 15, 0), + ColumnBuilder::Int64(_) + )); + assert!(matches!( + ColumnBuilder::from_sql_type(SQL_NUMERIC, 25, 0), + ColumnBuilder::Float64(_) + )); + // Decimal with scale>0 maps to Float64 + assert!(matches!( + ColumnBuilder::from_sql_type(SQL_DECIMAL, 10, 2), + ColumnBuilder::Float64(_) + )); + } + + #[test] + fn test_column_builder_c_types() { + assert_eq!(ColumnBuilder::Int8(vec![]).c_type(), SQL_C_STINYINT); + assert_eq!(ColumnBuilder::Int16(vec![]).c_type(), SQL_C_SSHORT); + assert_eq!(ColumnBuilder::Int32(vec![]).c_type(), SQL_C_SLONG); + assert_eq!(ColumnBuilder::Int64(vec![]).c_type(), SQL_C_SBIGINT); + assert_eq!(ColumnBuilder::Float32(vec![]).c_type(), SQL_C_FLOAT); + assert_eq!(ColumnBuilder::Float64(vec![]).c_type(), SQL_C_DOUBLE); + assert_eq!(ColumnBuilder::Boolean(vec![]).c_type(), SQL_C_BIT); + assert_eq!(ColumnBuilder::Date(vec![]).c_type(), SQL_C_TYPE_DATE); + assert_eq!(ColumnBuilder::Time(vec![]).c_type(), SQL_C_TYPE_TIME); + assert_eq!( + ColumnBuilder::Timestamp(vec![]).c_type(), + SQL_C_TYPE_TIMESTAMP + ); + assert_eq!(ColumnBuilder::Text(vec![]).c_type(), SQL_C_CHAR); + } + + #[test] + fn test_column_builder_element_sizes() { + assert_eq!(ColumnBuilder::Int8(vec![]).element_size(), 1); + assert_eq!(ColumnBuilder::Int16(vec![]).element_size(), 2); + assert_eq!(ColumnBuilder::Int32(vec![]).element_size(), 4); + assert_eq!(ColumnBuilder::Int64(vec![]).element_size(), 8); + assert_eq!(ColumnBuilder::Float32(vec![]).element_size(), 4); + assert_eq!(ColumnBuilder::Float64(vec![]).element_size(), 8); + assert_eq!(ColumnBuilder::Boolean(vec![]).element_size(), 1); + assert!(ColumnBuilder::Date(vec![]).element_size() >= 6); + assert!(ColumnBuilder::Time(vec![]).element_size() >= 6); + assert!(ColumnBuilder::Timestamp(vec![]).element_size() >= 14); + assert_eq!(ColumnBuilder::Text(vec![]).element_size(), 0); + } + + #[test] + fn test_column_builder_into_named_array() { + let builder = ColumnBuilder::Int64(vec![Some(1), None, Some(3)]); + let (name, array) = builder.into_named_array("col"); + assert_eq!(name, "col"); + assert_eq!(array.len(), 3); + assert!(!array.is_null(0)); + assert!(array.is_null(1)); + assert!(!array.is_null(2)); + + let builder = ColumnBuilder::Text(vec![Some("hello".into()), None]); + let (_, array) = builder.into_named_array("t"); + assert_eq!(array.len(), 2); + assert!(!array.is_null(0)); + assert!(array.is_null(1)); + + let builder = ColumnBuilder::Boolean(vec![Some(true), Some(false)]); + let (_, array) = builder.into_named_array("b"); + assert_eq!(array.len(), 2); + } + + #[test] + fn test_odbc_date_to_days() { + let d = SqlDateStruct { + year: 1970, + month: 1, + day: 1, + }; + assert_eq!(odbc_date_to_days(&d), Some(0)); + + let d = SqlDateStruct { + year: 2000, + month: 1, + day: 1, + }; + assert_eq!(odbc_date_to_days(&d), Some(10957)); + + let d = SqlDateStruct { + year: 2024, + month: 2, + day: 29, + }; + assert!(odbc_date_to_days(&d).is_some()); + + let d = SqlDateStruct { + year: 2024, + month: 13, + day: 1, + }; + assert_eq!(odbc_date_to_days(&d), None); + } + + #[test] + fn test_odbc_time_to_nanos() { + let t = SqlTimeStruct { + hour: 0, + minute: 0, + second: 0, + }; + assert_eq!(odbc_time_to_nanos(&t), 0); + + let t = SqlTimeStruct { + hour: 1, + minute: 30, + second: 45, + }; + assert_eq!(odbc_time_to_nanos(&t), (3600 + 1800 + 45) * 1_000_000_000); + + let t = SqlTimeStruct { + hour: 23, + minute: 59, + second: 59, + }; + assert_eq!(odbc_time_to_nanos(&t), (86399) * 1_000_000_000); + } + + #[test] + fn test_odbc_timestamp_to_micros() { + let ts = SqlTimestampStruct { + year: 1970, + month: 1, + day: 1, + hour: 0, + minute: 0, + second: 0, + fraction: 0, + }; + assert_eq!(odbc_timestamp_to_micros(&ts), Some(0)); + + let ts = SqlTimestampStruct { + year: 2024, + month: 6, + day: 15, + hour: 12, + minute: 30, + second: 45, + fraction: 0, + }; + assert!(odbc_timestamp_to_micros(&ts).unwrap() > 0); + + let ts = SqlTimestampStruct { + year: 2024, + month: 13, + day: 1, + hour: 0, + minute: 0, + second: 0, + fraction: 0, + }; + assert_eq!(odbc_timestamp_to_micros(&ts), None); + } + + #[test] + fn test_succeeded() { + assert!(ffi::succeeded(SQL_SUCCESS)); + assert!(ffi::succeeded(SQL_SUCCESS_WITH_INFO)); + assert!(!ffi::succeeded(SQL_ERROR)); + assert!(!ffi::succeeded(SQL_NO_DATA)); + assert!(!ffi::succeeded(SQL_INVALID_HANDLE)); + } + + #[test] + fn test_try_load_error_without_odbc() { + // This tests the error path — if ODBC *is* installed it returns Ok, + // if not it returns a descriptive error. Either way it shouldn't panic. + let result = ffi::try_load(); + match result { + Ok(()) => {} // ODBC is available on this machine + Err(e) => { + assert!( + e.contains("ODBC driver manager not found") || e.contains("GGSQL_ODBC_LIBRARY") + ); + } + } + } + + #[test] + fn test_connect_missing_prefix() { + let result = OdbcReader::from_connection_string("DSN=foo"); + match result { + Err(e) => assert!(e.to_string().contains("odbc://"), "Got: {}", e), + Ok(_) => panic!("Should have failed without odbc:// prefix"), + } + } + + // ======================================================================== + // ODBC integration tests + // + // These require ODBC DSNs to be configured on the machine. Run with: + // cargo test --package ggsql -- odbc::tests --include-ignored --nocapture + // ======================================================================== + + fn try_connect(dsn: &str) -> Option { + OdbcReader::from_connection_string(&format!("odbc://DSN={}", dsn)).ok() + } + + // --- PostgreSQL via ODBC ------------------------------------------------- + + const PG_DSN: &str = "ggsql-pg-test"; + + #[test] + #[ignore] + fn pg_connect_and_detect_dialect() { + let reader = try_connect(PG_DSN).expect("Cannot connect to PostgreSQL ODBC DSN"); + assert_eq!(reader.connection.dbms_name().as_deref(), Some("PostgreSQL")); + } + + #[test] + #[ignore] + fn pg_execute_sql_integer() { + let reader = try_connect(PG_DSN).unwrap(); + let df = reader.execute_sql("SELECT 42 AS value").unwrap(); + assert_eq!(df.height(), 1); + assert_eq!(df.width(), 1); + let col = df.column("value").unwrap(); + assert_eq!(crate::array_util::value_to_string(col, 0), "42"); + } + + #[test] + #[ignore] + fn pg_execute_sql_float() { + let reader = try_connect(PG_DSN).unwrap(); + let df = reader + .execute_sql("SELECT 4.28::double precision AS foo") + .unwrap(); + assert_eq!(df.height(), 1); + let col = df.column("foo").unwrap(); + let val: f64 = crate::array_util::value_to_string(col, 0).parse().unwrap(); + assert!((val - 4.28).abs() < 0.001); + } + + #[test] + #[ignore] + fn pg_execute_sql_boolean() { + let reader = try_connect(PG_DSN).unwrap(); + let df = reader.execute_sql("SELECT true AS t, false AS f").unwrap(); + assert_eq!(df.height(), 1); + let t = df.column("t").unwrap(); + let f = df.column("f").unwrap(); + assert!(!t.is_null(0)); + assert!(!f.is_null(0)); + } + + #[test] + #[ignore] + fn pg_execute_sql_date() { + let reader = try_connect(PG_DSN).unwrap(); + let df = reader.execute_sql("SELECT DATE '2024-06-15' AS d").unwrap(); + assert_eq!(df.height(), 1); + let col = df.column("d").unwrap(); + assert!(!col.is_null(0)); + } + + #[test] + #[ignore] + fn pg_execute_sql_timestamp() { + let reader = try_connect(PG_DSN).unwrap(); + let df = reader + .execute_sql("SELECT TIMESTAMP '2024-06-15 12:30:45' AS ts") + .unwrap(); + assert_eq!(df.height(), 1); + let col = df.column("ts").unwrap(); + assert!(!col.is_null(0)); + } + + #[test] + #[ignore] + fn pg_execute_sql_multiple_types() { + let reader = try_connect(PG_DSN).unwrap(); + let df = reader + .execute_sql("SELECT 1 AS i, 2.5::double precision AS f, 'hello' AS s, true AS b") + .unwrap(); + assert_eq!(df.height(), 1); + assert_eq!(df.width(), 4); + } + + #[test] + #[ignore] + fn pg_execute_sql_multiple_rows() { + let reader = try_connect(PG_DSN).unwrap(); + let df = reader + .execute_sql("SELECT generate_series AS n FROM generate_series(1, 100)") + .unwrap(); + assert_eq!(df.height(), 100); + } + + #[test] + #[ignore] + fn pg_execute_sql_large_batch() { + let reader = try_connect(PG_DSN).unwrap(); + let df = reader + .execute_sql("SELECT generate_series AS n FROM generate_series(1, 5000)") + .unwrap(); + assert_eq!(df.height(), 5000); + let col = df.column("n").unwrap(); + assert!(!col.is_null(0)); + assert!(!col.is_null(4999)); + } + + #[test] + #[ignore] + fn pg_execute_sql_nulls() { + let reader = try_connect(PG_DSN).unwrap(); + let df = reader + .execute_sql("SELECT NULL::integer AS x UNION ALL SELECT 1") + .unwrap(); + assert_eq!(df.height(), 2); + let col = df.column("x").unwrap(); + assert!(col.is_null(0)); + assert!(!col.is_null(1)); + } + + #[test] + #[ignore] + fn pg_execute_sql_null_text() { + let reader = try_connect(PG_DSN).unwrap(); + let df = reader + .execute_sql("SELECT NULL::text AS s UNION ALL SELECT 'hello'") + .unwrap(); + assert_eq!(df.height(), 2); + let col = df.column("s").unwrap(); + assert!(col.is_null(0)); + assert_eq!(crate::array_util::value_to_string(col, 1), "hello"); + } + + #[test] + #[ignore] + fn pg_execute_sql_empty_result() { + let reader = try_connect(PG_DSN).unwrap(); + let df = reader.execute_sql("SELECT 1 AS x WHERE false").unwrap(); + assert_eq!(df.height(), 0); + assert_eq!(df.width(), 1); + } + + #[test] + #[ignore] + fn pg_execute_sql_ddl_returns_empty() { + let reader = try_connect(PG_DSN).unwrap(); + let df = reader + .execute_sql("CREATE TEMPORARY TABLE __ggsql_test_ddl (x int)") + .unwrap(); + assert_eq!(df.height(), 0); + let _ = reader.execute_sql("DROP TABLE IF EXISTS __ggsql_test_ddl"); + } + + #[test] + #[ignore] + fn pg_execute_sql_unicode() { + let reader = try_connect(PG_DSN).unwrap(); + let df = reader + .execute_sql("SELECT 'héllo wörld 日本語' AS s") + .unwrap(); + assert_eq!(df.height(), 1); + let col = df.column("s").unwrap(); + let val = crate::array_util::value_to_string(col, 0); + assert!(val.contains("héllo")); + assert!(val.contains("日本語")); + } + + #[test] + #[ignore] + fn pg_list_catalogs() { + use crate::reader::Reader; + let reader = try_connect(PG_DSN).unwrap(); + let catalogs = reader.list_catalogs().unwrap(); + assert!(!catalogs.is_empty(), "Should have at least one catalog"); + } + + #[test] + #[ignore] + fn pg_list_schemas() { + use crate::reader::Reader; + let reader = try_connect(PG_DSN).unwrap(); + let schemas = reader.list_schemas("").unwrap(); + assert!( + schemas.iter().any(|s| s == "public"), + "Should contain 'public' schema, got: {:?}", + schemas + ); + } + + #[test] + #[ignore] + fn pg_list_tables() { + use crate::reader::Reader; + let reader = try_connect(PG_DSN).unwrap(); + let tables = reader.list_tables("", "public").unwrap(); + for t in &tables { + assert!(!t.name.is_empty()); + assert!(!t.table_type.is_empty()); + } + } + + #[test] + #[ignore] + fn pg_list_columns() { + use crate::reader::Reader; + let reader = try_connect(PG_DSN).unwrap(); + // Create a temp table so we have something to list columns for + let _ = reader.execute_sql( + "CREATE TEMPORARY TABLE __ggsql_test_cols (id int, name text, score double precision)", + ); + let cols = reader.list_columns("", "pg_temp_3", "__ggsql_test_cols"); + // May fail if schema name differs, just check it doesn't crash + if let Ok(cols) = cols { + if !cols.is_empty() { + assert!(!cols[0].name.is_empty()); + assert!(!cols[0].data_type.is_empty()); + } + } + let _ = reader.execute_sql("DROP TABLE IF EXISTS __ggsql_test_cols"); + } + + #[test] + #[ignore] + fn pg_register_and_query() { + use crate::reader::Reader; + let reader = try_connect(PG_DSN).unwrap(); + + let df = crate::df!( + "name" => vec!["alice", "bob", "carol"], + "score" => vec![85i64, 92, 78] + ) + .unwrap(); + + reader.register("__ggsql_test_reg", df, true).unwrap(); + + let result = reader + .execute_sql("SELECT name, score FROM __ggsql_test_reg ORDER BY name") + .unwrap(); + assert_eq!(result.height(), 3); + + let name_col = result.column("name").unwrap(); + assert_eq!(crate::array_util::value_to_string(name_col, 0), "alice"); + + reader.unregister("__ggsql_test_reg").unwrap(); + } + + #[test] + #[ignore] + fn pg_register_with_nulls() { + use crate::reader::Reader; + let reader = try_connect(PG_DSN).unwrap(); + + let df = crate::df!( + "x" => vec![Some(1i64), None, Some(3)], + "s" => vec![Some("a"), None, Some("c")] + ) + .unwrap(); + + reader.register("__ggsql_test_null", df, true).unwrap(); + + let result = reader + .execute_sql("SELECT x, s FROM __ggsql_test_null ORDER BY x NULLS FIRST") + .unwrap(); + assert_eq!(result.height(), 3); + let x = result.column("x").unwrap(); + assert!(x.is_null(0)); + assert!(!x.is_null(1)); + + reader.unregister("__ggsql_test_null").unwrap(); + } + + #[test] + #[ignore] + fn pg_unregister_nonexistent_errors() { + use crate::reader::Reader; + let reader = try_connect(PG_DSN).unwrap(); + assert!(reader.unregister("__does_not_exist__").is_err()); + } + + // --- SQLite via ODBC ----------------------------------------------------- + + const SQLITE_DSN: &str = "ggsql-sqlite-test"; + + #[test] + #[ignore] + fn sqlite_connect_and_detect_dialect() { + let reader = try_connect(SQLITE_DSN).expect("Cannot connect to SQLite ODBC DSN"); + let dbms = reader.connection.dbms_name().unwrap_or_default(); + assert!( + dbms.to_lowercase().contains("sqlite"), + "Expected SQLite, got: {}", + dbms + ); + } + + #[test] + #[ignore] + fn sqlite_execute_sql_integer() { + let reader = try_connect(SQLITE_DSN).unwrap(); + let df = reader.execute_sql("SELECT 42 AS value").unwrap(); + assert_eq!(df.height(), 1); + let col = df.column("value").unwrap(); + assert_eq!(crate::array_util::value_to_string(col, 0), "42"); + } + + #[test] + #[ignore] + fn sqlite_execute_sql_text() { + let reader = try_connect(SQLITE_DSN).unwrap(); + let df = reader + .execute_sql("SELECT 'hello world' AS greeting") + .unwrap(); + assert_eq!(df.height(), 1); + let col = df.column("greeting").unwrap(); + assert_eq!(crate::array_util::value_to_string(col, 0), "hello world"); + } + + #[test] + #[ignore] + fn sqlite_execute_sql_multiple_rows() { + let reader = try_connect(SQLITE_DSN).unwrap(); + // SQLite doesn't have generate_series by default, use VALUES + let df = reader + .execute_sql( + "WITH RECURSIVE cnt(x) AS (SELECT 1 UNION ALL SELECT x+1 FROM cnt WHERE x < 50) SELECT x FROM cnt", + ) + .unwrap(); + assert_eq!(df.height(), 50); + } + + #[test] + #[ignore] + fn sqlite_execute_sql_nulls() { + let reader = try_connect(SQLITE_DSN).unwrap(); + let df = reader + .execute_sql("SELECT NULL AS x UNION ALL SELECT 1") + .unwrap(); + assert_eq!(df.height(), 2); + let col = df.column("x").unwrap(); + assert!(col.is_null(0)); + } + + #[test] + #[ignore] + fn sqlite_list_catalogs_empty() { + use crate::reader::Reader; + let reader = try_connect(SQLITE_DSN).unwrap(); + let catalogs = reader.list_catalogs().unwrap(); + // SQLite ODBC driver typically returns no catalogs + let _ = catalogs; + } + + #[test] + #[ignore] + fn sqlite_list_schemas() { + use crate::reader::Reader; + let reader = try_connect(SQLITE_DSN).unwrap(); + let schemas = reader.list_schemas("").unwrap(); + let _ = schemas; + } + + #[test] + #[ignore] + fn sqlite_list_tables() { + use crate::reader::Reader; + let reader = try_connect(SQLITE_DSN).unwrap(); + let tables = reader.list_tables("", "").unwrap(); + let _ = tables; + } + + #[test] + #[ignore] + fn sqlite_list_columns() { + use crate::reader::Reader; + let reader = try_connect(SQLITE_DSN).unwrap(); + // Create a table to list columns for + let _ = reader + .execute_sql("CREATE TABLE IF NOT EXISTS __ggsql_test_cols (id INTEGER, name TEXT)"); + let cols = reader.list_columns("", "", "__ggsql_test_cols").unwrap(); + if !cols.is_empty() { + assert!(!cols[0].name.is_empty()); + } + let _ = reader.execute_sql("DROP TABLE IF EXISTS __ggsql_test_cols"); + } + + #[test] + #[ignore] + fn sqlite_register_and_query() { + use crate::reader::Reader; + let reader = try_connect(SQLITE_DSN).unwrap(); + + let df = crate::df!( + "x" => vec![1i64, 2, 3], + "y" => vec![10i64, 20, 30] + ) + .unwrap(); + + reader.register("__ggsql_test_reg", df, true).unwrap(); + + let result = reader + .execute_sql("SELECT x, y FROM __ggsql_test_reg ORDER BY x") + .unwrap(); + assert_eq!(result.height(), 3); + + reader.unregister("__ggsql_test_reg").unwrap(); + } +} diff --git a/src/reader/odbc/snowflake.rs b/src/reader/odbc/snowflake.rs new file mode 100644 index 00000000..165ab689 --- /dev/null +++ b/src/reader/odbc/snowflake.rs @@ -0,0 +1,248 @@ +//! Snowflake Workbench credential detection and connection resolution. + +pub(super) fn is_snowflake(conn_str: &str) -> bool { + conn_str.to_lowercase().contains("driver=snowflake") +} + +pub(super) fn has_token(conn_str: &str) -> bool { + conn_str.to_lowercase().contains("token=") +} + +fn home_dir() -> Option { + #[cfg(target_os = "windows")] + { + std::env::var("USERPROFILE") + .ok() + .map(std::path::PathBuf::from) + } + #[cfg(not(target_os = "windows"))] + { + std::env::var("HOME").ok().map(std::path::PathBuf::from) + } +} + +/// Find the Snowflake connections.toml file, checking standard locations. +fn find_snowflake_connections_toml() -> Option { + use std::path::PathBuf; + + if let Ok(snowflake_home) = std::env::var("SNOWFLAKE_HOME") { + let p = PathBuf::from(&snowflake_home).join("connections.toml"); + if p.exists() { + return Some(p); + } + } + + if let Some(home) = home_dir() { + let p = home.join(".snowflake").join("connections.toml"); + if p.exists() { + return Some(p); + } + } + + if let Some(home) = home_dir() { + #[cfg(target_os = "macos")] + { + let p = home.join("Library/Application Support/snowflake/connections.toml"); + if p.exists() { + return Some(p); + } + } + + #[cfg(target_os = "linux")] + { + let xdg = std::env::var("XDG_CONFIG_HOME") + .map(PathBuf::from) + .unwrap_or_else(|_| home.join(".config")); + let p = xdg.join("snowflake").join("connections.toml"); + if p.exists() { + return Some(p); + } + } + + #[cfg(target_os = "windows")] + { + let p = home.join("AppData/Local/snowflake/connections.toml"); + if p.exists() { + return Some(p); + } + } + } + + None +} + +/// Resolve a `ConnectionName=` parameter in a Snowflake ODBC connection +/// string by reading the named entry from `~/.snowflake/connections.toml` and +/// building a full ODBC connection string from it. +pub(super) fn resolve_connection_name(conn_str: &str) -> Option { + let lower = conn_str.to_lowercase(); + let cn_key = "connectionname="; + let cn_start = lower.find(cn_key)?; + let value_start = cn_start + cn_key.len(); + + let rest = &conn_str[value_start..]; + let value_end = rest.find(';').unwrap_or(rest.len()); + let connection_name = rest[..value_end].trim(); + + if connection_name.is_empty() { + return None; + } + + let toml_path = find_snowflake_connections_toml()?; + let content = std::fs::read_to_string(&toml_path).ok()?; + let doc = content.parse::().ok()?; + + let entry = doc.get(connection_name)?; + if !entry.is_table() && !entry.is_inline_table() { + return None; + } + + let get_str = |key: &str| -> Option { entry.get(key)?.as_str().map(|s| s.to_string()) }; + + let account = get_str("account")?; + let mut parts = vec![ + "Driver=Snowflake".to_string(), + format!("Server={}.snowflakecomputing.com", account), + ]; + + if let Some(user) = get_str("user") { + parts.push(format!("UID={}", user)); + } + if let Some(password) = get_str("password") { + parts.push(format!("PWD={}", password)); + } + if let Some(authenticator) = get_str("authenticator") { + parts.push(format!("Authenticator={}", authenticator)); + } + if let Some(token) = get_str("token") { + parts.push(format!("Token={}", token)); + } + if let Some(warehouse) = get_str("warehouse") { + parts.push(format!("Warehouse={}", warehouse)); + } + if let Some(database) = get_str("database") { + parts.push(format!("Database={}", database)); + } + if let Some(schema) = get_str("schema") { + parts.push(format!("Schema={}", schema)); + } + if let Some(role) = get_str("role") { + parts.push(format!("Role={}", role)); + } + + Some(parts.join(";")) +} + +/// Detect Posit Workbench Snowflake OAuth token. +pub(super) fn detect_workbench_token() -> Option { + let snowflake_home = std::env::var("SNOWFLAKE_HOME").ok()?; + + if !snowflake_home.contains("posit-workbench") { + return None; + } + + let toml_path = std::path::Path::new(&snowflake_home).join("connections.toml"); + let content = std::fs::read_to_string(&toml_path).ok()?; + + let doc = content.parse::().ok()?; + let token = doc.get("workbench")?.get("token")?.as_str()?.to_string(); + + if token.is_empty() { + None + } else { + Some(token) + } +} + +/// Inject OAuth token into a Snowflake ODBC connection string. +pub(super) fn inject_snowflake_token(conn_str: &str, token: &str) -> String { + let mut result = conn_str.trim_end_matches(';').to_string(); + result.push_str(";Authenticator=oauth;Token="); + result.push_str(token); + result +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_is_snowflake() { + assert!(is_snowflake( + "Driver=Snowflake;Server=foo.snowflakecomputing.com" + )); + assert!(!is_snowflake("Driver={PostgreSQL};Server=localhost")); + } + + #[test] + fn test_has_token() { + assert!(has_token("Driver=Snowflake;Token=abc123")); + assert!(!has_token("Driver=Snowflake;Server=foo")); + } + + #[test] + fn test_inject_snowflake_token() { + let result = inject_snowflake_token( + "Driver=Snowflake;Server=foo.snowflakecomputing.com", + "mytoken", + ); + assert!(result.contains("Authenticator=oauth")); + assert!(result.contains("Token=mytoken")); + } + + #[test] + fn test_resolve_connection_name_with_toml() { + use std::io::Write; + + let dir = tempfile::tempdir().unwrap(); + let toml_path = dir.path().join("connections.toml"); + let mut f = std::fs::File::create(&toml_path).unwrap(); + writeln!( + f, + r#" +default_connection_name = "myconn" + +[myconn] +account = "myaccount" +user = "myuser" +password = "mypass" +warehouse = "mywh" +database = "mydb" +schema = "public" +role = "myrole" + +[other] +account = "otheraccount" +"# + ) + .unwrap(); + + std::env::set_var("SNOWFLAKE_HOME", dir.path()); + + let result = resolve_connection_name("Driver=Snowflake;ConnectionName=myconn"); + assert!(result.is_some()); + let conn = result.unwrap(); + assert!(conn.contains("Driver=Snowflake")); + assert!(conn.contains("Server=myaccount.snowflakecomputing.com")); + assert!(conn.contains("UID=myuser")); + assert!(conn.contains("PWD=mypass")); + assert!(conn.contains("Warehouse=mywh")); + assert!(conn.contains("Database=mydb")); + assert!(conn.contains("Schema=public")); + assert!(conn.contains("Role=myrole")); + + let result2 = resolve_connection_name("Driver=Snowflake;ConnectionName=other"); + assert!(result2.is_some()); + let conn2 = result2.unwrap(); + assert!(conn2.contains("Server=otheraccount.snowflakecomputing.com")); + assert!(!conn2.contains("UID=")); + + let result3 = resolve_connection_name("Driver=Snowflake;ConnectionName=nonexistent"); + assert!(result3.is_none()); + + let result4 = resolve_connection_name("Driver=Snowflake;Server=foo"); + assert!(result4.is_none()); + + std::env::remove_var("SNOWFLAKE_HOME"); + } +} diff --git a/src/reader/odbc/wrapper.rs b/src/reader/odbc/wrapper.rs new file mode 100644 index 00000000..5e3a105f --- /dev/null +++ b/src/reader/odbc/wrapper.rs @@ -0,0 +1,674 @@ +//! Safe Rust wrappers around ODBC handles. +//! +//! Provides `Environment`, `Connection`, `Statement`, and `PreparedStatement` +//! types that own ODBC handles and call through the runtime-loaded FFI. + +use super::ffi::*; +use crate::{GgsqlError, Result}; +use std::sync::OnceLock; + +// ============================================================================ +// Diagnostic helpers +// ============================================================================ + +fn extract_diagnostic(handle_type: SqlSmallInt, handle: SqlHandle) -> String { + let f = fns(); + let mut state = [0u8; 6]; + let mut native_error: SqlInteger = 0; + let mut buf = vec![0u8; 512]; + let mut text_len: SqlSmallInt = 0; + + let rc = unsafe { + (f.SQLGetDiagRec)( + handle_type, + handle, + 1, + state.as_mut_ptr(), + &mut native_error, + buf.as_mut_ptr(), + buf.len() as SqlSmallInt, + &mut text_len, + ) + }; + + if !succeeded(rc) { + return "Unknown ODBC error (no diagnostic record)".to_string(); + } + + // Retry with larger buffer if truncated + if text_len as usize >= buf.len() { + buf.resize(text_len as usize + 1, 0); + unsafe { + (f.SQLGetDiagRec)( + handle_type, + handle, + 1, + state.as_mut_ptr(), + &mut native_error, + buf.as_mut_ptr(), + buf.len() as SqlSmallInt, + &mut text_len, + ); + } + } + + let state_str = std::str::from_utf8(&state[..5]).unwrap_or("?????"); + let msg = std::str::from_utf8(&buf[..text_len as usize]).unwrap_or("(invalid UTF-8)"); + format!("[{}] {}", state_str, msg) +} + +fn check(rc: SqlReturn, handle_type: SqlSmallInt, handle: SqlHandle, context: &str) -> Result<()> { + match rc { + SQL_SUCCESS | SQL_SUCCESS_WITH_INFO => Ok(()), + SQL_NO_DATA => Ok(()), + _ => { + let diag = extract_diagnostic(handle_type, handle); + Err(GgsqlError::ReaderError(format!("{}: {}", context, diag))) + } + } +} + +/// Check if the SQLSTATE from the last operation matches a specific code. +fn sqlstate_is(handle_type: SqlSmallInt, handle: SqlHandle, expected: &[u8; 5]) -> bool { + let f = fns(); + let mut state = [0u8; 6]; + let mut native_error: SqlInteger = 0; + let mut text_len: SqlSmallInt = 0; + let rc = unsafe { + (f.SQLGetDiagRec)( + handle_type, + handle, + 1, + state.as_mut_ptr(), + &mut native_error, + std::ptr::null_mut(), + 0, + &mut text_len, + ) + }; + succeeded(rc) && state[..5] == expected[..] +} + +// ============================================================================ +// Environment +// ============================================================================ + +pub struct Environment { + handle: SqlHEnv, +} + +unsafe impl Send for Environment {} +unsafe impl Sync for Environment {} + +impl Environment { + fn new() -> Result { + let f = fns(); + let mut handle = SQL_NULL_HANDLE; + let rc = unsafe { (f.SQLAllocHandle)(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &mut handle) }; + if !succeeded(rc) { + return Err(GgsqlError::ReaderError( + "Failed to allocate ODBC environment handle".into(), + )); + } + + let rc = unsafe { + (f.SQLSetEnvAttr)(handle, SQL_ATTR_ODBC_VERSION, SQL_OV_ODBC3 as SqlPointer, 0) + }; + check(rc, SQL_HANDLE_ENV, handle, "Failed to set ODBC version")?; + + Ok(Environment { handle }) + } + + pub fn handle(&self) -> SqlHEnv { + self.handle + } +} + +impl Drop for Environment { + fn drop(&mut self) { + let f = fns(); + unsafe { (f.SQLFreeHandle)(SQL_HANDLE_ENV, self.handle) }; + } +} + +/// Global ODBC environment (singleton per process). +pub fn odbc_env() -> Result<&'static Environment> { + static ENV: OnceLock> = OnceLock::new(); + let result = ENV.get_or_init(|| Environment::new().map_err(|e| e.to_string())); + match result { + Ok(env) => Ok(env), + Err(e) => Err(GgsqlError::ReaderError(e.clone())), + } +} + +// ============================================================================ +// Connection +// ============================================================================ + +pub struct Connection { + handle: SqlHDbc, +} + +unsafe impl Send for Connection {} + +impl Connection { + pub fn connect(env: &Environment, conn_str: &str) -> Result { + let f = fns(); + let mut handle = SQL_NULL_HANDLE; + let rc = unsafe { (f.SQLAllocHandle)(SQL_HANDLE_DBC, env.handle(), &mut handle) }; + if !succeeded(rc) { + return Err(GgsqlError::ReaderError( + "Failed to allocate ODBC connection handle".into(), + )); + } + + let conn_cstr = std::ffi::CString::new(conn_str) + .map_err(|_| GgsqlError::ReaderError("Connection string contains null byte".into()))?; + let rc = unsafe { + (f.SQLDriverConnect)( + handle, + std::ptr::null_mut(), // no window handle + conn_cstr.as_ptr() as *const SqlChar, + conn_str.len() as SqlSmallInt, + std::ptr::null_mut(), // no output buffer + 0, + std::ptr::null_mut(), + SQL_DRIVER_NOPROMPT, + ) + }; + if !succeeded(rc) { + let diag = extract_diagnostic(SQL_HANDLE_DBC, handle); + unsafe { (f.SQLFreeHandle)(SQL_HANDLE_DBC, handle) }; + return Err(GgsqlError::ReaderError(format!( + "ODBC connection failed: {}", + diag + ))); + } + + Ok(Connection { handle }) + } + + pub fn handle(&self) -> SqlHDbc { + self.handle + } + + /// Execute a SQL statement, returning a Statement if it produces a result set. + pub fn execute(&self, sql: &str) -> Result> { + let f = fns(); + let mut stmt_handle = SQL_NULL_HANDLE; + let rc = unsafe { (f.SQLAllocHandle)(SQL_HANDLE_STMT, self.handle, &mut stmt_handle) }; + check( + rc, + SQL_HANDLE_DBC, + self.handle, + "Failed to allocate statement", + )?; + + let sql_cstr = std::ffi::CString::new(sql) + .map_err(|_| GgsqlError::ReaderError("SQL string contains null byte".into()))?; + let rc = unsafe { + (f.SQLExecDirect)( + stmt_handle, + sql_cstr.as_ptr() as *const SqlChar, + sql.len() as SqlInteger, + ) + }; + + match rc { + SQL_SUCCESS | SQL_SUCCESS_WITH_INFO => {} + SQL_NO_DATA => { + // DDL or statement with no result set + unsafe { (f.SQLFreeHandle)(SQL_HANDLE_STMT, stmt_handle) }; + return Ok(None); + } + _ => { + let diag = extract_diagnostic(SQL_HANDLE_STMT, stmt_handle); + unsafe { (f.SQLFreeHandle)(SQL_HANDLE_STMT, stmt_handle) }; + return Err(GgsqlError::ReaderError(format!( + "ODBC execute failed: {}", + diag + ))); + } + } + + // Check if there's a result set + let mut col_count: SqlSmallInt = 0; + let rc = unsafe { (f.SQLNumResultCols)(stmt_handle, &mut col_count) }; + check( + rc, + SQL_HANDLE_STMT, + stmt_handle, + "Failed to get column count", + )?; + + if col_count == 0 { + unsafe { (f.SQLFreeHandle)(SQL_HANDLE_STMT, stmt_handle) }; + return Ok(None); + } + + Ok(Some(Statement { + handle: stmt_handle, + })) + } + + /// Prepare a SQL statement for repeated execution with parameters. + pub fn prepare(&self, sql: &str) -> Result { + let f = fns(); + let mut stmt_handle = SQL_NULL_HANDLE; + let rc = unsafe { (f.SQLAllocHandle)(SQL_HANDLE_STMT, self.handle, &mut stmt_handle) }; + check( + rc, + SQL_HANDLE_DBC, + self.handle, + "Failed to allocate statement", + )?; + + let sql_cstr = std::ffi::CString::new(sql) + .map_err(|_| GgsqlError::ReaderError("SQL string contains null byte".into()))?; + let rc = unsafe { + (f.SQLPrepare)( + stmt_handle, + sql_cstr.as_ptr() as *const SqlChar, + sql.len() as SqlInteger, + ) + }; + if !succeeded(rc) { + let diag = extract_diagnostic(SQL_HANDLE_STMT, stmt_handle); + unsafe { (f.SQLFreeHandle)(SQL_HANDLE_STMT, stmt_handle) }; + return Err(GgsqlError::ReaderError(format!( + "ODBC prepare failed: {}", + diag + ))); + } + + Ok(PreparedStatement { + handle: stmt_handle, + }) + } + + /// Get DBMS name via SQLGetInfo. + pub fn dbms_name(&self) -> Option { + let f = fns(); + let mut buf = vec![0u8; 256]; + let mut len: SqlSmallInt = 0; + let rc = unsafe { + (f.SQLGetInfo)( + self.handle, + SQL_DBMS_NAME, + buf.as_mut_ptr() as SqlPointer, + buf.len() as SqlSmallInt, + &mut len, + ) + }; + if succeeded(rc) && len > 0 { + let s = std::str::from_utf8(&buf[..len as usize]).ok()?.to_string(); + Some(s) + } else { + None + } + } +} + +impl Drop for Connection { + fn drop(&mut self) { + let f = fns(); + let rc = unsafe { (f.SQLDisconnect)(self.handle) }; + // If there's an open transaction, attempt rollback then retry + if !succeeded(rc) && sqlstate_is(SQL_HANDLE_DBC, self.handle, b"25000") { + unsafe { + (f.SQLEndTran)(SQL_HANDLE_DBC, self.handle, SQL_ROLLBACK); + (f.SQLDisconnect)(self.handle); + } + } + unsafe { (f.SQLFreeHandle)(SQL_HANDLE_DBC, self.handle) }; + } +} + +// ============================================================================ +// Statement (result cursor) +// ============================================================================ + +pub struct Statement { + handle: SqlHStmt, +} + +impl Statement { + pub fn handle(&self) -> SqlHStmt { + self.handle + } + + pub fn num_result_cols(&self) -> Result { + let f = fns(); + let mut count: SqlSmallInt = 0; + let rc = unsafe { (f.SQLNumResultCols)(self.handle, &mut count) }; + check( + rc, + SQL_HANDLE_STMT, + self.handle, + "Failed to get column count", + )?; + Ok(count as usize) + } + + /// Describe column `col` (1-based). + /// Returns (name, sql_data_type, column_size, decimal_digits, nullable). + pub fn describe_col( + &self, + col: u16, + ) -> Result<(String, SqlSmallInt, SqlULen, SqlSmallInt, bool)> { + let f = fns(); + let mut name_buf = vec![0u8; 256]; + let mut name_len: SqlSmallInt = 0; + let mut data_type: SqlSmallInt = 0; + let mut col_size: SqlULen = 0; + let mut decimal_digits: SqlSmallInt = 0; + let mut nullable: SqlSmallInt = 0; + + let rc = unsafe { + (f.SQLDescribeCol)( + self.handle, + col, + name_buf.as_mut_ptr(), + name_buf.len() as SqlSmallInt, + &mut name_len, + &mut data_type, + &mut col_size, + &mut decimal_digits, + &mut nullable, + ) + }; + check( + rc, + SQL_HANDLE_STMT, + self.handle, + "Failed to describe column", + )?; + + let name = std::str::from_utf8(&name_buf[..name_len as usize]) + .unwrap_or("?") + .to_string(); + + Ok(( + name, + data_type, + col_size, + decimal_digits, + nullable != SQL_NO_NULLS, + )) + } + + /// Bind a column buffer for batch fetching (1-based column index). + pub fn bind_col( + &self, + col: u16, + c_type: SqlSmallInt, + buffer: SqlPointer, + buffer_len: SqlLen, + indicator: *mut SqlLen, + ) -> Result<()> { + let f = fns(); + let rc = unsafe { (f.SQLBindCol)(self.handle, col, c_type, buffer, buffer_len, indicator) }; + check(rc, SQL_HANDLE_STMT, self.handle, "Failed to bind column") + } + + /// Set a statement attribute. + pub fn set_stmt_attr( + &self, + attribute: SqlInteger, + value: SqlPointer, + string_length: SqlInteger, + ) -> Result<()> { + let f = fns(); + let rc = unsafe { (f.SQLSetStmtAttr)(self.handle, attribute, value, string_length) }; + check( + rc, + SQL_HANDLE_STMT, + self.handle, + "Failed to set statement attribute", + ) + } + + /// Set up batch fetching with the given batch size. + /// Returns a mutable reference location for rows_fetched. + pub fn setup_batch_fetch(&self, batch_size: usize) -> Result<()> { + // Column-wise binding + self.set_stmt_attr(SQL_ATTR_ROW_BIND_TYPE, SQL_BIND_BY_COLUMN as SqlPointer, 0)?; + // Array size + self.set_stmt_attr(SQL_ATTR_ROW_ARRAY_SIZE, batch_size as SqlPointer, 0)?; + Ok(()) + } + + /// Set the rows-fetched pointer. + /// + /// # Safety + /// `rows_fetched` must remain valid and pinned for the lifetime of the cursor. + pub unsafe fn set_rows_fetched_ptr(&self, rows_fetched: *mut SqlULen) -> Result<()> { + self.set_stmt_attr(SQL_ATTR_ROWS_FETCHED_PTR, rows_fetched as SqlPointer, 0) + } + + /// Fetch the next batch of rows. Returns the ODBC return code. + pub fn fetch_raw(&self) -> SqlReturn { + let f = fns(); + unsafe { (f.SQLFetch)(self.handle) } + } + + /// Unbind all columns. + pub fn unbind_cols(&self) -> Result<()> { + let f = fns(); + let rc = unsafe { (f.SQLFreeStmt)(self.handle, SQL_UNBIND) }; + check(rc, SQL_HANDLE_STMT, self.handle, "Failed to unbind columns") + } +} + +impl Drop for Statement { + fn drop(&mut self) { + let f = fns(); + let rc = unsafe { (f.SQLFreeHandle)(SQL_HANDLE_STMT, self.handle) }; + if !succeeded(rc) && !std::thread::panicking() { + panic!( + "SQLFreeHandle(STMT) failed: {}", + extract_diagnostic(SQL_HANDLE_STMT, self.handle) + ); + } + } +} + +// ============================================================================ +// PreparedStatement (for bulk insert) +// ============================================================================ + +pub struct PreparedStatement { + handle: SqlHStmt, +} + +impl PreparedStatement { + /// Bind a text parameter (1-based index). + /// + /// # Safety + /// `value_ptr` and `indicator` must remain valid until execute() or reset_params(). + pub unsafe fn bind_text_parameter( + &self, + param_num: u16, + value_ptr: *const u8, + buffer_len: SqlLen, + indicator: *mut SqlLen, + ) -> Result<()> { + let f = fns(); + let rc = unsafe { + (f.SQLBindParameter)( + self.handle, + param_num, + SQL_PARAM_INPUT, + SQL_C_CHAR, + SQL_VARCHAR, + buffer_len as SqlULen, + 0, + value_ptr as SqlPointer, + buffer_len, + indicator, + ) + }; + check(rc, SQL_HANDLE_STMT, self.handle, "Failed to bind parameter") + } + + pub fn execute(&self) -> Result<()> { + let f = fns(); + let rc = unsafe { (f.SQLExecute)(self.handle) }; + check( + rc, + SQL_HANDLE_STMT, + self.handle, + "Failed to execute prepared statement", + ) + } + + pub fn reset_params(&self) -> Result<()> { + let f = fns(); + let rc = unsafe { (f.SQLFreeStmt)(self.handle, SQL_RESET_PARAMS) }; + check( + rc, + SQL_HANDLE_STMT, + self.handle, + "Failed to reset parameters", + ) + } +} + +impl Drop for PreparedStatement { + fn drop(&mut self) { + let f = fns(); + let rc = unsafe { (f.SQLFreeHandle)(SQL_HANDLE_STMT, self.handle) }; + if !succeeded(rc) && !std::thread::panicking() { + panic!( + "SQLFreeHandle(STMT) failed: {}", + extract_diagnostic(SQL_HANDLE_STMT, self.handle) + ); + } + } +} + +// ============================================================================ +// ODBC catalog function helpers +// ============================================================================ + +/// Execute SQLTables and return the result as a Statement cursor. +pub fn sql_tables( + conn: &Connection, + catalog: Option<&str>, + schema: Option<&str>, + table: Option<&str>, + table_type: Option<&str>, +) -> Result { + let f = fns(); + let mut stmt_handle = SQL_NULL_HANDLE; + let rc = unsafe { (f.SQLAllocHandle)(SQL_HANDLE_STMT, conn.handle(), &mut stmt_handle) }; + check( + rc, + SQL_HANDLE_DBC, + conn.handle(), + "Failed to allocate statement for SQLTables", + )?; + + let (cat_cs, cat_len) = str_to_odbc_cstring(catalog)?; + let (sch_cs, sch_len) = str_to_odbc_cstring(schema)?; + let (tbl_cs, tbl_len) = str_to_odbc_cstring(table)?; + let (typ_cs, typ_len) = str_to_odbc_cstring(table_type)?; + + let rc = unsafe { + (f.SQLTables)( + stmt_handle, + cstring_ptr(&cat_cs), + cat_len, + cstring_ptr(&sch_cs), + sch_len, + cstring_ptr(&tbl_cs), + tbl_len, + cstring_ptr(&typ_cs), + typ_len, + ) + }; + if !succeeded(rc) { + let diag = extract_diagnostic(SQL_HANDLE_STMT, stmt_handle); + unsafe { (f.SQLFreeHandle)(SQL_HANDLE_STMT, stmt_handle) }; + return Err(GgsqlError::ReaderError(format!( + "SQLTables failed: {}", + diag + ))); + } + + Ok(Statement { + handle: stmt_handle, + }) +} + +/// Execute SQLColumns and return the result as a Statement cursor. +pub fn sql_columns( + conn: &Connection, + catalog: Option<&str>, + schema: Option<&str>, + table: Option<&str>, + column: Option<&str>, +) -> Result { + let f = fns(); + let mut stmt_handle = SQL_NULL_HANDLE; + let rc = unsafe { (f.SQLAllocHandle)(SQL_HANDLE_STMT, conn.handle(), &mut stmt_handle) }; + check( + rc, + SQL_HANDLE_DBC, + conn.handle(), + "Failed to allocate statement for SQLColumns", + )?; + + let (cat_cs, cat_len) = str_to_odbc_cstring(catalog)?; + let (sch_cs, sch_len) = str_to_odbc_cstring(schema)?; + let (tbl_cs, tbl_len) = str_to_odbc_cstring(table)?; + let (col_cs, col_len) = str_to_odbc_cstring(column)?; + + let rc = unsafe { + (f.SQLColumns)( + stmt_handle, + cstring_ptr(&cat_cs), + cat_len, + cstring_ptr(&sch_cs), + sch_len, + cstring_ptr(&tbl_cs), + tbl_len, + cstring_ptr(&col_cs), + col_len, + ) + }; + if !succeeded(rc) { + let diag = extract_diagnostic(SQL_HANDLE_STMT, stmt_handle); + unsafe { (f.SQLFreeHandle)(SQL_HANDLE_STMT, stmt_handle) }; + return Err(GgsqlError::ReaderError(format!( + "SQLColumns failed: {}", + diag + ))); + } + + Ok(Statement { + handle: stmt_handle, + }) +} + +fn str_to_odbc_cstring( + s: Option<&str>, +) -> Result<(Option, SqlSmallInt)> { + match s { + Some(s) => { + let cs = std::ffi::CString::new(s).map_err(|_| { + GgsqlError::ReaderError("ODBC catalog argument contains null byte".into()) + })?; + let len = s.len() as SqlSmallInt; + Ok((Some(cs), len)) + } + None => Ok((None, 0)), + } +} + +fn cstring_ptr(cs: &Option) -> *const SqlChar { + match cs { + Some(cs) => cs.as_ptr() as *const SqlChar, + None => std::ptr::null(), + } +} diff --git a/src/reader/snowflake.rs b/src/reader/snowflake.rs deleted file mode 100644 index 9257052f..00000000 --- a/src/reader/snowflake.rs +++ /dev/null @@ -1,35 +0,0 @@ -//! Snowflake-specific SQL dialect. -//! -//! Overrides schema introspection to use Snowflake's SHOW commands -//! instead of information_schema queries. - -use crate::naming; - -pub struct SnowflakeDialect; - -impl super::SqlDialect for SnowflakeDialect { - fn sql_list_catalogs(&self) -> String { - "SHOW DATABASES".into() - } - - fn sql_list_schemas(&self, catalog: &str) -> String { - format!("SHOW SCHEMAS IN DATABASE {}", naming::quote_ident(catalog)) - } - - fn sql_list_tables(&self, catalog: &str, schema: &str) -> String { - format!( - "SHOW OBJECTS IN SCHEMA {}.{}", - naming::quote_ident(catalog), - naming::quote_ident(schema) - ) - } - - fn sql_list_columns(&self, catalog: &str, schema: &str, table: &str) -> String { - format!( - "SHOW COLUMNS IN TABLE {}.{}.{}", - naming::quote_ident(catalog), - naming::quote_ident(schema), - naming::quote_ident(table) - ) - } -} diff --git a/src/reader/sqlite.rs b/src/reader/sqlite.rs index 67f1033b..9ed8c525 100644 --- a/src/reader/sqlite.rs +++ b/src/reader/sqlite.rs @@ -70,29 +70,6 @@ impl super::SqlDialect for SqliteDialect { } } - fn sql_list_catalogs(&self) -> String { - "SELECT name AS catalog_name FROM pragma_database_list ORDER BY name".into() - } - - fn sql_list_schemas(&self, _catalog: &str) -> String { - "SELECT 'main' AS schema_name".into() - } - - fn sql_list_tables(&self, catalog: &str, _schema: &str) -> String { - format!( - "SELECT name AS table_name, type AS table_type FROM {}.sqlite_master \ - WHERE type IN ('table', 'view') ORDER BY name", - naming::quote_ident(catalog) - ) - } - - fn sql_list_columns(&self, _catalog: &str, _schema: &str, table: &str) -> String { - format!( - "SELECT name AS column_name, type AS data_type FROM pragma_table_info('{}') ORDER BY cid", - table.replace('\'', "''") - ) - } - /// SQLite does not support `CREATE OR REPLACE`, so emit a drop-then-create /// pair. Column aliases are preserved portably via the default CTE wrapper. fn create_or_replace_temp_table_sql( @@ -537,6 +514,66 @@ impl Reader for SqliteReader { fn dialect(&self) -> &dyn super::SqlDialect { &SqliteDialect } + + fn list_catalogs(&self) -> Result> { + let df = self.execute_sql("SELECT name FROM pragma_database_list ORDER BY name")?; + let col = df.column("name")?; + let mut results = Vec::with_capacity(df.height()); + for i in 0..df.height() { + if !col.is_null(i) { + results.push(crate::array_util::value_to_string(col, i)); + } + } + Ok(results) + } + + fn list_schemas(&self, _catalog: &str) -> Result> { + Ok(vec![]) + } + + fn list_tables(&self, catalog: &str, _schema: &str) -> Result> { + let df = self.execute_sql(&format!( + "SELECT name, type FROM {}.sqlite_master \ + WHERE type IN ('table', 'view') ORDER BY name", + naming::quote_ident(catalog) + ))?; + let name_col = df.column("name")?; + let type_col = df.column("type")?; + let mut results = Vec::with_capacity(df.height()); + for i in 0..df.height() { + if !name_col.is_null(i) { + results.push(super::TableInfo { + name: crate::array_util::value_to_string(name_col, i), + table_type: crate::array_util::value_to_string(type_col, i), + }); + } + } + Ok(results) + } + + fn list_columns( + &self, + _catalog: &str, + _schema: &str, + table: &str, + ) -> Result> { + let df = self.execute_sql(&format!( + "SELECT name, type FROM pragma_table_info('{}') ORDER BY cid", + table.replace('\'', "''") + ))?; + let name_col = df.column("name")?; + let type_col = df.column("type")?; + let mut results = Vec::with_capacity(df.height()); + for i in 0..df.height() { + if !name_col.is_null(i) { + results.push(super::ColumnInfo { + name: crate::array_util::value_to_string(name_col, i), + data_type: crate::array_util::value_to_string(type_col, i), + }); + } + } + Ok(results) + } } /// Try to parse all non-null TEXT values as ISO-8601 dates (YYYY-MM-DD). From a0d00f95fefe62e8f71fbe0440c90a804c2de670 Mon Sep 17 00:00:00 2001 From: George Stagg Date: Fri, 1 May 2026 14:28:23 +0100 Subject: [PATCH 2/8] Fixup sqlite and cargo fmt --- ggsql-jupyter/src/connection.rs | 12 +++++++-- ggsql-jupyter/src/executor.rs | 43 ++++++++++++++++++++++++++------- ggsql-vscode/src/connections.ts | 40 ++++++++++++++++++++++++++++++ src/reader/odbc/mod.rs | 15 +++++++----- src/reader/odbc/wrapper.rs | 4 +-- src/reader/sqlite.rs | 19 ++++----------- 6 files changed, 99 insertions(+), 34 deletions(-) diff --git a/ggsql-jupyter/src/connection.rs b/ggsql-jupyter/src/connection.rs index e2ecf7a9..2fbe4009 100644 --- a/ggsql-jupyter/src/connection.rs +++ b/ggsql-jupyter/src/connection.rs @@ -49,7 +49,11 @@ pub fn list_objects(reader: &dyn Reader, path: &[String]) -> Result list_catalogs(reader), 1 => { - let catalog = if offset >= 1 { &default_catalog } else { &path[0] }; + let catalog = if offset >= 1 { + &default_catalog + } else { + &path[0] + }; list_schemas(reader, catalog) } 2 => { @@ -74,7 +78,11 @@ pub fn list_fields(reader: &dyn Reader, path: &[String]) -> Result (default_catalog.as_str(), default_schema.as_str(), path[0].as_str()), + 2 => ( + default_catalog.as_str(), + default_schema.as_str(), + path[0].as_str(), + ), 1 => (default_catalog.as_str(), path[0].as_str(), path[1].as_str()), _ => (path[0].as_str(), path[1].as_str(), path[2].as_str()), }; diff --git a/ggsql-jupyter/src/executor.rs b/ggsql-jupyter/src/executor.rs index 435c7295..a845f471 100644 --- a/ggsql-jupyter/src/executor.rs +++ b/ggsql-jupyter/src/executor.rs @@ -69,20 +69,16 @@ pub fn display_name_for_uri(uri: &str) -> String { return format!("DuckDB ({})", path); } if let Some(path) = uri.strip_prefix("sqlite://") { - if path.is_empty() { + if path == ":memory:" || path.is_empty() { return "SQLite (memory)".to_string(); } return format!("SQLite ({})", path); } if let Some(odbc) = uri.strip_prefix("odbc://") { - // Try to extract driver name from ODBC string - if let Some(driver_start) = odbc.to_lowercase().find("driver=") { - let rest = &odbc[driver_start + 7..]; - let driver = rest - .split(';') - .next() - .unwrap_or("ODBC") - .trim_matches(|c| c == '{' || c == '}'); + if let Some(dsn) = extract_odbc_value(odbc, "dsn") { + return format!("{} (ODBC)", dsn); + } + if let Some(driver) = extract_odbc_value(odbc, "driver") { return format!("{} (ODBC)", driver); } return "ODBC".to_string(); @@ -90,6 +86,20 @@ pub fn display_name_for_uri(uri: &str) -> String { uri.to_string() } +fn extract_odbc_value(conn_str: &str, key: &str) -> Option { + let lower = conn_str.to_lowercase(); + let prefix = format!("{}=", key); + let start = lower.find(&prefix)?; + let rest = &conn_str[start + prefix.len()..]; + let value = rest.split(';').next().unwrap_or(""); + let value = value.trim().trim_matches(|c| c == '{' || c == '}'); + if value.is_empty() { + None + } else { + Some(value.to_string()) + } +} + /// Detect the database type name from a connection URI (e.g. "DuckDB", "Snowflake"). pub fn type_name_for_uri(uri: &str) -> String { if uri.starts_with("duckdb://") { @@ -302,5 +312,20 @@ mod tests { fn test_display_name_for_uri() { assert_eq!(display_name_for_uri("duckdb://memory"), "DuckDB (memory)"); assert_eq!(display_name_for_uri("duckdb://my.db"), "DuckDB (my.db)"); + assert_eq!(display_name_for_uri("sqlite://:memory:"), "SQLite (memory)"); + assert_eq!(display_name_for_uri("sqlite://data.db"), "SQLite (data.db)"); + assert_eq!( + display_name_for_uri("odbc://DSN=my-postgres"), + "my-postgres (ODBC)" + ); + assert_eq!( + display_name_for_uri("odbc://Driver=Snowflake;Server=foo"), + "Snowflake (ODBC)" + ); + assert_eq!( + display_name_for_uri("odbc://Driver={PostgreSQL};DSN=pg-test"), + "pg-test (ODBC)" + ); + assert_eq!(display_name_for_uri("odbc://"), "ODBC"); } } diff --git a/ggsql-vscode/src/connections.ts b/ggsql-vscode/src/connections.ts index 0df0641f..66ee87cd 100644 --- a/ggsql-vscode/src/connections.ts +++ b/ggsql-vscode/src/connections.ts @@ -23,6 +23,7 @@ export function createConnectionDrivers( ): positron.ConnectionsDriver[] { return [ createDuckDBDriver(positronApi), + createSQLiteDriver(positronApi), createSnowflakeDefaultDriver(positronApi), createSnowflakePasswordDriver(positronApi), createSnowflakeSSODriver(positronApi), @@ -70,6 +71,45 @@ function createDuckDBDriver( }; } +// ============================================================================ +// SQLite +// ============================================================================ + +/** + * SQLite connection driver. + * + * Inputs: database file path (required). + */ +function createSQLiteDriver( + positronApi: PositronApi +): positron.ConnectionsDriver { + return { + driverId: 'ggsql-sqlite', + metadata: { + languageId: 'ggsql', + name: 'SQLite', + inputs: [ + { + id: 'database', + label: 'Database', + type: 'string', + value: '', + }, + ], + }, + generateCode: (inputs) => { + const db = inputs.find((i) => i.id === 'database')?.value?.trim(); + if (!db) { + return '-- @connect: sqlite://:memory:'; + } + return `-- @connect: sqlite://${db}`; + }, + connect: async (code: string) => { + await positronApi.runtime.executeCode('ggsql', code, false); + }, + }; +} + // ============================================================================ // Snowflake — shared helpers // ============================================================================ diff --git a/src/reader/odbc/mod.rs b/src/reader/odbc/mod.rs index b7d3a74e..cd15a32d 100644 --- a/src/reader/odbc/mod.rs +++ b/src/reader/odbc/mod.rs @@ -722,8 +722,9 @@ fn extract_batch( } else { let size = std::mem::size_of::(); let offset = row * size; - let d: SqlDateStruct = - unsafe { std::ptr::read_unaligned(buf.data[offset..].as_ptr() as *const _) }; + let d: SqlDateStruct = unsafe { + std::ptr::read_unaligned(buf.data[offset..].as_ptr() as *const _) + }; v.push(odbc_date_to_days(&d)); } } @@ -733,8 +734,9 @@ fn extract_batch( } else { let size = std::mem::size_of::(); let offset = row * size; - let t: SqlTimeStruct = - unsafe { std::ptr::read_unaligned(buf.data[offset..].as_ptr() as *const _) }; + let t: SqlTimeStruct = unsafe { + std::ptr::read_unaligned(buf.data[offset..].as_ptr() as *const _) + }; v.push(Some(odbc_time_to_nanos(&t))); } } @@ -744,8 +746,9 @@ fn extract_batch( } else { let size = std::mem::size_of::(); let offset = row * size; - let ts: SqlTimestampStruct = - unsafe { std::ptr::read_unaligned(buf.data[offset..].as_ptr() as *const _) }; + let ts: SqlTimestampStruct = unsafe { + std::ptr::read_unaligned(buf.data[offset..].as_ptr() as *const _) + }; v.push(odbc_timestamp_to_micros(&ts)); } } diff --git a/src/reader/odbc/wrapper.rs b/src/reader/odbc/wrapper.rs index 5e3a105f..f33f8d3f 100644 --- a/src/reader/odbc/wrapper.rs +++ b/src/reader/odbc/wrapper.rs @@ -651,9 +651,7 @@ pub fn sql_columns( }) } -fn str_to_odbc_cstring( - s: Option<&str>, -) -> Result<(Option, SqlSmallInt)> { +fn str_to_odbc_cstring(s: Option<&str>) -> Result<(Option, SqlSmallInt)> { match s { Some(s) => { let cs = std::ffi::CString::new(s).map_err(|_| { diff --git a/src/reader/sqlite.rs b/src/reader/sqlite.rs index 9ed8c525..1b780629 100644 --- a/src/reader/sqlite.rs +++ b/src/reader/sqlite.rs @@ -516,27 +516,18 @@ impl Reader for SqliteReader { } fn list_catalogs(&self) -> Result> { - let df = self.execute_sql("SELECT name FROM pragma_database_list ORDER BY name")?; - let col = df.column("name")?; - let mut results = Vec::with_capacity(df.height()); - for i in 0..df.height() { - if !col.is_null(i) { - results.push(crate::array_util::value_to_string(col, i)); - } - } - Ok(results) + Ok(vec![]) } fn list_schemas(&self, _catalog: &str) -> Result> { Ok(vec![]) } - fn list_tables(&self, catalog: &str, _schema: &str) -> Result> { - let df = self.execute_sql(&format!( - "SELECT name, type FROM {}.sqlite_master \ + fn list_tables(&self, _catalog: &str, _schema: &str) -> Result> { + let df = self.execute_sql( + "SELECT name, type FROM sqlite_master \ WHERE type IN ('table', 'view') ORDER BY name", - naming::quote_ident(catalog) - ))?; + )?; let name_col = df.column("name")?; let type_col = df.column("type")?; let mut results = Vec::with_capacity(df.height()); From 75b322041fa21d264ea3edb82c311e12573ec8b0 Mon Sep 17 00:00:00 2001 From: George Stagg Date: Fri, 1 May 2026 15:33:11 +0100 Subject: [PATCH 3/8] Apply changes from code review --- src/reader/odbc/mod.rs | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/reader/odbc/mod.rs b/src/reader/odbc/mod.rs index cd15a32d..5bcd6b67 100644 --- a/src/reader/odbc/mod.rs +++ b/src/reader/odbc/mod.rs @@ -758,8 +758,18 @@ fn extract_batch( } else { let elem_size = buf.text_buf_size + 1; let offset = row * elem_size; - let len = indicator as usize; - let actual_len = len.min(buf.text_buf_size); + // indicator is the actual byte length, but may be + // SQL_NO_TOTAL (-4) if the driver can't determine length. + // In that case, scan for null terminator in the buffer. + let actual_len = if indicator >= 0 { + (indicator as usize).min(buf.text_buf_size) + } else { + let slice = &buf.data[offset..offset + buf.text_buf_size]; + slice + .iter() + .position(|&b| b == 0) + .unwrap_or(buf.text_buf_size) + }; let bytes = &buf.data[offset..offset + actual_len]; let s = String::from_utf8_lossy(bytes).into_owned(); v.push(Some(s)); From 5541ce29163f09030848e95a1abfa0bba1607b96 Mon Sep 17 00:00:00 2001 From: George Stagg Date: Fri, 1 May 2026 15:56:55 +0100 Subject: [PATCH 4/8] Link with Rstrtmgr on Windows --- src/build.rs | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 src/build.rs diff --git a/src/build.rs b/src/build.rs new file mode 100644 index 00000000..0b2e6827 --- /dev/null +++ b/src/build.rs @@ -0,0 +1,5 @@ +fn main() { + if std::env::var("CARGO_CFG_WINDOWS").is_ok() { + println!("cargo:rustc-link-lib=Rstrtmgr"); + } +} From e3c18bc01934ad253d3853d2a1f3c41877c7f628 Mon Sep 17 00:00:00 2001 From: George Stagg Date: Fri, 1 May 2026 17:22:46 +0100 Subject: [PATCH 5/8] Add CHANGELOG entry --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3b9923ff..fb419989 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ ## [Unreleased] +### Changed + +- Restructured how ggsql integrates with ODBC drivers to use system ODBC, +rather than bundling unixodbc as part of binary releases. This fixes several +issues on Linux and macOS caused by using relative paths to dynamic libraries. + ## 0.3.1 - 2026-04-30 ### Fixed From 97ba2c7b0455488f79009f904a442e70da1b6cee Mon Sep 17 00:00:00 2001 From: George Stagg Date: Fri, 1 May 2026 17:58:35 +0100 Subject: [PATCH 6/8] Allow customising the snowflake driver name --- ggsql-vscode/src/connections.ts | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/ggsql-vscode/src/connections.ts b/ggsql-vscode/src/connections.ts index 66ee87cd..ee15ddfe 100644 --- a/ggsql-vscode/src/connections.ts +++ b/ggsql-vscode/src/connections.ts @@ -200,7 +200,8 @@ function readSnowflakeConnections(): { * Build an ODBC connection string for Snowflake with the given parts. */ function buildSnowflakeOdbc(parts: Record): string { - let connStr = `Driver=Snowflake;Server=${parts.account}.snowflakecomputing.com`; + const driver = parts.driver || 'Snowflake'; + let connStr = `Driver=${driver};Server=${parts.account}.snowflakecomputing.com`; if (parts.uid) { connStr += `;UID=${parts.uid}`; } @@ -271,6 +272,8 @@ function createSnowflakeDefaultDriver( ]; } + inputs.unshift({ id: 'driver', label: 'Driver', type: 'string', value: 'Snowflake' }); + return { driverId: 'ggsql-snowflake-default', metadata: { @@ -282,7 +285,9 @@ function createSnowflakeDefaultDriver( generateCode: (inputs) => { const name = inputs.find((i) => i.id === 'connection_name')?.value?.trim() || 'default'; - return `-- @connect: odbc://Driver=Snowflake;ConnectionName=${name}`; + const driver = + inputs.find((i) => i.id === 'driver')?.value?.trim() || 'Snowflake'; + return `-- @connect: odbc://Driver=${driver};ConnectionName=${name}`; }, connect: snowflakeConnect(positronApi), }; @@ -302,6 +307,7 @@ function createSnowflakePasswordDriver( name: 'Snowflake', description: 'Username/Password', inputs: [ + { id: 'driver', label: 'Driver', type: 'string', value: 'Snowflake' }, { id: 'account', label: 'Account', type: 'string' }, { id: 'user', label: 'User', type: 'string' }, { id: 'password', label: 'Password', type: 'string' }, @@ -320,6 +326,7 @@ function createSnowflakePasswordDriver( warehouse: get('warehouse'), database: get('database') || undefined, schema: get('schema') || undefined, + driver: get('driver') || undefined, }); }, connect: snowflakeConnect(positronApi), @@ -340,6 +347,7 @@ function createSnowflakeSSODriver( name: 'Snowflake', description: 'External Browser (SSO)', inputs: [ + { id: 'driver', label: 'Driver', type: 'string', value: 'Snowflake' }, { id: 'account', label: 'Account', type: 'string' }, { id: 'user', label: 'User', type: 'string', value: '' }, { id: 'warehouse', label: 'Warehouse', type: 'string' }, @@ -357,6 +365,7 @@ function createSnowflakeSSODriver( warehouse: get('warehouse'), database: get('database') || undefined, schema: get('schema') || undefined, + driver: get('driver') || undefined, }); }, connect: snowflakeConnect(positronApi), @@ -377,6 +386,7 @@ function createSnowflakePATDriver( name: 'Snowflake', description: 'Programmatic Access Token (PAT)', inputs: [ + { id: 'driver', label: 'Driver', type: 'string', value: 'Snowflake' }, { id: 'account', label: 'Account', type: 'string' }, { id: 'token', label: 'Token', type: 'string' }, { id: 'warehouse', label: 'Warehouse', type: 'string' }, @@ -394,6 +404,7 @@ function createSnowflakePATDriver( warehouse: get('warehouse'), database: get('database') || undefined, schema: get('schema') || undefined, + driver: get('driver') || undefined, }); }, connect: snowflakeConnect(positronApi), From ea8d227b587f71aa8f77c2471d46dd0a2d9a2bd2 Mon Sep 17 00:00:00 2001 From: George Stagg Date: Sat, 2 May 2026 10:00:32 +0100 Subject: [PATCH 7/8] Use SQL_OV_ODBC3_80 --- src/reader/odbc/ffi.rs | 2 +- src/reader/odbc/wrapper.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/reader/odbc/ffi.rs b/src/reader/odbc/ffi.rs index ab65b675..5d5ed70e 100644 --- a/src/reader/odbc/ffi.rs +++ b/src/reader/odbc/ffi.rs @@ -54,7 +54,7 @@ pub const SQL_NULL_HANDLE: SqlHandle = std::ptr::null_mut(); // ============================================================================ pub const SQL_ATTR_ODBC_VERSION: SqlInteger = 200; -pub const SQL_OV_ODBC3: SqlInteger = 3; +pub const SQL_OV_ODBC3_80: SqlInteger = 380; // ============================================================================ // Connection attributes diff --git a/src/reader/odbc/wrapper.rs b/src/reader/odbc/wrapper.rs index f33f8d3f..31cdb2b9 100644 --- a/src/reader/odbc/wrapper.rs +++ b/src/reader/odbc/wrapper.rs @@ -112,7 +112,7 @@ impl Environment { } let rc = unsafe { - (f.SQLSetEnvAttr)(handle, SQL_ATTR_ODBC_VERSION, SQL_OV_ODBC3 as SqlPointer, 0) + (f.SQLSetEnvAttr)(handle, SQL_ATTR_ODBC_VERSION, SQL_OV_ODBC3_80 as SqlPointer, 0) }; check(rc, SQL_HANDLE_ENV, handle, "Failed to set ODBC version")?; From ade3b0b0ae71e660df42b6f8fa964abd6ea770b5 Mon Sep 17 00:00:00 2001 From: George Stagg Date: Sat, 2 May 2026 10:04:11 +0100 Subject: [PATCH 8/8] cargo fmt --- src/reader/odbc/wrapper.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/reader/odbc/wrapper.rs b/src/reader/odbc/wrapper.rs index 31cdb2b9..2cb9c79b 100644 --- a/src/reader/odbc/wrapper.rs +++ b/src/reader/odbc/wrapper.rs @@ -112,7 +112,12 @@ impl Environment { } let rc = unsafe { - (f.SQLSetEnvAttr)(handle, SQL_ATTR_ODBC_VERSION, SQL_OV_ODBC3_80 as SqlPointer, 0) + (f.SQLSetEnvAttr)( + handle, + SQL_ATTR_ODBC_VERSION, + SQL_OV_ODBC3_80 as SqlPointer, + 0, + ) }; check(rc, SQL_HANDLE_ENV, handle, "Failed to set ODBC version")?;