From db8a214589e67db346beacf5d3b3ea7b431e163e Mon Sep 17 00:00:00 2001 From: Eran Sandler Date: Mon, 18 May 2026 10:35:23 -0700 Subject: [PATCH] fix catalog alias projections --- src/catalog/pg_roles.rs | 73 ++++++++++++++++++++++---------- src/catalog/query_interceptor.rs | 64 ++++++++++++++++++++-------- tests/catalog_alias_test.rs | 56 ++++++++++++++++++++++++ 3 files changed, 152 insertions(+), 41 deletions(-) create mode 100644 tests/catalog_alias_test.rs diff --git a/src/catalog/pg_roles.rs b/src/catalog/pg_roles.rs index 15bc12f5..30e2979a 100644 --- a/src/catalog/pg_roles.rs +++ b/src/catalog/pg_roles.rs @@ -1,9 +1,9 @@ -use crate::session::db_handler::{DbHandler, DbResponse}; +use super::where_evaluator::WhereEvaluator; use crate::PgSqliteError; -use sqlparser::ast::{Select, SelectItem, Expr}; -use tracing::debug; +use crate::session::db_handler::{DbHandler, DbResponse}; +use sqlparser::ast::{Expr, Select, SelectItem}; use std::collections::HashMap; -use super::where_evaluator::WhereEvaluator; +use tracing::debug; pub struct PgRolesHandler; @@ -31,7 +31,7 @@ impl PgRolesHandler { "rolconfig".to_string(), ]; - // Determine which columns to return + // Determine which output columns to return and which source columns supply values let selected_columns = Self::get_selected_columns(&select.projection, &all_columns); // Build default roles (since SQLite doesn't have role management) @@ -39,7 +39,7 @@ impl PgRolesHandler { // Apply WHERE clause filtering if present let filtered_roles = if let Some(where_clause) = &select.selection { - Self::apply_where_filter(&roles, where_clause, &selected_columns)? + Self::apply_where_filter(&roles, where_clause)? } else { roles }; @@ -48,8 +48,11 @@ impl PgRolesHandler { let mut rows = Vec::new(); for role in filtered_roles { let mut row = Vec::new(); - for column in &selected_columns { - let value = role.get(column).cloned().unwrap_or_else(|| b"".to_vec()); + for (_, source_column) in &selected_columns { + let value = role + .get(source_column) + .cloned() + .unwrap_or_else(|| b"".to_vec()); row.push(Some(value)); } rows.push(row); @@ -57,45 +60,70 @@ impl PgRolesHandler { let rows_count = rows.len(); Ok(DbResponse { - columns: selected_columns, + columns: selected_columns + .into_iter() + .map(|(output_column, _)| output_column) + .collect(), rows, rows_affected: rows_count, }) } - fn get_selected_columns(projection: &[SelectItem], all_columns: &[String]) -> Vec { + fn get_selected_columns( + projection: &[SelectItem], + all_columns: &[String], + ) -> Vec<(String, String)> { let mut selected = Vec::new(); for item in projection { match item { SelectItem::Wildcard(_) => { - selected.extend_from_slice(all_columns); + selected.extend( + all_columns + .iter() + .map(|column| (column.clone(), column.clone())), + ); break; } - SelectItem::UnnamedExpr(Expr::Identifier(ident)) => { - let col_name = ident.value.to_lowercase(); - if all_columns.contains(&col_name) { - selected.push(col_name); + SelectItem::UnnamedExpr(expr) => { + if let Some(col_name) = Self::extract_source_column(expr) { + if all_columns.contains(&col_name) { + selected.push((col_name.clone(), col_name)); + } } } - SelectItem::ExprWithAlias { expr: Expr::Identifier(ident), alias } => { - let col_name = ident.value.to_lowercase(); - if all_columns.contains(&col_name) { - selected.push(alias.value.clone()); + SelectItem::ExprWithAlias { expr, alias } => { + if let Some(col_name) = Self::extract_source_column(expr) { + if all_columns.contains(&col_name) { + selected.push((alias.value.clone(), col_name)); + } } } SelectItem::QualifiedWildcard(_, _) => { // For qualified wildcard like pg_roles.*, return all columns - selected.extend_from_slice(all_columns); + selected.extend( + all_columns + .iter() + .map(|column| (column.clone(), column.clone())), + ); break; } - _ => {} } } selected } + fn extract_source_column(expr: &Expr) -> Option { + match expr { + Expr::Identifier(ident) => Some(ident.value.to_lowercase()), + Expr::CompoundIdentifier(parts) => parts.last().map(|ident| ident.value.to_lowercase()), + Expr::Cast { expr, .. } => Self::extract_source_column(expr), + Expr::Nested(expr) => Self::extract_source_column(expr), + _ => None, + } + } + fn get_default_roles() -> Vec>> { let mut roles = Vec::new(); @@ -156,7 +184,6 @@ impl PgRolesHandler { fn apply_where_filter( roles: &[HashMap>], where_clause: &Expr, - _selected_columns: &[String], ) -> Result>>, PgSqliteError> { let mut filtered = Vec::new(); @@ -177,4 +204,4 @@ impl PgRolesHandler { Ok(filtered) } -} \ No newline at end of file +} diff --git a/src/catalog/query_interceptor.rs b/src/catalog/query_interceptor.rs index 96189268..0497afae 100644 --- a/src/catalog/query_interceptor.rs +++ b/src/catalog/query_interceptor.rs @@ -1054,10 +1054,11 @@ impl CatalogInterceptor { } } - fn handle_pg_namespace_query(_select: &Select) -> DbResponse { - // Return basic namespaces - let columns = vec!["oid".to_string(), "nspname".to_string()]; - let rows = vec![ + fn handle_pg_namespace_query(select: &Select) -> DbResponse { + let all_columns = vec!["oid".to_string(), "nspname".to_string()]; + let (columns, column_indices) = Self::extract_selected_columns(select, &all_columns); + + let full_rows = vec![ vec![ Some("11".to_string().into_bytes()), Some("pg_catalog".to_string().into_bytes()), @@ -1067,6 +1068,15 @@ impl CatalogInterceptor { Some("public".to_string().into_bytes()), ], ]; + let rows: Vec>>> = full_rows + .into_iter() + .map(|full_row| { + column_indices + .iter() + .map(|&idx| full_row[idx].clone()) + .collect() + }) + .collect(); let rows_affected = rows.len(); debug!("Returning {} rows for pg_type query with {} columns: {:?}", rows_affected, columns.len(), columns); @@ -1713,10 +1723,13 @@ impl CatalogInterceptor { }) } - /// Extract selected columns from a SELECT query for information_schema views + /// Extract selected output columns and source indices from a SELECT query. fn extract_selected_columns(select: &Select, all_columns: &[String]) -> (Vec, Vec) { if select.projection.len() == 1 - && let SelectItem::Wildcard(_) = &select.projection[0] { + && matches!( + &select.projection[0], + SelectItem::Wildcard(_) | SelectItem::QualifiedWildcard(_, _) + ) { // SELECT * - return all columns return (all_columns.to_vec(), (0..all_columns.len()).collect::>()); } @@ -1726,29 +1739,44 @@ impl CatalogInterceptor { let mut indices = Vec::new(); for item in &select.projection { match item { - SelectItem::UnnamedExpr(Expr::Identifier(ident)) => { - let col_name = ident.value.to_string(); - if let Some(idx) = all_columns.iter().position(|c| c == &col_name) { - cols.push(col_name); - indices.push(idx); + SelectItem::UnnamedExpr(expr) => { + if let Some(col_name) = Self::extract_projection_source_column(expr) { + if let Some(idx) = all_columns.iter().position(|c| c == &col_name) { + cols.push(all_columns[idx].clone()); + indices.push(idx); + } } } - SelectItem::UnnamedExpr(Expr::CompoundIdentifier(parts)) => { - // Handle compound identifiers like c.table_name - if let Some(last_part) = parts.last() { - let col_name = last_part.value.to_string(); + SelectItem::ExprWithAlias { expr, alias } => { + if let Some(col_name) = Self::extract_projection_source_column(expr) { if let Some(idx) = all_columns.iter().position(|c| c == &col_name) { - cols.push(col_name); + cols.push(alias.value.clone()); indices.push(idx); } } } - _ => {} + SelectItem::Wildcard(_) | SelectItem::QualifiedWildcard(_, _) => { + cols.extend_from_slice(all_columns); + indices.extend(0..all_columns.len()); + break; + } } } (cols, indices) } + fn extract_projection_source_column(expr: &Expr) -> Option { + match expr { + Expr::Identifier(ident) => Some(ident.value.to_lowercase()), + Expr::CompoundIdentifier(parts) => { + parts.last().map(|ident| ident.value.to_lowercase()) + } + Expr::Cast { expr, .. } => Self::extract_projection_source_column(expr), + Expr::Nested(expr) => Self::extract_projection_source_column(expr), + _ => None, + } + } + async fn handle_information_schema_schemata_query(select: &Select, _db: &DbHandler) -> DbResponse { debug!("Handling information_schema.schemata query"); @@ -3865,4 +3893,4 @@ impl CatalogInterceptor { Ok(filtered) } -} \ No newline at end of file +} diff --git a/tests/catalog_alias_test.rs b/tests/catalog_alias_test.rs new file mode 100644 index 00000000..6446bf12 --- /dev/null +++ b/tests/catalog_alias_test.rs @@ -0,0 +1,56 @@ +use pgsqlite::catalog::CatalogInterceptor; +use pgsqlite::session::db_handler::DbHandler; +use std::sync::Arc; + +async fn catalog_query(query: &str) -> pgsqlite::session::db_handler::DbResponse { + let db = Arc::new(DbHandler::new(":memory:").unwrap()); + CatalogInterceptor::intercept_query(query, db, None) + .await + .expect("query should be intercepted") + .expect("catalog query should succeed") +} + +fn text_cell(row: &[Option>], index: usize) -> String { + String::from_utf8(row[index].clone().expect("cell should not be NULL")).unwrap() +} + +#[tokio::test] +async fn test_pg_catalog_database_alias_projection() { + let response = catalog_query("SELECT oid AS did, datname FROM pg_catalog.pg_database").await; + + assert_eq!(response.columns, vec!["did", "datname"]); + assert_eq!(response.rows.len(), 1); + assert_eq!(text_cell(&response.rows[0], 0), "1"); + assert_eq!(text_cell(&response.rows[0], 1), "main"); +} + +#[tokio::test] +async fn test_pg_catalog_roles_alias_projection_uses_source_column() { + let response = catalog_query("SELECT rolname AS rolsuper FROM pg_catalog.pg_roles").await; + + assert_eq!(response.columns, vec!["rolsuper"]); + assert_eq!(response.rows.len(), 3); + let role_names: Vec = response.rows.iter().map(|row| text_cell(row, 0)).collect(); + assert_eq!(role_names, vec!["postgres", "public", "pgsqlite_user"]); +} + +#[tokio::test] +async fn test_pg_catalog_namespace_alias_projection() { + let response = catalog_query("SELECT oid AS did FROM pg_catalog.pg_namespace").await; + + assert_eq!(response.columns, vec!["did"]); + assert_eq!(response.rows.len(), 2); + assert_eq!(text_cell(&response.rows[0], 0), "11"); + assert_eq!(text_cell(&response.rows[1], 0), "2200"); +} + +#[tokio::test] +async fn test_information_schema_schemata_alias_projection() { + let response = catalog_query("SELECT catalog_name AS x FROM information_schema.schemata").await; + + assert_eq!(response.columns, vec!["x"]); + assert_eq!(response.rows.len(), 3); + for row in &response.rows { + assert_eq!(text_cell(row, 0), "main"); + } +}