From cc91f87f5ac275f0027acaa422a0b0fdde3393cc Mon Sep 17 00:00:00 2001
From: Chris <cgenrich@gmail.com>
Date: Wed, 30 Jun 2021 14:46:30 -0700
Subject: [PATCH] Add HTTPS proxy support and fix broken test.

---
 .../channel/ChannelPoolPartitioning.java      |  2 +-
 .../netty/channel/ChannelManager.java         | 13 ++-
 .../netty/channel/NettyConnectListener.java   |  5 +-
 .../intercept/ConnectSuccessInterceptor.java  |  2 +-
 .../org/asynchttpclient/proxy/ProxyType.java  |  2 +-
 .../AsyncStreamHandlerTest.java               |  2 +-
 .../asynchttpclient/proxy/HttpsProxyTest.java | 79 ++++++++++++-------
 7 files changed, 69 insertions(+), 36 deletions(-)

diff --git a/client/src/main/java/org/asynchttpclient/channel/ChannelPoolPartitioning.java b/client/src/main/java/org/asynchttpclient/channel/ChannelPoolPartitioning.java
index fb00ba4803..2fa1a43e60 100644
--- a/client/src/main/java/org/asynchttpclient/channel/ChannelPoolPartitioning.java
+++ b/client/src/main/java/org/asynchttpclient/channel/ChannelPoolPartitioning.java
@@ -42,7 +42,7 @@ public Object getPartitionKey(Uri uri, String virtualHost, ProxyServer proxyServ
                 targetHostBaseUrl,
                 virtualHost,
                 proxyServer.getHost(),
-                uri.isSecured() && proxyServer.getProxyType() == ProxyType.HTTP ?
+                uri.isSecured() && proxyServer.getProxyType().isHttp() ?
                         proxyServer.getSecuredPort() :
                         proxyServer.getPort(),
                 proxyServer.getProxyType());
diff --git a/client/src/main/java/org/asynchttpclient/netty/channel/ChannelManager.java b/client/src/main/java/org/asynchttpclient/netty/channel/ChannelManager.java
index b93dfb380e..2cbe960bfb 100755
--- a/client/src/main/java/org/asynchttpclient/netty/channel/ChannelManager.java
+++ b/client/src/main/java/org/asynchttpclient/netty/channel/ChannelManager.java
@@ -51,6 +51,7 @@
 import org.asynchttpclient.netty.request.NettyRequestSender;
 import org.asynchttpclient.netty.ssl.DefaultSslEngineFactory;
 import org.asynchttpclient.proxy.ProxyServer;
+import org.asynchttpclient.proxy.ProxyType;
 import org.asynchttpclient.uri.Uri;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -71,6 +72,7 @@ public class ChannelManager {
 
   public static final String HTTP_CLIENT_CODEC = "http";
   public static final String SSL_HANDLER = "ssl";
+  public static final String SSL_TUNNEL_HANDLER = "ssl-tunnel";
   public static final String SOCKS_HANDLER = "socks";
   public static final String INFLATER_HANDLER = "inflater";
   public static final String CHUNKED_WRITER_HANDLER = "chunked-writer";
@@ -156,6 +158,10 @@ public ChannelManager(final AsyncHttpClientConfig config, Timer nettyTimer) {
   public static boolean isSslHandlerConfigured(ChannelPipeline pipeline) {
     return pipeline.get(SSL_HANDLER) != null;
   }
+  
+  public static boolean isSslTunnelHandlerConfigured(ChannelPipeline pipeline) {
+    return pipeline.get(SSL_TUNNEL_HANDLER) != null;
+  }
 
   private Bootstrap newBootstrap(ChannelFactory<? extends Channel> channelFactory, EventLoopGroup eventLoopGroup, AsyncHttpClientConfig config) {
     @SuppressWarnings("deprecation")
@@ -340,7 +346,7 @@ private SslHandler createSslHandler(String peerHost, int peerPort) {
     return sslHandler;
   }
 
-  public Future<Channel> updatePipelineForHttpTunneling(ChannelPipeline pipeline, Uri requestUri) {
+  public Future<Channel> updatePipelineForHttpTunneling(ChannelPipeline pipeline, Uri requestUri, ProxyType proxyType) {
 
     Future<Channel> whenHandshaked = null;
 
@@ -354,6 +360,11 @@ public Future<Channel> updatePipelineForHttpTunneling(ChannelPipeline pipeline,
         pipeline.addBefore(INFLATER_HANDLER, SSL_HANDLER, sslHandler);
       }
       pipeline.addAfter(SSL_HANDLER, HTTP_CLIENT_CODEC, newHttpClientCodec());
+      if(ProxyType.HTTPS.equals(proxyType) && !isSslTunnelHandlerConfigured(pipeline)) {
+        SslHandler sslHandler = createSslHandler(requestUri.getHost(), requestUri.getExplicitPort());
+        whenHandshaked = sslHandler.handshakeFuture();
+        pipeline.addAfter(SSL_HANDLER, SSL_TUNNEL_HANDLER, sslHandler);
+      }
 
     } else {
       pipeline.addBefore(AHC_HTTP_HANDLER, HTTP_CLIENT_CODEC, newHttpClientCodec());
diff --git a/client/src/main/java/org/asynchttpclient/netty/channel/NettyConnectListener.java b/client/src/main/java/org/asynchttpclient/netty/channel/NettyConnectListener.java
index 4a6f4dce20..5551ef7ac7 100755
--- a/client/src/main/java/org/asynchttpclient/netty/channel/NettyConnectListener.java
+++ b/client/src/main/java/org/asynchttpclient/netty/channel/NettyConnectListener.java
@@ -24,6 +24,7 @@
 import org.asynchttpclient.netty.request.NettyRequestSender;
 import org.asynchttpclient.netty.timeout.TimeoutsHolder;
 import org.asynchttpclient.proxy.ProxyServer;
+import org.asynchttpclient.proxy.ProxyType;
 import org.asynchttpclient.uri.Uri;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -107,10 +108,10 @@ public void onSuccess(Channel channel, InetSocketAddress remoteAddress) {
     ProxyServer proxyServer = future.getProxyServer();
 
     // in case of proxy tunneling, we'll add the SslHandler later, after the CONNECT request
-    if ((proxyServer == null || proxyServer.getProxyType().isSocks()) && uri.isSecured()) {
+    if ((proxyServer == null || ProxyType.HTTPS.equals(proxyServer.getProxyType()) || proxyServer.getProxyType().isSocks()) && uri.isSecured()) {
       SslHandler sslHandler;
       try {
-        sslHandler = channelManager.addSslHandler(channel.pipeline(), uri, request.getVirtualHost(), proxyServer != null);
+        sslHandler = channelManager.addSslHandler(channel.pipeline(), uri, request.getVirtualHost(), proxyServer != null && proxyServer.getProxyType().isSocks());
       } catch (Exception sslError) {
         onFailure(channel, sslError);
         return;
diff --git a/client/src/main/java/org/asynchttpclient/netty/handler/intercept/ConnectSuccessInterceptor.java b/client/src/main/java/org/asynchttpclient/netty/handler/intercept/ConnectSuccessInterceptor.java
index eb2e98e36f..fd31d1bd28 100644
--- a/client/src/main/java/org/asynchttpclient/netty/handler/intercept/ConnectSuccessInterceptor.java
+++ b/client/src/main/java/org/asynchttpclient/netty/handler/intercept/ConnectSuccessInterceptor.java
@@ -48,7 +48,7 @@ public boolean exitAfterHandlingConnect(Channel channel,
     Uri requestUri = request.getUri();
     LOGGER.debug("Connecting to proxy {} for scheme {}", proxyServer, requestUri.getScheme());
 
-    Future<Channel> whenHandshaked =  channelManager.updatePipelineForHttpTunneling(channel.pipeline(), requestUri);
+    Future<Channel> whenHandshaked =  channelManager.updatePipelineForHttpTunneling(channel.pipeline(), requestUri, proxyServer.getProxyType());
 
     future.setReuseChannel(true);
     future.setConnectAllowed(false);
diff --git a/client/src/main/java/org/asynchttpclient/proxy/ProxyType.java b/client/src/main/java/org/asynchttpclient/proxy/ProxyType.java
index bf680018a7..fc09e11b06 100644
--- a/client/src/main/java/org/asynchttpclient/proxy/ProxyType.java
+++ b/client/src/main/java/org/asynchttpclient/proxy/ProxyType.java
@@ -14,7 +14,7 @@
 package org.asynchttpclient.proxy;
 
 public enum ProxyType {
-  HTTP(true), SOCKS_V4(false), SOCKS_V5(false);
+  HTTP(true), HTTPS(true), SOCKS_V4(false), SOCKS_V5(false);
 
   private final boolean http;
 
diff --git a/client/src/test/java/org/asynchttpclient/AsyncStreamHandlerTest.java b/client/src/test/java/org/asynchttpclient/AsyncStreamHandlerTest.java
index 17dc2213ba..8c8deb5a84 100644
--- a/client/src/test/java/org/asynchttpclient/AsyncStreamHandlerTest.java
+++ b/client/src/test/java/org/asynchttpclient/AsyncStreamHandlerTest.java
@@ -442,7 +442,7 @@ public void asyncOptionsTest() throws Throwable {
           // FIXME: Actually refactor this test to account for both cases
         final String[] expected = {"GET", "HEAD", "OPTIONS", "POST"};
         final String[] expectedWithTrace = {"GET", "HEAD", "OPTIONS", "POST", "TRACE"};
-        Future<String> f = client.prepareOptions("http://www.apache.org/").execute(new AsyncHandlerAdapter() {
+        Future<String> f = client.prepareOptions("https://www.apache.org/").execute(new AsyncHandlerAdapter() {
 
           @Override
           public State onHeadersReceived(HttpHeaders headers) {
diff --git a/client/src/test/java/org/asynchttpclient/proxy/HttpsProxyTest.java b/client/src/test/java/org/asynchttpclient/proxy/HttpsProxyTest.java
index a8a1e8d3d3..b9004a4094 100644
--- a/client/src/test/java/org/asynchttpclient/proxy/HttpsProxyTest.java
+++ b/client/src/test/java/org/asynchttpclient/proxy/HttpsProxyTest.java
@@ -12,15 +12,19 @@
  */
 package org.asynchttpclient.proxy;
 
+import java.util.ArrayList;
+import java.util.List;
 import org.asynchttpclient.*;
 import org.asynchttpclient.request.body.generator.ByteArrayBodyGenerator;
 import org.asynchttpclient.test.EchoHandler;
 import org.eclipse.jetty.proxy.ConnectHandler;
+import org.eclipse.jetty.server.Handler;
 import org.eclipse.jetty.server.Server;
 import org.eclipse.jetty.server.ServerConnector;
 import org.eclipse.jetty.server.handler.AbstractHandler;
 import org.testng.annotations.AfterClass;
 import org.testng.annotations.BeforeClass;
+import org.testng.annotations.DataProvider;
 import org.testng.annotations.Test;
 
 import static org.asynchttpclient.Dsl.*;
@@ -34,50 +38,67 @@
  */
 public class HttpsProxyTest extends AbstractBasicTest {
 
-  private Server server2;
+  private List<Server> servers;
+  private int httpsProxyPort;
 
   public AbstractHandler configureHandler() throws Exception {
     return new ConnectHandler();
   }
+  
+  @DataProvider (name = "serverPorts")
+  public Object[][] serverPorts() {
+    return new Object[][] {{port1, ProxyType.HTTP}, {httpsProxyPort, ProxyType.HTTPS}};
+  }
+  
 
   @BeforeClass(alwaysRun = true)
   public void setUpGlobal() throws Exception {
-    server = new Server();
-    ServerConnector connector = addHttpConnector(server);
-    server.setHandler(configureHandler());
-    server.start();
-    port1 = connector.getLocalPort();
+    servers = new ArrayList<>();
+    port1 = startServer(configureHandler(), false);
+
+    port2 = startServer(new EchoHandler(), true);
 
-    server2 = new Server();
-    ServerConnector connector2 = addHttpsConnector(server2);
-    server2.setHandler(new EchoHandler());
-    server2.start();
-    port2 = connector2.getLocalPort();
+    httpsProxyPort = startServer(configureHandler(), true);
 
     logger.info("Local HTTP server started successfully");
   }
+  
+  private int startServer(Handler handler, boolean secure) throws Exception {
+    Server server = new Server();
+    @SuppressWarnings("resource")
+    ServerConnector connector = secure ? addHttpsConnector(server) : addHttpConnector(server);
+    server.setHandler(handler);
+    server.start();
+    servers.add(server);
+    return connector.getLocalPort();
+  }
 
   @AfterClass(alwaysRun = true)
-  public void tearDownGlobal() throws Exception {
-    server.stop();
-    server2.stop();
+  public void tearDownGlobal() {
+    servers.forEach(t -> {
+      try {
+        t.stop();
+      } catch (Exception e) {
+        // couldn't stop server
+      }
+    });
   }
 
-  @Test
-  public void testRequestProxy() throws Exception {
+  @Test(dataProvider = "serverPorts")
+  public void testRequestProxy(int proxyPort, ProxyType type) throws Exception {
 
     try (AsyncHttpClient asyncHttpClient = asyncHttpClient(config().setFollowRedirect(true).setUseInsecureTrustManager(true))) {
-      RequestBuilder rb = get(getTargetUrl2()).setProxyServer(proxyServer("localhost", port1));
+      RequestBuilder rb = get(getTargetUrl2()).setProxyServer(proxyServer("localhost", proxyPort).setProxyType(type));
       Response r = asyncHttpClient.executeRequest(rb.build()).get();
       assertEquals(r.getStatusCode(), 200);
     }
   }
 
-  @Test
-  public void testConfigProxy() throws Exception {
+  @Test(dataProvider = "serverPorts")
+  public void testConfigProxy(int proxyPort, ProxyType type) throws Exception {
     AsyncHttpClientConfig config = config()
             .setFollowRedirect(true)
-            .setProxyServer(proxyServer("localhost", port1).build())
+            .setProxyServer(proxyServer("localhost", proxyPort).setProxyType(type).build())
             .setUseInsecureTrustManager(true)
             .build();
     try (AsyncHttpClient asyncHttpClient = asyncHttpClient(config)) {
@@ -86,11 +107,11 @@ public void testConfigProxy() throws Exception {
     }
   }
 
-  @Test
-  public void testNoDirectRequestBodyWithProxy() throws Exception {
+  @Test(dataProvider = "serverPorts")
+  public void testNoDirectRequestBodyWithProxy(int proxyPort, ProxyType type) throws Exception {
     AsyncHttpClientConfig config = config()
       .setFollowRedirect(true)
-      .setProxyServer(proxyServer("localhost", port1).build())
+      .setProxyServer(proxyServer("localhost", proxyPort).setProxyType(type).build())
       .setUseInsecureTrustManager(true)
       .build();
     try (AsyncHttpClient asyncHttpClient = asyncHttpClient(config)) {
@@ -99,11 +120,11 @@ public void testNoDirectRequestBodyWithProxy() throws Exception {
     }
   }
 
-  @Test
-  public void testDecompressBodyWithProxy() throws Exception {
+  @Test(dataProvider = "serverPorts")
+  public void testDecompressBodyWithProxy(int proxyPort, ProxyType type) throws Exception {
     AsyncHttpClientConfig config = config()
       .setFollowRedirect(true)
-      .setProxyServer(proxyServer("localhost", port1).build())
+      .setProxyServer(proxyServer("localhost", proxyPort).setProxyType(type).build())
       .setUseInsecureTrustManager(true)
       .build();
     try (AsyncHttpClient asyncHttpClient = asyncHttpClient(config)) {
@@ -116,11 +137,11 @@ public void testDecompressBodyWithProxy() throws Exception {
     }
   }
 
-  @Test
-  public void testPooledConnectionsWithProxy() throws Exception {
+  @Test(dataProvider = "serverPorts")
+  public void testPooledConnectionsWithProxy(int proxyPort, ProxyType type) throws Exception {
 
     try (AsyncHttpClient asyncHttpClient = asyncHttpClient(config().setFollowRedirect(true).setUseInsecureTrustManager(true).setKeepAlive(true))) {
-      RequestBuilder rb = get(getTargetUrl2()).setProxyServer(proxyServer("localhost", port1));
+      RequestBuilder rb = get(getTargetUrl2()).setProxyServer(proxyServer("localhost", proxyPort).setProxyType(type));
 
       Response r1 = asyncHttpClient.executeRequest(rb.build()).get();
       assertEquals(r1.getStatusCode(), 200);