diff --git a/src/query/extended.rs b/src/query/extended.rs index f7cb910f..8eca4f38 100644 --- a/src/query/extended.rs +++ b/src/query/extended.rs @@ -1862,8 +1862,10 @@ impl ExtendedQueryHandler { || query_starts_with_ignore_case(&final_query, "DROP") || query_starts_with_ignore_case(&final_query, "ALTER") { Self::execute_ddl(framed, db, session, &final_query).await?; - } else if query_starts_with_ignore_case(&final_query, "BEGIN") - || query_starts_with_ignore_case(&final_query, "COMMIT") + } else if query_starts_with_ignore_case(&final_query, "BEGIN") + || query_starts_with_ignore_case(&final_query, "START") + || query_starts_with_ignore_case(&final_query, "COMMIT") + || query_starts_with_ignore_case(&final_query, "END") || query_starts_with_ignore_case(&final_query, "ROLLBACK") { Self::execute_transaction(framed, db, session, &final_query).await?; } else if crate::query::SetHandler::is_set_command(&final_query) { @@ -5632,11 +5634,13 @@ impl ExtendedQueryHandler { where T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, { - if query_starts_with_ignore_case(query, "BEGIN") { + if query_starts_with_ignore_case(query, "BEGIN") + || query_starts_with_ignore_case(query, "START") { db.begin_with_session(&session.id).await?; framed.send(BackendMessage::CommandComplete { tag: "BEGIN".to_string() }).await .map_err(PgSqliteError::Io)?; - } else if query_starts_with_ignore_case(query, "COMMIT") { + } else if query_starts_with_ignore_case(query, "COMMIT") + || query_starts_with_ignore_case(query, "END") { db.commit_with_session(&session.id).await?; framed.send(BackendMessage::CommandComplete { tag: "COMMIT".to_string() }).await .map_err(PgSqliteError::Io)?; diff --git a/src/query/query_type_detection.rs b/src/query/query_type_detection.rs index 20f4d221..7ff33547 100644 --- a/src/query/query_type_detection.rs +++ b/src/query/query_type_detection.rs @@ -38,10 +38,20 @@ impl QueryTypeDetector { match &bytes[0..5] { b"ALTER" | b"alter" | b"Alter" => return QueryType::Alter, b"BEGIN" | b"begin" | b"Begin" => return QueryType::Begin, + b"START" | b"start" | b"Start" => return QueryType::Begin, _ => {} } } - + + if bytes.len() >= 3 { + let first3 = &bytes[0..3]; + if (first3 == b"END" || first3 == b"end" || first3 == b"End") + && (bytes.len() == 3 || bytes[3].is_ascii_whitespace()) + { + return QueryType::Commit; + } + } + if bytes.len() >= 6 { match &bytes[0..6] { b"COMMIT" | b"commit" | b"Commit" => return QueryType::Commit, @@ -84,8 +94,13 @@ impl QueryTypeDetector { QueryType::Truncate } else if trimmed.len() >= 5 && trimmed[..5].eq_ignore_ascii_case("BEGIN") { QueryType::Begin + } else if trimmed.len() >= 5 && trimmed[..5].eq_ignore_ascii_case("START") { + QueryType::Begin } else if trimmed.len() >= 6 && trimmed[..6].eq_ignore_ascii_case("COMMIT") { QueryType::Commit + } else if trimmed.len() >= 3 && trimmed[..3].eq_ignore_ascii_case("END") + && (trimmed.len() == 3 || trimmed.as_bytes()[3].is_ascii_whitespace()) { + QueryType::Commit } else if trimmed.len() >= 8 && trimmed[..8].eq_ignore_ascii_case("ROLLBACK") { QueryType::Rollback } else if trimmed.len() >= 7 && trimmed[..7].eq_ignore_ascii_case("COMMENT") { @@ -232,6 +247,21 @@ mod tests { assert_eq!(QueryTypeDetector::detect_query_type("ROLLBACK"), QueryType::Rollback); assert_eq!(QueryTypeDetector::detect_query_type("rollback"), QueryType::Rollback); + // START TRANSACTION (issue #70) + assert_eq!(QueryTypeDetector::detect_query_type("START TRANSACTION"), QueryType::Begin); + assert_eq!(QueryTypeDetector::detect_query_type("start transaction"), QueryType::Begin); + assert_eq!(QueryTypeDetector::detect_query_type("Start Transaction"), QueryType::Begin); + assert_eq!(QueryTypeDetector::detect_query_type("START TRANSACTION ISOLATION LEVEL REPEATABLE READ"), QueryType::Begin); + + // END / END TRANSACTION (issue #70) + assert_eq!(QueryTypeDetector::detect_query_type("END"), QueryType::Commit); + assert_eq!(QueryTypeDetector::detect_query_type("end"), QueryType::Commit); + assert_eq!(QueryTypeDetector::detect_query_type("END TRANSACTION"), QueryType::Commit); + assert_eq!(QueryTypeDetector::detect_query_type("end transaction"), QueryType::Commit); + // Word boundary guard — END must not match identifiers starting with "END" + assert_ne!(QueryTypeDetector::detect_query_type("ENDLESS"), QueryType::Commit); + assert_ne!(QueryTypeDetector::detect_query_type("ENDTABLE"), QueryType::Commit); + assert_eq!(QueryTypeDetector::detect_query_type("EXPLAIN SELECT * FROM test"), QueryType::Other); assert_eq!(QueryTypeDetector::detect_query_type(" SELECT * FROM test"), QueryType::Select); @@ -242,6 +272,17 @@ mod tests { assert_eq!(QueryTypeDetector::detect_query_type("WiTh cte AS (SELECT 1) SELECT * FROM cte"), QueryType::Select); } + #[test] + fn test_is_transaction_with_synonyms() { + assert!(QueryTypeDetector::is_transaction("BEGIN")); + assert!(QueryTypeDetector::is_transaction("START TRANSACTION")); + assert!(QueryTypeDetector::is_transaction("COMMIT")); + assert!(QueryTypeDetector::is_transaction("END")); + assert!(QueryTypeDetector::is_transaction("END TRANSACTION")); + assert!(QueryTypeDetector::is_transaction("ROLLBACK")); + assert!(!QueryTypeDetector::is_transaction("SELECT 1")); + } + #[test] fn test_is_ddl() { assert!(QueryTypeDetector::is_ddl("CREATE TABLE test"));