Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
package com.linkedin.venice.router;

import static org.mockito.Mockito.any;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

import com.linkedin.venice.router.stats.HealthCheckStats;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import java.net.InetSocketAddress;
import org.mockito.ArgumentCaptor;
import org.testng.Assert;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;


public class HealthCheckHandlerTest {
private HealthCheckStats healthCheckStats;
private HealthCheckHandler handler;
private ChannelHandlerContext ctx;
private Channel channel;

@BeforeMethod
public void setUp() {
healthCheckStats = mock(HealthCheckStats.class);
handler = new HealthCheckHandler(healthCheckStats);
ctx = mock(ChannelHandlerContext.class);
channel = mock(Channel.class);
doReturn(channel).when(ctx).channel();
doReturn(new InetSocketAddress("localhost", 1234)).when(channel).remoteAddress();
}

@Test
public void testOptionsRequestReturnsOK() {
HttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.OPTIONS, "/any/path");

handler.channelRead0(ctx, request);

ArgumentCaptor<Object> captor = ArgumentCaptor.forClass(Object.class);
verify(ctx).writeAndFlush(captor.capture());

FullHttpResponse response = (FullHttpResponse) captor.getValue();
try {
Assert.assertEquals(response.status(), HttpResponseStatus.OK);
verify(healthCheckStats, times(1)).recordHealthCheck();
} finally {
response.release();
}
}

@Test
public void testGetAdminWithoutResourceReturnsOK() {
HttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/admin");

handler.channelRead0(ctx, request);

ArgumentCaptor<Object> captor = ArgumentCaptor.forClass(Object.class);
verify(ctx).writeAndFlush(captor.capture());

FullHttpResponse response = (FullHttpResponse) captor.getValue();
try {
Assert.assertEquals(response.status(), HttpResponseStatus.OK);
verify(healthCheckStats, times(1)).recordHealthCheck();
} finally {
response.release();
}
}

@Test
public void testGetAdminWithTrailingSlashReturnsOK() {
HttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/admin/");

handler.channelRead0(ctx, request);

ArgumentCaptor<Object> captor = ArgumentCaptor.forClass(Object.class);
verify(ctx).writeAndFlush(captor.capture());

FullHttpResponse response = (FullHttpResponse) captor.getValue();
try {
Assert.assertEquals(response.status(), HttpResponseStatus.OK);
verify(healthCheckStats, times(1)).recordHealthCheck();
} finally {
response.release();
}
}

@Test
public void testGetAdminWithResourcePassesThrough() {
HttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/admin/storeName");

handler.channelRead0(ctx, request);

// Should pass through to next handler, not write response
verify(ctx, never()).writeAndFlush(any());
verify(ctx, times(1)).fireChannelRead(any());
verify(healthCheckStats, never()).recordHealthCheck();
}

@Test
public void testGetStoragePassesThrough() {
HttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/storage/storeName/key");

handler.channelRead0(ctx, request);

// Should pass through to next handler
verify(ctx, never()).writeAndFlush(any());
verify(ctx, times(1)).fireChannelRead(any());
verify(healthCheckStats, never()).recordHealthCheck();
}

@Test
public void testPostRequestPassesThrough() {
HttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/admin");

handler.channelRead0(ctx, request);

// POST should pass through even to /admin
verify(ctx, never()).writeAndFlush(any());
verify(ctx, times(1)).fireChannelRead(any());
verify(healthCheckStats, never()).recordHealthCheck();
}

@Test
public void testExceptionCaughtReturnsInternalServerError() {
Exception testException = new RuntimeException("Test exception");

handler.exceptionCaught(ctx, testException);

ArgumentCaptor<Object> captor = ArgumentCaptor.forClass(Object.class);
verify(ctx).writeAndFlush(captor.capture());

FullHttpResponse response = (FullHttpResponse) captor.getValue();
try {
Assert.assertEquals(response.status(), HttpResponseStatus.INTERNAL_SERVER_ERROR);
verify(healthCheckStats, times(1)).recordErrorHealthCheck();
verify(ctx, times(1)).close();
} finally {
response.release();
}
}

@Test
public void testExceptionCaughtWithRedundantExceptionStillRecordsMetric() {
// Even redundant exceptions should record the error metric
Exception testException = new RuntimeException("Connection reset by peer");

handler.exceptionCaught(ctx, testException);

verify(healthCheckStats, times(1)).recordErrorHealthCheck();
verify(ctx, times(1)).close();
}

@Test
public void testHealthCheckResponseHasEmptyBody() {
HttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.OPTIONS, "/");

handler.channelRead0(ctx, request);

ArgumentCaptor<Object> captor = ArgumentCaptor.forClass(Object.class);
verify(ctx).writeAndFlush(captor.capture());

FullHttpResponse response = (FullHttpResponse) captor.getValue();
try {
ByteBuf content = response.content();
Assert.assertEquals(content.readableBytes(), 0);
} finally {
response.release();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
package com.linkedin.venice.router;

import static org.mockito.Mockito.any;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

import com.linkedin.venice.router.stats.SecurityStats;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslHandshakeCompletionEvent;
import java.io.IOException;
import java.net.InetSocketAddress;
import org.mockito.ArgumentCaptor;
import org.testng.Assert;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;


public class RouterSslVerificationHandlerTest {
private SecurityStats securityStats;
private ChannelHandlerContext ctx;
private Channel channel;
private ChannelPipeline pipeline;
private Channel parentChannel;
private ChannelPipeline parentPipeline;

@BeforeMethod
public void setUp() {
securityStats = mock(SecurityStats.class);
ctx = mock(ChannelHandlerContext.class);
channel = mock(Channel.class);
pipeline = mock(ChannelPipeline.class);
parentChannel = mock(Channel.class);
parentPipeline = mock(ChannelPipeline.class);

doReturn(channel).when(ctx).channel();
doReturn(pipeline).when(ctx).pipeline();
doReturn(new InetSocketAddress("localhost", 1234)).when(channel).remoteAddress();
doReturn(parentChannel).when(channel).parent();
doReturn(parentPipeline).when(parentChannel).pipeline();
}

@Test
public void testRequestWithSslHandlerPassesThrough() throws IOException {
SslHandler sslHandler = mock(SslHandler.class);
doReturn(sslHandler).when(pipeline).get(SslHandler.class);

RouterSslVerificationHandler handler = new RouterSslVerificationHandler(securityStats, true);
HttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/storage/store/key");

handler.channelRead0(ctx, request);

// Should pass through to next handler
verify(ctx, times(1)).fireChannelRead(any());
verify(ctx, never()).writeAndFlush(any());
verify(securityStats, times(1)).updateConnectionCountInCurrentMetricTimeWindow();
verify(securityStats, never()).recordNonSslRequest();
}

@Test
public void testRequestWithSslHandlerInParentPipelinePassesThrough() throws IOException {
// No SSL in direct pipeline
doReturn(null).when(pipeline).get(SslHandler.class);
// SSL in parent pipeline (HTTP/2 case)
SslHandler sslHandler = mock(SslHandler.class);
doReturn(sslHandler).when(parentPipeline).get(SslHandler.class);

RouterSslVerificationHandler handler = new RouterSslVerificationHandler(securityStats, true);
HttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/storage/store/key");

handler.channelRead0(ctx, request);

// Should pass through to next handler
verify(ctx, times(1)).fireChannelRead(any());
verify(ctx, never()).writeAndFlush(any());
verify(securityStats, never()).recordNonSslRequest();
}

@Test
public void testRequestWithoutSslHandlerReturnsForbiddenWhenRequired() throws IOException {
doReturn(null).when(pipeline).get(SslHandler.class);
doReturn(null).when(parentPipeline).get(SslHandler.class);

RouterSslVerificationHandler handler = new RouterSslVerificationHandler(securityStats, true);
HttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/storage/store/key");

handler.channelRead0(ctx, request);

ArgumentCaptor<Object> captor = ArgumentCaptor.forClass(Object.class);
verify(ctx).writeAndFlush(captor.capture());

FullHttpResponse response = (FullHttpResponse) captor.getValue();
Assert.assertEquals(response.status(), HttpResponseStatus.FORBIDDEN);
verify(securityStats, times(1)).recordNonSslRequest();
verify(ctx, times(1)).close();
verify(ctx, never()).fireChannelRead(any());
Comment on lines +104 to +107
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The FullHttpResponse object captured from the ArgumentCaptor is a Netty reference-counted object that must be explicitly released to prevent memory leaks. Wrap the assertions in a try-finally block and call response.release() in the finally block, following the pattern used consistently in HealthCheckHandlerTest.

Suggested change
Assert.assertEquals(response.status(), HttpResponseStatus.FORBIDDEN);
verify(securityStats, times(1)).recordNonSslRequest();
verify(ctx, times(1)).close();
verify(ctx, never()).fireChannelRead(any());
try {
Assert.assertEquals(response.status(), HttpResponseStatus.FORBIDDEN);
verify(securityStats, times(1)).recordNonSslRequest();
verify(ctx, times(1)).close();
verify(ctx, never()).fireChannelRead(any());
} finally {
response.release();
}

Copilot uses AI. Check for mistakes.
}

@Test
public void testRequestWithoutSslHandlerPassesThroughWhenNotRequired() throws IOException {
doReturn(null).when(pipeline).get(SslHandler.class);
doReturn(null).when(parentPipeline).get(SslHandler.class);

RouterSslVerificationHandler handler = new RouterSslVerificationHandler(securityStats, false);
HttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/storage/store/key");

handler.channelRead0(ctx, request);

// Should pass through even without SSL when not required
verify(ctx, times(1)).fireChannelRead(any());
verify(ctx, never()).writeAndFlush(any());
verify(ctx, never()).close();
verify(securityStats, times(1)).recordNonSslRequest();
}

@Test
public void testSslHandshakeSuccessRecordsMetric() {
RouterSslVerificationHandler handler = new RouterSslVerificationHandler(securityStats, true);
SslHandshakeCompletionEvent successEvent = SslHandshakeCompletionEvent.SUCCESS;

handler.userEventTriggered(ctx, successEvent);

verify(securityStats, times(1)).recordSslSuccess();
verify(securityStats, never()).recordSslError();
verify(ctx, times(1)).fireUserEventTriggered(successEvent);
verify(ctx, never()).close();
}

@Test
public void testSslHandshakeFailureRecordsMetric() {
doReturn(pipeline).when(ctx).pipeline();

RouterSslVerificationHandler handler = new RouterSslVerificationHandler(securityStats, true);
Exception cause = new RuntimeException("SSL handshake failed");
SslHandshakeCompletionEvent failureEvent = new SslHandshakeCompletionEvent(cause);

handler.userEventTriggered(ctx, failureEvent);

verify(securityStats, times(1)).recordSslError();
verify(securityStats, never()).recordSslSuccess();
verify(ctx, times(1)).close();
verify(pipeline, times(1)).remove(handler);
}

@Test
public void testNonSslHandshakeEventPassesThrough() {
RouterSslVerificationHandler handler = new RouterSslVerificationHandler(securityStats, true);
Object otherEvent = new Object();

handler.userEventTriggered(ctx, otherEvent);

verify(ctx, times(1)).fireUserEventTriggered(otherEvent);
verify(securityStats, never()).recordSslSuccess();
verify(securityStats, never()).recordSslError();
}

@Test
public void testDefaultConstructorRequiresSsl() throws IOException {
doReturn(null).when(pipeline).get(SslHandler.class);
doReturn(null).when(parentPipeline).get(SslHandler.class);

// Use default constructor which should default to requireSsl=true
RouterSslVerificationHandler handler = new RouterSslVerificationHandler(securityStats);
HttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/storage/store/key");

handler.channelRead0(ctx, request);

// Should return 403 because SSL is required by default
ArgumentCaptor<Object> captor = ArgumentCaptor.forClass(Object.class);
verify(ctx).writeAndFlush(captor.capture());

FullHttpResponse response = (FullHttpResponse) captor.getValue();
Assert.assertEquals(response.status(), HttpResponseStatus.FORBIDDEN);
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The FullHttpResponse object captured from the ArgumentCaptor is a Netty reference-counted object that must be explicitly released to prevent memory leaks. Wrap the assertions in a try-finally block and call response.release() in the finally block, following the pattern used consistently in HealthCheckHandlerTest.

Suggested change
Assert.assertEquals(response.status(), HttpResponseStatus.FORBIDDEN);
try {
Assert.assertEquals(response.status(), HttpResponseStatus.FORBIDDEN);
} finally {
response.release();
}

Copilot uses AI. Check for mistakes.
}

@Test
public void testConnectionCountUpdatedOnEveryRequest() throws IOException {
SslHandler sslHandler = mock(SslHandler.class);
doReturn(sslHandler).when(pipeline).get(SslHandler.class);

RouterSslVerificationHandler handler = new RouterSslVerificationHandler(securityStats, true);
HttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/storage/store/key");

// Make multiple requests
handler.channelRead0(ctx, request);
handler.channelRead0(ctx, request);
handler.channelRead0(ctx, request);

// Connection count should be updated for each request
verify(securityStats, times(3)).updateConnectionCountInCurrentMetricTimeWindow();
}
}
Loading
Loading