diff --git a/lib/cpp/README.md b/lib/cpp/README.md index a9959af8c2..7827dc385e 100644 --- a/lib/cpp/README.md +++ b/lib/cpp/README.md @@ -138,7 +138,11 @@ TSSLSocket are always created from TSSLSocketFactory. The default TSSLSocketFactory context uses OpenSSL's version-flexible TLS method and sets TLS 1.2 as the minimum negotiated protocol version. Applications that need a different protocol range can provide a custom SSLContext factory and -adjust the OpenSSL context options before creating sockets. +adjust the OpenSSL context options before creating sockets. Applications that +link against an OpenSSL-compatible TLS library can also create and configure an +SSL_CTX externally, wrap it with `SSLContext`, and pass it through the factory +(for example, protocol-specific or dual-certificate setups that the default +factory methods cannot express). ## How to use SSL APIs diff --git a/lib/cpp/src/thrift/transport/TSSLSocket.cpp b/lib/cpp/src/thrift/transport/TSSLSocket.cpp index ac88acd926..ac0b85945f 100644 --- a/lib/cpp/src/thrift/transport/TSSLSocket.cpp +++ b/lib/cpp/src/thrift/transport/TSSLSocket.cpp @@ -179,7 +179,7 @@ static bool matchName(const char* host, const char* pattern, int size); static char uppercase(char c); // SSLContext implementation -SSLContext::SSLContext(const SSLProtocol& protocol) { +SSLContext::SSLContext(const SSLProtocol& protocol) : takeOwnership_(true) { if (protocol == SSLTLS) { ctx_ = SSL_CTX_new(SSLv23_method()); #ifndef OPENSSL_NO_SSL3 @@ -213,8 +213,15 @@ SSLContext::SSLContext(const SSLProtocol& protocol) { } } +SSLContext::SSLContext(SSL_CTX* ctx, bool takeOwnership) + : ctx_(ctx), takeOwnership_(takeOwnership) { + if (ctx_ == nullptr) { + throw TSSLException("SSLContext: ctx must not be null"); + } +} + SSLContext::~SSLContext() { - if (ctx_ != nullptr) { + if (ctx_ != nullptr && takeOwnership_) { SSL_CTX_free(ctx_); ctx_ = nullptr; } diff --git a/lib/cpp/src/thrift/transport/TSSLSocket.h b/lib/cpp/src/thrift/transport/TSSLSocket.h index e6b8992451..b331c624d9 100644 --- a/lib/cpp/src/thrift/transport/TSSLSocket.h +++ b/lib/cpp/src/thrift/transport/TSSLSocket.h @@ -365,12 +365,20 @@ class TSSLException : public TTransportException { class SSLContext { public: SSLContext(const SSLProtocol& protocol = SSLTLS); + /** + * Wrap an existing OpenSSL SSL_CTX. + * + * @param ctx OpenSSL context to wrap + * @param takeOwnership If true (default), SSLContext frees ctx on destruction + */ + explicit SSLContext(SSL_CTX* ctx, bool takeOwnership = true); virtual ~SSLContext(); SSL* createSSL(); SSL_CTX* get() { return ctx_; } private: SSL_CTX* ctx_; + bool takeOwnership_; }; /** diff --git a/lib/cpp/test/SecurityTest.cpp b/lib/cpp/test/SecurityTest.cpp index d64b0da844..3fb1f18be3 100644 --- a/lib/cpp/test/SecurityTest.cpp +++ b/lib/cpp/test/SecurityTest.cpp @@ -34,6 +34,7 @@ #endif using apache::thrift::transport::TSSLServerSocket; +using apache::thrift::transport::SSLContext; using apache::thrift::transport::SSLContextFactory; using apache::thrift::transport::TSSLException; using apache::thrift::transport::TServerTransport; @@ -274,6 +275,34 @@ BOOST_AUTO_TEST_CASE(custom_ssl_context_options) context.reset(); } +BOOST_AUTO_TEST_CASE(wrapped_ssl_context) +{ + SSL_CTX* raw = SSL_CTX_new(TLS_method()); + BOOST_REQUIRE(raw != nullptr); + SSL_CTX_set_mode(raw, SSL_MODE_AUTO_RETRY); + + std::shared_ptr context; + TSSLSocketFactory factory([&context, raw]() { + context = std::make_shared(raw, true); + return context; + }); + BOOST_CHECK(context->get() == raw); + context.reset(); +} + +BOOST_AUTO_TEST_CASE(wrapped_ssl_context_null) +{ + try + { + std::make_shared(nullptr); + BOOST_FAIL("Expected null SSL_CTX to throw"); + } + catch (const TSSLException& ex) + { + BOOST_CHECK_EQUAL("SSLContext: ctx must not be null", std::string(ex.what())); + } +} + BOOST_AUTO_TEST_CASE(custom_ssl_context_factory_validation) { try