diff --git a/rlib-network/src/main/java/javasabr/rlib/network/NetworkFactory.java b/rlib-network/src/main/java/javasabr/rlib/network/NetworkFactory.java index 0c09b9a5..e0b0b00b 100644 --- a/rlib-network/src/main/java/javasabr/rlib/network/NetworkFactory.java +++ b/rlib-network/src/main/java/javasabr/rlib/network/NetworkFactory.java @@ -7,6 +7,7 @@ import javasabr.rlib.network.impl.DefaultBufferAllocator; import javasabr.rlib.network.impl.DefaultConnection; import javasabr.rlib.network.impl.StringDataConnection; +import javasabr.rlib.network.impl.StringDataMtlsServerConnection; import javasabr.rlib.network.impl.StringDataSslConnection; import javasabr.rlib.network.packet.impl.DefaultReadableNetworkPacket; import javasabr.rlib.network.packet.registry.ReadableNetworkPacketRegistry; @@ -140,7 +141,11 @@ public static ClientNetwork stringDataSslClientNetwork( SSLContext sslContext) { return clientNetwork( networkConfig, - (network, channel) -> new StringDataSslConnection(network, channel, bufferAllocator, sslContext, true)); + (network, channel) -> { + StringDataSslConnection connection = new StringDataSslConnection(network, channel, bufferAllocator, sslContext, true); + connection.beginHandshake(); + return connection; + }); } /** @@ -196,7 +201,11 @@ public static ServerNetwork stringDataSslServerNetwork( SSLContext sslContext) { return serverNetwork( networkConfig, - (network, channel) -> new StringDataSslConnection(network, channel, bufferAllocator, sslContext, false)); + (network, channel) -> { + StringDataSslConnection connection = new StringDataSslConnection(network, channel, bufferAllocator, sslContext, false); + connection.beginHandshake(); + return connection; + }); } /** @@ -231,4 +240,26 @@ public static ServerNetwork defaultServerNetwork( networkConfig, (network, channel) -> new DefaultConnection(network, channel, bufferAllocator, packetRegistry)); } + + /** + * Create string packet based asynchronous Mutual TLS server network. + * + * @param networkConfig the server network configuration + * @param bufferAllocator the buffer allocator + * @param sslContext SSL context + * @return a new mTLS server network + * @since 10.0.0 + */ + public static ServerNetwork stringDataMtlsServerNetwork( + ServerNetworkConfig networkConfig, + BufferAllocator bufferAllocator, + SSLContext sslContext) { + return serverNetwork( + networkConfig, + (network, channel) -> { + StringDataMtlsServerConnection connection = new StringDataMtlsServerConnection(network, channel, bufferAllocator, sslContext); + connection.beginHandshake(); + return connection; + }); + } } diff --git a/rlib-network/src/main/java/javasabr/rlib/network/impl/AbstractSslConnection.java b/rlib-network/src/main/java/javasabr/rlib/network/impl/AbstractSslConnection.java index 9ae853d7..f015070b 100644 --- a/rlib-network/src/main/java/javasabr/rlib/network/impl/AbstractSslConnection.java +++ b/rlib-network/src/main/java/javasabr/rlib/network/impl/AbstractSslConnection.java @@ -26,6 +26,9 @@ public AbstractSslConnection( super(network, channel, bufferAllocator, maxPacketsByRead); this.sslEngine = sslContext.createSSLEngine(); this.sslEngine.setUseClientMode(clientMode); + } + + public void beginHandshake() { try { this.sslEngine.beginHandshake(); } catch (SSLException e) { diff --git a/rlib-network/src/main/java/javasabr/rlib/network/impl/StringDataMtlsServerConnection.java b/rlib-network/src/main/java/javasabr/rlib/network/impl/StringDataMtlsServerConnection.java new file mode 100644 index 00000000..9bd21203 --- /dev/null +++ b/rlib-network/src/main/java/javasabr/rlib/network/impl/StringDataMtlsServerConnection.java @@ -0,0 +1,28 @@ +package javasabr.rlib.network.impl; + +import javasabr.rlib.network.BufferAllocator; +import javasabr.rlib.network.Network; +import javasabr.rlib.network.packet.impl.StringReadableNetworkPacket; + +import javax.net.ssl.SSLContext; +import java.nio.channels.AsynchronousSocketChannel; + +/** + * @author crazyrokr + */ +public class StringDataMtlsServerConnection extends DefaultDataSslConnection { + + public StringDataMtlsServerConnection( + Network network, + AsynchronousSocketChannel channel, + BufferAllocator bufferAllocator, + SSLContext sslContext) { + super(network, channel, bufferAllocator, sslContext, 100, 2, false); + sslEngine.setNeedClientAuth(true); + } + + @Override + protected StringReadableNetworkPacket createReadablePacket() { + return new StringReadableNetworkPacket<>(); + } +} diff --git a/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketReader.java b/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketReader.java index 6ab75309..1225f4c4 100644 --- a/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketReader.java +++ b/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketReader.java @@ -159,6 +159,9 @@ protected int doHandshake(ByteBuffer networkBuffer, int receivedBytes) { case NEED_WRAP: { log.debug(remoteAddress, "[%s] Send command to wrap data"::formatted); packetWriter.accept(SslWrapRequestNetworkPacket.getInstance()); + if (networkBuffer.hasRemaining()) { + return decryptAndRead(networkBuffer); + } NetworkUtils.cleanNetworkBuffer(networkBuffer); return SKIP_READ_PACKETS; } diff --git a/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketWriter.java b/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketWriter.java index f15f9aee..b2b0fb57 100644 --- a/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketWriter.java +++ b/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketWriter.java @@ -197,7 +197,7 @@ protected ByteBuffer doHandshake(HandshakeStatus handshakeStatus) { break; } case NEED_UNWRAP: { - break; + return EMPTY_BUFFER; } default: { throw new IllegalStateException("Invalid SSL status:" + handshakeStatus); diff --git a/rlib-network/src/test/java/javasabr/rlib/network/StringSslNetworkTest.java b/rlib-network/src/test/java/javasabr/rlib/network/StringSslNetworkTest.java index a91a4419..552de2d0 100644 --- a/rlib-network/src/test/java/javasabr/rlib/network/StringSslNetworkTest.java +++ b/rlib-network/src/test/java/javasabr/rlib/network/StringSslNetworkTest.java @@ -23,6 +23,7 @@ import javasabr.rlib.common.util.Utils; import javasabr.rlib.network.client.ClientNetwork; import javasabr.rlib.network.impl.DefaultBufferAllocator; +import javasabr.rlib.network.impl.StringDataMtlsServerConnection; import javasabr.rlib.network.impl.StringDataSslConnection; import javasabr.rlib.network.packet.ReadableNetworkPacket; import javasabr.rlib.network.packet.impl.StringReadableNetworkPacket; @@ -328,6 +329,46 @@ void shouldReceiveManyPacketsFromSmallToBigSize() { } } + @Test + @SneakyThrows + void shouldRejectClientWithoutCertificateWithinMutualTls() { + InputStream serverKeystoreFile = StringSslNetworkTest.class.getResourceAsStream("/ssl/rlib_test_cert.p12"); + SSLContext serverSslContext = NetworkUtils.createSslContext(serverKeystoreFile, "test"); + ServerNetworkConfig serverConfig = ServerNetworkConfig.SimpleServerNetworkConfig.builder().build(); + BufferAllocator bufferAllocator = new DefaultBufferAllocator(serverConfig); + + ServerNetwork serverNetwork = + NetworkFactory.stringDataMtlsServerNetwork(serverConfig, bufferAllocator, serverSslContext); + + InetSocketAddress serverAddress = serverNetwork.start(); + CountDownLatch dataReceivedByServer = new CountDownLatch(1); + + serverNetwork + .accepted() + .flatMap(Connection::receivedEvents) + .subscribe(event -> dataReceivedByServer.countDown()); + + SSLContext clientWithoutCertContext = NetworkUtils.createAllTrustedClientSslContext(); + ClientNetwork clientNetwork = NetworkFactory.stringDataSslClientNetwork( + NetworkConfig.DEFAULT_CLIENT, + new DefaultBufferAllocator(NetworkConfig.DEFAULT_CLIENT), + clientWithoutCertContext); + + try { + clientNetwork + .connectReactive(serverAddress) + .doOnNext(connection -> connection.sendInBackground(new StringWritableNetworkPacket<>("no cert"))) + .subscribe(); + + assertThat(dataReceivedByServer.await(5, TimeUnit.SECONDS)) + .as("Server must reject a client that presents no certificate when requireClientAuth=true.") + .isFalse(); + } finally { + serverNetwork.shutdown(); + clientNetwork.shutdown(); + } + } + private static StringWritableNetworkPacket newMessage(int minMessageLength, int maxMessageLength) { return new StringWritableNetworkPacket<>(StringUtils.generate(minMessageLength, maxMessageLength)); } diff --git a/rlib-network/src/test/java/javasabr/rlib/network/packet/impl/SslPacketReaderTest.java b/rlib-network/src/test/java/javasabr/rlib/network/packet/impl/SslPacketReaderTest.java new file mode 100644 index 00000000..47eb7482 --- /dev/null +++ b/rlib-network/src/test/java/javasabr/rlib/network/packet/impl/SslPacketReaderTest.java @@ -0,0 +1,153 @@ +package javasabr.rlib.network.packet.impl; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.nio.ByteBuffer; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; +import javasabr.rlib.network.BufferAllocator; +import javasabr.rlib.network.Network; +import javasabr.rlib.network.NetworkConfig; +import javasabr.rlib.network.UnsafeConnection; +import javasabr.rlib.network.impl.DefaultBufferAllocator; +import javasabr.rlib.network.packet.ReadableNetworkPacket; +import javasabr.rlib.network.packet.WritableNetworkPacket; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLEngineResult.HandshakeStatus; +import javax.net.ssl.SSLEngineResult.Status; +import javax.net.ssl.SSLSession; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.mockito.junit.jupiter.MockitoSettings; +import org.mockito.quality.Strictness; + +/** + * The tests of SSL packet reader + * + * @author crazyrokr + */ +@ExtendWith(MockitoExtension.class) +@MockitoSettings(strictness = Strictness.LENIENT) +public class SslPacketReaderTest { + + private interface TestConnection extends UnsafeConnection {} + + @Mock + private TestConnection connection; + + @Mock + private Network network; + + @Mock + private SSLEngine sslEngine; + + @Mock + private SSLSession sslSession; + + @Mock + private Consumer> packetHandler; + + @Mock + private Consumer> packetWriter; + + private BufferAllocator bufferAllocator; + + @BeforeEach + void setUp() { + bufferAllocator = new DefaultBufferAllocator(NetworkConfig.DEFAULT_CLIENT); + when(connection.bufferAllocator()).thenReturn(bufferAllocator); + when(connection.network()).thenReturn((Network) network); + when(connection.remoteAddress()).thenReturn("test-address"); + when(network.config()).thenReturn(NetworkConfig.DEFAULT_CLIENT); + when(sslEngine.getSession()).thenReturn(sslSession); + when(sslSession.getApplicationBufferSize()).thenReturn(1024); + when(sslSession.getPacketBufferSize()).thenReturn(1024); + } + + private static class TestSslPacketReader extends + AbstractSslNetworkPacketReader, TestConnection> { + + private final AtomicInteger readPacketsCount = new AtomicInteger(); + + protected TestSslPacketReader( + TestConnection connection, + Consumer> packetHandler, + SSLEngine sslEngine, + Consumer> packetWriter) { + super(connection, () -> {}, packetHandler, packetHandler, sslEngine, packetWriter, 100); + } + + @Override + protected boolean canStartReadPacket(ByteBuffer buffer) { + return buffer.remaining() >= 1; + } + + @Override + protected int readFullPacketLength(ByteBuffer buffer) { + return 1; + } + + @Override + protected ReadableNetworkPacket createPacketFor( + ByteBuffer buffer, + int startPacketPosition, + int packetFullLength, + int packetDataLength) { + buffer.get(); // consume 1 byte + readPacketsCount.incrementAndGet(); + return mock(ReadableNetworkPacket.class); + } + } + + @Test + void testShouldNotLoseDataOnNeedWrapDuringHandshake() throws Exception { + // given + var reader = new TestSslPacketReader(connection, packetHandler, sslEngine, packetWriter); + + // Initial state: NEED_UNWRAP + when(sslEngine.getHandshakeStatus()).thenReturn(HandshakeStatus.NEED_UNWRAP); + + // First unwrap will result in NEED_WRAP and status OK, consuming some data + // MQTT broker received 10 bytes, first 5 bytes are handshake, and the last 5 bytes are application data + ByteBuffer networkData = ByteBuffer.allocate(10); + networkData.put(new byte[10]); + networkData.flip(); + + // doHandshake calls unwrap in NEED_UNWRAP, consumes first 5 bytes, then returns OK + when(sslEngine.unwrap(any(ByteBuffer.class), any(ByteBuffer[].class))).thenAnswer(invocation -> { + ByteBuffer in = invocation.getArgument(0); + in.position(in.position() + 5); // consume 5 bytes of handshake + // Change status to NEED_WRAP for next getHandshakeStatus() call + when(sslEngine.getHandshakeStatus()).thenReturn(HandshakeStatus.NEED_WRAP); + return new SSLEngineResult(Status.OK, HandshakeStatus.NEED_WRAP, 5, 0); + }); + + // decryptAndRead calls unwrap, consumes the remaining 5 bytes, then return FINISHED or NOT_HANDSHAKING + when(sslEngine.unwrap(any(ByteBuffer.class), any(ByteBuffer.class))).thenAnswer(invocation -> { + ByteBuffer in = invocation.getArgument(0); + ByteBuffer out = invocation.getArgument(1); + int remaining = in.remaining(); + in.position(in.limit()); // consume all + out.put(new byte[remaining]); // put decrypted data (mocked) + when(sslEngine.getHandshakeStatus()).thenReturn(HandshakeStatus.NOT_HANDSHAKING); + return new SSLEngineResult(Status.OK, HandshakeStatus.NOT_HANDSHAKING, remaining, remaining); + }); + + // when + reader.readPackets(networkData); + + // then + // readPackets should have been called for the remaining 5 bytes, + // since each packet is 1 byte, it should have read 5 packets + assertThat(reader.readPacketsCount.get()).isEqualTo(5); + verify(packetWriter).accept(any(SslWrapRequestNetworkPacket.class)); + } +}