diff --git a/dd-java-agent/instrumentation/netty-3.8/src/main/java/datadog/trace/instrumentation/netty38/ChannelPipelineAdviceUtil.java b/dd-java-agent/instrumentation/netty-3.8/src/main/java/datadog/trace/instrumentation/netty38/ChannelPipelineAdviceUtil.java index bcfb4349fb0..45d4a34f9cb 100644 --- a/dd-java-agent/instrumentation/netty-3.8/src/main/java/datadog/trace/instrumentation/netty38/ChannelPipelineAdviceUtil.java +++ b/dd-java-agent/instrumentation/netty-3.8/src/main/java/datadog/trace/instrumentation/netty38/ChannelPipelineAdviceUtil.java @@ -1,5 +1,6 @@ package datadog.trace.instrumentation.netty38; +import datadog.trace.api.InstrumenterConfig; import datadog.trace.bootstrap.CallDepthThreadLocalMap; import datadog.trace.bootstrap.ContextStore; import datadog.trace.instrumentation.netty38.client.HttpClientRequestTracingHandler; @@ -9,6 +10,9 @@ import datadog.trace.instrumentation.netty38.server.HttpServerResponseTracingHandler; import datadog.trace.instrumentation.netty38.server.HttpServerTracingHandler; import datadog.trace.instrumentation.netty38.server.MaybeBlockResponseHandler; +import datadog.trace.instrumentation.netty38.server.websocket.WebSocketServerRequestTracingHandler; +import datadog.trace.instrumentation.netty38.server.websocket.WebSocketServerResponseTracingHandler; +import datadog.trace.instrumentation.netty38.server.websocket.WebSocketServerTracingHandler; import org.jboss.netty.channel.Channel; import org.jboss.netty.channel.ChannelHandler; import org.jboss.netty.channel.ChannelPipeline; @@ -18,6 +22,9 @@ import org.jboss.netty.handler.codec.http.HttpResponseDecoder; import org.jboss.netty.handler.codec.http.HttpResponseEncoder; import org.jboss.netty.handler.codec.http.HttpServerCodec; +import org.jboss.netty.handler.codec.http.websocketx.WebSocket13FrameDecoder; +import org.jboss.netty.handler.codec.http.websocketx.WebSocket13FrameEncoder; +import org.jboss.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; /** * When certain handlers are added to the pipeline, we want to add our corresponding tracing @@ -46,6 +53,33 @@ public static void wrapHandler( new HttpServerResponseTracingHandler(contextStore)); pipeline.addLast( MaybeBlockResponseHandler.class.getName(), new MaybeBlockResponseHandler(contextStore)); + } else if (handler instanceof WebSocketServerProtocolHandler) { + if (InstrumenterConfig.get().isWebsocketTracingEnabled()) { + if (pipeline.get(HttpServerTracingHandler.class) != null) { + addHandlerAfter( + pipeline, + "datadog.trace.instrumentation.netty38.server.HttpServerTracingHandler", + new WebSocketServerTracingHandler(contextStore)); + } + } + } else if (handler instanceof WebSocket13FrameEncoder) { + if (InstrumenterConfig.get().isWebsocketTracingEnabled()) { + if (pipeline.get(HttpServerRequestTracingHandler.class) != null) { + addHandlerAfter( + pipeline, + "datadog.trace.instrumentation.netty38.server.HttpServerRequestTracingHandler", + new WebSocketServerRequestTracingHandler(contextStore)); + } + } + } else if (handler instanceof WebSocket13FrameDecoder) { + if (InstrumenterConfig.get().isWebsocketTracingEnabled()) { + if (pipeline.get(HttpServerResponseTracingHandler.class) != null) { + addHandlerAfter( + pipeline, + "datadog.trace.instrumentation.netty38.server.HttpServerResponseTracingHandler", + new WebSocketServerResponseTracingHandler(contextStore)); + } + } } else // Client pipeline handlers if (handler instanceof HttpClientCodec) { @@ -64,4 +98,13 @@ public static void wrapHandler( CallDepthThreadLocalMap.reset(ChannelPipeline.class); } } + + private static void addHandlerAfter( + final ChannelPipeline pipeline, final String name, final ChannelHandler handler) { + ChannelHandler existing = pipeline.get(handler.getClass()); + if (existing != null) { + pipeline.remove(existing); + } + pipeline.addAfter(name, handler.getClass().getName(), handler); + } } diff --git a/dd-java-agent/instrumentation/netty-3.8/src/main/java/datadog/trace/instrumentation/netty38/ChannelTraceContext.java b/dd-java-agent/instrumentation/netty-3.8/src/main/java/datadog/trace/instrumentation/netty38/ChannelTraceContext.java index 51b277f958d..24985e8cd40 100644 --- a/dd-java-agent/instrumentation/netty-3.8/src/main/java/datadog/trace/instrumentation/netty38/ChannelTraceContext.java +++ b/dd-java-agent/instrumentation/netty-3.8/src/main/java/datadog/trace/instrumentation/netty38/ChannelTraceContext.java @@ -3,6 +3,7 @@ import datadog.trace.bootstrap.ContextStore; import datadog.trace.bootstrap.instrumentation.api.AgentScope; import datadog.trace.bootstrap.instrumentation.api.AgentSpan; +import datadog.trace.bootstrap.instrumentation.websocket.HandlerContext; import org.jboss.netty.handler.codec.http.HttpHeaders; public class ChannelTraceContext { @@ -23,6 +24,9 @@ public ChannelTraceContext create() { boolean analyzedResponse; boolean blockedResponse; + HandlerContext.Sender senderHandlerContext; + HandlerContext.Receiver receiverHandlerContext; + public void reset() { this.connectionContinuation = null; this.serverSpan = null; @@ -88,4 +92,20 @@ public void setClientSpan(AgentSpan clientSpan) { public void setClientParentSpan(AgentSpan clientParentSpan) { this.clientParentSpan = clientParentSpan; } + + public HandlerContext.Sender getSenderHandlerContext() { + return senderHandlerContext; + } + + public void setSenderHandlerContext(HandlerContext.Sender senderHandlerContext) { + this.senderHandlerContext = senderHandlerContext; + } + + public HandlerContext.Receiver getReceiverHandlerContext() { + return receiverHandlerContext; + } + + public void setReceiverHandlerContext(HandlerContext.Receiver receiverHandlerContext) { + this.receiverHandlerContext = receiverHandlerContext; + } } diff --git a/dd-java-agent/instrumentation/netty-3.8/src/main/java/datadog/trace/instrumentation/netty38/NettyChannelPipelineInstrumentation.java b/dd-java-agent/instrumentation/netty-3.8/src/main/java/datadog/trace/instrumentation/netty38/NettyChannelPipelineInstrumentation.java index 30c38f089fb..46586895a20 100644 --- a/dd-java-agent/instrumentation/netty-3.8/src/main/java/datadog/trace/instrumentation/netty38/NettyChannelPipelineInstrumentation.java +++ b/dd-java-agent/instrumentation/netty-3.8/src/main/java/datadog/trace/instrumentation/netty38/NettyChannelPipelineInstrumentation.java @@ -68,6 +68,9 @@ public String[] helperClassNames() { packageName + ".server.HttpServerResponseTracingHandler", packageName + ".server.HttpServerTracingHandler", packageName + ".server.MaybeBlockResponseHandler", + packageName + ".server.websocket.WebSocketServerTracingHandler", + packageName + ".server.websocket.WebSocketServerRequestTracingHandler", + packageName + ".server.websocket.WebSocketServerResponseTracingHandler", }; } diff --git a/dd-java-agent/instrumentation/netty-3.8/src/main/java/datadog/trace/instrumentation/netty38/server/HttpServerResponseTracingHandler.java b/dd-java-agent/instrumentation/netty-3.8/src/main/java/datadog/trace/instrumentation/netty38/server/HttpServerResponseTracingHandler.java index 78762a88104..a32891cb820 100644 --- a/dd-java-agent/instrumentation/netty-3.8/src/main/java/datadog/trace/instrumentation/netty38/server/HttpServerResponseTracingHandler.java +++ b/dd-java-agent/instrumentation/netty-3.8/src/main/java/datadog/trace/instrumentation/netty38/server/HttpServerResponseTracingHandler.java @@ -6,6 +6,7 @@ import datadog.trace.bootstrap.ContextStore; import datadog.trace.bootstrap.instrumentation.api.AgentScope; import datadog.trace.bootstrap.instrumentation.api.AgentSpan; +import datadog.trace.bootstrap.instrumentation.websocket.HandlerContext; import datadog.trace.instrumentation.netty38.ChannelTraceContext; import org.jboss.netty.channel.Channel; import org.jboss.netty.channel.ChannelHandlerContext; @@ -17,6 +18,7 @@ public class HttpServerResponseTracingHandler extends SimpleChannelDownstreamHandler { private final ContextStore contextStore; + private static final String UPGRADE_HEADER = "upgrade"; public HttpServerResponseTracingHandler( final ContextStore contextStore) { @@ -45,7 +47,16 @@ public void writeRequested(final ChannelHandlerContext ctx, final MessageEvent m span.finish(); // Finish the span manually since finishSpanOnClose was false throw throwable; } - if (response.getStatus() != HttpResponseStatus.CONTINUE) { + final boolean isWebsocketUpgrade = + response.getStatus() == HttpResponseStatus.SWITCHING_PROTOCOLS + && "websocket".equals(response.headers().get(UPGRADE_HEADER)); + if (isWebsocketUpgrade) { + String channelId = ctx.getChannel().getId().toString(); + channelTraceContext.setSenderHandlerContext(new HandlerContext.Sender(span, channelId)); + } + if (response.getStatus() != HttpResponseStatus.CONTINUE + && (response.getStatus() != HttpResponseStatus.SWITCHING_PROTOCOLS + || isWebsocketUpgrade)) { DECORATE.onResponse(span, response); DECORATE.beforeFinish(span); span.finish(); // Finish the span manually since finishSpanOnClose was false diff --git a/dd-java-agent/instrumentation/netty-3.8/src/main/java/datadog/trace/instrumentation/netty38/server/websocket/WebSocketServerRequestTracingHandler.java b/dd-java-agent/instrumentation/netty-3.8/src/main/java/datadog/trace/instrumentation/netty38/server/websocket/WebSocketServerRequestTracingHandler.java new file mode 100644 index 00000000000..70c990db900 --- /dev/null +++ b/dd-java-agent/instrumentation/netty-3.8/src/main/java/datadog/trace/instrumentation/netty38/server/websocket/WebSocketServerRequestTracingHandler.java @@ -0,0 +1,135 @@ +package datadog.trace.instrumentation.netty38.server.websocket; + +import static datadog.trace.bootstrap.instrumentation.api.AgentTracer.activateSpan; +import static datadog.trace.bootstrap.instrumentation.decorator.WebsocketDecorator.DECORATE; +import static datadog.trace.bootstrap.instrumentation.websocket.HandlersExtractor.MESSAGE_TYPE_TEXT; + +import datadog.trace.bootstrap.ContextStore; +import datadog.trace.bootstrap.instrumentation.api.AgentScope; +import datadog.trace.bootstrap.instrumentation.api.AgentSpan; +import datadog.trace.bootstrap.instrumentation.websocket.HandlerContext; +import datadog.trace.instrumentation.netty38.ChannelTraceContext; +import org.jboss.netty.channel.Channel; +import org.jboss.netty.channel.ChannelHandlerContext; +import org.jboss.netty.channel.MessageEvent; +import org.jboss.netty.channel.SimpleChannelUpstreamHandler; +import org.jboss.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import org.jboss.netty.handler.codec.http.websocketx.CloseWebSocketFrame; +import org.jboss.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame; +import org.jboss.netty.handler.codec.http.websocketx.TextWebSocketFrame; +import org.jboss.netty.handler.codec.http.websocketx.WebSocketFrame; + +public class WebSocketServerRequestTracingHandler extends SimpleChannelUpstreamHandler { + + private final ContextStore contextStore; + + public WebSocketServerRequestTracingHandler( + final ContextStore contextStore) { + this.contextStore = contextStore; + } + + @Override + public void messageReceived(ChannelHandlerContext ctx, MessageEvent event) throws Exception { + Object frame = event.getMessage(); + if (frame instanceof WebSocketFrame) { + Channel channel = ctx.getChannel(); + + ChannelTraceContext traceContext = this.contextStore.get(channel); + if (traceContext != null) { + + HandlerContext.Receiver receiverContext = traceContext.getReceiverHandlerContext(); + if (receiverContext == null) { + HandlerContext.Sender sessionState = traceContext.getSenderHandlerContext(); + if (sessionState != null) { + receiverContext = + new HandlerContext.Receiver( + sessionState.getHandshakeSpan(), channel.getId().toString()); + traceContext.setReceiverHandlerContext(receiverContext); + } + } + if (receiverContext != null) { + if (frame instanceof TextWebSocketFrame) { + // WebSocket Read Text Start + TextWebSocketFrame textFrame = (TextWebSocketFrame) frame; + + final AgentSpan span = + DECORATE.onReceiveFrameStart( + receiverContext, textFrame.getText(), textFrame.isFinalFragment()); + try (final AgentScope scope = activateSpan(span)) { + ctx.sendUpstream(event); + // WebSocket Read Text Start + } finally { + if (textFrame.isFinalFragment()) { + traceContext.setReceiverHandlerContext(null); + DECORATE.onFrameEnd(receiverContext); + } + } + return; + } + + if (frame instanceof BinaryWebSocketFrame) { + // WebSocket Read Binary Start + BinaryWebSocketFrame binaryFrame = (BinaryWebSocketFrame) frame; + final AgentSpan span = + DECORATE.onReceiveFrameStart( + receiverContext, + binaryFrame.getBinaryData().array(), + binaryFrame.isFinalFragment()); + try (final AgentScope scope = activateSpan(span)) { + ctx.sendUpstream(event); + } finally { + // WebSocket Read Binary End + if (binaryFrame.isFinalFragment()) { + traceContext.setReceiverHandlerContext(null); + DECORATE.onFrameEnd(receiverContext); + } + } + + return; + } + + if (frame instanceof ContinuationWebSocketFrame) { + ContinuationWebSocketFrame continuationWebSocketFrame = + (ContinuationWebSocketFrame) frame; + final AgentSpan span = + DECORATE.onReceiveFrameStart( + receiverContext, + MESSAGE_TYPE_TEXT.equals(receiverContext.getMessageType()) + ? continuationWebSocketFrame.getText() + : continuationWebSocketFrame.getBinaryData().array(), + continuationWebSocketFrame.isFinalFragment()); + try (final AgentScope scope = activateSpan(span)) { + ctx.sendUpstream(event); + } finally { + if (continuationWebSocketFrame.isFinalFragment()) { + traceContext.setReceiverHandlerContext(null); + DECORATE.onFrameEnd(receiverContext); + } + } + return; + } + + if (frame instanceof CloseWebSocketFrame) { + // WebSocket Closed by client + CloseWebSocketFrame closeFrame = (CloseWebSocketFrame) frame; + int statusCode = closeFrame.getStatusCode(); + String reasonText = closeFrame.getReasonText(); + traceContext.setSenderHandlerContext(null); + traceContext.setReceiverHandlerContext(null); + final AgentSpan span = + DECORATE.onSessionCloseReceived(receiverContext, reasonText, statusCode); + try (final AgentScope scope = activateSpan(span)) { + ctx.sendUpstream(event); + if (closeFrame.isFinalFragment()) { + DECORATE.onFrameEnd(receiverContext); + } + } + return; + } + } + } + } + + ctx.sendUpstream(event); // superclass does not throw + } +} diff --git a/dd-java-agent/instrumentation/netty-3.8/src/main/java/datadog/trace/instrumentation/netty38/server/websocket/WebSocketServerResponseTracingHandler.java b/dd-java-agent/instrumentation/netty-3.8/src/main/java/datadog/trace/instrumentation/netty38/server/websocket/WebSocketServerResponseTracingHandler.java new file mode 100644 index 00000000000..637d1a4abca --- /dev/null +++ b/dd-java-agent/instrumentation/netty-3.8/src/main/java/datadog/trace/instrumentation/netty38/server/websocket/WebSocketServerResponseTracingHandler.java @@ -0,0 +1,123 @@ +package datadog.trace.instrumentation.netty38.server.websocket; + +import static datadog.trace.bootstrap.instrumentation.api.AgentTracer.activateSpan; +import static datadog.trace.bootstrap.instrumentation.decorator.WebsocketDecorator.DECORATE; +import static datadog.trace.bootstrap.instrumentation.websocket.HandlersExtractor.MESSAGE_TYPE_BINARY; +import static datadog.trace.bootstrap.instrumentation.websocket.HandlersExtractor.MESSAGE_TYPE_TEXT; + +import datadog.trace.bootstrap.ContextStore; +import datadog.trace.bootstrap.instrumentation.api.AgentScope; +import datadog.trace.bootstrap.instrumentation.api.AgentSpan; +import datadog.trace.bootstrap.instrumentation.websocket.HandlerContext; +import datadog.trace.instrumentation.netty38.ChannelTraceContext; +import org.jboss.netty.channel.Channel; +import org.jboss.netty.channel.ChannelHandlerContext; +import org.jboss.netty.channel.MessageEvent; +import org.jboss.netty.channel.SimpleChannelDownstreamHandler; +import org.jboss.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import org.jboss.netty.handler.codec.http.websocketx.CloseWebSocketFrame; +import org.jboss.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame; +import org.jboss.netty.handler.codec.http.websocketx.TextWebSocketFrame; +import org.jboss.netty.handler.codec.http.websocketx.WebSocketFrame; + +public class WebSocketServerResponseTracingHandler extends SimpleChannelDownstreamHandler { + + private final ContextStore contextStore; + + public WebSocketServerResponseTracingHandler( + final ContextStore contextStore) { + this.contextStore = contextStore; + } + + @Override + public void writeRequested(ChannelHandlerContext ctx, MessageEvent event) throws Exception { + Object frame = event.getMessage(); + if (frame instanceof WebSocketFrame) { + Channel channel = ctx.getChannel(); + + ChannelTraceContext traceContext = this.contextStore.get(channel); + if (traceContext != null) { + HandlerContext.Sender handlerContext = traceContext.getSenderHandlerContext(); + if (handlerContext != null) { + + if (frame instanceof TextWebSocketFrame) { + // WebSocket Write Text Start + TextWebSocketFrame textFrame = (TextWebSocketFrame) frame; + final AgentSpan span = + DECORATE.onSendFrameStart( + handlerContext, MESSAGE_TYPE_TEXT, textFrame.getText().length()); + try (final AgentScope scope = activateSpan(span)) { + ctx.sendDownstream(event); + } finally { + // WebSocket Write Text End + if (textFrame.isFinalFragment()) { + DECORATE.onFrameEnd(handlerContext); + } + } + return; + } + + if (frame instanceof BinaryWebSocketFrame) { + // WebSocket Write Binary Start + BinaryWebSocketFrame binaryFrame = (BinaryWebSocketFrame) frame; + final AgentSpan span = + DECORATE.onSendFrameStart( + handlerContext, + MESSAGE_TYPE_BINARY, + binaryFrame.getBinaryData().readableBytes()); + try (final AgentScope scope = activateSpan(span)) { + ctx.sendDownstream(event); + } finally { + // WebSocket Write Binary End + if (binaryFrame.isFinalFragment()) { + DECORATE.onFrameEnd(handlerContext); + } + } + return; + } + + if (frame instanceof ContinuationWebSocketFrame) { + ContinuationWebSocketFrame continuationWebSocketFrame = + (ContinuationWebSocketFrame) frame; + final AgentSpan span = + DECORATE.onSendFrameStart( + handlerContext, + handlerContext.getMessageType(), + MESSAGE_TYPE_TEXT.equals(handlerContext.getMessageType()) + ? continuationWebSocketFrame.getText().length() + : continuationWebSocketFrame.getBinaryData().readableBytes()); + try (final AgentScope scope = activateSpan(span)) { + ctx.sendDownstream(event); + } finally { + // WebSocket Write Binary End + if (continuationWebSocketFrame.isFinalFragment()) { + DECORATE.onFrameEnd(handlerContext); + } + } + return; + } + + if (frame instanceof CloseWebSocketFrame) { + // WebSocket Closed by Server + CloseWebSocketFrame closeFrame = (CloseWebSocketFrame) frame; + int statusCode = closeFrame.getStatusCode(); + String reasonText = closeFrame.getReasonText(); + traceContext.setSenderHandlerContext(null); + final AgentSpan span = + DECORATE.onSessionCloseIssued(handlerContext, reasonText, statusCode); + try (final AgentScope scope = activateSpan(span)) { + ctx.sendDownstream(event); + } finally { + if (closeFrame.isFinalFragment()) { + DECORATE.onFrameEnd(handlerContext); + } + } + return; + } + } + } + } + // can be other messages we do not handle like ping, pong + ctx.sendDownstream(event); + } +} diff --git a/dd-java-agent/instrumentation/netty-3.8/src/main/java/datadog/trace/instrumentation/netty38/server/websocket/WebSocketServerTracingHandler.java b/dd-java-agent/instrumentation/netty-3.8/src/main/java/datadog/trace/instrumentation/netty38/server/websocket/WebSocketServerTracingHandler.java new file mode 100644 index 00000000000..a1c817bb674 --- /dev/null +++ b/dd-java-agent/instrumentation/netty-3.8/src/main/java/datadog/trace/instrumentation/netty38/server/websocket/WebSocketServerTracingHandler.java @@ -0,0 +1,18 @@ +package datadog.trace.instrumentation.netty38.server.websocket; + +import datadog.trace.bootstrap.ContextStore; +import datadog.trace.instrumentation.netty38.ChannelTraceContext; +import datadog.trace.instrumentation.netty38.util.CombinedSimpleChannelHandler; +import org.jboss.netty.channel.Channel; + +public class WebSocketServerTracingHandler + extends CombinedSimpleChannelHandler< + WebSocketServerRequestTracingHandler, WebSocketServerResponseTracingHandler> { + + public WebSocketServerTracingHandler( + final ContextStore contextStore) { + super( + new WebSocketServerRequestTracingHandler(contextStore), + new WebSocketServerResponseTracingHandler(contextStore)); + } +} diff --git a/dd-java-agent/instrumentation/netty-3.8/src/test/groovy/datadog/trace/instrumentation/netty38/Netty38ServerTest.groovy b/dd-java-agent/instrumentation/netty-3.8/src/test/groovy/datadog/trace/instrumentation/netty38/Netty38ServerTest.groovy index 820df9c7d46..6facf8bcd37 100644 --- a/dd-java-agent/instrumentation/netty-3.8/src/test/groovy/datadog/trace/instrumentation/netty38/Netty38ServerTest.groovy +++ b/dd-java-agent/instrumentation/netty-3.8/src/test/groovy/datadog/trace/instrumentation/netty38/Netty38ServerTest.groovy @@ -1,5 +1,21 @@ package datadog.trace.instrumentation.netty38 +import datadog.trace.agent.test.base.WebsocketServer +import org.jboss.netty.channel.ChannelFutureListener +import org.jboss.netty.channel.ChannelStateEvent +import org.jboss.netty.channel.SimpleChannelHandler +import org.jboss.netty.handler.codec.http.HttpChunkAggregator +import org.jboss.netty.handler.codec.http.HttpRequestDecoder +import org.jboss.netty.handler.codec.http.HttpResponseEncoder +import org.jboss.netty.handler.codec.http.websocketx.BinaryWebSocketFrame +import org.jboss.netty.handler.codec.http.websocketx.CloseWebSocketFrame +import org.jboss.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame +import org.jboss.netty.handler.codec.http.websocketx.PingWebSocketFrame +import org.jboss.netty.handler.codec.http.websocketx.PongWebSocketFrame +import org.jboss.netty.handler.codec.http.websocketx.TextWebSocketFrame +import org.jboss.netty.handler.codec.http.websocketx.WebSocketFrame +import org.jboss.netty.handler.codec.http.websocketx.WebSocketServerHandshaker +import org.jboss.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.ERROR import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.EXCEPTION @@ -12,6 +28,7 @@ import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.REDIRE import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.SUCCESS import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.USER_BLOCK import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.forPath +import static datadog.trace.agent.test.utils.TraceUtils.runUnderTrace import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.CONTENT_LENGTH import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.CONTENT_TYPE import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.LOCATION @@ -34,7 +51,6 @@ import org.jboss.netty.channel.DownstreamMessageEvent import org.jboss.netty.channel.ExceptionEvent import org.jboss.netty.channel.FailedChannelFuture import org.jboss.netty.channel.MessageEvent -import org.jboss.netty.channel.SimpleChannelHandler import org.jboss.netty.channel.SucceededChannelFuture import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory import org.jboss.netty.handler.codec.http.DefaultHttpResponse @@ -42,7 +58,6 @@ import org.jboss.netty.handler.codec.http.HttpHeaders import org.jboss.netty.handler.codec.http.HttpRequest import org.jboss.netty.handler.codec.http.HttpResponse import org.jboss.netty.handler.codec.http.HttpResponseStatus -import org.jboss.netty.handler.codec.http.HttpServerCodec import org.jboss.netty.handler.logging.LoggingHandler import org.jboss.netty.logging.InternalLogLevel import org.jboss.netty.logging.InternalLoggerFactory @@ -62,12 +77,34 @@ abstract class Netty38ServerTest extends HttpServerTest { ChannelPipeline channelPipeline = new DefaultChannelPipeline() channelPipeline.addFirst("logger", LOGGING_HANDLER) - channelPipeline.addLast("http-codec", new HttpServerCodec()) + channelPipeline.addLast("decoder", new HttpRequestDecoder()) + channelPipeline.addLast("encoder", new HttpResponseEncoder()) + channelPipeline.addLast("aggregator", new HttpChunkAggregator(65536)) channelPipeline.addLast("controller", new SimpleChannelHandler() { + WebSocketServerHandshaker handshaker + @Override void messageReceived(ChannelHandlerContext ctx, MessageEvent msg) throws Exception { if (msg.getMessage() instanceof HttpRequest) { def request = msg.getMessage() as HttpRequest + + def upgradeHeader = request.headers().get("Upgrade") + if (upgradeHeader && upgradeHeader.equalsIgnoreCase("websocket")) { + // Handshake + def host = request.headers().get("Host") + String wsLocation = "ws://" + host + request.uri + WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory( + wsLocation, null, false) + this.handshaker = wsFactory.newHandshaker(request) + if (this.handshaker == null) { + wsFactory.sendUnsupportedWebSocketVersionResponse(ctx.getChannel()) + } else { + this.handshaker.handshake(ctx.getChannel(), request) + WsEndpoint.onOpen(ctx) + } + return + } + if (HttpHeaders.is100ContinueExpected(request)) { ctx.sendDownstream(new DownstreamMessageEvent(ctx.getChannel(), new SucceededChannelFuture(ctx.getChannel()), new DefaultHttpResponse(HTTP_1_1, HttpResponseStatus.CONTINUE), @@ -131,6 +168,20 @@ abstract class Netty38ServerTest extends HttpServerTest { response, ctx.getChannel().getRemoteAddress()) } + } else if (msg.getMessage() instanceof WebSocketFrame) { + def frame = msg.getMessage() as WebSocketFrame + + if (frame instanceof CloseWebSocketFrame) { + this.handshaker.close(ctx.getChannel(), (CloseWebSocketFrame) frame) + } else if (frame instanceof PingWebSocketFrame) { + ctx.getChannel().write(new PongWebSocketFrame(frame.getBinaryData())) + } else if (frame instanceof TextWebSocketFrame || frame instanceof BinaryWebSocketFrame || frame instanceof ContinuationWebSocketFrame) { + // generate a child span. The websocket test expects this way + runUnderTrace("onRead", {}) + } else { + throw new UnsupportedOperationException(String.format("%s frame types not supported", frame.getClass() + .getName())) + } } } @@ -148,12 +199,18 @@ abstract class Netty38ServerTest extends HttpServerTest { response, ctx.getChannel().getRemoteAddress())) } + + @Override + void channelClosed(ChannelHandlerContext ctx, ChannelStateEvent e) throws Exception { + WsEndpoint.onClose() + ctx.sendDownstream(e) + } }) return channelPipeline } - private class NettyServer implements HttpServer { + private class NettyServer implements WebsocketServer { final ServerBootstrap server = new ServerBootstrap(new NioServerSocketChannelFactory()) int port = 0 @@ -182,6 +239,46 @@ abstract class Netty38ServerTest extends HttpServerTest { URI address() { return new URI("http://localhost:$port/") } + + @Override + void awaitConnected() { + while (WsEndpoint.activeSession == null) { + synchronized (WsEndpoint) { + WsEndpoint.wait() + } + } + } + + @Override + void serverSendText(String[] messages) { + WsEndpoint.activeSession.getChannel().write(new TextWebSocketFrame(messages.length == 1, 0, messages[0])) + for (def i = 1; i < messages.length; i++) { + WsEndpoint.activeSession.getChannel().write(new ContinuationWebSocketFrame(messages.length - 1 == i, 0, messages[i])) + } + } + + @Override + void serverSendBinary(byte[][] binaries) { + WsEndpoint.activeSession.getChannel().write(new BinaryWebSocketFrame(binaries.length == 1, 0, ChannelBuffers.copiedBuffer(binaries[0]))) + for (def i = 1; i < binaries.length; i++) { + WsEndpoint.activeSession.getChannel().write(new ContinuationWebSocketFrame(binaries.length - 1 == i, 0, ChannelBuffers.copiedBuffer(binaries[i]))) + } + } + + @Override + void serverClose() { + WsEndpoint.activeSession.getChannel().write(new CloseWebSocketFrame(1000, null)).addListener(ChannelFutureListener.CLOSE) + } + + @Override + void setMaxPayloadSize(int size) { + // not applicable + } + + @Override + boolean canSplitLargeWebsocketPayloads() { + false + } } @Override @@ -226,3 +323,18 @@ class Netty38ServerV0Test extends Netty38ServerTest implements TestingNettyHttpN class Netty38ServerV1ForkedTest extends Netty38ServerTest implements TestingNettyHttpNamingConventions.ServerV1 { } + +class WsEndpoint { + static volatile ChannelHandlerContext activeSession + + static void onOpen(ChannelHandlerContext session) { + activeSession = session + synchronized (WsEndpoint) { + WsEndpoint.notifyAll() + } + } + + static void onClose() { + activeSession = null + } +} diff --git a/dd-java-agent/instrumentation/netty-4.0/src/main/java/datadog/trace/instrumentation/netty40/AttributeKeys.java b/dd-java-agent/instrumentation/netty-4.0/src/main/java/datadog/trace/instrumentation/netty40/AttributeKeys.java index eff307e153e..e08f930e495 100644 --- a/dd-java-agent/instrumentation/netty-4.0/src/main/java/datadog/trace/instrumentation/netty40/AttributeKeys.java +++ b/dd-java-agent/instrumentation/netty-4.0/src/main/java/datadog/trace/instrumentation/netty40/AttributeKeys.java @@ -5,6 +5,7 @@ import datadog.trace.api.GenericClassValue; import datadog.trace.bootstrap.instrumentation.api.AgentScope; import datadog.trace.bootstrap.instrumentation.api.AgentSpan; +import datadog.trace.bootstrap.instrumentation.websocket.HandlerContext; import io.netty.handler.codec.http.HttpHeaders; import io.netty.util.AttributeKey; import java.util.concurrent.ConcurrentHashMap; @@ -33,6 +34,14 @@ public final class AttributeKeys { public static final AttributeKey BLOCKED_RESPONSE_KEY = new AttributeKey<>("datadog.server.blocked_response"); + public static final AttributeKey CHANNEL_ID = attributeKey("io.netty.channel.id"); + + public static final AttributeKey WEBSOCKET_SENDER_HANDLER_CONTEXT = + attributeKey("datadog.server.websocket.sender.handler_context"); + + public static final AttributeKey WEBSOCKET_RECEIVER_HANDLER_CONTEXT = + attributeKey("datadog.server.websocket.receiver.handler_context"); + /** * Generate an attribute key or reuse the one existing in the global app map. This implementation * creates attributes only once even if the current class is loaded by several class loaders and diff --git a/dd-java-agent/instrumentation/netty-4.0/src/main/java/datadog/trace/instrumentation/netty40/NettyChannelPipelineInstrumentation.java b/dd-java-agent/instrumentation/netty-4.0/src/main/java/datadog/trace/instrumentation/netty40/NettyChannelPipelineInstrumentation.java index 9fd97dcc600..04b930fcb03 100644 --- a/dd-java-agent/instrumentation/netty-4.0/src/main/java/datadog/trace/instrumentation/netty40/NettyChannelPipelineInstrumentation.java +++ b/dd-java-agent/instrumentation/netty-4.0/src/main/java/datadog/trace/instrumentation/netty40/NettyChannelPipelineInstrumentation.java @@ -13,6 +13,7 @@ import com.google.auto.service.AutoService; import datadog.trace.agent.tooling.Instrumenter; import datadog.trace.agent.tooling.InstrumenterModule; +import datadog.trace.api.InstrumenterConfig; import datadog.trace.bootstrap.CallDepthThreadLocalMap; import datadog.trace.bootstrap.instrumentation.api.AgentScope; import datadog.trace.instrumentation.netty40.client.HttpClientRequestTracingHandler; @@ -22,8 +23,10 @@ import datadog.trace.instrumentation.netty40.server.HttpServerResponseTracingHandler; import datadog.trace.instrumentation.netty40.server.HttpServerTracingHandler; import datadog.trace.instrumentation.netty40.server.MaybeBlockResponseHandler; +import datadog.trace.instrumentation.netty40.server.websocket.WebSocketServerRequestTracingHandler; +import datadog.trace.instrumentation.netty40.server.websocket.WebSocketServerResponseTracingHandler; +import datadog.trace.instrumentation.netty40.server.websocket.WebSocketServerTracingHandler; import io.netty.channel.ChannelHandler; -import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPipeline; import io.netty.handler.codec.http.HttpClientCodec; import io.netty.handler.codec.http.HttpRequestDecoder; @@ -31,6 +34,7 @@ import io.netty.handler.codec.http.HttpResponseDecoder; import io.netty.handler.codec.http.HttpResponseEncoder; import io.netty.handler.codec.http.HttpServerCodec; +import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; import io.netty.util.Attribute; import net.bytebuddy.asm.Advice; import net.bytebuddy.description.type.TypeDescription; @@ -77,6 +81,10 @@ public String[] helperClassNames() { packageName + ".server.HttpServerResponseTracingHandler", packageName + ".server.HttpServerTracingHandler", packageName + ".server.MaybeBlockResponseHandler", + packageName + ".server.websocket.WebSocketServerTracingHandler", + packageName + ".server.websocket.WebSocketServerResponseTracingHandler", + packageName + ".server.websocket.WebSocketServerRequestTracingHandler", + packageName + ".NettyPipelineHelper" }; } @@ -136,44 +144,53 @@ public static void addHandler( handler2 instanceof ChannelHandler ? (ChannelHandler) handler2 : handler3; try { - ChannelHandler toAdd = null; - ChannelHandler toAdd2 = null; // Server pipeline handlers if (handler instanceof HttpServerCodec) { - toAdd = new HttpServerTracingHandler(); - toAdd2 = MaybeBlockResponseHandler.INSTANCE; + NettyPipelineHelper.addHandlerAfter( + pipeline, + handler, + new HttpServerTracingHandler(), + MaybeBlockResponseHandler.INSTANCE); } else if (handler instanceof HttpRequestDecoder) { - toAdd = HttpServerRequestTracingHandler.INSTANCE; + NettyPipelineHelper.addHandlerAfter( + pipeline, handler, HttpServerRequestTracingHandler.INSTANCE); } else if (handler instanceof HttpResponseEncoder) { - toAdd = HttpServerResponseTracingHandler.INSTANCE; - toAdd2 = MaybeBlockResponseHandler.INSTANCE; + NettyPipelineHelper.addHandlerAfter( + pipeline, + handler, + HttpServerResponseTracingHandler.INSTANCE, + MaybeBlockResponseHandler.INSTANCE); + } else if (handler instanceof WebSocketServerProtocolHandler) { + if (InstrumenterConfig.get().isWebsocketTracingEnabled()) { + if (pipeline.get(HttpServerTracingHandler.class) != null) { + NettyPipelineHelper.addHandlerAfter( + pipeline, + "datadog.trace.instrumentation.netty40.server.HttpServerTracingHandler", + new WebSocketServerTracingHandler()); + } + if (pipeline.get(HttpRequestDecoder.class) != null) { + NettyPipelineHelper.addHandlerAfter( + pipeline, + "datadog.trace.instrumentation.netty40.server.HttpServerRequestTracingHandler", + WebSocketServerRequestTracingHandler.INSTANCE); + } + if (pipeline.get(HttpResponseEncoder.class) != null) { + NettyPipelineHelper.addHandlerAfter( + pipeline, + "datadog.trace.instrumentation.netty40.server.HttpServerResponseTracingHandler", + WebSocketServerResponseTracingHandler.INSTANCE); + } + } } else // Client pipeline handlers if (handler instanceof HttpClientCodec) { - toAdd = new HttpClientTracingHandler(); + NettyPipelineHelper.addHandlerAfter(pipeline, handler, new HttpClientTracingHandler()); } else if (handler instanceof HttpRequestEncoder) { - toAdd = HttpClientRequestTracingHandler.INSTANCE; + NettyPipelineHelper.addHandlerAfter( + pipeline, handler, HttpClientRequestTracingHandler.INSTANCE); } else if (handler instanceof HttpResponseDecoder) { - toAdd = HttpClientResponseTracingHandler.INSTANCE; - } - if (toAdd != null) { - // Get the name so we can add immediately following - ChannelHandlerContext handlerContext = pipeline.context(handler); - if (handlerContext != null) { - String handlerName = handlerContext.name(); - ChannelHandler existing = pipeline.get(toAdd.getClass()); - if (existing != null) { - pipeline.remove(existing); - } - pipeline.addAfter(handlerName, toAdd.getClass().getName(), toAdd); - if (toAdd2 != null) { - ChannelHandler existing2 = pipeline.get(toAdd2.getClass()); - if (existing2 != null) { - pipeline.remove(existing2); - } - pipeline.addAfter(toAdd.getClass().getName(), toAdd2.getClass().getName(), toAdd2); - } - } + NettyPipelineHelper.addHandlerAfter( + pipeline, handler, HttpClientResponseTracingHandler.INSTANCE); } } catch (final IllegalArgumentException e) { // Prevented adding duplicate handlers. diff --git a/dd-java-agent/instrumentation/netty-4.0/src/main/java/datadog/trace/instrumentation/netty40/NettyPipelineHelper.java b/dd-java-agent/instrumentation/netty-4.0/src/main/java/datadog/trace/instrumentation/netty40/NettyPipelineHelper.java new file mode 100644 index 00000000000..1bda9760ef4 --- /dev/null +++ b/dd-java-agent/instrumentation/netty-4.0/src/main/java/datadog/trace/instrumentation/netty40/NettyPipelineHelper.java @@ -0,0 +1,32 @@ +package datadog.trace.instrumentation.netty40; + +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; + +public class NettyPipelineHelper { + public static void addHandlerAfter( + final ChannelPipeline pipeline, final String name, final ChannelHandler... toAdd) { + String targetHandlerName = name; + for (ChannelHandler handler : toAdd) { + ChannelHandler existing = pipeline.get(handler.getClass()); + if (existing != null) { + pipeline.remove(existing); + } + pipeline.addAfter(targetHandlerName, handler.getClass().getName(), handler); + ChannelHandlerContext handlerContext = pipeline.context(handler); + if (handlerContext != null) { + targetHandlerName = handlerContext.name(); + } + } + } + + public static void addHandlerAfter( + final ChannelPipeline pipeline, final ChannelHandler handler, final ChannelHandler... toAdd) { + ChannelHandlerContext handlerContext = pipeline.context(handler); + if (handlerContext != null) { + String handlerName = handlerContext.name(); + addHandlerAfter(pipeline, handlerName, toAdd); + } + } +} diff --git a/dd-java-agent/instrumentation/netty-4.0/src/main/java/datadog/trace/instrumentation/netty40/server/HttpServerResponseTracingHandler.java b/dd-java-agent/instrumentation/netty-4.0/src/main/java/datadog/trace/instrumentation/netty40/server/HttpServerResponseTracingHandler.java index eb30caa5043..47399950b44 100644 --- a/dd-java-agent/instrumentation/netty-4.0/src/main/java/datadog/trace/instrumentation/netty40/server/HttpServerResponseTracingHandler.java +++ b/dd-java-agent/instrumentation/netty-4.0/src/main/java/datadog/trace/instrumentation/netty40/server/HttpServerResponseTracingHandler.java @@ -1,21 +1,26 @@ package datadog.trace.instrumentation.netty40.server; import static datadog.trace.bootstrap.instrumentation.api.AgentTracer.activateSpan; +import static datadog.trace.instrumentation.netty40.AttributeKeys.CHANNEL_ID; import static datadog.trace.instrumentation.netty40.AttributeKeys.SPAN_ATTRIBUTE_KEY; +import static datadog.trace.instrumentation.netty40.AttributeKeys.WEBSOCKET_SENDER_HANDLER_CONTEXT; import static datadog.trace.instrumentation.netty40.server.NettyHttpServerDecorator.DECORATE; import datadog.trace.bootstrap.instrumentation.api.AgentScope; import datadog.trace.bootstrap.instrumentation.api.AgentSpan; +import datadog.trace.bootstrap.instrumentation.websocket.HandlerContext; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelOutboundHandlerAdapter; import io.netty.channel.ChannelPromise; import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponseStatus; +import java.util.UUID; @ChannelHandler.Sharable public class HttpServerResponseTracingHandler extends ChannelOutboundHandlerAdapter { public static HttpServerResponseTracingHandler INSTANCE = new HttpServerResponseTracingHandler(); + private static final String UPGRADE_HEADER = "upgrade"; @Override public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise prm) { @@ -37,7 +42,21 @@ public void write(final ChannelHandlerContext ctx, final Object msg, final Chann ctx.channel().attr(SPAN_ATTRIBUTE_KEY).remove(); throw throwable; } - if (response.getStatus() != HttpResponseStatus.CONTINUE) { + final boolean isWebsocketUpgrade = + response.getStatus() == HttpResponseStatus.SWITCHING_PROTOCOLS + && "websocket".equals(response.headers().get(UPGRADE_HEADER)); + if (isWebsocketUpgrade) { + String channelId = + ctx.channel() + .attr(CHANNEL_ID) + .setIfAbsent(UUID.randomUUID().toString().substring(0, 8)); + ctx.channel() + .attr(WEBSOCKET_SENDER_HANDLER_CONTEXT) + .set(new HandlerContext.Sender(span, channelId)); + } + if (response.getStatus() != HttpResponseStatus.CONTINUE + && (response.getStatus() != HttpResponseStatus.SWITCHING_PROTOCOLS + || isWebsocketUpgrade)) { DECORATE.onResponse(span, response); DECORATE.beforeFinish(span); ctx.channel().attr(SPAN_ATTRIBUTE_KEY).remove(); diff --git a/dd-java-agent/instrumentation/netty-4.0/src/main/java/datadog/trace/instrumentation/netty40/server/websocket/WebSocketServerRequestTracingHandler.java b/dd-java-agent/instrumentation/netty-4.0/src/main/java/datadog/trace/instrumentation/netty40/server/websocket/WebSocketServerRequestTracingHandler.java new file mode 100644 index 00000000000..6ec281b3ae9 --- /dev/null +++ b/dd-java-agent/instrumentation/netty-4.0/src/main/java/datadog/trace/instrumentation/netty40/server/websocket/WebSocketServerRequestTracingHandler.java @@ -0,0 +1,122 @@ +package datadog.trace.instrumentation.netty40.server.websocket; + +import static datadog.trace.bootstrap.instrumentation.api.AgentTracer.activateSpan; +import static datadog.trace.bootstrap.instrumentation.decorator.WebsocketDecorator.DECORATE; +import static datadog.trace.bootstrap.instrumentation.websocket.HandlersExtractor.MESSAGE_TYPE_TEXT; +import static datadog.trace.instrumentation.netty40.AttributeKeys.*; + +import datadog.trace.bootstrap.instrumentation.api.AgentScope; +import datadog.trace.bootstrap.instrumentation.api.AgentSpan; +import datadog.trace.bootstrap.instrumentation.websocket.HandlerContext; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.http.websocketx.*; + +@ChannelHandler.Sharable +public class WebSocketServerRequestTracingHandler extends ChannelInboundHandlerAdapter { + public static WebSocketServerRequestTracingHandler INSTANCE = + new WebSocketServerRequestTracingHandler(); + + @Override + public void channelRead(ChannelHandlerContext ctx, Object frame) { + + if (frame instanceof WebSocketFrame) { + Channel channel = ctx.channel(); + HandlerContext.Receiver receiverContext = + channel.attr(WEBSOCKET_RECEIVER_HANDLER_CONTEXT).get(); + + if (receiverContext == null) { + HandlerContext.Sender sessionState = channel.attr(WEBSOCKET_SENDER_HANDLER_CONTEXT).get(); + if (sessionState != null) { + String channelId = ctx.channel().attr(CHANNEL_ID).get(); + receiverContext = new HandlerContext.Receiver(sessionState.getHandshakeSpan(), channelId); + channel.attr(WEBSOCKET_RECEIVER_HANDLER_CONTEXT).set(receiverContext); + } + } + if (receiverContext != null) { + if (frame instanceof TextWebSocketFrame) { + // WebSocket Read Text Start + TextWebSocketFrame textFrame = (TextWebSocketFrame) frame; + + final AgentSpan span = + DECORATE.onReceiveFrameStart( + receiverContext, textFrame.text(), textFrame.isFinalFragment()); + try (final AgentScope scope = activateSpan(span)) { + ctx.fireChannelRead(textFrame); + // WebSocket Read Text Start + } finally { + if (textFrame.isFinalFragment()) { + channel.attr(WEBSOCKET_RECEIVER_HANDLER_CONTEXT).remove(); + DECORATE.onFrameEnd(receiverContext); + } + } + return; + } + + if (frame instanceof BinaryWebSocketFrame) { + // WebSocket Read Binary Start + BinaryWebSocketFrame binaryFrame = (BinaryWebSocketFrame) frame; + final AgentSpan span = + DECORATE.onReceiveFrameStart( + receiverContext, + binaryFrame.content().nioBuffer(), + binaryFrame.isFinalFragment()); + try (final AgentScope scope = activateSpan(span)) { + ctx.fireChannelRead(binaryFrame); + } finally { + // WebSocket Read Binary End + if (binaryFrame.isFinalFragment()) { + channel.attr(WEBSOCKET_RECEIVER_HANDLER_CONTEXT).remove(); + DECORATE.onFrameEnd(receiverContext); + } + } + + return; + } + + if (frame instanceof ContinuationWebSocketFrame) { + ContinuationWebSocketFrame continuationWebSocketFrame = + (ContinuationWebSocketFrame) frame; + final AgentSpan span = + DECORATE.onReceiveFrameStart( + receiverContext, + MESSAGE_TYPE_TEXT.equals(receiverContext.getMessageType()) + ? continuationWebSocketFrame.text() + : continuationWebSocketFrame.content().nioBuffer(), + continuationWebSocketFrame.isFinalFragment()); + try (final AgentScope scope = activateSpan(span)) { + ctx.fireChannelRead(continuationWebSocketFrame); + } finally { + if (continuationWebSocketFrame.isFinalFragment()) { + channel.attr(WEBSOCKET_RECEIVER_HANDLER_CONTEXT).remove(); + DECORATE.onFrameEnd(receiverContext); + } + } + return; + } + + if (frame instanceof CloseWebSocketFrame) { + // WebSocket Closed by client + CloseWebSocketFrame closeFrame = (CloseWebSocketFrame) frame; + int statusCode = closeFrame.statusCode(); + String reasonText = closeFrame.reasonText(); + channel.attr(WEBSOCKET_SENDER_HANDLER_CONTEXT).remove(); + channel.attr(WEBSOCKET_RECEIVER_HANDLER_CONTEXT).remove(); + final AgentSpan span = + DECORATE.onSessionCloseReceived(receiverContext, reasonText, statusCode); + try (final AgentScope scope = activateSpan(span)) { + ctx.fireChannelRead(closeFrame); + if (closeFrame.isFinalFragment()) { + DECORATE.onFrameEnd(receiverContext); + } + } + return; + } + } + } + // can be other messages we do not handle like ping, pong + ctx.fireChannelRead(frame); + } +} diff --git a/dd-java-agent/instrumentation/netty-4.0/src/main/java/datadog/trace/instrumentation/netty40/server/websocket/WebSocketServerResponseTracingHandler.java b/dd-java-agent/instrumentation/netty-4.0/src/main/java/datadog/trace/instrumentation/netty40/server/websocket/WebSocketServerResponseTracingHandler.java new file mode 100644 index 00000000000..bdea83f8bc2 --- /dev/null +++ b/dd-java-agent/instrumentation/netty-4.0/src/main/java/datadog/trace/instrumentation/netty40/server/websocket/WebSocketServerResponseTracingHandler.java @@ -0,0 +1,106 @@ +package datadog.trace.instrumentation.netty40.server.websocket; + +import static datadog.trace.bootstrap.instrumentation.api.AgentTracer.activateSpan; +import static datadog.trace.bootstrap.instrumentation.decorator.WebsocketDecorator.DECORATE; +import static datadog.trace.bootstrap.instrumentation.websocket.HandlersExtractor.MESSAGE_TYPE_BINARY; +import static datadog.trace.bootstrap.instrumentation.websocket.HandlersExtractor.MESSAGE_TYPE_TEXT; +import static datadog.trace.instrumentation.netty40.AttributeKeys.WEBSOCKET_SENDER_HANDLER_CONTEXT; + +import datadog.trace.bootstrap.instrumentation.api.AgentScope; +import datadog.trace.bootstrap.instrumentation.api.AgentSpan; +import datadog.trace.bootstrap.instrumentation.websocket.HandlerContext; +import io.netty.channel.*; +import io.netty.handler.codec.http.websocketx.*; + +@ChannelHandler.Sharable +public class WebSocketServerResponseTracingHandler extends ChannelOutboundHandlerAdapter { + public static WebSocketServerResponseTracingHandler INSTANCE = + new WebSocketServerResponseTracingHandler(); + + @Override + public void write(ChannelHandlerContext ctx, Object frame, ChannelPromise promise) + throws Exception { + + if (frame instanceof WebSocketFrame) { + Channel channel = ctx.channel(); + HandlerContext.Sender handlerContext = channel.attr(WEBSOCKET_SENDER_HANDLER_CONTEXT).get(); + if (handlerContext != null) { + + if (frame instanceof TextWebSocketFrame) { + // WebSocket Write Text Start + TextWebSocketFrame textFrame = (TextWebSocketFrame) frame; + final AgentSpan span = + DECORATE.onSendFrameStart( + handlerContext, MESSAGE_TYPE_TEXT, textFrame.text().length()); + try (final AgentScope scope = activateSpan(span)) { + ctx.write(frame, promise); + } finally { + // WebSocket Write Text End + if (textFrame.isFinalFragment()) { + DECORATE.onFrameEnd(handlerContext); + } + } + return; + } + + if (frame instanceof BinaryWebSocketFrame) { + // WebSocket Write Binary Start + BinaryWebSocketFrame binaryFrame = (BinaryWebSocketFrame) frame; + final AgentSpan span = + DECORATE.onSendFrameStart( + handlerContext, MESSAGE_TYPE_BINARY, binaryFrame.content().readableBytes()); + try (final AgentScope scope = activateSpan(span)) { + ctx.write(frame, promise); + } finally { + // WebSocket Write Binary End + if (binaryFrame.isFinalFragment()) { + DECORATE.onFrameEnd(handlerContext); + } + } + return; + } + + if (frame instanceof ContinuationWebSocketFrame) { + ContinuationWebSocketFrame continuationWebSocketFrame = + (ContinuationWebSocketFrame) frame; + final AgentSpan span = + DECORATE.onSendFrameStart( + handlerContext, + handlerContext.getMessageType(), + MESSAGE_TYPE_TEXT.equals(handlerContext.getMessageType()) + ? continuationWebSocketFrame.text().length() + : continuationWebSocketFrame.content().readableBytes()); + try (final AgentScope scope = activateSpan(span)) { + ctx.write(frame, promise); + } finally { + // WebSocket Write Binary End + if (continuationWebSocketFrame.isFinalFragment()) { + DECORATE.onFrameEnd(handlerContext); + } + } + return; + } + + if (frame instanceof CloseWebSocketFrame) { + // WebSocket Closed by Server + CloseWebSocketFrame closeFrame = (CloseWebSocketFrame) frame; + int statusCode = closeFrame.statusCode(); + String reasonText = closeFrame.reasonText(); + channel.attr(WEBSOCKET_SENDER_HANDLER_CONTEXT).remove(); + final AgentSpan span = + DECORATE.onSessionCloseIssued(handlerContext, reasonText, statusCode); + try (final AgentScope scope = activateSpan(span)) { + ctx.write(frame, promise); + } finally { + if (closeFrame.isFinalFragment()) { + DECORATE.onFrameEnd(handlerContext); + } + } + return; + } + } + } + // can be other messages we do not handle like ping, pong + ctx.write(frame, promise); + } +} diff --git a/dd-java-agent/instrumentation/netty-4.0/src/main/java/datadog/trace/instrumentation/netty40/server/websocket/WebSocketServerTracingHandler.java b/dd-java-agent/instrumentation/netty-4.0/src/main/java/datadog/trace/instrumentation/netty40/server/websocket/WebSocketServerTracingHandler.java new file mode 100644 index 00000000000..6e453c75dd1 --- /dev/null +++ b/dd-java-agent/instrumentation/netty-4.0/src/main/java/datadog/trace/instrumentation/netty40/server/websocket/WebSocketServerTracingHandler.java @@ -0,0 +1,14 @@ +package datadog.trace.instrumentation.netty40.server.websocket; + +import io.netty.channel.CombinedChannelDuplexHandler; + +public class WebSocketServerTracingHandler + extends CombinedChannelDuplexHandler< + WebSocketServerRequestTracingHandler, WebSocketServerResponseTracingHandler> { + + public WebSocketServerTracingHandler() { + super( + WebSocketServerRequestTracingHandler.INSTANCE, + WebSocketServerResponseTracingHandler.INSTANCE); + } +} diff --git a/dd-java-agent/instrumentation/netty-4.0/src/test/groovy/Netty40ServerTest.groovy b/dd-java-agent/instrumentation/netty-4.0/src/test/groovy/Netty40ServerTest.groovy index d478bfe4ea3..c4b007be56d 100644 --- a/dd-java-agent/instrumentation/netty-4.0/src/test/groovy/Netty40ServerTest.groovy +++ b/dd-java-agent/instrumentation/netty-4.0/src/test/groovy/Netty40ServerTest.groovy @@ -1,12 +1,15 @@ import datadog.appsec.api.blocking.Blocking import datadog.trace.agent.test.base.HttpServer import datadog.trace.agent.test.base.HttpServerTest +import datadog.trace.agent.test.base.WebsocketClient +import datadog.trace.agent.test.base.WebsocketServer import datadog.trace.agent.test.naming.TestingNettyHttpNamingConventions import datadog.trace.bootstrap.instrumentation.api.URIUtils import datadog.trace.instrumentation.netty40.server.NettyHttpServerDecorator import io.netty.bootstrap.ServerBootstrap import io.netty.buffer.ByteBuf import io.netty.buffer.Unpooled +import io.netty.channel.ChannelFutureListener import io.netty.channel.ChannelHandlerContext import io.netty.channel.ChannelInitializer import io.netty.channel.ChannelPipeline @@ -17,10 +20,16 @@ import io.netty.channel.socket.nio.NioServerSocketChannel import io.netty.handler.codec.http.DefaultFullHttpResponse import io.netty.handler.codec.http.FullHttpResponse import io.netty.handler.codec.http.HttpHeaders +import io.netty.handler.codec.http.HttpObjectAggregator import io.netty.handler.codec.http.HttpRequest import io.netty.handler.codec.http.HttpRequestDecoder import io.netty.handler.codec.http.HttpResponseEncoder import io.netty.handler.codec.http.HttpResponseStatus +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame +import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame +import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame +import io.netty.handler.codec.http.websocketx.TextWebSocketFrame +import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler import io.netty.handler.logging.LogLevel import io.netty.handler.logging.LoggingHandler import io.netty.util.CharsetUtil @@ -36,6 +45,7 @@ import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.QUERY_ import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.REDIRECT import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.SUCCESS import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.USER_BLOCK +import static datadog.trace.agent.test.utils.TraceUtils.runUnderTrace import static io.netty.handler.codec.http.HttpHeaders.Names.CONTENT_LENGTH import static io.netty.handler.codec.http.HttpHeaders.Names.CONTENT_TYPE import static io.netty.handler.codec.http.HttpHeaders.is100ContinueExpected @@ -47,7 +57,7 @@ abstract class Netty40ServerTest extends HttpServerTest { static final LoggingHandler LOGGING_HANDLER = new LoggingHandler(SERVER_LOGGER.name, LogLevel.DEBUG) - private class NettyServer implements HttpServer { + private class NettyServer implements WebsocketServer { final eventLoopGroup = new NioEventLoopGroup() int port = 0 @@ -61,8 +71,10 @@ abstract class Netty40ServerTest extends HttpServerTest { ChannelPipeline pipeline = ch.pipeline() pipeline.addFirst("logger", LOGGING_HANDLER) - def handlers = [new HttpRequestDecoder(), new HttpResponseEncoder()] - handlers.each { pipeline.addLast(it) } + pipeline.addLast(new HttpRequestDecoder()) + pipeline.addLast(new HttpResponseEncoder()) + pipeline.addLast(new HttpObjectAggregator(1024)) + pipeline.addLast(new WebSocketServerProtocolHandler("/websocket")) pipeline.addLast([ channelRead0 : { ChannelHandlerContext ctx, msg -> if (msg instanceof HttpRequest) { @@ -120,6 +132,10 @@ abstract class Netty40ServerTest extends HttpServerTest { return response } } + if (msg instanceof TextWebSocketFrame || msg instanceof BinaryWebSocketFrame || msg instanceof ContinuationWebSocketFrame) { + // generate a child span. The websocket test expects this way + runUnderTrace("onRead", {}) + } }, exceptionCaught : { ChannelHandlerContext ctx, Throwable cause -> ByteBuf content = Unpooled.copiedBuffer(cause.message, CharsetUtil.UTF_8) @@ -130,6 +146,14 @@ abstract class Netty40ServerTest extends HttpServerTest { }, channelReadComplete: { it.flush() + }, + userEventTriggered : { ChannelHandlerContext ctx, Object evt -> + if (evt == WebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE) { + WsEndpoint.onOpen(ctx) + } + }, + channelInactive : { ChannelHandlerContext ctx -> + WsEndpoint.onClose() } ] as SimpleChannelInboundHandler) } @@ -147,6 +171,46 @@ abstract class Netty40ServerTest extends HttpServerTest { URI address() { return new URI("http://localhost:$port/") } + + @Override + void awaitConnected() { + while (WsEndpoint.activeSession == null) { + synchronized (WsEndpoint) { + WsEndpoint.wait() + } + } + } + + @Override + void serverSendText(String[] messages) { + WsEndpoint.activeSession.writeAndFlush(new TextWebSocketFrame(messages.length == 1, 0, messages[0])) + for (def i = 1; i < messages.length; i++) { + WsEndpoint.activeSession.writeAndFlush(new ContinuationWebSocketFrame(messages.length - 1 == i, 0, messages[i])) + } + } + + @Override + void serverSendBinary(byte[][] binaries) { + WsEndpoint.activeSession.writeAndFlush(new BinaryWebSocketFrame(binaries.length == 1, 0, Unpooled.wrappedBuffer(binaries[0]))) + for (def i = 1; i < binaries.length; i++) { + WsEndpoint.activeSession.writeAndFlush(new ContinuationWebSocketFrame(binaries.length - 1 == i, 0, Unpooled.wrappedBuffer(binaries[i]))) + } + } + + @Override + void serverClose() { + WsEndpoint.activeSession.writeAndFlush(new CloseWebSocketFrame(1000, null)).addListener(ChannelFutureListener.CLOSE) + } + + @Override + void setMaxPayloadSize(int size) { + // not applicable + } + + @Override + boolean canSplitLargeWebsocketPayloads() { + false + } } @Override @@ -154,6 +218,11 @@ abstract class Netty40ServerTest extends HttpServerTest { return new NettyServer() } + @Override + WebsocketClient websocketClient() { + return new NettyWebsocketClient() + } + @Override String component() { NettyHttpServerDecorator.DECORATE.component() @@ -191,3 +260,18 @@ class Netty40ServerV0Test extends Netty40ServerTest implements TestingNettyHttpN class Netty40ServerV1ForkedTest extends Netty40ServerTest implements TestingNettyHttpNamingConventions.ServerV1 { } + +class WsEndpoint { + static volatile ChannelHandlerContext activeSession + + static void onOpen(ChannelHandlerContext session) { + activeSession = session + synchronized (WsEndpoint) { + WsEndpoint.notifyAll() + } + } + + static void onClose() { + activeSession = null + } +} diff --git a/dd-java-agent/instrumentation/netty-4.0/src/test/groovy/NettyWebsocketClient.groovy b/dd-java-agent/instrumentation/netty-4.0/src/test/groovy/NettyWebsocketClient.groovy new file mode 100644 index 00000000000..22ec350366e --- /dev/null +++ b/dd-java-agent/instrumentation/netty-4.0/src/test/groovy/NettyWebsocketClient.groovy @@ -0,0 +1,131 @@ +import datadog.trace.agent.test.base.WebsocketClient +import io.netty.bootstrap.Bootstrap +import io.netty.buffer.Unpooled +import io.netty.channel.Channel +import io.netty.channel.ChannelHandlerContext +import io.netty.channel.ChannelInitializer +import io.netty.channel.SimpleChannelInboundHandler +import io.netty.channel.nio.NioEventLoopGroup +import io.netty.channel.socket.nio.NioSocketChannel +import io.netty.handler.codec.http.DefaultHttpHeaders +import io.netty.handler.codec.http.FullHttpResponse +import io.netty.handler.codec.http.HttpClientCodec +import io.netty.handler.codec.http.HttpObjectAggregator +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame +import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame +import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame +import io.netty.handler.codec.http.websocketx.TextWebSocketFrame +import io.netty.handler.codec.http.websocketx.WebSocketClientHandshaker +import io.netty.handler.codec.http.websocketx.WebSocketVersion +import io.netty.handler.logging.LogLevel +import io.netty.handler.logging.LoggingHandler + +import java.nio.charset.StandardCharsets +import java.util.concurrent.CountDownLatch + +import static io.netty.handler.codec.http.websocketx.WebSocketClientHandshakerFactory.newHandshaker + +class NettyWebsocketClient implements WebsocketClient { + static final LoggingHandler LOGGING_HANDLER = new LoggingHandler(NettyWebsocketClient.class.getName(), LogLevel.DEBUG) + static class WebsocketHandler extends SimpleChannelInboundHandler { + final URI uri + WebSocketClientHandshaker handshaker + def handshaken = new CountDownLatch(1) + + WebsocketHandler(uri) { + this.uri = uri + } + + @Override + void channelActive(ChannelHandlerContext ctx) throws Exception { + handshaker = newHandshaker( + uri, WebSocketVersion.V13, null, false, new DefaultHttpHeaders() + .add("User-Agent", "dd-trace-java"), // keep me + 1280000) + handshaker.handshake(ctx.channel()) + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception { + final Channel ch = ctx.channel() + if (!handshaker.isHandshakeComplete()) { + // web socket client connected + handshaker.finishHandshake(ch, (FullHttpResponse) msg) + } + handshaken.countDown() + } + } + final eventLoopGroup = new NioEventLoopGroup() + + Channel channel + def chunkSize = -1 + + @Override + void connect(String url) { + def uri = new URI(url) + def wsHandler = new WebsocketHandler(uri) + Bootstrap b = new Bootstrap() + b.group(eventLoopGroup) + .handler(LOGGING_HANDLER) + .handler(new ChannelInitializer() { + protected void initChannel(Channel ch) throws Exception { + def pipeline = ch.pipeline() + pipeline.addLast(new HttpClientCodec()) + pipeline.addLast(new HttpObjectAggregator(1024)) + pipeline.addLast(wsHandler) + // remove our handler since we do not want to trace that client + pipeline.names().findAll { it.contains("HttpClientTracingHandler") }.each { pipeline.remove(it) } + } + }).channel(NioSocketChannel) + channel = b.connect(uri.host, uri.port).sync().channel() + //wait for the handshake to complete properly + wsHandler.handshaken.await() + } + + @Override + void send(String text) { + def chunks = split(text.getBytes(StandardCharsets.UTF_8)) + channel.writeAndFlush(new TextWebSocketFrame(chunks.length == 1, 0, Unpooled.wrappedBuffer(chunks[0]))) + for (def i = 1; i < chunks.length; i++) { + channel.writeAndFlush(new ContinuationWebSocketFrame(chunks.length - 1 == i, 0, Unpooled.wrappedBuffer(chunks[i]))) + } + } + + @Override + void send(byte[] bytes) { + def chunks = split(bytes) + channel.writeAndFlush(new BinaryWebSocketFrame(chunks.length == 1, 0, Unpooled.wrappedBuffer(chunks[0]))).sync() + for (def i = 1; i < chunks.length; i++) { + channel.writeAndFlush(new ContinuationWebSocketFrame(chunks.length - 1 == i, 0, Unpooled.wrappedBuffer(chunks[i]))).sync() + } + } + + byte[][] split(byte[] src) { + if (chunkSize <= 0) { + return new byte[][]{src} + } + def ret = new byte[(int) Math.ceil(src.length / chunkSize)][] + def offset = 0 + for (def i = 0; i < ret.length; i++) { + ret[i] = new byte[Math.min(src.length - offset, chunkSize)] + System.arraycopy(src, offset, ret[i], 0, ret[i].length) + } + ret + } + + @Override + void close(int code, String reason) { + channel.writeAndFlush(new CloseWebSocketFrame(code, reason)).sync() + channel.close() + } + + @Override + boolean supportMessageChunks() { + true + } + + @Override + void setSplitChunksAfter(int size) { + chunkSize = size + } +} diff --git a/dd-java-agent/instrumentation/netty-4.1-shared/src/main/java/datadog/trace/instrumentation/netty41/AttributeKeys.java b/dd-java-agent/instrumentation/netty-4.1-shared/src/main/java/datadog/trace/instrumentation/netty41/AttributeKeys.java index 74d124fb866..41ab0146d3d 100644 --- a/dd-java-agent/instrumentation/netty-4.1-shared/src/main/java/datadog/trace/instrumentation/netty41/AttributeKeys.java +++ b/dd-java-agent/instrumentation/netty-4.1-shared/src/main/java/datadog/trace/instrumentation/netty41/AttributeKeys.java @@ -5,6 +5,7 @@ import datadog.trace.api.GenericClassValue; import datadog.trace.bootstrap.instrumentation.api.AgentScope; import datadog.trace.bootstrap.instrumentation.api.AgentSpan; +import datadog.trace.bootstrap.instrumentation.websocket.HandlerContext; import io.netty.handler.codec.http.HttpHeaders; import io.netty.util.AttributeKey; import java.util.concurrent.ConcurrentHashMap; @@ -33,6 +34,12 @@ public final class AttributeKeys { public static final AttributeKey BLOCKED_RESPONSE_KEY = attributeKey("datadog.server.blocked_response"); + public static final AttributeKey WEBSOCKET_SENDER_HANDLER_CONTEXT = + attributeKey("datadog.server.websocket.sender.handler_context"); + + public static final AttributeKey WEBSOCKET_RECEIVER_HANDLER_CONTEXT = + attributeKey("datadog.server.websocket.receiver.handler_context"); + /** * Generate an attribute key or reuse the one existing in the global app map. This implementation * creates attributes only once even if the current class is loaded by several class loaders and diff --git a/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/ChannelFutureListenerInstrumentation.java b/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/ChannelFutureListenerInstrumentation.java index df528d3cf44..c9ecc54c19e 100644 --- a/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/ChannelFutureListenerInstrumentation.java +++ b/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/ChannelFutureListenerInstrumentation.java @@ -73,7 +73,7 @@ public void methodAdvice(MethodTransformer transformer) { } public static class OperationCompleteAdvice { - @Advice.OnMethodEnter + @Advice.OnMethodEnter(suppress = Throwable.class) public static AgentScope activateScope(@Advice.Argument(0) final ChannelFuture future) { /* Idea here is: diff --git a/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/NettyChannelPipelineInstrumentation.java b/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/NettyChannelPipelineInstrumentation.java index f824b6f63f1..4cfe703e194 100644 --- a/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/NettyChannelPipelineInstrumentation.java +++ b/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/NettyChannelPipelineInstrumentation.java @@ -13,6 +13,7 @@ import com.google.auto.service.AutoService; import datadog.trace.agent.tooling.Instrumenter; import datadog.trace.agent.tooling.InstrumenterModule; +import datadog.trace.api.InstrumenterConfig; import datadog.trace.bootstrap.CallDepthThreadLocalMap; import datadog.trace.bootstrap.instrumentation.api.AgentScope; import datadog.trace.instrumentation.netty41.client.HttpClientRequestTracingHandler; @@ -22,8 +23,10 @@ import datadog.trace.instrumentation.netty41.server.HttpServerResponseTracingHandler; import datadog.trace.instrumentation.netty41.server.HttpServerTracingHandler; import datadog.trace.instrumentation.netty41.server.MaybeBlockResponseHandler; +import datadog.trace.instrumentation.netty41.server.websocket.WebSocketServerRequestTracingHandler; +import datadog.trace.instrumentation.netty41.server.websocket.WebSocketServerResponseTracingHandler; +import datadog.trace.instrumentation.netty41.server.websocket.WebSocketServerTracingHandler; import io.netty.channel.ChannelHandler; -import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPipeline; import io.netty.handler.codec.http.HttpClientCodec; import io.netty.handler.codec.http.HttpRequestDecoder; @@ -31,6 +34,7 @@ import io.netty.handler.codec.http.HttpResponseDecoder; import io.netty.handler.codec.http.HttpResponseEncoder; import io.netty.handler.codec.http.HttpServerCodec; +import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; import io.netty.util.Attribute; import net.bytebuddy.asm.Advice; import net.bytebuddy.description.type.TypeDescription; @@ -77,7 +81,11 @@ public String[] helperClassNames() { packageName + ".server.HttpServerResponseTracingHandler", packageName + ".server.HttpServerTracingHandler", packageName + ".server.MaybeBlockResponseHandler", + packageName + ".server.websocket.WebSocketServerTracingHandler", + packageName + ".server.websocket.WebSocketServerResponseTracingHandler", + packageName + ".server.websocket.WebSocketServerRequestTracingHandler", packageName + ".NettyHttp2Helper", + packageName + ".NettyPipelineHelper", }; } @@ -137,50 +145,60 @@ public static void addHandler( handler2 instanceof ChannelHandler ? (ChannelHandler) handler2 : handler3; try { - ChannelHandler toAdd = null; - ChannelHandler toAdd2 = null; // Server pipeline handlers if (handler instanceof HttpServerCodec) { - toAdd = new HttpServerTracingHandler(); - toAdd2 = MaybeBlockResponseHandler.INSTANCE; + NettyPipelineHelper.addHandlerAfter( + pipeline, + handler, + new HttpServerTracingHandler(), + MaybeBlockResponseHandler.INSTANCE); } else if (handler instanceof HttpRequestDecoder) { - toAdd = HttpServerRequestTracingHandler.INSTANCE; + NettyPipelineHelper.addHandlerAfter( + pipeline, handler, HttpServerRequestTracingHandler.INSTANCE); } else if (handler instanceof HttpResponseEncoder) { - toAdd = HttpServerResponseTracingHandler.INSTANCE; - toAdd2 = MaybeBlockResponseHandler.INSTANCE; - } else + NettyPipelineHelper.addHandlerAfter( + pipeline, + handler, + HttpServerResponseTracingHandler.INSTANCE, + MaybeBlockResponseHandler.INSTANCE); + } else if (handler instanceof WebSocketServerProtocolHandler) { + if (InstrumenterConfig.get().isWebsocketTracingEnabled()) { + if (pipeline.get(HttpServerTracingHandler.class) != null) { + NettyPipelineHelper.addHandlerAfter( + pipeline, "HttpServerTracingHandler#0", new WebSocketServerTracingHandler()); + } + if (pipeline.get(HttpServerRequestTracingHandler.class) != null) { + NettyPipelineHelper.addHandlerAfter( + pipeline, + "HttpServerRequestTracingHandler#0", + WebSocketServerRequestTracingHandler.INSTANCE); + } + if (pipeline.get(HttpServerResponseTracingHandler.class) != null) { + NettyPipelineHelper.addHandlerAfter( + pipeline, + "HttpServerResponseTracingHandler#0", + WebSocketServerResponseTracingHandler.INSTANCE); + } + } + } // Client pipeline handlers - if (handler instanceof HttpClientCodec) { - toAdd = new HttpClientTracingHandler(); + else if (handler instanceof HttpClientCodec) { + NettyPipelineHelper.addHandlerAfter(pipeline, handler, new HttpClientTracingHandler()); } else if (handler instanceof HttpRequestEncoder) { - toAdd = HttpClientRequestTracingHandler.INSTANCE; + NettyPipelineHelper.addHandlerAfter( + pipeline, handler, HttpClientRequestTracingHandler.INSTANCE); } else if (handler instanceof HttpResponseDecoder) { - toAdd = HttpClientResponseTracingHandler.INSTANCE; + NettyPipelineHelper.addHandlerAfter( + pipeline, handler, HttpClientResponseTracingHandler.INSTANCE); } else if (NettyHttp2Helper.isHttp2FrameCodec(handler)) { if (NettyHttp2Helper.isServer(handler)) { - toAdd = new HttpServerTracingHandler(); - toAdd2 = MaybeBlockResponseHandler.INSTANCE; + NettyPipelineHelper.addHandlerAfter( + pipeline, + handler, + new HttpServerTracingHandler(), + MaybeBlockResponseHandler.INSTANCE); } else { - toAdd = new HttpClientTracingHandler(); - } - } - if (toAdd != null) { - // Get the name so we can add immediately following - ChannelHandlerContext handlerContext = pipeline.context(handler); - if (handlerContext != null) { - String handlerName = handlerContext.name(); - ChannelHandler existing = pipeline.get(toAdd.getClass()); - if (existing != null) { - pipeline.remove(existing); - } - pipeline.addAfter(handlerName, null, toAdd); - if (toAdd2 != null) { - ChannelHandler existing2 = pipeline.get(toAdd2.getClass()); - if (existing2 != null) { - pipeline.remove(existing2); - } - pipeline.addAfter(pipeline.context(toAdd).name(), null, toAdd2); - } + NettyPipelineHelper.addHandlerAfter(pipeline, handler, new HttpClientTracingHandler()); } } } catch (final IllegalArgumentException e) { diff --git a/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/NettyPipelineHelper.java b/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/NettyPipelineHelper.java new file mode 100644 index 00000000000..63947626b86 --- /dev/null +++ b/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/NettyPipelineHelper.java @@ -0,0 +1,32 @@ +package datadog.trace.instrumentation.netty41; + +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; + +public class NettyPipelineHelper { + public static void addHandlerAfter( + final ChannelPipeline pipeline, final String name, final ChannelHandler... toAdd) { + String handlerName = name; + for (ChannelHandler handler : toAdd) { + ChannelHandler existing = pipeline.get(handler.getClass()); + if (existing != null) { + pipeline.remove(existing); + } + pipeline.addAfter(handlerName, null, handler); + ChannelHandlerContext handlerContext = pipeline.context(handler); + if (handlerContext != null) { + handlerName = handlerContext.name(); + } + } + } + + public static void addHandlerAfter( + final ChannelPipeline pipeline, final ChannelHandler handler, final ChannelHandler... toAdd) { + ChannelHandlerContext handlerContext = pipeline.context(handler); + if (handlerContext != null) { + String handlerName = handlerContext.name(); + addHandlerAfter(pipeline, handlerName, toAdd); + } + } +} diff --git a/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/server/HttpServerResponseTracingHandler.java b/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/server/HttpServerResponseTracingHandler.java index 6a1b638dd11..07235193fc8 100644 --- a/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/server/HttpServerResponseTracingHandler.java +++ b/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/server/HttpServerResponseTracingHandler.java @@ -2,10 +2,12 @@ import static datadog.trace.bootstrap.instrumentation.api.AgentTracer.activateSpan; import static datadog.trace.instrumentation.netty41.AttributeKeys.SPAN_ATTRIBUTE_KEY; +import static datadog.trace.instrumentation.netty41.AttributeKeys.WEBSOCKET_SENDER_HANDLER_CONTEXT; import static datadog.trace.instrumentation.netty41.server.NettyHttpServerDecorator.DECORATE; import datadog.trace.bootstrap.instrumentation.api.AgentScope; import datadog.trace.bootstrap.instrumentation.api.AgentSpan; +import datadog.trace.bootstrap.instrumentation.websocket.HandlerContext; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelOutboundHandlerAdapter; @@ -39,9 +41,16 @@ public void write(final ChannelHandlerContext ctx, final Object msg, final Chann ctx.channel().attr(SPAN_ATTRIBUTE_KEY).remove(); throw throwable; } + final boolean isWebsocketUpgrade = + response.status() == HttpResponseStatus.SWITCHING_PROTOCOLS + && "websocket".equals(response.headers().get(HttpHeaderNames.UPGRADE)); + if (isWebsocketUpgrade) { + ctx.channel() + .attr(WEBSOCKET_SENDER_HANDLER_CONTEXT) + .set(new HandlerContext.Sender(span, ctx.channel().id().asShortText())); + } if (response.status() != HttpResponseStatus.CONTINUE - && (response.status() != HttpResponseStatus.SWITCHING_PROTOCOLS - || "websocket".equals(response.headers().get(HttpHeaderNames.UPGRADE)))) { + && (response.status() != HttpResponseStatus.SWITCHING_PROTOCOLS || isWebsocketUpgrade)) { DECORATE.onResponse(span, response); DECORATE.beforeFinish(span); span.finish(); // Finish the span manually since finishSpanOnClose was false diff --git a/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/server/websocket/WebSocketServerRequestTracingHandler.java b/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/server/websocket/WebSocketServerRequestTracingHandler.java new file mode 100644 index 00000000000..5556af0d2b2 --- /dev/null +++ b/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/server/websocket/WebSocketServerRequestTracingHandler.java @@ -0,0 +1,124 @@ +package datadog.trace.instrumentation.netty41.server.websocket; + +import static datadog.trace.bootstrap.instrumentation.api.AgentTracer.activateSpan; +import static datadog.trace.bootstrap.instrumentation.decorator.WebsocketDecorator.DECORATE; +import static datadog.trace.bootstrap.instrumentation.websocket.HandlersExtractor.MESSAGE_TYPE_TEXT; +import static datadog.trace.instrumentation.netty41.AttributeKeys.WEBSOCKET_RECEIVER_HANDLER_CONTEXT; +import static datadog.trace.instrumentation.netty41.AttributeKeys.WEBSOCKET_SENDER_HANDLER_CONTEXT; + +import datadog.trace.bootstrap.instrumentation.api.AgentScope; +import datadog.trace.bootstrap.instrumentation.api.AgentSpan; +import datadog.trace.bootstrap.instrumentation.websocket.HandlerContext; +import io.netty.channel.*; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; +import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame; +import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; + +@ChannelHandler.Sharable +public class WebSocketServerRequestTracingHandler extends ChannelInboundHandlerAdapter { + public static WebSocketServerRequestTracingHandler INSTANCE = + new WebSocketServerRequestTracingHandler(); + + @Override + public void channelRead(ChannelHandlerContext ctx, Object frame) { + + if (frame instanceof WebSocketFrame) { + Channel channel = ctx.channel(); + HandlerContext.Receiver receiverContext = + channel.attr(WEBSOCKET_RECEIVER_HANDLER_CONTEXT).get(); + if (receiverContext == null) { + HandlerContext.Sender sessionState = channel.attr(WEBSOCKET_SENDER_HANDLER_CONTEXT).get(); + if (sessionState != null) { + receiverContext = + new HandlerContext.Receiver( + sessionState.getHandshakeSpan(), channel.id().asShortText()); + channel.attr(WEBSOCKET_RECEIVER_HANDLER_CONTEXT).set(receiverContext); + } + } + if (receiverContext != null) { + if (frame instanceof TextWebSocketFrame) { + // WebSocket Read Text Start + TextWebSocketFrame textFrame = (TextWebSocketFrame) frame; + + final AgentSpan span = + DECORATE.onReceiveFrameStart( + receiverContext, textFrame.text(), textFrame.isFinalFragment()); + try (final AgentScope scope = activateSpan(span)) { + ctx.fireChannelRead(textFrame); + // WebSocket Read Text Start + } finally { + if (textFrame.isFinalFragment()) { + channel.attr(WEBSOCKET_RECEIVER_HANDLER_CONTEXT).remove(); + DECORATE.onFrameEnd(receiverContext); + } + } + return; + } + + if (frame instanceof BinaryWebSocketFrame) { + // WebSocket Read Binary Start + BinaryWebSocketFrame binaryFrame = (BinaryWebSocketFrame) frame; + final AgentSpan span = + DECORATE.onReceiveFrameStart( + receiverContext, + binaryFrame.content().nioBuffer(), + binaryFrame.isFinalFragment()); + try (final AgentScope scope = activateSpan(span)) { + ctx.fireChannelRead(binaryFrame); + } finally { + // WebSocket Read Binary End + if (binaryFrame.isFinalFragment()) { + channel.attr(WEBSOCKET_RECEIVER_HANDLER_CONTEXT).remove(); + DECORATE.onFrameEnd(receiverContext); + } + } + + return; + } + + if (frame instanceof ContinuationWebSocketFrame) { + ContinuationWebSocketFrame continuationWebSocketFrame = + (ContinuationWebSocketFrame) frame; + final AgentSpan span = + DECORATE.onReceiveFrameStart( + receiverContext, + MESSAGE_TYPE_TEXT.equals(receiverContext.getMessageType()) + ? continuationWebSocketFrame.text() + : continuationWebSocketFrame.content().nioBuffer(), + continuationWebSocketFrame.isFinalFragment()); + try (final AgentScope scope = activateSpan(span)) { + ctx.fireChannelRead(continuationWebSocketFrame); + } finally { + if (continuationWebSocketFrame.isFinalFragment()) { + channel.attr(WEBSOCKET_RECEIVER_HANDLER_CONTEXT).remove(); + DECORATE.onFrameEnd(receiverContext); + } + } + return; + } + + if (frame instanceof CloseWebSocketFrame) { + // WebSocket Closed by client + CloseWebSocketFrame closeFrame = (CloseWebSocketFrame) frame; + int statusCode = closeFrame.statusCode(); + String reasonText = closeFrame.reasonText(); + channel.attr(WEBSOCKET_SENDER_HANDLER_CONTEXT).remove(); + channel.attr(WEBSOCKET_RECEIVER_HANDLER_CONTEXT).remove(); + final AgentSpan span = + DECORATE.onSessionCloseReceived(receiverContext, reasonText, statusCode); + try (final AgentScope scope = activateSpan(span)) { + ctx.fireChannelRead(closeFrame); + if (closeFrame.isFinalFragment()) { + DECORATE.onFrameEnd(receiverContext); + } + } + return; + } + } + } + // can be other messages we do not handle like ping, pong + ctx.fireChannelRead(frame); + } +} diff --git a/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/server/websocket/WebSocketServerResponseTracingHandler.java b/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/server/websocket/WebSocketServerResponseTracingHandler.java new file mode 100644 index 00000000000..cc073f6aa1b --- /dev/null +++ b/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/server/websocket/WebSocketServerResponseTracingHandler.java @@ -0,0 +1,110 @@ +package datadog.trace.instrumentation.netty41.server.websocket; + +import static datadog.trace.bootstrap.instrumentation.api.AgentTracer.activateSpan; +import static datadog.trace.bootstrap.instrumentation.decorator.WebsocketDecorator.DECORATE; +import static datadog.trace.bootstrap.instrumentation.websocket.HandlersExtractor.MESSAGE_TYPE_BINARY; +import static datadog.trace.bootstrap.instrumentation.websocket.HandlersExtractor.MESSAGE_TYPE_TEXT; +import static datadog.trace.instrumentation.netty41.AttributeKeys.WEBSOCKET_SENDER_HANDLER_CONTEXT; + +import datadog.trace.bootstrap.instrumentation.api.AgentScope; +import datadog.trace.bootstrap.instrumentation.api.AgentSpan; +import datadog.trace.bootstrap.instrumentation.websocket.HandlerContext; +import io.netty.channel.*; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; +import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame; +import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; + +@ChannelHandler.Sharable +public class WebSocketServerResponseTracingHandler extends ChannelOutboundHandlerAdapter { + public static WebSocketServerResponseTracingHandler INSTANCE = + new WebSocketServerResponseTracingHandler(); + + @Override + public void write(ChannelHandlerContext ctx, Object frame, ChannelPromise promise) + throws Exception { + + if (frame instanceof WebSocketFrame) { + Channel channel = ctx.channel(); + HandlerContext.Sender handlerContext = channel.attr(WEBSOCKET_SENDER_HANDLER_CONTEXT).get(); + if (handlerContext != null) { + + if (frame instanceof TextWebSocketFrame) { + // WebSocket Write Text Start + TextWebSocketFrame textFrame = (TextWebSocketFrame) frame; + final AgentSpan span = + DECORATE.onSendFrameStart( + handlerContext, MESSAGE_TYPE_TEXT, textFrame.text().length()); + try (final AgentScope scope = activateSpan(span)) { + ctx.write(frame, promise); + } finally { + // WebSocket Write Text End + if (textFrame.isFinalFragment()) { + DECORATE.onFrameEnd(handlerContext); + } + } + return; + } + + if (frame instanceof BinaryWebSocketFrame) { + // WebSocket Write Binary Start + BinaryWebSocketFrame binaryFrame = (BinaryWebSocketFrame) frame; + final AgentSpan span = + DECORATE.onSendFrameStart( + handlerContext, MESSAGE_TYPE_BINARY, binaryFrame.content().readableBytes()); + try (final AgentScope scope = activateSpan(span)) { + ctx.write(frame, promise); + } finally { + // WebSocket Write Binary End + if (binaryFrame.isFinalFragment()) { + DECORATE.onFrameEnd(handlerContext); + } + } + return; + } + + if (frame instanceof ContinuationWebSocketFrame) { + ContinuationWebSocketFrame continuationWebSocketFrame = + (ContinuationWebSocketFrame) frame; + final AgentSpan span = + DECORATE.onSendFrameStart( + handlerContext, + handlerContext.getMessageType(), + MESSAGE_TYPE_TEXT.equals(handlerContext.getMessageType()) + ? continuationWebSocketFrame.text().length() + : continuationWebSocketFrame.content().readableBytes()); + try (final AgentScope scope = activateSpan(span)) { + ctx.write(frame, promise); + } finally { + // WebSocket Write Binary End + if (continuationWebSocketFrame.isFinalFragment()) { + DECORATE.onFrameEnd(handlerContext); + } + } + return; + } + + if (frame instanceof CloseWebSocketFrame) { + // WebSocket Closed by Server + CloseWebSocketFrame closeFrame = (CloseWebSocketFrame) frame; + int statusCode = closeFrame.statusCode(); + String reasonText = closeFrame.reasonText(); + channel.attr(WEBSOCKET_SENDER_HANDLER_CONTEXT).remove(); + final AgentSpan span = + DECORATE.onSessionCloseIssued(handlerContext, reasonText, statusCode); + try (final AgentScope scope = activateSpan(span)) { + ctx.write(frame, promise); + } finally { + if (closeFrame.isFinalFragment()) { + DECORATE.onFrameEnd(handlerContext); + } + } + return; + } + } + } + // can be other messages we do not handle like ping, pong + ctx.write(frame, promise); + } +} diff --git a/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/server/websocket/WebSocketServerTracingHandler.java b/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/server/websocket/WebSocketServerTracingHandler.java new file mode 100644 index 00000000000..8f6f4b2e6c4 --- /dev/null +++ b/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/server/websocket/WebSocketServerTracingHandler.java @@ -0,0 +1,14 @@ +package datadog.trace.instrumentation.netty41.server.websocket; + +import io.netty.channel.CombinedChannelDuplexHandler; + +public class WebSocketServerTracingHandler + extends CombinedChannelDuplexHandler< + WebSocketServerRequestTracingHandler, WebSocketServerResponseTracingHandler> { + + public WebSocketServerTracingHandler() { + super( + WebSocketServerRequestTracingHandler.INSTANCE, + WebSocketServerResponseTracingHandler.INSTANCE); + } +} diff --git a/dd-java-agent/instrumentation/netty-4.1/src/test/groovy/Netty41ServerTest.groovy b/dd-java-agent/instrumentation/netty-4.1/src/test/groovy/Netty41ServerTest.groovy index 65fa7a2a15e..36d479207d9 100644 --- a/dd-java-agent/instrumentation/netty-4.1/src/test/groovy/Netty41ServerTest.groovy +++ b/dd-java-agent/instrumentation/netty-4.1/src/test/groovy/Netty41ServerTest.groovy @@ -1,30 +1,15 @@ -import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.BODY_URLENCODED -import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.ERROR -import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.EXCEPTION -import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.FORWARDED -import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.NOT_FOUND -import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.QUERY_ENCODED_BOTH -import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.QUERY_ENCODED_QUERY -import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.QUERY_PARAM -import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.REDIRECT -import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.SUCCESS -import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.USER_BLOCK -import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH -import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_TYPE -import static io.netty.handler.codec.http.HttpResponseStatus.CONTINUE -import static io.netty.handler.codec.http.HttpResponseStatus.INTERNAL_SERVER_ERROR -import static io.netty.handler.codec.http.HttpUtil.is100ContinueExpected -import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1 - import datadog.appsec.api.blocking.Blocking import datadog.trace.agent.test.base.HttpServer import datadog.trace.agent.test.base.HttpServerTest +import datadog.trace.agent.test.base.WebsocketClient +import datadog.trace.agent.test.base.WebsocketServer import datadog.trace.agent.test.naming.TestingNettyHttpNamingConventions import datadog.trace.bootstrap.instrumentation.api.URIUtils import datadog.trace.instrumentation.netty41.server.NettyHttpServerDecorator import io.netty.bootstrap.ServerBootstrap import io.netty.buffer.ByteBuf import io.netty.buffer.Unpooled +import io.netty.channel.ChannelFutureListener import io.netty.channel.ChannelHandlerContext import io.netty.channel.ChannelInitializer import io.netty.channel.ChannelPipeline @@ -42,15 +27,39 @@ import io.netty.handler.codec.http.HttpResponseStatus import io.netty.handler.codec.http.HttpServerCodec import io.netty.handler.codec.http.multipart.Attribute import io.netty.handler.codec.http.multipart.HttpPostRequestDecoder +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame +import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame +import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame +import io.netty.handler.codec.http.websocketx.TextWebSocketFrame +import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler import io.netty.handler.logging.LogLevel import io.netty.handler.logging.LoggingHandler import io.netty.util.CharsetUtil +import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.BODY_URLENCODED +import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.ERROR +import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.EXCEPTION +import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.FORWARDED +import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.NOT_FOUND +import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.QUERY_ENCODED_BOTH +import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.QUERY_ENCODED_QUERY +import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.QUERY_PARAM +import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.REDIRECT +import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.SUCCESS +import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.USER_BLOCK +import static datadog.trace.agent.test.utils.TraceUtils.runUnderTrace +import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH +import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_TYPE +import static io.netty.handler.codec.http.HttpResponseStatus.CONTINUE +import static io.netty.handler.codec.http.HttpResponseStatus.INTERNAL_SERVER_ERROR +import static io.netty.handler.codec.http.HttpUtil.is100ContinueExpected +import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1 + abstract class Netty41ServerTest extends HttpServerTest { static final LoggingHandler LOGGING_HANDLER = new LoggingHandler(SERVER_LOGGER.name, LogLevel.DEBUG) - private class NettyServer implements HttpServer { + private class NettyServer implements WebsocketServer { final eventLoopGroup = new NioEventLoopGroup() int port = 0 @@ -63,9 +72,9 @@ abstract class Netty41ServerTest extends HttpServerTest { initChannel: { ch -> ChannelPipeline pipeline = ch.pipeline() pipeline.addFirst("logger", LOGGING_HANDLER) - - def handlers = [new HttpServerCodec(), new HttpObjectAggregator(1024)] - handlers.each { pipeline.addLast(it) } + pipeline.addLast(new HttpServerCodec()) + pipeline.addLast(new HttpObjectAggregator(1024)) + pipeline.addLast(new WebSocketServerProtocolHandler("/websocket")) pipeline.addLast([ channelRead0 : { ChannelHandlerContext ctx, msg -> if (msg instanceof HttpRequest) { @@ -75,7 +84,7 @@ abstract class Netty41ServerTest extends HttpServerTest { } def uri = URIUtils.safeParse(request.uri) if (uri == null) { - ctx.write(new DefaultFullHttpResponse(request.protocolVersion(),HttpResponseStatus.BAD_REQUEST)) + ctx.write(new DefaultFullHttpResponse(request.protocolVersion(), HttpResponseStatus.BAD_REQUEST)) return } HttpServerTest.ServerEndpoint endpoint = HttpServerTest.ServerEndpoint.forPath(uri.path) @@ -114,7 +123,7 @@ abstract class Netty41ServerTest extends HttpServerTest { decoder.offer(msg) m = decoder.bodyHttpDatas.collectEntries { d -> - [d.name, [((Attribute)d).value]] + [d.name, [((Attribute) d).value]] } } finally { decoder.destroy() @@ -149,6 +158,10 @@ abstract class Netty41ServerTest extends HttpServerTest { return response } } + if (msg instanceof TextWebSocketFrame || msg instanceof BinaryWebSocketFrame || msg instanceof ContinuationWebSocketFrame) { + // generate a child span. The websocket test expects this way + runUnderTrace("onRead", {}) + } }, exceptionCaught : { ChannelHandlerContext ctx, Throwable cause -> ByteBuf content = Unpooled.copiedBuffer(cause.message, CharsetUtil.UTF_8) @@ -159,6 +172,14 @@ abstract class Netty41ServerTest extends HttpServerTest { }, channelReadComplete: { it.flush() + }, + userEventTriggered : { ChannelHandlerContext ctx, Object evt -> + if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) { + WsEndpoint.onOpen(ctx) + } + }, + channelInactive : { ChannelHandlerContext ctx -> + WsEndpoint.onClose() } ] as SimpleChannelInboundHandler) } @@ -176,6 +197,46 @@ abstract class Netty41ServerTest extends HttpServerTest { URI address() { return new URI("http://localhost:$port/") } + + @Override + void awaitConnected() { + while (WsEndpoint.activeSession == null) { + synchronized (WsEndpoint) { + WsEndpoint.wait() + } + } + } + + @Override + void serverSendText(String[] messages) { + WsEndpoint.activeSession.writeAndFlush(new TextWebSocketFrame(messages.length == 1, 0, messages[0])) + for (def i = 1; i < messages.length; i++) { + WsEndpoint.activeSession.writeAndFlush(new ContinuationWebSocketFrame(messages.length - 1 == i, 0, messages[i])) + } + } + + @Override + void serverSendBinary(byte[][] binaries) { + WsEndpoint.activeSession.writeAndFlush(new BinaryWebSocketFrame(binaries.length == 1, 0, Unpooled.wrappedBuffer(binaries[0]))) + for (def i = 1; i < binaries.length; i++) { + WsEndpoint.activeSession.writeAndFlush(new ContinuationWebSocketFrame(binaries.length - 1 == i, 0, Unpooled.wrappedBuffer(binaries[i]))) + } + } + + @Override + void serverClose() { + WsEndpoint.activeSession.writeAndFlush(new CloseWebSocketFrame(1000, null)).addListener(ChannelFutureListener.CLOSE) + } + + @Override + void setMaxPayloadSize(int size) { + // not applicable + } + + @Override + boolean canSplitLargeWebsocketPayloads() { + false + } } @Override @@ -183,6 +244,11 @@ abstract class Netty41ServerTest extends HttpServerTest { return new NettyServer() } + @Override + WebsocketClient websocketClient() { + return new NettyWebsocketClient() + } + @Override String component() { NettyHttpServerDecorator.DECORATE.component() @@ -227,3 +293,21 @@ class Netty41ServerV0Test extends Netty41ServerTest implements TestingNettyHttpN class Netty41ServerV1ForkedTest extends Netty41ServerTest implements TestingNettyHttpNamingConventions.ServerV1 { } + + +class WsEndpoint { + static volatile ChannelHandlerContext activeSession + + static void onOpen(ChannelHandlerContext session) { + activeSession = session + synchronized (WsEndpoint) { + WsEndpoint.notifyAll() + } + } + + static void onClose() { + activeSession = null + } +} + + diff --git a/dd-java-agent/instrumentation/netty-4.1/src/test/groovy/NettyWebsocketClient.groovy b/dd-java-agent/instrumentation/netty-4.1/src/test/groovy/NettyWebsocketClient.groovy new file mode 100644 index 00000000000..22ec350366e --- /dev/null +++ b/dd-java-agent/instrumentation/netty-4.1/src/test/groovy/NettyWebsocketClient.groovy @@ -0,0 +1,131 @@ +import datadog.trace.agent.test.base.WebsocketClient +import io.netty.bootstrap.Bootstrap +import io.netty.buffer.Unpooled +import io.netty.channel.Channel +import io.netty.channel.ChannelHandlerContext +import io.netty.channel.ChannelInitializer +import io.netty.channel.SimpleChannelInboundHandler +import io.netty.channel.nio.NioEventLoopGroup +import io.netty.channel.socket.nio.NioSocketChannel +import io.netty.handler.codec.http.DefaultHttpHeaders +import io.netty.handler.codec.http.FullHttpResponse +import io.netty.handler.codec.http.HttpClientCodec +import io.netty.handler.codec.http.HttpObjectAggregator +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame +import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame +import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame +import io.netty.handler.codec.http.websocketx.TextWebSocketFrame +import io.netty.handler.codec.http.websocketx.WebSocketClientHandshaker +import io.netty.handler.codec.http.websocketx.WebSocketVersion +import io.netty.handler.logging.LogLevel +import io.netty.handler.logging.LoggingHandler + +import java.nio.charset.StandardCharsets +import java.util.concurrent.CountDownLatch + +import static io.netty.handler.codec.http.websocketx.WebSocketClientHandshakerFactory.newHandshaker + +class NettyWebsocketClient implements WebsocketClient { + static final LoggingHandler LOGGING_HANDLER = new LoggingHandler(NettyWebsocketClient.class.getName(), LogLevel.DEBUG) + static class WebsocketHandler extends SimpleChannelInboundHandler { + final URI uri + WebSocketClientHandshaker handshaker + def handshaken = new CountDownLatch(1) + + WebsocketHandler(uri) { + this.uri = uri + } + + @Override + void channelActive(ChannelHandlerContext ctx) throws Exception { + handshaker = newHandshaker( + uri, WebSocketVersion.V13, null, false, new DefaultHttpHeaders() + .add("User-Agent", "dd-trace-java"), // keep me + 1280000) + handshaker.handshake(ctx.channel()) + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception { + final Channel ch = ctx.channel() + if (!handshaker.isHandshakeComplete()) { + // web socket client connected + handshaker.finishHandshake(ch, (FullHttpResponse) msg) + } + handshaken.countDown() + } + } + final eventLoopGroup = new NioEventLoopGroup() + + Channel channel + def chunkSize = -1 + + @Override + void connect(String url) { + def uri = new URI(url) + def wsHandler = new WebsocketHandler(uri) + Bootstrap b = new Bootstrap() + b.group(eventLoopGroup) + .handler(LOGGING_HANDLER) + .handler(new ChannelInitializer() { + protected void initChannel(Channel ch) throws Exception { + def pipeline = ch.pipeline() + pipeline.addLast(new HttpClientCodec()) + pipeline.addLast(new HttpObjectAggregator(1024)) + pipeline.addLast(wsHandler) + // remove our handler since we do not want to trace that client + pipeline.names().findAll { it.contains("HttpClientTracingHandler") }.each { pipeline.remove(it) } + } + }).channel(NioSocketChannel) + channel = b.connect(uri.host, uri.port).sync().channel() + //wait for the handshake to complete properly + wsHandler.handshaken.await() + } + + @Override + void send(String text) { + def chunks = split(text.getBytes(StandardCharsets.UTF_8)) + channel.writeAndFlush(new TextWebSocketFrame(chunks.length == 1, 0, Unpooled.wrappedBuffer(chunks[0]))) + for (def i = 1; i < chunks.length; i++) { + channel.writeAndFlush(new ContinuationWebSocketFrame(chunks.length - 1 == i, 0, Unpooled.wrappedBuffer(chunks[i]))) + } + } + + @Override + void send(byte[] bytes) { + def chunks = split(bytes) + channel.writeAndFlush(new BinaryWebSocketFrame(chunks.length == 1, 0, Unpooled.wrappedBuffer(chunks[0]))).sync() + for (def i = 1; i < chunks.length; i++) { + channel.writeAndFlush(new ContinuationWebSocketFrame(chunks.length - 1 == i, 0, Unpooled.wrappedBuffer(chunks[i]))).sync() + } + } + + byte[][] split(byte[] src) { + if (chunkSize <= 0) { + return new byte[][]{src} + } + def ret = new byte[(int) Math.ceil(src.length / chunkSize)][] + def offset = 0 + for (def i = 0; i < ret.length; i++) { + ret[i] = new byte[Math.min(src.length - offset, chunkSize)] + System.arraycopy(src, offset, ret[i], 0, ret[i].length) + } + ret + } + + @Override + void close(int code, String reason) { + channel.writeAndFlush(new CloseWebSocketFrame(code, reason)).sync() + channel.close() + } + + @Override + boolean supportMessageChunks() { + true + } + + @Override + void setSplitChunksAfter(int size) { + chunkSize = size + } +} diff --git a/dd-java-agent/testing/src/main/groovy/datadog/trace/agent/test/base/HttpServerTest.groovy b/dd-java-agent/testing/src/main/groovy/datadog/trace/agent/test/base/HttpServerTest.groovy index 529a41510e7..17685c2318a 100644 --- a/dd-java-agent/testing/src/main/groovy/datadog/trace/agent/test/base/HttpServerTest.groovy +++ b/dd-java-agent/testing/src/main/groovy/datadog/trace/agent/test/base/HttpServerTest.groovy @@ -48,8 +48,6 @@ import okhttp3.MultipartBody import okhttp3.Request import okhttp3.RequestBody import okhttp3.Response -import okhttp3.WebSocketListener -import okio.ByteString import org.slf4j.Logger import org.slf4j.LoggerFactory @@ -409,6 +407,10 @@ abstract class HttpServerTest extends WithHttpServer { server instanceof WebsocketServer } + WebsocketClient websocketClient() { + new OkHttpWebsocketClient() + } + boolean testEndpointDiscovery() { false } @@ -1934,12 +1936,9 @@ abstract class HttpServerTest extends WithHttpServer { setup: assumeTrue(testWebsockets()) def wsServer = getServer() as WebsocketServer - + def client = websocketClient() when: - def request = new Request.Builder().url(HttpUrl.get(WEBSOCKET.resolve(address))) - .get().build() - - client.newWebSocket(request, new WebSocketListener() {}) + client.connect(WEBSOCKET.resolve(address).toString()) wsServer.awaitConnected() runUnderTrace("parent", { if (messages[0] instanceof String) { @@ -1981,21 +1980,21 @@ abstract class HttpServerTest extends WithHttpServer { setup: assumeTrue(testWebsockets()) def wsServer = getServer() as WebsocketServer - assumeTrue(chunks == 1 || wsServer.canSplitLargeWebsocketPayloads()) + def client = websocketClient() + assumeTrue(chunks == 1 || wsServer.canSplitLargeWebsocketPayloads() || client.supportMessageChunks()) when: - def request = new Request.Builder().url(HttpUrl.get(WEBSOCKET.resolve(address))) - .get().build() - - def ws = client.newWebSocket(request, new WebSocketListener() {}) + client.connect(WEBSOCKET.resolve(address).toString()) wsServer.awaitConnected() wsServer.setMaxPayloadSize(10) + // in case the client can also send partial fragments + client.setSplitChunksAfter(10) if (message instanceof String) { - ws.send(message as String) + client.send(message as String) } else { - ws.send(ByteString.of(message as byte[])) + client.send(message as byte[]) } - ws.close(1000, "goodbye") + client.close(1000, "goodbye") then: assertTraces(3, { diff --git a/dd-java-agent/testing/src/main/groovy/datadog/trace/agent/test/base/WebsocketClient.groovy b/dd-java-agent/testing/src/main/groovy/datadog/trace/agent/test/base/WebsocketClient.groovy new file mode 100644 index 00000000000..6a30ece981f --- /dev/null +++ b/dd-java-agent/testing/src/main/groovy/datadog/trace/agent/test/base/WebsocketClient.groovy @@ -0,0 +1,50 @@ +package datadog.trace.agent.test.base + +import datadog.trace.agent.test.utils.OkHttpUtils +import okhttp3.Request +import okhttp3.WebSocket +import okhttp3.WebSocketListener +import okio.ByteString + +interface WebsocketClient { + void connect(String url) + void send(String text) + void send(byte[] bytes) + void close(int code, String reason) + boolean supportMessageChunks() + void setSplitChunksAfter(int size) +} + +class OkHttpWebsocketClient implements WebsocketClient { + WebSocket session + + @Override + void connect(String url) { + session = OkHttpUtils.client().newWebSocket(new Request.Builder().url(url).get().build(), new WebSocketListener() {}) + } + + @Override + void send(String text) { + session.send(text) + } + + @Override + void send(byte[] bytes) { + session.send(ByteString.of(bytes)) + } + + @Override + void close(int code, String reason) { + session.close(code, reason) + } + + @Override + boolean supportMessageChunks() { + false + } + + @Override + void setSplitChunksAfter(int size) { + // not supported + } +}