diff --git a/crates/squawk_ide/src/code_actions/quote_identifier.rs b/crates/squawk_ide/src/code_actions/quote_identifier.rs index 8d6c8827..2e4e574c 100644 --- a/crates/squawk_ide/src/code_actions/quote_identifier.rs +++ b/crates/squawk_ide/src/code_actions/quote_identifier.rs @@ -1,9 +1,6 @@ use rowan::TextSize; use salsa::Database as Db; -use squawk_syntax::{ - ast::{self, AstNode}, - quote::normalize_identifier, -}; +use squawk_syntax::ast::{self, AstNode}; use crate::{db::File, offsets::token_from_offset}; @@ -18,25 +15,27 @@ pub(super) fn quote_identifier( let token = token_from_offset(db, file, offset)?; let parent = token.parent()?; - let name_node = if let Some(name) = ast::Name::cast(parent.clone()) { - name.syntax().clone() - } else if let Some(name_ref) = ast::NameRef::cast(parent) { - name_ref.syntax().clone() + let (is_quoted, text, text_range) = if let Some(name) = ast::Name::cast(parent.clone()) { + (name.is_quoted(), name.text(), name.syntax().text_range()) + } else if let Some(name_ref) = ast::NameRef::cast(parent.clone()) { + ( + name_ref.is_quoted(), + name_ref.text(), + name_ref.syntax().text_range(), + ) } else { return None; }; - let text = name_node.text().to_string(); - - if text.starts_with('"') { + if is_quoted { return None; } - let quoted = format!(r#""{}""#, normalize_identifier(&text)); + let quoted = format!(r#""{text}""#); actions.push(CodeAction { title: "Quote identifier".to_owned(), - edits: vec![squawk_linter::Edit::replace(name_node.text_range(), quoted)], + edits: vec![squawk_linter::Edit::replace(text_range, quoted)], kind: ActionKind::RefactorRewrite, }); diff --git a/crates/squawk_ide/src/code_actions/test_utils.rs b/crates/squawk_ide/src/code_actions/test_utils.rs index 0a9aa707..2a2dd08e 100644 --- a/crates/squawk_ide/src/code_actions/test_utils.rs +++ b/crates/squawk_ide/src/code_actions/test_utils.rs @@ -95,7 +95,7 @@ fn code_action_not_applicable_( allow_errors: bool, ) -> bool { let fixture = Fixture::new(sql); - let offset = fixture.marker().offset(); + let offset = fixture.marker().offset_before(); let sql = fixture.sql(); let db = Database::default(); let file = File::new(&db, sql.into()); diff --git a/crates/squawk_ide/src/collect.rs b/crates/squawk_ide/src/collect.rs index 0e9f4540..7b839c48 100644 --- a/crates/squawk_ide/src/collect.rs +++ b/crates/squawk_ide/src/collect.rs @@ -1,7 +1,6 @@ use crate::ast_nav; -use crate::builtins::builtins_file; use crate::column_name::ColumnName; -use crate::db::{File, parse}; +use crate::db::{File, list_files, parse}; use crate::goto_definition::goto_definition; use crate::infer::{Type, infer_type_from_expr, infer_type_from_ty}; use crate::location::{Location, LocationKind}; @@ -216,7 +215,7 @@ pub(crate) fn create_table_as_columns_with_types( file: File, create_table_as: &ast::CreateTableAs, ) -> Vec<(Name, Option)> { - for file in [file, builtins_file(db)] { + for file in list_files(db, file) { let columns = select_columns_with_types(db, file, &create_table_as.query()); if !columns.is_empty() { return columns; @@ -426,7 +425,7 @@ pub(crate) fn view_like_columns_with_types( .collect(); let mut base_columns = vec![]; - for file in [file, builtins_file(db)] { + for file in list_files(db, file) { base_columns = select_columns_with_types(db, file, &create_view.query()); if !base_columns.is_empty() { break; @@ -464,7 +463,7 @@ pub(crate) fn with_table_columns_with_types( .collect(); let mut base_columns = vec![]; - for file in [file, builtins_file(db)] { + for file in list_files(db, file) { base_columns = with_table_query_columns_with_types(db, file, with_table.clone()); if !base_columns.is_empty() { break; @@ -686,7 +685,7 @@ fn columns_for_star_from_table_ptr( columns_for_star_from_alias(db, file, &from_item, &alias) } Some(ast_nav::ParentSouce::WithTable(with_table)) => { - for f in [file, builtins_file(db)] { + for f in list_files(db, file) { let columns = with_table_columns_with_types(db, f, with_table.clone()); if !columns.is_empty() { return columns; @@ -826,7 +825,7 @@ pub(crate) fn star_column_names(db: &dyn Db, file: File, table_ptr: &SyntaxNodeP .filter_map(|column| column.name().map(|name| Name::from_node(&name))) .collect(), Some(ast_nav::ParentSouce::WithTable(with_table)) => { - for file in [file, builtins_file(db)] { + for file in list_files(db, file) { let columns: Vec<_> = with_table_columns_with_types(db, file, with_table.clone()) .into_iter() .map(|(name, _)| name) diff --git a/crates/squawk_ide/src/column_name.rs b/crates/squawk_ide/src/column_name.rs index c1781132..9a99d38f 100644 --- a/crates/squawk_ide/src/column_name.rs +++ b/crates/squawk_ide/src/column_name.rs @@ -3,8 +3,6 @@ use squawk_syntax::{ ast::{self, AstNode}, }; -use squawk_syntax::quote::normalize_identifier; - #[derive(Clone, Debug, PartialEq)] pub(crate) enum ColumnName { Column(String), @@ -29,9 +27,10 @@ impl ColumnName { if let Some(as_name) = target.as_name() && let Some(name_node) = as_name.name() { - let text = name_node.text(); - let normalized = normalize_identifier(&text); - return Some((ColumnName::Column(normalized), name_node.syntax().clone())); + return Some(( + ColumnName::Column(name_node.text()), + name_node.syntax().clone(), + )); } Self::inferred_from_target(target) } @@ -133,7 +132,7 @@ fn name_from_type(ty: ast::Type, unknown_column: bool) -> Option<(ColumnName, Sy name.push_str("tz"); }; return Some(( - ColumnName::new(name.to_string(), unknown_column), + ColumnName::new(name, unknown_column), time_type.syntax().clone(), )); } @@ -228,9 +227,10 @@ fn name_from_name_ref( } } } - let text = name_ref.text(); - let normalized = normalize_identifier(&text); - return Some((ColumnName::Column(normalized), name_ref.syntax().clone())); + return Some(( + ColumnName::Column(name_ref.text()), + name_ref.syntax().clone(), + )); } /* diff --git a/crates/squawk_ide/src/db.rs b/crates/squawk_ide/src/db.rs index 675b9a10..3ae21645 100644 --- a/crates/squawk_ide/src/db.rs +++ b/crates/squawk_ide/src/db.rs @@ -6,6 +6,7 @@ use std::sync::Arc; use crate::binder; use crate::binder::Binder; +use crate::builtins::builtins_file; #[salsa::input] pub struct File { @@ -23,6 +24,11 @@ pub fn line_index(db: &dyn Db, file: File) -> LineIndex { LineIndex::new(file.content(db)) } +#[inline] +pub(crate) fn list_files(db: &dyn Db, file: File) -> impl Iterator { + [file, builtins_file(db)].into_iter() +} + #[salsa::tracked] pub fn bind(db: &dyn Db, file: File) -> Binder { let result = parse(db, file); diff --git a/crates/squawk_ide/src/goto_definition.rs b/crates/squawk_ide/src/goto_definition.rs index 45053ad1..19966ab2 100644 --- a/crates/squawk_ide/src/goto_definition.rs +++ b/crates/squawk_ide/src/goto_definition.rs @@ -1,5 +1,4 @@ -use crate::builtins::builtins_file; -use crate::db::{File, parse}; +use crate::db::{File, list_files, parse}; use crate::location::{Location, LocationKind}; use crate::offsets::token_from_offset; use crate::resolve; @@ -66,7 +65,7 @@ pub fn goto_definition(db: &dyn Db, file: File, offset: TextSize) -> SmallVec<[L } if let Some(name_ref) = ast::NameRef::cast(parent.clone()) { - for definition_file in [file, builtins_file(db)] { + for definition_file in list_files(db, file) { if let Some(locations) = resolve::resolve_name_ref(db, definition_file, &name_ref) { return locations; } @@ -82,7 +81,7 @@ pub fn goto_definition(db: &dyn Db, file: File, offset: TextSize) -> SmallVec<[L } }); if let Some(ty) = type_node { - for definition_file in [file, builtins_file(db)] { + for definition_file in list_files(db, file) { let position = token.text_range().start(); if let Some(ptr) = resolve::resolve_type_ptr_from_type(db, definition_file, &ty, position) diff --git a/crates/squawk_ide/src/hover.rs b/crates/squawk_ide/src/hover.rs index 04b07a52..52ec534a 100644 --- a/crates/squawk_ide/src/hover.rs +++ b/crates/squawk_ide/src/hover.rs @@ -1,9 +1,8 @@ use crate::ast_nav; -use crate::builtins::builtins_file; use crate::collect; use crate::column_name::ColumnName; use crate::comments::preceding_comment; -use crate::db::{File, bind, parse}; +use crate::db::{File, bind, list_files, parse}; use crate::infer::infer_type_from_expr; use crate::location::{Location, LocationKind}; use crate::name; @@ -548,7 +547,7 @@ fn hover_qualified_star(db: &dyn Db, file: File, field_expr: ast::FieldExpr) -> fn hover_unqualified_star(db: &dyn Db, file: File, target: ast::Target) -> Option { let mut results = vec![]; - for file in [file, builtins_file(db)] { + for file in list_files(db, file) { results = hover_unqualified_star_with_binder(db, file, &target); if results.is_empty() && target_has_schema_qualified_from_item(&target) { continue; diff --git a/crates/squawk_ide/src/name.rs b/crates/squawk_ide/src/name.rs index f49a0d6d..0c818fbb 100644 --- a/crates/squawk_ide/src/name.rs +++ b/crates/squawk_ide/src/name.rs @@ -2,8 +2,6 @@ use smol_str::SmolStr; use squawk_syntax::ast::{self, AstNode}; use std::fmt; -use squawk_syntax::quote::normalize_identifier; - #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub(crate) struct Name(pub(crate) SmolStr); @@ -23,15 +21,20 @@ impl fmt::Display for Schema { } impl Name { + // TODO: we should get rid of this and update the ast methods to return + // normalized idents. pub(crate) fn from_string(text: impl Into) -> Self { let text = text.into(); - let normalized = normalize_identifier(&text); - Name(normalized.into()) + let text = text + .strip_prefix('"') + .and_then(|t| t.strip_suffix('"')) + .map(|x| x.replace(r#""""#, "\"")) + .unwrap_or(text.to_ascii_lowercase()); + Name(text.into()) } pub(crate) fn from_node(node: &impl ast::NameLike) -> Self { - let text = node.syntax().text().to_string(); - let normalized = normalize_identifier(&text); - Name(normalized.into()) + let text = node.text(); + Name(text.into()) } } @@ -198,7 +201,7 @@ mod test { use super::*; #[test] fn name_case_insensitive_compare() { - assert_eq!(Name::from_string("foo"), Name::from_string("FOO")); + assert_eq!(Name::from_string("foo"), Name::from_string(r#""foo""#)); } #[test] diff --git a/crates/squawk_ide/src/resolve.rs b/crates/squawk_ide/src/resolve.rs index 46841b54..d35fca62 100644 --- a/crates/squawk_ide/src/resolve.rs +++ b/crates/squawk_ide/src/resolve.rs @@ -1,4 +1,4 @@ -use crate::ast_nav; +use crate::{ast_nav, db::list_files}; use rowan::TextSize; use smallvec::{SmallVec, smallvec}; use squawk_syntax::{ @@ -6,7 +6,6 @@ use squawk_syntax::{ ast::{self, AstNode}, }; -use crate::builtins::builtins_file; use crate::column_name::ColumnName; use crate::db::File; use crate::location::{Location, LocationKind}; @@ -1379,7 +1378,7 @@ pub(crate) fn resolve_table_name( position: TextSize, ) -> Option<(File, ResolvedTableName)> { use ResolvedTableName::*; - for file in [file, builtins_file(db)] { + for file in list_files(db, file) { let Some((ptr, kind)) = resolve_table_like(db, file, None, table_name, schema, position) else { continue; diff --git a/crates/squawk_linter/src/rules/adding_field_with_default.rs b/crates/squawk_linter/src/rules/adding_field_with_default.rs index 0d112393..d5d53e38 100644 --- a/crates/squawk_linter/src/rules/adding_field_with_default.rs +++ b/crates/squawk_linter/src/rules/adding_field_with_default.rs @@ -2,20 +2,19 @@ use std::sync::OnceLock; use rustc_hash::FxHashSet; +use squawk_syntax::ast; use squawk_syntax::ast::AstNode; use squawk_syntax::{Parse, SourceFile, SyntaxKind}; -use squawk_syntax::{ast, identifier::Identifier}; use crate::{Linter, Rule, Version, Violation}; -fn non_volatile_funcs() -> &'static FxHashSet { - static NON_VOLATILE_FUNCS: OnceLock> = OnceLock::new(); +fn non_volatile_funcs() -> &'static FxHashSet<&'static str> { + static NON_VOLATILE_FUNCS: OnceLock> = OnceLock::new(); NON_VOLATILE_FUNCS.get_or_init(|| { NON_VOLATILE_BUILT_IN_FUNCTIONS .split('\n') .map(|x| x.trim()) .filter(|x| !x.is_empty()) - .map(Identifier::new) .collect() }) } @@ -41,8 +40,7 @@ fn is_non_volatile_or_const(expr: &ast::Expr) -> bool { return false; }; - let non_volatile_name = - non_volatile_funcs().contains(&Identifier::new(name_ref.text().as_str())); + let non_volatile_name = non_volatile_funcs().contains(name_ref.text().as_str()); no_args && non_volatile_name } else { diff --git a/crates/squawk_linter/src/rules/adding_foreign_key_constraint.rs b/crates/squawk_linter/src/rules/adding_foreign_key_constraint.rs index 92a2cacc..6e0660b8 100644 --- a/crates/squawk_linter/src/rules/adding_foreign_key_constraint.rs +++ b/crates/squawk_linter/src/rules/adding_foreign_key_constraint.rs @@ -1,7 +1,6 @@ use squawk_syntax::{ Parse, SourceFile, ast::{self, AstNode}, - identifier::Identifier, }; use crate::{ @@ -27,7 +26,7 @@ pub(crate) fn adding_foreign_key_constraint(ctx: &mut Linter, parse: &Parse { if add_constraint.not_valid().is_some() - || tables_created.contains(&Identifier::new(&table_name.text())) + || tables_created.contains(&table_name.text()) { // Adding foreign key is okay when: // - NOT VALID is specified. diff --git a/crates/squawk_linter/src/rules/adding_not_null_field.rs b/crates/squawk_linter/src/rules/adding_not_null_field.rs index 38307f75..1bed0c2c 100644 --- a/crates/squawk_linter/src/rules/adding_not_null_field.rs +++ b/crates/squawk_linter/src/rules/adding_not_null_field.rs @@ -3,18 +3,17 @@ use rustc_hash::{FxHashMap, FxHashSet}; use squawk_syntax::{ Parse, SourceFile, ast::{self, AstNode}, - identifier::Identifier, }; use crate::{Linter, Rule, Version, Violation}; #[derive(Debug, Clone, PartialEq, Eq, Hash)] struct TableColumn { - table: Identifier, - column: Identifier, + table: String, + column: String, } -fn is_not_null_check(expr: &ast::Expr) -> Option { +fn is_not_null_check(expr: &ast::Expr) -> Option { let ast::Expr::BinExpr(bin_expr) = expr else { return None; }; @@ -30,18 +29,18 @@ fn is_not_null_check(expr: &ast::Expr) -> Option { } match bin_expr.lhs()? { - ast::Expr::NameRef(name_ref) => Some(Identifier::new(&name_ref.text())), + ast::Expr::NameRef(name_ref) => Some(name_ref.text()), _ => None, } } -fn get_table_name(alter_table: &ast::AlterTable) -> Option { +fn get_table_name(alter_table: &ast::AlterTable) -> Option { alter_table .relation_name()? .path()? .segment()? .name_ref() - .map(|x| Identifier::new(&x.text())) + .map(|x| x.text()) } pub(crate) fn adding_not_null_field(ctx: &mut Linter, parse: &Parse) { @@ -49,12 +48,11 @@ pub(crate) fn adding_not_null_field(ctx: &mut Linter, parse: &Parse) let is_pg12_plus = ctx.settings.pg_version >= Version::new(12, None, None); - let mut not_null_constraints: FxHashMap = FxHashMap::default(); + let mut not_null_constraints: FxHashMap = FxHashMap::default(); let mut validated_not_null_columns: FxHashSet = FxHashSet::default(); // Tables where VALIDATE CONSTRAINT was seen without a matching ADD CONSTRAINT // in the same file (cross-migration pattern). - let mut tables_with_external_validated_constraints: FxHashSet = - FxHashSet::default(); + let mut tables_with_external_validated_constraints: FxHashSet = FxHashSet::default(); for stmt in file.stmts() { if let ast::Stmt::AlterTable(alter_table) = stmt { @@ -76,7 +74,7 @@ pub(crate) fn adding_not_null_field(ctx: &mut Linter, parse: &Parse) && let Some(column) = is_not_null_check(&expr) { not_null_constraints.insert( - Identifier::new(&constraint_name.text()), + constraint_name.text(), TableColumn { table: table.clone(), column, @@ -88,9 +86,8 @@ pub(crate) fn adding_not_null_field(ctx: &mut Linter, parse: &Parse) ast::AlterTableAction::ValidateConstraint(validate_constraint) if is_pg12_plus => { - if let Some(constraint_name) = validate_constraint - .name_ref() - .map(|x| Identifier::new(&x.text())) + if let Some(constraint_name) = + validate_constraint.name_ref().map(|x| x.text()) { if let Some(table_column) = not_null_constraints.get(&constraint_name) && table_column.table == table @@ -113,8 +110,7 @@ pub(crate) fn adding_not_null_field(ctx: &mut Linter, parse: &Parse) }; if is_pg12_plus - && let Some(column) = - alter_column.name_ref().map(|x| Identifier::new(&x.text())) + && let Some(column) = alter_column.name_ref().map(|x| x.text()) { let table_column = TableColumn { table: table.clone(), diff --git a/crates/squawk_linter/src/rules/ban_char_field.rs b/crates/squawk_linter/src/rules/ban_char_field.rs index f36ef7e1..46611ad7 100644 --- a/crates/squawk_linter/src/rules/ban_char_field.rs +++ b/crates/squawk_linter/src/rules/ban_char_field.rs @@ -2,9 +2,8 @@ use rustc_hash::FxHashSet; use rowan::TextRange; use squawk_syntax::{ - Parse, SourceFile, TokenText, + Parse, SourceFile, ast::{self, AstNode}, - identifier::Identifier, }; use crate::visitors::check_not_allowed_types; @@ -12,19 +11,13 @@ use crate::{Edit, Fix, Linter, Rule, Violation}; use std::sync::OnceLock; -fn char_types() -> &'static FxHashSet { - static CHAR_TYPES: OnceLock> = OnceLock::new(); - CHAR_TYPES.get_or_init(|| { - FxHashSet::from_iter([ - Identifier::new("char"), - Identifier::new("character"), - Identifier::new("bpchar"), - ]) - }) +fn char_types() -> &'static FxHashSet<&'static str> { + static CHAR_TYPES: OnceLock> = OnceLock::new(); + CHAR_TYPES.get_or_init(|| FxHashSet::from_iter(["char", "character", "bpchar"])) } -fn is_char_type(x: TokenText<'_>) -> bool { - char_types().contains(&Identifier::new(x.as_ref())) +fn is_char_type(x: &str) -> bool { + char_types().contains(x.to_ascii_lowercase().as_str()) } fn create_fix(range: TextRange, args: Option) -> Fix { @@ -43,7 +36,7 @@ fn check_path_type(ctx: &mut Linter, path_type: ast::PathType) { .path() .and_then(|x| x.segment()) .and_then(|x| x.name_ref()) - && is_char_type(name_ref.text()) + && is_char_type(&name_ref.text()) { let fix = create_fix(name_ref.syntax().text_range(), path_type.arg_list()); ctx.report(Violation::for_node( @@ -55,7 +48,7 @@ fn check_path_type(ctx: &mut Linter, path_type: ast::PathType) { } fn check_char_type(ctx: &mut Linter, char_type: ast::CharType) { - if is_char_type(char_type.text()) { + if is_char_type(&char_type.text()) { let fix = create_fix(char_type.syntax().text_range(), char_type.arg_list()); ctx.report(Violation::for_node( Rule::BanCharField, diff --git a/crates/squawk_linter/src/rules/constraint_missing_not_valid.rs b/crates/squawk_linter/src/rules/constraint_missing_not_valid.rs index 03c0f8ed..2055c345 100644 --- a/crates/squawk_linter/src/rules/constraint_missing_not_valid.rs +++ b/crates/squawk_linter/src/rules/constraint_missing_not_valid.rs @@ -3,7 +3,6 @@ use rustc_hash::FxHashSet; use squawk_syntax::{ Parse, SourceFile, ast::{self, AstNode}, - identifier::Identifier, }; use crate::{Linter, Rule, Violation}; @@ -11,7 +10,7 @@ use crate::{Linter, Rule, Violation}; pub fn tables_created_in_transaction( assume_in_transaction: bool, file: &ast::SourceFile, -) -> FxHashSet { +) -> FxHashSet { let mut created_table_names = FxHashSet::default(); let mut inside_transaction = assume_in_transaction; for stmt in file.stmts() { @@ -30,7 +29,7 @@ pub fn tables_created_in_transaction( else { continue; }; - created_table_names.insert(Identifier::new(&table_name.text())); + created_table_names.insert(table_name.text()); } _ => (), } @@ -44,7 +43,7 @@ fn not_valid_validate_in_transaction( file: &ast::SourceFile, ) { let mut inside_transaction = assume_in_transaction; - let mut not_valid_names: FxHashSet = FxHashSet::default(); + let mut not_valid_names: FxHashSet = FxHashSet::default(); for stmt in file.stmts() { match stmt { ast::Stmt::AlterTable(alter_table) => { @@ -52,10 +51,9 @@ fn not_valid_validate_in_transaction( match action { ast::AlterTableAction::ValidateConstraint(validate_constraint) => { if let Some(constraint_name) = - validate_constraint.name_ref().map(|x| x.text().to_string()) + validate_constraint.name_ref().map(|x| x.text()) { - if inside_transaction - && not_valid_names.contains(&Identifier::new(&constraint_name)) + if inside_transaction && not_valid_names.contains(&constraint_name) { ctx.report( Violation::for_node( @@ -72,7 +70,7 @@ fn not_valid_validate_in_transaction( && let Some(constraint_name) = constraint.constraint_name().and_then(|c| c.name()) { - not_valid_names.insert(Identifier::new(&constraint_name.text())); + not_valid_names.insert(constraint_name.text()); } } _ => (), @@ -109,14 +107,13 @@ pub(crate) fn constraint_missing_not_valid(ctx: &mut Linter, parse: &Parse) { } fn check_name(ctx: &mut Linter, name_like: &impl ast::NameLike) { - let text = name_like.syntax().text().to_string(); - let ident = normalize_identifier(&text); + let ident = name_like.text(); if ident.len() <= MAX_IDENT_BYTES { return; } - let fix = truncate(&text).map(|truncated| { + let fix = truncate(name_like).map(|truncated| { Fix::new( format!("Rename to `{truncated}`"), vec![Edit::replace(name_like.syntax().text_range(), truncated)], @@ -45,15 +44,16 @@ fn check_name(ctx: &mut Linter, name_like: &impl ast::NameLike) { ); } -fn truncate(text: &str) -> Option { - if has_escaped_quotes(text) { +fn truncate(name_like: &impl ast::NameLike) -> Option { + let raw = name_like.syntax().text().to_string(); + if has_escaped_quotes(&raw) { return None; } - let unquoted = normalize_identifier(text); - let truncated = &unquoted[..unquoted.floor_char_boundary(MAX_IDENT_BYTES)]; + let ident = name_like.text(); + let truncated = &ident[..ident.floor_char_boundary(MAX_IDENT_BYTES)]; - Some(if text.starts_with('"') { + Some(if raw.starts_with('"') || needs_quoting(truncated) { format!("\"{truncated}\"") } else { truncated.to_owned() diff --git a/crates/squawk_linter/src/rules/prefer_bigint_over_int.rs b/crates/squawk_linter/src/rules/prefer_bigint_over_int.rs index 253fedbe..de3af7d4 100644 --- a/crates/squawk_linter/src/rules/prefer_bigint_over_int.rs +++ b/crates/squawk_linter/src/rules/prefer_bigint_over_int.rs @@ -1,8 +1,7 @@ use rustc_hash::FxHashSet; use squawk_syntax::ast::AstNode; -use squawk_syntax::quote::normalize_identifier; -use squawk_syntax::{Parse, SourceFile, ast, identifier::Identifier}; +use squawk_syntax::{Parse, SourceFile, ast}; use crate::{Edit, Fix, Linter, Rule, Violation}; @@ -11,21 +10,13 @@ use crate::visitors::is_not_valid_int_type; use std::sync::OnceLock; -fn int_types() -> &'static FxHashSet { - static INT_TYPES: OnceLock> = OnceLock::new(); - INT_TYPES.get_or_init(|| { - FxHashSet::from_iter([ - Identifier::new("int"), - Identifier::new("integer"), - Identifier::new("int4"), - Identifier::new("serial"), - Identifier::new("serial4"), - ]) - }) +fn int_types() -> &'static FxHashSet<&'static str> { + static INT_TYPES: OnceLock> = OnceLock::new(); + INT_TYPES.get_or_init(|| FxHashSet::from_iter(["int", "integer", "int4", "serial", "serial4"])) } fn int_to_bigint_replacement(int_type: &str) -> &'static str { - match normalize_identifier(int_type).as_str() { + match int_type { "int" | "integer" => "bigint", "int4" => "int8", "serial" => "bigserial", @@ -35,9 +26,20 @@ fn int_to_bigint_replacement(int_type: &str) -> &'static str { } fn create_bigint_fix(ty: &ast::Type) -> Option { - let type_name = ty.syntax().first_token()?; - let i64 = int_to_bigint_replacement(type_name.text()); - let edit = Edit::replace(type_name.text_range(), i64); + let name = match ty { + ast::Type::ArrayType(array_type) => return create_bigint_fix(&array_type.ty()?), + ast::Type::PathType(path_type) => path_type.path()?.segment()?.name_ref()?, + ast::Type::BitType(_) + | ast::Type::CharType(_) + | ast::Type::DoubleType(_) + | ast::Type::ExprType(_) + | ast::Type::PercentType(_) + | ast::Type::TimeType(_) + | ast::Type::IntervalType(_) => return None, + }; + let int_type = name.text(); + let i64 = int_to_bigint_replacement(&int_type); + let edit = Edit::replace(name.syntax().text_range(), i64); Some(Fix::new( format!("Replace with a 64-bit integer type: `{i64}`"), vec![edit], diff --git a/crates/squawk_linter/src/rules/prefer_bigint_over_smallint.rs b/crates/squawk_linter/src/rules/prefer_bigint_over_smallint.rs index 696cd633..f2257200 100644 --- a/crates/squawk_linter/src/rules/prefer_bigint_over_smallint.rs +++ b/crates/squawk_linter/src/rules/prefer_bigint_over_smallint.rs @@ -1,8 +1,7 @@ use rustc_hash::FxHashSet; use squawk_syntax::ast::AstNode; -use squawk_syntax::quote::normalize_identifier; -use squawk_syntax::{Parse, SourceFile, ast, identifier::Identifier}; +use squawk_syntax::{Parse, SourceFile, ast}; use crate::{Edit, Fix, Linter, Rule, Violation}; @@ -11,20 +10,14 @@ use crate::visitors::is_not_valid_int_type; use std::sync::OnceLock; -fn small_int_types() -> &'static FxHashSet { - static SMALL_INT_TYPES: OnceLock> = OnceLock::new(); - SMALL_INT_TYPES.get_or_init(|| { - FxHashSet::from_iter([ - Identifier::new("smallint"), - Identifier::new("int2"), - Identifier::new("smallserial"), - Identifier::new("serial2"), - ]) - }) +fn small_int_types() -> &'static FxHashSet<&'static str> { + static SMALL_INT_TYPES: OnceLock> = OnceLock::new(); + SMALL_INT_TYPES + .get_or_init(|| FxHashSet::from_iter(["smallint", "int2", "smallserial", "serial2"])) } fn smallint_to_bigint(smallint_type: &str) -> &'static str { - match normalize_identifier(smallint_type).as_str() { + match smallint_type { "smallint" => "bigint", "int2" => "int8", "smallserial" => "bigserial", @@ -34,9 +27,19 @@ fn smallint_to_bigint(smallint_type: &str) -> &'static str { } fn create_bigint_fix(ty: &ast::Type) -> Option { - let type_name = ty.syntax().first_token()?; - let i64 = smallint_to_bigint(type_name.text()); - let edit = Edit::replace(type_name.text_range(), i64); + let name = match ty { + ast::Type::ArrayType(array_type) => return create_bigint_fix(&array_type.ty()?), + ast::Type::PathType(path_type) => path_type.path()?.segment()?.name_ref()?, + ast::Type::BitType(_) + | ast::Type::CharType(_) + | ast::Type::DoubleType(_) + | ast::Type::ExprType(_) + | ast::Type::PercentType(_) + | ast::Type::TimeType(_) + | ast::Type::IntervalType(_) => return None, + }; + let i64 = smallint_to_bigint(&name.text()); + let edit = Edit::replace(name.syntax().text_range(), i64); Some(Fix::new( format!("Replace with a 64-bit integer type: `{i64}`"), vec![edit], diff --git a/crates/squawk_linter/src/rules/prefer_identity.rs b/crates/squawk_linter/src/rules/prefer_identity.rs index 1a2797bc..6c4543a6 100644 --- a/crates/squawk_linter/src/rules/prefer_identity.rs +++ b/crates/squawk_linter/src/rules/prefer_identity.rs @@ -3,8 +3,6 @@ use rustc_hash::FxHashSet; use squawk_syntax::{ Parse, SourceFile, ast::{self, AstNode}, - identifier::Identifier, - quote::normalize_identifier, }; use crate::{Edit, Fix, Linter, Rule, Violation}; @@ -13,22 +11,22 @@ use std::sync::OnceLock; use crate::visitors::{check_not_allowed_types, is_not_valid_int_type}; -fn serial_types() -> &'static FxHashSet { - static SERIAL_TYPES: OnceLock> = OnceLock::new(); +fn serial_types() -> &'static FxHashSet<&'static str> { + static SERIAL_TYPES: OnceLock> = OnceLock::new(); SERIAL_TYPES.get_or_init(|| { FxHashSet::from_iter([ - Identifier::new("serial"), - Identifier::new("serial2"), - Identifier::new("serial4"), - Identifier::new("serial8"), - Identifier::new("smallserial"), - Identifier::new("bigserial"), + "serial", + "serial2", + "serial4", + "serial8", + "smallserial", + "bigserial", ]) }) } fn replace_serial(serial_type: &str) -> &'static str { - match normalize_identifier(serial_type).as_str() { + match serial_type { "serial" | "serial4" => "integer generated by default as identity", "serial2" | "smallserial" => "smallint generated by default as identity", "serial8" | "bigserial" => "bigint generated by default as identity", @@ -37,9 +35,19 @@ fn replace_serial(serial_type: &str) -> &'static str { } fn create_identity_fix(ty: &ast::Type) -> Option { - let type_name = ty.syntax().first_token()?; - let text = replace_serial(type_name.text()); - let edit = Edit::replace(ty.syntax().text_range(), text); + let name = match ty { + ast::Type::ArrayType(array_type) => return create_identity_fix(&array_type.ty()?), + ast::Type::PathType(path_type) => path_type.path()?.segment()?.name_ref()?, + ast::Type::BitType(_) + | ast::Type::CharType(_) + | ast::Type::DoubleType(_) + | ast::Type::ExprType(_) + | ast::Type::PercentType(_) + | ast::Type::TimeType(_) + | ast::Type::IntervalType(_) => return None, + }; + let text = replace_serial(&name.text()); + let edit = Edit::replace(name.syntax().text_range(), text); Some(Fix::new("Replace with IDENTITY column", vec![edit])) } diff --git a/crates/squawk_linter/src/rules/prefer_robust_stmts.rs b/crates/squawk_linter/src/rules/prefer_robust_stmts.rs index bc30f8eb..11e907c2 100644 --- a/crates/squawk_linter/src/rules/prefer_robust_stmts.rs +++ b/crates/squawk_linter/src/rules/prefer_robust_stmts.rs @@ -3,7 +3,6 @@ use rustc_hash::FxHashMap; use squawk_syntax::{ Parse, SourceFile, ast::{self, AstNode}, - identifier::Identifier, }; use crate::{Edit, Fix, Linter, Rule, Violation}; @@ -17,7 +16,7 @@ enum Constraint { pub(crate) fn prefer_robust_stmts(ctx: &mut Linter, parse: &Parse) { let file = parse.tree(); let mut inside_transaction = ctx.settings.assume_in_transaction; - let mut constraint_names: FxHashMap = FxHashMap::default(); + let mut constraint_names: FxHashMap = FxHashMap::default(); enum ActionErrorMessage { IfExists, @@ -38,10 +37,8 @@ pub(crate) fn prefer_robust_stmts(ctx: &mut Linter, parse: &Parse) { let (message_type, fix) = match &action { ast::AlterTableAction::DropConstraint(drop_constraint) => { if let Some(constraint_name) = drop_constraint.name_ref() { - constraint_names.insert( - Identifier::new(constraint_name.text().as_str()), - Constraint::Dropped, - ); + constraint_names + .insert(constraint_name.text(), Constraint::Dropped); } if drop_constraint.if_exists().is_some() { continue; @@ -68,12 +65,10 @@ pub(crate) fn prefer_robust_stmts(ctx: &mut Linter, parse: &Parse) { (ActionErrorMessage::IfNotExists, fix) } ast::AlterTableAction::ValidateConstraint(validate_constraint) => { - if let Some(constraint_name) = validate_constraint.name_ref() { - if constraint_names - .contains_key(&Identifier::new(constraint_name.text().as_str())) - { - continue; - } + if let Some(constraint_name) = validate_constraint.name_ref() + && constraint_names.contains_key(&constraint_name.text()) + { + continue; } (ActionErrorMessage::None, None) } @@ -82,8 +77,8 @@ pub(crate) fn prefer_robust_stmts(ctx: &mut Linter, parse: &Parse) { if let Some(constraint_name) = constraint .and_then(|x| x.constraint_name()) .and_then(|x| x.name()) - && let Some(constraint) = constraint_names - .get_mut(&Identifier::new(constraint_name.text().as_str())) + && let Some(constraint) = + constraint_names.get_mut(&constraint_name.text()) && *constraint == Constraint::Dropped { *constraint = Constraint::Added; diff --git a/crates/squawk_linter/src/rules/prefer_text_field.rs b/crates/squawk_linter/src/rules/prefer_text_field.rs index 95bcc88a..37bfa042 100644 --- a/crates/squawk_linter/src/rules/prefer_text_field.rs +++ b/crates/squawk_linter/src/rules/prefer_text_field.rs @@ -1,7 +1,6 @@ use squawk_syntax::{ Parse, SourceFile, ast::{self, AstNode}, - identifier::Identifier, }; use crate::{Edit, Fix, Linter, Rule, Violation}; @@ -23,17 +22,15 @@ fn is_not_allowed_varchar(ty: &ast::Type) -> bool { .path() .and_then(|x| x.segment()) .and_then(|x| x.name_ref()) - .map(|x| x.text().to_string()) + .map(|x| x.text()) else { return false; }; // if we don't have any args, then it's the same as `text` - Identifier::new(ty_name.as_str()) == Identifier::new("varchar") - && path_type.arg_list().is_some() + ty_name == "varchar" && path_type.arg_list().is_some() } ast::Type::CharType(char_type) => { - Identifier::new(&char_type.text()) == Identifier::new("varchar") - && char_type.arg_list().is_some() + char_type.text().eq_ignore_ascii_case("varchar") && char_type.arg_list().is_some() } ast::Type::BitType(_) => false, ast::Type::DoubleType(_) => false, diff --git a/crates/squawk_linter/src/rules/prefer_timestamptz.rs b/crates/squawk_linter/src/rules/prefer_timestamptz.rs index 8b7c03f8..f68d6ff5 100644 --- a/crates/squawk_linter/src/rules/prefer_timestamptz.rs +++ b/crates/squawk_linter/src/rules/prefer_timestamptz.rs @@ -1,7 +1,6 @@ use squawk_syntax::{ Parse, SourceFile, ast::{self, AstNode}, - identifier::Identifier, }; use crate::visitors::check_not_allowed_types; @@ -22,13 +21,12 @@ pub fn is_not_allowed_timestamp(ty: &ast::Type) -> bool { .path() .and_then(|x| x.segment()) .and_then(|x| x.name_ref()) - .map(|x| x.text().to_string()) + .map(|x| x.text()) else { return false; }; // if we don't have any args, then it's the same as `text` - Identifier::new(ty_name.as_str()) == Identifier::new("varchar") - && path_type.arg_list().is_some() + ty_name == "varchar" && path_type.arg_list().is_some() } ast::Type::CharType(_) => false, ast::Type::BitType(_) => false, diff --git a/crates/squawk_linter/src/rules/require_concurrent_index_creation.rs b/crates/squawk_linter/src/rules/require_concurrent_index_creation.rs index 4178f600..b4db3476 100644 --- a/crates/squawk_linter/src/rules/require_concurrent_index_creation.rs +++ b/crates/squawk_linter/src/rules/require_concurrent_index_creation.rs @@ -1,7 +1,6 @@ use squawk_syntax::{ Parse, SourceFile, ast::{self, AstNode}, - identifier::Identifier, }; use crate::{Edit, Fix, Linter, Rule, Violation}; @@ -27,7 +26,7 @@ pub(crate) fn require_concurrent_index_creation(ctx: &mut Linter, parse: &Parse< .and_then(|x| x.name_ref()) { if create_index.concurrently_token().is_none() - && !tables_created.contains(&Identifier::new(&table_name.text())) + && !tables_created.contains(&table_name.text()) { let fix = concurrently_fix(&create_index); diff --git a/crates/squawk_linter/src/rules/require_timeout_settings.rs b/crates/squawk_linter/src/rules/require_timeout_settings.rs index af5c2949..fe93b0d0 100644 --- a/crates/squawk_linter/src/rules/require_timeout_settings.rs +++ b/crates/squawk_linter/src/rules/require_timeout_settings.rs @@ -2,7 +2,6 @@ use rowan::TextSize; use squawk_syntax::{ Parse, SourceFile, SyntaxKind, ast::{self, AstNode}, - identifier::Identifier, }; use crate::{Edit, Fix, Linter, Rule, Violation, analyze}; @@ -62,10 +61,10 @@ pub(crate) fn require_timeout_settings(ctx: &mut Linter, parse: &Parse S let mut linter = Linter::from([rule]); linter.settings = settings; let errors = linter.lint(&file, &result); - assert_eq!( - errors.len(), - 0, - "Fixes should remove all the linter errors." - ); + assert_eq!(errors, vec![], "Fixes should remove all the linter errors."); result } diff --git a/crates/squawk_linter/src/visitors.rs b/crates/squawk_linter/src/visitors.rs index d14976dd..39a65dc7 100644 --- a/crates/squawk_linter/src/visitors.rs +++ b/crates/squawk_linter/src/visitors.rs @@ -1,12 +1,12 @@ use rustc_hash::FxHashSet; -use squawk_syntax::{ast, identifier::Identifier}; +use squawk_syntax::ast; use crate::Linter; pub(crate) fn is_not_valid_int_type( ty: &ast::Type, - invalid_type_names: &FxHashSet, + invalid_type_names: &FxHashSet<&'static str>, ) -> bool { match ty { ast::Type::ArrayType(array_type) => { @@ -22,12 +22,11 @@ pub(crate) fn is_not_valid_int_type( .path() .and_then(|x| x.segment()) .and_then(|x| x.name_ref()) - .map(|x| x.text().to_string()) + .map(|x| x.text()) else { return false; }; - let name = Identifier::new(ty_name.as_str()); - invalid_type_names.contains(&name) + invalid_type_names.contains(ty_name.as_str()) } ast::Type::CharType(_) => false, ast::Type::BitType(_) => false, diff --git a/crates/squawk_syntax/src/ast/node_ext.rs b/crates/squawk_syntax/src/ast/node_ext.rs index e7f93b6e..be01155b 100644 --- a/crates/squawk_syntax/src/ast/node_ext.rs +++ b/crates/squawk_syntax/src/ast/node_ext.rs @@ -36,6 +36,7 @@ use rowan::Direction; use crate::ast; use crate::ast::AstNode; +use crate::unescape::{escape_unicode_esc_str, uescape_char}; use crate::{SyntaxKind, SyntaxNode, SyntaxToken, TokenText}; use super::support; @@ -396,18 +397,81 @@ impl ast::CompoundSelect { impl ast::NameRef { #[inline] - pub fn text(&self) -> TokenText<'_> { - text_of_first_token(self.syntax()) + pub fn text(&self) -> String { + normalize_name_node(self.syntax()) + } + + #[inline] + pub fn is_quoted(&self) -> bool { + is_quoted(self.syntax()) } } impl ast::Name { #[inline] - pub fn text(&self) -> TokenText<'_> { - text_of_first_token(self.syntax()) + pub fn text(&self) -> String { + normalize_name_node(self.syntax()) + } + + #[inline] + pub fn is_quoted(&self) -> bool { + is_quoted(self.syntax()) } } +fn is_quoted(node: &SyntaxNode) -> bool { + let text = node.text(); + let first = text.char_at(0.into()); + let second = text.char_at(1.into()); + matches!( + (first, second), + (Some('u' | 'U'), Some('"')) | (Some('"'), Some(_)) + ) +} + +// TODO: return a NewType wrapper around String? +fn normalize_name_node(node: &SyntaxNode) -> String { + let mut tokens = node + .children_with_tokens() + .filter_map(|el| el.into_token()) + .filter(|t| !t.kind().is_trivia()); + + let Some(ident_token) = tokens.next() else { + return String::new(); + }; + let raw = ident_token.text(); + + let unicode_inner = raw + .strip_prefix(['u', 'U']) + .and_then(|s| s.strip_prefix("&\"")) + .and_then(|s| s.strip_suffix('"')); + + if let Some(inner) = unicode_inner { + let mut escape_char = '\\'; + if let Some(uesc) = tokens.next() + && uesc.kind() == SyntaxKind::UESCAPE_KW + && let Some(token) = tokens.next() + && let Some(ch) = uescape_char(token.text()) + { + escape_char = ch; + } + + let inner = inner.replace(r#""""#, "\""); + let mut result = String::with_capacity(inner.len()); + escape_unicode_esc_str(&inner, escape_char, |_range, r| { + if let Ok(ch) = r { + result.push(ch); + } + }); + return result; + } + + raw.strip_prefix('"') + .and_then(|t| t.strip_suffix('"')) + .map(|x| x.replace(r#""""#, "\"")) + .unwrap_or_else(|| raw.to_ascii_lowercase()) +} + impl ast::CharType { #[inline] pub fn text(&self) -> TokenText<'_> { @@ -509,8 +573,18 @@ impl ast::SelectVariant { impl ast::HasParamList for ast::FunctionSig {} impl ast::HasParamList for ast::Aggregate {} -impl ast::NameLike for ast::Name {} -impl ast::NameLike for ast::NameRef {} +impl ast::NameLike for ast::Name { + #[inline] + fn text(&self) -> String { + self.text() + } +} +impl ast::NameLike for ast::NameRef { + #[inline] + fn text(&self) -> String { + self.text() + } +} impl ast::HasWithClause for ast::Select {} impl ast::HasWithClause for ast::SelectInto {} @@ -522,6 +596,70 @@ impl ast::HasCreateTable for ast::CreateTable {} impl ast::HasCreateTable for ast::CreateForeignTable {} impl ast::HasCreateTable for ast::CreateTableLike {} +#[test] +fn name() { + assert_snapshot!(extract_name("select 1 foo"), @"foo"); + assert_snapshot!(extract_name("select 1 FOO"), @"foo"); + assert_snapshot!(extract_name(r#"select 1 "foo""#), @"foo"); + assert_snapshot!(extract_name(r#"select 1 "Foo""#), @"Foo"); + assert_snapshot!(extract_name(r#"select 1 "FOO""#), @"FOO"); + assert_snapshot!(extract_name(r#"select 1 U&"\0066\006f\006f""#), @"foo"); + assert_snapshot!(extract_name(r#"select 1 U&"@0066@006f@006f" uescape '@'"#), @"foo"); + + fn extract_name(source_code: &str) -> String { + let parse = SourceFile::parse(source_code); + assert!(parse.errors().is_empty()); + let stmt = parse.tree().stmts().next().unwrap(); + let ast::Stmt::Select(select) = stmt else { + unreachable!() + }; + let name = select + .select_clause() + .unwrap() + .target_list() + .unwrap() + .targets() + .next() + .unwrap() + .as_name() + .unwrap() + .name() + .unwrap(); + name.text().to_string() + } +} + +#[test] +fn name_ref() { + assert_snapshot!(extract_name_ref("select foo"), @"foo"); + assert_snapshot!(extract_name_ref("select FOO"), @"foo"); + assert_snapshot!(extract_name_ref(r#"select "foo""#), @"foo"); + assert_snapshot!(extract_name_ref(r#"select "Foo""#), @"Foo"); + assert_snapshot!(extract_name_ref(r#"select "FOO""#), @"FOO"); + assert_snapshot!(extract_name_ref(r#"select U&"\0066\006f\006f""#), @"foo"); + assert_snapshot!(extract_name_ref(r#"select U&"@0066@006f@006f" uescape '@'"#), @"foo"); + + fn extract_name_ref(source_code: &str) -> String { + let parse = SourceFile::parse(source_code); + assert!(parse.errors().is_empty()); + let stmt = parse.tree().stmts().next().unwrap(); + let ast::Stmt::Select(select) = stmt else { + unreachable!() + }; + let select_clause = select.select_clause().unwrap(); + let target = select_clause + .target_list() + .unwrap() + .targets() + .next() + .unwrap(); + let ast::Expr::NameRef(name_ref) = target.expr().unwrap() else { + unreachable!() + }; + name_ref.text().to_string() + } +} + #[test] fn index_expr() { let source_code = " @@ -529,8 +667,7 @@ fn index_expr() { "; let parse = SourceFile::parse(source_code); assert!(parse.errors().is_empty()); - let file: SourceFile = parse.tree(); - let stmt = file.stmts().next().unwrap(); + let stmt = parse.tree().stmts().next().unwrap(); let ast::Stmt::Select(select) = stmt else { unreachable!() }; @@ -558,8 +695,7 @@ fn slice_expr() { "; let parse = SourceFile::parse(source_code); assert!(parse.errors().is_empty()); - let file: SourceFile = parse.tree(); - let stmt = file.stmts().next().unwrap(); + let stmt = parse.tree().stmts().next().unwrap(); let ast::Stmt::Select(select) = stmt else { unreachable!() }; @@ -606,8 +742,7 @@ fn field_expr() { "; let parse = SourceFile::parse(source_code); assert!(parse.errors().is_empty()); - let file: SourceFile = parse.tree(); - let stmt = file.stmts().next().unwrap(); + let stmt = parse.tree().stmts().next().unwrap(); let ast::Stmt::Select(select) = stmt else { unreachable!() }; @@ -634,8 +769,7 @@ fn between_expr() { "; let parse = SourceFile::parse(source_code); assert!(parse.errors().is_empty()); - let file: SourceFile = parse.tree(); - let stmt = file.stmts().next().unwrap(); + let stmt = parse.tree().stmts().next().unwrap(); let ast::Stmt::Select(select) = stmt else { unreachable!() }; @@ -730,8 +864,8 @@ fn cast_expr() { fn extract_expr(sql: &str) -> ast::CastExpr { let parse = SourceFile::parse(sql); assert!(parse.errors().is_empty()); - let file: SourceFile = parse.tree(); - let node = file + let node = parse + .tree() .stmts() .map(|x| match x { ast::Stmt::Select(select) => select @@ -764,8 +898,7 @@ fn op_sig() { "; let parse = SourceFile::parse(source_code); assert!(parse.errors().is_empty()); - let file: SourceFile = parse.tree(); - let stmt = file.stmts().next().unwrap(); + let stmt = parse.tree().stmts().next().unwrap(); let ast::Stmt::AlterOperator(alter_op) = stmt else { unreachable!() }; @@ -783,8 +916,7 @@ fn cast_sig() { "; let parse = SourceFile::parse(source_code); assert!(parse.errors().is_empty()); - let file: SourceFile = parse.tree(); - let stmt = file.stmts().next().unwrap(); + let stmt = parse.tree().stmts().next().unwrap(); let ast::Stmt::DropCast(alter_op) = stmt else { unreachable!() }; @@ -799,8 +931,7 @@ fn cast_sig() { fn extract_vacuum(sql: &str) -> ast::Vacuum { let parse = SourceFile::parse(sql); assert!(parse.errors().is_empty()); - let file: SourceFile = parse.tree(); - let stmt = file.stmts().next().unwrap(); + let stmt = parse.tree().stmts().next().unwrap(); let ast::Stmt::Vacuum(vacuum) = stmt else { unreachable!() }; diff --git a/crates/squawk_syntax/src/ast/traits.rs b/crates/squawk_syntax/src/ast/traits.rs index a057603e..c804bd80 100644 --- a/crates/squawk_syntax/src/ast/traits.rs +++ b/crates/squawk_syntax/src/ast/traits.rs @@ -3,7 +3,9 @@ use crate::ast; use crate::ast::{AstNode, support}; -pub trait NameLike: AstNode {} +pub trait NameLike: AstNode { + fn text(&self) -> String; +} pub trait HasCreateTable: AstNode { #[inline] diff --git a/crates/squawk_syntax/src/identifier.rs b/crates/squawk_syntax/src/identifier.rs deleted file mode 100644 index 440d2c4b..00000000 --- a/crates/squawk_syntax/src/identifier.rs +++ /dev/null @@ -1,32 +0,0 @@ -use crate::quote::normalize_identifier; - -/// Postgres Identifiers are case insensitive unless they're quoted. -/// -/// This type handles the casing rules for us to make comparisions easier. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct Identifier(String); - -impl Identifier { - // TODO: we need to handle more advanced identifiers like: - // U&"d!0061t!+000061" UESCAPE '!' - pub fn new(s: &str) -> Self { - let normalized = normalize_identifier(s); - Identifier(normalized) - } -} - -#[cfg(test)] -mod test { - use crate::identifier::Identifier; - - #[test] - fn case_folds_correctly() { - // https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS - // For example, the identifiers FOO, foo, and "foo" are considered the - // same by PostgreSQL, but "Foo" and "FOO" are different from these - // three and each other. - assert_eq!(Identifier::new("FOO"), Identifier::new("foo")); - assert_eq!(Identifier::new(r#""foo""#), Identifier::new("foo")); - assert_eq!(Identifier::new(r#""foo""#), Identifier::new("FOO")); - } -} diff --git a/crates/squawk_syntax/src/lib.rs b/crates/squawk_syntax/src/lib.rs index e3090d10..e6662bf8 100644 --- a/crates/squawk_syntax/src/lib.rs +++ b/crates/squawk_syntax/src/lib.rs @@ -26,7 +26,6 @@ pub mod ast; mod generated; -pub mod identifier; mod parsing; mod ptr; pub mod quote; diff --git a/crates/squawk_syntax/src/quote.rs b/crates/squawk_syntax/src/quote.rs index 0d099e63..3528a018 100644 --- a/crates/squawk_syntax/src/quote.rs +++ b/crates/squawk_syntax/src/quote.rs @@ -44,7 +44,7 @@ pub fn unquote_ident(node: &SyntaxNode) -> Option { Some(text.to_string()) } -fn needs_quoting(text: &str) -> bool { +pub fn needs_quoting(text: &str) -> bool { if text.is_empty() { return true; } @@ -76,14 +76,6 @@ pub fn is_reserved_word(text: &str) -> bool { .is_ok() } -pub fn normalize_identifier(text: &str) -> String { - // TODO: Cow/SmolStr/Salsa Interned? - text.strip_prefix('"') - .and_then(|t| t.strip_suffix('"')) - .map(|x| x.replace(r#""""#, "\"")) - .unwrap_or_else(|| text.to_ascii_lowercase()) -} - #[cfg(test)] mod tests { use insta::assert_snapshot; diff --git a/crates/squawk_syntax/src/unescape.rs b/crates/squawk_syntax/src/unescape.rs index 70e81a5e..a2e3b6ab 100644 --- a/crates/squawk_syntax/src/unescape.rs +++ b/crates/squawk_syntax/src/unescape.rs @@ -152,6 +152,26 @@ where } } +// https://github.com/postgres/postgres/blob/228a1f9542792c6533ef74c2e7aefad0da1d9a7a/src/backend/parser/parser.c#L350 +const fn is_valid_uescape_char(byte: u8) -> bool { + !byte.is_ascii_hexdigit() + && byte != b'+' + && byte != b'\'' + && byte != b'"' + && !matches!( + byte, + b' ' | b'\t' | b'\n' | b'\r' | /* b'\v' */ 0x0B | /* b'\f' */ 0x0C + ) +} + +pub(crate) fn uescape_char(text: &str) -> Option { + let inner = text.strip_prefix('\'')?.strip_suffix('\'')?; + let &[byte] = inner.as_bytes() else { + return None; + }; + is_valid_uescape_char(byte).then(|| char::from(byte)) +} + #[cfg(test)] mod tests { use insta::assert_snapshot; diff --git a/crates/squawk_syntax/src/validation.rs b/crates/squawk_syntax/src/validation.rs index af27fbff..4ce02b31 100644 --- a/crates/squawk_syntax/src/validation.rs +++ b/crates/squawk_syntax/src/validation.rs @@ -7,7 +7,7 @@ use std::ops::Range; use crate::ast::AstNode; -use crate::unescape::escape_unicode_esc_str; +use crate::unescape::{escape_unicode_esc_str, uescape_char}; use crate::{SyntaxNode, SyntaxToken, ast, match_ast, syntax_error::SyntaxError}; use rowan::{TextRange, TextSize}; use squawk_parser::SyntaxKind::*; @@ -322,7 +322,7 @@ fn validate_unicode_esc_string(lit: &ast::Literal, acc: &mut Vec) { UNICODE_ESC_STRING => unicode_esc = Some(token), UESCAPE_KW => seen_uescape = true, STRING if seen_uescape => { - escape_char = match uescape_char(&token) { + escape_char = match uescape_char(token.text()) { Some(ch) => ch, None => { acc.push(SyntaxError::new( @@ -397,7 +397,7 @@ fn validate_unicode_esc_ident(token: &SyntaxToken, acc: &mut Vec) { UESCAPE_KW => seen_uescape = true, STRING if seen_uescape => { if let Some(string_token) = element.as_token() { - escape_char = match uescape_char(string_token) { + escape_char = match uescape_char(string_token.text()) { Some(ch) => ch, None => { acc.push(SyntaxError::new( @@ -432,26 +432,6 @@ fn offset_range(start: TextSize, range: Range) -> TextRange { TextRange::new(begin, end) } -// https://github.com/postgres/postgres/blob/228a1f9542792c6533ef74c2e7aefad0da1d9a7a/src/backend/parser/parser.c#L350 -const fn is_valid_uescape_char(byte: u8) -> bool { - !byte.is_ascii_hexdigit() - && byte != b'+' - && byte != b'\'' - && byte != b'"' - && !matches!( - byte, - b' ' | b'\t' | b'\n' | b'\r' | /* b'\v' */ 0x0B | /* b'\f' */ 0x0C - ) -} - -fn uescape_char(string_token: &SyntaxToken) -> Option { - let inner = string_token.text().strip_prefix('\'')?.strip_suffix('\'')?; - let &[byte] = inner.as_bytes() else { - return None; - }; - is_valid_uescape_char(byte).then(|| char::from(byte)) -} - fn validate_join_expr(join_expr: ast::JoinExpr, acc: &mut Vec) { let Some(join) = join_expr.join() else { return;