From eed1207bbb92f987613fad1be9af69d24e371ff5 Mon Sep 17 00:00:00 2001 From: yanfeng Date: Wed, 27 May 2026 11:00:14 +0800 Subject: [PATCH 1/2] Add gflag to manage ConcurrencyRemover lifecycle for CallAfterRpcResp In SendRpcResponse, ConcurrencyRemover was destroyed before CallAfterRpcResp was called, meaning concurrency control didn't cover the after-response callback. This could lead to inaccurate concurrency tracking and latency measurements. This change adds FLAGS_concurrency_remover_manages_after_rpc_resp (default: false) and automatically sets it to controller when set_after_rpc_resp_fn is called. Implementation: - Add _concurrency_remover_manages_after_rpc_resp flag to Controller - In set_after_rpc_resp_fn(), read gflag value and set to controller instance - In baidu_rpc_protocol, use controller flag instead of global gflag - Use unique_ptr with explicit reset() for clear control flow When false (default): Original behavior - ConcurrencyRemover is released before CallAfterRpcResp via explicit reset(). When true (gflag enabled when callback set): ConcurrencyRemover lives until the end of BRPC_SCOPE_EXIT, covering the entire response lifecycle. Note: HTTP protocol not modified in this change due to its more complex async flow. Can be addressed separately if needed. Co-Authored-By: Claude Opus 4.7 --- src/brpc/controller.cpp | 8 ++++++++ src/brpc/controller.h | 8 +++++++- src/brpc/policy/baidu_rpc_protocol.cpp | 18 +++++++++++++++--- 3 files changed, 30 insertions(+), 4 deletions(-) diff --git a/src/brpc/controller.cpp b/src/brpc/controller.cpp index 15c8c91887..dd67435b71 100644 --- a/src/brpc/controller.cpp +++ b/src/brpc/controller.cpp @@ -127,6 +127,7 @@ const Controller* GetSubControllerOfSelectiveChannel( DECLARE_bool(usercode_in_pthread); DECLARE_bool(usercode_in_coroutine); +DECLARE_bool(concurrency_remover_manages_after_rpc_resp); static const int MAX_RETRY_COUNT = 1000; static bvar::Adder* g_ncontroller = NULL; @@ -298,6 +299,7 @@ void Controller::ResetPods() { _response_streams.clear(); _remote_stream_settings = NULL; _auth_flags = 0; + _concurrency_remover_manages_after_rpc_resp = false; _rpc_received_us = 0; } @@ -1593,6 +1595,12 @@ int Controller::GetSockOption(int level, int optname, void* optval, socklen_t* o } } +void Controller::set_after_rpc_resp_fn(AfterRpcRespFnType&& fn) { + _after_rpc_resp_fn = fn; + // Set the flag from global gflag when after_rpc_resp_fn is set + _concurrency_remover_manages_after_rpc_resp = FLAGS_concurrency_remover_manages_after_rpc_resp; +} + void Controller::CallAfterRpcResp(const google::protobuf::Message* req, const google::protobuf::Message* res) { if (_after_rpc_resp_fn) { _after_rpc_resp_fn(this, req, res); diff --git a/src/brpc/controller.h b/src/brpc/controller.h index 45f71b72f6..c0b59f5b07 100644 --- a/src/brpc/controller.h +++ b/src/brpc/controller.h @@ -621,10 +621,15 @@ friend void policy::ProcessThriftRequest(InputMessageBase*); const google::protobuf::Message* req, const google::protobuf::Message* res)>; - void set_after_rpc_resp_fn(AfterRpcRespFnType&& fn) { _after_rpc_resp_fn = fn; } + void set_after_rpc_resp_fn(AfterRpcRespFnType&& fn); void CallAfterRpcResp(const google::protobuf::Message* req, const google::protobuf::Message* res); + // Check whether ConcurrencyRemover should manage the lifecycle of CallAfterRpcResp. + bool concurrency_remover_manages_after_rpc_resp() const { + return _concurrency_remover_manages_after_rpc_resp; + } + void set_request_content_type(ContentType type) { _request_content_type = type; } @@ -921,6 +926,7 @@ friend void policy::ProcessThriftRequest(InputMessageBase*); uint32_t _auth_flags; AfterRpcRespFnType _after_rpc_resp_fn; + bool _concurrency_remover_manages_after_rpc_resp; // The point in time when the rpc is read from the socket int64_t _rpc_received_us; diff --git a/src/brpc/policy/baidu_rpc_protocol.cpp b/src/brpc/policy/baidu_rpc_protocol.cpp index 2c5a7e7224..2707c100cd 100644 --- a/src/brpc/policy/baidu_rpc_protocol.cpp +++ b/src/brpc/policy/baidu_rpc_protocol.cpp @@ -62,6 +62,11 @@ DEFINE_bool(baidu_protocol_use_fullname, true, DEFINE_bool(baidu_std_protocol_deliver_timeout_ms, false, "If this flag is true, baidu_std puts timeout_ms in requests."); +DEFINE_bool(concurrency_remover_manages_after_rpc_resp, false, + "If this flag is true, ConcurrencyRemover will manage the lifecycle " + "of CallAfterRpcResp, ensuring concurrency control covers the entire " + "response processing including after-response callbacks."); + DECLARE_bool(pb_enum_as_number); // Notes: @@ -285,9 +290,14 @@ void SendRpcResponse(int64_t correlation_id, Controller* cntl, // Recycle resources at the end of this function. BRPC_SCOPE_EXIT { - { - // Remove concurrency and record latency at first. - ConcurrencyRemover concurrency_remover(method_status, cntl, received_us); + std::unique_ptr concurrency_remover_ptr( + new ConcurrencyRemover(method_status, cntl, received_us)); + + // Only manage CallAfterRpcResp lifecycle if the flag is set + // (which happens when set_after_rpc_resp_fn is called with the gflag enabled) + if (!cntl->concurrency_remover_manages_after_rpc_resp()) { + // Original behavior: remove concurrency before CallAfterRpcResp + concurrency_remover_ptr.reset(); } std::unique_ptr recycle_cntl(cntl); @@ -302,6 +312,8 @@ void SendRpcResponse(int64_t correlation_id, Controller* cntl, } else { BaiduProxyPBMessages::Return(static_cast(messages)); } + // If concurrency_remover_manages_after_rpc_resp() is true, + // concurrency_remover_ptr will be destroyed here }; StreamIds response_stream_ids = accessor.response_streams(); From e35ddeabf03ef5369327a78529ae64d680514015 Mon Sep 17 00:00:00 2001 From: yanfeng Date: Sun, 31 May 2026 10:09:18 +0800 Subject: [PATCH 2/2] fix after-rpc response cleanup ordering Move CallAfterRpcResp to the response-written point and split ConcurrencyRemover accounting from concurrency release so after-response callbacks can safely reset the controller without breaking cleanup. Co-Authored-By: Claude Opus 4.7 --- example/asynchronous_echo_c++/server.cpp | 3 +- example/echo_c++/server.cpp | 3 +- example/http_c++/http_server.cpp | 3 +- src/brpc/controller.cpp | 9 +--- src/brpc/controller.h | 14 ++--- .../details/controller_private_accessor.h | 12 +++++ src/brpc/details/method_status.cpp | 24 +++++++-- src/brpc/details/method_status.h | 7 ++- src/brpc/details/server_private_accessor.h | 8 ++- src/brpc/policy/baidu_rpc_protocol.cpp | 51 +++++++++--------- src/brpc/policy/http_rpc_protocol.cpp | 53 +++++++++++++------ test/brpc_channel_unittest.cpp | 37 +++++++++++-- 12 files changed, 153 insertions(+), 71 deletions(-) diff --git a/example/asynchronous_echo_c++/server.cpp b/example/asynchronous_echo_c++/server.cpp index dcce19a80b..8715be9455 100644 --- a/example/asynchronous_echo_c++/server.cpp +++ b/example/asynchronous_echo_c++/server.cpp @@ -47,7 +47,8 @@ class EchoServiceImpl : public example::EchoService { // optional: set a callback function which is called after response is sent // and before cntl/req/res is destructed. cntl->set_after_rpc_resp_fn(std::bind(&EchoServiceImpl::CallAfterRpc, - std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + std::placeholders::_1, std::placeholders::_2, std::placeholders::_3), + true); // The purpose of following logs is to help you to understand // how clients interact with servers more intuitively. You should diff --git a/example/echo_c++/server.cpp b/example/echo_c++/server.cpp index 4113114629..e5c71d59ad 100644 --- a/example/echo_c++/server.cpp +++ b/example/echo_c++/server.cpp @@ -53,7 +53,8 @@ class EchoServiceImpl : public EchoService { // optional: set a callback function which is called after response is sent // and before cntl/req/res is destructed. cntl->set_after_rpc_resp_fn(std::bind(&EchoServiceImpl::CallAfterRpc, - std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + std::placeholders::_1, std::placeholders::_2, std::placeholders::_3), + true); // The purpose of following logs is to help you to understand // how clients interact with servers more intuitively. You should diff --git a/example/http_c++/http_server.cpp b/example/http_c++/http_server.cpp index 05c9a0ee4c..d0a712a820 100644 --- a/example/http_c++/http_server.cpp +++ b/example/http_c++/http_server.cpp @@ -53,7 +53,8 @@ class HttpServiceImpl : public HttpService { // optional: set a callback function which is called after response is sent // and before cntl/req/res is destructed. cntl->set_after_rpc_resp_fn(std::bind(&HttpServiceImpl::CallAfterRpc, - std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + std::placeholders::_1, std::placeholders::_2, std::placeholders::_3), + true); // Fill response. cntl->http_response().set_content_type("text/plain"); diff --git a/src/brpc/controller.cpp b/src/brpc/controller.cpp index dd67435b71..0718359772 100644 --- a/src/brpc/controller.cpp +++ b/src/brpc/controller.cpp @@ -127,7 +127,6 @@ const Controller* GetSubControllerOfSelectiveChannel( DECLARE_bool(usercode_in_pthread); DECLARE_bool(usercode_in_coroutine); -DECLARE_bool(concurrency_remover_manages_after_rpc_resp); static const int MAX_RETRY_COUNT = 1000; static bvar::Adder* g_ncontroller = NULL; @@ -299,7 +298,6 @@ void Controller::ResetPods() { _response_streams.clear(); _remote_stream_settings = NULL; _auth_flags = 0; - _concurrency_remover_manages_after_rpc_resp = false; _rpc_received_us = 0; } @@ -1595,16 +1593,11 @@ int Controller::GetSockOption(int level, int optname, void* optval, socklen_t* o } } -void Controller::set_after_rpc_resp_fn(AfterRpcRespFnType&& fn) { - _after_rpc_resp_fn = fn; - // Set the flag from global gflag when after_rpc_resp_fn is set - _concurrency_remover_manages_after_rpc_resp = FLAGS_concurrency_remover_manages_after_rpc_resp; -} - void Controller::CallAfterRpcResp(const google::protobuf::Message* req, const google::protobuf::Message* res) { if (_after_rpc_resp_fn) { _after_rpc_resp_fn(this, req, res); _after_rpc_resp_fn = nullptr; + clear_flag(FLAGS_MANAGE_AFTER_RPC_RESP); } } diff --git a/src/brpc/controller.h b/src/brpc/controller.h index c0b59f5b07..9cd0dfc3aa 100644 --- a/src/brpc/controller.h +++ b/src/brpc/controller.h @@ -152,6 +152,7 @@ friend void policy::ProcessThriftRequest(InputMessageBase*); static const uint32_t FLAGS_PB_SINGLE_REPEATED_TO_ARRAY = (1 << 20); static const uint32_t FLAGS_MANAGE_HTTP_BODY_ON_ERROR = (1 << 21); static const uint32_t FLAGS_WRITE_TO_SOCKET_IN_BACKGROUND = (1 << 22); + static const uint32_t FLAGS_MANAGE_AFTER_RPC_RESP = (1 << 23); public: struct Inheritable { @@ -621,15 +622,15 @@ friend void policy::ProcessThriftRequest(InputMessageBase*); const google::protobuf::Message* req, const google::protobuf::Message* res)>; - void set_after_rpc_resp_fn(AfterRpcRespFnType&& fn); + void set_after_rpc_resp_fn(AfterRpcRespFnType&& fn, + bool manage_concurrency_remover = false) { + _after_rpc_resp_fn = fn; + set_flag(FLAGS_MANAGE_AFTER_RPC_RESP, + manage_concurrency_remover && !!_after_rpc_resp_fn); + } void CallAfterRpcResp(const google::protobuf::Message* req, const google::protobuf::Message* res); - // Check whether ConcurrencyRemover should manage the lifecycle of CallAfterRpcResp. - bool concurrency_remover_manages_after_rpc_resp() const { - return _concurrency_remover_manages_after_rpc_resp; - } - void set_request_content_type(ContentType type) { _request_content_type = type; } @@ -926,7 +927,6 @@ friend void policy::ProcessThriftRequest(InputMessageBase*); uint32_t _auth_flags; AfterRpcRespFnType _after_rpc_resp_fn; - bool _concurrency_remover_manages_after_rpc_resp; // The point in time when the rpc is read from the socket int64_t _rpc_received_us; diff --git a/src/brpc/details/controller_private_accessor.h b/src/brpc/details/controller_private_accessor.h index 0ad1aba640..5e65b4b234 100644 --- a/src/brpc/details/controller_private_accessor.h +++ b/src/brpc/details/controller_private_accessor.h @@ -102,6 +102,18 @@ class ControllerPrivateAccessor { std::shared_ptr span() const; + bool has_after_rpc_resp_fn() const { + return !!_cntl->_after_rpc_resp_fn; + } + + bool has_added_concurrency() const { + return _cntl->has_flag(Controller::FLAGS_ADDED_CONCURRENCY); + } + + bool manages_after_rpc_resp() const { + return _cntl->has_flag(Controller::FLAGS_MANAGE_AFTER_RPC_RESP); + } + uint32_t pipelined_count() const { return _cntl->_pipelined_count; } void set_pipelined_count(uint32_t count) { _cntl->_pipelined_count = count; } diff --git a/src/brpc/details/method_status.cpp b/src/brpc/details/method_status.cpp index 3bed6bf209..ea1e62133e 100644 --- a/src/brpc/details/method_status.cpp +++ b/src/brpc/details/method_status.cpp @@ -16,9 +16,8 @@ // under the License. -#include -#include "butil/macros.h" #include "brpc/controller.h" +#include "brpc/details/controller_private_accessor.h" #include "brpc/details/server_private_accessor.h" #include "brpc/details/method_status.h" @@ -156,12 +155,27 @@ int HandleResponseWritten(bthread_id_t id, void* data, int /*error_code*/) { return 0; } -ConcurrencyRemover::~ConcurrencyRemover() { +ConcurrencyRemover::ConcurrencyRemover( + MethodStatus* status, Controller* c, int64_t received_us) + : _status(status) + , _c(c) + , _server(c->server()) + , _added_concurrency(ControllerPrivateAccessor(c).has_added_concurrency()) + , _received_us(received_us) { +} + +void ConcurrencyRemover::OnResponded(int error_code) { if (_status) { - _status->OnResponded(_c->ErrorCode(), butil::cpuwide_time_us() - _received_us); + _status->OnResponded(error_code, butil::cpuwide_time_us() - _received_us); _status = NULL; } - ServerPrivateAccessor(_c->server()).RemoveConcurrency(_c); +} + +ConcurrencyRemover::~ConcurrencyRemover() { + OnResponded(_c->ErrorCode()); + if (_server) { + ServerPrivateAccessor(_server).RemoveConcurrency(_added_concurrency); + } } } // namespace brpc diff --git a/src/brpc/details/method_status.h b/src/brpc/details/method_status.h index 9b7f070991..9507a03015 100644 --- a/src/brpc/details/method_status.h +++ b/src/brpc/details/method_status.h @@ -83,14 +83,17 @@ int HandleResponseWritten(bthread_id_t id, void* data, int error_code); class ConcurrencyRemover { public: - ConcurrencyRemover(MethodStatus* status, Controller* c, int64_t received_us) - : _status(status) , _c(c) , _received_us(received_us) {} + ConcurrencyRemover(MethodStatus* status, Controller* c, int64_t received_us); ~ConcurrencyRemover(); + void OnResponded(int error_code); + private: DISALLOW_COPY_AND_ASSIGN(ConcurrencyRemover); MethodStatus* _status; Controller* _c; + const Server* _server; + bool _added_concurrency; int64_t _received_us; }; diff --git a/src/brpc/details/server_private_accessor.h b/src/brpc/details/server_private_accessor.h index aacf283564..d502ba7054 100644 --- a/src/brpc/details/server_private_accessor.h +++ b/src/brpc/details/server_private_accessor.h @@ -50,12 +50,16 @@ class ServerPrivateAccessor { <= _server->options().max_concurrency); } - void RemoveConcurrency(const Controller* c) { - if (c->has_flag(Controller::FLAGS_ADDED_CONCURRENCY)) { + void RemoveConcurrency(bool added_concurrency) { + if (added_concurrency) { butil::subtle::NoBarrier_AtomicIncrement(&_server->_concurrency, -1); } } + void RemoveConcurrency(const Controller* c) { + RemoveConcurrency(c->has_flag(Controller::FLAGS_ADDED_CONCURRENCY)); + } + // Find by MethodDescriptor::full_name const Server::MethodProperty* FindMethodPropertyByFullName(const butil::StringPiece &fullname) { diff --git a/src/brpc/policy/baidu_rpc_protocol.cpp b/src/brpc/policy/baidu_rpc_protocol.cpp index 2707c100cd..396fdf53bd 100644 --- a/src/brpc/policy/baidu_rpc_protocol.cpp +++ b/src/brpc/policy/baidu_rpc_protocol.cpp @@ -62,11 +62,6 @@ DEFINE_bool(baidu_protocol_use_fullname, true, DEFINE_bool(baidu_std_protocol_deliver_timeout_ms, false, "If this flag is true, baidu_std puts timeout_ms in requests."); -DEFINE_bool(concurrency_remover_manages_after_rpc_resp, false, - "If this flag is true, ConcurrencyRemover will manage the lifecycle " - "of CallAfterRpcResp, ensuring concurrency control covers the entire " - "response processing including after-response callbacks."); - DECLARE_bool(pb_enum_as_number); // Notes: @@ -279,6 +274,7 @@ void SendRpcResponse(int64_t correlation_id, Controller* cntl, RpcPBMessages* messages, const Server* server, MethodStatus* method_status, int64_t received_us, std::shared_ptr span) { + std::unique_ptr recycle_cntl(cntl); ControllerPrivateAccessor accessor(cntl); if (span) { span->set_start_send_us(butil::cpuwide_time_us()); @@ -287,33 +283,20 @@ void SendRpcResponse(int64_t correlation_id, Controller* cntl, const google::protobuf::Message* req = NULL == messages ? NULL : messages->Request(); const google::protobuf::Message* res = NULL == messages ? NULL : messages->Response(); + const bool has_after_rpc_resp_fn = accessor.has_after_rpc_resp_fn(); + const bool manages_after_rpc_resp = accessor.manages_after_rpc_resp(); + std::unique_ptr concurrency_remover( + new ConcurrencyRemover(method_status, cntl, received_us)); - // Recycle resources at the end of this function. BRPC_SCOPE_EXIT { - std::unique_ptr concurrency_remover_ptr( - new ConcurrencyRemover(method_status, cntl, received_us)); - - // Only manage CallAfterRpcResp lifecycle if the flag is set - // (which happens when set_after_rpc_resp_fn is called with the gflag enabled) - if (!cntl->concurrency_remover_manages_after_rpc_resp()) { - // Original behavior: remove concurrency before CallAfterRpcResp - concurrency_remover_ptr.reset(); - } - - std::unique_ptr recycle_cntl(cntl); - if (NULL == messages) { return; } - - cntl->CallAfterRpcResp(req, res); if (NULL == server->options().baidu_master_service) { server->options().rpc_pb_message_factory->Return(messages); } else { BaiduProxyPBMessages::Return(static_cast(messages)); } - // If concurrency_remover_manages_after_rpc_resp() is true, - // concurrency_remover_ptr will be destroyed here }; StreamIds response_stream_ids = accessor.response_streams(); @@ -402,8 +385,11 @@ void SendRpcResponse(int64_t correlation_id, Controller* cntl, ResponseWriteInfo args; bthread_id_t response_id = INVALID_BTHREAD_ID; + const bool wait_for_response = (span || has_after_rpc_resp_fn); if (span) { span->set_response_size(res_buf.size()); + } + if (wait_for_response) { CHECK_EQ(0, bthread_id_create(&response_id, &args, HandleResponseWritten)); } @@ -463,12 +449,25 @@ void SendRpcResponse(int64_t correlation_id, Controller* cntl, } } - if (span) { + if (wait_for_response) { bthread_id_join(response_id); - // Do not care about the result of background writing. - // TODO: this is not sent - span->set_sent_us(args.sent_us); + if (span) { + // Do not care about the result of background writing. + // TODO: this is not sent + span->set_sent_us(args.sent_us); + } + } + const int responded_error_code = cntl->ErrorCode(); + if (!manages_after_rpc_resp) { + concurrency_remover.reset(); + } + if (has_after_rpc_resp_fn) { + cntl->CallAfterRpcResp(req, res); + } + if (manages_after_rpc_resp) { + concurrency_remover->OnResponded(responded_error_code); } + concurrency_remover.reset(); } namespace { diff --git a/src/brpc/policy/http_rpc_protocol.cpp b/src/brpc/policy/http_rpc_protocol.cpp index e5e22d924a..71e154f835 100644 --- a/src/brpc/policy/http_rpc_protocol.cpp +++ b/src/brpc/policy/http_rpc_protocol.cpp @@ -812,13 +812,7 @@ friend class HttpResponseSenderAsDone; class HttpResponseSenderAsDone : public google::protobuf::Closure { public: explicit HttpResponseSenderAsDone(HttpResponseSender* s) : _sender(std::move(*s)) {} - void Run() override { - if (NULL != _sender._messages) { - _sender._cntl->CallAfterRpcResp(_sender._messages->Request(), - _sender._messages->Response()); - } - delete this; - } + void Run() override { delete this; } private: HttpResponseSender _sender; @@ -840,7 +834,10 @@ HttpResponseSender::~HttpResponseSender() { if (span) { span->set_start_send_us(butil::cpuwide_time_us()); } - ConcurrencyRemover concurrency_remover(_method_status, cntl, _received_us); + const bool has_after_rpc_resp_fn = accessor.has_after_rpc_resp_fn(); + const bool manages_after_rpc_resp = accessor.manages_after_rpc_resp(); + std::unique_ptr concurrency_remover( + new ConcurrencyRemover(_method_status, cntl, _received_us)); Socket* socket = accessor.get_sending_socket(); const google::protobuf::Message* res = NULL != _messages ? _messages->Response() : NULL; @@ -851,6 +848,14 @@ HttpResponseSender::~HttpResponseSender() { const HttpHeader* req_header = &cntl->http_request(); HttpHeader* res_header = &cntl->http_response(); + HttpHeader original_response_header; + butil::IOBuf original_response_attachment; + if (has_after_rpc_resp_fn) { + original_response_header.Swap(*res_header); + original_response_attachment.swap(cntl->response_attachment()); + res_header->Swap(original_response_header); + cntl->response_attachment().swap(original_response_attachment); + } res_header->set_version(req_header->major_version(), req_header->minor_version()); @@ -1000,7 +1005,8 @@ HttpResponseSender::~HttpResponseSender() { Socket::WriteOptions wopt; wopt.ignore_eovercrowded = true; bthread_id_t response_id = INVALID_BTHREAD_ID; - if (span) { + const bool wait_for_response = (span || has_after_rpc_resp_fn); + if (wait_for_response) { CHECK_EQ(0, bthread_id_create(&response_id, &args, HandleResponseWritten)); wopt.id_wait = response_id; wopt.notify_on_success = true; @@ -1020,8 +1026,10 @@ HttpResponseSender::~HttpResponseSender() { if (FLAGS_http_verbose) { LOG(INFO) << '\n' << *h2_response; } - if (span) { - span->set_response_size(h2_response->EstimatedByteSize()); + if (span || has_after_rpc_resp_fn) { + if (span) { + span->set_response_size(h2_response->EstimatedByteSize()); + } } rc = socket->Write(h2_response, &wopt); } @@ -1050,12 +1058,27 @@ HttpResponseSender::~HttpResponseSender() { return; } - if (span) { + if (wait_for_response) { bthread_id_join(response_id); - // Do not care about the result of background writing. - // TODO: this is not sent - span->set_sent_us(args.sent_us); + if (span) { + // Do not care about the result of background writing. + // TODO: this is not sent + span->set_sent_us(args.sent_us); + } + } + const int responded_error_code = cntl->ErrorCode(); + if (!manages_after_rpc_resp) { + concurrency_remover.reset(); + } + if (has_after_rpc_resp_fn && NULL != _messages) { + cntl->http_response().Swap(original_response_header); + cntl->response_attachment().swap(original_response_attachment); + cntl->CallAfterRpcResp(_messages->Request(), _messages->Response()); + } + if (manages_after_rpc_resp) { + concurrency_remover->OnResponded(responded_error_code); } + concurrency_remover.reset(); } // Normalize the sub string of `uri_path' covered by `splitter' and diff --git a/test/brpc_channel_unittest.cpp b/test/brpc_channel_unittest.cpp index db6e2ac777..6f1e33349f 100644 --- a/test/brpc_channel_unittest.cpp +++ b/test/brpc_channel_unittest.cpp @@ -134,7 +134,7 @@ static bool VerifyMyRequest(const brpc::InputMessageBase* msg_base) { class CallAfterRpcObject { public: - explicit CallAfterRpcObject() {} + explicit CallAfterRpcObject() : reset_in_callback(false) {} ~CallAfterRpcObject() { EXPECT_EQ(str, "CallAfterRpcRespTest"); @@ -144,10 +144,14 @@ class CallAfterRpcObject { str.append(s); } + bool reset_in_callback; + private: std::string str; }; +static bool g_reset_cntl_in_after_rpc_resp = false; + class MyEchoService : public ::test::EchoService { void Echo(google::protobuf::RpcController* cntl_base, const ::test::EchoRequest* req, @@ -156,8 +160,10 @@ class MyEchoService : public ::test::EchoService { brpc::Controller* cntl = static_cast(cntl_base); std::shared_ptr str_test(new CallAfterRpcObject()); + str_test->reset_in_callback = g_reset_cntl_in_after_rpc_resp; cntl->set_after_rpc_resp_fn(std::bind(&MyEchoService::CallAfterRpc, str_test, - std::placeholders::_1, std::placeholders::_2, std::placeholders::_3)); + std::placeholders::_1, std::placeholders::_2, std::placeholders::_3), + true); brpc::ClosureGuard done_guard(done); if (req->server_fail()) { cntl->SetFailed(req->server_fail(), "Server fail1"); @@ -200,6 +206,9 @@ class MyEchoService : public ::test::EchoService { EXPECT_TRUE(nullptr != cntl); EXPECT_TRUE(nullptr != request); EXPECT_TRUE(nullptr != response); + if (str->reset_in_callback) { + cntl->Reset(); + } } public: @@ -2605,7 +2614,7 @@ TEST_F(ChannelTest, connection_failed_selective) { } TEST_F(ChannelTest, success) { - for (int i = 0; i <= 1; ++i) { // Flag SingleServer + for (int i = 0; i <= 1; ++i) { // Flag SingleServer for (int j = 0; j <= 1; ++j) { // Flag Asynchronous for (int k = 0; k <=1; ++k) { // Flag ShortConnection TestSuccess(i, j, k); @@ -2614,6 +2623,28 @@ TEST_F(ChannelTest, success) { } } +TEST_F(ChannelTest, reset_in_after_rpc_resp) { + brpc::Server server; + MyEchoService service; + ASSERT_EQ(0, server.AddService(&service, brpc::SERVER_DOESNT_OWN_SERVICE)); + brpc::ServerOptions opt; + ASSERT_EQ(0, server.Start(_ep, &opt)); + + brpc::Channel channel; + SetUpChannel(&channel, true, false); + + g_reset_cntl_in_after_rpc_resp = true; + brpc::Controller cntl; + test::EchoRequest req; + test::EchoResponse res; + req.set_message(__FUNCTION__); + CallMethod(&channel, &cntl, &req, &res, false); + g_reset_cntl_in_after_rpc_resp = false; + + ASSERT_FALSE(cntl.Failed()) << cntl.ErrorText(); + ASSERT_EQ("received reset_in_after_rpc_resp", res.message()); +} + TEST_F(ChannelTest, success_parallel) { for (int i = 0; i <= 1; ++i) { // Flag SingleServer for (int j = 0; j <= 1; ++j) { // Flag Asynchronous