diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.cc index 6923d7fafbe4..7fede0e22d5a 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/get_info_cache.cc @@ -25,6 +25,7 @@ #include "arrow/flight/sql/odbc/odbc_impl/exceptions.h" #include "arrow/scalar.h" #include "arrow/type_fwd.h" +#include "arrow/util/checked_cast.h" #include "arrow/flight/sql/odbc/odbc_impl/flight_sql_stream_chunk_buffer.h" #include "arrow/flight/sql/odbc/odbc_impl/scalar_function_reporter.h" @@ -76,6 +77,9 @@ #define ARROW_CONVERT_VARCHAR 19 namespace arrow::flight::sql::odbc { + +using arrow::internal::checked_cast; + namespace { // Return the corresponding field in SQLGetInfo's SQL_CONVERT_* field // types for the given Arrow SqlConvert enum value. @@ -190,7 +194,7 @@ inline int64_t ScalarToInt64(UnionScalar* scalar) { } inline std::string ScalarToBoolString(UnionScalar* scalar) { - return reinterpret_cast(scalar->child_value().get())->value ? "Y" : "N"; + return checked_cast(scalar->child_value().get())->value ? "Y" : "N"; } inline void SetDefaultIfMissing(std::unordered_map& cache, @@ -395,25 +399,21 @@ bool GetInfoCache::LoadInfoFromServer() { // Unused by ODBC. break; case SqlInfoOptions::SQL_DDL_SCHEMA: { - // GH-49500 TODO: use scalar bool to determine `SQL_CREATE_SCHEMA` and - // `SQL_DROP_SCHEMA` values - - // Note: this is a bitmask and we can't describe cascade or restrict - // flags. - info_[SQL_DROP_SCHEMA] = static_cast(SQL_DS_DROP_SCHEMA); - - // Note: this is a bitmask and we can't describe authorization or - // collation - info_[SQL_CREATE_SCHEMA] = static_cast(SQL_CS_CREATE_SCHEMA); + bool supported = + checked_cast(scalar->child_value().get())->value; + info_[SQL_DROP_SCHEMA] = + static_cast(supported ? SQL_DS_DROP_SCHEMA : 0); + info_[SQL_CREATE_SCHEMA] = + static_cast(supported ? SQL_CS_CREATE_SCHEMA : 0); break; } case SqlInfoOptions::SQL_DDL_TABLE: { - // GH-49500 TODO: use scalar bool to determine `SQL_CREATE_TABLE` and - // `SQL_DROP_TABLE` values - - // This is a bitmask and we cannot describe all clauses. - info_[SQL_CREATE_TABLE] = static_cast(SQL_CT_CREATE_TABLE); - info_[SQL_DROP_TABLE] = static_cast(SQL_DT_DROP_TABLE); + bool supported = + checked_cast(scalar->child_value().get())->value; + info_[SQL_CREATE_TABLE] = + static_cast(supported ? SQL_CT_CREATE_TABLE : 0); + info_[SQL_DROP_TABLE] = + static_cast(supported ? SQL_DT_DROP_TABLE : 0); break; } case SqlInfoOptions::SQL_ALL_TABLES_ARE_SELECTABLE: { @@ -426,7 +426,7 @@ bool GetInfoCache::LoadInfoFromServer() { } case SqlInfoOptions::SQL_NULL_PLUS_NULL_IS_NULL: { info_[SQL_CONCAT_NULL_BEHAVIOR] = static_cast( - reinterpret_cast(scalar->child_value().get())->value + checked_cast(scalar->child_value().get())->value ? SQL_CB_NULL : SQL_CB_NON_NULL); break; @@ -436,7 +436,7 @@ bool GetInfoCache::LoadInfoFromServer() { // SQL_SUPPORTS_DIFFERENT_TABLE_CORRELATION_NAMES since we need both // properties to determine the value for SQL_CORRELATION_NAME. supports_correlation_name = - reinterpret_cast(scalar->child_value().get())->value; + checked_cast(scalar->child_value().get())->value; break; } case SqlInfoOptions::SQL_SUPPORTS_DIFFERENT_TABLE_CORRELATION_NAMES: { @@ -444,7 +444,7 @@ bool GetInfoCache::LoadInfoFromServer() { // SQL_SUPPORTS_DIFFERENT_TABLE_CORRELATION_NAMES since we need both // properties to determine the value for SQL_CORRELATION_NAME. requires_different_correlation_name = - reinterpret_cast(scalar->child_value().get())->value; + checked_cast(scalar->child_value().get())->value; break; } case SqlInfoOptions::SQL_SUPPORTS_EXPRESSIONS_IN_ORDER_BY: { @@ -454,7 +454,7 @@ bool GetInfoCache::LoadInfoFromServer() { case SqlInfoOptions::SQL_SUPPORTS_ORDER_BY_UNRELATED: { // Note: this is the negation of the Flight SQL property. info_[SQL_ORDER_BY_COLUMNS_IN_SELECT] = - reinterpret_cast(scalar->child_value().get())->value + checked_cast(scalar->child_value().get())->value ? "N" : "Y"; break; @@ -465,7 +465,7 @@ bool GetInfoCache::LoadInfoFromServer() { } case SqlInfoOptions::SQL_SUPPORTS_NON_NULLABLE_COLUMNS: { info_[SQL_NON_NULLABLE_COLUMNS] = static_cast( - reinterpret_cast(scalar->child_value().get())->value + checked_cast(scalar->child_value().get())->value ? SQL_NNC_NON_NULL : SQL_NNC_NULL); break; @@ -475,10 +475,15 @@ bool GetInfoCache::LoadInfoFromServer() { break; } case SqlInfoOptions::SQL_CATALOG_AT_START: { - info_[SQL_CATALOG_LOCATION] = static_cast( - reinterpret_cast(scalar->child_value().get())->value - ? SQL_CL_START - : SQL_CL_END); + // Only use this as a fallback if ARROW_SQL_CATALOG_TERM has not already + // set SQL_CATALOG_LOCATION (to avoid conflicting writes depending on + // response key ordering). + SetDefaultIfMissing( + info_, SQL_CATALOG_LOCATION, + static_cast( + checked_cast(scalar->child_value().get())->value + ? SQL_CL_START + : SQL_CL_END)); break; } case SqlInfoOptions::SQL_SELECT_FOR_UPDATE_SUPPORTED: @@ -494,22 +499,22 @@ bool GetInfoCache::LoadInfoFromServer() { } case SqlInfoOptions::SQL_TRANSACTIONS_SUPPORTED: { transactions_supported = - reinterpret_cast(scalar->child_value().get())->value; + checked_cast(scalar->child_value().get())->value; break; } case SqlInfoOptions::SQL_DATA_DEFINITION_CAUSES_TRANSACTION_COMMIT: { transaction_ddl_commit = - reinterpret_cast(scalar->child_value().get())->value; + checked_cast(scalar->child_value().get())->value; break; } case SqlInfoOptions::SQL_DATA_DEFINITIONS_IN_TRANSACTIONS_IGNORED: { transaction_ddl_ignore = - reinterpret_cast(scalar->child_value().get())->value; + checked_cast(scalar->child_value().get())->value; break; } case SqlInfoOptions::SQL_BATCH_UPDATES_SUPPORTED: { info_[SQL_BATCH_SUPPORT] = static_cast( - reinterpret_cast(scalar->child_value().get())->value + checked_cast(scalar->child_value().get())->value ? SQL_BS_ROW_COUNT_EXPLICIT : 0); break; diff --git a/cpp/src/arrow/flight/sql/odbc/tests/connection_info_test.cc b/cpp/src/arrow/flight/sql/odbc/tests/connection_info_test.cc index 257e8affa48e..f3afe4a5a1bd 100644 --- a/cpp/src/arrow/flight/sql/odbc/tests/connection_info_test.cc +++ b/cpp/src/arrow/flight/sql/odbc/tests/connection_info_test.cc @@ -616,18 +616,11 @@ TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoAlterTable) { EXPECT_EQ(static_cast(0), value); } -TYPED_TEST(ConnectionInfoHandleTest, TestSQLGetInfoCatalogLocation) { - // GH-49482 TODO: resolve inconsitent return value for SQL_CATALOG_LOCATION and change - // test type to `ConnectionInfoTest` - this->ConnectWithString(this->GetConnectionString(), this->conn); - +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoCatalogLocation) { SQLUSMALLINT value; GetInfo(this->conn, SQL_CATALOG_LOCATION, &value); EXPECT_EQ(static_cast(0), value); - - EXPECT_EQ(SQL_SUCCESS, SQLDisconnect(this->conn)) - << GetOdbcErrorMessage(SQL_HANDLE_DBC, this->conn); } TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoCatalogName) { @@ -752,32 +745,18 @@ TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoDropDomain) { EXPECT_EQ(static_cast(0), value); } -TYPED_TEST(ConnectionInfoHandleTest, TestSQLGetInfoDropSchema) { - // GH-49482 TODO: resolve inconsitent return value for SQL_DROP_SCHEMA and change test - // type to `ConnectionInfoTest` - this->ConnectWithString(this->GetConnectionString(), this->conn); - +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoDropSchema) { SQLUINTEGER value; GetInfo(this->conn, SQL_DROP_SCHEMA, &value); - EXPECT_EQ(static_cast(0), value); - - EXPECT_EQ(SQL_SUCCESS, SQLDisconnect(this->conn)) - << GetOdbcErrorMessage(SQL_HANDLE_DBC, this->conn); + EXPECT_EQ(static_cast(SQL_DS_DROP_SCHEMA), value); } -TYPED_TEST(ConnectionInfoHandleTest, TestSQLGetInfoDropTable) { - // GH-49482 TODO: resolve inconsitent return value for SQL_DROP_TABLE and change test - // type to `ConnectionInfoTest` - this->ConnectWithString(this->GetConnectionString(), this->conn); - +TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoDropTable) { SQLUINTEGER value; GetInfo(this->conn, SQL_DROP_TABLE, &value); - EXPECT_EQ(static_cast(0), value); - - EXPECT_EQ(SQL_SUCCESS, SQLDisconnect(this->conn)) - << GetOdbcErrorMessage(SQL_HANDLE_DBC, this->conn); + EXPECT_EQ(static_cast(SQL_DT_DROP_TABLE), value); } TYPED_TEST(ConnectionInfoTest, TestSQLGetInfoDropTranslation) {