diff --git a/CMakeLists.txt b/CMakeLists.txt index 6df450c510c..7f89e2365d0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,6 +12,7 @@ option(YDB_SDK_TESTS "Build YDB C++ SDK tests" Off) option(YDB_SDK_EXAMPLES "Build YDB C++ SDK examples" On) option(YDB_SDK_ENABLE_OTEL_METRICS "Build OpenTelemetry metrics plugin" Off) option(YDB_SDK_ENABLE_OTEL_TRACE "Build OpenTelemetry trace plugin" Off) +option(YDB_SDK_ODBC "Build YDB ODBC driver" Off) set(YDB_SDK_GOOGLE_COMMON_PROTOS_TARGET "" CACHE STRING "Name of cmake target preparing google common proto library") option(YDB_SDK_USE_RAPID_JSON "Search for rapid json library in system" ON) @@ -64,6 +65,10 @@ add_subdirectory(plugins) #_ydb_sdk_validate_public_headers() +if (YDB_SDK_ODBC) + add_subdirectory(odbc) +endif() + if (YDB_SDK_EXAMPLES) add_subdirectory(examples) endif() diff --git a/CMakePresets.json b/CMakePresets.json index d33c11b6811..ef642c40fb6 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -56,6 +56,7 @@ "cacheVariables": { "YDB_SDK_TESTS": "TRUE", "YDB_SDK_EXAMPLES": "TRUE", + "YDB_SDK_ODBC": "TRUE", "ARCADIA_ROOT": "..", "ARCADIA_BUILD_ROOT": "." } diff --git a/cmake/common.cmake b/cmake/common.cmake index bd017425190..37db9cc9067 100644 --- a/cmake/common.cmake +++ b/cmake/common.cmake @@ -115,7 +115,7 @@ function(generate_enum_serilization Tgt Input) endfunction() function(add_global_library_for TgtName MainName) - add_library(${TgtName} STATIC ${ARGN}) + _ydb_sdk_add_library(${TgtName} STATIC ${ARGN}) if(APPLE) target_link_options(${MainName} INTERFACE "SHELL:-Wl,-force_load,$${TgtName}>") else() @@ -182,7 +182,7 @@ endfunction() function(_ydb_sdk_add_library Tgt) cmake_parse_arguments(ARG - "INTERFACE" "" "" + "INTERFACE;OBJECT;SHARED" "" "" ${ARGN} ) @@ -192,6 +192,12 @@ function(_ydb_sdk_add_library Tgt) set(libraryMode "INTERFACE") set(includeMode "INTERFACE") endif() + if (ARG_OBJECT) + set(libraryMode "OBJECT") + endif() + if (ARG_SHARED) + set(libraryMode "SHARED") + endif() add_library(${Tgt} ${libraryMode}) target_include_directories(${Tgt} ${includeMode} $ @@ -201,6 +207,7 @@ function(_ydb_sdk_add_library Tgt) target_compile_definitions(${Tgt} ${includeMode} YDB_SDK_OSS ) + set_property(TARGET ${Tgt} PROPERTY POSITION_INDEPENDENT_CODE ON) endfunction() function(_ydb_sdk_validate_public_headers) @@ -255,4 +262,3 @@ function(_ydb_sdk_validate_public_headers) ) target_include_directories(validate_public_interface PUBLIC ${YDB_SDK_BINARY_DIR}/__validate_headers_dir/include) endfunction() - diff --git a/cmake/external_libs.cmake b/cmake/external_libs.cmake index 4560fd662b3..2915a13d5b0 100644 --- a/cmake/external_libs.cmake +++ b/cmake/external_libs.cmake @@ -19,6 +19,10 @@ if (YDB_SDK_ENABLE_OTEL_METRICS OR YDB_SDK_ENABLE_OTEL_TRACE) find_package(opentelemetry-cpp REQUIRED) endif() +if (YDB_SDK_ODBC) + find_package(ODBC REQUIRED) +endif() + # RapidJSON if (YDB_SDK_USE_RAPID_JSON) find_package(RapidJSON REQUIRED) diff --git a/cmake/testing.cmake b/cmake/testing.cmake index 0d8ba73d64d..053bfa39baa 100644 --- a/cmake/testing.cmake +++ b/cmake/testing.cmake @@ -121,3 +121,35 @@ function(add_ydb_test) vcs_info(${YDB_TEST_NAME}) endfunction() + +if (YDB_SDK_ODBC) + function(add_odbc_test) + set(opts "") + set(oneval_args NAME WORKING_DIRECTORY OUTPUT_DIRECTORY) + set(multival_args SOURCES LINK_LIBRARIES LABELS) + cmake_parse_arguments(ODBC_TEST + "${opts}" + "${oneval_args}" + "${multival_args}" + ${ARGN} + ) + + add_ydb_test(GTEST + NAME ${ODBC_TEST_NAME} + SOURCES ${ODBC_TEST_SOURCES} + LINK_LIBRARIES + ${ODBC_TEST_LINK_LIBRARIES} + ODBC::ODBC + LABELS + integration + ${ODBC_TEST_LABELS} + ) + + target_compile_definitions(${ODBC_TEST_NAME} + PRIVATE + ODBC_DRIVER_PATH="$" + ) + + add_dependencies(${ODBC_TEST_NAME} ydb-odbc) + endfunction() +endif() diff --git a/odbc/CMakeLists.txt b/odbc/CMakeLists.txt new file mode 100644 index 00000000000..a74800a52d2 --- /dev/null +++ b/odbc/CMakeLists.txt @@ -0,0 +1,50 @@ +add_library(ydb-odbc SHARED + src/utils/attr.cpp + src/utils/escape.cpp + src/utils/cursor.cpp + src/utils/types.cpp + src/utils/util.cpp + src/utils/convert.cpp + src/utils/error_manager.cpp + src/odbc_driver.cpp + src/connection_attr.cpp + src/connection.cpp + src/statement_attr.cpp + src/statement.cpp + src/environment.cpp + src/metadata.cpp +) + +target_include_directories(ydb-odbc + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${CMAKE_CURRENT_SOURCE_DIR}/src + ${ODBC_INCLUDE_DIRS} +) + +target_link_libraries(ydb-odbc + PRIVATE + YDB-CPP-SDK::Query + YDB-CPP-SDK::Table + YDB-CPP-SDK::Scheme + YDB-CPP-SDK::Driver + ODBC::ODBC + odbcinst +) + +set_target_properties(ydb-odbc PROPERTIES + POSITION_INDEPENDENT_CODE ON +) + +include(GNUInstallDirs) + +install(TARGETS ydb-odbc + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} +) + +install(DIRECTORY include/ + DESTINATION include/ydb-odbc +) + +add_subdirectory(examples) +add_subdirectory(tests) diff --git a/odbc/README.md b/odbc/README.md new file mode 100644 index 00000000000..826666f30aa --- /dev/null +++ b/odbc/README.md @@ -0,0 +1,98 @@ +# YDB ODBC Driver + +ODBC driver for YDB. + +## Requirements + +- CMake 3.10 or higher +- C/C++ compiler with C11 and C++20 support +- YDB C++ SDK +- unixODBC (for Linux/macOS) + +## Build + +```bash +cmake -DYDB_SDK_ODBC=1 --preset release-test-clang +cmake --build --preset default +``` + +The shared library is produced as `build/odbc/libydb-odbc.so`. + +## Install + +```bash +cmake --install build --prefix /usr/local +``` + +## Configuration + +For `SQLConnect("YDB", ...)`, `isql -v YDB`, or `Driver=YDB`. + +**`odbcinst.ini`** — driver registration. Section `[YDB]` is the driver name used as `Driver=YDB` in connection strings and DSNs. `Driver` and `Setup` are the full path to `libydb-odbc.so`. Use `/etc/odbcinst.ini`, a file in `/etc/odbcinst.d/`, or set `ODBCSYSINI` to the directory that contains `odbcinst.ini`. + +```ini +[YDB] +Description=YDB ODBC Driver +Driver=/path/to/libydb-odbc.so +Setup=/path/to/libydb-odbc.so +``` + +**`odbc.ini`** — DSN named `YDB`. In section `[YDB]`: `Driver` is the registered driver name, `Server` is the YDB endpoint, `Database` is the database path. Use `/etc/odbc.ini` or set `ODBCINI` to your file path. + +```ini +[ODBC Data Sources] +YDB=YDB ODBC Driver + +[YDB] +Driver=YDB +Server=localhost:2136 +Database=/local +``` + +## Usage + +Example of connecting via isql: +```bash +isql -v YDB +``` + +Example usage in C: +```c +SQLHENV env; +SQLHDBC dbc; +SQLHSTMT stmt; + +// Initialize environment +SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env); +SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0); + +// Connect +SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc); +SQLConnect(dbc, (SQLCHAR*)"YDB", SQL_NTS, + (SQLCHAR*)"", SQL_NTS, + (SQLCHAR*)"", SQL_NTS); + +// Execute query +SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt); +SQLExecDirect(stmt, (SQLCHAR*)"SELECT * FROM mytable", SQL_NTS); + +// Cleanup +SQLFreeHandle(SQL_HANDLE_STMT, stmt); +SQLDisconnect(dbc); +SQLFreeHandle(SQL_HANDLE_DBC, dbc); +SQLFreeHandle(SQL_HANDLE_ENV, env); +``` + +Alternatively, use `SQLDriverConnect` with a connection string (does not require DSN in odbc.ini): +```c +SQLCHAR connStr[] = "Driver=YDB;Endpoint=localhost:2136;Database=/local"; +SQLDriverConnect(dbc, NULL, connStr, SQL_NTS, NULL, 0, NULL, SQL_DRIVER_NOPROMPT); +``` + +## Parameters + +Use names $p1, $p2, ... for parameter names + +## License + +Apache License 2.0 diff --git a/odbc/examples/CMakeLists.txt b/odbc/examples/CMakeLists.txt new file mode 100644 index 00000000000..88b1f27cc60 --- /dev/null +++ b/odbc/examples/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(basic) +add_subdirectory(scheme) diff --git a/odbc/examples/basic/CMakeLists.txt b/odbc/examples/basic/CMakeLists.txt new file mode 100644 index 00000000000..b99d1175f43 --- /dev/null +++ b/odbc/examples/basic/CMakeLists.txt @@ -0,0 +1,14 @@ +add_executable(odbc_basic + main.cpp +) + +target_link_libraries(odbc_basic + PRIVATE + ODBC::ODBC +) +target_compile_definitions(odbc_basic + PRIVATE + ODBC_DRIVER_PATH="$" +) + +add_dependencies(odbc_basic ydb-odbc) diff --git a/odbc/examples/basic/main.cpp b/odbc/examples/basic/main.cpp new file mode 100644 index 00000000000..8084e32f3d1 --- /dev/null +++ b/odbc/examples/basic/main.cpp @@ -0,0 +1,132 @@ +#include +#include + +#include + +void PrintOdbcError(SQLSMALLINT handleType, SQLHANDLE handle) { + SQLCHAR sqlState[6] = {0}; + SQLINTEGER nativeError = 0; + SQLCHAR message[256] = {0}; + SQLSMALLINT textLength = 0; + SQLGetDiagRec(handleType, handle, 1, sqlState, &nativeError, message, sizeof(message), &textLength); + std::cerr << "ODBC error: [" << sqlState << "] " << message << std::endl; +} + +int main() { + SQLHENV henv = nullptr; + SQLHDBC hdbc = nullptr; + SQLHSTMT hstmt = nullptr; + SQLRETURN ret; + + std::cout << "1. Allocating environment handle" << std::endl; + ret = SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &henv); + if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + std::cerr << "Error allocating environment handle" << std::endl; + return 1; + } + SQLSetEnvAttr(henv, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0); + + std::cout << "2. Allocating connection handle" << std::endl; + ret = SQLAllocHandle(SQL_HANDLE_DBC, henv, &hdbc); + if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + std::cerr << "Error allocating connection handle" << std::endl; + SQLFreeHandle(SQL_HANDLE_ENV, henv); + return 1; + } + + std::cout << "3. Building connection string" << std::endl; + std::string connStr = "Driver=" ODBC_DRIVER_PATH ";Endpoint=localhost:2136;Database=/local;"; + SQLCHAR outConnStr[1024] = {0}; + SQLSMALLINT outConnStrLen = 0; + + std::cout << "4. Connecting with SQLDriverConnect" << std::endl; + ret = SQLDriverConnect(hdbc, NULL, (SQLCHAR*)connStr.c_str(), SQL_NTS, + outConnStr, sizeof(outConnStr), &outConnStrLen, SQL_DRIVER_COMPLETE); + if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + std::cerr << "Error connecting with SQLDriverConnect" << std::endl; + PrintOdbcError(SQL_HANDLE_DBC, hdbc); + SQLFreeHandle(SQL_HANDLE_DBC, hdbc); + SQLFreeHandle(SQL_HANDLE_ENV, henv); + return 1; + } + + std::cout << "5. Allocating statement handle" << std::endl; + ret = SQLAllocHandle(SQL_HANDLE_STMT, hdbc, &hstmt); + if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + std::cerr << "Error allocating statement handle" << std::endl; + SQLDisconnect(hdbc); + SQLFreeHandle(SQL_HANDLE_DBC, hdbc); + SQLFreeHandle(SQL_HANDLE_ENV, henv); + return 1; + } + + std::cout << "6. Executing query" << std::endl; + SQLCHAR query[] = R"( + DECLARE $p1 AS Int64?; + SELECT id, data from test_table WHERE id == $p1; + )"; + + int64_t paramValue = 1; + SQLLEN paramInd = 0; + ret = SQLBindParameter(hstmt, 1, SQL_PARAM_INPUT, SQL_C_SBIGINT, SQL_BIGINT, 0, 0, ¶mValue, 0, ¶mInd); + if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + std::cerr << "Error binding parameter" << std::endl; + PrintOdbcError(SQL_HANDLE_STMT, hstmt); + SQLFreeHandle(SQL_HANDLE_STMT, hstmt); + SQLDisconnect(hdbc); + SQLFreeHandle(SQL_HANDLE_DBC, hdbc); + SQLFreeHandle(SQL_HANDLE_ENV, henv); + return 1; + } + + ret = SQLExecDirect(hstmt, query, SQL_NTS); + if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + std::cerr << "Error executing query" << std::endl; + PrintOdbcError(SQL_HANDLE_STMT, hstmt); + SQLFreeHandle(SQL_HANDLE_STMT, hstmt); + SQLDisconnect(hdbc); + SQLFreeHandle(SQL_HANDLE_DBC, hdbc); + SQLFreeHandle(SQL_HANDLE_ENV, henv); + return 1; + } + + std::cout << "7. Fetching result" << std::endl; + + SQLLEN ind = 0; + int value1 = 0; + if (SQLBindCol(hstmt, 1, SQL_C_SLONG, &value1, 0, &ind) != SQL_SUCCESS) { + std::cerr << "Error binding column 1" << std::endl; + PrintOdbcError(SQL_HANDLE_STMT, hstmt); + return 1; + } + + SQLCHAR value2[1024] = {0}; + if (SQLBindCol(hstmt, 2, SQL_C_CHAR, &value2, 1024, &ind) != SQL_SUCCESS) { + std::cerr << "Error binding column 2" << std::endl; + PrintOdbcError(SQL_HANDLE_STMT, hstmt); + return 1; + } + + while ((ret = SQLFetch(hstmt)) == SQL_SUCCESS || ret == SQL_SUCCESS_WITH_INFO) { + if (ret != SQL_SUCCESS) { + std::cerr << "Error fetching result" << std::endl; + PrintOdbcError(SQL_HANDLE_STMT, hstmt); + return 1; + } + + std::cout << "Result column 1: " << value1 << std::endl; + std::cout << "Result column 2: " << value2 << std::endl; + + std::cout << "--------------------------------" << std::endl; + } + + std::cout << "8. Cleaning up" << std::endl; + + SQLCloseCursor(hstmt); + SQLFreeHandle(SQL_HANDLE_STMT, hstmt); + SQLDisconnect(hdbc); + SQLFreeHandle(SQL_HANDLE_DBC, hdbc); + SQLFreeHandle(SQL_HANDLE_ENV, henv); + + return 0; +} diff --git a/odbc/examples/erlang_client/Makefile b/odbc/examples/erlang_client/Makefile new file mode 100644 index 00000000000..5f402ef9720 --- /dev/null +++ b/odbc/examples/erlang_client/Makefile @@ -0,0 +1,50 @@ +.PHONY: all compile run run-shell clean distclean check help + +CONN ?= Driver=YDB;Endpoint=localhost:2136;Database=/local; +ERLC ?= erlc +ERL ?= erl + +all: compile + +help: + @echo "YDB Series Example - Erlang ODBC" + @echo "" + @echo "Targets:" + @echo " make compile Compile Erlang modules" + @echo " make run Run the example" + @echo " make run CONN='...' Run with a custom connection string" + @echo " make run-shell Start Erlang shell with compiled modules" + @echo " make check Check Erlang ODBC availability" + @echo " make clean Remove compiled files" + @echo "" + @echo "Examples:" + @echo " make run" + @echo " make run CONN=\"Driver=YDB;Endpoint=myhost:2136;Database=/mydb;\"" + +prepare: + @mkdir -p ebin + +compile: prepare + @echo "Compiling Erlang modules..." + $(ERLC) -o ebin src/*.erl + @echo "Done." + +run: compile + @echo "Running YDB Series Example..." + $(ERL) -pa ebin -noshell -eval 'application:load(odbc), application:start(odbc), ydb_series_client:run("$(CONN)"), halt().' + +run-shell: compile + @echo "Starting Erlang shell with ydb_series_client..." + $(ERL) -pa ebin + +check: + @echo "Checking ODBC support in Erlang..." + @$(ERL) -noshell -eval 'application:load(odbc), io:format("~p~n", [application:start(odbc)]), halt().' + +clean: + @rm -rf ebin/*.beam + @rm -rf *.beam + @rm -rf erl_crash.dump + +distclean: clean + @rm -rf ebin diff --git a/odbc/examples/erlang_client/README.md b/odbc/examples/erlang_client/README.md new file mode 100644 index 00000000000..3a782ba9444 --- /dev/null +++ b/odbc/examples/erlang_client/README.md @@ -0,0 +1,30 @@ +# Erlang ODBC Series Example + +Minimal Erlang client for YDB ODBC. Mirrors the main scenario of the C++ `basic_example`: creates `series`, `seasons`, and `episodes` tables, fills them with test data, runs several queries, and drops the tables. + +## Requirements + +- Erlang/OTP with the `odbc` module +- unixODBC +- Built and registered YDB ODBC driver +- Running YDB instance reachable via the connection string + +Verify driver registration: + +```bash +odbcinst -q -d +``` + +## Running + +By default, `Driver=YDB;Endpoint=localhost:2136;Database=/local;` is used. + +```bash +make run +``` + +With a different connection string: + +```bash +make run CONN='...' +``` diff --git a/odbc/examples/erlang_client/src/sample_data.erl b/odbc/examples/erlang_client/src/sample_data.erl new file mode 100644 index 00000000000..5fccdce72ee --- /dev/null +++ b/odbc/examples/erlang_client/src/sample_data.erl @@ -0,0 +1,64 @@ +-module(sample_data). +-export([series/0, seasons/0, episodes/0, format_date/1]). + +series() -> + [ + [1, "IT Crowd", "The IT Crowd is a British sitcom by Channel 4.", days_from_date({2006, 2, 3})], + [2, "Silicon Valley", "Silicon Valley is an American comedy series.", days_from_date({2014, 4, 6})] + ]. + +seasons() -> + [ + [1, 1, "Season 1", days_from_date({2006, 2, 3}), days_from_date({2006, 5, 5})], + [1, 2, "Season 2", days_from_date({2007, 8, 24}), days_from_date({2007, 11, 16})], + [1, 3, "Season 3", days_from_date({2008, 11, 21}), days_from_date({2008, 12, 26})], + [1, 4, "Season 4", days_from_date({2010, 6, 25}), days_from_date({2010, 7, 30})], + [2, 1, "Season 1", days_from_date({2014, 4, 6}), days_from_date({2014, 6, 15})], + [2, 2, "Season 2", days_from_date({2015, 4, 12}), days_from_date({2015, 6, 14})], + [2, 3, "Season 3", days_from_date({2016, 4, 24}), days_from_date({2016, 6, 26})], + [2, 4, "Season 4", days_from_date({2017, 4, 23}), days_from_date({2017, 6, 25})], + [2, 5, "Season 5", days_from_date({2018, 3, 25}), days_from_date({2018, 5, 13})], + [2, 6, "Season 6", days_from_date({2019, 10, 27}), days_from_date({2019, 12, 8})] + ]. + +episodes() -> + [ + [1, 1, 1, "Yesterday's Jam", days_from_date({2006, 2, 3})], + [1, 1, 2, "Calamity Jen", days_from_date({2006, 2, 10})], + [1, 1, 3, "Fifty-Fifty", days_from_date({2006, 2, 17})], + [1, 1, 4, "The Red Door", days_from_date({2006, 2, 24})], + [1, 1, 5, "The Haunting of Bill Crouse", days_from_date({2006, 3, 3})], + [1, 1, 6, "Aunt Irma Visits", days_from_date({2006, 3, 10})], + [1, 2, 1, "The Work Outing", days_from_date({2007, 8, 24})], + [1, 2, 2, "Return of the Golden Child", days_from_date({2007, 8, 31})], + [1, 2, 3, "Moss and the German", days_from_date({2007, 9, 7})], + [2, 1, 1, "Minimum Viable Product", days_from_date({2014, 4, 6})], + [2, 1, 2, "The Cap Table", days_from_date({2014, 4, 13})], + [2, 1, 3, "Articles of Incorporation", days_from_date({2014, 4, 20})], + [2, 1, 4, "Fiduciary Duties", days_from_date({2014, 4, 27})], + [2, 1, 5, "Signaling Risk", days_from_date({2014, 5, 4})], + [2, 3, 1, "Founder Friendly", days_from_date({2016, 4, 24})], + [2, 3, 2, "Two in the Box", days_from_date({2016, 5, 1})], + [2, 3, 3, "Meinertzhagen's Haversack", days_from_date({2016, 5, 8})], + [2, 3, 4, "Maleant Data Systems Solutions", days_from_date({2016, 5, 15})], + [2, 5, 1, "Grow Fast or Die Slow", days_from_date({2018, 3, 25})], + [2, 5, 2, "Reorientation", days_from_date({2018, 4, 1})], + [2, 5, 3, "Chief Operating Officer", days_from_date({2018, 4, 8})], + [2, 5, 4, "Tech Evangelist", days_from_date({2018, 4, 15})], + [2, 5, 5, "Facial Recognition", days_from_date({2018, 4, 22})], + [2, 6, 1, "Artificial Emotional Intelligence", days_from_date({2019, 10, 27})], + [2, 6, 2, "Blood Money", days_from_date({2019, 11, 3})], + [2, 6, 3, "Hooli Smokes!", days_from_date({2019, 11, 10})], + [2, 6, 4, "Maximizing Alphaness", days_from_date({2019, 11, 17})], + [2, 6, 5, "Tethics", days_from_date({2019, 11, 24})], + [2, 6, 6, "RussFest", days_from_date({2019, 12, 1})], + [2, 6, 7, "Exit Event", days_from_date({2019, 12, 8})] + ]. + +days_from_date({Year, Month, Day}) -> + calendar:date_to_gregorian_days(Year, Month, Day) - calendar:date_to_gregorian_days(1970, 1, 1). + +format_date(Days) -> + Date = calendar:gregorian_days_to_date(Days + calendar:date_to_gregorian_days(1970, 1, 1)), + {Year, Month, Day} = Date, + io_lib:format("~4..0B-~2..0B-~2..0B", [Year, Month, Day]). diff --git a/odbc/examples/erlang_client/src/ydb_series_client.erl b/odbc/examples/erlang_client/src/ydb_series_client.erl new file mode 100644 index 00000000000..7b52b6687b5 --- /dev/null +++ b/odbc/examples/erlang_client/src/ydb_series_client.erl @@ -0,0 +1,223 @@ +-module(ydb_series_client). +-export([run/0, run/1, run_with_dsn/1]). + +run() -> + ConnectionString = "Driver=YDB;Endpoint=localhost:2136;Database=/local;", + run(ConnectionString). + +run(ConnectionString) when is_list(ConnectionString) -> + io:format("=== ODBC YDB Series Example ===~n"), + + application:load(odbc), + application:start(odbc), + + case odbc:connect(ConnectionString, [{tuple_format, list}]) of + {ok, Ref} -> + Result = run_example(Ref), + odbc:disconnect(Ref), + Result; + {error, Reason} -> + io:format("Connection failed: ~p~n", [Reason]), + error + end. + +run_with_dsn(DSN) -> + ConnectionString = lists:flatten(io_lib:format("DSN=~s;", [DSN])), + run(ConnectionString). + +run_example(Ref) -> + try + drop_tables(Ref), + create_tables(Ref), + fill_table_data(Ref), + select_simple(Ref), + upsert_simple(Ref), + select_with_params(Ref), + multistep(Ref), + select_seasons_by_series(Ref), + drop_tables(Ref), + + io:format("Completed successfully~n"), + ok + catch + Class:Reason:Stacktrace -> + io:format("~nError: ~p:~p~n", [Class, Reason]), + io:format("Stacktrace: ~p~n", [Stacktrace]), + error + end. + +create_tables(Ref) -> + Tables = [ + {"CREATE TABLE series ( + series_id Uint64, + title Utf8, + series_info Utf8, + release_date Uint64, + PRIMARY KEY (series_id) + );"}, + {"CREATE TABLE seasons ( + series_id Uint64, + season_id Uint64, + title Utf8, + first_aired Uint64, + last_aired Uint64, + PRIMARY KEY (series_id, season_id) + );"}, + {"CREATE TABLE episodes ( + series_id Uint64, + season_id Uint64, + episode_id Uint64, + title Utf8, + air_date Uint64, + PRIMARY KEY (series_id, season_id, episode_id) + );"} + ], + + lists:foreach(fun({Query}) -> + execute_update(Ref, Query) + end, Tables). + +fill_table_data(Ref) -> + SeriesData = sample_data:series(), + SeasonsData = sample_data:seasons(), + EpisodesData = sample_data:episodes(), + + lists:foreach(fun(Row) -> + [Id, Title, Info, Date] = Row, + Query = io_lib:format( + "UPSERT INTO series (series_id, title, series_info, release_date) VALUES (~p, \"~s\", \"~s\", ~p);", + [Id, escape_string(Title), escape_string(Info), Date] + ), + execute_update(Ref, Query) + end, SeriesData), + + lists:foreach(fun(Row) -> + [SeriesId, SeasonId, Title, FirstAired, LastAired] = Row, + Query = io_lib:format( + "UPSERT INTO seasons (series_id, season_id, title, first_aired, last_aired) VALUES (~p, ~p, \"~s\", ~p, ~p);", + [SeriesId, SeasonId, escape_string(Title), FirstAired, LastAired] + ), + execute_update(Ref, Query) + end, SeasonsData), + + lists:foreach(fun(Row) -> + [SeriesId, SeasonId, EpisodeId, Title, AirDate] = Row, + Query = io_lib:format( + "UPSERT INTO episodes (series_id, season_id, episode_id, title, air_date) VALUES (~p, ~p, ~p, \"~s\", ~p);", + [SeriesId, SeasonId, EpisodeId, escape_string(Title), AirDate] + ), + execute_update(Ref, Query) + end, EpisodesData), + + io:format("Inserted ~p series, ~p seasons, ~p episodes~n", + [length(SeriesData), length(SeasonsData), length(EpisodesData)]). + +select_simple(Ref) -> + Query = "SELECT CAST(series_id AS Utf8) AS series_id, title, CAST(release_date AS Date) AS release_date FROM series WHERE series_id = 1;", + Rows = selected_rows(Ref, select_simple, Query), + lists:foreach(fun(Row) -> + [Id, Title, ReleaseDate] = row_values(Row), + io:format("Series: Id=~p, Title=~p, Release=~p~n", [Id, Title, ReleaseDate]) + end, Rows). + +upsert_simple(Ref) -> + Query = "UPSERT INTO episodes (series_id, season_id, episode_id, title) VALUES (2, 6, 1, \"TBD\");", + execute_update(Ref, Query). + +select_with_params(Ref) -> + SeriesId = 2, + SeasonId = 3, + + Query = + "SELECT sa.title AS season_title, sr.title AS series_title " + "FROM seasons AS sa INNER JOIN series AS sr ON sa.series_id = sr.series_id " + "WHERE sa.series_id = CAST($p1 AS Uint64) AND sa.season_id = CAST($p2 AS Uint64);", + Params = [{sql_integer, [SeriesId]}, {sql_integer, [SeasonId]}], + + Rows = selected_param_rows(Ref, select_with_params, Query, Params), + lists:foreach(fun(Row) -> + [SeasonTitle, SeriesTitle] = row_values(Row), + io:format("Season: ~p (Series: ~p)~n", [SeasonTitle, SeriesTitle]) + end, Rows). + +multistep(Ref) -> + SeriesId = 2, + SeasonId = 5, + + Query1 = io_lib:format( + "SELECT CAST(first_aired AS Utf8) AS first_aired FROM seasons WHERE series_id = ~p AND season_id = ~p;", + [SeriesId, SeasonId] + ), + + [FirstAiredRow] = selected_rows(Ref, multistep_step1, Query1), + [Date] = row_values(FirstAiredRow), + FromDate = list_to_integer(Date), + + ToDate = FromDate + 15, + + Query2 = io_lib:format( + "SELECT CAST(season_id AS Utf8) AS season_id, CAST(episode_id AS Utf8) AS episode_id, title, CAST(air_date AS Utf8) AS air_date FROM episodes " + "WHERE series_id = ~p AND air_date >= ~p AND air_date <= ~p;", + [SeriesId, FromDate, ToDate] + ), + + Rows = selected_rows(Ref, multistep_step2, Query2), + lists:foreach(fun(Row) -> + [SId, EId, Title, AirDate] = row_values(Row), + io:format("Episode: S~pE~p ~p (aired: ~p)~n", [SId, EId, Title, AirDate]) + end, Rows). + +select_seasons_by_series(Ref) -> + SeriesList = [1, 2], + InClause = string:join([integer_to_list(X) || X <- SeriesList], ", "), + + Query = io_lib:format( + "SELECT CAST(series_id AS Utf8) AS series_id, CAST(season_id AS Utf8) AS season_id, title, CAST(first_aired AS Date) AS first_aired " + "FROM seasons WHERE series_id IN (~s) ORDER BY season_id;", + [InClause] + ), + + Rows = selected_rows(Ref, select_seasons_by_series, Query), + lists:foreach(fun(Row) -> + [SeriesId, SeasonId, Title, FirstAired] = row_values(Row), + io:format("Season: Series=~p, Season=~p, Title=~p, FirstAired=~p~n", + [SeriesId, SeasonId, Title, FirstAired]) + end, Rows). + +drop_tables(Ref) -> + Tables = ["series", "seasons", "episodes"], + + lists:foreach(fun(Table) -> + Query = io_lib:format("DROP TABLE ~s;", [Table]), + case odbc:sql_query(Ref, lists:flatten(Query)) of + {updated, _} -> ok; + {error, _} -> ok + end + end, Tables). + +escape_string(String) -> + EscapedBackslash = string:replace(String, "\\", "\\\\", all), + lists:flatten(string:replace(EscapedBackslash, "\"", "\\\"", all)). + +execute_update(Ref, Query) -> + case odbc:sql_query(Ref, lists:flatten(Query)) of + {updated, _} -> ok; + Error -> throw({query_failed, update, Error}) + end. + +selected_rows(Ref, Step, Query) -> + case odbc:sql_query(Ref, lists:flatten(Query)) of + {selected, _, Rows} -> Rows; + Error -> throw({query_failed, Step, Error}) + end. + +selected_param_rows(Ref, Step, Query, Params) -> + case odbc:param_query(Ref, lists:flatten(Query), Params) of + {selected, _, Rows} -> Rows; + Error -> throw({query_failed, Step, Error}) + end. + +row_values(Row) when is_tuple(Row) -> + tuple_to_list(Row); +row_values(Row) -> + Row. diff --git a/odbc/examples/scheme/CMakeLists.txt b/odbc/examples/scheme/CMakeLists.txt new file mode 100644 index 00000000000..ffab881aed5 --- /dev/null +++ b/odbc/examples/scheme/CMakeLists.txt @@ -0,0 +1,14 @@ +add_executable(odbc_scheme + main.cpp +) + +target_link_libraries(odbc_scheme + PRIVATE + ODBC::ODBC +) +target_compile_definitions(odbc_scheme + PRIVATE + ODBC_DRIVER_PATH="$" +) + +add_dependencies(odbc_scheme ydb-odbc) diff --git a/odbc/examples/scheme/main.cpp b/odbc/examples/scheme/main.cpp new file mode 100644 index 00000000000..3ae2cd6fe40 --- /dev/null +++ b/odbc/examples/scheme/main.cpp @@ -0,0 +1,116 @@ +#include +#include + +#include + +void PrintOdbcError(SQLSMALLINT handleType, SQLHANDLE handle) { + SQLCHAR sqlState[6] = {0}; + SQLINTEGER nativeError = 0; + SQLCHAR message[256] = {0}; + SQLSMALLINT textLength = 0; + SQLGetDiagRec(handleType, handle, 1, sqlState, &nativeError, message, sizeof(message), &textLength); + std::cerr << "ODBC error: [" << sqlState << "] " << message << std::endl; +} + +int main() { + SQLHENV henv = nullptr; + SQLHDBC hdbc = nullptr; + SQLHSTMT hstmt = nullptr; + SQLRETURN ret; + + std::cout << "1. Allocating environment handle" << std::endl; + ret = SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &henv); + if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + std::cerr << "Error allocating environment handle" << std::endl; + return 1; + } + SQLSetEnvAttr(henv, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0); + + std::cout << "2. Allocating connection handle" << std::endl; + ret = SQLAllocHandle(SQL_HANDLE_DBC, henv, &hdbc); + if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + std::cerr << "Error allocating connection handle" << std::endl; + SQLFreeHandle(SQL_HANDLE_ENV, henv); + return 1; + } + + std::cout << "3. Building connection string" << std::endl; + std::string connStr = "Driver=" ODBC_DRIVER_PATH ";Endpoint=localhost:2136;Database=/local;"; + SQLCHAR outConnStr[1024] = {0}; + SQLSMALLINT outConnStrLen = 0; + + std::cout << "4. Connecting with SQLDriverConnect" << std::endl; + ret = SQLDriverConnect(hdbc, NULL, (SQLCHAR*)connStr.c_str(), SQL_NTS, + outConnStr, sizeof(outConnStr), &outConnStrLen, SQL_DRIVER_COMPLETE); + if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + std::cerr << "Error connecting with SQLDriverConnect" << std::endl; + PrintOdbcError(SQL_HANDLE_DBC, hdbc); + SQLFreeHandle(SQL_HANDLE_DBC, hdbc); + SQLFreeHandle(SQL_HANDLE_ENV, henv); + return 1; + } + + std::cout << "5. Allocating statement handle" << std::endl; + ret = SQLAllocHandle(SQL_HANDLE_STMT, hdbc, &hstmt); + if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + std::cerr << "Error allocating statement handle" << std::endl; + SQLDisconnect(hdbc); + SQLFreeHandle(SQL_HANDLE_DBC, hdbc); + SQLFreeHandle(SQL_HANDLE_ENV, henv); + return 1; + } + + std::cout << "6. Getting tables" << std::endl; + + SQLCHAR pattern[] = "/local"; + SQLCHAR tableType[] = "TABLE"; + + ret = SQLTables(hstmt, NULL, 0, NULL, 0, pattern, SQL_NTS, tableType, SQL_NTS); + if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + std::cerr << "Error executing query" << std::endl; + PrintOdbcError(SQL_HANDLE_STMT, hstmt); + SQLFreeHandle(SQL_HANDLE_STMT, hstmt); + SQLDisconnect(hdbc); + SQLFreeHandle(SQL_HANDLE_DBC, hdbc); + SQLFreeHandle(SQL_HANDLE_ENV, henv); + return 1; + } + + std::cout << "7. Fetching result" << std::endl; + + SQLLEN ind = 0; + SQLCHAR value1[1024] = {0}; + if (SQLBindCol(hstmt, 3, SQL_C_CHAR, &value1, 1024, &ind) != SQL_SUCCESS) { + std::cerr << "Error binding column 1" << std::endl; + PrintOdbcError(SQL_HANDLE_STMT, hstmt); + return 1; + } + + SQLCHAR value2[1024] = {0}; + if (SQLBindCol(hstmt, 4, SQL_C_CHAR, &value2, 1024, &ind) != SQL_SUCCESS) { + std::cerr << "Error binding column 2" << std::endl; + PrintOdbcError(SQL_HANDLE_STMT, hstmt); + return 1; + } + + while ((ret = SQLFetch(hstmt)) == SQL_SUCCESS || ret == SQL_SUCCESS_WITH_INFO) { + if (ret != SQL_SUCCESS) { + std::cerr << "Error fetching result" << std::endl; + PrintOdbcError(SQL_HANDLE_STMT, hstmt); + return 1; + } + + std::cout << "Table name: " << value1 << std::endl; + std::cout << "Table type: " << value2 << std::endl; + + std::cout << "--------------------------------" << std::endl; + } + + std::cout << "8. Cleaning up" << std::endl; + SQLFreeHandle(SQL_HANDLE_STMT, hstmt); + SQLDisconnect(hdbc); + SQLFreeHandle(SQL_HANDLE_DBC, hdbc); + SQLFreeHandle(SQL_HANDLE_ENV, henv); + + return 0; +} diff --git a/odbc/odbc.ini b/odbc/odbc.ini new file mode 100644 index 00000000000..a1ba3c951c3 --- /dev/null +++ b/odbc/odbc.ini @@ -0,0 +1,9 @@ +[ODBC Data Sources] +YDB=YDB ODBC Driver + +[YDB] +Driver=YDB +Description=YDB Database Connection +Server=localhost:2136 +Database=/local +AuthMode=none diff --git a/odbc/odbcinst.ini b/odbc/odbcinst.ini new file mode 100644 index 00000000000..db2a9b8378e --- /dev/null +++ b/odbc/odbcinst.ini @@ -0,0 +1,4 @@ +[YDB] +Description=YDB ODBC Driver +Driver=/app/build/odbc/libydb-odbc.so +Setup=/app/build/odbc/libydb-odbc.so diff --git a/odbc/src/connection.cpp b/odbc/src/connection.cpp new file mode 100644 index 00000000000..91d632aaa66 --- /dev/null +++ b/odbc/src/connection.cpp @@ -0,0 +1,341 @@ +#include "connection.h" +#include "statement.h" +#include "utils/error_manager.h" + +#include +#include + +#include +#include +#include +#include + +#include +#include + +#include + +namespace NYdb { +namespace NOdbc { + +namespace { + +struct TDriverKey { + std::string Endpoint; + std::string Database; + + bool operator==(const TDriverKey& other) const noexcept { + return Endpoint == other.Endpoint && Database == other.Database; + } +}; + +struct TDriverKeyHash { + size_t operator()(const TDriverKey& key) const noexcept { + return std::hash{}(key.Endpoint) ^ (std::hash{}(key.Database) << 1U); + } +}; + +struct TDriverPool { + std::unordered_map, TDriverKeyHash> DriversByKey; + size_t InsertionsSinceCleanup = 0; +}; + +void CleanupExpiredDrivers(TDriverPool& pool) { + for (auto mapIt = pool.DriversByKey.begin(); mapIt != pool.DriversByKey.end();) { + if (mapIt->second.expired()) { + mapIt = pool.DriversByKey.erase(mapIt); + } else { + ++mapIt; + } + } +} + +std::shared_ptr AcquireSharedDriver(const std::string& endpoint, const std::string& database) { + static TDriverPool pool; + TDriverKey key{endpoint, database}; + auto it = pool.DriversByKey.find(key); + if (it != pool.DriversByKey.end()) { + if (std::shared_ptr existing = it->second.lock()) { + return existing; + } + } + auto driver = std::make_shared( + NYdb::TDriverConfig().SetEndpoint(endpoint).SetDatabase(database)); + pool.DriversByKey[std::move(key)] = driver; + ++pool.InsertionsSinceCleanup; + if (pool.InsertionsSinceCleanup >= 32) { + CleanupExpiredDrivers(pool); + pool.InsertionsSinceCleanup = 0; + } + return driver; +} + +} // namespace + +SQLRETURN TConnection::DriverConnect(const std::string& connectionString) { + std::map params; + size_t pos = 0; + while (pos < connectionString.size()) { + size_t eq = connectionString.find('=', pos); + if (eq == std::string::npos) { + break; + } + + size_t sc = connectionString.find(';', eq); + std::string key = connectionString.substr(pos, eq-pos); + std::string val = connectionString.substr(eq+1, (sc == std::string::npos ? std::string::npos : sc-eq-1)); + params[key] = val; + if (sc == std::string::npos) { + break; + } + pos = sc+1; + } + Endpoint_ = params.contains("Server") ? params["Server"] : params["Endpoint"]; + Database_ = params["Database"]; + DataSourceName_ = params.contains("DSN") ? params["DSN"] : ""; + + if (Endpoint_.empty() || Database_.empty()) { + throw TOdbcException("08001", 0, "Missing Endpoint (or Server) or Database in connection string"); + } + + TConnectionAttributes::NormalizeCatalogPath(Database_); + RecreateYdbClients(); + Attributes_.SetCurrentCatalog(Database_); + + return SQL_SUCCESS; +} + +SQLRETURN TConnection::Connect(const std::string& serverName, + const std::string& userName, + const std::string& auth) { + DataSourceName_ = serverName; + + char endpoint[256] = {0}; + char server[256] = {0}; + char database[256] = {0}; + + SQLGetPrivateProfileString(serverName.c_str(), "Endpoint", "", endpoint, sizeof(endpoint), nullptr); + SQLGetPrivateProfileString(serverName.c_str(), "Server", "", server, sizeof(server), nullptr); + SQLGetPrivateProfileString(serverName.c_str(), "Database", "", database, sizeof(database), nullptr); + + Endpoint_ = endpoint[0] ? endpoint : server; + Database_ = database; + + if (Endpoint_.empty() || Database_.empty()) { + throw TOdbcException("08001", 0, "Missing Endpoint (or Server) or Database in DSN"); + } + + TConnectionAttributes::NormalizeCatalogPath(Database_); + RecreateYdbClients(); + Attributes_.SetCurrentCatalog(Database_); + + return SQL_SUCCESS; +} + +SQLRETURN TConnection::Disconnect() { + QuerySession_.reset(); + Tx_.reset(); + DbmsVersionCache_.reset(); + DataSourceName_.clear(); + YdbSchemeClient_.reset(); + YdbTableClient_.reset(); + YdbClient_.reset(); + YdbDriver_.reset(); + return SQL_SUCCESS; +} + +NQuery::TSession& TConnection::GetOrCreateQuerySession() { + if (!QuerySession_) { + auto sessionResult = YdbClient_->GetSession().ExtractValueSync(); + NStatusHelpers::ThrowOnError(sessionResult); + QuerySession_.emplace(std::move(sessionResult.GetSession())); + } + return *QuerySession_; +} + +std::unique_ptr TConnection::CreateStatement() { + return std::make_unique(this); +} + +void TConnection::RemoveStatement(TStatement* stmt) { + Statements_.erase(std::remove_if(Statements_.begin(), Statements_.end(), + [stmt](const std::unique_ptr& s) { return s.get() == stmt; }), Statements_.end()); +} + +SQLRETURN TConnection::SetAutocommit(bool value) { + Attributes_.SetAutocommit(value); + if (Attributes_.GetAutocommit() && Tx_) { + auto status = Tx_->Commit().ExtractValueSync(); + NStatusHelpers::ThrowOnError(status); + Tx_.reset(); + } + return SQL_SUCCESS; +} + +bool TConnection::GetAutocommit() const { + return Attributes_.GetAutocommit(); +} + +SQLRETURN TConnection::SetConnectAttr(SQLINTEGER attr, SQLPOINTER value, SQLINTEGER stringLength) { + if (attr == SQL_ATTR_CURRENT_CATALOG) { + std::optional rebindDatabase; + SQLRETURN rc = Attributes_.ApplyCatalogChange(value, stringLength, Database_, rebindDatabase, *this); + if (rc != SQL_SUCCESS) { + return rc; + } + if (rebindDatabase) { + RebindToDatabase(*rebindDatabase); + } + return SQL_SUCCESS; + } + return Attributes_.SetConnectAttr(attr, value, stringLength, [this](bool autocommit) { + return SetAutocommit(autocommit); + }, *this); +} + +SQLRETURN TConnection::GetConnectAttr(SQLINTEGER attr, SQLPOINTER value, SQLINTEGER bufferLength, + SQLINTEGER* stringLengthPtr) { + return Attributes_.GetConnectAttr(attr, value, bufferLength, stringLengthPtr, *this); +} + +NQuery::TTxSettings TConnection::MakeTxSettings() const { + return Attributes_.MakeTxSettings(); +} + +const std::optional& TConnection::GetTx() { + return Tx_; +} + +void TConnection::SetTx(const NQuery::TTransaction& tx) { + Tx_ = tx; +} + +void TConnection::ResetTx() { + Tx_.reset(); +} + +void TConnection::ResetQuerySession() { + QuerySession_.reset(); +} + +SQLRETURN TConnection::CommitTx() { + if (!Tx_) { + return AddError("25000", 0, "Invalid transaction state: no active transaction"); + } + auto status = Tx_->Commit().ExtractValueSync(); + NStatusHelpers::ThrowOnError(status); + Tx_.reset(); + return SQL_SUCCESS; +} + +SQLRETURN TConnection::RollbackTx() { + if (!Tx_) { + return AddError("25000", 0, "Invalid transaction state: no active transaction"); + } + auto status = Tx_->Rollback().ExtractValueSync(); + NStatusHelpers::ThrowOnError(status); + Tx_.reset(); + return SQL_SUCCESS; +} + +void TConnection::SetEnvironment(TEnvironment* env){ + if (ParentEnv_){ + throw std::logic_error("Connection already bound to environment"); + } + ParentEnv_ = env; +} + +TEnvironment* TConnection::GetEnvironment(){ + return ParentEnv_; +} + +const std::string& TConnection::GetDataSourceName() const { + return DataSourceName_; +} + +SQLUINTEGER TConnection::GetSupportedTxnIsolationOptions() const { + return Attributes_.GetSupportedTxnIsolationOptions(); +} + +bool TConnection::IsDataSourceReadOnly() const { + return Attributes_.GetAccessMode() == SQL_MODE_READ_ONLY; +} + +const std::string& TConnection::GetDbmsVersion() { + if (DbmsVersionCache_) { + return *DbmsVersionCache_; + } + + auto* client = GetClient(); + if (!client) { + throw TOdbcException("08003", 0, "Connection is not established"); + } + + std::optional fetched; + const NYdb::TStatus status = client->RetryQuerySync( + [&fetched](NQuery::TSession session) -> NYdb::TStatus { + auto result = session.ExecuteQuery( + "SELECT Version();", + NQuery::TTxControl::NoTx(), + NYdb::TParamsBuilder().Build()).ExtractValueSync(); + if (!result.IsSuccess()) { + return result; + } + if (result.GetResultSets().empty()) { + return NYdb::TStatus(EStatus::SUCCESS, NYdb::NIssue::TIssues()); + } + TResultSetParser parser(result.GetResultSetParser(0)); + if (parser.TryNextRow()) { + fetched = parser.ColumnParser(0).GetUtf8(); + } + return NYdb::TStatus(EStatus::SUCCESS, NYdb::NIssue::TIssues()); + }); + + NStatusHelpers::ThrowOnError(status); + if (!fetched || fetched->empty()) { + throw TOdbcException("HY000", 0, "Failed to retrieve DBMS version"); + } + + DbmsVersionCache_ = std::move(*fetched); + return *DbmsVersionCache_; +} + +void TConnection::RecreateYdbClients() { + QuerySession_.reset(); + Tx_.reset(); + DbmsVersionCache_.reset(); + YdbSchemeClient_.reset(); + YdbTableClient_.reset(); + YdbClient_.reset(); + YdbDriver_ = AcquireSharedDriver(Endpoint_, Database_); + YdbClient_ = std::make_unique(*YdbDriver_); + YdbSchemeClient_ = std::make_unique(*YdbDriver_); + YdbTableClient_ = std::make_unique(*YdbDriver_); +} + +void TConnection::RebindToDatabase(const std::string& newDatabase) { + std::string db = newDatabase; + TConnectionAttributes::NormalizeCatalogPath(db); + Database_ = std::move(db); + Attributes_.SetCurrentCatalog(Database_); + RecreateYdbClients(); +} + + +std::string TConnection::WrapQueryForCurrentCatalog(const std::string& sql) const { + std::optional rel = Attributes_.ResolveCatalogRoute(Database_).TablePathPrefix; + if (!rel) { + return sql; + } + std::string escapedPrefix; + escapedPrefix.reserve(rel->size() + 8); + for (const char ch : *rel) { + if (ch == '\\' || ch == '"') { + escapedPrefix.push_back('\\'); + } + escapedPrefix.push_back(ch); + } + return "PRAGMA TablePathPrefix = \"" + escapedPrefix + "\";\n" + sql; +} +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/connection.h b/odbc/src/connection.h new file mode 100644 index 00000000000..0bf9cc3d78c --- /dev/null +++ b/odbc/src/connection.h @@ -0,0 +1,88 @@ +#pragma once + +#include "environment.h" +#include "connection_attr.h" +#include "utils/error_manager.h" + +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +namespace NYdb { +namespace NOdbc { + +class TStatement; + +class TConnection : public TErrorManager { +private: + std::shared_ptr YdbDriver_; + std::unique_ptr YdbClient_; + std::unique_ptr YdbTableClient_; + std::unique_ptr YdbSchemeClient_; + std::optional Tx_; + std::optional QuerySession_; + + std::vector> Statements_; + std::string Endpoint_; + std::string Database_; + std::string DataSourceName_; + std::string AuthToken_; + TEnvironment* ParentEnv_; + + TConnectionAttributes Attributes_; + mutable std::optional DbmsVersionCache_; + + void RecreateYdbClients(); + void RebindToDatabase(const std::string& newDatabase); +public: + SQLRETURN Connect(const std::string& serverName, + const std::string& userName, + const std::string& auth); + + SQLRETURN DriverConnect(const std::string& connectionString); + SQLRETURN Disconnect(); + + std::unique_ptr CreateStatement(); + void RemoveStatement(TStatement* stmt); + + NYdb::NQuery::TQueryClient* GetClient() { return YdbClient_.get(); } + NQuery::TSession& GetOrCreateQuerySession(); + NYdb::NTable::TTableClient* GetTableClient() { return YdbTableClient_.get(); } + NScheme::TSchemeClient* GetSchemeClient() { return YdbSchemeClient_.get(); } + + SQLRETURN SetAutocommit(bool value); + bool GetAutocommit() const; + + SQLRETURN SetConnectAttr(SQLINTEGER attr, SQLPOINTER value, SQLINTEGER stringLength); + SQLRETURN GetConnectAttr(SQLINTEGER attr, SQLPOINTER value, SQLINTEGER bufferLength, SQLINTEGER* stringLengthPtr); + NQuery::TTxSettings MakeTxSettings() const; + + std::string WrapQueryForCurrentCatalog(const std::string& sql) const; + const std::string& GetDbmsVersion(); + const std::string& GetDataSourceName() const; + SQLUINTEGER GetSupportedTxnIsolationOptions() const; + bool IsDataSourceReadOnly() const; + + const std::optional& GetTx(); + void SetTx(const NQuery::TTransaction& tx); + void ResetTx(); + void ResetQuerySession(); + + SQLRETURN CommitTx(); + SQLRETURN RollbackTx(); + + void SetEnvironment(TEnvironment* env); + TEnvironment* GetEnvironment(); +}; + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/connection_attr.cpp b/odbc/src/connection_attr.cpp new file mode 100644 index 00000000000..576f093b656 --- /dev/null +++ b/odbc/src/connection_attr.cpp @@ -0,0 +1,319 @@ + +#include "connection_attr.h" +#include "utils/attr.h" +#include "utils/diag.h" + +#include + +namespace NYdb { +namespace NOdbc { + +namespace { + +namespace Catalog { + +void NormalizePath(std::string& path) { + if (path.empty() || path == "/") { + return; + } + const size_t trailingSlashStart = path.find_last_not_of('/'); + if (trailingSlashStart == std::string::npos) { + path.assign("/"); + return; + } + path.erase(trailingSlashStart + 1); +} + +TConnectionAttributes::TCatalogBinding BuildBinding(const std::string& currentCatalog, const std::string& database) { + TConnectionAttributes::TCatalogBinding binding; + binding.Catalog = currentCatalog; + binding.Database = database; + NormalizePath(binding.Catalog); + NormalizePath(binding.Database); + if (binding.Catalog == binding.Database) { + return binding; + } + + const std::string databasePrefix = binding.Database + "/"; + if (binding.Catalog.size() <= databasePrefix.size() || + binding.Catalog.compare(0, databasePrefix.size(), databasePrefix) != 0) { + return binding; + } + + std::string relativeCatalog = binding.Catalog.substr(databasePrefix.size()); + if (!relativeCatalog.empty()) { + binding.RelativeCatalog = std::move(relativeCatalog); + } + return binding; +} + +} // namespace Catalog + +namespace Tx { + +bool IsKnownTxnIsolation(SQLUINTEGER txnIsolation) { + switch (txnIsolation) { + case SQL_TXN_READ_UNCOMMITTED: + case SQL_TXN_READ_COMMITTED: + case SQL_TXN_REPEATABLE_READ: + case SQL_TXN_SERIALIZABLE: + return true; + default: + return false; + } +} + +std::optional ResolveTxMode(SQLUINTEGER accessMode, SQLUINTEGER txnIsolation) { + if (accessMode == SQL_MODE_READ_ONLY) { + return NQuery::TTxSettings::TS_SNAPSHOT_RO; + } + + switch (txnIsolation) { + case SQL_TXN_REPEATABLE_READ: + return NQuery::TTxSettings::TS_SNAPSHOT_RW; + case SQL_TXN_SERIALIZABLE: + return NQuery::TTxSettings::TS_SERIALIZABLE_RW; + default: + return std::nullopt; + } +} + +} // namespace Tx + +namespace Autocommit { + +SQLRETURN Get(bool autocommitEnabled, SQLPOINTER value) { + auto* out = reinterpret_cast(value); + *out = autocommitEnabled ? SQL_AUTOCOMMIT_ON : SQL_AUTOCOMMIT_OFF; + return SQL_SUCCESS; +} + +} // namespace Autocommit + +} + +void TConnectionAttributes::NormalizeCatalogPath(std::string& path) { + Catalog::NormalizePath(path); +} + +SQLRETURN TConnectionAttributes::SetAutocommit(bool value) { + Autocommit_ = value; + return SQL_SUCCESS; +} + +bool TConnectionAttributes::GetAutocommit() const { + return Autocommit_; +} + +SQLRETURN TConnectionAttributes::SetConnectAttr( + SQLINTEGER attr, + SQLPOINTER value, + SQLINTEGER stringLength, + const std::function& applyAutocommit, + TErrorManager& errors) { + switch (attr) { + case SQL_ATTR_AUTOCOMMIT: + return SetAutocommit(value, applyAutocommit, errors); + case SQL_ATTR_ACCESS_MODE: + return SetAccessMode(value, errors); + case SQL_ATTR_TXN_ISOLATION: + return SetTxnIsolation(value, errors); + case SQL_ATTR_CURRENT_CATALOG: + return SetCurrentCatalog(value, stringLength, errors); + default: + return Diag::AddNotImplemented(errors); + } +} + +SQLRETURN TConnectionAttributes::GetConnectAttr( + SQLINTEGER attr, + SQLPOINTER value, + SQLINTEGER bufferLength, + SQLINTEGER* stringLengthPtr, + TErrorManager& errors) const { + if (!value) { + return Diag::AddNullPointer(errors); + } + if (stringLengthPtr) { + *stringLengthPtr = 0; + } + switch (attr) { + case SQL_ATTR_AUTOCOMMIT: + return GetAutocommit(value); + case SQL_ATTR_ACCESS_MODE: + return GetAccessMode(value); + case SQL_ATTR_TXN_ISOLATION: + return GetTxnIsolation(value); + case SQL_ATTR_CURRENT_CATALOG: + return GetCurrentCatalog(value, bufferLength, stringLengthPtr, errors); + default: + return Diag::AddNotImplemented(errors); + } +} + +SQLRETURN TConnectionAttributes::SetAutocommit( + SQLPOINTER value, + const std::function& applyAutocommit, + TErrorManager& errors) { + const auto token = ReadIntegerAttrIfIn( + value, + {static_cast(SQL_AUTOCOMMIT_ON), static_cast(SQL_AUTOCOMMIT_OFF)}); + if (!token) { + return Diag::AddInvalidAttrValue(errors, "SQL_ATTR_AUTOCOMMIT"); + } + if (*token == static_cast(SQL_AUTOCOMMIT_ON)) { + return applyAutocommit(true); + } + return applyAutocommit(false); +} + +SQLRETURN TConnectionAttributes::SetAccessMode(SQLPOINTER value, TErrorManager& errors) { + const auto mode = ReadIntegerAttrIfIn(value, {SQL_MODE_READ_WRITE, SQL_MODE_READ_ONLY}); + if (!mode) { + return Diag::AddInvalidAttrValue(errors, "SQL_ATTR_ACCESS_MODE"); + } + AccessMode_ = *mode; + auto txMode = Tx::ResolveTxMode(AccessMode_, TxnIsolation_); + if (!txMode) { + return errors.AddError( + "HYC00", + 0, + AccessMode_ == SQL_MODE_READ_WRITE + ? "Transaction isolation is not supported for read-write mode" + : "Transaction isolation is not supported for read-only mode"); + } + TxMode_ = *txMode; + return SQL_SUCCESS; +} + +SQLRETURN TConnectionAttributes::SetTxnIsolation(SQLPOINTER value, TErrorManager& errors) { + const SQLUINTEGER isolation = ReadIntegerAttr(value); + if (!Tx::IsKnownTxnIsolation(isolation)) { + return Diag::AddInvalidAttrValue(errors, "SQL_ATTR_TXN_ISOLATION"); + } + auto txMode = Tx::ResolveTxMode(AccessMode_, isolation); + if (!txMode) { + return errors.AddError("HYC00", 0, "SQL_ATTR_TXN_ISOLATION value is not supported"); + } + TxnIsolation_ = isolation; + TxMode_ = *txMode; + return SQL_SUCCESS; +} + +SQLRETURN TConnectionAttributes::SetCurrentCatalog(SQLPOINTER value, SQLINTEGER stringLength, TErrorManager& errors) { + if (!value) { + return Diag::AddNullPointer(errors); + } + CurrentCatalog_ = ReadAttributeString(value, stringLength); + Catalog::NormalizePath(CurrentCatalog_); + if (CurrentCatalog_.empty()) { + return Diag::AddInvalidAttrValue(errors, "SQL_ATTR_CURRENT_CATALOG"); + } + return SQL_SUCCESS; +} + +SQLRETURN TConnectionAttributes::GetAutocommit(SQLPOINTER value) const { + return Autocommit::Get(Autocommit_, value); +} + +SQLRETURN TConnectionAttributes::GetAccessMode(SQLPOINTER value) const { + auto* out = reinterpret_cast(value); + *out = AccessMode_; + return SQL_SUCCESS; +} + +SQLRETURN TConnectionAttributes::GetTxnIsolation(SQLPOINTER value) const { + auto* out = reinterpret_cast(value); + *out = TxnIsolation_; + return SQL_SUCCESS; +} + +SQLUINTEGER TConnectionAttributes::GetAccessMode() const { + return AccessMode_; +} + +SQLUINTEGER TConnectionAttributes::GetSupportedTxnIsolationOptions() const { + static constexpr SQLUINTEGER kLevels[] = { + SQL_TXN_READ_UNCOMMITTED, + SQL_TXN_READ_COMMITTED, + SQL_TXN_REPEATABLE_READ, + SQL_TXN_SERIALIZABLE, + }; + SQLUINTEGER mask = 0; + for (const SQLUINTEGER level : kLevels) { + if (Tx::ResolveTxMode(AccessMode_, level)) { + mask |= level; + } + } + return mask; +} + +SQLRETURN TConnectionAttributes::GetCurrentCatalog( + SQLPOINTER value, + SQLINTEGER bufferLength, + SQLINTEGER* stringLengthPtr, + TErrorManager& errors) const { + return WriteAttributeString(CurrentCatalog_, value, bufferLength, stringLengthPtr, errors); +} + +NQuery::TTxSettings TConnectionAttributes::MakeTxSettings() const { + switch (TxMode_) { + case NQuery::TTxSettings::TS_ONLINE_RO: + return NQuery::TTxSettings::OnlineRO(); + case NQuery::TTxSettings::TS_STALE_RO: + return NQuery::TTxSettings::StaleRO(); + case NQuery::TTxSettings::TS_SNAPSHOT_RO: + return NQuery::TTxSettings::SnapshotRO(); + case NQuery::TTxSettings::TS_SNAPSHOT_RW: + return NQuery::TTxSettings::SnapshotRW(); + case NQuery::TTxSettings::TS_SERIALIZABLE_RW: + default: + return NQuery::TTxSettings::SerializableRW(); + } +} + +void TConnectionAttributes::SetCurrentCatalog(const std::string& value) { + CurrentCatalog_ = value; + Catalog::NormalizePath(CurrentCatalog_); +} + +const std::string& TConnectionAttributes::GetCurrentCatalog() const { + return CurrentCatalog_; +} + +TConnectionAttributes::TCatalogBinding TConnectionAttributes::BuildCatalogBinding(const std::string& database) const { + return Catalog::BuildBinding(CurrentCatalog_, database); +} + +TConnectionAttributes::TCatalogRoute TConnectionAttributes::ResolveCatalogRoute(const std::string& currentDatabase) const { + const TCatalogBinding binding = BuildCatalogBinding(currentDatabase); + if (binding.Catalog == binding.Database) { + return {binding.Database, std::nullopt}; + } + if (binding.RelativeCatalog) { + return {binding.Database, binding.Catalog}; + } + return {binding.Catalog, std::nullopt}; +} + +SQLRETURN TConnectionAttributes::ApplyCatalogChange( + SQLPOINTER value, + SQLINTEGER stringLength, + const std::string& currentDatabase, + std::optional& rebindDatabase, + TErrorManager& errors) { + SQLRETURN rc = SetCurrentCatalog(value, stringLength, errors); + if (rc != SQL_SUCCESS) { + return rc; + } + const TCatalogRoute route = ResolveCatalogRoute(currentDatabase); + if (route.EffectiveDatabase != currentDatabase) { + rebindDatabase = route.EffectiveDatabase; + } else { + rebindDatabase.reset(); + } + return SQL_SUCCESS; +} + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/connection_attr.h b/odbc/src/connection_attr.h new file mode 100644 index 00000000000..cd530554f2f --- /dev/null +++ b/odbc/src/connection_attr.h @@ -0,0 +1,88 @@ +#pragma once + +#include "utils/error_manager.h" + +#include + +#include +#include +#include + +#include +#include + +namespace NYdb { +namespace NOdbc { + +class TConnectionAttributes { +public: + struct TCatalogBinding { + std::string Catalog; + std::string Database; + std::optional RelativeCatalog; + }; + + struct TCatalogRoute { + std::string EffectiveDatabase; + std::optional TablePathPrefix; + }; + + SQLRETURN SetAutocommit(bool value); + bool GetAutocommit() const; + + SQLRETURN SetConnectAttr( + SQLINTEGER attr, + SQLPOINTER value, + SQLINTEGER stringLength, + const std::function& applyAutocommit, + TErrorManager& errors); + + SQLRETURN GetConnectAttr( + SQLINTEGER attr, + SQLPOINTER value, + SQLINTEGER bufferLength, + SQLINTEGER* stringLengthPtr, + TErrorManager& errors) const; + + NQuery::TTxSettings MakeTxSettings() const; + void SetCurrentCatalog(const std::string& value); + const std::string& GetCurrentCatalog() const; + TCatalogBinding BuildCatalogBinding(const std::string& database) const; + TCatalogRoute ResolveCatalogRoute(const std::string& currentDatabase) const; + SQLRETURN ApplyCatalogChange( + SQLPOINTER value, + SQLINTEGER stringLength, + const std::string& currentDatabase, + std::optional& rebindDatabase, + TErrorManager& errors); + static void NormalizeCatalogPath(std::string& path); + SQLUINTEGER GetSupportedTxnIsolationOptions() const; + SQLUINTEGER GetAccessMode() const; + +private: + SQLRETURN SetAutocommit( + SQLPOINTER value, + const std::function& applyAutocommit, + TErrorManager& errors); + SQLRETURN SetAccessMode(SQLPOINTER value, TErrorManager& errors); + SQLRETURN SetTxnIsolation(SQLPOINTER value, TErrorManager& errors); + SQLRETURN SetCurrentCatalog(SQLPOINTER value, SQLINTEGER stringLength, TErrorManager& errors); + + SQLRETURN GetAutocommit(SQLPOINTER value) const; + SQLRETURN GetAccessMode(SQLPOINTER value) const; + SQLRETURN GetTxnIsolation(SQLPOINTER value) const; + SQLRETURN GetCurrentCatalog( + SQLPOINTER value, + SQLINTEGER bufferLength, + SQLINTEGER* stringLengthPtr, + TErrorManager& errors) const; + + bool Autocommit_ = true; + std::string CurrentCatalog_; + SQLUINTEGER AccessMode_ = SQL_MODE_READ_WRITE; + SQLUINTEGER TxnIsolation_ = SQL_TXN_SERIALIZABLE; + NQuery::TTxSettings::ETransactionMode TxMode_ = NQuery::TTxSettings::TS_SERIALIZABLE_RW; +}; + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/environment.cpp b/odbc/src/environment.cpp new file mode 100644 index 00000000000..8df0949faa8 --- /dev/null +++ b/odbc/src/environment.cpp @@ -0,0 +1,85 @@ +#include "environment.h" +#include "connection.h" + + #include + #include + +namespace NYdb { +namespace NOdbc { + +TEnvironment::TEnvironment() : OdbcVersion_(SQL_OV_ODBC3) {} +TEnvironment::~TEnvironment() {} + +SQLRETURN TEnvironment::SetAttribute(SQLINTEGER attribute, SQLPOINTER value, SQLINTEGER stringLength) { + switch (attribute) { + case SQL_ATTR_ODBC_VERSION: { + if (!value) { + return AddError("HY009", 0, "Invalid use of null pointer"); + } + OdbcVersion_ = static_cast(reinterpret_cast(value)); + return SQL_SUCCESS; + } + case SQL_ATTR_OUTPUT_NTS: { + if (value && static_cast(reinterpret_cast(value)) != SQL_TRUE) { + return AddError("HY024", 0, "SQL_ATTR_OUTPUT_NTS must be SQL_TRUE"); + } + return SQL_SUCCESS; + } + default: + return AddError("HYC00", 0, "Optional feature not implemented"); + } +} + +void TEnvironment::RegisterConnection(TConnection* conn){ + if (conn == nullptr){ + throw std::invalid_argument("null connection"); + } + connections_.insert(conn); +} + +void TEnvironment::UnregisterConnection(TConnection* conn){ + if (conn == nullptr){ + throw std::invalid_argument("null connection"); + } + connections_.erase(conn); +} + +std::vector TEnvironment::GetConnectionsSnapshot() const { + return std::vector(connections_.begin(), connections_.end()); +} + +SQLRETURN TEnvironment::EndTran(SQLSMALLINT completionType){ + if (completionType != SQL_COMMIT && completionType != SQL_ROLLBACK){ + return AddError("HY012", 0, "Invalid transaction operation code"); + } + bool hasFailures = false; + int failedCount = 0; + + for (auto* conn : connections_) { + if (!conn || !conn->GetTx()) { + continue; + } + try { + if (completionType == SQL_COMMIT) { + conn->CommitTx(); + } else { + conn->RollbackTx(); + } + } catch (const std::exception& ex) { + hasFailures = true; + ++failedCount; + AddError("HY000", 0, ex.what(), SQL_SUCCESS_WITH_INFO); + } catch (...) { + hasFailures = true; + ++failedCount; + AddError("HY000", 0, "Unknown error during ENV-level transaction completion", SQL_SUCCESS_WITH_INFO); + } + } + if (hasFailures) { + AddError("01000", 0, "SQLEndTran(SQL_HANDLE_ENV): some connections failed", SQL_SUCCESS_WITH_INFO); + return SQL_SUCCESS_WITH_INFO; + } + return SQL_SUCCESS; +} +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/environment.h b/odbc/src/environment.h new file mode 100644 index 00000000000..70a785f45d7 --- /dev/null +++ b/odbc/src/environment.h @@ -0,0 +1,34 @@ +#pragma once + +#include "utils/error_manager.h" + +#include +#include +#include +#include + +namespace NYdb { +namespace NOdbc { + +class TConnection; + +class TEnvironment : public TErrorManager { +private: + SQLINTEGER OdbcVersion_; + std::unordered_set connections_; + +public: + TEnvironment(); + ~TEnvironment(); + + SQLRETURN SetAttribute(SQLINTEGER attribute, SQLPOINTER value, SQLINTEGER stringLength); + + void RegisterConnection(TConnection*); + void UnregisterConnection(TConnection*); + std::vector GetConnectionsSnapshot() const; + + SQLRETURN EndTran(SQLSMALLINT completionType); +}; + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/metadata.cpp b/odbc/src/metadata.cpp new file mode 100644 index 00000000000..b8857b245d9 --- /dev/null +++ b/odbc/src/metadata.cpp @@ -0,0 +1,279 @@ +#include "metadata.h" + +#include +#include + +namespace NYdb::NOdbc { +namespace { + +SQLRETURN WriteInfoString( + TConnection* conn, + const char* value, + SQLPOINTER infoValuePtr, + SQLSMALLINT bufferLength, + SQLSMALLINT* stringLengthPtr) { + if (!infoValuePtr) { + return conn->AddError("HY009", 0, "Invalid use of null pointer"); + } + if (bufferLength < 0) { + return conn->AddError("HY090", 0, "Invalid string or buffer length"); + } + const SQLSMALLINT fullLen = static_cast(std::strlen(value)); + if (stringLengthPtr) { + *stringLengthPtr = fullLen; + } + if (bufferLength == 0) { + return fullLen == 0 ? SQL_SUCCESS : conn->AddError("01004", 0, "String data, right truncated", SQL_SUCCESS_WITH_INFO); + } + + auto* out = reinterpret_cast(infoValuePtr); + const SQLSMALLINT copyLen = static_cast(std::min(fullLen, bufferLength - 1)); + if (copyLen > 0) { + std::memcpy(out, value, static_cast(copyLen)); + } + out[copyLen] = '\0'; + if (copyLen < fullLen) { + return conn->AddError("01004", 0, "String data, right truncated", SQL_SUCCESS_WITH_INFO); + } + return SQL_SUCCESS; +} + +template +SQLRETURN WriteInfoScalar( + TConnection* conn, + T value, + SQLPOINTER infoValuePtr, + SQLSMALLINT* stringLengthPtr) { + if (!infoValuePtr) { + return conn->AddError("HY009", 0, "Invalid use of null pointer"); + } + *reinterpret_cast(infoValuePtr) = value; + if (stringLengthPtr) { + *stringLengthPtr = static_cast(sizeof(T)); + } + return SQL_SUCCESS; +} + + +bool IsSupportedFunction(SQLUSMALLINT functionId) { + switch (functionId) { + case SQL_API_SQLALLOCHANDLE: + case SQL_API_SQLBINDCOL: + case SQL_API_SQLBINDPARAMETER: + case SQL_API_SQLCLOSECURSOR: + case SQL_API_SQLCOLUMNS: + case SQL_API_SQLCONNECT: + case SQL_API_SQLDESCRIBECOL: + case SQL_API_SQLDISCONNECT: + case SQL_API_SQLDRIVERCONNECT: + case SQL_API_SQLENDTRAN: + case SQL_API_SQLEXECDIRECT: + case SQL_API_SQLEXECUTE: + case SQL_API_SQLFETCH: + case SQL_API_SQLFETCHSCROLL: + case SQL_API_SQLFREEHANDLE: + case SQL_API_SQLFREESTMT: + case SQL_API_SQLGETDATA: + case SQL_API_SQLGETDIAGFIELD: + case SQL_API_SQLGETDIAGREC: + case SQL_API_SQLGETFUNCTIONS: + case SQL_API_SQLGETCONNECTATTR: + case SQL_API_SQLGETINFO: + case SQL_API_SQLGETSTMTATTR: + case SQL_API_SQLMORERESULTS: + case SQL_API_SQLNUMRESULTCOLS: + case SQL_API_SQLPREPARE: + case SQL_API_SQLROWCOUNT: + case SQL_API_SQLSETCONNECTATTR: + case SQL_API_SQLSETENVATTR: + case SQL_API_SQLSETSTMTATTR: + case SQL_API_SQLTABLES: + return true; + default: + return false; + } +} + +} // namespace + +SQLRETURN TMetadata::GetInfo( + TConnection* conn, + SQLUSMALLINT infoType, + SQLPOINTER infoValuePtr, + SQLSMALLINT bufferLength, + SQLSMALLINT* stringLengthPtr) { + switch (infoType) { + // Driver Information + case SQL_DRIVER_NAME: + return WriteInfoString(conn, "ydb-odbc", infoValuePtr, bufferLength, stringLengthPtr); + case SQL_DRIVER_VER: + return WriteInfoString(conn, "unknown", infoValuePtr, bufferLength, stringLengthPtr); + case SQL_DRIVER_ODBC_VER: + return WriteInfoString(conn, "03.00", infoValuePtr, bufferLength, stringLengthPtr); + + // DBMS Information + case SQL_DBMS_NAME: + return WriteInfoString(conn, "YDB", infoValuePtr, bufferLength, stringLengthPtr); + case SQL_DBMS_VER: + return WriteInfoString(conn, conn->GetDbmsVersion().c_str(), infoValuePtr, bufferLength, stringLengthPtr); + + // Identifier Handling + case SQL_IDENTIFIER_QUOTE_CHAR: + return WriteInfoString(conn, "\"", infoValuePtr, bufferLength, stringLengthPtr); + case SQL_IDENTIFIER_CASE: + return WriteInfoScalar(conn, SQL_IC_LOWER, infoValuePtr, stringLengthPtr); + + // Catalog Support + case SQL_CATALOG_NAME: + return WriteInfoString(conn, "Y", infoValuePtr, bufferLength, stringLengthPtr); + case SQL_CATALOG_NAME_SEPARATOR: + return WriteInfoString(conn, "/", infoValuePtr, bufferLength, stringLengthPtr); + case SQL_CATALOG_TERM: + return WriteInfoString(conn, "path", infoValuePtr, bufferLength, stringLengthPtr); + case SQL_CATALOG_USAGE: + return WriteInfoScalar(conn, SQL_CU_DML_STATEMENTS, infoValuePtr, stringLengthPtr); + + // Schema Support (YDB doesn't use schemas) + case SQL_SCHEMA_USAGE: + return WriteInfoScalar(conn, 0, infoValuePtr, stringLengthPtr); + case SQL_SCHEMA_TERM: + return WriteInfoString(conn, "", infoValuePtr, bufferLength, stringLengthPtr); + + // Data Source Capabilities + case SQL_DATA_SOURCE_READ_ONLY: + return WriteInfoString( + conn, conn->IsDataSourceReadOnly() ? "Y" : "N", infoValuePtr, bufferLength, stringLengthPtr); + case SQL_DATA_SOURCE_NAME: + return WriteInfoString(conn, conn->GetDataSourceName().c_str(), infoValuePtr, bufferLength, stringLengthPtr); + + // Result Set Capabilities + case SQL_MULT_RESULT_SETS: + return WriteInfoString(conn, "N", infoValuePtr, bufferLength, stringLengthPtr); + case SQL_DYNAMIC_CURSOR_ATTRIBUTES1: + case SQL_FORWARD_ONLY_CURSOR_ATTRIBUTES1: + case SQL_STATIC_CURSOR_ATTRIBUTES1: + return WriteInfoScalar(conn, SQL_CA1_NEXT, infoValuePtr, stringLengthPtr); + case SQL_CURSOR_COMMIT_BEHAVIOR: + case SQL_CURSOR_ROLLBACK_BEHAVIOR: + return WriteInfoScalar(conn, SQL_CB_CLOSE, infoValuePtr, stringLengthPtr); + + // Transaction Support + case SQL_TXN_CAPABLE: + return WriteInfoScalar(conn, SQL_TC_ALL, infoValuePtr, stringLengthPtr); + case SQL_DEFAULT_TXN_ISOLATION: + return WriteInfoScalar(conn, SQL_TXN_SERIALIZABLE, infoValuePtr, stringLengthPtr); + case SQL_TXN_ISOLATION_OPTION: + return WriteInfoScalar( + conn, conn->GetSupportedTxnIsolationOptions(), infoValuePtr, stringLengthPtr); + + // Stored Procedures (not supported) + case SQL_PROCEDURES: + return WriteInfoString(conn, "N", infoValuePtr, bufferLength, stringLengthPtr); + + case SQL_OUTER_JOINS: + return WriteInfoString(conn, "Y", infoValuePtr, bufferLength, stringLengthPtr); + + // Positioned Operations (not supported) + case SQL_POSITIONED_STATEMENTS: + return WriteInfoScalar(conn, 0, infoValuePtr, stringLengthPtr); + + // Batch Operations (not supported) + case SQL_BATCH_SUPPORT: + return WriteInfoScalar(conn, 0, infoValuePtr, stringLengthPtr); + case SQL_BATCH_ROW_COUNT: + return WriteInfoScalar(conn, 0, infoValuePtr, stringLengthPtr); + + // Bookmarks (not supported) + case SQL_BOOKMARK_PERSISTENCE: + return WriteInfoScalar(conn, 0, infoValuePtr, stringLengthPtr); + + // Named Cursors (not supported) + case SQL_FILE_USAGE: + return WriteInfoScalar(conn, SQL_FILE_NOT_SUPPORTED, infoValuePtr, stringLengthPtr); + + // GetData Extensions + case SQL_GETDATA_EXTENSIONS: + return WriteInfoScalar(conn, SQL_GD_ANY_COLUMN | SQL_GD_ANY_ORDER, infoValuePtr, stringLengthPtr); + + // Async Execution (not supported) + case SQL_ASYNC_MODE: + return WriteInfoScalar(conn, SQL_AM_NONE, infoValuePtr, stringLengthPtr); + + case SQL_QUOTED_IDENTIFIER_CASE: + return WriteInfoScalar(conn, SQL_IC_SENSITIVE, infoValuePtr, stringLengthPtr); + + default: + return conn->AddError("HYC00", 0, "Optional feature not implemented"); + } +} + + +SQLRETURN TMetadata::GetFunctions(SQLUSMALLINT functionId, SQLUSMALLINT* supportedPtr) { + if (!supportedPtr) { + return SQL_ERROR; + } + + if (functionId == SQL_API_ALL_FUNCTIONS) { + std::memset(supportedPtr, 0, 100 * sizeof(SQLUSMALLINT)); + for (SQLUSMALLINT id = 0; id < 100; ++id) { + if (IsSupportedFunction(id)) { + supportedPtr[id] = SQL_TRUE; + } + } + return SQL_SUCCESS; + } + + if (functionId == SQL_API_ODBC3_ALL_FUNCTIONS) { + std::memset(supportedPtr, 0, SQL_API_ODBC3_ALL_FUNCTIONS_SIZE * sizeof(SQLUSMALLINT)); + for (SQLUSMALLINT id = 0; id < SQL_API_ODBC3_ALL_FUNCTIONS_SIZE * 16; ++id) { + if (IsSupportedFunction(id)) { + supportedPtr[id >> 4] |= (1 << (id & 0x000F)); + } + } + return SQL_SUCCESS; + } + + *supportedPtr = IsSupportedFunction(functionId) ? SQL_TRUE : SQL_FALSE; + return SQL_SUCCESS; +} + +SQLRETURN TMetadata::DescribeCol( + TStatement* stmt, + SQLUSMALLINT columnNumber, + SQLCHAR* columnName, + SQLSMALLINT bufferLength, + SQLSMALLINT* nameLengthPtr, + SQLSMALLINT* dataTypePtr, + SQLULEN* columnSizePtr, + SQLSMALLINT* decimalDigitsPtr, + SQLSMALLINT* nullablePtr) { + const auto& columns = stmt->GetColumnMeta(); + if (columnNumber < 1 || columnNumber > columns.size()) { + throw TOdbcException("07009", 0, "Invalid descriptor index"); + } + + const auto& column = columns[columnNumber - 1]; + if (nameLengthPtr) { + *nameLengthPtr = static_cast(column.Name.size()); + } + if (columnName && bufferLength > 0) { + const auto copyLength = std::min(column.Name.size(), static_cast(bufferLength - 1)); + std::memcpy(columnName, column.Name.data(), copyLength); + columnName[copyLength] = '\0'; + } + if (dataTypePtr) { + *dataTypePtr = column.SqlType; + } + if (columnSizePtr) { + *columnSizePtr = column.Size; + } + if (decimalDigitsPtr) { + *decimalDigitsPtr = column.DecimalDigits; + } + if (nullablePtr) { + *nullablePtr = column.Nullable; + } + return SQL_SUCCESS; +} + +} // namespace NYdb::NOdbc diff --git a/odbc/src/metadata.h b/odbc/src/metadata.h new file mode 100644 index 00000000000..7374e45e53a --- /dev/null +++ b/odbc/src/metadata.h @@ -0,0 +1,33 @@ +#pragma once + +#include "connection.h" +#include "statement.h" + +namespace NYdb::NOdbc { + +class TMetadata { +public: + static SQLRETURN GetInfo( + TConnection* conn, + SQLUSMALLINT infoType, + SQLPOINTER infoValuePtr, + SQLSMALLINT bufferLength, + SQLSMALLINT* stringLengthPtr); + + static SQLRETURN GetFunctions( + SQLUSMALLINT functionId, + SQLUSMALLINT* supportedPtr); + + static SQLRETURN DescribeCol( + TStatement* stmt, + SQLUSMALLINT columnNumber, + SQLCHAR* columnName, + SQLSMALLINT bufferLength, + SQLSMALLINT* nameLengthPtr, + SQLSMALLINT* dataTypePtr, + SQLULEN* columnSizePtr, + SQLSMALLINT* decimalDigitsPtr, + SQLSMALLINT* nullablePtr); +}; + +} // namespace NYdb::NOdbc diff --git a/odbc/src/odbc_driver.cpp b/odbc/src/odbc_driver.cpp new file mode 100644 index 00000000000..3d9a2019c3a --- /dev/null +++ b/odbc/src/odbc_driver.cpp @@ -0,0 +1,459 @@ +#include "environment.h" +#include "connection.h" +#include "statement.h" +#include "metadata.h" + +#include "utils/util.h" +#include "utils/error_manager.h" + +#include +#include + +namespace { + template + Handle* GetHandle(SQLHANDLE handle) { + if (!handle) { + throw NYdb::NOdbc::TOdbcException("HY000", 0, "Invalid handle", SQL_INVALID_HANDLE); + } + return static_cast(handle); + } + +} + +extern "C" { + +SQLRETURN SQL_API SQLAllocHandle(SQLSMALLINT handleType, + SQLHANDLE inputHandle, + SQLHANDLE* outputHandle) { + if (!outputHandle) { + return SQL_INVALID_HANDLE; + } + + switch (handleType) { + case SQL_HANDLE_ENV: { + return NYdb::NOdbc::HandleOdbcExceptions( + inputHandle, + [&]() { + auto* const env = new NYdb::NOdbc::TEnvironment(); + *outputHandle = env; + env->SetLastReturnCode(SQL_SUCCESS); + return SQL_SUCCESS; + }, + NYdb::NOdbc::ENullInputHandlePolicy::Allow); + } + + case SQL_HANDLE_DBC: { + return NYdb::NOdbc::HandleOdbcExceptions(inputHandle, [&](auto* env) { + auto conn = std::make_unique(); + conn->SetEnvironment(env); + env->RegisterConnection(conn.get()); + auto* const raw = conn.release(); + *outputHandle = raw; + raw->SetLastReturnCode(SQL_SUCCESS); + return SQL_SUCCESS; + }); + } + case SQL_HANDLE_STMT: { + return NYdb::NOdbc::HandleOdbcExceptions(inputHandle, [&](auto* conn) { + auto stmt = conn->CreateStatement(); + auto* const raw = stmt.release(); + *outputHandle = raw; + raw->SetLastReturnCode(SQL_SUCCESS); + return SQL_SUCCESS; + }); + } + default: + return SQL_ERROR; + } +} + +SQLRETURN SQL_API SQLFreeHandle(SQLSMALLINT handleType, SQLHANDLE handle) { + switch (handleType) { + case SQL_HANDLE_ENV: { + return NYdb::NOdbc::HandleOdbcExceptions(handle, [](auto* env) { + delete env; + return SQL_SUCCESS; + }); + } + case SQL_HANDLE_DBC: { + return NYdb::NOdbc::HandleOdbcExceptions(handle, [](auto* conn) { + auto* env = conn->GetEnvironment(); + if (env != nullptr){ + env->UnregisterConnection(conn); + } + delete conn; + return SQL_SUCCESS; + }); + } + case SQL_HANDLE_STMT: { + return NYdb::NOdbc::HandleOdbcExceptions(handle, [](auto* stmt) { + if (stmt->GetConnection()) { + stmt->GetConnection()->RemoveStatement(stmt); + } + delete stmt; + return SQL_SUCCESS; + }); + } + default: + return SQL_ERROR; + } +} + +SQLRETURN SQL_API SQLSetEnvAttr(SQLHENV environmentHandle, + SQLINTEGER attribute, + SQLPOINTER value, + SQLINTEGER stringLength) { + auto env = static_cast(environmentHandle); + if (!env) { + return SQL_INVALID_HANDLE; + } + + return NYdb::NOdbc::HandleOdbcExceptions(env, [&]() { + return env->SetAttribute(attribute, value, stringLength); + }); +} + +SQLRETURN SQL_API SQLDriverConnect(SQLHDBC connectionHandle, + SQLHWND /*WindowHandle*/, + SQLCHAR* inConnectionString, + SQLSMALLINT stringLength1, + SQLCHAR* /*outConnectionString*/, + SQLSMALLINT /*bufferLength*/, + SQLSMALLINT* /*stringLength2Ptr*/, + SQLUSMALLINT /*driverCompletion*/) { + return NYdb::NOdbc::HandleOdbcExceptions(connectionHandle, [&](auto* conn) { + return conn->DriverConnect(NYdb::NOdbc::GetString(inConnectionString, stringLength1)); + }); +} + +SQLRETURN SQL_API SQLConnect(SQLHDBC connectionHandle, + SQLCHAR* serverName, SQLSMALLINT nameLength1, + SQLCHAR* userName, SQLSMALLINT nameLength2, + SQLCHAR* authentication, SQLSMALLINT nameLength3) { + return NYdb::NOdbc::HandleOdbcExceptions(connectionHandle, [&](auto* conn) { + return conn->Connect(NYdb::NOdbc::GetString(serverName, nameLength1), + NYdb::NOdbc::GetString(userName, nameLength2), + NYdb::NOdbc::GetString(authentication, nameLength3)); + }); +} + +SQLRETURN SQL_API SQLDisconnect(SQLHDBC connectionHandle) { + return NYdb::NOdbc::HandleOdbcExceptions(connectionHandle, [&](auto* conn) { + return conn->Disconnect(); + }); +} + +SQLRETURN SQL_API SQLExecDirect(SQLHSTMT statementHandle, + SQLCHAR* statementText, + SQLINTEGER textLength) { + return NYdb::NOdbc::HandleOdbcExceptions(statementHandle, [&](auto* stmt) { + auto ret = stmt->Prepare(NYdb::NOdbc::GetString(statementText, textLength)); + if (ret != SQL_SUCCESS) { + return ret; + } + return stmt->Execute(); + }); +} + +SQLRETURN SQL_API SQLPrepare(SQLHSTMT statementHandle, + SQLCHAR* statementText, + SQLINTEGER textLength) { + return NYdb::NOdbc::HandleOdbcExceptions(statementHandle, [&](auto* stmt) { + return stmt->Prepare(NYdb::NOdbc::GetString(statementText, textLength)); + }); +} + +SQLRETURN SQL_API SQLExecute(SQLHSTMT statementHandle) { + return NYdb::NOdbc::HandleOdbcExceptions(statementHandle, [&](auto* stmt) { + return stmt->Execute(); + }); +} + +SQLRETURN SQL_API SQLFetch(SQLHSTMT statementHandle) { + return NYdb::NOdbc::HandleOdbcExceptions(statementHandle, [&](auto* stmt) { + return stmt->Fetch(); + }); +} + +SQLRETURN SQL_API SQLGetData(SQLHSTMT statementHandle, + SQLUSMALLINT columnNumber, + SQLSMALLINT targetType, + SQLPOINTER targetValue, + SQLLEN bufferLength, + SQLLEN* strLenOrInd) { + return NYdb::NOdbc::HandleOdbcExceptions(statementHandle, [&](auto* stmt) { + return stmt->GetData(columnNumber, targetType, targetValue, bufferLength, strLenOrInd); + }); +} + +SQLRETURN SQL_API SQLBindCol(SQLHSTMT statementHandle, + SQLUSMALLINT columnNumber, + SQLSMALLINT targetType, + SQLPOINTER targetValue, + SQLLEN bufferLength, + SQLLEN* strLenOrInd) { + return NYdb::NOdbc::HandleOdbcExceptions(statementHandle, [&](auto* stmt) { + return stmt->BindCol(columnNumber, targetType, targetValue, bufferLength, strLenOrInd); + }); +} + +SQLRETURN SQL_API SQLGetDiagRec(SQLSMALLINT handleType, + SQLHANDLE handle, + SQLSMALLINT recNumber, + SQLCHAR* sqlState, + SQLINTEGER* nativeError, + SQLCHAR* messageText, + SQLSMALLINT bufferLength, + SQLSMALLINT* textLength) { + switch (handleType) { + case SQL_HANDLE_ENV: { + return NYdb::NOdbc::HandleOdbcExceptions(handle, [&](auto* env) { + return env->GetDiagRec(recNumber, sqlState, nativeError, messageText, bufferLength, textLength); + }); + } + case SQL_HANDLE_DBC: { + return NYdb::NOdbc::HandleOdbcExceptions(handle, [&](auto* conn) { + return conn->GetDiagRec(recNumber, sqlState, nativeError, messageText, bufferLength, textLength); + }); + } + case SQL_HANDLE_STMT: { + return NYdb::NOdbc::HandleOdbcExceptions(handle, [&](auto* stmt) { + return stmt->GetDiagRec(recNumber, sqlState, nativeError, messageText, bufferLength, textLength); + }); + } + default: + return SQL_ERROR; + } +} + +SQLRETURN SQL_API SQLGetDiagField(SQLSMALLINT handleType, + SQLHANDLE handle, + SQLSMALLINT recNumber, + SQLSMALLINT diagIdentifier, + SQLPOINTER diagInfoPtr, + SQLSMALLINT bufferLength, + SQLSMALLINT* stringLengthPtr) { + switch (handleType) { + case SQL_HANDLE_ENV: { + return NYdb::NOdbc::HandleOdbcExceptions(handle, [&](auto* env) { + return env->GetDiagField(recNumber, diagIdentifier, diagInfoPtr, bufferLength, stringLengthPtr); + }); + } + case SQL_HANDLE_DBC: { + return NYdb::NOdbc::HandleOdbcExceptions(handle, [&](auto* conn) { + return conn->GetDiagField(recNumber, diagIdentifier, diagInfoPtr, bufferLength, stringLengthPtr); + }); + } + case SQL_HANDLE_STMT: { + return NYdb::NOdbc::HandleOdbcExceptions(handle, [&](auto* stmt) { + return stmt->GetDiagField(recNumber, diagIdentifier, diagInfoPtr, bufferLength, stringLengthPtr); + }); + } + default: + return SQL_ERROR; + } +} + +SQLRETURN SQL_API SQLBindParameter(SQLHSTMT statementHandle, + SQLUSMALLINT paramNumber, + SQLSMALLINT inputOutputType, + SQLSMALLINT valueType, + SQLSMALLINT parameterType, + SQLULEN columnSize, + SQLSMALLINT decimalDigits, + SQLPOINTER parameterValuePtr, + SQLLEN bufferLength, + SQLLEN* strLenOrIndPtr) { + return NYdb::NOdbc::HandleOdbcExceptions(statementHandle, [&](auto* stmt) { + return stmt->BindParameter(paramNumber, inputOutputType, valueType, parameterType, columnSize, decimalDigits, parameterValuePtr, bufferLength, strLenOrIndPtr); + }); +} + +SQLRETURN SQL_API SQLEndTran(SQLSMALLINT handleType, SQLHANDLE handle, SQLSMALLINT completionType) { + switch (handleType) { + case SQL_HANDLE_DBC: { + return NYdb::NOdbc::HandleOdbcExceptions(handle, [&](auto* conn) { + if (completionType == SQL_COMMIT) { + return conn->CommitTx(); + } else if (completionType == SQL_ROLLBACK) { + return conn->RollbackTx(); + } else { + throw NYdb::NOdbc::TOdbcException("HY000", 0, "Invalid completion type"); + } + }); + } + case SQL_HANDLE_STMT: { + return NYdb::NOdbc::HandleOdbcExceptions(handle, [&](auto* stmt) -> SQLRETURN { + auto conn = stmt->GetConnection(); + if (!conn) return SQL_INVALID_HANDLE; + if (completionType == SQL_COMMIT) { + return conn->CommitTx(); + } else if (completionType == SQL_ROLLBACK) { + return conn->RollbackTx(); + } else { + throw NYdb::NOdbc::TOdbcException("HY000", 0, "Invalid completion type"); + } + }); + } + case SQL_HANDLE_ENV: { + return NYdb::NOdbc::HandleOdbcExceptions(handle, [&](auto* env) -> SQLRETURN { + return env->EndTran(completionType); + }); + } + default: + return SQL_ERROR; + } +} + +SQLRETURN SQL_API SQLSetConnectAttr(SQLHDBC connectionHandle, SQLINTEGER attribute, SQLPOINTER value, SQLINTEGER stringLength) { + return NYdb::NOdbc::HandleOdbcExceptions(connectionHandle, [&](auto* conn) { + return conn->SetConnectAttr(attribute, value, stringLength); + }); +} + +SQLRETURN SQL_API SQLGetConnectAttr(SQLHDBC connectionHandle, SQLINTEGER attribute, SQLPOINTER value, SQLINTEGER bufferLength, + SQLINTEGER* stringLengthPtr) { + return NYdb::NOdbc::HandleOdbcExceptions(connectionHandle, [&](auto* conn) { + return conn->GetConnectAttr(attribute, value, bufferLength, stringLengthPtr); + }); +} + +SQLRETURN SQL_API SQLColumns(SQLHSTMT statementHandle, + SQLCHAR* catalogName, SQLSMALLINT nameLength1, + SQLCHAR* schemaName, SQLSMALLINT nameLength2, + SQLCHAR* tableName, SQLSMALLINT nameLength3, + SQLCHAR* columnName, SQLSMALLINT nameLength4) { + return NYdb::NOdbc::HandleOdbcExceptions(statementHandle, [&](auto* stmt) { + return stmt->Columns( + NYdb::NOdbc::GetString(catalogName, nameLength1), + NYdb::NOdbc::GetString(schemaName, nameLength2), + NYdb::NOdbc::GetString(tableName, nameLength3), + NYdb::NOdbc::GetString(columnName, nameLength4)); + }); +} + +SQLRETURN SQL_API SQLTables(SQLHSTMT statementHandle, + SQLCHAR* catalogName, SQLSMALLINT nameLength1, + SQLCHAR* schemaName, SQLSMALLINT nameLength2, + SQLCHAR* tableName, SQLSMALLINT nameLength3, + SQLCHAR* tableType, SQLSMALLINT nameLength4) { + return NYdb::NOdbc::HandleOdbcExceptions(statementHandle, [&](auto* stmt) { + return stmt->Tables( + NYdb::NOdbc::GetString(catalogName, nameLength1), + NYdb::NOdbc::GetString(schemaName, nameLength2), + NYdb::NOdbc::GetString(tableName, nameLength3), + NYdb::NOdbc::GetString(tableType, nameLength4)); + }); +} + +SQLRETURN SQL_API SQLCloseCursor(SQLHSTMT statementHandle) { + return NYdb::NOdbc::HandleOdbcExceptions(statementHandle, [&](auto* stmt) { + return stmt->Close(false); + }); +} + +SQLRETURN SQL_API SQLFreeStmt(SQLHSTMT statementHandle, SQLUSMALLINT option) { + return NYdb::NOdbc::HandleOdbcExceptions(statementHandle, [&](auto* stmt) -> SQLRETURN { + switch (option) { + case SQL_CLOSE: + return stmt->Close(true); + case SQL_DROP: + return SQLFreeHandle(SQL_HANDLE_STMT, statementHandle); + case SQL_UNBIND: + stmt->UnbindColumns(); + return SQL_SUCCESS; + case SQL_RESET_PARAMS: + stmt->ResetParams(); + return SQL_SUCCESS; + default: + throw NYdb::NOdbc::TOdbcException("HY000", 0, "Invalid option"); + } + }); +} + +SQLRETURN SQL_API SQLFetchScroll(SQLHSTMT statementHandle, SQLSMALLINT fetchOrientation, SQLLEN fetchOffset) { + return NYdb::NOdbc::HandleOdbcExceptions(statementHandle, [&](auto* stmt) { + if (fetchOrientation == SQL_FETCH_NEXT) { + return stmt->Fetch(); + } else { + throw NYdb::NOdbc::TOdbcException("HYC00", 0, "Only SQL_FETCH_NEXT is supported"); + } + //TODO other fetch-orientation + }); +} + +SQLRETURN SQL_API SQLRowCount(SQLHSTMT statementHandle, SQLLEN* rowCount) { + return NYdb::NOdbc::HandleOdbcExceptions(statementHandle, [&](auto* stmt) { + return stmt->RowCount(rowCount); + }); +} + +SQLRETURN SQL_API SQLNumResultCols(SQLHSTMT statementHandle, SQLSMALLINT* colCount) { + return NYdb::NOdbc::HandleOdbcExceptions(statementHandle, [&](auto* stmt) { + return stmt->NumResultCols(colCount); + }); +} + +SQLRETURN SQL_API SQLDescribeCol( + SQLHSTMT statementHandle, + SQLUSMALLINT columnNumber, + SQLCHAR* columnName, + SQLSMALLINT bufferLength, + SQLSMALLINT* nameLengthPtr, + SQLSMALLINT* dataTypePtr, + SQLULEN* columnSizePtr, + SQLSMALLINT* decimalDigitsPtr, + SQLSMALLINT* nullablePtr) { + return NYdb::NOdbc::HandleOdbcExceptions(statementHandle, [&](auto* stmt) { + return NYdb::NOdbc::TMetadata::DescribeCol( + stmt, + columnNumber, + columnName, + bufferLength, + nameLengthPtr, + dataTypePtr, + columnSizePtr, + decimalDigitsPtr, + nullablePtr); + }); +} + +SQLRETURN SQL_API SQLMoreResults(SQLHSTMT) { + // YDB ODBC currently exposes only one result set per statement. + return SQL_NO_DATA; +} + +SQLRETURN SQL_API SQLGetFunctions(SQLHDBC connectionHandle, SQLUSMALLINT functionId, SQLUSMALLINT* supportedPtr) { + return NYdb::NOdbc::HandleOdbcExceptions(connectionHandle, [&](auto*) { + return NYdb::NOdbc::TMetadata::GetFunctions(functionId, supportedPtr); + }); +} + +SQLRETURN SQL_API SQLSetStmtAttr(SQLHSTMT statementHandle, SQLINTEGER attribute, SQLPOINTER value, SQLINTEGER stringLength) { + return NYdb::NOdbc::HandleOdbcExceptions(statementHandle, [&](auto* stmt) { + return stmt->SetStmtAttr(attribute, value, stringLength); + }); +} + +SQLRETURN SQL_API SQLGetStmtAttr( + SQLHSTMT statementHandle, + SQLINTEGER attribute, + SQLPOINTER value, + SQLINTEGER bufferLength, + SQLINTEGER* stringLengthPtr) { + return NYdb::NOdbc::HandleOdbcExceptions(statementHandle, [&](auto* stmt) { + return stmt->GetStmtAttr(attribute, value, bufferLength, stringLengthPtr); + }); +} + +SQLRETURN SQL_API SQLGetInfo(SQLHDBC connectionHandle, + SQLUSMALLINT infoType, + SQLPOINTER infoValuePtr, + SQLSMALLINT bufferLength, + SQLSMALLINT* stringLengthPtr) { + return NYdb::NOdbc::HandleOdbcExceptions(connectionHandle, [&](auto* conn) { + return NYdb::NOdbc::TMetadata::GetInfo(conn, infoType, infoValuePtr, bufferLength, stringLengthPtr); + }); +} + +} diff --git a/odbc/src/statement.cpp b/odbc/src/statement.cpp new file mode 100644 index 00000000000..4115dc93c84 --- /dev/null +++ b/odbc/src/statement.cpp @@ -0,0 +1,571 @@ +#include "statement.h" + +#include "utils/convert.h" +#include "utils/types.h" +#include "utils/error_manager.h" +#include "utils/escape.h" +#include "utils/sql_like.h" + +#include +#include +#include +#include + +#include + +#include +#include +#include + +namespace NYdb { +namespace NOdbc { + +namespace { + + bool StartsWithPrefix(const char* s, size_t sLen, const char* prefix, size_t prefixLen) { + if (sLen < prefixLen) { + return false; + } + for (size_t i = 0; i < prefixLen; ++i) { + if (std::tolower(static_cast(s[i])) != + std::tolower(static_cast(prefix[i]))) { + return false; + } + } + return true; + } + + bool IsDdlQuery(const std::string& queryText) { + size_t pos = 0; + while (pos < queryText.size() && std::isspace(static_cast(queryText[pos]))) { + ++pos; + } + if (queryText.size() - pos < 6) { + return false; + } + const char* start = queryText.c_str() + pos; + const size_t remaining = queryText.size() - pos; + return StartsWithPrefix(start, remaining, "CREATE", 6) || + StartsWithPrefix(start, remaining, "DROP", 4) || + StartsWithPrefix(start, remaining, "ALTER", 5); + } + + NYdb::TStatus StatusFrom(const NYdb::TStatus& ydb_status) { + return NYdb::TStatus(ydb_status.GetStatus(), NYdb::NIssue::TIssues(ydb_status.GetIssues())); + } +} + +TStatement::TStatement(TConnection* conn) + : Conn_(conn) {} + +SQLRETURN TStatement::Prepare(const std::string& statementText) { + StreamFetchError_ = false; + RowsFetched_ = 0; + Cursor_.reset(); + PreparedQuery_ = statementText; + IsPrepared_ = true; + return SQL_SUCCESS; +} + +SQLRETURN TStatement::Execute() { + if (!IsPrepared_ || PreparedQuery_.empty()) { + throw TOdbcException("HY007", 0, "No prepared statement"); + } + StreamFetchError_ = false; + RowsFetched_ = 0; + Cursor_.reset(); + auto* client = Conn_->GetClient(); + if (!client) { + throw TOdbcException("HY000", 0, "No client connection"); + } + NYdb::TParams params = NYdb::TParamsBuilder().Build(); + const SQLRETURN buildRc = BuildParams(params); + if (buildRc != SQL_SUCCESS) { + return buildRc; + } + + if (Conn_->GetAutocommit()) { + Conn_->ResetTx(); + Conn_->ResetQuerySession(); + const NYdb::NRetry::TRetryOperationSettings retrySettings = MakeAutocommitRetrySettings(); + + const NYdb::TStatus execStatus = client->RetryQuerySync( + [this, ¶ms](NQuery::TSession session) -> NYdb::TStatus { + auto retryIterator = CreateExecuteIterator(session, params); + if (!retryIterator.IsSuccess()) { + return StatusFrom(retryIterator); + } + TExecCursorCreateResult created = TryCreateExecCursor(this, std::move(retryIterator)); + if (!created.Status.IsSuccess()) { + return created.Status; + } + Cursor_ = std::move(created.Cursor); + return NYdb::TStatus(EStatus::SUCCESS, NYdb::NIssue::TIssues()); + }, + retrySettings); + + NStatusHelpers::ThrowOnError(execStatus); + } else { + NQuery::TSession& session = Conn_->GetOrCreateQuerySession(); + auto iterator = CreateExecuteIterator(session, params); + NStatusHelpers::ThrowOnError(iterator); + TExecCursorCreateResult created = TryCreateExecCursor(this, std::move(iterator)); + NStatusHelpers::ThrowOnError(created.Status); + Cursor_ = std::move(created.Cursor); + } + return SQL_SUCCESS; +} + +NYdb::NRetry::TRetryOperationSettings TStatement::MakeAutocommitRetrySettings() { + NYdb::NRetry::TRetryOperationSettings settings; + settings.Idempotent(true); + SQLUINTEGER queryTimeoutSec = Attributes_.GetQueryTimeoutSec(); + if (queryTimeoutSec > 0) { + const TDuration deadline = TDuration::Seconds(queryTimeoutSec); + settings.MaxTimeout(deadline).GetSessionClientTimeout(deadline); + } + return settings; +} + +NQuery::TExecuteQueryIterator TStatement::CreateExecuteIterator(NQuery::TSession& session, const NYdb::TParams& params){ + const std::string sqlText = Attributes_.GetNoScanMode() == SQL_NOSCAN_ON + ? PreparedQuery_ + : RewriteOdbcEscapes(PreparedQuery_); + const std::string queryText = Conn_->WrapQueryForCurrentCatalog(sqlText); + NQuery::TExecuteQuerySettings execSettings; + const SQLUINTEGER queryTimeoutSec = Attributes_.GetQueryTimeoutSec(); + if (queryTimeoutSec > 0) { + execSettings.ClientTimeout(TDuration::Seconds(queryTimeoutSec)); + } + const auto txSettings = Conn_->MakeTxSettings(); + if (Conn_->GetAutocommit()) { + // TS_SNAPSHOT_RW doesn't support explicit BeginTx() - we use NoTx() instead + // DDL must use NoTx() per YDB documentation + const bool isSnapshotRw = (txSettings.GetMode() == NQuery::TTxSettings::TS_SNAPSHOT_RW); + + const bool isDdl = IsDdlQuery(queryText); + + if (isSnapshotRw || isDdl) { + return session.StreamExecuteQuery( + queryText, + NQuery::TTxControl::NoTx(), + params, + execSettings).ExtractValueSync(); + } + return session.StreamExecuteQuery( + queryText, + NQuery::TTxControl::BeginTx(txSettings).CommitTx(), + params, + execSettings).ExtractValueSync(); + } + if (!Conn_->GetTx()) { + auto beginTxResult = session.BeginTransaction(txSettings).ExtractValueSync(); + NStatusHelpers::ThrowOnError(beginTxResult); + Conn_->SetTx(beginTxResult.GetTransaction()); + } + return session.StreamExecuteQuery( + queryText, + NQuery::TTxControl::Tx(*Conn_->GetTx()).CommitTx(false), + params, + execSettings).ExtractValueSync(); +} + + + +SQLRETURN TStatement::Fetch() { + if (!Cursor_) { + Cursor_.reset(); + return SQL_NO_DATA; + } + const SQLULEN maxRows = Attributes_.GetMaxRows(); + if (maxRows > 0 && RowsFetched_ >= maxRows) { + return SQL_NO_DATA; + } + StreamFetchError_ = false; + if (!Cursor_->Fetch()) { + return StreamFetchError_ ? SQL_ERROR : SQL_NO_DATA; + } + ++RowsFetched_; + return SQL_SUCCESS; +} + +void TStatement::OnStreamPartError(const TStatus& status) { + ClearErrors(); + AddError(status); + StreamFetchError_ = true; +} + +SQLRETURN TStatement::GetData(SQLUSMALLINT columnNumber, SQLSMALLINT targetType, + SQLPOINTER targetValue, SQLLEN bufferLength, SQLLEN* strLenOrInd) { + if (!Cursor_) { + return SQL_NO_DATA; + } + return Cursor_->GetData(columnNumber, targetType, targetValue, bufferLength, strLenOrInd); +} + +void TStatement::FillBoundColumns() { + if (!Cursor_) { + return; + } + for (const auto& col : BoundColumns_) { + Cursor_->GetData(col.ColumnNumber, col.TargetType, col.TargetValue, col.BufferLength, col.StrLenOrInd); + } +} + +SQLRETURN TStatement::BindCol(SQLUSMALLINT columnNumber, SQLSMALLINT targetType, SQLPOINTER targetValue, SQLLEN bufferLength, SQLLEN* strLenOrInd) { + if (targetValue && columnNumber < 1) { + return AddError("07009", 0, "Invalid descriptor index"); + } + if (Cursor_) { + const size_t n = Cursor_->GetColumnMeta().size(); + if (targetValue && n > 0 && static_cast(columnNumber) > n) { + return AddError("07009", 0, "Invalid descriptor index"); + } + } + + BoundColumns_.erase(std::remove_if(BoundColumns_.begin(), BoundColumns_.end(), + [columnNumber](const TBoundColumn& col) { return col.ColumnNumber == columnNumber; }), BoundColumns_.end()); + + if (!targetValue) { + return SQL_SUCCESS; + } + BoundColumns_.push_back({columnNumber, targetType, targetValue, bufferLength, strLenOrInd}); + return SQL_SUCCESS; +} + +SQLRETURN TStatement::BindParameter(SQLUSMALLINT paramNumber, + SQLSMALLINT inputOutputType, + SQLSMALLINT valueType, + SQLSMALLINT parameterType, + SQLULEN columnSize, + SQLSMALLINT decimalDigits, + SQLPOINTER parameterValuePtr, + SQLLEN bufferLength, + SQLLEN* strLenOrIndPtr) { + + if (inputOutputType != SQL_PARAM_INPUT) { + throw TOdbcException("HYC00", 0, "Only input parameters are supported"); + } + + BoundParams_.erase(std::remove_if(BoundParams_.begin(), BoundParams_.end(), + [paramNumber](const TBoundParam& p) { return p.ParamNumber == paramNumber; }), BoundParams_.end()); + + if (!parameterValuePtr) { + return SQL_SUCCESS; + } + BoundParams_.push_back({paramNumber, inputOutputType, valueType, parameterType, columnSize, decimalDigits, parameterValuePtr, bufferLength, strLenOrIndPtr}); + return SQL_SUCCESS; +} + +SQLRETURN TStatement::BuildParams(NYdb::TParams& out) { + ClearErrors(); + NYdb::TParamsBuilder paramsBuilder; + for (const auto& param : BoundParams_) { + const std::string paramName = "$p" + std::to_string(param.ParamNumber); + const SQLRETURN convRc = ConvertParam(param, paramsBuilder.AddParam(paramName)); + if (convRc != SQL_SUCCESS) { + return AddError( + "07006", + 0, + "Unsupported or invalid ODBC parameter type for parameter " + std::to_string(param.ParamNumber) + + " (C type " + std::to_string(static_cast(param.ValueType)) + ", SQL type " + + std::to_string(static_cast(param.ParameterType)) + ")"); + } + } + out = paramsBuilder.Build(); + return SQL_SUCCESS; +} + +SQLRETURN TStatement::Columns(const std::string& catalogName, + const std::string& schemaName, + const std::string& tableName, + const std::string& columnName) { + ClearErrors(); + RowsFetched_ = 0; + Cursor_.reset(); + + std::vector columns = { + {"TABLE_CAT", SQL_VARCHAR, 128, SQL_NULLABLE}, + {"TABLE_SCHEM", SQL_VARCHAR, 128, SQL_NULLABLE}, + {"TABLE_NAME", SQL_VARCHAR, 128, SQL_NO_NULLS}, + {"COLUMN_NAME", SQL_VARCHAR, 128, SQL_NO_NULLS}, + {"DATA_TYPE", SQL_INTEGER, 0, SQL_NO_NULLS}, + {"TYPE_NAME", SQL_VARCHAR, 128, SQL_NO_NULLS}, + {"COLUMN_SIZE", SQL_INTEGER, 0, SQL_NULLABLE}, + {"BUFFER_LENGTH", SQL_INTEGER, 0, SQL_NULLABLE}, + {"DECIMAL_DIGITS", SQL_INTEGER, 0, SQL_NULLABLE}, + {"NUM_PREC_RADIX", SQL_INTEGER, 0, SQL_NULLABLE}, + {"NULLABLE", SQL_INTEGER, 0, SQL_NO_NULLS}, + {"REMARKS", SQL_VARCHAR, 762, SQL_NULLABLE}, + {"COLUMN_DEF", SQL_VARCHAR, 254, SQL_NULLABLE}, + {"SQL_DATA_TYPE", SQL_INTEGER, 0, SQL_NO_NULLS}, + {"SQL_DATETIME_SUB", SQL_INTEGER, 0, SQL_NULLABLE}, + {"CHAR_OCTET_LENGTH", SQL_INTEGER, 0, SQL_NULLABLE}, + {"ORDINAL_POSITION", SQL_INTEGER, 0, SQL_NO_NULLS}, + {"IS_NULLABLE", SQL_VARCHAR, 254, SQL_NO_NULLS} + }; + + auto entries = GetPatternEntries(tableName); + + TTable table; + table.reserve(entries.size()); + + if (entries.empty()) { + Cursor_ = CreateVirtualCursor(this, columns, table); + return SQL_SUCCESS; + } + + for (const auto& entry : entries) { + if (entry.Type != NScheme::ESchemeEntryType::Table && + entry.Type != NScheme::ESchemeEntryType::ColumnTable) { + continue; + } + + auto status = Conn_->GetTableClient()->RetryOperationSync([this, path = entry.Name, &table, &columnName](NTable::TSession session) -> TStatus { + auto result = session.DescribeTable(path).ExtractValueSync(); + NStatusHelpers::ThrowOnError(result); + + auto columns = result.GetTableDescription().GetTableColumns(); + + auto columnMatches = [&](const NTable::TTableColumn& column) { + if (columnName.empty()) { + return true; + } + if (Attributes_.GetMetadataId() == SQL_TRUE) { + return column.Name == columnName; + } + return SqlLikeMatch(column.Name, columnName); + }; + + bool foundColumn = false; + for (size_t columnIndex = 0; columnIndex < columns.size(); ++columnIndex) { + const auto& column = columns[columnIndex]; + if (!columnMatches(column)) { + continue; + } + foundColumn = true; + + table.push_back({ + TValueBuilder().OptionalUtf8(std::nullopt).Build(), + TValueBuilder().OptionalUtf8(std::nullopt).Build(), + TValueBuilder().Utf8(path).Build(), + TValueBuilder().Utf8(column.Name).Build(), + TValueBuilder().Int16(GetTypeId(column.Type)).Build(), + TValueBuilder().Utf8(column.Type.ToString()).Build(), + TValueBuilder().OptionalInt32(std::nullopt).Build(), + TValueBuilder().OptionalInt32(std::nullopt).Build(), + TValueBuilder().OptionalInt16(GetDecimalDigits(column.Type)).Build(), + TValueBuilder().OptionalInt16(GetRadix(column.Type)).Build(), + TValueBuilder().Int16(column.NotNull && *column.NotNull ? SQL_NO_NULLS : SQL_NULLABLE).Build(), + TValueBuilder().OptionalUtf8(std::nullopt).Build(), + TValueBuilder().OptionalUtf8(std::nullopt).Build(), + TValueBuilder().Int16(GetTypeId(column.Type)).Build(), + TValueBuilder().OptionalInt16(std::nullopt).Build(), + TValueBuilder().OptionalInt32(8).Build(), + TValueBuilder().OptionalInt32(columnIndex + 1).Build(), + TValueBuilder().Utf8(column.NotNull && *column.NotNull ? "NO" : "YES").Build(), + }); + } + if (!foundColumn) { + throw TOdbcException("42S22", 0, "Column not found", SQL_ERROR); + } + return TStatus(EStatus::SUCCESS, {}); + }); + + NStatusHelpers::ThrowOnError(status); + } + + Cursor_ = CreateVirtualCursor(this, columns, table); + return SQL_SUCCESS; +} + +SQLRETURN TStatement::Tables(const std::string& catalogName, + const std::string& schemaName, + const std::string& tableName, + const std::string& tableType) { + ClearErrors(); + RowsFetched_ = 0; + Cursor_.reset(); + + std::vector columns = { + {"TABLE_CAT", SQL_VARCHAR, 128, SQL_NULLABLE}, + {"TABLE_SCHEM", SQL_VARCHAR, 128, SQL_NULLABLE}, + {"TABLE_NAME", SQL_VARCHAR, 128, SQL_NO_NULLS}, + {"TABLE_TYPE", SQL_VARCHAR, 128, SQL_NO_NULLS}, + {"REMARKS", SQL_VARCHAR, 254, SQL_NULLABLE} + }; + + auto entries = GetPatternEntries(tableName); + + TTable table; + table.reserve(entries.size()); + + for (const auto& entry : entries) { + auto tableType = GetTableType(entry.Type); + if (!tableType) { + continue; + } + + table.push_back({ + TValueBuilder().OptionalUtf8(std::nullopt).Build(), + TValueBuilder().OptionalUtf8(std::nullopt).Build(), + TValueBuilder().Utf8(entry.Name).Build(), + TValueBuilder().Utf8(*tableType).Build(), + TValueBuilder().OptionalUtf8(std::nullopt).Build(), + }); + } + + Cursor_ = CreateVirtualCursor(this, columns, table); + return SQL_SUCCESS; +} + +std::vector TStatement::GetPatternEntries(const std::string& pattern) { + std::vector entries; + VisitEntry("", pattern, entries); + return entries; +} + +SQLRETURN TStatement::VisitEntry(const std::string& path, const std::string& pattern, std::vector& resultEntries) { + auto schemeClient = Conn_->GetSchemeClient(); + auto listDirectoryResult = schemeClient->ListDirectory(path + "/").ExtractValueSync(); + NStatusHelpers::ThrowOnError(listDirectoryResult); + + for (const auto& entry : listDirectoryResult.GetChildren()) { + std::string fullPath = path + "/" + entry.Name; + if (entry.Type == NScheme::ESchemeEntryType::Directory || + entry.Type == NScheme::ESchemeEntryType::SubDomain) { + VisitEntry(fullPath, pattern, resultEntries); + } else if (IsPatternMatch(fullPath, pattern)) { + NScheme::TSchemeEntry entryCopy = entry; + entryCopy.Name = fullPath; + resultEntries.push_back(entryCopy); + } + } + return SQL_SUCCESS; +} + +bool TStatement::IsPatternMatch(const std::string& path, const std::string& pattern) { + if (pattern.empty()) { + return true; + } + if (Attributes_.GetMetadataId() == SQL_TRUE) { + return path == pattern; + } + return SqlLikeMatch(path, pattern); +} + +std::optional TStatement::GetTableType(NScheme::ESchemeEntryType type) { + switch (type) { + case NScheme::ESchemeEntryType::Table: + return "TABLE"; + case NScheme::ESchemeEntryType::View: + return "VIEW"; + case NScheme::ESchemeEntryType::ColumnStore: + return "COLUMN_STORE"; + case NScheme::ESchemeEntryType::ColumnTable: + return "COLUMN_TABLE"; + case NScheme::ESchemeEntryType::Sequence: + return "SEQUENCE"; + case NScheme::ESchemeEntryType::Replication: + return "REPLICATION"; + case NScheme::ESchemeEntryType::Topic: + return "TOPIC"; + case NScheme::ESchemeEntryType::ExternalTable: + return "EXTERNAL_TABLE"; + case NScheme::ESchemeEntryType::ExternalDataSource: + return "EXTERNAL_DATA_SOURCE"; + case NScheme::ESchemeEntryType::ResourcePool: + return "RESOURCE_POOL"; + case NScheme::ESchemeEntryType::PqGroup: + return "PQ_GROUP"; + case NScheme::ESchemeEntryType::RtmrVolume: + return "RTMR_VOLUME"; + case NScheme::ESchemeEntryType::BlockStoreVolume: + return "BLOCK_STORE_VOLUME"; + case NScheme::ESchemeEntryType::CoordinationNode: + return "COORDINATION_NODE"; + case NScheme::ESchemeEntryType::Unknown: + return "UNKNOWN"; + case NScheme::ESchemeEntryType::SysView: + return "SYSTEM VIEW"; + case NScheme::ESchemeEntryType::Transfer: + return "TRANSFER"; + case NScheme::ESchemeEntryType::Directory: + case NScheme::ESchemeEntryType::SubDomain: + return std::nullopt; + default: + return std::nullopt; + } +} + +SQLRETURN TStatement::Close(bool force) { + if (!force && !Cursor_) { + throw TOdbcException("24000", 0, "Invalid handle"); + } + + Cursor_.reset(); + RowsFetched_ = 0; + ClearErrors(); + return SQL_SUCCESS; +} + +void TStatement::UnbindColumns() { + BoundColumns_.clear(); +} + +void TStatement::ResetParams() { + BoundParams_.clear(); +} + +SQLRETURN TStatement::RowCount(SQLLEN* rowCount) { + if (!rowCount) { + throw TOdbcException("HY000", 0, "Invalid parameter"); + } + + *rowCount = -1; + return SQL_SUCCESS; +} + +SQLRETURN TStatement::NumResultCols(SQLSMALLINT* colCount) { + if (!colCount) { + throw TOdbcException("HY000", 0, "Invalid parameter"); + } + if (!Cursor_) { + *colCount = 0; + return SQL_SUCCESS; + } + *colCount = static_cast(Cursor_->GetColumnMeta().size()); + return SQL_SUCCESS; +} + +const std::vector& TStatement::GetColumnMeta() const { + static const std::vector EmptyColumns; + return Cursor_ ? Cursor_->GetColumnMeta() : EmptyColumns; +} + +SQLRETURN TStatement::SetStmtAttr(SQLINTEGER attr, SQLPOINTER value, SQLINTEGER stringLength) { + return Attributes_.SetStmtAttr(attr, value, stringLength, *this); +} + +SQLRETURN TStatement::GetStmtAttr(SQLINTEGER attr, SQLPOINTER value, SQLINTEGER bufferLength, SQLINTEGER* stringLengthPtr) { + return Attributes_.GetStmtAttr(attr, value, bufferLength, stringLengthPtr, *this); +} + +SQLRETURN TStatement::GetDiagField( + SQLSMALLINT recNumber, + SQLSMALLINT diagIdentifier, + SQLPOINTER diagInfoPtr, + SQLSMALLINT bufferLength, + SQLSMALLINT* stringLengthPtr) { + if (recNumber == 0 && diagIdentifier == SQL_DIAG_ROW_COUNT) { + if (!diagInfoPtr) { + return SQL_ERROR; + } + *reinterpret_cast(diagInfoPtr) = -1; + return SQL_SUCCESS; + } + return TErrorManager::GetDiagField(recNumber, diagIdentifier, diagInfoPtr, bufferLength, stringLengthPtr); +} + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/statement.h b/odbc/src/statement.h new file mode 100644 index 00000000000..e74ce58b84e --- /dev/null +++ b/odbc/src/statement.h @@ -0,0 +1,90 @@ +#pragma once + +#include "connection.h" +#include "statement_attr.h" +#include "utils/error_manager.h" +#include "utils/bindings.h" +#include "utils/cursor.h" + +#include + +#include +#include + +#include +#include +#include + + +namespace NYdb { +namespace NOdbc { + +class TStatement : public TErrorManager, public IBindingFiller { +public: + TStatement(TConnection* conn); + + SQLRETURN Prepare(const std::string& statementText); + SQLRETURN Execute(); + + SQLRETURN Fetch(); + SQLRETURN GetData(SQLUSMALLINT columnNumber, SQLSMALLINT targetType, + SQLPOINTER targetValue, SQLLEN bufferLength, SQLLEN* strLenOrInd); + + void FillBoundColumns() override; + void OnStreamPartError(const TStatus& status) override; + + SQLRETURN Close(bool force = false); + void UnbindColumns(); + void ResetParams(); + + SQLRETURN BindCol(SQLUSMALLINT columnNumber, SQLSMALLINT targetType, SQLPOINTER targetValue, SQLLEN bufferLength, SQLLEN* strLenOrInd); + SQLRETURN BindParameter(SQLUSMALLINT paramNumber, SQLSMALLINT inputOutputType, SQLSMALLINT valueType, SQLSMALLINT parameterType, SQLULEN columnSize, SQLSMALLINT decimalDigits, SQLPOINTER parameterValuePtr, SQLLEN bufferLength, SQLLEN* strLenOrIndPtr); + + SQLRETURN Columns(const std::string& catalogName, + const std::string& schemaName, + const std::string& tableName, + const std::string& columnName); + + SQLRETURN Tables(const std::string& catalogName, + const std::string& schemaName, + const std::string& tableName, + const std::string& tableType); + + SQLRETURN RowCount(SQLLEN* rowCount); + SQLRETURN NumResultCols(SQLSMALLINT* colCount); + const std::vector& GetColumnMeta() const; + SQLRETURN SetStmtAttr(SQLINTEGER attr, SQLPOINTER value, SQLINTEGER stringLength); + SQLRETURN GetStmtAttr(SQLINTEGER attr, SQLPOINTER value, SQLINTEGER bufferLength, SQLINTEGER* stringLengthPtr); + + SQLRETURN GetDiagField(SQLSMALLINT recNumber, SQLSMALLINT diagIdentifier, SQLPOINTER diagInfoPtr, SQLSMALLINT bufferLength, + SQLSMALLINT* stringLengthPtr) override; + + TConnection* GetConnection() { + return Conn_; + } + +private: + TConnection* Conn_; + std::unique_ptr Cursor_; + std::string PreparedQuery_; + bool IsPrepared_ = false; + + std::vector BoundColumns_; + std::vector BoundParams_; + bool StreamFetchError_ = false; + SQLULEN RowsFetched_ = 0; + TStatementAttributes Attributes_; + + SQLRETURN BuildParams(NYdb::TParams& out); + + NQuery::TExecuteQueryIterator CreateExecuteIterator(NQuery::TSession& session, const NYdb::TParams& params); + + NYdb::NRetry::TRetryOperationSettings MakeAutocommitRetrySettings(); + std::vector GetPatternEntries(const std::string& pattern); + SQLRETURN VisitEntry(const std::string& path, const std::string& pattern, std::vector& resultEntries); + bool IsPatternMatch(const std::string& path, const std::string& pattern); + std::optional GetTableType(NScheme::ESchemeEntryType type); +}; + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/statement_attr.cpp b/odbc/src/statement_attr.cpp new file mode 100644 index 00000000000..f0baad0016a --- /dev/null +++ b/odbc/src/statement_attr.cpp @@ -0,0 +1,101 @@ +#include "statement_attr.h" + +#include "utils/attr.h" +#include "utils/diag.h" + +#include + +namespace NYdb { +namespace NOdbc { + +SQLRETURN TStatementAttributes::SetStmtAttr( + SQLINTEGER attr, + SQLPOINTER value, + SQLINTEGER /*stringLength*/, + TErrorManager& errors) { + switch (attr) { + case SQL_ATTR_QUERY_TIMEOUT: { + const SQLINTEGER timeout = ReadIntegerAttr(value); + if (timeout < 0) { + return Diag::AddInvalidAttrValue(errors, "SQL_ATTR_QUERY_TIMEOUT"); + } + QueryTimeoutSec_ = static_cast(timeout); + return SQL_SUCCESS; + } + case SQL_ATTR_MAX_ROWS: { + const SQLLEN maxRows = ReadIntegerAttr(value); + if (maxRows < 0) { + return Diag::AddInvalidAttrValue(errors, "SQL_ATTR_MAX_ROWS"); + } + MaxRows_ = static_cast(maxRows); + return SQL_SUCCESS; + } + case SQL_ATTR_NOSCAN: { + const auto mode = ReadIntegerAttrIfIn(value, {SQL_NOSCAN_OFF, SQL_NOSCAN_ON}); + if (!mode) { + return Diag::AddInvalidAttrValue(errors, "SQL_ATTR_NOSCAN"); + } + NoScan_ = *mode; + return SQL_SUCCESS; + } + case SQL_ATTR_METADATA_ID: { + const auto mode = ReadIntegerAttrIfIn(value, {SQL_FALSE, SQL_TRUE}); + if (!mode) { + return Diag::AddInvalidAttrValue(errors, "SQL_ATTR_METADATA_ID"); + } + MetadataId_ = *mode; + return SQL_SUCCESS; + } + default: + return Diag::AddNotImplemented(errors); + } +} + +SQLRETURN TStatementAttributes::GetStmtAttr( + SQLINTEGER attr, + SQLPOINTER value, + SQLINTEGER /*bufferLength*/, + SQLINTEGER* stringLengthPtr, + TErrorManager& errors) const { + if (!value) { + return Diag::AddNullPointer(errors); + } + if (stringLengthPtr) { + *stringLengthPtr = 0; + } + switch (attr) { + case SQL_ATTR_QUERY_TIMEOUT: + *reinterpret_cast(value) = QueryTimeoutSec_; + return SQL_SUCCESS; + case SQL_ATTR_MAX_ROWS: + *reinterpret_cast(value) = MaxRows_; + return SQL_SUCCESS; + case SQL_ATTR_NOSCAN: + *reinterpret_cast(value) = NoScan_; + return SQL_SUCCESS; + case SQL_ATTR_METADATA_ID: + *reinterpret_cast(value) = MetadataId_; + return SQL_SUCCESS; + default: + return Diag::AddNotImplemented(errors); + } +} + +SQLUINTEGER TStatementAttributes::GetQueryTimeoutSec() const noexcept{ + return QueryTimeoutSec_; +} + +SQLULEN TStatementAttributes::GetMaxRows() const noexcept { + return MaxRows_; +} + +SQLULEN TStatementAttributes::GetNoScanMode() const noexcept { + return NoScan_; +} + +SQLULEN TStatementAttributes::GetMetadataId() const noexcept { + return MetadataId_; +} + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/statement_attr.h b/odbc/src/statement_attr.h new file mode 100644 index 00000000000..b0d6e9bd97f --- /dev/null +++ b/odbc/src/statement_attr.h @@ -0,0 +1,39 @@ +#pragma once + +#include "utils/error_manager.h" + +#include +#include + +namespace NYdb { +namespace NOdbc { + +class TStatementAttributes { +public: + SQLRETURN SetStmtAttr( + SQLINTEGER attr, + SQLPOINTER value, + SQLINTEGER stringLength, + TErrorManager& errors); + + SQLRETURN GetStmtAttr( + SQLINTEGER attr, + SQLPOINTER value, + SQLINTEGER bufferLength, + SQLINTEGER* stringLengthPtr, + TErrorManager& errors) const; + + SQLUINTEGER GetQueryTimeoutSec() const noexcept; + SQLULEN GetMaxRows() const noexcept; + SQLULEN GetNoScanMode() const noexcept; + SQLULEN GetMetadataId() const noexcept; + +private: + SQLUINTEGER QueryTimeoutSec_ = 0; + SQLULEN MaxRows_ = 0; + SQLULEN NoScan_ = SQL_NOSCAN_OFF; + SQLULEN MetadataId_ = SQL_FALSE; +}; + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/utils/attr.cpp b/odbc/src/utils/attr.cpp new file mode 100644 index 00000000000..1fb2a83324a --- /dev/null +++ b/odbc/src/utils/attr.cpp @@ -0,0 +1,51 @@ +#include "attr.h" +#include "diag.h" + +#include +#include + +namespace NYdb::NOdbc { + +std::string ReadAttributeString(SQLPOINTER value, SQLINTEGER stringLength) { + const char* const str = static_cast(value); + if (stringLength == SQL_NTS) { + return std::string(str); + } + if (stringLength < 0) { + return {}; + } + return std::string(str, static_cast(stringLength)); +} + +SQLRETURN WriteAttributeString( + const std::string& source, + SQLPOINTER value, + SQLINTEGER bufferLength, + SQLINTEGER* stringLengthPtr, + TErrorManager& errors) { + const SQLINTEGER length = static_cast(source.size()); + if (stringLengthPtr != nullptr) { + *stringLengthPtr = length; + } + if (value == nullptr) { + return SQL_SUCCESS; + } + if (bufferLength <= 0) { + return Diag::AddInvalidBufferLength(errors); + } + + auto* dest = static_cast(value); + const size_t maxData = static_cast(bufferLength - 1); + const size_t nCopy = std::min(source.size(), maxData); + if (nCopy > 0) { + std::memcpy(dest, source.data(), nCopy); + } + dest[nCopy] = 0; + + if (length >= bufferLength) { + return Diag::AddRightTruncated(errors); + } + return SQL_SUCCESS; +} + +} // namespace NYdb::NOdbc diff --git a/odbc/src/utils/attr.h b/odbc/src/utils/attr.h new file mode 100644 index 00000000000..96695c221aa --- /dev/null +++ b/odbc/src/utils/attr.h @@ -0,0 +1,39 @@ +#pragma once + +#include "error_manager.h" + +#include +#include +#include + +#include +#include + +namespace NYdb::NOdbc { + +std::string ReadAttributeString(SQLPOINTER value, SQLINTEGER stringLength); + +SQLRETURN WriteAttributeString( + const std::string& source, + SQLPOINTER value, + SQLINTEGER bufferLength, + SQLINTEGER* stringLengthPtr, + TErrorManager& errors); + +template +T ReadIntegerAttr(SQLPOINTER value) noexcept { + return static_cast(reinterpret_cast(value)); +} + +template +std::optional ReadIntegerAttrIfIn(SQLPOINTER value, std::initializer_list allowed) noexcept { + const T token = ReadIntegerAttr(value); + for (const T allowedValue : allowed) { + if (token == allowedValue) { + return token; + } + } + return std::nullopt; +} + +} // namespace NYdb::NOdbc diff --git a/odbc/src/utils/bindings.h b/odbc/src/utils/bindings.h new file mode 100644 index 00000000000..2480f5367af --- /dev/null +++ b/odbc/src/utils/bindings.h @@ -0,0 +1,41 @@ +#pragma once + +#include +#include + +#include + +namespace NYdb { +namespace NOdbc { + +struct TBoundParam { + SQLUSMALLINT ParamNumber; + SQLSMALLINT InputOutputType; + SQLSMALLINT ValueType; + SQLSMALLINT ParameterType; + SQLULEN ColumnSize; + SQLSMALLINT DecimalDigits; + SQLPOINTER ParameterValuePtr; + SQLLEN BufferLength; + SQLLEN* StrLenOrIndPtr; +}; + +struct TBoundColumn { + SQLUSMALLINT ColumnNumber; + SQLSMALLINT TargetType; + SQLPOINTER TargetValue; + SQLLEN BufferLength; + SQLLEN* StrLenOrInd; +}; + +class IBindingFiller { +public: + virtual void FillBoundColumns() = 0; + virtual void OnStreamPartError([[maybe_unused]] const TStatus& status) { + } + + virtual ~IBindingFiller() = default; +}; + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/utils/convert.cpp b/odbc/src/utils/convert.cpp new file mode 100644 index 00000000000..db7928ce659 --- /dev/null +++ b/odbc/src/utils/convert.cpp @@ -0,0 +1,539 @@ +#include "convert.h" + +#include +#include + +#include +#include + +namespace NYdb { +namespace NOdbc { + +template +struct TSqlTypeTraits; + +template<> struct TSqlTypeTraits { using Type = std::string; }; +template<> struct TSqlTypeTraits { using Type = std::string; }; +template<> struct TSqlTypeTraits { using Type = SQLBIGINT; }; +template<> struct TSqlTypeTraits { using Type = SQLUBIGINT; }; +template<> struct TSqlTypeTraits { using Type = SQLINTEGER; }; +template<> struct TSqlTypeTraits { using Type = SQLINTEGER; }; +template<> struct TSqlTypeTraits { using Type = SQLUINTEGER; }; +template<> struct TSqlTypeTraits { using Type = SQLSMALLINT; }; +template<> struct TSqlTypeTraits { using Type = SQLSMALLINT; }; +template<> struct TSqlTypeTraits { using Type = SQLUSMALLINT; }; +template<> struct TSqlTypeTraits { using Type = SQLSCHAR; }; +template<> struct TSqlTypeTraits { using Type = SQLCHAR; }; +template<> struct TSqlTypeTraits { using Type = SQLDOUBLE; }; +template<> struct TSqlTypeTraits { using Type = SQLFLOAT; }; +template<> struct TSqlTypeTraits { using Type = SQLCHAR; }; + +template +struct TTypedValue { + using TSrcType = typename TSqlTypeTraits::Type; + + TSrcType Data; + + TTypedValue(const TBoundParam& param) { + Data = *static_cast(param.ParameterValuePtr); + } +}; + +template<> +TTypedValue::TTypedValue(const TBoundParam& param) { + if (param.StrLenOrIndPtr && *param.StrLenOrIndPtr == SQL_NULL_DATA) { + Data.clear(); + return; + } + + const char* ptr = static_cast(param.ParameterValuePtr); + if (!ptr) { + Data.clear(); + return; + } + + if (param.StrLenOrIndPtr) { + SQLLEN len = *param.StrLenOrIndPtr; + if (len == SQL_NTS) { + Data = std::string(ptr); + } else if (len >= 0) { + Data = std::string(ptr, static_cast(len)); + } else { + Data = std::string(ptr, param.BufferLength); + } + } else { + Data = std::string(ptr, param.BufferLength); + } +} + +template<> +TTypedValue::TTypedValue(const TBoundParam& param) { + if (param.StrLenOrIndPtr && *param.StrLenOrIndPtr == SQL_NULL_DATA) { + Data.clear(); + return; + } + + const char* ptr = static_cast(param.ParameterValuePtr); + if (!ptr) { + Data.clear(); + return; + } + + if (param.StrLenOrIndPtr && *param.StrLenOrIndPtr >= 0) { + Data = std::string(ptr, static_cast(*param.StrLenOrIndPtr)); + } else { + Data = std::string(ptr, param.BufferLength); + } +} + +class IConverter { +public: + virtual void AddToBuilder(const TBoundParam& param, TParamValueBuilder& builder) = 0; + + virtual ~IConverter() = default; +}; + +template +class TConverter : public IConverter { +public: + virtual void AddToBuilder(const TBoundParam& param, TParamValueBuilder& builder) override { + TTypedValue value(param); + Convert(param, std::move(value.Data), builder); + if (param.StrLenOrIndPtr && *param.StrLenOrIndPtr == SQL_NULL_DATA) { + builder.EmptyOptional(GetType()); + } + builder.Build(); + } + +private: + void Convert(const TBoundParam& param, TTypedValue::TSrcType&& data, TParamValueBuilder& builder); + TType GetType(); +}; + +class TConverterRegistry { +public: + static TConverterRegistry& GetInstance() { + static TConverterRegistry instance; + return instance; + } + + void RegisterConverter(SQLSMALLINT cType, SQLSMALLINT sqlType, std::unique_ptr converter) { + Converters_.emplace(std::make_pair(cType, sqlType), std::move(converter)); + } + + IConverter* GetConverter(SQLSMALLINT cType, SQLSMALLINT sqlType) { + auto it = Converters_.find(std::make_pair(cType, sqlType)); + if (it != Converters_.end()) { + return it->second.get(); + } + return nullptr; + } + +private: + std::map, std::unique_ptr> Converters_; +}; + +#define REGISTER_CONVERTER(CType, SqlType, YdbType) \ + struct TConverterRegistration##CType##SqlType { \ + TConverterRegistration##CType##SqlType() { \ + TConverterRegistry::GetInstance().RegisterConverter(CType, SqlType, std::make_unique>()); \ + } \ + }; \ + static const TConverterRegistration##CType##SqlType converterRegistration##CType##SqlType; \ + template<> \ + TType TConverter::GetType() { \ + return TTypeBuilder().Primitive(YdbType).Build(); \ + } \ + template<> \ + void TConverter::Convert(const TBoundParam& param, TTypedValue::TSrcType&& data, TParamValueBuilder& builder) + +// Integer types + +REGISTER_CONVERTER(SQL_C_SBIGINT, SQL_BIGINT, EPrimitiveType::Int64) { + builder.OptionalInt64(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_LONG, SQL_BIGINT, EPrimitiveType::Int64) { + builder.OptionalInt64(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_SLONG, SQL_BIGINT, EPrimitiveType::Int64) { + builder.OptionalInt64(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_SHORT, SQL_BIGINT, EPrimitiveType::Int64) { + builder.OptionalInt64(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_TINYINT, SQL_BIGINT, EPrimitiveType::Int64) { + builder.OptionalInt64(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_UBIGINT, SQL_BIGINT, EPrimitiveType::Uint64) { + builder.OptionalUint64(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_ULONG, SQL_BIGINT, EPrimitiveType::Uint64) { + builder.OptionalUint64(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_USHORT, SQL_BIGINT, EPrimitiveType::Uint64) { + builder.OptionalUint64(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_UTINYINT, SQL_BIGINT, EPrimitiveType::Uint64) { + builder.OptionalUint64(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_SBIGINT, SQL_INTEGER, EPrimitiveType::Int32) { + builder.OptionalInt32(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_LONG, SQL_INTEGER, EPrimitiveType::Int32) { + builder.OptionalInt32(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_SLONG, SQL_INTEGER, EPrimitiveType::Int32) { + builder.OptionalInt32(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_SHORT, SQL_INTEGER, EPrimitiveType::Int32) { + builder.OptionalInt32(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_TINYINT, SQL_INTEGER, EPrimitiveType::Int32) { + builder.OptionalInt32(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_UBIGINT, SQL_INTEGER, EPrimitiveType::Uint32) { + builder.OptionalUint32(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_ULONG, SQL_INTEGER, EPrimitiveType::Uint32) { + builder.OptionalUint32(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_USHORT, SQL_INTEGER, EPrimitiveType::Uint32) { + builder.OptionalUint32(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_UTINYINT, SQL_INTEGER, EPrimitiveType::Uint32) { + builder.OptionalUint32(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_SBIGINT, SQL_SMALLINT, EPrimitiveType::Int16) { + builder.OptionalInt16(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_LONG, SQL_SMALLINT, EPrimitiveType::Int16) { + builder.OptionalInt16(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_SLONG, SQL_SMALLINT, EPrimitiveType::Int16) { + builder.OptionalInt16(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_SHORT, SQL_SMALLINT, EPrimitiveType::Int16) { + builder.OptionalInt16(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_TINYINT, SQL_SMALLINT, EPrimitiveType::Int16) { + builder.OptionalInt16(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_UBIGINT, SQL_SMALLINT, EPrimitiveType::Uint16) { + builder.OptionalUint16(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_ULONG, SQL_SMALLINT, EPrimitiveType::Uint16) { + builder.OptionalUint16(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_USHORT, SQL_SMALLINT, EPrimitiveType::Uint16) { + builder.OptionalUint16(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_UTINYINT, SQL_SMALLINT, EPrimitiveType::Uint16) { + builder.OptionalUint16(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_SBIGINT, SQL_TINYINT, EPrimitiveType::Int8) { + builder.OptionalInt8(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_LONG, SQL_TINYINT, EPrimitiveType::Int8) { + builder.OptionalInt8(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_SLONG, SQL_TINYINT, EPrimitiveType::Int8) { + builder.OptionalInt8(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_SHORT, SQL_TINYINT, EPrimitiveType::Int8) { + builder.OptionalInt8(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_TINYINT, SQL_TINYINT, EPrimitiveType::Int8) { + builder.OptionalInt8(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_UBIGINT, SQL_TINYINT, EPrimitiveType::Uint8) { + builder.OptionalUint8(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_ULONG, SQL_TINYINT, EPrimitiveType::Uint8) { + builder.OptionalUint8(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_USHORT, SQL_TINYINT, EPrimitiveType::Uint8) { + builder.OptionalUint8(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_UTINYINT, SQL_TINYINT, EPrimitiveType::Uint8) { + builder.OptionalUint8(static_cast(data)); +} + +// Floating point types + +REGISTER_CONVERTER(SQL_C_FLOAT, SQL_REAL, EPrimitiveType::Float) { + builder.OptionalFloat(data); +} + +REGISTER_CONVERTER(SQL_C_DOUBLE, SQL_FLOAT, EPrimitiveType::Double) { + builder.OptionalDouble(data); +} + +REGISTER_CONVERTER(SQL_C_DOUBLE, SQL_DOUBLE, EPrimitiveType::Double) { + builder.OptionalDouble(data); +} + +// String types + +REGISTER_CONVERTER(SQL_C_CHAR, SQL_CHAR, EPrimitiveType::Utf8) { + builder.OptionalUtf8(std::move(data)); +} + +REGISTER_CONVERTER(SQL_C_CHAR, SQL_VARCHAR, EPrimitiveType::Utf8) { + builder.OptionalUtf8(std::move(data)); +} + +REGISTER_CONVERTER(SQL_C_CHAR, SQL_LONGVARCHAR, EPrimitiveType::Utf8) { + builder.OptionalUtf8(std::move(data)); +} + +// Binary types + +REGISTER_CONVERTER(SQL_C_BINARY, SQL_BINARY, EPrimitiveType::String) { + builder.OptionalString(std::move(data)); +} + +REGISTER_CONVERTER(SQL_C_BINARY, SQL_VARBINARY, EPrimitiveType::String) { + builder.OptionalString(std::move(data)); +} + +REGISTER_CONVERTER(SQL_C_BINARY, SQL_LONGVARBINARY, EPrimitiveType::String) { + builder.OptionalString(std::move(data)); +} + +#undef REGISTER_CONVERTER + +SQLRETURN ConvertParam(const TBoundParam& param, TParamValueBuilder& builder) { + auto converter = TConverterRegistry::GetInstance().GetConverter(param.ValueType, param.ParameterType); + if (!converter) { + return SQL_ERROR; + } + + converter->AddToBuilder(param, builder); + return SQL_SUCCESS; +} + +SQLRETURN ConvertColumn(TValueParser& parser, SQLSMALLINT targetType, SQLPOINTER targetValue, SQLLEN bufferLength, SQLLEN* strLenOrInd) { + if (parser.IsNull()) { + if (strLenOrInd) { + *strLenOrInd = SQL_NULL_DATA; + } + return SQL_SUCCESS; + } + + if (parser.GetKind() == TTypeParser::ETypeKind::Optional) { + parser.OpenOptional(); + SQLRETURN ret = ConvertColumn(parser, targetType, targetValue, bufferLength, strLenOrInd); + parser.CloseOptional(); + return ret; + } + + if (parser.GetKind() != TTypeParser::ETypeKind::Primitive) { + return SQL_ERROR; + } + + EPrimitiveType ydbType = parser.GetPrimitiveType(); + + switch (targetType) { + case SQL_C_SHORT: + case SQL_C_SSHORT: + { + SQLSMALLINT v = 0; + switch (ydbType) { + case EPrimitiveType::Int16: v = parser.GetInt16(); break; + case EPrimitiveType::Uint16: v = static_cast(parser.GetUint16()); break; + case EPrimitiveType::Int8: v = static_cast(parser.GetInt8()); break; + case EPrimitiveType::Uint8: v = static_cast(parser.GetUint8()); break; + case EPrimitiveType::Int32: v = static_cast(parser.GetInt32()); break; + case EPrimitiveType::Uint32: v = static_cast(parser.GetUint32()); break; + case EPrimitiveType::Bool: v = parser.GetBool() ? 1 : 0; break; + default: return SQL_ERROR; + } + if (targetValue) { + *reinterpret_cast(targetValue) = v; + } + if (strLenOrInd) { + *strLenOrInd = sizeof(SQLSMALLINT); + } + return SQL_SUCCESS; + } + case SQL_C_SLONG: + case SQL_C_LONG: + { + int32_t v = 0; + switch (ydbType) { + case EPrimitiveType::Int16: v = static_cast(parser.GetInt16()); break; + case EPrimitiveType::Uint16: v = static_cast(parser.GetUint16()); break; + case EPrimitiveType::Int8: v = static_cast(parser.GetInt8()); break; + case EPrimitiveType::Uint8: v = static_cast(parser.GetUint8()); break; + case EPrimitiveType::Int32: v = static_cast(parser.GetInt32()); break; + case EPrimitiveType::Uint32: v = static_cast(parser.GetUint32()); break; + case EPrimitiveType::Int64: v = static_cast(parser.GetInt64()); break; + case EPrimitiveType::Uint64: v = static_cast(parser.GetUint64()); break; + case EPrimitiveType::Bool: v = parser.GetBool() ? 1 : 0; break; + default: return SQL_ERROR; + } + if (targetValue) { + *reinterpret_cast(targetValue) = v; + } + if (strLenOrInd) { + *strLenOrInd = sizeof(int32_t); + } + return SQL_SUCCESS; + } + case SQL_C_SBIGINT: + { + SQLBIGINT v = 0; + switch (ydbType) { + case EPrimitiveType::Int64: v = parser.GetInt64(); break; + case EPrimitiveType::Uint64: v = static_cast(parser.GetUint64()); break; + case EPrimitiveType::Int32: v = static_cast(parser.GetInt32()); break; + case EPrimitiveType::Uint32: v = static_cast(parser.GetUint32()); break; + default: return SQL_ERROR; + } + if (targetValue) { + *reinterpret_cast(targetValue) = v; + } + if (strLenOrInd) { + *strLenOrInd = sizeof(SQLBIGINT); + } + return SQL_SUCCESS; + } + case SQL_C_DOUBLE: + { + double v = 0.0; + switch (ydbType) { + case EPrimitiveType::Double: v = parser.GetDouble(); break; + case EPrimitiveType::Float: v = parser.GetFloat(); break; + default: return SQL_ERROR; + } + if (targetValue) { + *reinterpret_cast(targetValue) = v; + } + if (strLenOrInd) { + *strLenOrInd = sizeof(double); + } + return SQL_SUCCESS; + } + case SQL_C_CHAR: + { + std::string str; + switch (ydbType) { + case EPrimitiveType::Utf8: str = parser.GetUtf8(); break; + case EPrimitiveType::String: str = parser.GetString(); break; + case EPrimitiveType::Json: str = parser.GetJson(); break; + case EPrimitiveType::JsonDocument: str = parser.GetJsonDocument(); break; + case EPrimitiveType::Date: { + const TString t = parser.GetDate().FormatGmTime("%Y-%m-%d"); + str.assign(t.data(), t.size()); + break; + } + case EPrimitiveType::Date32: { + const auto days = parser.GetDate32().time_since_epoch(); + if (days.count() < 0) { + return SQL_ERROR; + } + const TString t = + TInstant::Days(static_cast(days.count())).FormatGmTime("%Y-%m-%d"); + str.assign(t.data(), t.size()); + break; + } + case EPrimitiveType::Datetime: { + const TString t = parser.GetDatetime().FormatGmTime("%Y-%m-%d %H:%M:%S"); + str.assign(t.data(), t.size()); + break; + } + case EPrimitiveType::Datetime64: { + const auto secs = parser.GetDatetime64().time_since_epoch(); + if (secs.count() < 0) { + return SQL_ERROR; + } + const TString t = TInstant::Seconds(static_cast(static_cast(secs.count()))) + .FormatGmTime("%Y-%m-%d %H:%M:%S"); + str.assign(t.data(), t.size()); + break; + } + case EPrimitiveType::Timestamp: { + const TString t = parser.GetTimestamp().FormatGmTime("%Y-%m-%d %H:%M:%S"); + str.assign(t.data(), t.size()); + break; + } + case EPrimitiveType::Timestamp64: { + const auto micros = parser.GetTimestamp64().time_since_epoch(); + if (micros.count() < 0) { + return SQL_ERROR; + } + const TString t = + TInstant::MicroSeconds(static_cast(static_cast(micros.count()))) + .FormatGmTime("%Y-%m-%d %H:%M:%S"); + str.assign(t.data(), t.size()); + break; + } + case EPrimitiveType::TzDate: str = parser.GetTzDate(); break; + case EPrimitiveType::TzDatetime: str = parser.GetTzDatetime(); break; + case EPrimitiveType::TzTimestamp: str = parser.GetTzTimestamp(); break; + default: return SQL_ERROR; + } + SQLLEN len = str.size(); + if (targetValue && bufferLength > 0) { + SQLLEN copyLen = std::min(len, bufferLength - 1); + memcpy(targetValue, str.data(), copyLen); + reinterpret_cast(targetValue)[copyLen] = 0; + } + if (strLenOrInd) { + *strLenOrInd = len; + } + return SQL_SUCCESS; + } + case SQL_C_BIT: + { + char v = parser.GetBool() ? 1 : 0; + if (targetValue) { + *reinterpret_cast(targetValue) = v; + } + if (strLenOrInd) { + *strLenOrInd = sizeof(char); + } + return SQL_SUCCESS; + } + default: + return SQL_ERROR; + } +} + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/utils/convert.h b/odbc/src/utils/convert.h new file mode 100644 index 00000000000..8f8195ba1c8 --- /dev/null +++ b/odbc/src/utils/convert.h @@ -0,0 +1,18 @@ +#pragma once + +#include "bindings.h" + +#include + +#include +#include + +namespace NYdb { +namespace NOdbc { + +SQLRETURN ConvertParam(const TBoundParam& param, TParamValueBuilder& builder); +SQLRETURN ConvertColumn(TValueParser& parser, SQLSMALLINT targetType, SQLPOINTER targetValue, SQLLEN bufferLength, SQLLEN* strLenOrInd); + +} // namespace NOdbc +} // namespace NYdb + diff --git a/odbc/src/utils/cursor.cpp b/odbc/src/utils/cursor.cpp new file mode 100644 index 00000000000..533f0b20217 --- /dev/null +++ b/odbc/src/utils/cursor.cpp @@ -0,0 +1,176 @@ +#include "cursor.h" + +#include "convert.h" +#include "types.h" + +#include + +namespace NYdb { +namespace NOdbc { + +namespace { + +NYdb::TStatus StatusFrom(const NYdb::TStatus& ydbStatus) { + return NYdb::TStatus(ydbStatus.GetStatus(), NYdb::NIssue::TIssues(ydbStatus.GetIssues())); +} + +NYdb::TStatus PrefetchFirstResultSet( + NQuery::TExecuteQueryIterator& iterator, + std::optional* resultSet) { + resultSet->reset(); + while (true) { + auto part = iterator.ReadNext().ExtractValueSync(); + if (part.EOS()) { + break; + } + if (!part.IsSuccess()) { + return StatusFrom(part); + } + if (part.HasResultSet()) { + resultSet->emplace(part.ExtractResultSet()); + return NYdb::TStatus(EStatus::SUCCESS, NYdb::NIssue::TIssues()); + } + } + return NYdb::TStatus(EStatus::SUCCESS, NYdb::NIssue::TIssues()); +} + +} // namespace + +class TExecCursor : public ICursor { +public: + TExecCursor(IBindingFiller* bindingFiller, NQuery::TExecuteQueryIterator iterator, + std::optional firstResultSet) + : BindingFiller_(bindingFiller) + , Iterator_(std::move(iterator)) + { + if (firstResultSet) { + InitResultSet(std::move(*firstResultSet)); + } + } + + bool Fetch() override { + while (true) { + if (ResultSetParser_) { + if (ResultSetParser_->TryNextRow()) { + BindingFiller_->FillBoundColumns(); + return true; + } + ResultSetParser_.reset(); + } + NQuery::TExecuteQueryPart part = Iterator_.ReadNext().ExtractValueSync(); + if (part.EOS()) { + return false; + } + if (!part.IsSuccess()) { + BindingFiller_->OnStreamPartError(part); + return false; + } + if (part.HasResultSet()) { + InitResultSet(part.ExtractResultSet()); + } + } + return false; + } + + SQLRETURN GetData(SQLUSMALLINT columnNumber, SQLSMALLINT targetType, + SQLPOINTER targetValue, SQLLEN bufferLength, SQLLEN* strLenOrInd) override { + if (!ResultSetParser_) { + return SQL_NO_DATA; + } + if (columnNumber < 1 || columnNumber > ResultSetParser_->ColumnsCount()) { + return SQL_ERROR; + } + return ConvertColumn(ResultSetParser_->ColumnParser(columnNumber - 1), targetType, targetValue, bufferLength, strLenOrInd); + } + + const std::vector& GetColumnMeta() const override { + return Columns_; + } + +private: + void InitResultSet(TResultSet resultSet) { + Columns_.clear(); + FillColumnsMeta(resultSet); + ResultSetParser_ = std::make_unique(std::move(resultSet)); + } + + void FillColumnsMeta(const TResultSet& resultSet) { + for (const auto& col : resultSet.GetColumnsMeta()) { + const SQLSMALLINT sqlType = GetTypeId(col.Type); + Columns_.push_back(TColumnMeta{ + col.Name, + sqlType, + GetColumnSize(sqlType), + IsNullable(col.Type), + GetDecimalDigits(col.Type).value_or(0)}); + } + } + + IBindingFiller* BindingFiller_; + NQuery::TExecuteQueryIterator Iterator_; + std::unique_ptr ResultSetParser_; + std::vector Columns_; +}; + +class TVirtualCursor : public ICursor { +public: + TVirtualCursor(IBindingFiller* bindingFiller, const std::vector& columns, const TTable& table) + : BindingFiller_(bindingFiller) + , Columns_(columns) + , Table_(table) + {} + + bool Fetch() override { + Cursor_++; + if (Cursor_ >= static_cast(Table_.size())) { + return false; + } + BindingFiller_->FillBoundColumns(); + return true; + } + + SQLRETURN GetData(SQLUSMALLINT columnNumber, SQLSMALLINT targetType, + SQLPOINTER targetValue, SQLLEN bufferLength, SQLLEN* strLenOrInd) override { + if (Cursor_ >= static_cast(Table_.size())) { + return SQL_NO_DATA; + } + if (Cursor_ < 0 || columnNumber < 1 || columnNumber > Columns_.size()) { + return SQL_ERROR; + } + TValueParser parser{Table_[Cursor_][columnNumber - 1]}; + return ConvertColumn(parser, targetType, targetValue, bufferLength, strLenOrInd); + } + + const std::vector& GetColumnMeta() const override { + return Columns_; + } + +private: + IBindingFiller* BindingFiller_; + std::vector Columns_; + TTable Table_; + int64_t Cursor_ = -1; +}; + +TExecCursorCreateResult TryCreateExecCursor( + IBindingFiller* bindingFiller, + NQuery::TExecuteQueryIterator iterator) { + std::optional firstResultSet; + const NYdb::TStatus prefetchStatus = PrefetchFirstResultSet(iterator, &firstResultSet); + if (!prefetchStatus.IsSuccess()) { + return {prefetchStatus, nullptr}; + } + if (!firstResultSet) { + return {NYdb::TStatus(EStatus::SUCCESS, NYdb::NIssue::TIssues()), nullptr}; + } + return { + NYdb::TStatus(EStatus::SUCCESS, NYdb::NIssue::TIssues()), + std::make_unique(bindingFiller, std::move(iterator), std::move(firstResultSet))}; +} + +std::unique_ptr CreateVirtualCursor(IBindingFiller* bindingFiller, const std::vector& columns, const TTable& table) { + return std::make_unique(bindingFiller, columns, table); +} + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/utils/cursor.h b/odbc/src/utils/cursor.h new file mode 100644 index 00000000000..4fa2682e59b --- /dev/null +++ b/odbc/src/utils/cursor.h @@ -0,0 +1,49 @@ +#pragma once + +#include "bindings.h" + +#include +#include + +#include + +#include +#include +#include +#include + +namespace NYdb { +namespace NOdbc { + +struct TColumnMeta { + std::string Name; + SQLSMALLINT SqlType; + SQLULEN Size; + SQLSMALLINT Nullable; + SQLSMALLINT DecimalDigits = 0; +}; + +using TTable = std::vector>; + +class ICursor { +public: + virtual ~ICursor() = default; + virtual bool Fetch() = 0; + virtual SQLRETURN GetData(SQLUSMALLINT columnNumber, SQLSMALLINT targetType, + SQLPOINTER targetValue, SQLLEN bufferLength, SQLLEN* strLenOrInd) = 0; + virtual const std::vector& GetColumnMeta() const = 0; +}; + +struct TExecCursorCreateResult { + NYdb::TStatus Status; + std::unique_ptr Cursor; +}; + +TExecCursorCreateResult TryCreateExecCursor( + IBindingFiller* bindingFiller, + NYdb::NQuery::TExecuteQueryIterator iterator); + +std::unique_ptr CreateVirtualCursor(IBindingFiller* bindingFiller, const std::vector& columns, const TTable& table); + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/utils/diag.h b/odbc/src/utils/diag.h new file mode 100644 index 00000000000..5e2db740a07 --- /dev/null +++ b/odbc/src/utils/diag.h @@ -0,0 +1,33 @@ +#pragma once + +#include "error_manager.h" + +#include +#include + +namespace NYdb::NOdbc { +namespace Diag { + + inline SQLRETURN AddNullPointer(TErrorManager& errors) { + return errors.AddError("HY009", 0, "Invalid use of null pointer"); + } + + inline SQLRETURN AddNotImplemented(TErrorManager& errors) { + return errors.AddError("HYC00", 0, "Optional feature not implemented"); + } + + inline SQLRETURN AddInvalidAttrValue(TErrorManager& errors, std::string_view attrName) { + return errors.AddError("HY024", 0, "Invalid " + std::string(attrName) + " value"); + } + + inline SQLRETURN AddInvalidBufferLength(TErrorManager& errors) { + return errors.AddError("HY090", 0, "Invalid string or buffer length"); + } + + inline SQLRETURN AddRightTruncated(TErrorManager& errors) { + return errors.AddError("01004", 0, "String data, right truncated", SQL_SUCCESS_WITH_INFO); + } + +} + +} // namespace NYdb::NOdbc::Diag diff --git a/odbc/src/utils/error_manager.cpp b/odbc/src/utils/error_manager.cpp new file mode 100644 index 00000000000..8e540e20c83 --- /dev/null +++ b/odbc/src/utils/error_manager.cpp @@ -0,0 +1,226 @@ +#include "error_manager.h" + +#include +#include +#include +#include +#include + +namespace NYdb { +namespace NOdbc { + +namespace { + struct OdbcErrorMapping { + const char* sqlState; + const char* description; + SQLRETURN returnCode; + }; + + const std::unordered_map ERROR_MAPPINGS = { + {EStatus::SUCCESS, {"00000", "Success", SQL_SUCCESS}}, + {EStatus::BAD_REQUEST, {"42000", "Syntax error or access rule violation", SQL_ERROR}}, + {EStatus::UNAUTHORIZED, {"28000", "Invalid authorization specification", SQL_ERROR}}, + {EStatus::INTERNAL_ERROR, {"HY000", "General error", SQL_ERROR}}, + {EStatus::ABORTED, {"25000", "Invalid transaction state", SQL_ERROR}}, + {EStatus::UNAVAILABLE, {"08001", "Client unable to establish connection", SQL_ERROR}}, + {EStatus::OVERLOADED, {"HY000", "General error - server overloaded", SQL_ERROR}}, + {EStatus::SCHEME_ERROR, {"42S02", "Base table or view not found", SQL_ERROR}}, + {EStatus::GENERIC_ERROR, {"HY000", "General error", SQL_ERROR}}, + {EStatus::TIMEOUT, {"HYT00", "Timeout expired", SQL_ERROR}}, + {EStatus::BAD_SESSION, {"08003", "Connection does not exist", SQL_ERROR}}, + {EStatus::PRECONDITION_FAILED, {"23000", "Integrity constraint violation", SQL_ERROR}}, + {EStatus::ALREADY_EXISTS, {"23000", "Integrity constraint violation", SQL_ERROR}}, + {EStatus::NOT_FOUND, {"02000", "No data found", SQL_NO_DATA}}, + {EStatus::SESSION_EXPIRED, {"08003", "Connection does not exist", SQL_ERROR}}, + {EStatus::CANCELLED, {"HY008", "Operation canceled", SQL_ERROR}}, + {EStatus::UNDETERMINED, {"HY000", "General error", SQL_ERROR}}, + {EStatus::UNSUPPORTED, {"HYC00", "Optional feature not implemented", SQL_ERROR}}, + {EStatus::SESSION_BUSY, {"HY000", "General error - session busy", SQL_ERROR}}, + // Transport errors + {EStatus::TRANSPORT_UNAVAILABLE, {"08001", "Client unable to establish connection", SQL_ERROR}}, + {EStatus::CLIENT_RESOURCE_EXHAUSTED, {"HY000", "General error - resource exhausted", SQL_ERROR}}, + {EStatus::CLIENT_DEADLINE_EXCEEDED, {"HYT00", "Timeout expired", SQL_ERROR}}, + {EStatus::CLIENT_INTERNAL_ERROR, {"HY000", "General error", SQL_ERROR}}, + {EStatus::CLIENT_CANCELLED, {"HY008", "Operation canceled", SQL_ERROR}}, + {EStatus::CLIENT_UNAUTHENTICATED, {"28000", "Invalid authorization specification", SQL_ERROR}}, + {EStatus::CLIENT_LIMITS_REACHED, {"HY000", "General error - limits reached", SQL_ERROR}}, + {EStatus::CLIENT_DISCOVERY_FAILED, {"08001", "Client unable to establish connection", SQL_ERROR}}, + {EStatus::CLIENT_CALL_UNIMPLEMENTED, {"HYC00", "Optional feature not implemented", SQL_ERROR}}, + {EStatus::CLIENT_OUT_OF_RANGE, {"22003", "Numeric value out of range", SQL_ERROR}}, + }; + + const OdbcErrorMapping DEFAULT_ERROR_MAPPING = {"HY000", "Unknown YDB error", SQL_ERROR}; + + OdbcErrorMapping GetErrorMappingForStatus(EStatus status) { + auto it = ERROR_MAPPINGS.find(status); + if (it != ERROR_MAPPINGS.end()) { + return it->second; + } + return DEFAULT_ERROR_MAPPING; + } + + SQLRETURN WriteDiagCStr( + const std::string& str, + SQLPOINTER diagInfoPtr, + SQLSMALLINT bufferLength, + SQLSMALLINT* stringLengthPtr, + bool sqlStateField = false) { + std::string storage; + const std::string* src = &str; + if (sqlStateField) { + storage = str; + if (storage.size() < 5) { + storage.append(5U - storage.size(), ' '); + } else { + storage.resize(5U); + } + src = &storage; + } + const size_t fullLen = src->size(); + if (stringLengthPtr) { + *stringLengthPtr = static_cast( + std::min(fullLen, static_cast(std::numeric_limits::max()))); + } + if (!diagInfoPtr) { + return SQL_SUCCESS; + } + if (bufferLength < 0) { + return SQL_ERROR; + } + if (bufferLength == 0) { + return fullLen == 0 ? SQL_SUCCESS : SQL_SUCCESS_WITH_INFO; + } + auto* out = static_cast(diagInfoPtr); + const size_t maxData = static_cast(bufferLength - 1U); + const size_t copyLen = std::min(fullLen, maxData); + std::memcpy(out, src->data(), copyLen); + out[copyLen] = 0; + return (fullLen > maxData) ? SQL_SUCCESS_WITH_INFO : SQL_SUCCESS; + } + +} // namespace + +SQLRETURN TErrorManager::AddError(const std::string& sqlState, SQLINTEGER nativeError, const std::string& message, SQLRETURN returnCode) { + Errors_.push_back({sqlState, nativeError, message, returnCode}); + LastReturnCode_ = returnCode; + return returnCode; +} + +SQLRETURN TErrorManager::AddError(const TOdbcException& ex) { + Errors_.push_back({ex.GetSqlState(), ex.GetNativeError(), ex.GetMessage(), ex.GetReturnCode()}); + LastReturnCode_ = ex.GetReturnCode(); + return ex.GetReturnCode(); +} + +SQLRETURN TErrorManager::AddError(const TStatus& status) { + auto mapping = GetErrorMappingForStatus(status.GetStatus()); + std::string message = mapping.description; + if (!status.GetIssues().Empty()) { + message += ": " + status.GetIssues().ToString(); + } + Errors_.push_back({mapping.sqlState, static_cast(status.GetStatus()), message, mapping.returnCode}); + LastReturnCode_ = mapping.returnCode; + return mapping.returnCode; +} + +void TErrorManager::ClearErrors() { + Errors_.clear(); +} + +SQLRETURN TErrorManager::GetDiagRec(SQLSMALLINT recNumber, SQLCHAR* sqlState, SQLINTEGER* nativeError, + SQLCHAR* messageText, SQLSMALLINT bufferLength, SQLSMALLINT* textLength) { + if (recNumber < 1 || recNumber > (SQLSMALLINT)Errors_.size()) { + return SQL_NO_DATA; + } + + const auto& err = Errors_[recNumber-1]; + + if (sqlState) { + WriteDiagCStr(err.SqlState, sqlState, 6, nullptr, true); + } + + if (nativeError) { + *nativeError = err.NativeError; + } + + return WriteDiagCStr(err.Message, messageText, bufferLength, textLength, false); +} + +SQLRETURN TErrorManager::GetDiagField(SQLSMALLINT recNumber, SQLSMALLINT diagIdentifier, SQLPOINTER diagInfoPtr, + SQLSMALLINT bufferLength, SQLSMALLINT* stringLengthPtr) { + const SQLSMALLINT count = static_cast(Errors_.size()); + if (diagInfoPtr == nullptr) { + return SQL_ERROR; + } + if (recNumber == 0) { + switch (diagIdentifier) { + case SQL_DIAG_RETURNCODE: + *static_cast(diagInfoPtr) = LastReturnCode_; + return SQL_SUCCESS; + case SQL_DIAG_NUMBER: { + *static_cast(diagInfoPtr) = static_cast(count); + return SQL_SUCCESS; + } + case SQL_DIAG_ROW_COUNT: + return SQL_ERROR; + default: + return SQL_ERROR; + } + } + + if (recNumber < 1 || recNumber > count) { + return SQL_NO_DATA; + } + + const auto& err = Errors_[recNumber - 1]; + switch (diagIdentifier) { + case SQL_DIAG_SQLSTATE: + return WriteDiagCStr(err.SqlState, diagInfoPtr, bufferLength, stringLengthPtr, true); + case SQL_DIAG_NATIVE: { + *static_cast(diagInfoPtr) = err.NativeError; + return SQL_SUCCESS; + } + case SQL_DIAG_MESSAGE_TEXT: + return WriteDiagCStr(err.Message, diagInfoPtr, bufferLength, stringLengthPtr); + case SQL_DIAG_CLASS_ORIGIN: + return WriteDiagCStr("ODBC 3.0", diagInfoPtr, bufferLength, stringLengthPtr); + case SQL_DIAG_SUBCLASS_ORIGIN: + return WriteDiagCStr("ODBC 3.0", diagInfoPtr, bufferLength, stringLengthPtr); + case SQL_DIAG_CONNECTION_NAME: + case SQL_DIAG_SERVER_NAME: + return WriteDiagCStr("", diagInfoPtr, bufferLength, stringLengthPtr); + case SQL_DIAG_COLUMN_NUMBER: + *static_cast(diagInfoPtr) = SQL_COLUMN_NUMBER_UNKNOWN; + return SQL_SUCCESS; + case SQL_DIAG_ROW_NUMBER: + *static_cast(diagInfoPtr) = SQL_ROW_NUMBER_UNKNOWN; + return SQL_SUCCESS; + default: + return SQL_ERROR; + } +} + +SQLRETURN HandleOdbcExceptions( + SQLHANDLE handlePtr, + std::function&& func, + ENullInputHandlePolicy nullInputPolicy) { + if (!handlePtr && nullInputPolicy != ENullInputHandlePolicy::Allow) { + return SQL_INVALID_HANDLE; + } + + try { + const SQLRETURN r = func(); + if (handlePtr) { + static_cast(handlePtr)->SetLastReturnCode(r); + } + return r; + } catch (...) { + if (handlePtr) { + static_cast(handlePtr)->SetLastReturnCode(SQL_ERROR); + } + return SQL_ERROR; + } +} + +} // namespace NOdbc +} // namespace NYdb \ No newline at end of file diff --git a/odbc/src/utils/error_manager.h b/odbc/src/utils/error_manager.h new file mode 100644 index 00000000000..e08083ed1f0 --- /dev/null +++ b/odbc/src/utils/error_manager.h @@ -0,0 +1,119 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +namespace NYdb { +namespace NOdbc { + +struct TErrorInfo { + std::string SqlState; + SQLINTEGER NativeError; + std::string Message; + SQLRETURN ReturnCode; +}; + +using TErrorList = std::vector; + +class TOdbcException : public std::exception { +public: + TOdbcException(const std::string& sqlState, SQLINTEGER nativeError, + const std::string& message, SQLRETURN returnCode = SQL_ERROR) + : SqlState_(sqlState) + , NativeError_(nativeError) + , Message_(message) + , ReturnCode_(returnCode) + {} + + const std::string& GetSqlState() const { + return SqlState_; + } + + SQLINTEGER GetNativeError() const { + return NativeError_; + } + + const std::string& GetMessage() const { + return Message_; + } + + SQLRETURN GetReturnCode() const { + return ReturnCode_; + } + + const char* what() const noexcept override { + return Message_.c_str(); + } + +private: + std::string SqlState_; + SQLINTEGER NativeError_; + std::string Message_; + SQLRETURN ReturnCode_; +}; + +class TErrorManager { +public: + SQLRETURN AddError(const std::string& sqlState, SQLINTEGER nativeError, const std::string& message, SQLRETURN returnCode = SQL_ERROR); + SQLRETURN AddError(const TOdbcException& ex); + SQLRETURN AddError(const TStatus& status); + + void ClearErrors(); + + void SetLastReturnCode(SQLRETURN code) { + LastReturnCode_ = code; + } + [[nodiscard]] SQLRETURN GetLastReturnCode() const { + return LastReturnCode_; + } + + SQLRETURN GetDiagRec(SQLSMALLINT recNumber, SQLCHAR* sqlState, SQLINTEGER* nativeError, + SQLCHAR* messageText, SQLSMALLINT bufferLength, SQLSMALLINT* textLength); + virtual SQLRETURN GetDiagField(SQLSMALLINT recNumber, SQLSMALLINT diagIdentifier, + SQLPOINTER diagInfoPtr, SQLSMALLINT bufferLength, SQLSMALLINT* stringLengthPtr); + +private: + TErrorList Errors_; + SQLRETURN LastReturnCode_ = SQL_SUCCESS; +}; + +enum class ENullInputHandlePolicy : unsigned char { + Reject, + Allow, +}; + +template +SQLRETURN HandleOdbcExceptions(SQLHANDLE handlePtr, std::function&& func) { + if (!handlePtr) { + return SQL_INVALID_HANDLE; + } + auto handle = static_cast(handlePtr); + + try { + const SQLRETURN ret = func(handle); + handle->SetLastReturnCode(ret); + return ret; + } catch (const NStatusHelpers::TYdbErrorException& ex) { + return handle->AddError(ex.GetStatus()); + } catch (const TOdbcException& ex) { + return handle->AddError(ex); + } catch (const std::exception& ex) { + return handle->AddError("HY000", 0, ex.what()); + } catch (...) { + return handle->AddError("HY000", 0, "Unknown error"); + } +} + +SQLRETURN HandleOdbcExceptions( + SQLHANDLE handlePtr, + std::function&& func, + ENullInputHandlePolicy nullInputPolicy = ENullInputHandlePolicy::Reject); + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/utils/escape.cpp b/odbc/src/utils/escape.cpp new file mode 100644 index 00000000000..5a9c643eb7a --- /dev/null +++ b/odbc/src/utils/escape.cpp @@ -0,0 +1,444 @@ +#include "escape.h" + +#include +#include +#include +#include +#include + +namespace NYdb::NOdbc { +namespace { + +bool EqualNoCase(std::string_view lhs, std::string_view rhs) { + return lhs.size() == rhs.size() && + std::equal(lhs.begin(), lhs.end(), rhs.begin(), [](char leftCh, char rightCh) { + return std::tolower(static_cast(leftCh)) == + std::tolower(static_cast(rightCh)); + }); +} + +void SkipLeadingWhitespace(std::string_view sql, size_t& cursor) { + const auto strEnd = sql.end(); + const auto firstNonSpace = std::find_if_not( + sql.begin() + static_cast(cursor), + strEnd, + [](unsigned char byte) { + return std::isspace(byte) != 0; + }); + cursor = static_cast(firstNonSpace - sql.begin()); +} + +bool ReadIdent(std::string_view sql, size_t& cursor, std::string_view* outIdent) { + SkipLeadingWhitespace(sql, cursor); + const size_t identStart = cursor; + const auto afterIdent = std::find_if_not( + sql.begin() + static_cast(cursor), + sql.end(), + [](unsigned char byte) { + return std::isalpha(byte) != 0 || byte == '_'; + }); + cursor = static_cast(afterIdent - sql.begin()); + if (cursor == identStart) { + return false; + } + *outIdent = std::string_view(sql.data() + identStart, cursor - identStart); + return true; +} + +bool ParseSingleQuoted(std::string_view sql, size_t& cursor, std::string* outValue) { + SkipLeadingWhitespace(sql, cursor); + if (cursor >= sql.size() || sql[cursor] != '\'') { + return false; + } + ++cursor; + outValue->clear(); + while (cursor < sql.size()) { + if (sql[cursor] == '\'') { + if (cursor + 1 < sql.size() && sql[cursor + 1] == '\'') { + outValue->push_back('\''); + cursor += 2; + continue; + } + ++cursor; + return true; + } + outValue->push_back(sql[cursor++]); + } + return false; +} + +size_t FindMatchingCloseBrace(std::string_view sql, size_t openBrace) { + if (openBrace >= sql.size() || sql[openBrace] != '{') { + return std::string_view::npos; + } + int braceDepth = 1; + for (size_t idx = openBrace + 1; idx < sql.size(); ++idx) { + if (sql[idx] == '{') { + ++braceDepth; + } else if (sql[idx] == '}') { + --braceDepth; + if (braceDepth == 0) { + return idx; + } + } + } + return std::string_view::npos; +} + +std::string NormalizeOdbcTimestampLiteral(const std::string& raw) { + std::string normalized = raw; + const auto firstSpace = std::find(normalized.begin(), normalized.end(), ' '); + if (firstSpace != normalized.end()) { + *firstSpace = 'T'; + } + if (std::find(normalized.begin(), normalized.end(), 'Z') == normalized.end()) { + normalized.push_back('Z'); + } + return normalized; +} + +std::string ToUpperAscii(std::string_view sv) { + std::string upper; + upper.resize(sv.size()); + std::transform(sv.begin(), sv.end(), upper.begin(), [](unsigned char byte) { + return static_cast(std::toupper(byte)); + }); + return upper; +} + +std::string MapSqlTypeToken(std::string_view sqlType) { + static const std::unordered_map kMap = { + {"CHAR", "Utf8"}, + {"VARCHAR", "Utf8"}, + {"LONGVARCHAR", "Utf8"}, + {"WCHAR", "Utf8"}, + {"WVARCHAR", "Utf8"}, + {"WLONGVARCHAR", "Utf8"}, + {"BIT", "Bool"}, + {"TINYINT", "Int8"}, + {"SMALLINT", "Int16"}, + {"INTEGER", "Int32"}, + {"BIGINT", "Int64"}, + {"REAL", "Float"}, + {"FLOAT", "Double"}, + {"DOUBLE", "Double"}, + {"DECIMAL", "Decimal(22, 9)"}, + {"NUMERIC", "Decimal(22, 9)"}, + {"BINARY", "String"}, + {"VARBINARY", "String"}, + {"LONGVARBINARY", "String"}, + {"DATE", "Date"}, + {"TIME", "Time"}, + {"TIMESTAMP", "Datetime"}, + {"TYPE_DATE", "Date"}, + {"TYPE_TIME", "Time"}, + {"TYPE_TIMESTAMP", "Datetime"}, + }; + std::string key = ToUpperAscii(sqlType); + const std::string kSql = "SQL_"; + if (key.size() > kSql.size() && key.compare(0, kSql.size(), kSql) == 0) { + key.erase(0, kSql.size()); + } + const auto mapped = kMap.find(key); + if (mapped != kMap.end()) { + return mapped->second; + } + return key; +} + +std::string RewriteOdbcEscapesImpl(std::string_view sql); + + +enum class OdbcBraceKind { + OutputProcedureCall, // {?= call ... } + FnBody, // {fn ...} + OjBody, // {oj ...} + DateLiteral, // {d '...'} + TimeLiteral, // {t '...'} + TimestampLiteral, // {ts '...'} + ProcedureCall, // {call ...} + LikeEscape, // {escape '...'} +}; + +struct OdbcBraceParsed { + OdbcBraceKind Kind; + std::string_view RecurseTail; + std::string QuotedValue; +}; + +std::optional TryParseOutputCallBrace(std::string_view sql, size_t parsePos, size_t closeBrace) { + if (parsePos + 1 >= sql.size() || sql[parsePos] != '?' || sql[parsePos + 1] != '=') { + return std::nullopt; + } + size_t inner = parsePos + 2; + SkipLeadingWhitespace(sql, inner); + std::string_view keyword; + if (!ReadIdent(sql, inner, &keyword) || !EqualNoCase(keyword, "call")) { + return std::nullopt; + } + SkipLeadingWhitespace(sql, inner); + if (inner > closeBrace) { + return std::nullopt; + } + OdbcBraceParsed parsed; + parsed.Kind = OdbcBraceKind::OutputProcedureCall; + parsed.RecurseTail = std::string_view(sql.data() + inner, closeBrace - inner); + return parsed; +} + +std::optional MakeRecurseTailBrace(OdbcBraceKind kind, std::string_view sql, size_t& parsePos, size_t closeBrace) { + SkipLeadingWhitespace(sql, parsePos); + if (parsePos > closeBrace) { + return std::nullopt; + } + OdbcBraceParsed parsed; + parsed.Kind = kind; + parsed.RecurseTail = std::string_view(sql.data() + parsePos, closeBrace - parsePos); + return parsed; +} + +std::optional MakeQuotedBrace(OdbcBraceKind kind, std::string_view sql, size_t& parsePos, size_t closeBrace) { + std::string quotedLit; + if (!ParseSingleQuoted(sql, parsePos, "edLit) || parsePos > closeBrace) { + return std::nullopt; + } + SkipLeadingWhitespace(sql, parsePos); + if (parsePos != closeBrace) { + return std::nullopt; + } + OdbcBraceParsed parsed; + parsed.Kind = kind; + parsed.QuotedValue = std::move(quotedLit); + return parsed; +} + +struct BraceKeywordSpec { + std::string_view Keyword; + OdbcBraceKind Kind; + bool IsQuotedLiteral; +}; + +static constexpr BraceKeywordSpec kBraceKeywordSpecs[] = { + {"fn", OdbcBraceKind::FnBody, false}, + {"oj", OdbcBraceKind::OjBody, false}, + {"d", OdbcBraceKind::DateLiteral, true}, + {"t", OdbcBraceKind::TimeLiteral, true}, + {"ts", OdbcBraceKind::TimestampLiteral, true}, + {"call", OdbcBraceKind::ProcedureCall, false}, + {"escape", OdbcBraceKind::LikeEscape, true}, +}; + +std::optional TryParseOdbcBrace(std::string_view sql, size_t openBrace, size_t closeBrace) { + size_t parsePos = openBrace + 1; + SkipLeadingWhitespace(sql, parsePos); + + if (std::optional outputCall = TryParseOutputCallBrace(sql, parsePos, closeBrace)) { + return outputCall; + } + if (parsePos + 1 < sql.size() && sql[parsePos] == '?' && sql[parsePos + 1] == '=') { + return std::nullopt; + } + + std::string_view token; + if (!ReadIdent(sql, parsePos, &token)) { + return std::nullopt; + } + + for (const BraceKeywordSpec& spec : kBraceKeywordSpecs) { + if (!EqualNoCase(token, spec.Keyword)) { + continue; + } + if (spec.IsQuotedLiteral) { + return MakeQuotedBrace(spec.Kind, sql, parsePos, closeBrace); + } + return MakeRecurseTailBrace(spec.Kind, sql, parsePos, closeBrace); + } + + return std::nullopt; +} + +void AppendRewrittenBrace(std::string& rewritten, const OdbcBraceParsed& parsed) { + switch (parsed.Kind) { + case OdbcBraceKind::OutputProcedureCall: + case OdbcBraceKind::ProcedureCall: + rewritten += "CALL "; + rewritten.append(RewriteOdbcEscapesImpl(parsed.RecurseTail)); + return; + case OdbcBraceKind::FnBody: + case OdbcBraceKind::OjBody: + rewritten.append(RewriteOdbcEscapesImpl(parsed.RecurseTail)); + return; + case OdbcBraceKind::DateLiteral: + rewritten += "CAST('"; + rewritten += parsed.QuotedValue; + rewritten += "' AS Date)"; + return; + case OdbcBraceKind::TimeLiteral: + rewritten += "CAST('"; + rewritten += parsed.QuotedValue; + rewritten += "' AS Time)"; + return; + case OdbcBraceKind::TimestampLiteral: { + const std::string normalizedTs = NormalizeOdbcTimestampLiteral(parsed.QuotedValue); + rewritten += "CAST('"; + rewritten += normalizedTs; + rewritten += "' AS Datetime)"; + return; + } + case OdbcBraceKind::LikeEscape: + rewritten += " ESCAPE '"; + rewritten += parsed.QuotedValue; + rewritten += '\''; + return; + } +} + +std::string RewriteOdbcEscapesImpl(std::string_view sql) { + std::string rewritten; + rewritten.reserve(sql.size()); + + for (size_t readPos = 0; readPos < sql.size();) { + if (sql[readPos] != '{') { + rewritten.push_back(sql[readPos++]); + continue; + } + + const size_t closeBrace = FindMatchingCloseBrace(sql, readPos); + if (closeBrace == std::string_view::npos) { + rewritten.push_back(sql[readPos++]); + continue; + } + + if (std::optional parsedBrace = TryParseOdbcBrace(sql, readPos, closeBrace)) { + AppendRewrittenBrace(rewritten, *parsedBrace); + readPos = closeBrace + 1; + continue; + } + + rewritten.push_back(sql[readPos++]); + } + + return rewritten; +} + +std::string RewriteOdbcConvertCalls(std::string_view sql); + +class TOdbcConvertCallRewriter { +public: + explicit TOdbcConvertCallRewriter(std::string_view sql) + : Sql_(sql) { + Rewritten_.reserve(sql.size()); + } + + std::string TakeResult() && { + return std::move(Rewritten_); + } + + void Run() { + while (SegmentStart_ < Sql_.size()) { + const std::optional convertKeywordPos = FindNextConvertKeyword(SegmentStart_); + if (!convertKeywordPos) { + Rewritten_.append(Sql_.substr(SegmentStart_)); + break; + } + Rewritten_.append(Sql_.substr(SegmentStart_, *convertKeywordPos - SegmentStart_)); + if (!TryRewriteConvertAt(*convertKeywordPos)) { + break; + } + } + } + +private: + static constexpr size_t kConvertTokenLen = 7; + + std::optional FindNextConvertKeyword(size_t from) const { + for (size_t probePos = from; probePos + kConvertTokenLen <= Sql_.size(); ++probePos) { + if (!EqualNoCase(Sql_.substr(probePos, kConvertTokenLen), "CONVERT")) { + continue; + } + size_t afterKeyword = probePos + kConvertTokenLen; + SkipLeadingWhitespace(Sql_, afterKeyword); + if (afterKeyword < Sql_.size() && Sql_[afterKeyword] == '(') { + return probePos; + } + } + return std::nullopt; + } + + bool TryRewriteConvertAt(size_t convertKeywordPos) { + size_t parsePos = convertKeywordPos + kConvertTokenLen; + SkipLeadingWhitespace(Sql_, parsePos); + if (parsePos >= Sql_.size() || Sql_[parsePos] != '(') { + Rewritten_.append(Sql_.substr(convertKeywordPos, kConvertTokenLen)); + SegmentStart_ = convertKeywordPos + kConvertTokenLen; + return true; + } + ++parsePos; + + int parenDepth = 1; + const size_t firstArgStart = parsePos; + std::optional typeCommaPos; + for (; parsePos < Sql_.size(); ++parsePos) { + if (Sql_[parsePos] == '(') { + ++parenDepth; + } else if (Sql_[parsePos] == ')') { + --parenDepth; + } else if (Sql_[parsePos] == ',' && parenDepth == 1) { + typeCommaPos = parsePos; + break; + } + } + if (!typeCommaPos) { + Rewritten_.append(Sql_.substr(convertKeywordPos)); + return false; + } + + const std::string_view firstArg(Sql_.data() + firstArgStart, *typeCommaPos - firstArgStart); + parsePos = *typeCommaPos + 1; + SkipLeadingWhitespace(Sql_, parsePos); + const size_t sqlTypeStart = parsePos; + const auto sqlTypeEnd = std::find_if_not( + Sql_.begin() + static_cast(parsePos), + Sql_.end(), + [](unsigned char byte) { + return std::isalpha(byte) != 0 || byte == '_'; + }); + parsePos = static_cast(sqlTypeEnd - Sql_.begin()); + const std::string_view sqlTypeToken(Sql_.data() + sqlTypeStart, parsePos - sqlTypeStart); + SkipLeadingWhitespace(Sql_, parsePos); + if (parsePos >= Sql_.size() || Sql_[parsePos] != ')') { + Rewritten_.append(Sql_.substr(convertKeywordPos)); + return false; + } + + const std::string yqlType = MapSqlTypeToken(sqlTypeToken); + Rewritten_ += "CAST("; + Rewritten_ += RewriteOdbcConvertCalls(RewriteOdbcEscapesImpl(firstArg)); + Rewritten_ += " AS "; + Rewritten_ += yqlType; + Rewritten_ += ')'; + SegmentStart_ = parsePos + 1; + return true; + } + + std::string_view Sql_; + std::string Rewritten_; + size_t SegmentStart_ = 0; +}; + +std::string RewriteOdbcConvertCalls(std::string_view sql) { + TOdbcConvertCallRewriter rewriter(sql); + rewriter.Run(); + return std::move(rewriter).TakeResult(); +} + +} // namespace + + + +std::string RewriteOdbcEscapes(const std::string& sql) { + std::string afterBraceRewrite = RewriteOdbcEscapesImpl(sql); + return RewriteOdbcConvertCalls(afterBraceRewrite); +} + +} // namespace NYdb::NOdbc diff --git a/odbc/src/utils/escape.h b/odbc/src/utils/escape.h new file mode 100644 index 00000000000..7397a128450 --- /dev/null +++ b/odbc/src/utils/escape.h @@ -0,0 +1,9 @@ +#pragma once + +#include + +namespace NYdb::NOdbc { + +std::string RewriteOdbcEscapes(const std::string& sql); + +} // namespace NYdb::NOdbc diff --git a/odbc/src/utils/sql_like.h b/odbc/src/utils/sql_like.h new file mode 100644 index 00000000000..f51c10ca28c --- /dev/null +++ b/odbc/src/utils/sql_like.h @@ -0,0 +1,49 @@ +#pragma once + +#include + +namespace NYdb::NOdbc { + +// SQL LIKE — '%' is any substring, '_' is any single character. +inline bool SqlLikeMatch(std::string_view text, std::string_view pattern) { + size_t textPos = 0; + size_t patPos = 0; + size_t lastPercentPat = std::string_view::npos; + size_t textStartAfterPercent = 0; + + const size_t textLen = text.size(); + const size_t patLen = pattern.size(); + + while (textPos < textLen) { + const bool morePat = patPos < patLen; + const char patCh = morePat ? pattern[patPos] : '\0'; + + if (morePat && patCh != '%' && (patCh == '_' || patCh == text[textPos])) { + ++textPos; + ++patPos; + continue; + } + + if (morePat && patCh == '%') { + lastPercentPat = patPos++; + textStartAfterPercent = textPos; + continue; + } + + if (lastPercentPat != std::string_view::npos) { + patPos = lastPercentPat + 1; + ++textStartAfterPercent; + textPos = textStartAfterPercent; + continue; + } + + return false; + } + + while (patPos < patLen && pattern[patPos] == '%') { + ++patPos; + } + return patPos == patLen; +} + +} // namespace NYdb::NOdbc diff --git a/odbc/src/utils/types.cpp b/odbc/src/utils/types.cpp new file mode 100644 index 00000000000..3c9c70549fa --- /dev/null +++ b/odbc/src/utils/types.cpp @@ -0,0 +1,159 @@ +#include "types.h" + +namespace NYdb { +namespace NOdbc { + +SQLSMALLINT GetTypeId(const TType& type) { + TTypeParser typeParser(type); + size_t openedOptionals = 0; + while (typeParser.GetKind() == TTypeParser::ETypeKind::Optional) { + typeParser.OpenOptional(); + ++openedOptionals; + } + + auto closeOpenedOptionals = [&]() { + while (openedOptionals > 0) { + typeParser.CloseOptional(); + --openedOptionals; + } + }; + + const auto kind = typeParser.GetKind(); + if (kind == TTypeParser::ETypeKind::Primitive) { + const auto primitive = typeParser.GetPrimitive(); + closeOpenedOptionals(); + switch (primitive) { + case EPrimitiveType::Bool: + return SQL_BIT; + case EPrimitiveType::Int8: + case EPrimitiveType::Uint8: + return SQL_TINYINT; + case EPrimitiveType::Int16: + case EPrimitiveType::Uint16: + return SQL_SMALLINT; + case EPrimitiveType::Int32: + case EPrimitiveType::Uint32: + return SQL_INTEGER; + case EPrimitiveType::Int64: + case EPrimitiveType::Uint64: + return SQL_BIGINT; + case EPrimitiveType::Float: + return SQL_REAL; + case EPrimitiveType::Double: + return SQL_DOUBLE; + case EPrimitiveType::Date: + case EPrimitiveType::Date32: + case EPrimitiveType::TzDate: + return SQL_TYPE_DATE; + case EPrimitiveType::Datetime: + case EPrimitiveType::Timestamp: + case EPrimitiveType::Datetime64: + case EPrimitiveType::Timestamp64: + case EPrimitiveType::TzDatetime: + case EPrimitiveType::TzTimestamp: + return SQL_TYPE_TIMESTAMP; + case EPrimitiveType::Interval: + case EPrimitiveType::Interval64: + return SQL_BIGINT; + case EPrimitiveType::String: + return SQL_VARBINARY; + case EPrimitiveType::Utf8: + case EPrimitiveType::Yson: + case EPrimitiveType::Json: + case EPrimitiveType::JsonDocument: + case EPrimitiveType::DyNumber: + return SQL_VARCHAR; + case EPrimitiveType::Uuid: + return SQL_GUID; + } + } + + closeOpenedOptionals(); + if (kind == TTypeParser::ETypeKind::Decimal) { + return SQL_DECIMAL; + } + return SQL_UNKNOWN_TYPE; +} + +SQLSMALLINT IsNullable(const TType& type) { + TTypeParser typeParser(type); + if (typeParser.GetKind() == TTypeParser::ETypeKind::Optional || typeParser.GetKind() == TTypeParser::ETypeKind::Null) { + return SQL_NULLABLE; + } + + return SQL_NO_NULLS; +} + +SQLULEN GetColumnSize(SQLSMALLINT sqlType) { + switch (sqlType) { + case SQL_BIT: + return 1; + case SQL_TINYINT: + return 3; + case SQL_SMALLINT: + return 5; + case SQL_INTEGER: + return 10; + case SQL_BIGINT: + return 20; + case SQL_REAL: + return 7; + case SQL_DOUBLE: + return 15; + case SQL_TYPE_DATE: + return 10; + case SQL_TYPE_TIMESTAMP: + return 26; + case SQL_GUID: + return 36; + case SQL_VARCHAR: + case SQL_VARBINARY: + default: + return 4096; + } +} + +std::optional GetDecimalDigits(const TType& type) { + TTypeParser typeParser(type); + if (typeParser.GetKind() != TTypeParser::ETypeKind::Primitive) { + return std::nullopt; + } + + switch (typeParser.GetPrimitive()) { + case EPrimitiveType::Int64: + case EPrimitiveType::Uint64: + case EPrimitiveType::Int32: + case EPrimitiveType::Uint32: + case EPrimitiveType::Int16: + case EPrimitiveType::Uint16: + case EPrimitiveType::Int8: + case EPrimitiveType::Uint8: + return 0; + default: + return std::nullopt; + } +} + +std::optional GetRadix(const TType& type) { + TTypeParser typeParser(type); + if (typeParser.GetKind() != TTypeParser::ETypeKind::Primitive) { + return std::nullopt; + } + + switch (typeParser.GetPrimitive()) { + case EPrimitiveType::Int64: + case EPrimitiveType::Uint64: + case EPrimitiveType::Int32: + case EPrimitiveType::Uint32: + case EPrimitiveType::Int16: + case EPrimitiveType::Uint16: + case EPrimitiveType::Int8: + case EPrimitiveType::Uint8: + return 10; + default: + return std::nullopt; + } +} + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/utils/types.h b/odbc/src/utils/types.h new file mode 100644 index 00000000000..9428cafebb0 --- /dev/null +++ b/odbc/src/utils/types.h @@ -0,0 +1,20 @@ +#pragma once + +#include + +#include +#include +#include + +namespace NYdb { +namespace NOdbc { + +SQLSMALLINT GetTypeId(const TType& type); +SQLSMALLINT IsNullable(const TType& type); +SQLULEN GetColumnSize(SQLSMALLINT sqlType); + +std::optional GetDecimalDigits(const TType& type); +std::optional GetRadix(const TType& type); + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/utils/util.cpp b/odbc/src/utils/util.cpp new file mode 100644 index 00000000000..54700d06091 --- /dev/null +++ b/odbc/src/utils/util.cpp @@ -0,0 +1,18 @@ +#include "util.h" + +namespace NYdb::NOdbc { + +std::string GetString(SQLCHAR* str, SQLSMALLINT length) { + if (!str) { + return {}; + } + if (length == SQL_NTS) { + return std::string(reinterpret_cast(str)); + } + if (length <= 0) { + return {}; + } + return std::string(reinterpret_cast(str), length); +} + +} // namespace NYdb::NOdbc diff --git a/odbc/src/utils/util.h b/odbc/src/utils/util.h new file mode 100644 index 00000000000..b17fe2c235f --- /dev/null +++ b/odbc/src/utils/util.h @@ -0,0 +1,12 @@ +#pragma once + +#include +#include + +#include + +namespace NYdb::NOdbc { + +std::string GetString(SQLCHAR* str, SQLSMALLINT length); + +} // namespace NYdb::NOdbc diff --git a/odbc/tests/CMakeLists.txt b/odbc/tests/CMakeLists.txt new file mode 100644 index 00000000000..8d9f3176aea --- /dev/null +++ b/odbc/tests/CMakeLists.txt @@ -0,0 +1,30 @@ +set(YDB_ODBC_TEST_CONFIG_DIR "${CMAKE_BINARY_DIR}/odbc") +file(MAKE_DIRECTORY "${YDB_ODBC_TEST_CONFIG_DIR}") + +set(YDB_ODBC_DSN_SERVER "localhost:2136" CACHE STRING + "YDB endpoint in odbc.ini generated for ODBC integration tests") +set(YDB_ODBC_DSN_DATABASE "/local" CACHE STRING + "YDB database path in odbc.ini generated for ODBC integration tests") + +file(GENERATE + OUTPUT "${YDB_ODBC_TEST_CONFIG_DIR}/odbcinst.ini" + CONTENT "[YDB] +Description=YDB ODBC Driver +Driver=$ +Setup=$ +" +) + +file(WRITE "${YDB_ODBC_TEST_CONFIG_DIR}/odbc.ini" +"[ODBC Data Sources] +YDB=YDB ODBC Driver + +[YDB] +Driver=YDB +Description=YDB Database Connection +Server=${YDB_ODBC_DSN_SERVER} +Database=${YDB_ODBC_DSN_DATABASE} +") + +add_subdirectory(integration) +add_subdirectory(unit) diff --git a/odbc/tests/integration/CMakeLists.txt b/odbc/tests/integration/CMakeLists.txt new file mode 100644 index 00000000000..19a7004f16a --- /dev/null +++ b/odbc/tests/integration/CMakeLists.txt @@ -0,0 +1,34 @@ +add_odbc_test(NAME odbc-basic_it + SOURCES + basic_it.cpp +) + +add_odbc_test(NAME odbc-environment_api_it + SOURCES + environment_api_it.cpp +) + +add_odbc_test(NAME odbc-connection_api_it + SOURCES + connection_api_it.cpp +) + +add_odbc_test(NAME odbc-statement_api_it + SOURCES + statement_api_it.cpp +) + +add_odbc_test(NAME odbc-transaction_api_it + SOURCES + transaction_api_it.cpp +) + +add_odbc_test(NAME odbc-error_handling_it + SOURCES + error_handling_it.cpp +) + +add_odbc_test(NAME odbc-metadata_api_it + SOURCES + metadata_api_it.cpp +) diff --git a/odbc/tests/integration/basic_it.cpp b/odbc/tests/integration/basic_it.cpp new file mode 100644 index 00000000000..e7af877b37a --- /dev/null +++ b/odbc/tests/integration/basic_it.cpp @@ -0,0 +1,126 @@ +#include "test_utils.h" + +TEST(OdbcBasic, SimpleQuery) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env), SQL_SUCCESS); + ASSERT_EQ(SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0), SQL_SUCCESS); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc), SQL_SUCCESS); + CHECK_ODBC_OK(SQLDriverConnect(dbc, nullptr, (SQLCHAR*)kConnStr, SQL_NTS, nullptr, 0, nullptr, SQL_DRIVER_COMPLETE), dbc, SQL_HANDLE_DBC); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + // Simple query + CHECK_ODBC_OK(SQLExecDirect(stmt, (SQLCHAR*)"SELECT 1 AS one, 'abc' AS str", SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + + SQLINTEGER ival = 0; + char sval[16] = {0}; + SQLLEN ival_ind = 0, sval_ind = 0; + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_LONG, &ival, 0, &ival_ind), SQL_SUCCESS); + ASSERT_EQ(SQLGetData(stmt, 2, SQL_C_CHAR, sval, sizeof(sval), &sval_ind), SQL_SUCCESS); + ASSERT_EQ(ival, 1); + ASSERT_STREQ(sval, "abc"); + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(OdbcBasic, ParameterizedQuery) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env), SQL_SUCCESS); + ASSERT_EQ(SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0), SQL_SUCCESS); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc), SQL_SUCCESS); + CHECK_ODBC_OK(SQLDriverConnect(dbc, nullptr, (SQLCHAR*)kConnStr, SQL_NTS, nullptr, 0, nullptr, SQL_DRIVER_COMPLETE), dbc, SQL_HANDLE_DBC); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + SQLCHAR query[] = R"( + DECLARE $p1 AS Int32?; + SELECT $p1 + 10 AS res; + )"; + + // Parameterized query + CHECK_ODBC_OK(SQLPrepare(stmt, query, SQL_NTS), stmt, SQL_HANDLE_STMT); + SQLINTEGER param = 5; + CHECK_ODBC_OK(SQLBindParameter(stmt, 1, SQL_PARAM_INPUT, SQL_C_LONG, SQL_INTEGER, 0, 0, ¶m, 0, nullptr), stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLExecute(stmt), stmt, SQL_HANDLE_STMT); + + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + + SQLINTEGER res = 0; + SQLLEN res_ind = 0; + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_LONG, &res, 0, &res_ind), SQL_SUCCESS); + ASSERT_EQ(res, 15); + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(OdbcBasic, ColumnBinding) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env), SQL_SUCCESS); + ASSERT_EQ(SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0), SQL_SUCCESS); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc), SQL_SUCCESS); + CHECK_ODBC_OK(SQLDriverConnect(dbc, nullptr, (SQLCHAR*)kConnStr, SQL_NTS, nullptr, 0, nullptr, SQL_DRIVER_COMPLETE), dbc, SQL_HANDLE_DBC); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + SQLCHAR query_ddl[] = R"( + DROP TABLE IF EXISTS test_bind; + CREATE TABLE test_bind (id Int32, name Text, PRIMARY KEY (id)); + )"; + + SQLCHAR query[] = R"( + UPSERT INTO test_bind (id, name) VALUES (1, 'foo'), (2, 'bar'); + SELECT id, name FROM test_bind ORDER BY id; + )"; + + CHECK_ODBC_OK(SQLExecDirect(stmt, query_ddl, SQL_NTS), stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLExecDirect(stmt, query, SQL_NTS), stmt, SQL_HANDLE_STMT); + + SQLINTEGER id = 0; + char name[16] = {0}; + SQLLEN id_ind = 0, name_ind = 0; + ASSERT_EQ(SQLBindCol(stmt, 1, SQL_C_LONG, &id, 0, &id_ind), SQL_SUCCESS); + ASSERT_EQ(SQLBindCol(stmt, 2, SQL_C_CHAR, name, sizeof(name), &name_ind), SQL_SUCCESS); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + ASSERT_EQ(id, 1); + ASSERT_STREQ(name, "foo"); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + ASSERT_EQ(id, 2); + ASSERT_STREQ(name, "bar"); + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(OdbcBasic, SQLConnect) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + + AllocEnv(&env); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc), SQL_SUCCESS); + CHECK_ODBC_OK(SQLConnect(dbc, (SQLCHAR*)"YDB", SQL_NTS, (SQLCHAR*)"", SQL_NTS, (SQLCHAR*)"", SQL_NTS), + dbc, SQL_HANDLE_DBC); + + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + CHECK_ODBC_OK(SQLExecDirect(stmt, (SQLCHAR*)"SELECT 1", SQL_NTS), stmt, SQL_HANDLE_STMT); + + SQLINTEGER val; + SQLLEN ind; + SQLBindCol(stmt, 1, SQL_C_SLONG, &val, 0, &ind); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + ASSERT_EQ(val, 1); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} diff --git a/odbc/tests/integration/connection_api_it.cpp b/odbc/tests/integration/connection_api_it.cpp new file mode 100644 index 00000000000..aff067a9e20 --- /dev/null +++ b/odbc/tests/integration/connection_api_it.cpp @@ -0,0 +1,193 @@ +#include "test_utils.h" + +TEST(ConnectionApi, AllocFreeEnvHandle) { + SQLHENV env; + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env), SQL_SUCCESS); + ASSERT_EQ(SQLFreeHandle(SQL_HANDLE_ENV, env), SQL_SUCCESS); +} + +TEST(ConnectionApi, AllocFreeDbcHandle) { + SQLHENV env; + SQLHDBC dbc; + AllocEnv(&env); + + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc), SQL_SUCCESS); + ASSERT_EQ(SQLFreeHandle(SQL_HANDLE_DBC, dbc), SQL_SUCCESS); + + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(ConnectionApi, AllocFreeHandleInvalid) { + SQLHENV env; + SQLRETURN rc = SQLAllocHandle(999, SQL_NULL_HANDLE, &env); + ASSERT_TRUE(rc == SQL_ERROR || rc == SQL_INVALID_HANDLE); +} + +TEST(ConnectionApi, SQLConnectWithDSN) { + SQLHENV env; + SQLHDBC dbc; + AllocEnv(&env); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc), SQL_SUCCESS); + + CHECK_ODBC_OK(SQLConnect(dbc, (SQLCHAR*)"YDB", SQL_NTS, (SQLCHAR*)"", SQL_NTS, (SQLCHAR*)"", SQL_NTS), + dbc, SQL_HANDLE_DBC); + + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(ConnectionApi, SQLDriverConnectComplete) { + SQLHENV env; + SQLHDBC dbc; + AllocEnv(&env); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc), SQL_SUCCESS); + SQLCHAR outStr[256]; + SQLSMALLINT outLen; + SQLRETURN rc = SQLDriverConnect(dbc, nullptr, (SQLCHAR*)kConnStr, SQL_NTS, + outStr, sizeof(outStr), &outLen, SQL_DRIVER_COMPLETE); + CHECK_ODBC_OK(rc, dbc, SQL_HANDLE_DBC); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(ConnectionApi, SQLDriverConnectNoPrompt) { + SQLHENV env; + SQLHDBC dbc; + AllocEnv(&env); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc), SQL_SUCCESS); + + SQLRETURN rc = SQLDriverConnect(dbc, nullptr, (SQLCHAR*)kConnStr, SQL_NTS, + nullptr, 0, nullptr, SQL_DRIVER_NOPROMPT); + CHECK_ODBC_OK(rc, dbc, SQL_HANDLE_DBC); + + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(ConnectionApi, SQLDriverConnectInvalidConnString) { + SQLHENV env; + SQLHDBC dbc; + AllocEnv(&env); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc), SQL_SUCCESS); + + SQLRETURN rc = SQLDriverConnect(dbc, nullptr, (SQLCHAR*)"InvalidParam=test", SQL_NTS, + nullptr, 0, nullptr, SQL_DRIVER_NOPROMPT); + ASSERT_EQ(rc, SQL_ERROR); + + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(ConnectionApi, SQLConnectMissingDSN) { + SQLHENV env; + SQLHDBC dbc; + AllocEnv(&env); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc), SQL_SUCCESS); + + SQLRETURN rc = SQLConnect(dbc, (SQLCHAR*)"NONEXISTENT_DSN", SQL_NTS, (SQLCHAR*)"", SQL_NTS, (SQLCHAR*)"", SQL_NTS); + ASSERT_EQ(rc, SQL_ERROR); + EXPECT_TRUE(SqlStatePrefix(GetOdbcError(dbc, SQL_HANDLE_DBC), "IM002")); + + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(ConnectionApi, EnvAttrOdbcVersion) { + SQLHENV env; + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env), SQL_SUCCESS); + ASSERT_EQ(SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0), SQL_SUCCESS); + ASSERT_NE(SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, nullptr, 0), SQL_SUCCESS); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(ConnectionApi, EnvAttrOutputNts) { + SQLHENV env; + AllocEnv(&env); + ASSERT_EQ(SQLSetEnvAttr(env, SQL_ATTR_OUTPUT_NTS, (void*)SQL_TRUE, 0), SQL_SUCCESS); + ASSERT_NE(SQLSetEnvAttr(env, SQL_ATTR_OUTPUT_NTS, (void*)SQL_FALSE, 0), SQL_SUCCESS); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(ConnectionApi, ConnAttrAccessMode) { + SQLHENV env; + SQLHDBC dbc; + AllocEnvAndConnect(&env, &dbc); + SQLUINTEGER mode; + ASSERT_EQ(SQLGetConnectAttr(dbc, SQL_ATTR_ACCESS_MODE, &mode, sizeof(mode), nullptr), SQL_SUCCESS); + ASSERT_EQ(mode, SQL_MODE_READ_WRITE); + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_ACCESS_MODE, (SQLPOINTER)SQL_MODE_READ_ONLY, 0), + dbc, SQL_HANDLE_DBC); + ASSERT_EQ(SQLGetConnectAttr(dbc, SQL_ATTR_ACCESS_MODE, &mode, sizeof(mode), nullptr), SQL_SUCCESS); + ASSERT_EQ(mode, SQL_MODE_READ_ONLY); + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_ACCESS_MODE, (SQLPOINTER)SQL_MODE_READ_WRITE, 0), + dbc, SQL_HANDLE_DBC); + ASSERT_EQ(SQLSetConnectAttr(dbc, SQL_ATTR_ACCESS_MODE, (SQLPOINTER)9999, 0), SQL_ERROR); + EXPECT_TRUE(SqlStatePrefix(GetOdbcError(dbc, SQL_HANDLE_DBC), "HY024")); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(ConnectionApi, ConnAttrCurrentCatalog) { + SQLHENV env; + SQLHDBC dbc; + AllocEnvAndConnect(&env, &dbc); + char catalog[256]; + SQLINTEGER len; + ASSERT_EQ(SQLGetConnectAttr(dbc, SQL_ATTR_CURRENT_CATALOG, catalog, sizeof(catalog), &len), SQL_SUCCESS); + ASSERT_STREQ(catalog, "/local"); + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_CURRENT_CATALOG, (SQLPOINTER)"/local/test", SQL_NTS), + dbc, SQL_HANDLE_DBC); + ASSERT_EQ(SQLGetConnectAttr(dbc, SQL_ATTR_CURRENT_CATALOG, catalog, sizeof(catalog), &len), SQL_SUCCESS); + ASSERT_STREQ(catalog, "/local/test"); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(ConnectionApi, ConnAttrCurrentCatalogAffectsQueries) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + SQLExecDirect(stmt, (SQLCHAR*)"DROP TABLE IF EXISTS `/local/cat_a/probe`", SQL_NTS); + SQLFreeStmt(stmt, SQL_CLOSE); + CHECK_ODBC_OK(SQLExecDirect(stmt, (SQLCHAR*)"CREATE TABLE `/local/cat_a/probe` (id Int32, value Int32, PRIMARY KEY (id))", SQL_NTS), + stmt, SQL_HANDLE_STMT); + SQLFreeStmt(stmt, SQL_CLOSE); + CHECK_ODBC_OK(SQLExecDirect(stmt, (SQLCHAR*)"UPSERT INTO `/local/cat_a/probe` (id, value) VALUES (1, 100)", SQL_NTS), + stmt, SQL_HANDLE_STMT); + SQLFreeStmt(stmt, SQL_CLOSE); + SQLExecDirect(stmt, (SQLCHAR*)"DROP TABLE IF EXISTS `/local/cat_b/probe`", SQL_NTS); + SQLFreeStmt(stmt, SQL_CLOSE); + CHECK_ODBC_OK(SQLExecDirect(stmt, (SQLCHAR*)"CREATE TABLE `/local/cat_b/probe` (id Int32, value Int32, PRIMARY KEY (id))", SQL_NTS), + stmt, SQL_HANDLE_STMT); + SQLFreeStmt(stmt, SQL_CLOSE); + CHECK_ODBC_OK(SQLExecDirect(stmt, (SQLCHAR*)"UPSERT INTO `/local/cat_b/probe` (id, value) VALUES (1, 200)", SQL_NTS), + stmt, SQL_HANDLE_STMT); + SQLFreeStmt(stmt, SQL_CLOSE); + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_CURRENT_CATALOG, (SQLPOINTER)"/local/cat_a", SQL_NTS), + dbc, SQL_HANDLE_DBC); + CHECK_ODBC_OK(SQLExecDirect(stmt, (SQLCHAR*)"SELECT value FROM probe WHERE id = 1", SQL_NTS), + stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + SQLINTEGER value; + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_LONG, &value, 0, nullptr), SQL_SUCCESS); + ASSERT_EQ(value, 100); + SQLFreeStmt(stmt, SQL_CLOSE); + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_CURRENT_CATALOG, (SQLPOINTER)"/local/cat_b", SQL_NTS), + dbc, SQL_HANDLE_DBC); + CHECK_ODBC_OK(SQLExecDirect(stmt, (SQLCHAR*)"SELECT value FROM probe WHERE id = 1", SQL_NTS), + stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_LONG, &value, 0, nullptr), SQL_SUCCESS); + ASSERT_EQ(value, 200); + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} diff --git a/odbc/tests/integration/environment_api_it.cpp b/odbc/tests/integration/environment_api_it.cpp new file mode 100644 index 00000000000..b3395dc620a --- /dev/null +++ b/odbc/tests/integration/environment_api_it.cpp @@ -0,0 +1,225 @@ +#include "test_utils.h" + +TEST(EnvironmentApi, AllocFreeEnv) { + SQLHENV env; + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env), SQL_SUCCESS); + ASSERT_EQ(SQLFreeHandle(SQL_HANDLE_ENV, env), SQL_SUCCESS); +} + +TEST(EnvironmentApi, AllocEnvInvalidType) { + SQLHENV env; + SQLRETURN rc = SQLAllocHandle(999, SQL_NULL_HANDLE, &env); + ASSERT_TRUE(rc == SQL_ERROR || rc == SQL_INVALID_HANDLE); +} + +TEST(EnvironmentApi, FreeInvalidEnvHandle) { + SQLHENV env; + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env), SQL_SUCCESS); + ASSERT_EQ(SQLFreeHandle(SQL_HANDLE_ENV, env), SQL_SUCCESS); + SQLRETURN rc = SQLFreeHandle(SQL_HANDLE_ENV, env); + ASSERT_TRUE(rc == SQL_SUCCESS || rc == SQL_INVALID_HANDLE || rc == SQL_ERROR); +} + +TEST(EnvironmentApi, DoubleFreeEnv) { + SQLHENV env; + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env), SQL_SUCCESS); + ASSERT_EQ(SQLFreeHandle(SQL_HANDLE_ENV, env), SQL_SUCCESS); + // Second free may return error or success depending on implementation + // but should not crash + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(EnvironmentApi, SetOdbcVersion) { + SQLHENV env; + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env), SQL_SUCCESS); + + ASSERT_EQ(SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0), SQL_SUCCESS); + + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(EnvironmentApi, SetOdbcVersionInvalid) { + SQLHENV env; + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env), SQL_SUCCESS); + ASSERT_NE(SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, nullptr, 0), SQL_SUCCESS); + ASSERT_NE(SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)9999, 0), SQL_SUCCESS); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(EnvironmentApi, GetOdbcVersion) { + SQLHENV env; + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env), SQL_SUCCESS); + ASSERT_EQ(SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0), SQL_SUCCESS); + + SQLINTEGER version; + ASSERT_EQ(SQLGetEnvAttr(env, SQL_ATTR_ODBC_VERSION, &version, sizeof(version), nullptr), SQL_SUCCESS); + ASSERT_EQ(version, SQL_OV_ODBC3); + + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(EnvironmentApi, SetOutputNtsTrue) { + SQLHENV env; + AllocEnv(&env); + + ASSERT_EQ(SQLSetEnvAttr(env, SQL_ATTR_OUTPUT_NTS, (void*)SQL_TRUE, 0), SQL_SUCCESS); + + SQLINTEGER outputNts; + ASSERT_EQ(SQLGetEnvAttr(env, SQL_ATTR_OUTPUT_NTS, &outputNts, sizeof(outputNts), nullptr), SQL_SUCCESS); + ASSERT_EQ(outputNts, SQL_TRUE); + + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(EnvironmentApi, SetOutputNtsFalse) { + SQLHENV env; + AllocEnv(&env); + ASSERT_NE(SQLSetEnvAttr(env, SQL_ATTR_OUTPUT_NTS, (void*)SQL_FALSE, 0), SQL_SUCCESS); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(EnvironmentApi, GetOutputNtsDefault) { + SQLHENV env; + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env), SQL_SUCCESS); + ASSERT_EQ(SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0), SQL_SUCCESS); + SQLINTEGER outputNts; + ASSERT_EQ(SQLGetEnvAttr(env, SQL_ATTR_OUTPUT_NTS, &outputNts, sizeof(outputNts), nullptr), SQL_SUCCESS); + ASSERT_EQ(outputNts, SQL_TRUE); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(EnvironmentApi, SetInvalidEnvAttr) { + SQLHENV env; + AllocEnv(&env); + ASSERT_EQ(SQLSetEnvAttr(env, 9999, (void*)1, 0), SQL_ERROR); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(EnvironmentApi, GetInvalidEnvAttr) { + SQLHENV env; + AllocEnv(&env); + char buffer[256]; + SQLINTEGER len; + ASSERT_EQ(SQLGetEnvAttr(env, 9999, buffer, sizeof(buffer), &len), SQL_ERROR); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(EnvironmentApi, MultipleConnectionsSequential) { + SQLHENV env; + AllocEnv(&env); + for (int i = 0; i < 3; ++i) { + SQLHDBC dbc; + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc), SQL_SUCCESS); + CHECK_ODBC_OK(SQLConnect(dbc, (SQLCHAR*)"YDB", SQL_NTS, nullptr, 0, nullptr, 0), dbc, SQL_HANDLE_DBC); + SQLHSTMT stmt; + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + char query[32]; + snprintf(query, sizeof(query), "SELECT %d", i + 1); + CHECK_ODBC_OK(SQLExecDirect(stmt, (SQLCHAR*)query, SQL_NTS), stmt, SQL_HANDLE_STMT); + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + } + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +namespace { + +void StartManualTx(SQLHDBC dbc, SQLHSTMT* stmt) { + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)SQL_AUTOCOMMIT_OFF, 0), dbc, SQL_HANDLE_DBC); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, stmt), SQL_SUCCESS); + CHECK_ODBC_OK(SQLExecDirect(*stmt, (SQLCHAR*)"SELECT 1", SQL_NTS), *stmt, SQL_HANDLE_STMT); +} + +} // namespace + +TEST(EnvironmentApi, EndTranCommitOnEnv) { + SQLHENV env; + SQLHDBC dbc1, dbc2; + SQLHSTMT stmt1, stmt2; + + AllocEnvAndConnect(&env, &dbc1); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc2), SQL_SUCCESS); + SQLRETURN rc = SQLDriverConnect( + dbc2, nullptr, (SQLCHAR*)kConnStr, SQL_NTS, nullptr, 0, nullptr, SQL_DRIVER_COMPLETE); + CHECK_ODBC_OK(rc, dbc2, SQL_HANDLE_DBC); + + StartManualTx(dbc1, &stmt1); + StartManualTx(dbc2, &stmt2); + + CHECK_ODBC_OK(SQLEndTran(SQL_HANDLE_ENV, env, SQL_COMMIT), env, SQL_HANDLE_ENV); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt1); + SQLFreeHandle(SQL_HANDLE_STMT, stmt2); + SQLDisconnect(dbc1); + SQLDisconnect(dbc2); + SQLFreeHandle(SQL_HANDLE_DBC, dbc1); + SQLFreeHandle(SQL_HANDLE_DBC, dbc2); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(EnvironmentApi, EndTranRollbackOnEnv) { + SQLHENV env; + SQLHDBC dbc1, dbc2; + SQLHSTMT stmt1, stmt2; + + AllocEnvAndConnect(&env, &dbc1); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc2), SQL_SUCCESS); + SQLRETURN rc = SQLDriverConnect( + dbc2, nullptr, (SQLCHAR*)kConnStr, SQL_NTS, nullptr, 0, nullptr, SQL_DRIVER_COMPLETE); + CHECK_ODBC_OK(rc, dbc2, SQL_HANDLE_DBC); + + StartManualTx(dbc1, &stmt1); + StartManualTx(dbc2, &stmt2); + + CHECK_ODBC_OK(SQLEndTran(SQL_HANDLE_ENV, env, SQL_ROLLBACK), env, SQL_HANDLE_ENV); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt1); + SQLFreeHandle(SQL_HANDLE_STMT, stmt2); + SQLDisconnect(dbc1); + SQLDisconnect(dbc2); + SQLFreeHandle(SQL_HANDLE_DBC, dbc1); + SQLFreeHandle(SQL_HANDLE_DBC, dbc2); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(EnvironmentApi, EndTranPartialFailureReturnsInfo) { + SQLHENV env; + SQLHDBC dbc1, dbc2; + SQLHSTMT stmt1, stmt2; + + AllocEnvAndConnect(&env, &dbc1); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc2), SQL_SUCCESS); + SQLRETURN rc = SQLDriverConnect( + dbc2, nullptr, (SQLCHAR*)kConnStr, SQL_NTS, nullptr, 0, nullptr, SQL_DRIVER_COMPLETE); + CHECK_ODBC_OK(rc, dbc2, SQL_HANDLE_DBC); + + StartManualTx(dbc1, &stmt1); + CHECK_ODBC_OK(SQLSetConnectAttr(dbc2, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)SQL_AUTOCOMMIT_OFF, 0), dbc2, SQL_HANDLE_DBC); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc2, &stmt2), SQL_SUCCESS); + (void)SQLExecDirect(stmt2, (SQLCHAR*)"SELECT FROM", SQL_NTS); + + rc = SQLEndTran(SQL_HANDLE_ENV, env, SQL_COMMIT); + ASSERT_TRUE(rc == SQL_SUCCESS || rc == SQL_SUCCESS_WITH_INFO || rc == SQL_ERROR) + << GetOdbcError(env, SQL_HANDLE_ENV); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt1); + SQLFreeHandle(SQL_HANDLE_STMT, stmt2); + SQLDisconnect(dbc1); + SQLDisconnect(dbc2); + SQLFreeHandle(SQL_HANDLE_DBC, dbc1); + SQLFreeHandle(SQL_HANDLE_DBC, dbc2); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(EnvironmentApi, GetDiagRecEnv) { + SQLHENV env; + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env), SQL_SUCCESS); + (void)SQLSetEnvAttr(env, 9999, (void*)1, 0); + SQLCHAR sqlState[6]; + SQLINTEGER nativeError; + SQLCHAR msg[256]; + SQLSMALLINT msgLen; + (void)SQLGetDiagRec(SQL_HANDLE_ENV, env, 1, sqlState, &nativeError, msg, sizeof(msg), &msgLen); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} diff --git a/odbc/tests/integration/error_handling_it.cpp b/odbc/tests/integration/error_handling_it.cpp new file mode 100644 index 00000000000..96bd01d7ea2 --- /dev/null +++ b/odbc/tests/integration/error_handling_it.cpp @@ -0,0 +1,130 @@ +#include "test_utils.h" + +TEST(ErrorHandling, GetDiagRecAfterError) { + SQLHENV env; + SQLHDBC dbc; + AllocEnv(&env); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc), SQL_SUCCESS); + SQLRETURN rc = SQLConnect(dbc, (SQLCHAR*)"NONEXISTENT_DSN", SQL_NTS, + (SQLCHAR*)"", SQL_NTS, (SQLCHAR*)"", SQL_NTS); + ASSERT_EQ(rc, SQL_ERROR); + SQLCHAR sqlState[6]; + SQLINTEGER nativeError; + SQLCHAR msg[256]; + SQLSMALLINT msgLen; + SQLRETURN diagRc = SQLGetDiagRec(SQL_HANDLE_DBC, dbc, 1, sqlState, &nativeError, + msg, sizeof(msg), &msgLen); + ASSERT_TRUE(diagRc == SQL_SUCCESS || diagRc == SQL_SUCCESS_WITH_INFO); + const size_t copiedLen = std::strlen(reinterpret_cast(msg)); + ASSERT_EQ(msg[copiedLen], static_cast(0)); + if (diagRc == SQL_SUCCESS_WITH_INFO) { + ASSERT_GE(static_cast(msgLen), copiedLen); + } else { + ASSERT_EQ(static_cast(msgLen), copiedLen); + } + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(ErrorHandling, GetDiagRecMultipleErrors) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + SQLExecDirect(stmt, (SQLCHAR*)"INVALID SYNTAX HERE", SQL_NTS); + + SQLSMALLINT numRecs; + SQLGetDiagField(SQL_HANDLE_STMT, stmt, 0, SQL_DIAG_NUMBER, &numRecs, 0, nullptr); + + for (SQLSMALLINT i = 1; i <= numRecs; ++i) { + SQLCHAR sqlState[6]; + SQLINTEGER nativeError; + SQLCHAR msg[256]; + SQLSMALLINT msgLen; + SQLRETURN rc = SQLGetDiagRec(SQL_HANDLE_STMT, stmt, i, sqlState, &nativeError, + msg, sizeof(msg), &msgLen); + ASSERT_TRUE(rc == SQL_SUCCESS || rc == SQL_SUCCESS_WITH_INFO); + } + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(ErrorHandling, GetDiagFieldState) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + SQLExecDirect(stmt, (SQLCHAR*)"SELECT invalid_column FROM nonexistent_table", SQL_NTS); + SQLCHAR sqlState[6]; + SQLRETURN rc = SQLGetDiagField(SQL_HANDLE_STMT, stmt, 1, SQL_DIAG_SQLSTATE, + sqlState, sizeof(sqlState), nullptr); + ASSERT_TRUE(rc == SQL_SUCCESS || rc == SQL_SUCCESS_WITH_INFO); + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(ErrorHandling, GetDiagFieldNativeError) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + SQLExecDirect(stmt, (SQLCHAR*)"SELECT * FROM nonexistent_table", SQL_NTS); + SQLINTEGER nativeError; + SQLRETURN rc = SQLGetDiagField(SQL_HANDLE_STMT, stmt, 1, SQL_DIAG_NATIVE, + &nativeError, sizeof(nativeError), nullptr); + ASSERT_TRUE(rc == SQL_SUCCESS || rc == SQL_SUCCESS_WITH_INFO); + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + + +TEST(ErrorHandling, SuccessWithInfo) { + SQLHENV env; + SQLHDBC dbc; + AllocEnv(&env); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc), SQL_SUCCESS); + SQLCHAR outStr[10]; + SQLSMALLINT outLen; + SQLRETURN rc = SQLDriverConnect(dbc, nullptr, (SQLCHAR*)kConnStr, SQL_NTS, + outStr, sizeof(outStr), &outLen, SQL_DRIVER_NOPROMPT); + if (rc == SQL_SUCCESS_WITH_INFO) { + SQLCHAR sqlState[6]; + SQLINTEGER nativeError; + SQLCHAR msg[256]; + SQLSMALLINT msgLen; + SQLGetDiagRec(SQL_HANDLE_DBC, dbc, 1, sqlState, &nativeError, msg, sizeof(msg), &msgLen); + } + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(ErrorHandling, ClearErrors) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + SQLExecDirect(stmt, (SQLCHAR*)"SELECT * FROM nonexistent_table", SQL_NTS); + SQLSMALLINT numRecs1; + SQLGetDiagField(SQL_HANDLE_STMT, stmt, 0, SQL_DIAG_NUMBER, &numRecs1, 0, nullptr); + ASSERT_GT(numRecs1, 0); + SQLExecDirect(stmt, (SQLCHAR*)"SELECT 1", SQL_NTS); + SQLSMALLINT numRecs2; + SQLGetDiagField(SQL_HANDLE_STMT, stmt, 0, SQL_DIAG_NUMBER, &numRecs2, 0, nullptr); + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} diff --git a/odbc/tests/integration/metadata_api_it.cpp b/odbc/tests/integration/metadata_api_it.cpp new file mode 100644 index 00000000000..d39710657fd --- /dev/null +++ b/odbc/tests/integration/metadata_api_it.cpp @@ -0,0 +1,228 @@ +#include "test_utils.h" + +#ifndef SQL_ATTR_METADATA_ID +#define SQL_ATTR_METADATA_ID 10029 +#endif + +TEST(MetadataApi, SQLTablesAll) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + CHECK_ODBC_OK(SQLTables(stmt, nullptr, 0, nullptr, 0, nullptr, 0, nullptr, 0), + stmt, SQL_HANDLE_STMT); + int rowCount = 0; + while (SQLFetch(stmt) == SQL_SUCCESS) { + ++rowCount; + } + ASSERT_GT(rowCount, 0); + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(MetadataApi, SQLTablesWithPattern) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + SQLExecDirect(stmt, (SQLCHAR*)"DROP TABLE IF EXISTS test_metadata_pattern_a", SQL_NTS); + SQLExecDirect(stmt, (SQLCHAR*)"DROP TABLE IF EXISTS test_metadata_pattern_b", SQL_NTS); + SQLFreeStmt(stmt, SQL_CLOSE); + CHECK_ODBC_OK(SQLExecDirect(stmt, + (SQLCHAR*)"CREATE TABLE test_metadata_pattern_a (id Int32, PRIMARY KEY (id))", SQL_NTS), + stmt, SQL_HANDLE_STMT); + SQLFreeStmt(stmt, SQL_CLOSE); + CHECK_ODBC_OK(SQLExecDirect(stmt, + (SQLCHAR*)"CREATE TABLE test_metadata_pattern_b (id Int32, PRIMARY KEY (id))", SQL_NTS), + stmt, SQL_HANDLE_STMT); + SQLFreeStmt(stmt, SQL_CLOSE); + CHECK_ODBC_OK(SQLTables(stmt, nullptr, 0, nullptr, 0, + (SQLCHAR*)"%/test_metadata_pattern_%", SQL_NTS, nullptr, 0), + stmt, SQL_HANDLE_STMT); + int tableCount = 0; + while (SQLFetch(stmt) == SQL_SUCCESS) { + ++tableCount; + } + ASSERT_EQ(tableCount, 2); + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(MetadataApi, SQLTablesExactMatch) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + SQLExecDirect(stmt, (SQLCHAR*)"DROP TABLE IF EXISTS test_exact_table", SQL_NTS); + SQLFreeStmt(stmt, SQL_CLOSE); + CHECK_ODBC_OK(SQLExecDirect(stmt, + (SQLCHAR*)"CREATE TABLE test_exact_table (id Int32, PRIMARY KEY (id))", SQL_NTS), + stmt, SQL_HANDLE_STMT); + SQLFreeStmt(stmt, SQL_CLOSE); + CHECK_ODBC_OK(SQLSetStmtAttr(stmt, SQL_ATTR_METADATA_ID, (SQLPOINTER)(uintptr_t)SQL_TRUE, 0), + stmt, SQL_HANDLE_STMT); + const std::string exactPath = "/local/test_exact_table"; + CHECK_ODBC_OK(SQLTables(stmt, nullptr, 0, nullptr, 0, + (SQLCHAR*)exactPath.c_str(), SQL_NTS, nullptr, 0), + stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + ASSERT_EQ(SQLFetch(stmt), SQL_NO_DATA); + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(MetadataApi, SQLTablesLikePatternWithMetadataId) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + SQLExecDirect(stmt, (SQLCHAR*)"DROP TABLE IF EXISTS test_meta_table_1", SQL_NTS); + SQLExecDirect(stmt, (SQLCHAR*)"DROP TABLE IF EXISTS test_meta_table_2", SQL_NTS); + SQLFreeStmt(stmt, SQL_CLOSE); + CHECK_ODBC_OK(SQLExecDirect(stmt, + (SQLCHAR*)"CREATE TABLE test_meta_table_1 (id Int32, PRIMARY KEY (id))", SQL_NTS), + stmt, SQL_HANDLE_STMT); + SQLFreeStmt(stmt, SQL_CLOSE); + CHECK_ODBC_OK(SQLExecDirect(stmt, + (SQLCHAR*)"CREATE TABLE test_meta_table_2 (id Int32, PRIMARY KEY (id))", SQL_NTS), + stmt, SQL_HANDLE_STMT); + SQLFreeStmt(stmt, SQL_CLOSE); + SQLULEN metadataId = SQL_FALSE; + ASSERT_EQ(SQLGetStmtAttr(stmt, SQL_ATTR_METADATA_ID, &metadataId, 0, nullptr), SQL_SUCCESS); + ASSERT_EQ(metadataId, SQL_FALSE); + const char* likePattern = "%/test_meta_table_%"; + CHECK_ODBC_OK(SQLTables(stmt, nullptr, 0, nullptr, 0, + (SQLCHAR*)likePattern, SQL_NTS, (SQLCHAR*)"TABLE", SQL_NTS), + stmt, SQL_HANDLE_STMT); + int tableRows = 0; + while (SQLFetch(stmt) == SQL_SUCCESS) { + ++tableRows; + } + ASSERT_EQ(tableRows, 2); + SQLFreeStmt(stmt, SQL_CLOSE); + CHECK_ODBC_OK(SQLSetStmtAttr(stmt, SQL_ATTR_METADATA_ID, (SQLPOINTER)(uintptr_t)SQL_TRUE, 0), + stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLGetStmtAttr(stmt, SQL_ATTR_METADATA_ID, &metadataId, 0, nullptr), SQL_SUCCESS); + ASSERT_EQ(metadataId, SQL_TRUE); + CHECK_ODBC_OK(SQLTables(stmt, nullptr, 0, nullptr, 0, + (SQLCHAR*)likePattern, SQL_NTS, (SQLCHAR*)"TABLE", SQL_NTS), + stmt, SQL_HANDLE_STMT); + tableRows = 0; + while (SQLFetch(stmt) == SQL_SUCCESS) { + ++tableRows; + } + ASSERT_EQ(tableRows, 0); + SQLFreeStmt(stmt, SQL_CLOSE); + const std::string exactPath = "/local/test_meta_table_1"; + CHECK_ODBC_OK(SQLTables(stmt, nullptr, 0, nullptr, 0, + (SQLCHAR*)exactPath.c_str(), SQL_NTS, (SQLCHAR*)"TABLE", SQL_NTS), + stmt, SQL_HANDLE_STMT); + tableRows = 0; + while (SQLFetch(stmt) == SQL_SUCCESS) { + ++tableRows; + } + ASSERT_EQ(tableRows, 1); + CHECK_ODBC_OK(SQLSetStmtAttr(stmt, SQL_ATTR_METADATA_ID, (SQLPOINTER)(uintptr_t)SQL_FALSE, 0), + stmt, SQL_HANDLE_STMT); + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(MetadataApi, SQLColumnsAll) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + SQLExecDirect(stmt, (SQLCHAR*)"DROP TABLE IF EXISTS test_columns_all", SQL_NTS); + SQLFreeStmt(stmt, SQL_CLOSE); + CHECK_ODBC_OK(SQLExecDirect(stmt, + (SQLCHAR*)"CREATE TABLE test_columns_all (id Int32, name Text, value Int32, PRIMARY KEY (id))", + SQL_NTS), stmt, SQL_HANDLE_STMT); + SQLFreeStmt(stmt, SQL_CLOSE); + CHECK_ODBC_OK(SQLColumns(stmt, nullptr, 0, nullptr, 0, + (SQLCHAR*)"/local/test_columns_all", SQL_NTS, nullptr, 0), + stmt, SQL_HANDLE_STMT); + int colCount = 0; + while (SQLFetch(stmt) == SQL_SUCCESS) { + ++colCount; + } + ASSERT_EQ(colCount, 3); + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(MetadataApi, SQLColumnsWithPattern) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + SQLExecDirect(stmt, (SQLCHAR*)"DROP TABLE IF EXISTS test_columns_pattern", SQL_NTS); + SQLFreeStmt(stmt, SQL_CLOSE); + CHECK_ODBC_OK(SQLExecDirect(stmt, + (SQLCHAR*)"CREATE TABLE test_columns_pattern (id Int32, value_x Int32, value_y Int32, PRIMARY KEY (id))", + SQL_NTS), stmt, SQL_HANDLE_STMT); + SQLFreeStmt(stmt, SQL_CLOSE); + constexpr SQLUSMALLINT kColumnNameCol = 4; + char colName[256] = {}; + SQLLEN colInd = 0; + CHECK_ODBC_OK(SQLColumns(stmt, nullptr, 0, nullptr, 0, + (SQLCHAR*)"/local/test_columns_pattern", SQL_NTS, + (SQLCHAR*)"val%", SQL_NTS), + stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + ASSERT_EQ(SQLGetData(stmt, kColumnNameCol, SQL_C_CHAR, colName, sizeof(colName), &colInd), SQL_SUCCESS); + ASSERT_STREQ(colName, "value_x"); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + ASSERT_EQ(SQLGetData(stmt, kColumnNameCol, SQL_C_CHAR, colName, sizeof(colName), &colInd), SQL_SUCCESS); + ASSERT_STREQ(colName, "value_y"); + ASSERT_EQ(SQLFetch(stmt), SQL_NO_DATA); + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(MetadataApi, SQLColumnsMetadataId) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + SQLExecDirect(stmt, (SQLCHAR*)"DROP TABLE IF EXISTS test_columns_metadata", SQL_NTS); + SQLFreeStmt(stmt, SQL_CLOSE); + CHECK_ODBC_OK(SQLExecDirect(stmt, + (SQLCHAR*)"CREATE TABLE test_columns_metadata (id Int32, value_x Int32, PRIMARY KEY (id))", + SQL_NTS), stmt, SQL_HANDLE_STMT); + SQLFreeStmt(stmt, SQL_CLOSE); + const std::string exactTable = "/local/test_columns_metadata"; + CHECK_ODBC_OK(SQLSetStmtAttr(stmt, SQL_ATTR_METADATA_ID, (SQLPOINTER)(uintptr_t)SQL_TRUE, 0), + stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLColumns(stmt, nullptr, 0, nullptr, 0, + (SQLCHAR*)exactTable.c_str(), SQL_NTS, + (SQLCHAR*)"val%", SQL_NTS), + SQL_ERROR); + EXPECT_TRUE(SqlStatePrefix(GetOdbcError(stmt, SQL_HANDLE_STMT), "42S22")); + SQLFreeStmt(stmt, SQL_CLOSE); + CHECK_ODBC_OK(SQLSetStmtAttr(stmt, SQL_ATTR_METADATA_ID, (SQLPOINTER)(uintptr_t)SQL_FALSE, 0), + stmt, SQL_HANDLE_STMT); + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} diff --git a/odbc/tests/integration/statement_api_it.cpp b/odbc/tests/integration/statement_api_it.cpp new file mode 100644 index 00000000000..1413e40229a --- /dev/null +++ b/odbc/tests/integration/statement_api_it.cpp @@ -0,0 +1,580 @@ +#include "test_utils.h" + +#include + +#ifndef SQL_ATTR_METADATA_ID +#define SQL_ATTR_METADATA_ID 10029 +#endif + +TEST(StatementApi, AllocFreeStmtHandle) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + ASSERT_EQ(SQLFreeHandle(SQL_HANDLE_STMT, stmt), SQL_SUCCESS); + + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(StatementApi, ExecDirectSimple) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + CHECK_ODBC_OK(SQLExecDirect(stmt, (SQLCHAR*)"SELECT 1 AS value", SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(StatementApi, ExecDirectMultipleColumns) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + CHECK_ODBC_OK(SQLExecDirect(stmt, + (SQLCHAR*)"SELECT 1 AS int_col, 'hello' AS str_col, CAST(3.14 AS Double) AS float_col", + SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(StatementApi, ExecDirectInvalidSyntax) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + ASSERT_EQ(SQLExecDirect(stmt, (SQLCHAR*)"INVALID SYNTAX HERE", SQL_NTS), SQL_ERROR); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(StatementApi, ExecDirectInvalidTable) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + ASSERT_EQ(SQLExecDirect(stmt, (SQLCHAR*)"SELECT * FROM nonexistent_table", SQL_NTS), SQL_ERROR); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(StatementApi, PrepareAndExecute) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + CHECK_ODBC_OK(SQLPrepare(stmt, (SQLCHAR*)"SELECT $p1 + $p2 AS result", SQL_NTS), stmt, SQL_HANDLE_STMT); + + SQLINTEGER p1 = 10, p2 = 20; + CHECK_ODBC_OK(SQLBindParameter(stmt, 1, SQL_PARAM_INPUT, SQL_C_LONG, SQL_INTEGER, + 0, 0, &p1, 0, nullptr), stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLBindParameter(stmt, 2, SQL_PARAM_INPUT, SQL_C_LONG, SQL_INTEGER, + 0, 0, &p2, 0, nullptr), stmt, SQL_HANDLE_STMT); + + CHECK_ODBC_OK(SQLExecute(stmt), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(StatementApi, PrepareAndExecuteReused) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + CHECK_ODBC_OK(SQLPrepare(stmt, (SQLCHAR*)"SELECT $p1", SQL_NTS), stmt, SQL_HANDLE_STMT); + SQLINTEGER param; + SQLBindParameter(stmt, 1, SQL_PARAM_INPUT, SQL_C_LONG, SQL_INTEGER, 0, 0, ¶m, 0, nullptr); + param = 100; + CHECK_ODBC_OK(SQLExecute(stmt), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + SQLINTEGER result; + SQLGetData(stmt, 1, SQL_C_LONG, &result, 0, nullptr); + ASSERT_EQ(result, 100); + SQLCloseCursor(stmt); + param = 200; + CHECK_ODBC_OK(SQLExecute(stmt), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + SQLGetData(stmt, 1, SQL_C_LONG, &result, 0, nullptr); + ASSERT_EQ(result, 200); + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(StatementApi, FetchSingleRow) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + CHECK_ODBC_OK(SQLExecDirect(stmt, (SQLCHAR*)"SELECT 42", SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + ASSERT_EQ(SQLFetch(stmt), SQL_NO_DATA); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(StatementApi, FetchMultipleRows) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + CHECK_ODBC_OK(SQLExecDirect(stmt, + (SQLCHAR*)"SELECT * FROM AS_TABLE(ListMap(ListFromRange(1, 4), ($x) -> (AsStruct($x AS a)))) ORDER BY a", + SQL_NTS), stmt, SQL_HANDLE_STMT); + SQLINTEGER value; + SQLLEN ind; + SQLBindCol(stmt, 1, SQL_C_LONG, &value, 0, &ind); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + ASSERT_EQ(value, 1); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + ASSERT_EQ(value, 2); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + ASSERT_EQ(value, 3); + ASSERT_EQ(SQLFetch(stmt), SQL_NO_DATA); + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(StatementApi, BindColMultipleTypes) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + CHECK_ODBC_OK(SQLExecDirect(stmt, (SQLCHAR*)"SELECT 42 AS col1, 'test' AS col2", SQL_NTS), + stmt, SQL_HANDLE_STMT); + + SQLINTEGER col1; + char col2[64]; + SQLLEN col1Ind, col2Ind; + + SQLBindCol(stmt, 1, SQL_C_LONG, &col1, 0, &col1Ind); + SQLBindCol(stmt, 2, SQL_C_CHAR, col2, sizeof(col2), &col2Ind); + + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + ASSERT_EQ(col1, 42); + ASSERT_STREQ(col2, "test"); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(StatementApi, BindColThenGetData) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + CHECK_ODBC_OK(SQLExecDirect(stmt, (SQLCHAR*)"SELECT 100", SQL_NTS), stmt, SQL_HANDLE_STMT); + + SQLINTEGER value; + SQLLEN ind; + SQLBindCol(stmt, 1, SQL_C_LONG, &value, 0, &ind); + + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + ASSERT_EQ(value, 100); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(StatementApi, GetDataWithoutBindCol) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + CHECK_ODBC_OK(SQLExecDirect(stmt, (SQLCHAR*)"SELECT 100", SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + + SQLINTEGER value; + SQLLEN ind; + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_LONG, &value, 0, &ind), SQL_SUCCESS); + ASSERT_EQ(value, 100); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(StatementApi, GetDataMultipleColumns) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + CHECK_ODBC_OK(SQLExecDirect(stmt, (SQLCHAR*)"SELECT 1, 'hello world'", SQL_NTS), + stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + + SQLINTEGER col1; + SQLLEN col1Ind; + SQLGetData(stmt, 1, SQL_C_LONG, &col1, 0, &col1Ind); + ASSERT_EQ(col1, 1); + + char col2[64]; + SQLLEN col2Ind; + SQLGetData(stmt, 2, SQL_C_CHAR, col2, sizeof(col2), &col2Ind); + ASSERT_STREQ(col2, "hello world"); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(StatementApi, CloseCursor) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + CHECK_ODBC_OK(SQLExecDirect(stmt, (SQLCHAR*)"SELECT 1", SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + + CHECK_ODBC_OK(SQLCloseCursor(stmt), stmt, SQL_HANDLE_STMT); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(StatementApi, FreeStmtClose) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + CHECK_ODBC_OK(SQLExecDirect(stmt, (SQLCHAR*)"SELECT 1", SQL_NTS), stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLFreeStmt(stmt, SQL_CLOSE), stmt, SQL_HANDLE_STMT); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(StatementApi, FreeStmtResetParams) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + SQLINTEGER param = 42; + SQLBindParameter(stmt, 1, SQL_PARAM_INPUT, SQL_C_LONG, SQL_INTEGER, 0, 0, ¶m, 0, nullptr); + + CHECK_ODBC_OK(SQLFreeStmt(stmt, SQL_RESET_PARAMS), stmt, SQL_HANDLE_STMT); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(StatementApi, NumResultCols) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + CHECK_ODBC_OK(SQLExecDirect(stmt, (SQLCHAR*)"SELECT 1, 2, 3, 4, 5", SQL_NTS), + stmt, SQL_HANDLE_STMT); + SQLSMALLINT numCols; + CHECK_ODBC_OK(SQLNumResultCols(stmt, &numCols), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(numCols, 5); + SQLFetch(stmt); + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(StatementApi, RowCount) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + CHECK_ODBC_OK(SQLExecDirect(stmt, + (SQLCHAR*)"SELECT * FROM AS_TABLE(ListMap(ListFromRange(1, 4), ($x) -> (AsStruct($x AS v))))", + SQL_NTS), stmt, SQL_HANDLE_STMT); + + SQLLEN rowCount; + CHECK_ODBC_OK(SQLRowCount(stmt, &rowCount), stmt, SQL_HANDLE_STMT); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(StatementApi, AttrQueryTimeout) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + SQLUINTEGER timeoutSec = 1; + CHECK_ODBC_OK(SQLSetStmtAttr(stmt, SQL_ATTR_QUERY_TIMEOUT, (SQLPOINTER)(uintptr_t)timeoutSec, 0), + stmt, SQL_HANDLE_STMT); + + SQLCHAR longQuery[] = + "SELECT COUNT(*) FROM AS_TABLE(ListMap(ListFromRange(1u, 100000000u), ($x)->(AsStruct($x AS v))))"; + ASSERT_EQ(SQLExecDirect(stmt, longQuery, SQL_NTS), SQL_ERROR); + EXPECT_TRUE(SqlStatePrefix(GetOdbcError(stmt, SQL_HANDLE_STMT), "HYT00")); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(StatementApi, AttrMaxRows) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + SQLExecDirect(stmt, (SQLCHAR*)"DROP TABLE IF EXISTS test_max_rows", SQL_NTS); + SQLFreeStmt(stmt, SQL_CLOSE); + CHECK_ODBC_OK(SQLExecDirect(stmt, (SQLCHAR*)"CREATE TABLE test_max_rows (id Int32, PRIMARY KEY (id))", SQL_NTS), + stmt, SQL_HANDLE_STMT); + SQLFreeStmt(stmt, SQL_CLOSE); + CHECK_ODBC_OK(SQLExecDirect(stmt, (SQLCHAR*)"UPSERT INTO test_max_rows (id) VALUES (1), (2)", SQL_NTS), + stmt, SQL_HANDLE_STMT); + SQLFreeStmt(stmt, SQL_CLOSE); + const SQLULEN maxRows = 1; + CHECK_ODBC_OK(SQLSetStmtAttr(stmt, SQL_ATTR_MAX_ROWS, (SQLPOINTER)(uintptr_t)maxRows, 0), + stmt, SQL_HANDLE_STMT); + SQLULEN maxRowsOut; + ASSERT_EQ(SQLGetStmtAttr(stmt, SQL_ATTR_MAX_ROWS, &maxRowsOut, 0, nullptr), SQL_SUCCESS); + ASSERT_EQ(maxRowsOut, maxRows); + CHECK_ODBC_OK(SQLExecDirect(stmt, (SQLCHAR*)"SELECT id FROM test_max_rows ORDER BY id", SQL_NTS), + stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + ASSERT_EQ(SQLFetch(stmt), SQL_NO_DATA); + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(StatementApi, AttrNoScan) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + SQLCHAR selectEscapeFnQuery[] = "SELECT {fn ABS(-12)} AS value"; + CHECK_ODBC_OK(SQLSetStmtAttr(stmt, SQL_ATTR_NOSCAN, (SQLPOINTER)SQL_NOSCAN_OFF, 0), + stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLExecDirect(stmt, selectEscapeFnQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + SQLINTEGER valueInt = 0; + SQLLEN valueInd = 0; + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_LONG, &valueInt, 0, &valueInd), SQL_SUCCESS); + ASSERT_EQ(valueInt, 12); + SQLFreeStmt(stmt, SQL_CLOSE); + CHECK_ODBC_OK(SQLSetStmtAttr(stmt, SQL_ATTR_NOSCAN, (SQLPOINTER)SQL_NOSCAN_ON, 0), + stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLExecDirect(stmt, selectEscapeFnQuery, SQL_NTS), SQL_ERROR); + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(StatementApi, EscapeSequenceConvert) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + CHECK_ODBC_OK(SQLSetStmtAttr(stmt, SQL_ATTR_NOSCAN, (SQLPOINTER)SQL_NOSCAN_OFF, 0), + stmt, SQL_HANDLE_STMT); + + SQLCHAR convertQuery[] = "SELECT {fn CONVERT(42, SQL_SMALLINT)} AS value"; + CHECK_ODBC_OK(SQLExecDirect(stmt, convertQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + SQLSMALLINT valueSmall = 0; + SQLLEN valueInd = 0; + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_SSHORT, &valueSmall, 0, &valueInd), SQL_SUCCESS); + ASSERT_EQ(valueSmall, 42); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(StatementApi, EscapeSequenceDouble) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + CHECK_ODBC_OK(SQLSetStmtAttr(stmt, SQL_ATTR_NOSCAN, (SQLPOINTER)SQL_NOSCAN_OFF, 0), + stmt, SQL_HANDLE_STMT); + + SQLCHAR convertDoubleQuery[] = "SELECT {fn CONVERT(2.5, SQL_DOUBLE)} AS value"; + CHECK_ODBC_OK(SQLExecDirect(stmt, convertDoubleQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + double valueDouble = 0; + SQLLEN valueInd = 0; + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_DOUBLE, &valueDouble, 0, &valueInd), SQL_SUCCESS); + ASSERT_LT(std::fabs(valueDouble - 2.5), 1e-9); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(StatementApi, EscapeSequenceNested) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + CHECK_ODBC_OK(SQLSetStmtAttr(stmt, SQL_ATTR_NOSCAN, (SQLPOINTER)SQL_NOSCAN_OFF, 0), + stmt, SQL_HANDLE_STMT); + + SQLCHAR nestedFnQuery[] = "SELECT {fn {fn ABS(-10)}} AS value"; + CHECK_ODBC_OK(SQLExecDirect(stmt, nestedFnQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + SQLINTEGER valueInt = 0; + SQLLEN valueInd = 0; + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_LONG, &valueInt, 0, &valueInd), SQL_SUCCESS); + ASSERT_EQ(valueInt, 10); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(StatementApi, EscapeSequenceString) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + CHECK_ODBC_OK(SQLSetStmtAttr(stmt, SQL_ATTR_NOSCAN, (SQLPOINTER)SQL_NOSCAN_OFF, 0), + stmt, SQL_HANDLE_STMT); + + SQLCHAR asciiLowerQuery[] = "SELECT {fn String::AsciiToLower('AbC')} AS value"; + CHECK_ODBC_OK(SQLExecDirect(stmt, asciiLowerQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + char buf[32] = {}; + SQLLEN valueInd = 0; + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_CHAR, buf, sizeof(buf), &valueInd), SQL_SUCCESS); + ASSERT_STREQ(buf, "abc"); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(StatementApi, EscapeSequenceDate) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + CHECK_ODBC_OK(SQLSetStmtAttr(stmt, SQL_ATTR_NOSCAN, (SQLPOINTER)SQL_NOSCAN_OFF, 0), + stmt, SQL_HANDLE_STMT); + + SQLCHAR dateQuery[] = "SELECT {d '2024-06-15'} AS value"; + CHECK_ODBC_OK(SQLExecDirect(stmt, dateQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + char buf[32] = {}; + SQLLEN valueInd = 0; + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_CHAR, buf, sizeof(buf), &valueInd), SQL_SUCCESS); + ASSERT_STREQ(buf, "2024-06-15"); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(StatementApi, EscapeSequenceTimestamp) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + CHECK_ODBC_OK(SQLSetStmtAttr(stmt, SQL_ATTR_NOSCAN, (SQLPOINTER)SQL_NOSCAN_OFF, 0), + stmt, SQL_HANDLE_STMT); + + SQLCHAR tsQuery[] = "SELECT {ts '2024-06-15 14:30:00'} AS value"; + CHECK_ODBC_OK(SQLExecDirect(stmt, tsQuery, SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + char buf[64] = {}; + SQLLEN valueInd = 0; + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_CHAR, buf, sizeof(buf), &valueInd), SQL_SUCCESS); + ASSERT_STREQ(buf, "2024-06-15 14:30:00"); + + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} diff --git a/odbc/tests/integration/test_utils.h b/odbc/tests/integration/test_utils.h new file mode 100644 index 00000000000..362a836991b --- /dev/null +++ b/odbc/tests/integration/test_utils.h @@ -0,0 +1,116 @@ +#pragma once + +#include + +#include +#include + +#include +#include + +inline std::string GetOdbcError(SQLHANDLE handle, SQLSMALLINT type) { + SQLCHAR sqlState[6] = {0}; + SQLCHAR message[256] = {0}; + SQLINTEGER nativeError = 0; + SQLSMALLINT textLength = 0; + SQLRETURN rc = SQLGetDiagRec(type, handle, 1, sqlState, &nativeError, message, sizeof(message), &textLength); + if (rc == SQL_SUCCESS || rc == SQL_SUCCESS_WITH_INFO) { + return std::string((char*)sqlState) + ": " + (char*)message; + } + return "Unknown ODBC error"; +} + +#define CHECK_ODBC_OK(rc, handle, type) \ + ASSERT_TRUE((rc) == SQL_SUCCESS || (rc) == SQL_SUCCESS_WITH_INFO) << GetOdbcError(handle, type) + +inline const char* kConnStr = "Driver=" ODBC_DRIVER_PATH ";Server=localhost:2136;Database=/local;"; + +inline bool SqlStatePrefix(const std::string& diag, const char* state5) { + return diag.size() >= 5 && std::strncmp(diag.c_str(), state5, 5) == 0; +} + +inline void AllocEnv(SQLHENV* env) { + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, env), SQL_SUCCESS); + ASSERT_EQ(SQLSetEnvAttr(*env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0), SQL_SUCCESS); +} + +inline void AllocEnvAndConnect(SQLHENV* env, SQLHDBC* dbc) { + AllocEnv(env); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_DBC, *env, dbc), SQL_SUCCESS); + SQLRETURN rc = SQLDriverConnect( + *dbc, nullptr, (SQLCHAR*)kConnStr, SQL_NTS, nullptr, 0, nullptr, SQL_DRIVER_COMPLETE); + CHECK_ODBC_OK(rc, *dbc, SQL_HANDLE_DBC); +} + +// ============================================================================ +// Type and Parameter Utilities +// ============================================================================ + +// Bind integer parameter and return result +inline SQLRETURN BindIntParam(SQLHSTMT stmt, SQLUSMALLINT paramNum, SQLINTEGER* value) { + return SQLBindParameter(stmt, paramNum, SQL_PARAM_INPUT, SQL_C_LONG, SQL_INTEGER, 0, 0, value, 0, nullptr); +} + +inline SQLRETURN BindInt64Param(SQLHSTMT stmt, SQLUSMALLINT paramNum, SQLBIGINT* value) { + return SQLBindParameter(stmt, paramNum, SQL_PARAM_INPUT, SQL_C_SBIGINT, SQL_BIGINT, 0, 0, value, 0, nullptr); +} + +inline SQLRETURN BindStringParam(SQLHSTMT stmt, SQLUSMALLINT paramNum, char* value, SQLLEN len) { + SQLLEN indicator = (len >= 0) ? len : SQL_NTS; + return SQLBindParameter(stmt, paramNum, SQL_PARAM_INPUT, SQL_C_CHAR, SQL_VARCHAR, + 0, 0, value, (indicator == SQL_NTS) ? 0 : indicator, &indicator); +} + +inline SQLRETURN BindDoubleParam(SQLHSTMT stmt, SQLUSMALLINT paramNum, double* value) { + return SQLBindParameter(stmt, paramNum, SQL_PARAM_INPUT, SQL_C_DOUBLE, SQL_DOUBLE, 0, 0, value, 0, nullptr); +} + +inline SQLRETURN BindNullParam(SQLHSTMT stmt, SQLUSMALLINT paramNum, SQLINTEGER* placeholder) { + static SQLLEN nullIndicator = SQL_NULL_DATA; + return SQLBindParameter(stmt, paramNum, SQL_PARAM_INPUT, SQL_C_LONG, SQL_INTEGER, + 0, 0, placeholder, 0, &nullIndicator); +} + +// Fetch and verify integer result +inline SQLINTEGER FetchIntResult(SQLHSTMT stmt, SQLUSMALLINT colNum = 1) { + SQLINTEGER result = 0; + SQLLEN indicator = 0; + SQLBindCol(stmt, colNum, SQL_C_LONG, &result, 0, &indicator); + SQLFetch(stmt); + return result; +} + +inline SQLBIGINT FetchInt64Result(SQLHSTMT stmt, SQLUSMALLINT colNum = 1) { + SQLBIGINT result = 0; + SQLLEN indicator = 0; + SQLBindCol(stmt, colNum, SQL_C_SBIGINT, &result, 0, &indicator); + SQLFetch(stmt); + return result; +} + +inline double FetchDoubleResult(SQLHSTMT stmt, SQLUSMALLINT colNum = 1) { + double result = 0.0; + SQLLEN indicator = 0; + SQLBindCol(stmt, colNum, SQL_C_DOUBLE, &result, 0, &indicator); + SQLFetch(stmt); + return result; +} + +inline std::string FetchStringResult(SQLHSTMT stmt, SQLUSMALLINT colNum = 1, size_t maxLen = 256) { + std::string result(maxLen, '\0'); + SQLLEN indicator = 0; + SQLBindCol(stmt, colNum, SQL_C_CHAR, &result[0], maxLen, &indicator); + SQLFetch(stmt); + if (indicator > 0 && indicator != SQL_NULL_DATA) { + result.resize(indicator); + } + return result; +} + +inline bool IsNullResult(SQLHSTMT stmt, SQLUSMALLINT colNum = 1) { + SQLINTEGER dummy; + SQLLEN indicator = 0; + SQLBindCol(stmt, colNum, SQL_C_LONG, &dummy, 0, &indicator); + SQLFetch(stmt); + return indicator == SQL_NULL_DATA; +} diff --git a/odbc/tests/integration/transaction_api_it.cpp b/odbc/tests/integration/transaction_api_it.cpp new file mode 100644 index 00000000000..b26953e4dd0 --- /dev/null +++ b/odbc/tests/integration/transaction_api_it.cpp @@ -0,0 +1,179 @@ +#include "test_utils.h" + +TEST(TransactionApi, AutocommitDefaultOn) { + SQLHENV env; + SQLHDBC dbc; + AllocEnvAndConnect(&env, &dbc); + + SQLUINTEGER autocommit; + ASSERT_EQ(SQLGetConnectAttr(dbc, SQL_ATTR_AUTOCOMMIT, &autocommit, 0, nullptr), SQL_SUCCESS); + ASSERT_EQ(autocommit, SQL_AUTOCOMMIT_ON); + + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(TransactionApi, AutocommitOnOffToggle) { + SQLHENV env; + SQLHDBC dbc; + AllocEnvAndConnect(&env, &dbc); + + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)SQL_AUTOCOMMIT_OFF, 0), + dbc, SQL_HANDLE_DBC); + SQLUINTEGER autocommit; + ASSERT_EQ(SQLGetConnectAttr(dbc, SQL_ATTR_AUTOCOMMIT, &autocommit, 0, nullptr), SQL_SUCCESS); + ASSERT_EQ(autocommit, SQL_AUTOCOMMIT_OFF); + + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)SQL_AUTOCOMMIT_ON, 0), + dbc, SQL_HANDLE_DBC); + ASSERT_EQ(SQLGetConnectAttr(dbc, SQL_ATTR_AUTOCOMMIT, &autocommit, 0, nullptr), SQL_SUCCESS); + ASSERT_EQ(autocommit, SQL_AUTOCOMMIT_ON); + + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(TransactionApi, AutocommitOffRollback) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + SQLExecDirect(stmt, (SQLCHAR*)"DROP TABLE IF EXISTS test_rollback", SQL_NTS); + SQLFreeStmt(stmt, SQL_CLOSE); + CHECK_ODBC_OK(SQLExecDirect(stmt, (SQLCHAR*)"CREATE TABLE test_rollback (id Int32, value Int32, PRIMARY KEY (id))", SQL_NTS), + stmt, SQL_HANDLE_STMT); + SQLFreeStmt(stmt, SQL_CLOSE); + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)SQL_AUTOCOMMIT_OFF, 0), + dbc, SQL_HANDLE_DBC); + CHECK_ODBC_OK(SQLExecDirect(stmt, (SQLCHAR*)"UPSERT INTO test_rollback (id, value) VALUES (1, 100)", SQL_NTS), + stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLEndTran(SQL_HANDLE_DBC, dbc, SQL_ROLLBACK), dbc, SQL_HANDLE_DBC); + SQLFreeStmt(stmt, SQL_CLOSE); + CHECK_ODBC_OK(SQLExecDirect(stmt, (SQLCHAR*)"SELECT value FROM test_rollback WHERE id = 1", SQL_NTS), + stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_NO_DATA); + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(TransactionApi, AutocommitOffCommit) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + SQLExecDirect(stmt, (SQLCHAR*)"DROP TABLE IF EXISTS test_commit", SQL_NTS); + SQLFreeStmt(stmt, SQL_CLOSE); + CHECK_ODBC_OK(SQLExecDirect(stmt, (SQLCHAR*)"CREATE TABLE test_commit (id Int32, value Int32, PRIMARY KEY (id))", SQL_NTS), + stmt, SQL_HANDLE_STMT); + SQLFreeStmt(stmt, SQL_CLOSE); + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)SQL_AUTOCOMMIT_OFF, 0), + dbc, SQL_HANDLE_DBC); + CHECK_ODBC_OK(SQLExecDirect(stmt, (SQLCHAR*)"UPSERT INTO test_commit (id, value) VALUES (1, 200)", SQL_NTS), + stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLEndTran(SQL_HANDLE_DBC, dbc, SQL_COMMIT), dbc, SQL_HANDLE_DBC); + SQLFreeStmt(stmt, SQL_CLOSE); + CHECK_ODBC_OK(SQLExecDirect(stmt, (SQLCHAR*)"SELECT value FROM test_commit WHERE id = 1", SQL_NTS), + stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + SQLINTEGER value; + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_LONG, &value, 0, nullptr), SQL_SUCCESS); + ASSERT_EQ(value, 200); + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(TransactionApi, MultipleStatementsInManualTransaction) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + SQLExecDirect(stmt, (SQLCHAR*)"DROP TABLE IF EXISTS test_multi_stmt", SQL_NTS); + SQLFreeStmt(stmt, SQL_CLOSE); + CHECK_ODBC_OK(SQLExecDirect(stmt, (SQLCHAR*)"CREATE TABLE test_multi_stmt (id Int32, value Int32, PRIMARY KEY (id))", SQL_NTS), + stmt, SQL_HANDLE_STMT); + SQLFreeStmt(stmt, SQL_CLOSE); + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)SQL_AUTOCOMMIT_OFF, 0), + dbc, SQL_HANDLE_DBC); + for (int i = 1; i <= 5; ++i) { + char query[256]; + snprintf(query, sizeof(query), "UPSERT INTO test_multi_stmt (id, value) VALUES (%d, %d)", i, i * 10); + CHECK_ODBC_OK(SQLExecDirect(stmt, (SQLCHAR*)query, SQL_NTS), stmt, SQL_HANDLE_STMT); + } + CHECK_ODBC_OK(SQLEndTran(SQL_HANDLE_DBC, dbc, SQL_COMMIT), dbc, SQL_HANDLE_DBC); + SQLFreeStmt(stmt, SQL_CLOSE); + CHECK_ODBC_OK(SQLExecDirect(stmt, (SQLCHAR*)"SELECT COUNT(*) FROM test_multi_stmt", SQL_NTS), + stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + SQLINTEGER count; + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_LONG, &count, 0, nullptr), SQL_SUCCESS); + ASSERT_EQ(count, 5); + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(TransactionApi, SQLEndTranOnEnv) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + SQLExecDirect(stmt, (SQLCHAR*)"DROP TABLE IF EXISTS test_env_tran", SQL_NTS); + SQLFreeStmt(stmt, SQL_CLOSE); + CHECK_ODBC_OK(SQLExecDirect(stmt, (SQLCHAR*)"CREATE TABLE test_env_tran (id Int32, PRIMARY KEY (id))", SQL_NTS), + stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLSetConnectAttr(dbc, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)SQL_AUTOCOMMIT_OFF, 0), + dbc, SQL_HANDLE_DBC); + CHECK_ODBC_OK(SQLExecDirect(stmt, (SQLCHAR*)"UPSERT INTO test_env_tran (id) VALUES (1)", SQL_NTS), + stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLEndTran(SQL_HANDLE_ENV, env, SQL_COMMIT), env, SQL_HANDLE_ENV); + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(TransactionApi, SQLEndTranInvalid) { + SQLHENV env; + SQLHDBC dbc; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLEndTran(SQL_HANDLE_DBC, dbc, 999), SQL_ERROR); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(TransactionApi, TxnIsolationDefault) { + SQLHENV env; + SQLHDBC dbc; + AllocEnvAndConnect(&env, &dbc); + SQLUINTEGER isolation; + ASSERT_EQ(SQLGetConnectAttr(dbc, SQL_ATTR_TXN_ISOLATION, &isolation, sizeof(isolation), nullptr), SQL_SUCCESS); + ASSERT_EQ(isolation, SQL_TXN_SERIALIZABLE); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(TransactionApi, TxnIsolationUnsupportedInReadWrite) { + SQLHENV env; + SQLHDBC dbc; + AllocEnvAndConnect(&env, &dbc); + ASSERT_EQ(SQLSetConnectAttr(dbc, SQL_ATTR_TXN_ISOLATION, (SQLPOINTER)SQL_TXN_READ_COMMITTED, 0), SQL_ERROR); + EXPECT_TRUE(SqlStatePrefix(GetOdbcError(dbc, SQL_HANDLE_DBC), "HYC00")); + ASSERT_EQ(SQLSetConnectAttr(dbc, SQL_ATTR_TXN_ISOLATION, (SQLPOINTER)9999, 0), SQL_ERROR); + EXPECT_TRUE(SqlStatePrefix(GetOdbcError(dbc, SQL_HANDLE_DBC), "HY024")); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} diff --git a/odbc/tests/unit/CMakeLists.txt b/odbc/tests/unit/CMakeLists.txt new file mode 100644 index 00000000000..006671ce933 --- /dev/null +++ b/odbc/tests/unit/CMakeLists.txt @@ -0,0 +1,36 @@ +add_ydb_test(NAME odbc-convert_ut GTEST + SOURCES + convert_ut.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/utils/convert.cpp + INCLUDE_DIRS + ${CMAKE_CURRENT_SOURCE_DIR}/../../src + LINK_LIBRARIES + yutil + YDB-CPP-SDK::Params + api-protos + LABELS + unit +) + +add_ydb_test(NAME odbc-escape_ut GTEST + SOURCES + escape_ut.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/utils/escape.cpp + INCLUDE_DIRS + ${CMAKE_CURRENT_SOURCE_DIR}/../../src + LINK_LIBRARIES + yutil + LABELS + unit +) + +add_ydb_test(NAME odbc-sql_like_ut GTEST + SOURCES + sql_like_ut.cpp + INCLUDE_DIRS + ${CMAKE_CURRENT_SOURCE_DIR}/../../src + LINK_LIBRARIES + yutil + LABELS + unit +) diff --git a/odbc/tests/unit/convert_ut.cpp b/odbc/tests/unit/convert_ut.cpp new file mode 100644 index 00000000000..86b0d3d5be3 --- /dev/null +++ b/odbc/tests/unit/convert_ut.cpp @@ -0,0 +1,346 @@ +#include "utils/convert.h" +#undef BOOL + +#include + +#include + +#include + +#include + +using namespace NYdb::NOdbc; +using namespace NYdb; + +template +void CheckProto(const T& value, const std::string& expected) { + std::string protoStr; + google::protobuf::TextFormat::PrintToString(value, &protoStr); + ASSERT_EQ(protoStr, expected); +} + +TEST(OdbcConvert, Int64ToYdb) { + SQLBIGINT v = 42; + TBoundParam param{ + 1, // ParamNumber + SQL_PARAM_INPUT, // InputOutputType + SQL_C_SBIGINT, // ValueType + SQL_BIGINT, // ParameterType + 0, 0, // ColumnSize, DecimalDigits + &v, // ParameterValuePtr + sizeof(v), // BufferLength + nullptr // StrLenOrIndPtr + }; + + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); + CheckProto(value->GetType().GetProto(), "optional_type {\n item {\n type_id: INT64\n }\n}\n"); + CheckProto(value->GetProto(), "int64_value: 42\n"); +} + +TEST(OdbcConvert, Uint64ToYdb) { + SQLUBIGINT v = 123; + TBoundParam param{ + 1, SQL_PARAM_INPUT, SQL_C_UBIGINT, SQL_BIGINT, 0, 0, &v, sizeof(v), nullptr + }; + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); + CheckProto(value->GetType().GetProto(), "optional_type {\n item {\n type_id: UINT64\n }\n}\n"); + CheckProto(value->GetProto(), "uint64_value: 123\n"); +} + +TEST(OdbcConvert, DoubleToYdb) { + SQLDOUBLE v = 3.14; + TBoundParam param{ + 1, SQL_PARAM_INPUT, SQL_C_DOUBLE, SQL_DOUBLE, 0, 0, &v, sizeof(v), nullptr + }; + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); + CheckProto(value->GetType().GetProto(), "optional_type {\n item {\n type_id: DOUBLE\n }\n}\n"); + CheckProto(value->GetProto(), "double_value: 3.14\n"); +} + +TEST(OdbcConvert, StringToYdbUtf8) { + const char* str = "hello"; + SQLLEN len = 5; + TBoundParam param{ + 1, SQL_PARAM_INPUT, SQL_C_CHAR, SQL_VARCHAR, 0, 0, (SQLPOINTER)str, len, nullptr + }; + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); + CheckProto(value->GetType().GetProto(), "optional_type {\n item {\n type_id: UTF8\n }\n}\n"); + CheckProto(value->GetProto(), "text_value: \"hello\"\n"); +} + +TEST(OdbcConvert, StringToYdbBinary) { + const char* str = "bin\x01\x02"; + SQLLEN len = 5; + TBoundParam param{ + 1, SQL_PARAM_INPUT, SQL_C_BINARY, SQL_BINARY, 0, 0, (SQLPOINTER)str, len, nullptr + }; + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); + CheckProto(value->GetType().GetProto(), "optional_type {\n item {\n type_id: STRING\n }\n}\n"); + CheckProto(value->GetProto(), "bytes_value: \"bin\\001\\002\"\n"); +} + +TEST(OdbcConvert, Int64NullToYdb) { + SQLBIGINT v = 42; + SQLLEN nullInd = SQL_NULL_DATA; + TBoundParam param{ + 1, SQL_PARAM_INPUT, SQL_C_SBIGINT, SQL_BIGINT, 0, 0, &v, sizeof(v), &nullInd + }; + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); + CheckProto(value->GetType().GetProto(), "optional_type {\n item {\n type_id: INT64\n }\n}\n"); + CheckProto(value->GetProto(), "null_flag_value: NULL_VALUE\n"); +} + +TEST(OdbcConvert, StringNullToYdb) { + const char* str = "test"; + SQLLEN nullInd = SQL_NULL_DATA; + TBoundParam param{ + 1, SQL_PARAM_INPUT, SQL_C_CHAR, SQL_VARCHAR, 0, 0, (SQLPOINTER)str, 4, &nullInd + }; + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); + CheckProto(value->GetType().GetProto(), "optional_type {\n item {\n type_id: UTF8\n }\n}\n"); + CheckProto(value->GetProto(), "null_flag_value: NULL_VALUE\n"); +} + +TEST(OdbcConvert, Int32ToYdb) { + SQLINTEGER v = 42; + TBoundParam param{ + 1, SQL_PARAM_INPUT, SQL_C_LONG, SQL_INTEGER, 0, 0, &v, sizeof(v), nullptr + }; + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); + CheckProto(value->GetType().GetProto(), "optional_type {\n item {\n type_id: INT32\n }\n}\n"); + CheckProto(value->GetProto(), "int32_value: 42\n"); +} + +TEST(OdbcConvert, Int32NegativeToYdb) { + SQLINTEGER v = -999; + TBoundParam param{ + 1, SQL_PARAM_INPUT, SQL_C_LONG, SQL_INTEGER, 0, 0, &v, sizeof(v), nullptr + }; + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); + CheckProto(value->GetProto(), "int32_value: -999\n"); +} + +TEST(OdbcConvert, Int32ZeroToYdb) { + SQLINTEGER v = 0; + TBoundParam param{ + 1, SQL_PARAM_INPUT, SQL_C_LONG, SQL_INTEGER, 0, 0, &v, sizeof(v), nullptr + }; + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); + CheckProto(value->GetProto(), "int32_value: 0\n"); +} + +TEST(OdbcConvert, Int32MaxToYdb) { + SQLINTEGER v = 2147483647; // INT32_MAX + TBoundParam param{ + 1, SQL_PARAM_INPUT, SQL_C_LONG, SQL_INTEGER, 0, 0, &v, sizeof(v), nullptr + }; + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); + CheckProto(value->GetProto(), "int32_value: 2147483647\n"); +} + +TEST(OdbcConvert, Int32NullToYdb) { + SQLINTEGER v = 42; + SQLLEN nullInd = SQL_NULL_DATA; + TBoundParam param{ + 1, SQL_PARAM_INPUT, SQL_C_LONG, SQL_INTEGER, 0, 0, &v, sizeof(v), &nullInd + }; + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); + CheckProto(value->GetProto(), "null_flag_value: NULL_VALUE\n"); +} + +TEST(OdbcConvert, Int64NegativeToYdb) { + SQLBIGINT v = -123456789012345LL; + TBoundParam param{ + 1, SQL_PARAM_INPUT, SQL_C_SBIGINT, SQL_BIGINT, 0, 0, &v, sizeof(v), nullptr + }; + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); + CheckProto(value->GetProto(), "int64_value: -123456789012345\n"); +} + +TEST(OdbcConvert, Int64ZeroToYdb) { + SQLBIGINT v = 0; + TBoundParam param{ + 1, SQL_PARAM_INPUT, SQL_C_SBIGINT, SQL_BIGINT, 0, 0, &v, sizeof(v), nullptr + }; + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); + CheckProto(value->GetProto(), "int64_value: 0\n"); +} + +TEST(OdbcConvert, DoubleNegativeToYdb) { + SQLDOUBLE v = -2.71828; + TBoundParam param{ + 1, SQL_PARAM_INPUT, SQL_C_DOUBLE, SQL_DOUBLE, 0, 0, &v, sizeof(v), nullptr + }; + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); +} + +TEST(OdbcConvert, DoubleZeroToYdb) { + SQLDOUBLE v = 0.0; + TBoundParam param{ + 1, SQL_PARAM_INPUT, SQL_C_DOUBLE, SQL_DOUBLE, 0, 0, &v, sizeof(v), nullptr + }; + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); + CheckProto(value->GetProto(), "double_value: 0\n"); +} + +TEST(OdbcConvert, DoubleNullToYdb) { + SQLDOUBLE v = 3.14; + SQLLEN nullInd = SQL_NULL_DATA; + TBoundParam param{ + 1, SQL_PARAM_INPUT, SQL_C_DOUBLE, SQL_DOUBLE, 0, 0, &v, sizeof(v), &nullInd + }; + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); + CheckProto(value->GetProto(), "null_flag_value: NULL_VALUE\n"); +} + +TEST(OdbcConvert, StringEmptyToYdb) { + const char* str = ""; + SQLLEN len = 0; + TBoundParam param{ + 1, SQL_PARAM_INPUT, SQL_C_CHAR, SQL_VARCHAR, 0, 0, (SQLPOINTER)str, len, &len + }; + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); + CheckProto(value->GetProto(), "text_value: \"\"\n"); +} + +TEST(OdbcConvert, StringUnicodeToYdb) { + const char* str = "Привет"; + SQLLEN len = SQL_NTS; + TBoundParam param{ + 1, SQL_PARAM_INPUT, SQL_C_CHAR, SQL_VARCHAR, 0, 0, (SQLPOINTER)str, 0, &len + }; + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); +} + +TEST(OdbcConvert, StringWithLengthToYdb) { + const char* str = "hello world"; + SQLLEN len = 5; // Only "hello" + TBoundParam param{ + 1, SQL_PARAM_INPUT, SQL_C_CHAR, SQL_VARCHAR, 0, 0, (SQLPOINTER)str, len, &len + }; + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); + CheckProto(value->GetProto(), "text_value: \"hello\"\n"); +} + +TEST(OdbcConvert, StringNullTerminatedToYdb) { + const char* str = "test"; + SQLLEN len = SQL_NTS; + TBoundParam param{ + 1, SQL_PARAM_INPUT, SQL_C_CHAR, SQL_VARCHAR, 0, 0, (SQLPOINTER)str, 0, &len + }; + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); + CheckProto(value->GetProto(), "text_value: \"test\"\n"); +} + + +TEST(OdbcConvert, BinaryNullToYdb) { + const char* data = "\x01\x02\x03"; + SQLLEN nullInd = SQL_NULL_DATA; + TBoundParam param{ + 1, SQL_PARAM_INPUT, SQL_C_BINARY, SQL_BINARY, 0, 0, (SQLPOINTER)data, 3, &nullInd + }; + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); + CheckProto(value->GetProto(), "null_flag_value: NULL_VALUE\n"); +} + +TEST(OdbcConvert, BinaryEmptyToYdb) { + const char* data = ""; + SQLLEN len = 0; + TBoundParam param{ + 1, SQL_PARAM_INPUT, SQL_C_BINARY, SQL_BINARY, 0, 0, (SQLPOINTER)data, len, &len + }; + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); + CheckProto(value->GetProto(), "bytes_value: \"\"\n"); +} diff --git a/odbc/tests/unit/escape_ut.cpp b/odbc/tests/unit/escape_ut.cpp new file mode 100644 index 00000000000..60b3e582e69 --- /dev/null +++ b/odbc/tests/unit/escape_ut.cpp @@ -0,0 +1,71 @@ +#include "utils/escape.h" + +#include + +using NYdb::NOdbc::RewriteOdbcEscapes; + +TEST(OdbcEscapeRewrite, FnUnwraps) { + EXPECT_EQ(RewriteOdbcEscapes("SELECT {fn ABS(-12)} AS v"), "SELECT ABS(-12) AS v"); +} + +TEST(OdbcEscapeRewrite, FnCaseInsensitive) { + EXPECT_EQ(RewriteOdbcEscapes("{FN LOWER('A')}"), "LOWER('A')"); +} + +TEST(OdbcEscapeRewrite, OjUnwraps) { + EXPECT_EQ(RewriteOdbcEscapes("{oj LEFT OUTER JOIN t ON a=b}"), "LEFT OUTER JOIN t ON a=b"); +} + +TEST(OdbcEscapeRewrite, DateLiteral) { + EXPECT_EQ(RewriteOdbcEscapes("SELECT {d '2024-01-01'}"), "SELECT CAST('2024-01-01' AS Date)"); +} + +TEST(OdbcEscapeRewrite, TimeLiteral) { + EXPECT_EQ(RewriteOdbcEscapes("{t '14:30:00'}"), "CAST('14:30:00' AS Time)"); +} + +TEST(OdbcEscapeRewrite, TimestampLiteralNormalizesSpaceToT) { + EXPECT_EQ( + RewriteOdbcEscapes("SELECT {ts '2024-06-15 14:30:00'} AS v"), + "SELECT CAST('2024-06-15T14:30:00Z' AS Datetime) AS v"); +} + +TEST(OdbcEscapeRewrite, TimestampLiteralKeepsExistingZ) { + EXPECT_EQ( + RewriteOdbcEscapes("SELECT {ts '2024-06-15T14:30:00Z'} AS v"), + "SELECT CAST('2024-06-15T14:30:00Z' AS Datetime) AS v"); +} + +TEST(OdbcEscapeRewrite, Call) { + EXPECT_EQ(RewriteOdbcEscapes("{call sp_demo(1, 2)}"), "CALL sp_demo(1, 2)"); +} + +TEST(OdbcEscapeRewrite, OutputCallBecomesPlainCall) { + EXPECT_EQ(RewriteOdbcEscapes("{?= call sp(1)}"), "CALL sp(1)"); +} + +TEST(OdbcEscapeRewrite, EscapeClause) { + EXPECT_EQ(RewriteOdbcEscapes("LIKE 'a%' {escape '\\'}"), "LIKE 'a%' ESCAPE '\\'"); +} + +TEST(OdbcEscapeRewrite, ConvertOdbcToYqlCast) { + EXPECT_EQ( + RewriteOdbcEscapes("SELECT {fn CONVERT(42, SQL_SMALLINT)} AS v"), + "SELECT CAST(42 AS Int16) AS v"); +} + +TEST(OdbcEscapeRewrite, ConvertNestedInFn) { + EXPECT_EQ(RewriteOdbcEscapes("{fn CONVERT(x, SQL_INTEGER)}"), "CAST(x AS Int32)"); +} + +TEST(OdbcEscapeRewrite, NestedFnEscapes) { + EXPECT_EQ(RewriteOdbcEscapes("{fn {fn ABS(1)}}"), "ABS(1)"); +} + +TEST(OdbcEscapeRewrite, UnknownBraceLeftUnchanged) { + EXPECT_EQ(RewriteOdbcEscapes("{not_a_keyword 1}"), "{not_a_keyword 1}"); +} + +TEST(OdbcEscapeRewrite, EmptyInput) { + EXPECT_EQ(RewriteOdbcEscapes(""), ""); +} diff --git a/odbc/tests/unit/sql_like_ut.cpp b/odbc/tests/unit/sql_like_ut.cpp new file mode 100644 index 00000000000..e0b8d87ee01 --- /dev/null +++ b/odbc/tests/unit/sql_like_ut.cpp @@ -0,0 +1,28 @@ +#include "utils/sql_like.h" + +#include + +using NYdb::NOdbc::SqlLikeMatch; + +TEST(SqlLikeMatch, PercentMatchesSubstring) { + EXPECT_TRUE(SqlLikeMatch("/local/foo_bar", "%foo%")); + EXPECT_TRUE(SqlLikeMatch("/local/pfx_foo_sfx", "%foo%")); + EXPECT_FALSE(SqlLikeMatch("/local/other", "%foo%")); +} + +TEST(SqlLikeMatch, UnderscoreMatchesSingleChar) { + EXPECT_TRUE(SqlLikeMatch("a_c", "a_c")); + EXPECT_TRUE(SqlLikeMatch("abc", "a_c")); + EXPECT_FALSE(SqlLikeMatch("abbc", "a_c")); +} + +TEST(SqlLikeMatch, EmptyPatternMatchesOnlyEmptyText) { + EXPECT_TRUE(SqlLikeMatch("", "")); + EXPECT_FALSE(SqlLikeMatch("anything", "")); +} + +TEST(SqlLikeMatch, PercentAtEnds) { + EXPECT_TRUE(SqlLikeMatch("hello", "%hello%")); + EXPECT_TRUE(SqlLikeMatch("hello", "hel%")); + EXPECT_TRUE(SqlLikeMatch("hello", "%llo")); +} diff --git a/tests/unit/library/operation_id/CMakeLists.txt b/tests/unit/library/operation_id/CMakeLists.txt index a6f2143949a..63d77da600a 100644 --- a/tests/unit/library/operation_id/CMakeLists.txt +++ b/tests/unit/library/operation_id/CMakeLists.txt @@ -5,6 +5,7 @@ add_ydb_test(NAME operation_id_ut GTEST yutil lib-operation_id-protos library-operation_id + cpp-testing-unittest LABELS unit )