Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import javasabr.rlib.network.impl.DefaultBufferAllocator;
import javasabr.rlib.network.impl.DefaultConnection;
import javasabr.rlib.network.impl.StringDataConnection;
import javasabr.rlib.network.impl.StringDataMtlsServerConnection;
import javasabr.rlib.network.impl.StringDataSslConnection;
import javasabr.rlib.network.packet.impl.DefaultReadableNetworkPacket;
import javasabr.rlib.network.packet.registry.ReadableNetworkPacketRegistry;
Expand Down Expand Up @@ -140,7 +141,11 @@ public static ClientNetwork<StringDataSslConnection> stringDataSslClientNetwork(
SSLContext sslContext) {
return clientNetwork(
networkConfig,
(network, channel) -> new StringDataSslConnection(network, channel, bufferAllocator, sslContext, true));
(network, channel) -> {
StringDataSslConnection connection = new StringDataSslConnection(network, channel, bufferAllocator, sslContext, true);
connection.beginHandshake();
return connection;
});
}

/**
Expand Down Expand Up @@ -196,7 +201,11 @@ public static ServerNetwork<StringDataSslConnection> stringDataSslServerNetwork(
SSLContext sslContext) {
return serverNetwork(
networkConfig,
(network, channel) -> new StringDataSslConnection(network, channel, bufferAllocator, sslContext, false));
(network, channel) -> {
StringDataSslConnection connection = new StringDataSslConnection(network, channel, bufferAllocator, sslContext, false);
connection.beginHandshake();
return connection;
});
}

/**
Expand Down Expand Up @@ -231,4 +240,26 @@ public static ServerNetwork<DefaultConnection> defaultServerNetwork(
networkConfig,
(network, channel) -> new DefaultConnection(network, channel, bufferAllocator, packetRegistry));
}

/**
* Create string packet based asynchronous Mutual TLS server network.
*
* @param networkConfig the server network configuration
* @param bufferAllocator the buffer allocator
* @param sslContext SSL context
* @return a new mTLS server network
* @since 10.0.0
*/
public static ServerNetwork<StringDataMtlsServerConnection> stringDataMtlsServerNetwork(
ServerNetworkConfig networkConfig,
BufferAllocator bufferAllocator,
SSLContext sslContext) {
return serverNetwork(
networkConfig,
(network, channel) -> {
StringDataMtlsServerConnection connection = new StringDataMtlsServerConnection(network, channel, bufferAllocator, sslContext);
connection.beginHandshake();
return connection;
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ public AbstractSslConnection(
super(network, channel, bufferAllocator, maxPacketsByRead);
this.sslEngine = sslContext.createSSLEngine();
this.sslEngine.setUseClientMode(clientMode);
}

public void beginHandshake() {
try {
this.sslEngine.beginHandshake();
} catch (SSLException e) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package javasabr.rlib.network.impl;

import javasabr.rlib.network.BufferAllocator;
import javasabr.rlib.network.Network;
import javasabr.rlib.network.packet.impl.StringReadableNetworkPacket;

import javax.net.ssl.SSLContext;
import java.nio.channels.AsynchronousSocketChannel;

/**
Comment thread
crazyrokr marked this conversation as resolved.
* @author crazyrokr
*/
public class StringDataMtlsServerConnection extends DefaultDataSslConnection<StringDataMtlsServerConnection> {

public StringDataMtlsServerConnection(
Network<StringDataMtlsServerConnection> network,
AsynchronousSocketChannel channel,
BufferAllocator bufferAllocator,
SSLContext sslContext) {
super(network, channel, bufferAllocator, sslContext, 100, 2, false);
sslEngine.setNeedClientAuth(true);
}

@Override
protected StringReadableNetworkPacket<StringDataMtlsServerConnection> createReadablePacket() {
return new StringReadableNetworkPacket<>();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ protected int doHandshake(ByteBuffer networkBuffer, int receivedBytes) {
case NEED_WRAP: {
log.debug(remoteAddress, "[%s] Send command to wrap data"::formatted);
packetWriter.accept(SslWrapRequestNetworkPacket.getInstance());
if (networkBuffer.hasRemaining()) {
return decryptAndRead(networkBuffer);
}
NetworkUtils.cleanNetworkBuffer(networkBuffer);
return SKIP_READ_PACKETS;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ protected ByteBuffer doHandshake(HandshakeStatus handshakeStatus) {
break;
}
case NEED_UNWRAP: {
break;
return EMPTY_BUFFER;
}
default: {
throw new IllegalStateException("Invalid SSL status:" + handshakeStatus);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import javasabr.rlib.common.util.Utils;
import javasabr.rlib.network.client.ClientNetwork;
import javasabr.rlib.network.impl.DefaultBufferAllocator;
import javasabr.rlib.network.impl.StringDataMtlsServerConnection;
import javasabr.rlib.network.impl.StringDataSslConnection;
import javasabr.rlib.network.packet.ReadableNetworkPacket;
import javasabr.rlib.network.packet.impl.StringReadableNetworkPacket;
Expand Down Expand Up @@ -328,6 +329,46 @@ void shouldReceiveManyPacketsFromSmallToBigSize() {
}
}

@Test
@SneakyThrows
void shouldRejectClientWithoutCertificateWithinMutualTls() {
InputStream serverKeystoreFile = StringSslNetworkTest.class.getResourceAsStream("/ssl/rlib_test_cert.p12");
SSLContext serverSslContext = NetworkUtils.createSslContext(serverKeystoreFile, "test");
ServerNetworkConfig serverConfig = ServerNetworkConfig.SimpleServerNetworkConfig.builder().build();
BufferAllocator bufferAllocator = new DefaultBufferAllocator(serverConfig);

ServerNetwork<StringDataMtlsServerConnection> serverNetwork =
NetworkFactory.stringDataMtlsServerNetwork(serverConfig, bufferAllocator, serverSslContext);

InetSocketAddress serverAddress = serverNetwork.start();
CountDownLatch dataReceivedByServer = new CountDownLatch(1);

serverNetwork
.accepted()
.flatMap(Connection::receivedEvents)
.subscribe(event -> dataReceivedByServer.countDown());

SSLContext clientWithoutCertContext = NetworkUtils.createAllTrustedClientSslContext();
ClientNetwork<StringDataSslConnection> clientNetwork = NetworkFactory.stringDataSslClientNetwork(
NetworkConfig.DEFAULT_CLIENT,
new DefaultBufferAllocator(NetworkConfig.DEFAULT_CLIENT),
clientWithoutCertContext);

try {
clientNetwork
.connectReactive(serverAddress)
.doOnNext(connection -> connection.sendInBackground(new StringWritableNetworkPacket<>("no cert")))
.subscribe();

assertThat(dataReceivedByServer.await(5, TimeUnit.SECONDS))
.as("Server must reject a client that presents no certificate when requireClientAuth=true.")
.isFalse();
} finally {
serverNetwork.shutdown();
clientNetwork.shutdown();
}
}

private static StringWritableNetworkPacket<StringDataSslConnection> newMessage(int minMessageLength, int maxMessageLength) {
return new StringWritableNetworkPacket<>(StringUtils.generate(minMessageLength, maxMessageLength));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
package javasabr.rlib.network.packet.impl;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.nio.ByteBuffer;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import javasabr.rlib.network.BufferAllocator;
import javasabr.rlib.network.Network;
import javasabr.rlib.network.NetworkConfig;
import javasabr.rlib.network.UnsafeConnection;
import javasabr.rlib.network.impl.DefaultBufferAllocator;
import javasabr.rlib.network.packet.ReadableNetworkPacket;
import javasabr.rlib.network.packet.WritableNetworkPacket;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import javax.net.ssl.SSLEngineResult.Status;
import javax.net.ssl.SSLSession;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.junit.jupiter.MockitoSettings;
import org.mockito.quality.Strictness;

/**
* The tests of SSL packet reader
*
* @author crazyrokr
*/
@ExtendWith(MockitoExtension.class)
@MockitoSettings(strictness = Strictness.LENIENT)
public class SslPacketReaderTest {

private interface TestConnection extends UnsafeConnection<TestConnection> {}

@Mock
private TestConnection connection;

@Mock
private Network<TestConnection> network;

@Mock
private SSLEngine sslEngine;

@Mock
private SSLSession sslSession;

@Mock
private Consumer<ReadableNetworkPacket<TestConnection>> packetHandler;

@Mock
private Consumer<WritableNetworkPacket<TestConnection>> packetWriter;

private BufferAllocator bufferAllocator;

@BeforeEach
void setUp() {
bufferAllocator = new DefaultBufferAllocator(NetworkConfig.DEFAULT_CLIENT);
when(connection.bufferAllocator()).thenReturn(bufferAllocator);
when(connection.network()).thenReturn((Network) network);
when(connection.remoteAddress()).thenReturn("test-address");
when(network.config()).thenReturn(NetworkConfig.DEFAULT_CLIENT);
when(sslEngine.getSession()).thenReturn(sslSession);
when(sslSession.getApplicationBufferSize()).thenReturn(1024);
when(sslSession.getPacketBufferSize()).thenReturn(1024);
}

private static class TestSslPacketReader extends
AbstractSslNetworkPacketReader<ReadableNetworkPacket<TestConnection>, TestConnection> {

private final AtomicInteger readPacketsCount = new AtomicInteger();

protected TestSslPacketReader(
TestConnection connection,
Consumer<? super ReadableNetworkPacket<TestConnection>> packetHandler,
SSLEngine sslEngine,
Consumer<WritableNetworkPacket<TestConnection>> packetWriter) {
super(connection, () -> {}, packetHandler, packetHandler, sslEngine, packetWriter, 100);
}

@Override
protected boolean canStartReadPacket(ByteBuffer buffer) {
return buffer.remaining() >= 1;
}

@Override
protected int readFullPacketLength(ByteBuffer buffer) {
return 1;
}

@Override
protected ReadableNetworkPacket<TestConnection> createPacketFor(
ByteBuffer buffer,
int startPacketPosition,
int packetFullLength,
int packetDataLength) {
buffer.get(); // consume 1 byte
readPacketsCount.incrementAndGet();
return mock(ReadableNetworkPacket.class);
}
}

@Test
void testShouldNotLoseDataOnNeedWrapDuringHandshake() throws Exception {
// given
var reader = new TestSslPacketReader(connection, packetHandler, sslEngine, packetWriter);

// Initial state: NEED_UNWRAP
when(sslEngine.getHandshakeStatus()).thenReturn(HandshakeStatus.NEED_UNWRAP);

// First unwrap will result in NEED_WRAP and status OK, consuming some data
// MQTT broker received 10 bytes, first 5 bytes are handshake, and the last 5 bytes are application data
ByteBuffer networkData = ByteBuffer.allocate(10);
networkData.put(new byte[10]);
networkData.flip();

// doHandshake calls unwrap in NEED_UNWRAP, consumes first 5 bytes, then returns OK
when(sslEngine.unwrap(any(ByteBuffer.class), any(ByteBuffer[].class))).thenAnswer(invocation -> {
ByteBuffer in = invocation.getArgument(0);
in.position(in.position() + 5); // consume 5 bytes of handshake
// Change status to NEED_WRAP for next getHandshakeStatus() call
when(sslEngine.getHandshakeStatus()).thenReturn(HandshakeStatus.NEED_WRAP);
return new SSLEngineResult(Status.OK, HandshakeStatus.NEED_WRAP, 5, 0);
});

// decryptAndRead calls unwrap, consumes the remaining 5 bytes, then return FINISHED or NOT_HANDSHAKING
when(sslEngine.unwrap(any(ByteBuffer.class), any(ByteBuffer.class))).thenAnswer(invocation -> {
ByteBuffer in = invocation.getArgument(0);
ByteBuffer out = invocation.getArgument(1);
int remaining = in.remaining();
in.position(in.limit()); // consume all
out.put(new byte[remaining]); // put decrypted data (mocked)
when(sslEngine.getHandshakeStatus()).thenReturn(HandshakeStatus.NOT_HANDSHAKING);
return new SSLEngineResult(Status.OK, HandshakeStatus.NOT_HANDSHAKING, remaining, remaining);
});

// when
reader.readPackets(networkData);

// then
// readPackets should have been called for the remaining 5 bytes,
// since each packet is 1 byte, it should have read 5 packets
assertThat(reader.readPacketsCount.get()).isEqualTo(5);
verify(packetWriter).accept(any(SslWrapRequestNetworkPacket.class));
}
}
Loading