Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"type": "prerelease",
"comment": "Fix WebSocket binaryType handling — stop unconditional Blob interception of binary messages",
"packageName": "react-native-windows",
"email": "gordomacmaster@gmail.com",
"dependentChangeType": "patch"
}
30 changes: 30 additions & 0 deletions vnext/Desktop.IntegrationTests/RNTesterHeadlessTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,36 @@ TEST_CLASS (RNTesterHeadlessTests) {
auto status = TestModule::AwaitCompletion();
Assert::IsTrue(status == TestStatus::Passed, L"Test did not pass (JS did not call markTestPassed within timeout)");
}

BEGIN_TEST_METHOD_ATTRIBUTE(WebSocketArrayBuffer)
TEST_IGNORE()
END_TEST_METHOD_ATTRIBUTE()
TEST_METHOD(WebSocketArrayBuffer) {
TestModule::Reset();

winrt::handle instanceLoadedEvent{CreateEvent(nullptr, TRUE, FALSE, nullptr)};
bool instanceFailed{false};

auto holder = TestReactNativeHostHolder(
L"IntegrationTests/WebSocketArrayBufferTest",
[&instanceLoadedEvent, &instanceFailed](msrn::ReactNativeHost const &host) noexcept {
host.InstanceSettings().InstanceLoaded(
[&instanceLoadedEvent, &instanceFailed](auto const &, msrn::InstanceLoadedEventArgs args) noexcept {
instanceFailed = args.Failed();
SetEvent(instanceLoadedEvent.get());
});
});

WaitForSingleObject(instanceLoadedEvent.get(), INFINITE);
if (instanceFailed) {
auto err = holder.GetLastError();
auto msg = L"InstanceLoaded reported failure: " + (err.empty() ? L"(no error captured)" : err);
Assert::Fail(msg.c_str());
}

auto status = TestModule::AwaitCompletion();
Assert::IsTrue(status == TestStatus::Passed, L"Test did not pass (JS did not call markTestPassed within timeout)");
}
};

} // namespace Microsoft::React::Test
15 changes: 15 additions & 0 deletions vnext/Shared/Modules/IWebSocketModuleContentHandler.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,26 @@ namespace Microsoft::React {
struct IWebSocketModuleContentHandler {
virtual ~IWebSocketModuleContentHandler() noexcept {}

/// Returns true if this handler should process messages for the given socket.
virtual bool CanHandleSocket(int64_t socketId) noexcept = 0;

virtual void ProcessMessage(std::string &&message, winrt::Microsoft::ReactNative::JSValueObject &params) noexcept = 0;

virtual void ProcessMessage(
std::vector<uint8_t> &&message,
winrt::Microsoft::ReactNative::JSValueObject &params) noexcept = 0;

/// Check CanHandleSocket() then ProcessMessage() in one call.
/// Returns true if the message was handled.
virtual bool TryProcessMessage(
int64_t socketId,
std::string &&message,
winrt::Microsoft::ReactNative::JSValueObject &params) noexcept = 0;

virtual bool TryProcessMessage(
int64_t socketId,
std::vector<uint8_t> &&message,
winrt::Microsoft::ReactNative::JSValueObject &params) noexcept = 0;
};

} // namespace Microsoft::React
11 changes: 8 additions & 3 deletions vnext/Shared/Modules/WebSocketModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,18 +83,23 @@ shared_ptr<IWebSocketResource> WebSocketTurboModule::CreateResource(int64_t id,
if (auto prop = propBag.Get(BlobModuleContentHandlerPropertyId()))
contentHandler = prop.Value().lock();

bool handled = false;
if (contentHandler) {
if (isBinary) {
auto buffer = CryptographicBuffer::DecodeFromBase64String(winrt::to_hstring(message));
winrt::com_array<uint8_t> arr;
CryptographicBuffer::CopyToByteArray(buffer, arr);
auto data = vector<uint8_t>(arr.begin(), arr.end());

contentHandler->ProcessMessage(std::move(data), args);
handled = contentHandler->TryProcessMessage(id, std::move(data), args);
} else {
contentHandler->ProcessMessage(string{message}, args);
handled = contentHandler->TryProcessMessage(id, string{message}, args);
}
} else {
}
// When the content handler processes the message, it takes ownership of the
// payload and populates args itself (e.g. as a blob reference), so we only
// fall back to setting args["data"] when no handler claimed the message.
if (!handled) {
args["data"] = message;
Comment thread
gmacmaster marked this conversation as resolved.
}

Expand Down
37 changes: 37 additions & 0 deletions vnext/Shared/Networking/DefaultBlobResource.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,11 @@ BlobWebSocketModuleContentHandler::BlobWebSocketModuleContentHandler(shared_ptr<

#pragma region IWebSocketModuleContentHandler

bool BlobWebSocketModuleContentHandler::CanHandleSocket(int64_t socketId) noexcept /*override*/ {
scoped_lock lock{m_mutex};
return m_socketIds.find(socketId) != m_socketIds.end();
}

void BlobWebSocketModuleContentHandler::ProcessMessage(
string &&message,
msrn::JSValueObject &params) noexcept /*override*/
Expand All @@ -241,6 +246,38 @@ void BlobWebSocketModuleContentHandler::ProcessMessage(
params[blobKeys.Type] = blobKeys.Blob;
}

bool BlobWebSocketModuleContentHandler::TryProcessMessage(
int64_t socketId,
string &&message,
msrn::JSValueObject &params) noexcept /*override*/
{
scoped_lock lock{m_mutex};
if (m_socketIds.find(socketId) == m_socketIds.end())
return false;

params[blobKeys.Data] = std::move(message);
return true;
}

bool BlobWebSocketModuleContentHandler::TryProcessMessage(
int64_t socketId,
vector<uint8_t> &&message,
msrn::JSValueObject &params) noexcept /*override*/
{
scoped_lock lock{m_mutex};
if (m_socketIds.find(socketId) == m_socketIds.end())
return false;

auto blob = msrn::JSValueObject{
{blobKeys.Offset, 0},
{blobKeys.Size, message.size()},
{blobKeys.BlobId, m_blobPersistor->StoreMessage(std::move(message))}};

params[blobKeys.Data] = std::move(blob);
params[blobKeys.Type] = blobKeys.Blob;
return true;
}

#pragma endregion IWebSocketModuleContentHandler

void BlobWebSocketModuleContentHandler::Register(int64_t socketID) noexcept {
Expand Down
12 changes: 12 additions & 0 deletions vnext/Shared/Networking/DefaultBlobResource.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,23 @@ class BlobWebSocketModuleContentHandler final : public IWebSocketModuleContentHa

#pragma region IWebSocketModuleContentHandler

bool CanHandleSocket(int64_t socketId) noexcept override;

void ProcessMessage(std::string &&message, winrt::Microsoft::ReactNative::JSValueObject &params) noexcept override;

void ProcessMessage(std::vector<uint8_t> &&message, winrt::Microsoft::ReactNative::JSValueObject &params) noexcept
override;

bool TryProcessMessage(
int64_t socketId,
std::string &&message,
winrt::Microsoft::ReactNative::JSValueObject &params) noexcept override;

bool TryProcessMessage(
int64_t socketId,
std::vector<uint8_t> &&message,
winrt::Microsoft::ReactNative::JSValueObject &params) noexcept override;

#pragma endregion IWebSocketModuleContentHandler

void Register(int64_t socketID) noexcept;
Expand Down
4 changes: 4 additions & 0 deletions vnext/overrides.json
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,10 @@
"type": "platform",
"file": "src-win/IntegrationTests/websocket_integration_test_server_blob.js"
},
{
"type": "platform",
"file": "src-win/IntegrationTests/WebSocketArrayBufferTest.js"
},
{
"type": "platform",
"file": "src-win/IntegrationTests/WebSocketBinaryTest.js"
Expand Down
76 changes: 76 additions & 0 deletions vnext/src-win/IntegrationTests/WebSocketArrayBufferTest.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/**
* Copyright (c) Microsoft Corporation.
* Licensed under the MIT License.
* @format
*/

'use strict';

const {TurboModuleRegistry} = require('react-native');
const TestModule = TurboModuleRegistry.get('TestModule');

if (!TestModule) {
throw new Error('TestModule is not available');
}

// eslint-disable-next-line @microsoft/sdl/no-insecure-url
const WS_URL = 'ws://localhost:5555/rnw/rntester/websocketbinarytest';

const socket = new WebSocket(WS_URL);
socket.binaryType = 'arraybuffer';

socket.addEventListener('open', () => {
socket.send('hello');
});

socket.addEventListener('message', event => {
const data = event.data;

if (!(data instanceof ArrayBuffer)) {
console.log(
'WebSocketArrayBufferTest FAIL: expected ArrayBuffer, got ' + typeof data,
);
TestModule.markTestPassed(false);
socket.close();
return;
}

const bytes = new Uint8Array(data);
const expected = new Uint8Array([4, 5, 6, 7]);

if (bytes.length !== expected.length) {
console.log(
'WebSocketArrayBufferTest FAIL: expected ' +
expected.length +
' bytes, got ' +
bytes.length,
);
TestModule.markTestPassed(false);
socket.close();
return;
}

for (let i = 0; i < expected.length; i++) {
if (bytes[i] !== expected[i]) {
console.log(
'WebSocketArrayBufferTest FAIL: byte[' +
i +
'] expected ' +
expected[i] +
' got ' +
bytes[i],
);
TestModule.markTestPassed(false);
socket.close();
return;
}
}

TestModule.markTestPassed(true);
socket.close();
});

socket.addEventListener('error', () => {
console.log('WebSocketArrayBufferTest FAIL: WebSocket error');
TestModule.markTestPassed(false);
});