Skip to content

Commit 3264f44

Browse files
authored
feat: Add complete server/client TLS support (#158)
BREAKING CHANGE: TLS client API now matches NodeJS official tls API.
1 parent 755d7cb commit 3264f44

36 files changed

+2609
-1445
lines changed

README.md

+193-59
Large diffs are not rendered by default.

android/src/main/java/com/asterinet/react/tcpsocket/SSLCertificateHelper.java

+24-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
import android.annotation.SuppressLint;
44
import android.content.Context;
55

6+
import androidx.annotation.NonNull;
7+
import androidx.annotation.RawRes;
8+
69
import java.io.IOException;
710
import java.io.InputStream;
811
import java.net.URI;
@@ -13,14 +16,16 @@
1316
import java.security.cert.CertificateFactory;
1417
import java.security.cert.X509Certificate;
1518

19+
import javax.net.ssl.KeyManagerFactory;
1620
import javax.net.ssl.SSLContext;
21+
import javax.net.ssl.SSLServerSocketFactory;
22+
import javax.net.ssl.SSLSocket;
1723
import javax.net.ssl.SSLSocketFactory;
1824
import javax.net.ssl.TrustManager;
1925
import javax.net.ssl.TrustManagerFactory;
26+
import javax.net.ssl.X509ExtendedKeyManager;
2027
import javax.net.ssl.X509TrustManager;
2128

22-
import androidx.annotation.NonNull;
23-
import androidx.annotation.RawRes;
2429

2530
final class SSLCertificateHelper {
2631
/**
@@ -34,6 +39,23 @@ static SSLSocketFactory createBlindSocketFactory() throws GeneralSecurityExcepti
3439
return ctx.getSocketFactory();
3540
}
3641

42+
static SSLServerSocketFactory createServerSocketFactory(Context context, @NonNull final String keyStoreResourceUri) throws GeneralSecurityException, IOException {
43+
char[] password = "".toCharArray();
44+
45+
InputStream keyStoreInput = getRawResourceStream(context, keyStoreResourceUri);
46+
KeyStore keyStore = KeyStore.getInstance("PKCS12");
47+
keyStore.load(keyStoreInput, password);
48+
keyStoreInput.close();
49+
50+
KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance("X509");
51+
keyManagerFactory.init(keyStore, password);
52+
53+
SSLContext sslContext = SSLContext.getInstance("TLS");
54+
sslContext.init(keyManagerFactory.getKeyManagers(), new TrustManager[]{new BlindTrustManager()}, null);
55+
56+
return sslContext.getServerSocketFactory();
57+
}
58+
3759
/**
3860
* Creates an SSLSocketFactory instance for use with the CA provided in the resource file.
3961
*

android/src/main/java/com/asterinet/react/tcpsocket/TcpEventListener.java

+23-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package com.asterinet.react.tcpsocket;
22

33
import android.util.Base64;
4+
import android.util.Log;
45

56
import com.facebook.react.bridge.Arguments;
67
import com.facebook.react.bridge.ReactContext;
@@ -24,6 +25,14 @@ public TcpEventListener(final ReactContext reactContext) {
2425
}
2526

2627
public void onConnection(int serverId, int clientId, Socket socket) {
28+
onSocketConnection("connection", serverId, clientId, socket);
29+
}
30+
31+
public void onSecureConnection(int serverId, int clientId, Socket socket) {
32+
onSocketConnection("secureConnection", serverId, clientId, socket);
33+
}
34+
35+
private void onSocketConnection(String connectionType, int serverId, int clientId, Socket socket) {
2736
WritableMap eventParams = Arguments.createMap();
2837
eventParams.putInt("id", serverId);
2938

@@ -42,7 +51,7 @@ public void onConnection(int serverId, int clientId, Socket socket) {
4251
infoParams.putMap("connection", connectionParams);
4352
eventParams.putMap("info", infoParams);
4453

45-
sendEvent("connection", eventParams);
54+
sendEvent(connectionType, eventParams);
4655
}
4756

4857
public void onConnect(int id, TcpSocketClient client) {
@@ -83,7 +92,12 @@ public void onData(int id, byte[] data) {
8392
sendEvent("data", eventParams);
8493
}
8594

86-
public void onWritten(int id, int msgId, @Nullable String error) {
95+
public void onWritten(int id, int msgId, @Nullable Exception e) {
96+
String error = null;
97+
if (e != null) {
98+
Log.e(TcpSocketModule.TAG, "Exception on socket " + id, e);
99+
error = e.getMessage();
100+
}
87101
WritableMap eventParams = Arguments.createMap();
88102
eventParams.putInt("id", id);
89103
eventParams.putInt("msgId", msgId);
@@ -92,18 +106,20 @@ public void onWritten(int id, int msgId, @Nullable String error) {
92106
sendEvent("written", eventParams);
93107
}
94108

95-
public void onClose(int id, String error) {
96-
if (error != null) {
97-
onError(id, error);
109+
public void onClose(int id, Exception e) {
110+
if (e != null) {
111+
onError(id, e);
98112
}
99113
WritableMap eventParams = Arguments.createMap();
100114
eventParams.putInt("id", id);
101-
eventParams.putBoolean("hadError", error != null);
115+
eventParams.putBoolean("hadError", e != null);
102116

103117
sendEvent("close", eventParams);
104118
}
105119

106-
public void onError(int id, String error) {
120+
public void onError(int id, Exception e) {
121+
Log.e(TcpSocketModule.TAG, "Exception on socket " + id, e);
122+
String error = e.getMessage();
107123
WritableMap eventParams = Arguments.createMap();
108124
eventParams.putInt("id", id);
109125
eventParams.putString("error", error);

android/src/main/java/com/asterinet/react/tcpsocket/TcpSocketClient.java

+33-20
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import java.util.concurrent.ExecutorService;
1616
import java.util.concurrent.Executors;
1717

18-
import javax.net.SocketFactory;
1918
import javax.net.ssl.SSLSocket;
2019
import javax.net.ssl.SSLSocketFactory;
2120

@@ -25,6 +24,7 @@ class TcpSocketClient extends TcpSocket {
2524
private final TcpEventListener receiverListener;
2625
private TcpReceiverTask receiverTask;
2726
private Socket socket;
27+
private boolean closed = true;
2828

2929
TcpSocketClient(TcpEventListener receiverListener, Integer id, Socket socket) {
3030
super(id);
@@ -38,20 +38,12 @@ public Socket getSocket() {
3838
return socket;
3939
}
4040

41-
public void connect(Context context, String address, final Integer port, ReadableMap options, Network network) throws IOException, GeneralSecurityException {
41+
public void connect(Context context, String address, final Integer port, ReadableMap options, Network network, ReadableMap tlsOptions) throws IOException, GeneralSecurityException {
4242
if (socket != null) throw new IOException("Already connected");
43-
final boolean isTls = options.hasKey("tls") && options.getBoolean("tls");
44-
if (isTls) {
45-
SocketFactory sf;
46-
if (options.hasKey("tlsCheckValidity") && !options.getBoolean("tlsCheckValidity")) {
47-
sf = SSLCertificateHelper.createBlindSocketFactory();
48-
} else {
49-
final String customTlsCert = options.hasKey("tlsCert") ? options.getString("tlsCert") : null;
50-
sf = customTlsCert != null ? SSLCertificateHelper.createCustomTrustedSocketFactory(context, customTlsCert) : SSLSocketFactory.getDefault();
51-
}
52-
final SSLSocket sslSocket = (SSLSocket) sf.createSocket();
53-
sslSocket.setUseClientMode(true);
54-
socket = sslSocket;
43+
if (tlsOptions != null) {
44+
SSLSocketFactory ssf = getSSLSocketFactory(context, tlsOptions);
45+
socket = ssf.createSocket();
46+
((SSLSocket) socket).setUseClientMode(true);
5547
} else {
5648
socket = new Socket();
5749
}
@@ -73,10 +65,30 @@ public void connect(Context context, String address, final Integer port, Readabl
7365
// bind
7466
socket.bind(new InetSocketAddress(localInetAddress, localPort));
7567
socket.connect(new InetSocketAddress(remoteInetAddress, port));
76-
if (isTls) ((SSLSocket) socket).startHandshake();
68+
if (socket instanceof SSLSocket) ((SSLSocket) socket).startHandshake();
7769
startListening();
7870
}
7971

72+
public void startTLS(Context context, ReadableMap tlsOptions) throws IOException, GeneralSecurityException {
73+
if (socket instanceof SSLSocket) return;
74+
SSLSocketFactory ssf = getSSLSocketFactory(context, tlsOptions);
75+
SSLSocket sslSocket = (SSLSocket) ssf.createSocket(socket, socket.getInetAddress().getHostAddress(), socket.getPort(), true);
76+
sslSocket.setUseClientMode(true);
77+
sslSocket.startHandshake();
78+
socket = sslSocket;
79+
}
80+
81+
private SSLSocketFactory getSSLSocketFactory(Context context, ReadableMap tlsOptions) throws GeneralSecurityException, IOException {
82+
SSLSocketFactory ssf;
83+
if (tlsOptions.hasKey("rejectUnauthorized") && !tlsOptions.getBoolean("rejectUnauthorized")) {
84+
ssf = SSLCertificateHelper.createBlindSocketFactory();
85+
} else {
86+
final String customTlsCert = tlsOptions.hasKey("ca") ? tlsOptions.getString("ca") : null;
87+
ssf = customTlsCert != null ? SSLCertificateHelper.createCustomTrustedSocketFactory(context, customTlsCert) : (SSLSocketFactory) SSLSocketFactory.getDefault();
88+
}
89+
return ssf;
90+
}
91+
8092
public void startListening() {
8193
receiverTask = new TcpReceiverTask(this, receiverListener);
8294
listenExecutor.execute(receiverTask);
@@ -95,8 +107,8 @@ public void run() {
95107
socket.getOutputStream().write(data);
96108
receiverListener.onWritten(getId(), msgId, null);
97109
} catch (IOException e) {
98-
receiverListener.onWritten(getId(), msgId, e.toString());
99-
receiverListener.onError(getId(), e.toString());
110+
receiverListener.onWritten(getId(), msgId, e);
111+
receiverListener.onError(getId(), e);
100112
}
101113
}
102114
});
@@ -109,12 +121,13 @@ public void destroy() {
109121
try {
110122
// close the socket
111123
if (socket != null && !socket.isClosed()) {
124+
closed = true;
112125
socket.close();
113126
receiverListener.onClose(getId(), null);
114127
socket = null;
115128
}
116129
} catch (IOException e) {
117-
receiverListener.onClose(getId(), e.getMessage());
130+
receiverListener.onClose(getId(), e);
118131
}
119132
}
120133

@@ -183,8 +196,8 @@ public void run() {
183196
}
184197
}
185198
} catch (IOException | InterruptedException ioe) {
186-
if (receiverListener != null && !socket.isClosed()) {
187-
receiverListener.onError(socketId, ioe.getMessage());
199+
if (receiverListener != null && !socket.isClosed() && !clientSocket.closed) {
200+
receiverListener.onError(socketId, ioe);
188201
}
189202
}
190203
}

android/src/main/java/com/asterinet/react/tcpsocket/TcpSocketModule.java

+36-17
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
import android.annotation.SuppressLint;
55
import android.content.Context;
66
import android.net.ConnectivityManager;
7+
import android.net.Network;
78
import android.net.NetworkCapabilities;
89
import android.net.NetworkRequest;
910
import android.util.Base64;
10-
import android.net.Network;
11+
12+
import androidx.annotation.NonNull;
13+
import androidx.annotation.Nullable;
1114

1215
import com.facebook.react.bridge.ReactApplicationContext;
1316
import com.facebook.react.bridge.ReactContextBaseJavaModule;
@@ -22,14 +25,12 @@
2225
import java.util.concurrent.ScheduledThreadPoolExecutor;
2326
import java.util.concurrent.TimeUnit;
2427

25-
import androidx.annotation.NonNull;
26-
import androidx.annotation.Nullable;
27-
2828
public class TcpSocketModule extends ReactContextBaseJavaModule {
29-
private static final String TAG = "TcpSockets";
29+
public static final String TAG = "TcpSockets";
3030
private static final int N_THREADS = 2;
3131
private final ReactApplicationContext mReactContext;
3232
private final ConcurrentHashMap<Integer, TcpSocket> socketMap = new ConcurrentHashMap<>();
33+
private final ConcurrentHashMap<Integer, ReadableMap> pendingTLS = new ConcurrentHashMap<>();
3334
private final ConcurrentHashMap<String, Network> mNetworkMap = new ConcurrentHashMap<>();
3435
private final CurrentNetwork currentNetwork = new CurrentNetwork();
3536
private final ExecutorService executorService = Executors.newFixedThreadPool(N_THREADS);
@@ -68,7 +69,7 @@ public void connect(@NonNull final Integer cId, @NonNull final String host, @Non
6869
@Override
6970
public void run() {
7071
if (socketMap.get(cId) != null) {
71-
tcpEvtListener.onError(cId, TAG + "createSocket called twice with the same id.");
72+
tcpEvtListener.onError(cId, new Exception("connect() called twice with the same id."));
7273
return;
7374
}
7475
try {
@@ -78,15 +79,33 @@ public void run() {
7879
selectNetwork(iface, localAddress);
7980
TcpSocketClient client = new TcpSocketClient(tcpEvtListener, cId, null);
8081
socketMap.put(cId, client);
81-
client.connect(mReactContext, host, port, options, currentNetwork.getNetwork());
82+
ReadableMap tlsOptions = pendingTLS.get(cId);
83+
client.connect(mReactContext, host, port, options, currentNetwork.getNetwork(), tlsOptions);
8284
tcpEvtListener.onConnect(cId, client);
8385
} catch (Exception e) {
84-
tcpEvtListener.onError(cId, e.getMessage());
86+
tcpEvtListener.onError(cId, e);
8587
}
8688
}
8789
});
8890
}
8991

92+
@SuppressLint("StaticFieldLeak")
93+
@SuppressWarnings("unused")
94+
@ReactMethod
95+
public void startTLS(final int cId, @NonNull final ReadableMap tlsOptions) {
96+
TcpSocketClient socketClient = (TcpSocketClient) socketMap.get(cId);
97+
// Not yet connected
98+
if (socketClient == null) {
99+
pendingTLS.put(cId, tlsOptions);
100+
} else {
101+
try {
102+
socketClient.startTLS(mReactContext, tlsOptions);
103+
} catch (Exception e) {
104+
tcpEvtListener.onError(cId, e);
105+
}
106+
}
107+
}
108+
90109
@SuppressLint("StaticFieldLeak")
91110
@SuppressWarnings("unused")
92111
@ReactMethod
@@ -137,11 +156,11 @@ public void listen(final Integer cId, final ReadableMap options) {
137156
@Override
138157
public void run() {
139158
try {
140-
TcpSocketServer server = new TcpSocketServer(socketMap, tcpEvtListener, cId, options);
159+
TcpSocketServer server = new TcpSocketServer(mReactContext, socketMap, tcpEvtListener, cId, options);
141160
socketMap.put(cId, server);
142161
tcpEvtListener.onListen(cId, server);
143162
} catch (Exception uhe) {
144-
tcpEvtListener.onError(cId, uhe.getMessage());
163+
tcpEvtListener.onError(cId, uhe);
145164
}
146165
}
147166
});
@@ -154,7 +173,7 @@ public void setNoDelay(@NonNull final Integer cId, final boolean noDelay) {
154173
try {
155174
client.setNoDelay(noDelay);
156175
} catch (IOException e) {
157-
tcpEvtListener.onError(cId, e.getMessage());
176+
tcpEvtListener.onError(cId, e);
158177
}
159178
}
160179

@@ -165,7 +184,7 @@ public void setKeepAlive(@NonNull final Integer cId, final boolean enable, final
165184
try {
166185
client.setKeepAlive(enable, initialDelay);
167186
} catch (IOException e) {
168-
tcpEvtListener.onError(cId, e.getMessage());
187+
tcpEvtListener.onError(cId, e);
169188
}
170189
}
171190

@@ -182,7 +201,7 @@ public void resume(final int cId) {
182201
TcpSocketClient client = getTcpClient(cId);
183202
client.resume();
184203
}
185-
204+
186205
@SuppressWarnings("unused")
187206
@ReactMethod
188207
public void addListener(String eventName) {
@@ -260,21 +279,21 @@ private void selectNetwork(@Nullable final String iface, @Nullable final String
260279
private TcpSocketClient getTcpClient(final int id) {
261280
TcpSocket socket = socketMap.get(id);
262281
if (socket == null) {
263-
throw new IllegalArgumentException(TAG + "No socket with id " + id);
282+
throw new IllegalArgumentException("No socket with id " + id);
264283
}
265284
if (!(socket instanceof TcpSocketClient)) {
266-
throw new IllegalArgumentException(TAG + "Socket with id " + id + " is not a client");
285+
throw new IllegalArgumentException("Socket with id " + id + " is not a client");
267286
}
268287
return (TcpSocketClient) socket;
269288
}
270289

271290
private TcpSocketServer getTcpServer(final int id) {
272291
TcpSocket socket = socketMap.get(id);
273292
if (socket == null) {
274-
throw new IllegalArgumentException(TAG + "No socket with id " + id);
293+
throw new IllegalArgumentException("No server socket with id " + id);
275294
}
276295
if (!(socket instanceof TcpSocketServer)) {
277-
throw new IllegalArgumentException(TAG + "Socket with id " + id + " is not a server");
296+
throw new IllegalArgumentException("Server socket with id " + id + " is not a server");
278297
}
279298
return (TcpSocketServer) socket;
280299
}

0 commit comments

Comments
 (0)