diff --git a/mssql_python/pybind/connection/connection.cpp b/mssql_python/pybind/connection/connection.cpp index 32ed5507..aaf172b9 100644 --- a/mssql_python/pybind/connection/connection.cpp +++ b/mssql_python/pybind/connection/connection.cpp @@ -85,8 +85,16 @@ void Connection::connect(const py::dict& attrs_before) { #else connStrPtr = const_cast(_connStr.c_str()); #endif - SQLRETURN ret = SQLDriverConnect_ptr(_dbcHandle->get(), nullptr, connStrPtr, SQL_NTS, nullptr, - 0, nullptr, SQL_DRIVER_NOPROMPT); + SQLRETURN ret; + { + // Release the GIL during the blocking ODBC connect call. + // SQLDriverConnect involves DNS resolution, TCP handshake, TLS negotiation, + // and SQL Server authentication — all pure I/O that doesn't need the GIL. + // This allows other Python threads to run concurrently. + py::gil_scoped_release release; + ret = SQLDriverConnect_ptr(_dbcHandle->get(), nullptr, connStrPtr, SQL_NTS, nullptr, + 0, nullptr, SQL_DRIVER_NOPROMPT); + } checkError(ret); updateLastUsed(); } @@ -95,6 +103,11 @@ void Connection::disconnect() { if (_dbcHandle) { LOG("Disconnecting from database"); + // Check if we hold the GIL so we can conditionally release it. + // The GIL is held when called from pybind11-bound methods but may NOT + // be held in destructor paths (C++ shared_ptr ref-count drop, shutdown). + bool hasGil = PyGILState_Check() != 0; + // CRITICAL FIX: Mark all child statement handles as implicitly freed // When we free the DBC handle below, the ODBC driver will automatically free // all child STMT handles. We need to tell the SqlHandle objects about this @@ -135,8 +148,26 @@ void Connection::disconnect() { _allocationsSinceCompaction = 0; } // Release lock before potentially slow SQLDisconnect call - SQLRETURN ret = SQLDisconnect_ptr(_dbcHandle->get()); - checkError(ret); + SQLRETURN ret; + if (hasGil) { + // Release the GIL during the blocking ODBC disconnect call. + // This allows other Python threads to run while the network + // round-trip completes. + py::gil_scoped_release release; + ret = SQLDisconnect_ptr(_dbcHandle->get()); + } else { + // Destructor / shutdown path — GIL is not held, call directly. + ret = SQLDisconnect_ptr(_dbcHandle->get()); + } + // In destructor/shutdown paths, suppress errors to avoid + // std::terminate() if this throws during stack unwinding. + if (hasGil) { + checkError(ret); + } else if (!SQL_SUCCEEDED(ret)) { + // Intentionally no LOG() here: LOG() acquires the GIL internally + // via py::gil_scoped_acquire, which is unsafe during interpreter + // shutdown or stack unwinding (can deadlock or call std::terminate). + } // triggers SQLFreeHandle via destructor, if last owner _dbcHandle.reset(); } else { @@ -375,7 +406,6 @@ bool Connection::reset() { (SQLPOINTER)SQL_RESET_CONNECTION_YES, SQL_IS_INTEGER); if (!SQL_SUCCEEDED(ret)) { LOG("Failed to reset connection (ret=%d). Marking as dead.", ret); - disconnect(); return false; } @@ -387,7 +417,6 @@ bool Connection::reset() { (SQLPOINTER)SQL_TXN_READ_COMMITTED, SQL_IS_INTEGER); if (!SQL_SUCCEEDED(ret)) { LOG("Failed to reset transaction isolation level (ret=%d). Marking as dead.", ret); - disconnect(); return false; } diff --git a/mssql_python/pybind/connection/connection_pool.cpp b/mssql_python/pybind/connection/connection_pool.cpp index 3000a970..7c6b7f70 100644 --- a/mssql_python/pybind/connection/connection_pool.cpp +++ b/mssql_python/pybind/connection/connection_pool.cpp @@ -16,6 +16,7 @@ std::shared_ptr ConnectionPool::acquire(const std::wstring& connStr, const py::dict& attrs_before) { std::vector> to_disconnect; std::shared_ptr valid_conn = nullptr; + bool needs_connect = false; { std::lock_guard lock(_mutex); auto now = std::chrono::steady_clock::now(); @@ -57,16 +58,33 @@ std::shared_ptr ConnectionPool::acquire(const std::wstring& connStr, } } - // Create new connection if none reusable + // Reserve a slot for a new connection if none reusable. + // The actual connect() call happens outside the mutex to avoid + // holding the mutex during the blocking ODBC call (which releases + // the GIL and could otherwise cause a mutex/GIL deadlock). if (!valid_conn && _current_size < _max_size) { valid_conn = std::make_shared(connStr, true); - valid_conn->connect(attrs_before); ++_current_size; + needs_connect = true; } else if (!valid_conn) { throw std::runtime_error("ConnectionPool::acquire: pool size limit reached"); } } + // Phase 2.5: Connect the new connection outside the mutex. + if (needs_connect) { + try { + valid_conn->connect(attrs_before); + } catch (...) { + // Connect failed — release the reserved slot + { + std::lock_guard lock(_mutex); + if (_current_size > 0) --_current_size; + } + throw; + } + } + // Phase 3: Disconnect expired/bad connections outside lock for (auto& conn : to_disconnect) { try { @@ -79,12 +97,25 @@ std::shared_ptr ConnectionPool::acquire(const std::wstring& connStr, } void ConnectionPool::release(std::shared_ptr conn) { - std::lock_guard lock(_mutex); - if (_pool.size() < _max_size) { - conn->updateLastUsed(); - _pool.push_back(conn); - } else { - conn->disconnect(); + bool should_disconnect = false; + { + std::lock_guard lock(_mutex); + if (_pool.size() < _max_size) { + conn->updateLastUsed(); + _pool.push_back(conn); + } else { + should_disconnect = true; + } + } + // Disconnect outside the mutex to avoid holding it during the + // blocking ODBC call (which releases the GIL). + if (should_disconnect) { + try { + conn->disconnect(); + } catch (const std::exception& ex) { + LOG("ConnectionPool::release: disconnect failed: %s", ex.what()); + } + std::lock_guard lock(_mutex); if (_current_size > 0) --_current_size; } @@ -116,21 +147,35 @@ ConnectionPoolManager& ConnectionPoolManager::getInstance() { std::shared_ptr ConnectionPoolManager::acquireConnection(const std::wstring& connStr, const py::dict& attrs_before) { - std::lock_guard lock(_manager_mutex); - - auto& pool = _pools[connStr]; - if (!pool) { - LOG("Creating new connection pool"); - pool = std::make_shared(_default_max_size, _default_idle_secs); + std::shared_ptr pool; + { + std::lock_guard lock(_manager_mutex); + auto& pool_ref = _pools[connStr]; + if (!pool_ref) { + LOG("Creating new connection pool"); + pool_ref = std::make_shared(_default_max_size, _default_idle_secs); + } + pool = pool_ref; } + // Call acquire() outside _manager_mutex. acquire() may release the GIL + // during the ODBC connect call; holding _manager_mutex across that would + // create a mutex/GIL lock-ordering deadlock. return pool->acquire(connStr, attrs_before); } void ConnectionPoolManager::returnConnection(const std::wstring& conn_str, const std::shared_ptr conn) { - std::lock_guard lock(_manager_mutex); - if (_pools.find(conn_str) != _pools.end()) { - _pools[conn_str]->release((conn)); + std::shared_ptr pool; + { + std::lock_guard lock(_manager_mutex); + auto it = _pools.find(conn_str); + if (it != _pools.end()) { + pool = it->second; + } + } + // Call release() outside _manager_mutex to avoid deadlock. + if (pool) { + pool->release(conn); } } @@ -142,6 +187,9 @@ void ConnectionPoolManager::configure(int max_size, int idle_timeout_secs) { void ConnectionPoolManager::closePools() { std::lock_guard lock(_manager_mutex); + // Keep _manager_mutex held for the full close operation so that + // acquireConnection()/returnConnection() cannot create or use pools + // while closePools() is in progress. for (auto& [conn_str, pool] : _pools) { if (pool) { pool->close(); diff --git a/tests/test_009_pooling.py b/tests/test_009_pooling.py index 1a3e5f09..d5b34346 100644 --- a/tests/test_009_pooling.py +++ b/tests/test_009_pooling.py @@ -278,6 +278,42 @@ def try_overflow(): c.close() +def test_pool_release_overflow_disconnects_outside_mutex(conn_str): + """Test that releasing a connection when pool is full disconnects it correctly. + + When a connection is returned to a pool that is already at max_size, + the connection must be disconnected. This exercises the overflow path in + ConnectionPool::release() (connection_pool.cpp) where should_disconnect + is set and disconnect happens outside the mutex. + + With the current pool semantics, max_size limits total concurrent + connections, so we acquire two connections with max_size=2, then shrink + the pool to max_size=1 before returning them. The second close hits + the overflow path. + """ + pooling(max_size=2, idle_timeout=30) + + conn1 = connect(conn_str) + conn2 = connect(conn_str) + + # Shrink idle capacity so first close fills the pool and second overflows + pooling(max_size=1, idle_timeout=30) + + # Close conn1 — returned to the pool (pool now has 1 idle entry) + conn1.close() + + # Close conn2 — pool is full (1 idle already), so this connection + # must be disconnected rather than pooled (overflow path). + conn2.close() + + # Verify the pool is still functional + conn3 = connect(conn_str) + cursor = conn3.cursor() + cursor.execute("SELECT 1") + assert cursor.fetchone()[0] == 1 + conn3.close() + + @pytest.mark.skip("Flaky test - idle timeout behavior needs investigation") def test_pool_idle_timeout_removes_connections(conn_str): """Test that idle_timeout removes connections from the pool after the timeout.""" diff --git a/tests/test_021_concurrent_connection_perf.py b/tests/test_021_concurrent_connection_perf.py new file mode 100644 index 00000000..0d4d0753 --- /dev/null +++ b/tests/test_021_concurrent_connection_perf.py @@ -0,0 +1,248 @@ +""" +Concurrent connection performance test for GIL release during ODBC operations. + +This test verifies that the GIL is properly released during blocking ODBC +connection establishment (SQLDriverConnect) and teardown (SQLDisconnect), +allowing multiple Python threads to establish connections in parallel. + +Without GIL release, N threads would take ~N * single_connection_time +(serialized by the GIL). With GIL release, threads overlap their I/O and +total wall-clock time is close to single_connection_time. + +Marked with @pytest.mark.stress to run in dedicated performance pipelines. +""" + +import os +import time +import threading +import statistics + +import pytest +import mssql_python +from mssql_python import connect + + +@pytest.fixture(scope="module") +def perf_conn_str(): + """Get connection string from environment.""" + conn_str = os.getenv("DB_CONNECTION_STRING") + if not conn_str: + pytest.skip("DB_CONNECTION_STRING environment variable not set") + return conn_str + + +def _connect_and_close(conn_str: str) -> float: + """Open a connection, close it, and return the elapsed time in seconds.""" + start = time.perf_counter() + conn = connect(conn_str) + conn.close() + return time.perf_counter() - start + + +# ============================================================================ +# GIL Release Performance Tests +# ============================================================================ + + +@pytest.mark.stress +def test_concurrent_connection_gil_release(perf_conn_str): + """ + Verify that concurrent connection establishment achieves parallelism, + proving the GIL is released during SQLDriverConnect. + + Approach: + 1. Measure the baseline: average time for a single connection (no contention). + 2. Launch NUM_THREADS threads, each creating a fresh connection (no pooling). + 3. The wall-clock time for all threads should be much less than + NUM_THREADS * baseline, because connections overlap their I/O. + + We require a speedup > 2x (very conservative). In practice, with proper + GIL release, the speedup approaches NUM_THREADS for I/O-bound work. + """ + NUM_THREADS = 10 + WARMUP_ROUNDS = 2 + BASELINE_ROUNDS = 5 + + # Disable pooling so every connect() creates a brand-new ODBC connection. + mssql_python.pooling(enabled=False) + + # ---- warm-up (prime DNS cache, driver loading, etc.) ---- + for _ in range(WARMUP_ROUNDS): + _connect_and_close(perf_conn_str) + + # ---- baseline: serial single-connection time ---- + serial_times = [_connect_and_close(perf_conn_str) for _ in range(BASELINE_ROUNDS)] + baseline = statistics.median(serial_times) + print(f"\n[BASELINE] Single connection (median of {BASELINE_ROUNDS}): {baseline*1000:.1f} ms") + + # ---- concurrent: N threads each opening a connection ---- + barrier = threading.Barrier(NUM_THREADS) + thread_times = [None] * NUM_THREADS + errors = [] + + def worker(idx): + try: + barrier.wait(timeout=30) + thread_times[idx] = _connect_and_close(perf_conn_str) + except Exception as exc: + errors.append((idx, str(exc))) + + threads = [threading.Thread(target=worker, args=(i,), daemon=True) for i in range(NUM_THREADS)] + + wall_start = time.perf_counter() + for t in threads: + t.start() + for t in threads: + t.join(timeout=120) + wall_time = time.perf_counter() - wall_start + + # ---- assertions ---- + assert not errors, f"Thread errors: {errors}" + assert all(t is not None for t in thread_times), "Some threads did not complete" + + serial_estimate = NUM_THREADS * baseline + speedup = serial_estimate / wall_time + + print(f"[CONCURRENT] {NUM_THREADS} threads wall-clock: {wall_time*1000:.1f} ms") + print(f"[SERIAL EST] {NUM_THREADS} × baseline: {serial_estimate*1000:.1f} ms") + print(f"[SPEEDUP] {speedup:.2f}x (>{NUM_THREADS}x means full parallelism)") + + # Conservative threshold: even modest parallelism should beat 2x. + # Without GIL release this would be ~1.0x (fully serialized). + assert speedup > 2.0, ( + f"Concurrent connections are not running in parallel (speedup={speedup:.2f}x). " + f"Expected >2x, got wall_time={wall_time*1000:.1f}ms vs serial_estimate={serial_estimate*1000:.1f}ms. " + f"This likely indicates the GIL is not being released during SQLDriverConnect." + ) + + print(f"[PASSED] GIL release verified — {speedup:.1f}x speedup with {NUM_THREADS} threads") + + +@pytest.mark.stress +def test_concurrent_disconnect_gil_release(perf_conn_str): + """ + Verify that concurrent disconnection works correctly with GIL release. + + Opens N connections serially, then closes them all concurrently. + On localhost, disconnect is sub-millisecond so thread overhead dominates + and speedup ratios are not meaningful. Instead, we verify that all + concurrent disconnects complete without errors or deadlocks. + """ + NUM_THREADS = 10 + + mssql_python.pooling(enabled=False) + + # warm-up + for _ in range(2): + _connect_and_close(perf_conn_str) + + # open N connections serially + connections = [connect(perf_conn_str) for _ in range(NUM_THREADS)] + + # concurrent close + barrier = threading.Barrier(NUM_THREADS) + thread_times = [None] * NUM_THREADS + errors = [] + + def close_worker(idx, conn): + try: + barrier.wait(timeout=30) + start = time.perf_counter() + conn.close() + thread_times[idx] = time.perf_counter() - start + except Exception as exc: + errors.append((idx, str(exc))) + + threads = [ + threading.Thread(target=close_worker, args=(i, connections[i]), daemon=True) + for i in range(NUM_THREADS) + ] + + wall_start = time.perf_counter() + for t in threads: + t.start() + for t in threads: + t.join(timeout=60) + wall_time = time.perf_counter() - wall_start + + assert not errors, f"Thread errors: {errors}" + assert all(t is not None for t in thread_times), "Some threads did not complete" + + print(f"\n[CONCURRENT] {NUM_THREADS} threads close wall-clock: {wall_time*1000:.1f} ms") + print(f"[PASSED] All {NUM_THREADS} concurrent disconnects completed without errors") + + +@pytest.mark.stress +def test_mixed_connect_disconnect_under_load(perf_conn_str): + """ + Stress test: threads continuously connect and disconnect while + other threads do CPU-bound Python work. Verifies that GIL release + during ODBC I/O does not starve or deadlock Python threads. + """ + NUM_IO_THREADS = 5 + NUM_CPU_THREADS = 3 + DURATION_SECS = 5 + + mssql_python.pooling(enabled=False) + + stop_event = threading.Event() + io_counts = [0] * NUM_IO_THREADS + cpu_counts = [0] * NUM_CPU_THREADS + errors = [] + + def io_worker(idx): + """Repeatedly connect/disconnect.""" + try: + while not stop_event.is_set(): + conn = connect(perf_conn_str) + conn.close() + io_counts[idx] += 1 + except Exception as exc: + errors.append((f"io-{idx}", str(exc))) + + def cpu_worker(idx): + """Do CPU-bound work (must be able to acquire GIL).""" + try: + while not stop_event.is_set(): + # Busy work that requires the GIL + total = sum(range(10000)) + _ = [x**2 for x in range(100)] + cpu_counts[idx] += 1 + except Exception as exc: + errors.append((f"cpu-{idx}", str(exc))) + + threads = [] + for i in range(NUM_IO_THREADS): + threads.append(threading.Thread(target=io_worker, args=(i,), daemon=True)) + for i in range(NUM_CPU_THREADS): + threads.append(threading.Thread(target=cpu_worker, args=(i,), daemon=True)) + + for t in threads: + t.start() + + time.sleep(DURATION_SECS) + stop_event.set() + + for t in threads: + t.join(timeout=30) + + total_io = sum(io_counts) + total_cpu = sum(cpu_counts) + + print(f"\n[MIXED LOAD] Duration: {DURATION_SECS}s") + print(f" I/O threads ({NUM_IO_THREADS}): {total_io} connect/disconnect cycles") + print(f" CPU threads ({NUM_CPU_THREADS}): {total_cpu} iterations") + + assert not errors, f"Errors during mixed load: {errors}" + + # CPU threads must have made progress — if the GIL were held during + # ODBC I/O, CPU threads would be starved. + assert total_cpu > 0, ( + "CPU threads made no progress — GIL may be held during ODBC I/O, " + "starving Python threads." + ) + + # I/O threads must have completed at least a few cycles + assert total_io > 0, "I/O threads made no progress." + + print(f"[PASSED] Mixed I/O + CPU load: no starvation or deadlocks")