Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 72 additions & 57 deletions src/AsyncWebSocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ AsyncWebSocketClient::AsyncWebSocketClient(AsyncClient *client, AsyncWebSocket *

AsyncWebSocketClient::~AsyncWebSocketClient() {
{
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_queue_lock);
_messageQueue.clear();
_controlQueue.clear();
}
Expand All @@ -313,7 +313,7 @@ void AsyncWebSocketClient::_clearQueue() {
void AsyncWebSocketClient::_onAck(size_t len, uint32_t time) {
_lastMessageTime = millis();

asyncsrv::unique_lock_type lock(_lock);
asyncsrv::unique_lock_type lock(_queue_lock);

async_ws_log_v("[%s][%" PRIu32 "] START ACK(%u, %" PRIu32 ") Q:%u", _server->url(), _clientId, len, time, _messageQueue.size());

Expand All @@ -325,14 +325,12 @@ void AsyncWebSocketClient::_onAck(size_t len, uint32_t time) {
_controlQueue.pop_front();
_status = WS_DISCONNECTED;
async_ws_log_v("[%s][%" PRIu32 "] ACK WS_DISCONNECTED", _server->url(), _clientId);
if (_client) {
/*
Unlocking has to be called before return execution otherwise std::unique_lock ::~unique_lock() will get an exception pthread_mutex_unlock.
Due to _client->close() shall call the callback function _onDisconnect()
The calling flow _onDisconnect() --> _handleDisconnect() --> ~AsyncWebSocketClient()
*/
// Capture _client before unlocking: _client->close() triggers the _onDisconnect() --> _handleDisconnect() --> ~AsyncWebSocketClient() chain,
// so we must not access any member after unlock.
AsyncClient *c = _client;
if (c) {
lock.unlock();
_client->close();
c->close();
}
return;
}
Expand All @@ -357,7 +355,7 @@ void AsyncWebSocketClient::_onAck(size_t len, uint32_t time) {
}

void AsyncWebSocketClient::_onPoll() {
asyncsrv::unique_lock_type lock(_lock);
asyncsrv::unique_lock_type lock(_queue_lock);

if (!_client) {
return;
Expand Down Expand Up @@ -430,22 +428,22 @@ void AsyncWebSocketClient::_runQueue() {
}

bool AsyncWebSocketClient::queueIsFull() const {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_queue_lock);
return (_messageQueue.size() >= WS_MAX_QUEUED_MESSAGES) || (_status != WS_CONNECTED);
}

size_t AsyncWebSocketClient::queueLen() const {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_queue_lock);
return _messageQueue.size();
}

bool AsyncWebSocketClient::canSend() const {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_queue_lock);
return _messageQueue.size() < WS_MAX_QUEUED_MESSAGES;
}

bool AsyncWebSocketClient::_queueControl(uint8_t opcode, const uint8_t *data, size_t len, bool mask) {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_queue_lock);

if (!_client) {
return false;
Expand All @@ -462,7 +460,7 @@ bool AsyncWebSocketClient::_queueControl(uint8_t opcode, const uint8_t *data, si
}

bool AsyncWebSocketClient::_queueMessage(AsyncWebSocketSharedBuffer buffer, uint8_t opcode, bool mask) {
asyncsrv::unique_lock_type lock(_lock);
asyncsrv::unique_lock_type lock(_queue_lock);

if (!_client || !buffer || buffer->empty() || _status != WS_CONNECTED) {
return false;
Expand All @@ -472,18 +470,16 @@ bool AsyncWebSocketClient::_queueMessage(AsyncWebSocketSharedBuffer buffer, uint
if (closeWhenFull) {
_status = WS_DISCONNECTED;

if (_client) {
/*
Unlocking has to be called before return execution otherwise std::unique_lock ::~unique_lock() will get an exception pthread_mutex_unlock.
Due to _client->close() shall call the callback function _onDisconnect()
The calling flow _onDisconnect() --> _handleDisconnect() --> ~AsyncWebSocketClient()
*/
async_ws_log_w("[%s][%" PRIu32 "] Too many messages queued: closing connection", _server->url(), _clientId);

// Capture _client before unlocking: _client->close() triggers the _onDisconnect() --> _handleDisconnect() --> ~AsyncWebSocketClient() chain,
// so we must not access any member after unlock.
AsyncClient *c = _client;
if (c) {
lock.unlock();
_client->close();
c->close();
}

async_ws_log_w("[%s][%" PRIu32 "] Too many messages queued: closing connection", _server->url(), _clientId);

} else {
async_ws_log_w("[%s][%" PRIu32 "] Too many messages queued: discarding new message", _server->url(), _clientId);
}
Expand Down Expand Up @@ -531,7 +527,14 @@ void AsyncWebSocketClient::close(uint16_t code, const char *message) {
return;
} else {
async_ws_log_e("Failed to allocate");
_client->abort();
// Reads _client, then dereference it without any lock.
// A concurrent _onDisconnect could null + delete the client between the check and the use.
// Local capture ensures the pointer is read exactly once, eliminating the null-dereference.
// (TOCTOU)
AsyncClient *c = _client;
if (c) {
c->abort();
}
}
}
_queueControl(WS_DISCONNECT);
Expand All @@ -546,19 +549,27 @@ void AsyncWebSocketClient::_onError(int8_t err) {
}

void AsyncWebSocketClient::_onTimeout(uint32_t time) {
asyncsrv::lock_guard_type lock(_lock);
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed because this is wrong to use the _queue_lock to protect the _client ptr (which is not guarded anywhere else it should be)

if (!_client) {
// Reads _client, then dereference it without any lock.
// A concurrent _onDisconnect could null + delete the client between the check and the use.
// Local capture ensures the pointer is read exactly once, eliminating the null-dereference.
// (TOCTOU)
AsyncClient *c = _client;
if (!c) {
return;
}
async_ws_log_v("[%s][%" PRIu32 "] TIMEOUT %" PRIu32, _server->url(), _clientId, time);
_client->close();
c->close();
}

void AsyncWebSocketClient::_onDisconnect() {
asyncsrv::lock_guard_type lock(_lock);
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed because this is wrong to use the _queue_lock to protect the _client ptr (which is not guarded anywhere else it should be)

async_ws_log_v("[%s][%" PRIu32 "] DISCONNECT", _server->url(), _clientId);
_status = WS_DISCONNECTED;
_client = nullptr;
{
// Every queue method (_queueControl, _queueMessage, _runQueue, _onPoll, _onAck) reads _client while holding _queue_lock.
// For those guarded reads to be meaningful, the write must also be synchronized. This doesn't change _queue_lock's purpose — it still guards queue integrity — but ensures the "is client alive?" checks that protect queue operations see a consistent value.
asyncsrv::lock_guard_type lock(_queue_lock);
_client = nullptr;
}
_server->_handleDisconnect(this);
}

Expand Down Expand Up @@ -951,23 +962,27 @@ bool AsyncWebSocketClient::binary(const __FlashStringHelper *data, size_t len) {
#endif

IPAddress AsyncWebSocketClient::remoteIP() const {
asyncsrv::lock_guard_type lock(_lock);
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed because this is wrong to use the _queue_lock to protect the _client ptr (which is not guarded anywhere else it should be)


if (!_client) {
// Reads _client, then dereference it without any lock.
// A concurrent _onDisconnect could null + delete the client between the check and the use.
// Local capture ensures the pointer is read exactly once, eliminating the null-dereference.
// (TOCTOU)
AsyncClient *c = _client;
if (!c) {
return IPAddress((uint32_t)0U);
}

return _client->remoteIP();
return c->remoteIP();
}

uint16_t AsyncWebSocketClient::remotePort() const {
asyncsrv::lock_guard_type lock(_lock);
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed because this is wrong to use the _queue_lock to protect the _client ptr (which is not guarded anywhere else it should be)


if (!_client) {
// Reads _client, then dereference it without any lock.
// A concurrent _onDisconnect could null + delete the client between the check and the use.
// Local capture ensures the pointer is read exactly once, eliminating the null-dereference.
// (TOCTOU)
AsyncClient *c = _client;
if (!c) {
return 0;
}

return _client->remotePort();
return c->remotePort();
}

/*
Expand All @@ -981,7 +996,7 @@ void AsyncWebSocket::_handleEvent(AsyncWebSocketClient *client, AwsEventType typ
}

AsyncWebSocketClient *AsyncWebSocket::_newClient(AsyncWebServerRequest *request) {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_ws_clients_lock);
_clients.emplace_back(request, this);
// we've just detached AsyncTCP client from AsyncWebServerRequest
_handleEvent(&_clients.back(), WS_EVT_CONNECT, request, NULL, 0);
Expand All @@ -991,7 +1006,7 @@ AsyncWebSocketClient *AsyncWebSocket::_newClient(AsyncWebServerRequest *request)
}

void AsyncWebSocket::_handleDisconnect(AsyncWebSocketClient *client) {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_ws_clients_lock);
const auto client_id = client->id();
const auto iter = std::find_if(std::begin(_clients), std::end(_clients), [client_id](const AsyncWebSocketClient &c) {
return c.id() == client_id;
Expand All @@ -1002,14 +1017,14 @@ void AsyncWebSocket::_handleDisconnect(AsyncWebSocketClient *client) {
}

bool AsyncWebSocket::availableForWriteAll() {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_ws_clients_lock);
return std::none_of(std::begin(_clients), std::end(_clients), [](const AsyncWebSocketClient &c) {
return c.queueIsFull();
});
}

bool AsyncWebSocket::availableForWrite(uint32_t id) {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_ws_clients_lock);
const auto iter = std::find_if(std::begin(_clients), std::end(_clients), [id](const AsyncWebSocketClient &c) {
return c.id() == id;
});
Expand All @@ -1020,14 +1035,14 @@ bool AsyncWebSocket::availableForWrite(uint32_t id) {
}

size_t AsyncWebSocket::count() const {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_ws_clients_lock);
return std::count_if(std::begin(_clients), std::end(_clients), [](const AsyncWebSocketClient &c) {
return c.status() == WS_CONNECTED;
});
}

AsyncWebSocketClient *AsyncWebSocket::client(uint32_t id) {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_ws_clients_lock);
const auto iter = std::find_if(_clients.begin(), _clients.end(), [id](const AsyncWebSocketClient &c) {
return c.id() == id && c.status() == WS_CONNECTED;
});
Expand All @@ -1039,14 +1054,14 @@ AsyncWebSocketClient *AsyncWebSocket::client(uint32_t id) {
}

void AsyncWebSocket::close(uint32_t id, uint16_t code, const char *message) {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_ws_clients_lock);
if (AsyncWebSocketClient *c = client(id)) {
c->close(code, message);
}
}

void AsyncWebSocket::closeAll(uint16_t code, const char *message) {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_ws_clients_lock);
for (auto &c : _clients) {
if (c.status() == WS_CONNECTED) {
c.close(code, message);
Expand All @@ -1055,7 +1070,7 @@ void AsyncWebSocket::closeAll(uint16_t code, const char *message) {
}

void AsyncWebSocket::cleanupClients(uint16_t maxClients) {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_ws_clients_lock);
const size_t c = count();
if (c > maxClients) {
async_ws_log_v("[%s] CLEANUP %" PRIu32 " (%u/%" PRIu16 ")", _url.c_str(), _clients.front().id(), c, maxClients);
Expand All @@ -1071,13 +1086,13 @@ void AsyncWebSocket::cleanupClients(uint16_t maxClients) {
}

bool AsyncWebSocket::ping(uint32_t id, const uint8_t *data, size_t len) {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_ws_clients_lock);
AsyncWebSocketClient *c = client(id);
return c && c->ping(data, len);
}

AsyncWebSocket::SendStatus AsyncWebSocket::pingAll(const uint8_t *data, size_t len) {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_ws_clients_lock);
size_t hit = 0;
size_t miss = 0;
for (auto &c : _clients) {
Expand All @@ -1091,7 +1106,7 @@ AsyncWebSocket::SendStatus AsyncWebSocket::pingAll(const uint8_t *data, size_t l
}

bool AsyncWebSocket::text(uint32_t id, const uint8_t *message, size_t len) {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_ws_clients_lock);
AsyncWebSocketClient *c = client(id);
return c && c->text(makeSharedBuffer(message, len));
}
Expand Down Expand Up @@ -1138,7 +1153,7 @@ bool AsyncWebSocket::text(uint32_t id, AsyncWebSocketMessageBuffer *buffer) {
return enqueued;
}
bool AsyncWebSocket::text(uint32_t id, AsyncWebSocketSharedBuffer buffer) {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_ws_clients_lock);
AsyncWebSocketClient *c = client(id);
return c && c->text(buffer);
}
Expand Down Expand Up @@ -1188,7 +1203,7 @@ AsyncWebSocket::SendStatus AsyncWebSocket::textAll(AsyncWebSocketMessageBuffer *
}

AsyncWebSocket::SendStatus AsyncWebSocket::textAll(AsyncWebSocketSharedBuffer buffer) {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_ws_clients_lock);
size_t hit = 0;
size_t miss = 0;
for (auto &c : _clients) {
Expand All @@ -1202,7 +1217,7 @@ AsyncWebSocket::SendStatus AsyncWebSocket::textAll(AsyncWebSocketSharedBuffer bu
}

bool AsyncWebSocket::binary(uint32_t id, const uint8_t *message, size_t len) {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_ws_clients_lock);
AsyncWebSocketClient *c = client(id);
return c && c->binary(makeSharedBuffer(message, len));
}
Expand Down Expand Up @@ -1239,7 +1254,7 @@ bool AsyncWebSocket::binary(uint32_t id, AsyncWebSocketMessageBuffer *buffer) {
return enqueued;
}
bool AsyncWebSocket::binary(uint32_t id, AsyncWebSocketSharedBuffer buffer) {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_ws_clients_lock);
AsyncWebSocketClient *c = client(id);
return c && c->binary(buffer);
}
Expand Down Expand Up @@ -1280,7 +1295,7 @@ AsyncWebSocket::SendStatus AsyncWebSocket::binaryAll(AsyncWebSocketMessageBuffer
return status;
}
AsyncWebSocket::SendStatus AsyncWebSocket::binaryAll(AsyncWebSocketSharedBuffer buffer) {
asyncsrv::lock_guard_type lock(_lock);
asyncsrv::lock_guard_type lock(_ws_clients_lock);
size_t hit = 0;
size_t miss = 0;
for (auto &c : _clients) {
Expand Down
5 changes: 2 additions & 3 deletions src/AsyncWebSocket.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ class AsyncWebSocketClient {
uint8_t _pstate;
uint32_t _lastMessageTime;
uint32_t _keepAlivePeriod;
mutable asyncsrv::mutex_type _lock;
mutable asyncsrv::mutex_type _queue_lock;
std::deque<AsyncWebSocketControl> _controlQueue;
std::deque<AsyncWebSocketMessage> _messageQueue;
bool closeWhenFull = true;
Expand Down Expand Up @@ -303,7 +303,6 @@ class AsyncWebSocketClient {
uint16_t remotePort() const;

bool shouldBeDeleted() const {
asyncsrv::lock_guard_type lock(_lock);
return !_client;
}

Expand Down Expand Up @@ -371,7 +370,7 @@ class AsyncWebSocket : public AsyncWebHandler {
AwsEventHandler _eventHandler;
AwsHandshakeHandler _handshakeHandler;
bool _enabled;
mutable asyncsrv::mutex_type _lock;
mutable asyncsrv::mutex_type _ws_clients_lock;

public:
typedef enum {
Expand Down
Loading