diff --git a/Package.swift b/Package.swift
index 201ba8e56..cd9be155f 100644
--- a/Package.swift
+++ b/Package.swift
@@ -72,6 +72,10 @@ let packageDependencies: [Package.Dependency] = [
url: "https://github.com/apple/swift-docc-plugin",
from: "1.0.0"
),
+ .package(
+ url: "https://github.com/apple/swift-distributed-tracing.git",
+ from: "1.0.0"
+ ),
].appending(
.package(
url: "https://github.com/apple/swift-nio-ssl.git",
@@ -131,9 +135,11 @@ extension Target.Dependency {
)
static let dequeModule: Self = .product(name: "DequeModule", package: "swift-collections")
static let atomics: Self = .product(name: "Atomics", package: "swift-atomics")
+ static let tracing: Self = .product(name: "Tracing", package: "swift-distributed-tracing")
static let grpcCore: Self = .target(name: "GRPCCore")
static let grpcInProcessTransport: Self = .target(name: "GRPCInProcessTransport")
+ static let grpcInterceptors: Self = .target(name: "GRPCInterceptors")
static let grpcHTTP2Core: Self = .target(name: "GRPCHTTP2Core")
static let grpcHTTP2TransportNIOPosix: Self = .target(name: "GRPCHTTP2TransportNIOPosix")
static let grpcHTTP2TransportNIOTransportServices: Self = .target(name: "GRPCHTTP2TransportNIOTransportServices")
@@ -181,6 +187,14 @@ extension Target {
]
)
+ static let grpcInterceptors: Target = .target(
+ name: "GRPCInterceptors",
+ dependencies: [
+ .grpcCore,
+ .tracing
+ ]
+ )
+
static let grpcHTTP2Core: Target = .target(
name: "GRPCHTTP2Core",
dependencies: [
@@ -274,10 +288,20 @@ extension Target {
name: "GRPCInProcessTransportTests",
dependencies: [
.grpcCore,
- .grpcInProcessTransport,
+ .grpcInProcessTransport
]
)
+ static let grpcInterceptorsTests: Target = .testTarget(
+ name: "GRPCInterceptorsTests",
+ dependencies: [
+ .grpcCore,
+ .tracing,
+ .nioCore,
+ .grpcInterceptors
+ ]
+ )
+
static let grpcHTTP2CoreTests: Target = .testTarget(
name: "GRPCHTTP2CoreTests",
dependencies: [
@@ -638,6 +662,7 @@ let package = Package(
.grpcCore,
.grpcInProcessTransport,
.grpcCodeGen,
+ .grpcInterceptors,
.grpcHTTP2Core,
.grpcHTTP2TransportNIOPosix,
.grpcHTTP2TransportNIOTransportServices,
@@ -646,6 +671,7 @@ let package = Package(
.grpcCoreTests,
.grpcInProcessTransportTests,
.grpcCodeGenTests,
+ .grpcInterceptorsTests,
.grpcHTTP2CoreTests,
.grpcHTTP2TransportNIOPosixTests,
.grpcHTTP2TransportNIOTransportServicesTests
diff --git a/Sources/GRPCInterceptors/ClientTracingInterceptor.swift b/Sources/GRPCInterceptors/ClientTracingInterceptor.swift
new file mode 100644
index 000000000..2bb9395c5
--- /dev/null
+++ b/Sources/GRPCInterceptors/ClientTracingInterceptor.swift
@@ -0,0 +1,140 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import GRPCCore
+import Tracing
+
+/// A client interceptor that injects tracing information into the request.
+///
+/// The tracing information is taken from the current `ServiceContext`, and injected into the request's
+/// metadata. It will then be picked up by the server-side ``ServerTracingInterceptor``.
+///
+/// For more information, refer to the documentation for `swift-distributed-tracing`.
+@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
+public struct ClientTracingInterceptor: ClientInterceptor {
+ private let injector: ClientRequestInjector
+ private let emitEventOnEachWrite: Bool
+
+ /// Create a new instance of a ``ClientTracingInterceptor``.
+ ///
+ /// - Parameter emitEventOnEachWrite: If `true`, each request part sent and response part
+ /// received will be recorded as a separate event in a tracing span. Otherwise, only the request/response
+ /// start and end will be recorded as events.
+ public init(emitEventOnEachWrite: Bool = false) {
+ self.injector = ClientRequestInjector()
+ self.emitEventOnEachWrite = emitEventOnEachWrite
+ }
+
+ /// This interceptor will inject as the request's metadata whatever `ServiceContext` key-value pairs
+ /// have been made available by the tracing implementation bootstrapped in your application.
+ ///
+ /// Which key-value pairs are injected will depend on the specific tracing implementation
+ /// that has been configured when bootstrapping `swift-distributed-tracing` in your application.
+ public func intercept (
+ request: ClientRequest.Stream ,
+ context: ClientInterceptorContext,
+ next: @Sendable (ClientRequest.Stream , ClientInterceptorContext) async throws ->
+ ClientResponse.Stream
+ ) async throws -> ClientResponse.Stream where Input: Sendable, Output: Sendable {
+ var request = request
+ let tracer = InstrumentationSystem.tracer
+ let serviceContext = ServiceContext.current ?? .topLevel
+
+ tracer.inject(
+ serviceContext,
+ into: &request.metadata,
+ using: self.injector
+ )
+
+ return try await tracer.withSpan(
+ context.descriptor.fullyQualifiedMethod,
+ context: serviceContext,
+ ofKind: .client
+ ) { span in
+ span.addEvent("Request started")
+
+ if self.emitEventOnEachWrite {
+ let wrappedProducer = request.producer
+ request.producer = { writer in
+ let eventEmittingWriter = HookedWriter(
+ wrapping: writer,
+ beforeEachWrite: {
+ span.addEvent("Sending request part")
+ },
+ afterEachWrite: {
+ span.addEvent("Sent request part")
+ }
+ )
+
+ do {
+ try await wrappedProducer(RPCWriter(wrapping: eventEmittingWriter))
+ } catch {
+ span.addEvent("Error encountered")
+ throw error
+ }
+
+ span.addEvent("Request end")
+ }
+ }
+
+ var response: ClientResponse.Stream
+ do {
+ response = try await next(request, context)
+ } catch {
+ span.addEvent("Error encountered")
+ throw error
+ }
+
+ switch response.accepted {
+ case .success(var success):
+ if self.emitEventOnEachWrite {
+ let onEachPartRecordingSequence = success.bodyParts.map { element in
+ span.addEvent("Received response part")
+ return element
+ }
+ let onFinishRecordingSequence = OnFinishAsyncSequence(
+ wrapping: onEachPartRecordingSequence
+ ) {
+ span.addEvent("Received response end")
+ }
+ success.bodyParts = RPCAsyncSequence(wrapping: onFinishRecordingSequence)
+ response.accepted = .success(success)
+ } else {
+ let onFinishRecordingSequence = OnFinishAsyncSequence(wrapping: success.bodyParts) {
+ span.addEvent("Received response end")
+ }
+ success.bodyParts = RPCAsyncSequence(wrapping: onFinishRecordingSequence)
+ response.accepted = .success(success)
+ }
+ case .failure:
+ span.addEvent("Received error response")
+ }
+
+ return response
+ }
+ }
+}
+
+/// An injector responsible for injecting the required instrumentation keys from the `ServiceContext` into
+/// the request metadata.
+@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
+struct ClientRequestInjector: Instrumentation.Injector {
+ typealias Carrier = Metadata
+
+ func inject(_ value: String, forKey key: String, into carrier: inout Carrier) {
+ carrier.addString(value, forKey: key)
+ }
+}
diff --git a/Sources/GRPCInterceptors/HookedWriter.swift b/Sources/GRPCInterceptors/HookedWriter.swift
new file mode 100644
index 000000000..b4bb52eed
--- /dev/null
+++ b/Sources/GRPCInterceptors/HookedWriter.swift
@@ -0,0 +1,40 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import GRPCCore
+import Tracing
+
+@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
+struct HookedWriter: RPCWriterProtocol {
+ private let writer: any RPCWriterProtocol
+ private let beforeEachWrite: @Sendable () -> Void
+ private let afterEachWrite: @Sendable () -> Void
+
+ init(
+ wrapping other: some RPCWriterProtocol,
+ beforeEachWrite: @Sendable @escaping () -> Void,
+ afterEachWrite: @Sendable @escaping () -> Void
+ ) {
+ self.writer = other
+ self.beforeEachWrite = beforeEachWrite
+ self.afterEachWrite = afterEachWrite
+ }
+
+ func write(contentsOf elements: some Sequence) async throws {
+ self.beforeEachWrite()
+ try await self.writer.write(contentsOf: elements)
+ self.afterEachWrite()
+ }
+}
diff --git a/Sources/GRPCInterceptors/OnFinishAsyncSequence.swift b/Sources/GRPCInterceptors/OnFinishAsyncSequence.swift
new file mode 100644
index 000000000..311e1fcd8
--- /dev/null
+++ b/Sources/GRPCInterceptors/OnFinishAsyncSequence.swift
@@ -0,0 +1,57 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
+struct OnFinishAsyncSequence: AsyncSequence, Sendable {
+ private let _makeAsyncIterator: @Sendable () -> AsyncIterator
+
+ init(
+ wrapping other: S,
+ onFinish: @escaping () -> Void
+ ) where S.Element == Element {
+ self._makeAsyncIterator = {
+ AsyncIterator(wrapping: other.makeAsyncIterator(), onFinish: onFinish)
+ }
+ }
+
+ func makeAsyncIterator() -> AsyncIterator {
+ self._makeAsyncIterator()
+ }
+
+ struct AsyncIterator: AsyncIteratorProtocol {
+ private var iterator: any AsyncIteratorProtocol
+ private var onFinish: (() -> Void)?
+
+ fileprivate init(
+ wrapping other: Iterator,
+ onFinish: @escaping () -> Void
+ ) where Iterator: AsyncIteratorProtocol, Iterator.Element == Element {
+ self.iterator = other
+ self.onFinish = onFinish
+ }
+
+ mutating func next() async throws -> Element? {
+ let elem = try await self.iterator.next()
+
+ if elem == nil {
+ self.onFinish?()
+ self.onFinish = nil
+ }
+
+ return elem as? Element
+ }
+ }
+}
diff --git a/Sources/GRPCInterceptors/ServerTracingInterceptor.swift b/Sources/GRPCInterceptors/ServerTracingInterceptor.swift
new file mode 100644
index 000000000..d91da7051
--- /dev/null
+++ b/Sources/GRPCInterceptors/ServerTracingInterceptor.swift
@@ -0,0 +1,148 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import GRPCCore
+import Tracing
+
+/// A server interceptor that extracts tracing information from the request.
+///
+/// The extracted tracing information is made available to user code via the current `ServiceContext`.
+/// For more information, refer to the documentation for `swift-distributed-tracing`.
+@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
+public struct ServerTracingInterceptor: ServerInterceptor {
+ private let extractor: ServerRequestExtractor
+ private let emitEventOnEachWrite: Bool
+
+ /// Create a new instance of a ``ServerTracingInterceptor``.
+ ///
+ /// - Parameter emitEventOnEachWrite: If `true`, each response part sent and request part
+ /// received will be recorded as a separate event in a tracing span. Otherwise, only the request/response
+ /// start and end will be recorded as events.
+ public init(emitEventOnEachWrite: Bool = false) {
+ self.extractor = ServerRequestExtractor()
+ self.emitEventOnEachWrite = emitEventOnEachWrite
+ }
+
+ /// This interceptor will extract whatever `ServiceContext` key-value pairs have been inserted into the
+ /// request's metadata, and will make them available to user code via the `ServiceContext/current`
+ /// context.
+ ///
+ /// Which key-value pairs are extracted and made available will depend on the specific tracing implementation
+ /// that has been configured when bootstrapping `swift-distributed-tracing` in your application.
+ public func intercept (
+ request: ServerRequest.Stream ,
+ context: ServerInterceptorContext,
+ next: @Sendable (ServerRequest.Stream , ServerInterceptorContext) async throws ->
+ ServerResponse.Stream
+ ) async throws -> ServerResponse.Stream where Input: Sendable, Output: Sendable {
+ var serviceContext = ServiceContext.topLevel
+ let tracer = InstrumentationSystem.tracer
+
+ tracer.extract(
+ request.metadata,
+ into: &serviceContext,
+ using: self.extractor
+ )
+
+ return try await ServiceContext.withValue(serviceContext) {
+ try await tracer.withSpan(
+ context.descriptor.fullyQualifiedMethod,
+ context: serviceContext,
+ ofKind: .server
+ ) { span in
+ span.addEvent("Received request start")
+
+ var request = request
+
+ if self.emitEventOnEachWrite {
+ request.messages = RPCAsyncSequence(
+ wrapping: request.messages.map { element in
+ span.addEvent("Received request part")
+ return element
+ }
+ )
+ }
+
+ var response = try await next(request, context)
+
+ span.addEvent("Received request end")
+
+ switch response.accepted {
+ case .success(var success):
+ let wrappedProducer = success.producer
+
+ if self.emitEventOnEachWrite {
+ success.producer = { writer in
+ let eventEmittingWriter = HookedWriter(
+ wrapping: writer,
+ beforeEachWrite: {
+ span.addEvent("Sending response part")
+ },
+ afterEachWrite: {
+ span.addEvent("Sent response part")
+ }
+ )
+
+ let wrappedResult: Metadata
+ do {
+ wrappedResult = try await wrappedProducer(
+ RPCWriter(wrapping: eventEmittingWriter)
+ )
+ } catch {
+ span.addEvent("Error encountered")
+ throw error
+ }
+
+ span.addEvent("Sent response end")
+ return wrappedResult
+ }
+ } else {
+ success.producer = { writer in
+ let wrappedResult: Metadata
+ do {
+ wrappedResult = try await wrappedProducer(writer)
+ } catch {
+ span.addEvent("Error encountered")
+ throw error
+ }
+
+ span.addEvent("Sent response end")
+ return wrappedResult
+ }
+ }
+
+ response = .init(accepted: .success(success))
+ case .failure:
+ span.addEvent("Sent error response")
+ }
+
+ return response
+ }
+ }
+ }
+}
+
+/// An extractor responsible for extracting the required instrumentation keys from request metadata.
+@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
+struct ServerRequestExtractor: Instrumentation.Extractor {
+ typealias Carrier = Metadata
+
+ func extract(key: String, from carrier: Carrier) -> String? {
+ var values = carrier[stringValues: key].makeIterator()
+ // There should only be one value for each key. If more, pick just one.
+ return values.next()
+ }
+}
diff --git a/Tests/GRPCInterceptorsTests/TracingInterceptorTests.swift b/Tests/GRPCInterceptorsTests/TracingInterceptorTests.swift
new file mode 100644
index 000000000..98fbb2558
--- /dev/null
+++ b/Tests/GRPCInterceptorsTests/TracingInterceptorTests.swift
@@ -0,0 +1,333 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import GRPCCore
+import Tracing
+import XCTest
+
+@testable import GRPCInterceptors
+
+@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
+final class TracingInterceptorTests: XCTestCase {
+ override class func setUp() {
+ InstrumentationSystem.bootstrap(TestTracer())
+ }
+
+ #if swift(>=5.8) // Compiling these tests fails in 5.7
+ func testClientInterceptor() async throws {
+ var serviceContext = ServiceContext.topLevel
+ let traceIDString = UUID().uuidString
+ let interceptor = ClientTracingInterceptor(emitEventOnEachWrite: false)
+ let (stream, continuation) = AsyncStream.makeStream()
+ serviceContext.traceID = traceIDString
+
+ try await ServiceContext.withValue(serviceContext) {
+ let methodDescriptor = MethodDescriptor(
+ service: "TracingInterceptorTests",
+ method: "testClientInterceptor"
+ )
+ let response = try await interceptor.intercept(
+ request: .init(producer: { writer in
+ try await writer.write(contentsOf: ["request1"])
+ try await writer.write(contentsOf: ["request2"])
+ }),
+ context: .init(descriptor: methodDescriptor)
+ ) { stream, _ in
+ // Assert the metadata contains the injected context key-value.
+ XCTAssertEqual(stream.metadata, ["trace-id": "\(traceIDString)"])
+
+ // Write into the response stream to make sure the `producer` closure's called.
+ let writer = RPCWriter(wrapping: TestWriter(streamContinuation: continuation))
+ try await stream.producer(writer)
+ continuation.finish()
+
+ return .init(
+ metadata: [],
+ bodyParts: .init(
+ wrapping: AsyncStream { cont in
+ cont.yield(.message(["response"]))
+ cont.finish()
+ }
+ )
+ )
+ }
+
+ var streamIterator = stream.makeAsyncIterator()
+ var element = await streamIterator.next()
+ XCTAssertEqual(element, "request1")
+ element = await streamIterator.next()
+ XCTAssertEqual(element, "request2")
+ element = await streamIterator.next()
+ XCTAssertNil(element)
+
+ var messages = response.messages.makeAsyncIterator()
+ var message = try await messages.next()
+ XCTAssertEqual(message, ["response"])
+ message = try await messages.next()
+ XCTAssertNil(message)
+
+ let tracer = InstrumentationSystem.tracer as! TestTracer
+ XCTAssertEqual(
+ tracer.getEventsForTestSpan(ofOperationName: methodDescriptor.fullyQualifiedMethod).map {
+ $0.name
+ },
+ [
+ "Request started",
+ "Received response end",
+ ]
+ )
+ }
+ }
+
+ func testClientInterceptorAllEventsRecorded() async throws {
+ let methodDescriptor = MethodDescriptor(
+ service: "TracingInterceptorTests",
+ method: "testClientInterceptorAllEventsRecorded"
+ )
+ var serviceContext = ServiceContext.topLevel
+ let traceIDString = UUID().uuidString
+ let interceptor = ClientTracingInterceptor(emitEventOnEachWrite: true)
+ let (stream, continuation) = AsyncStream.makeStream()
+ serviceContext.traceID = traceIDString
+
+ try await ServiceContext.withValue(serviceContext) {
+ let response = try await interceptor.intercept(
+ request: .init(producer: { writer in
+ try await writer.write(contentsOf: ["request1"])
+ try await writer.write(contentsOf: ["request2"])
+ }),
+ context: .init(descriptor: methodDescriptor)
+ ) { stream, _ in
+ // Assert the metadata contains the injected context key-value.
+ XCTAssertEqual(stream.metadata, ["trace-id": "\(traceIDString)"])
+
+ // Write into the response stream to make sure the `producer` closure's called.
+ let writer = RPCWriter(wrapping: TestWriter(streamContinuation: continuation))
+ try await stream.producer(writer)
+ continuation.finish()
+
+ return .init(
+ metadata: [],
+ bodyParts: .init(
+ wrapping: AsyncStream { cont in
+ cont.yield(.message(["response"]))
+ cont.finish()
+ }
+ )
+ )
+ }
+
+ var streamIterator = stream.makeAsyncIterator()
+ var element = await streamIterator.next()
+ XCTAssertEqual(element, "request1")
+ element = await streamIterator.next()
+ XCTAssertEqual(element, "request2")
+ element = await streamIterator.next()
+ XCTAssertNil(element)
+
+ var messages = response.messages.makeAsyncIterator()
+ var message = try await messages.next()
+ XCTAssertEqual(message, ["response"])
+ message = try await messages.next()
+ XCTAssertNil(message)
+
+ let tracer = InstrumentationSystem.tracer as! TestTracer
+ XCTAssertEqual(
+ tracer.getEventsForTestSpan(ofOperationName: methodDescriptor.fullyQualifiedMethod).map {
+ $0.name
+ },
+ [
+ "Request started",
+ // Recorded when `request1` is sent
+ "Sending request part",
+ "Sent request part",
+ // Recorded when `request2` is sent
+ "Sending request part",
+ "Sent request part",
+ // Recorded after all request parts have been sent
+ "Request end",
+ // Recorded when receiving response part
+ "Received response part",
+ // Recorded at end of response
+ "Received response end",
+ ]
+ )
+ }
+ }
+ #endif // swift >= 5.7
+
+ func testServerInterceptorErrorResponse() async throws {
+ let methodDescriptor = MethodDescriptor(
+ service: "TracingInterceptorTests",
+ method: "testServerInterceptorErrorResponse"
+ )
+ let interceptor = ServerTracingInterceptor(emitEventOnEachWrite: false)
+ let response = try await interceptor.intercept(
+ request: .init(single: .init(metadata: ["trace-id": "some-trace-id"], message: [])),
+ context: .init(descriptor: methodDescriptor)
+ ) { _, _ in
+ ServerResponse.Stream(error: .init(code: .unknown, message: "Test error"))
+ }
+ XCTAssertThrowsError(try response.accepted.get())
+
+ let tracer = InstrumentationSystem.tracer as! TestTracer
+ XCTAssertEqual(
+ tracer.getEventsForTestSpan(ofOperationName: methodDescriptor.fullyQualifiedMethod).map {
+ $0.name
+ },
+ [
+ "Received request start",
+ "Received request end",
+ "Sent error response",
+ ]
+ )
+ }
+
+ func testServerInterceptor() async throws {
+ let methodDescriptor = MethodDescriptor(
+ service: "TracingInterceptorTests",
+ method: "testServerInterceptor"
+ )
+ let (stream, continuation) = AsyncStream.makeStream()
+ let interceptor = ServerTracingInterceptor(emitEventOnEachWrite: false)
+ let response = try await interceptor.intercept(
+ request: .init(single: .init(metadata: ["trace-id": "some-trace-id"], message: [])),
+ context: .init(descriptor: methodDescriptor)
+ ) { _, _ in
+ { [serviceContext = ServiceContext.current] in
+ return ServerResponse.Stream(
+ accepted: .success(
+ .init(
+ metadata: [],
+ producer: { writer in
+ guard let serviceContext else {
+ XCTFail("There should be a service context present.")
+ return ["Result": "Test failed"]
+ }
+
+ let traceID = serviceContext.traceID
+ XCTAssertEqual("some-trace-id", traceID)
+
+ try await writer.write("response1")
+ try await writer.write("response2")
+
+ return ["Result": "Trailing metadata"]
+ }
+ )
+ )
+ )
+ }()
+ }
+
+ let responseContents = try response.accepted.get()
+ let trailingMetadata = try await responseContents.producer(
+ RPCWriter(wrapping: TestWriter(streamContinuation: continuation))
+ )
+ continuation.finish()
+ XCTAssertEqual(trailingMetadata, ["Result": "Trailing metadata"])
+
+ var streamIterator = stream.makeAsyncIterator()
+ var element = await streamIterator.next()
+ XCTAssertEqual(element, "response1")
+ element = await streamIterator.next()
+ XCTAssertEqual(element, "response2")
+ element = await streamIterator.next()
+ XCTAssertNil(element)
+
+ let tracer = InstrumentationSystem.tracer as! TestTracer
+ XCTAssertEqual(
+ tracer.getEventsForTestSpan(ofOperationName: methodDescriptor.fullyQualifiedMethod).map {
+ $0.name
+ },
+ [
+ "Received request start",
+ "Received request end",
+ "Sent response end",
+ ]
+ )
+ }
+
+ func testServerInterceptorAllEventsRecorded() async throws {
+ let methodDescriptor = MethodDescriptor(
+ service: "TracingInterceptorTests",
+ method: "testServerInterceptorAllEventsRecorded"
+ )
+ let (stream, continuation) = AsyncStream.makeStream()
+ let interceptor = ServerTracingInterceptor(emitEventOnEachWrite: true)
+ let response = try await interceptor.intercept(
+ request: .init(single: .init(metadata: ["trace-id": "some-trace-id"], message: [])),
+ context: .init(descriptor: methodDescriptor)
+ ) { _, _ in
+ { [serviceContext = ServiceContext.current] in
+ return ServerResponse.Stream(
+ accepted: .success(
+ .init(
+ metadata: [],
+ producer: { writer in
+ guard let serviceContext else {
+ XCTFail("There should be a service context present.")
+ return ["Result": "Test failed"]
+ }
+
+ let traceID = serviceContext.traceID
+ XCTAssertEqual("some-trace-id", traceID)
+
+ try await writer.write("response1")
+ try await writer.write("response2")
+
+ return ["Result": "Trailing metadata"]
+ }
+ )
+ )
+ )
+ }()
+ }
+
+ let responseContents = try response.accepted.get()
+ let trailingMetadata = try await responseContents.producer(
+ RPCWriter(wrapping: TestWriter(streamContinuation: continuation))
+ )
+ continuation.finish()
+ XCTAssertEqual(trailingMetadata, ["Result": "Trailing metadata"])
+
+ var streamIterator = stream.makeAsyncIterator()
+ var element = await streamIterator.next()
+ XCTAssertEqual(element, "response1")
+ element = await streamIterator.next()
+ XCTAssertEqual(element, "response2")
+ element = await streamIterator.next()
+ XCTAssertNil(element)
+
+ let tracer = InstrumentationSystem.tracer as! TestTracer
+ XCTAssertEqual(
+ tracer.getEventsForTestSpan(ofOperationName: methodDescriptor.fullyQualifiedMethod).map {
+ $0.name
+ },
+ [
+ "Received request start",
+ "Received request end",
+ // Recorded when `response1` is sent
+ "Sending response part",
+ "Sent response part",
+ // Recorded when `response2` is sent
+ "Sending response part",
+ "Sent response part",
+ // Recorded when we're done sending response
+ "Sent response end",
+ ]
+ )
+ }
+}
diff --git a/Tests/GRPCInterceptorsTests/TracingTestsUtilities.swift b/Tests/GRPCInterceptorsTests/TracingTestsUtilities.swift
new file mode 100644
index 000000000..ffed5213a
--- /dev/null
+++ b/Tests/GRPCInterceptorsTests/TracingTestsUtilities.swift
@@ -0,0 +1,182 @@
+/*
+ * Copyright 2024, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import GRPCCore
+import NIOConcurrencyHelpers
+import Tracing
+
+final class TestTracer: Tracer {
+ typealias Span = TestSpan
+
+ private var testSpans: NIOLockedValueBox<[String: TestSpan]> = .init([:])
+
+ func getEventsForTestSpan(ofOperationName operationName: String) -> [SpanEvent] {
+ self.testSpans.withLockedValue({ $0[operationName] })?.events ?? []
+ }
+
+ func extract(
+ _ carrier: Carrier,
+ into context: inout ServiceContextModule.ServiceContext,
+ using extractor: Extract
+ ) where Carrier == Extract.Carrier, Extract: Instrumentation.Extractor {
+ let traceID = extractor.extract(key: TraceID.keyName, from: carrier)
+ context[TraceID.self] = traceID
+ }
+
+ func inject(
+ _ context: ServiceContextModule.ServiceContext,
+ into carrier: inout Carrier,
+ using injector: Inject
+ ) where Carrier == Inject.Carrier, Inject: Instrumentation.Injector {
+ if let traceID = context.traceID {
+ injector.inject(traceID, forKey: TraceID.keyName, into: &carrier)
+ }
+ }
+
+ func forceFlush() {
+ // no-op
+ }
+
+ func startSpan(
+ _ operationName: String,
+ context: @autoclosure () -> ServiceContext,
+ ofKind kind: SpanKind,
+ at instant: @autoclosure () -> Instant,
+ function: String,
+ file fileID: String,
+ line: UInt
+ ) -> TestSpan where Instant: TracerInstant {
+ return self.testSpans.withLockedValue { testSpans in
+ let span = TestSpan(context: context(), operationName: operationName)
+ testSpans[operationName] = span
+ return span
+ }
+ }
+}
+
+class TestSpan: Span {
+ var context: ServiceContextModule.ServiceContext
+ var operationName: String
+ var attributes: Tracing.SpanAttributes
+ var isRecording: Bool
+ private(set) var status: Tracing.SpanStatus?
+ private(set) var events: [Tracing.SpanEvent] = []
+
+ init(
+ context: ServiceContextModule.ServiceContext,
+ operationName: String,
+ attributes: Tracing.SpanAttributes = [:],
+ isRecording: Bool = true
+ ) {
+ self.context = context
+ self.operationName = operationName
+ self.attributes = attributes
+ self.isRecording = isRecording
+ }
+
+ func setStatus(_ status: Tracing.SpanStatus) {
+ self.status = status
+ }
+
+ func addEvent(_ event: Tracing.SpanEvent) {
+ self.events.append(event)
+ }
+
+ func recordError(
+ _ error: any Error,
+ attributes: Tracing.SpanAttributes,
+ at instant: @autoclosure () -> Instant
+ ) where Instant: Tracing.TracerInstant {
+ self.setStatus(
+ .init(
+ code: .error,
+ message: "Error: \(error), attributes: \(attributes), at instant: \(instant())"
+ )
+ )
+ }
+
+ func addLink(_ link: Tracing.SpanLink) {
+ self.context.spanLinks?.append(link)
+ }
+
+ func end(at instant: @autoclosure () -> Instant) where Instant: Tracing.TracerInstant {
+ self.setStatus(.init(code: .ok, message: "Ended at instant: \(instant())"))
+ }
+}
+
+enum TraceID: ServiceContextModule.ServiceContextKey {
+ typealias Value = String
+
+ static let keyName = "trace-id"
+}
+
+enum ServiceContextSpanLinksKey: ServiceContextModule.ServiceContextKey {
+ typealias Value = [SpanLink]
+
+ static let keyName = "span-links"
+}
+
+extension ServiceContext {
+ var traceID: String? {
+ get {
+ self[TraceID.self]
+ }
+ set {
+ self[TraceID.self] = newValue
+ }
+ }
+
+ var spanLinks: [SpanLink]? {
+ get {
+ self[ServiceContextSpanLinksKey.self]
+ }
+ set {
+ self[ServiceContextSpanLinksKey.self] = newValue
+ }
+ }
+}
+
+struct TestWriter: RPCWriterProtocol {
+ typealias Element = WriterElement
+
+ private let streamContinuation: AsyncStream.Continuation
+
+ init(streamContinuation: AsyncStream.Continuation) {
+ self.streamContinuation = streamContinuation
+ }
+
+ func write(contentsOf elements: some Sequence) async throws {
+ elements.forEach { element in
+ self.streamContinuation.yield(element)
+ }
+ }
+}
+
+#if swift(<5.9)
+@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
+extension AsyncStream {
+ static func makeStream(
+ of elementType: Element.Type = Element.self,
+ bufferingPolicy limit: AsyncStream.Continuation.BufferingPolicy = .unbounded
+ ) -> (stream: AsyncStream, continuation: AsyncStream.Continuation) {
+ var continuation: AsyncStream.Continuation!
+ let stream = AsyncStream(Element.self, bufferingPolicy: limit) {
+ continuation = $0
+ }
+ return (stream, continuation)
+ }
+}
+#endif