diff --git a/Package.swift b/Package.swift index 201ba8e56..cd9be155f 100644 --- a/Package.swift +++ b/Package.swift @@ -72,6 +72,10 @@ let packageDependencies: [Package.Dependency] = [ url: "https://github.com/apple/swift-docc-plugin", from: "1.0.0" ), + .package( + url: "https://github.com/apple/swift-distributed-tracing.git", + from: "1.0.0" + ), ].appending( .package( url: "https://github.com/apple/swift-nio-ssl.git", @@ -131,9 +135,11 @@ extension Target.Dependency { ) static let dequeModule: Self = .product(name: "DequeModule", package: "swift-collections") static let atomics: Self = .product(name: "Atomics", package: "swift-atomics") + static let tracing: Self = .product(name: "Tracing", package: "swift-distributed-tracing") static let grpcCore: Self = .target(name: "GRPCCore") static let grpcInProcessTransport: Self = .target(name: "GRPCInProcessTransport") + static let grpcInterceptors: Self = .target(name: "GRPCInterceptors") static let grpcHTTP2Core: Self = .target(name: "GRPCHTTP2Core") static let grpcHTTP2TransportNIOPosix: Self = .target(name: "GRPCHTTP2TransportNIOPosix") static let grpcHTTP2TransportNIOTransportServices: Self = .target(name: "GRPCHTTP2TransportNIOTransportServices") @@ -181,6 +187,14 @@ extension Target { ] ) + static let grpcInterceptors: Target = .target( + name: "GRPCInterceptors", + dependencies: [ + .grpcCore, + .tracing + ] + ) + static let grpcHTTP2Core: Target = .target( name: "GRPCHTTP2Core", dependencies: [ @@ -274,10 +288,20 @@ extension Target { name: "GRPCInProcessTransportTests", dependencies: [ .grpcCore, - .grpcInProcessTransport, + .grpcInProcessTransport ] ) + static let grpcInterceptorsTests: Target = .testTarget( + name: "GRPCInterceptorsTests", + dependencies: [ + .grpcCore, + .tracing, + .nioCore, + .grpcInterceptors + ] + ) + static let grpcHTTP2CoreTests: Target = .testTarget( name: "GRPCHTTP2CoreTests", dependencies: [ @@ -638,6 +662,7 @@ let package = Package( .grpcCore, .grpcInProcessTransport, .grpcCodeGen, + .grpcInterceptors, .grpcHTTP2Core, .grpcHTTP2TransportNIOPosix, .grpcHTTP2TransportNIOTransportServices, @@ -646,6 +671,7 @@ let package = Package( .grpcCoreTests, .grpcInProcessTransportTests, .grpcCodeGenTests, + .grpcInterceptorsTests, .grpcHTTP2CoreTests, .grpcHTTP2TransportNIOPosixTests, .grpcHTTP2TransportNIOTransportServicesTests diff --git a/Sources/GRPCInterceptors/ClientTracingInterceptor.swift b/Sources/GRPCInterceptors/ClientTracingInterceptor.swift new file mode 100644 index 000000000..2bb9395c5 --- /dev/null +++ b/Sources/GRPCInterceptors/ClientTracingInterceptor.swift @@ -0,0 +1,140 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import GRPCCore +import Tracing + +/// A client interceptor that injects tracing information into the request. +/// +/// The tracing information is taken from the current `ServiceContext`, and injected into the request's +/// metadata. It will then be picked up by the server-side ``ServerTracingInterceptor``. +/// +/// For more information, refer to the documentation for `swift-distributed-tracing`. +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +public struct ClientTracingInterceptor: ClientInterceptor { + private let injector: ClientRequestInjector + private let emitEventOnEachWrite: Bool + + /// Create a new instance of a ``ClientTracingInterceptor``. + /// + /// - Parameter emitEventOnEachWrite: If `true`, each request part sent and response part + /// received will be recorded as a separate event in a tracing span. Otherwise, only the request/response + /// start and end will be recorded as events. + public init(emitEventOnEachWrite: Bool = false) { + self.injector = ClientRequestInjector() + self.emitEventOnEachWrite = emitEventOnEachWrite + } + + /// This interceptor will inject as the request's metadata whatever `ServiceContext` key-value pairs + /// have been made available by the tracing implementation bootstrapped in your application. + /// + /// Which key-value pairs are injected will depend on the specific tracing implementation + /// that has been configured when bootstrapping `swift-distributed-tracing` in your application. + public func intercept( + request: ClientRequest.Stream, + context: ClientInterceptorContext, + next: @Sendable (ClientRequest.Stream, ClientInterceptorContext) async throws -> + ClientResponse.Stream + ) async throws -> ClientResponse.Stream where Input: Sendable, Output: Sendable { + var request = request + let tracer = InstrumentationSystem.tracer + let serviceContext = ServiceContext.current ?? .topLevel + + tracer.inject( + serviceContext, + into: &request.metadata, + using: self.injector + ) + + return try await tracer.withSpan( + context.descriptor.fullyQualifiedMethod, + context: serviceContext, + ofKind: .client + ) { span in + span.addEvent("Request started") + + if self.emitEventOnEachWrite { + let wrappedProducer = request.producer + request.producer = { writer in + let eventEmittingWriter = HookedWriter( + wrapping: writer, + beforeEachWrite: { + span.addEvent("Sending request part") + }, + afterEachWrite: { + span.addEvent("Sent request part") + } + ) + + do { + try await wrappedProducer(RPCWriter(wrapping: eventEmittingWriter)) + } catch { + span.addEvent("Error encountered") + throw error + } + + span.addEvent("Request end") + } + } + + var response: ClientResponse.Stream + do { + response = try await next(request, context) + } catch { + span.addEvent("Error encountered") + throw error + } + + switch response.accepted { + case .success(var success): + if self.emitEventOnEachWrite { + let onEachPartRecordingSequence = success.bodyParts.map { element in + span.addEvent("Received response part") + return element + } + let onFinishRecordingSequence = OnFinishAsyncSequence( + wrapping: onEachPartRecordingSequence + ) { + span.addEvent("Received response end") + } + success.bodyParts = RPCAsyncSequence(wrapping: onFinishRecordingSequence) + response.accepted = .success(success) + } else { + let onFinishRecordingSequence = OnFinishAsyncSequence(wrapping: success.bodyParts) { + span.addEvent("Received response end") + } + success.bodyParts = RPCAsyncSequence(wrapping: onFinishRecordingSequence) + response.accepted = .success(success) + } + case .failure: + span.addEvent("Received error response") + } + + return response + } + } +} + +/// An injector responsible for injecting the required instrumentation keys from the `ServiceContext` into +/// the request metadata. +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +struct ClientRequestInjector: Instrumentation.Injector { + typealias Carrier = Metadata + + func inject(_ value: String, forKey key: String, into carrier: inout Carrier) { + carrier.addString(value, forKey: key) + } +} diff --git a/Sources/GRPCInterceptors/HookedWriter.swift b/Sources/GRPCInterceptors/HookedWriter.swift new file mode 100644 index 000000000..b4bb52eed --- /dev/null +++ b/Sources/GRPCInterceptors/HookedWriter.swift @@ -0,0 +1,40 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import GRPCCore +import Tracing + +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +struct HookedWriter: RPCWriterProtocol { + private let writer: any RPCWriterProtocol + private let beforeEachWrite: @Sendable () -> Void + private let afterEachWrite: @Sendable () -> Void + + init( + wrapping other: some RPCWriterProtocol, + beforeEachWrite: @Sendable @escaping () -> Void, + afterEachWrite: @Sendable @escaping () -> Void + ) { + self.writer = other + self.beforeEachWrite = beforeEachWrite + self.afterEachWrite = afterEachWrite + } + + func write(contentsOf elements: some Sequence) async throws { + self.beforeEachWrite() + try await self.writer.write(contentsOf: elements) + self.afterEachWrite() + } +} diff --git a/Sources/GRPCInterceptors/OnFinishAsyncSequence.swift b/Sources/GRPCInterceptors/OnFinishAsyncSequence.swift new file mode 100644 index 000000000..311e1fcd8 --- /dev/null +++ b/Sources/GRPCInterceptors/OnFinishAsyncSequence.swift @@ -0,0 +1,57 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +struct OnFinishAsyncSequence: AsyncSequence, Sendable { + private let _makeAsyncIterator: @Sendable () -> AsyncIterator + + init( + wrapping other: S, + onFinish: @escaping () -> Void + ) where S.Element == Element { + self._makeAsyncIterator = { + AsyncIterator(wrapping: other.makeAsyncIterator(), onFinish: onFinish) + } + } + + func makeAsyncIterator() -> AsyncIterator { + self._makeAsyncIterator() + } + + struct AsyncIterator: AsyncIteratorProtocol { + private var iterator: any AsyncIteratorProtocol + private var onFinish: (() -> Void)? + + fileprivate init( + wrapping other: Iterator, + onFinish: @escaping () -> Void + ) where Iterator: AsyncIteratorProtocol, Iterator.Element == Element { + self.iterator = other + self.onFinish = onFinish + } + + mutating func next() async throws -> Element? { + let elem = try await self.iterator.next() + + if elem == nil { + self.onFinish?() + self.onFinish = nil + } + + return elem as? Element + } + } +} diff --git a/Sources/GRPCInterceptors/ServerTracingInterceptor.swift b/Sources/GRPCInterceptors/ServerTracingInterceptor.swift new file mode 100644 index 000000000..d91da7051 --- /dev/null +++ b/Sources/GRPCInterceptors/ServerTracingInterceptor.swift @@ -0,0 +1,148 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import GRPCCore +import Tracing + +/// A server interceptor that extracts tracing information from the request. +/// +/// The extracted tracing information is made available to user code via the current `ServiceContext`. +/// For more information, refer to the documentation for `swift-distributed-tracing`. +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +public struct ServerTracingInterceptor: ServerInterceptor { + private let extractor: ServerRequestExtractor + private let emitEventOnEachWrite: Bool + + /// Create a new instance of a ``ServerTracingInterceptor``. + /// + /// - Parameter emitEventOnEachWrite: If `true`, each response part sent and request part + /// received will be recorded as a separate event in a tracing span. Otherwise, only the request/response + /// start and end will be recorded as events. + public init(emitEventOnEachWrite: Bool = false) { + self.extractor = ServerRequestExtractor() + self.emitEventOnEachWrite = emitEventOnEachWrite + } + + /// This interceptor will extract whatever `ServiceContext` key-value pairs have been inserted into the + /// request's metadata, and will make them available to user code via the `ServiceContext/current` + /// context. + /// + /// Which key-value pairs are extracted and made available will depend on the specific tracing implementation + /// that has been configured when bootstrapping `swift-distributed-tracing` in your application. + public func intercept( + request: ServerRequest.Stream, + context: ServerInterceptorContext, + next: @Sendable (ServerRequest.Stream, ServerInterceptorContext) async throws -> + ServerResponse.Stream + ) async throws -> ServerResponse.Stream where Input: Sendable, Output: Sendable { + var serviceContext = ServiceContext.topLevel + let tracer = InstrumentationSystem.tracer + + tracer.extract( + request.metadata, + into: &serviceContext, + using: self.extractor + ) + + return try await ServiceContext.withValue(serviceContext) { + try await tracer.withSpan( + context.descriptor.fullyQualifiedMethod, + context: serviceContext, + ofKind: .server + ) { span in + span.addEvent("Received request start") + + var request = request + + if self.emitEventOnEachWrite { + request.messages = RPCAsyncSequence( + wrapping: request.messages.map { element in + span.addEvent("Received request part") + return element + } + ) + } + + var response = try await next(request, context) + + span.addEvent("Received request end") + + switch response.accepted { + case .success(var success): + let wrappedProducer = success.producer + + if self.emitEventOnEachWrite { + success.producer = { writer in + let eventEmittingWriter = HookedWriter( + wrapping: writer, + beforeEachWrite: { + span.addEvent("Sending response part") + }, + afterEachWrite: { + span.addEvent("Sent response part") + } + ) + + let wrappedResult: Metadata + do { + wrappedResult = try await wrappedProducer( + RPCWriter(wrapping: eventEmittingWriter) + ) + } catch { + span.addEvent("Error encountered") + throw error + } + + span.addEvent("Sent response end") + return wrappedResult + } + } else { + success.producer = { writer in + let wrappedResult: Metadata + do { + wrappedResult = try await wrappedProducer(writer) + } catch { + span.addEvent("Error encountered") + throw error + } + + span.addEvent("Sent response end") + return wrappedResult + } + } + + response = .init(accepted: .success(success)) + case .failure: + span.addEvent("Sent error response") + } + + return response + } + } + } +} + +/// An extractor responsible for extracting the required instrumentation keys from request metadata. +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +struct ServerRequestExtractor: Instrumentation.Extractor { + typealias Carrier = Metadata + + func extract(key: String, from carrier: Carrier) -> String? { + var values = carrier[stringValues: key].makeIterator() + // There should only be one value for each key. If more, pick just one. + return values.next() + } +} diff --git a/Tests/GRPCInterceptorsTests/TracingInterceptorTests.swift b/Tests/GRPCInterceptorsTests/TracingInterceptorTests.swift new file mode 100644 index 000000000..98fbb2558 --- /dev/null +++ b/Tests/GRPCInterceptorsTests/TracingInterceptorTests.swift @@ -0,0 +1,333 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import GRPCCore +import Tracing +import XCTest + +@testable import GRPCInterceptors + +@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *) +final class TracingInterceptorTests: XCTestCase { + override class func setUp() { + InstrumentationSystem.bootstrap(TestTracer()) + } + + #if swift(>=5.8) // Compiling these tests fails in 5.7 + func testClientInterceptor() async throws { + var serviceContext = ServiceContext.topLevel + let traceIDString = UUID().uuidString + let interceptor = ClientTracingInterceptor(emitEventOnEachWrite: false) + let (stream, continuation) = AsyncStream.makeStream() + serviceContext.traceID = traceIDString + + try await ServiceContext.withValue(serviceContext) { + let methodDescriptor = MethodDescriptor( + service: "TracingInterceptorTests", + method: "testClientInterceptor" + ) + let response = try await interceptor.intercept( + request: .init(producer: { writer in + try await writer.write(contentsOf: ["request1"]) + try await writer.write(contentsOf: ["request2"]) + }), + context: .init(descriptor: methodDescriptor) + ) { stream, _ in + // Assert the metadata contains the injected context key-value. + XCTAssertEqual(stream.metadata, ["trace-id": "\(traceIDString)"]) + + // Write into the response stream to make sure the `producer` closure's called. + let writer = RPCWriter(wrapping: TestWriter(streamContinuation: continuation)) + try await stream.producer(writer) + continuation.finish() + + return .init( + metadata: [], + bodyParts: .init( + wrapping: AsyncStream { cont in + cont.yield(.message(["response"])) + cont.finish() + } + ) + ) + } + + var streamIterator = stream.makeAsyncIterator() + var element = await streamIterator.next() + XCTAssertEqual(element, "request1") + element = await streamIterator.next() + XCTAssertEqual(element, "request2") + element = await streamIterator.next() + XCTAssertNil(element) + + var messages = response.messages.makeAsyncIterator() + var message = try await messages.next() + XCTAssertEqual(message, ["response"]) + message = try await messages.next() + XCTAssertNil(message) + + let tracer = InstrumentationSystem.tracer as! TestTracer + XCTAssertEqual( + tracer.getEventsForTestSpan(ofOperationName: methodDescriptor.fullyQualifiedMethod).map { + $0.name + }, + [ + "Request started", + "Received response end", + ] + ) + } + } + + func testClientInterceptorAllEventsRecorded() async throws { + let methodDescriptor = MethodDescriptor( + service: "TracingInterceptorTests", + method: "testClientInterceptorAllEventsRecorded" + ) + var serviceContext = ServiceContext.topLevel + let traceIDString = UUID().uuidString + let interceptor = ClientTracingInterceptor(emitEventOnEachWrite: true) + let (stream, continuation) = AsyncStream.makeStream() + serviceContext.traceID = traceIDString + + try await ServiceContext.withValue(serviceContext) { + let response = try await interceptor.intercept( + request: .init(producer: { writer in + try await writer.write(contentsOf: ["request1"]) + try await writer.write(contentsOf: ["request2"]) + }), + context: .init(descriptor: methodDescriptor) + ) { stream, _ in + // Assert the metadata contains the injected context key-value. + XCTAssertEqual(stream.metadata, ["trace-id": "\(traceIDString)"]) + + // Write into the response stream to make sure the `producer` closure's called. + let writer = RPCWriter(wrapping: TestWriter(streamContinuation: continuation)) + try await stream.producer(writer) + continuation.finish() + + return .init( + metadata: [], + bodyParts: .init( + wrapping: AsyncStream { cont in + cont.yield(.message(["response"])) + cont.finish() + } + ) + ) + } + + var streamIterator = stream.makeAsyncIterator() + var element = await streamIterator.next() + XCTAssertEqual(element, "request1") + element = await streamIterator.next() + XCTAssertEqual(element, "request2") + element = await streamIterator.next() + XCTAssertNil(element) + + var messages = response.messages.makeAsyncIterator() + var message = try await messages.next() + XCTAssertEqual(message, ["response"]) + message = try await messages.next() + XCTAssertNil(message) + + let tracer = InstrumentationSystem.tracer as! TestTracer + XCTAssertEqual( + tracer.getEventsForTestSpan(ofOperationName: methodDescriptor.fullyQualifiedMethod).map { + $0.name + }, + [ + "Request started", + // Recorded when `request1` is sent + "Sending request part", + "Sent request part", + // Recorded when `request2` is sent + "Sending request part", + "Sent request part", + // Recorded after all request parts have been sent + "Request end", + // Recorded when receiving response part + "Received response part", + // Recorded at end of response + "Received response end", + ] + ) + } + } + #endif // swift >= 5.7 + + func testServerInterceptorErrorResponse() async throws { + let methodDescriptor = MethodDescriptor( + service: "TracingInterceptorTests", + method: "testServerInterceptorErrorResponse" + ) + let interceptor = ServerTracingInterceptor(emitEventOnEachWrite: false) + let response = try await interceptor.intercept( + request: .init(single: .init(metadata: ["trace-id": "some-trace-id"], message: [])), + context: .init(descriptor: methodDescriptor) + ) { _, _ in + ServerResponse.Stream(error: .init(code: .unknown, message: "Test error")) + } + XCTAssertThrowsError(try response.accepted.get()) + + let tracer = InstrumentationSystem.tracer as! TestTracer + XCTAssertEqual( + tracer.getEventsForTestSpan(ofOperationName: methodDescriptor.fullyQualifiedMethod).map { + $0.name + }, + [ + "Received request start", + "Received request end", + "Sent error response", + ] + ) + } + + func testServerInterceptor() async throws { + let methodDescriptor = MethodDescriptor( + service: "TracingInterceptorTests", + method: "testServerInterceptor" + ) + let (stream, continuation) = AsyncStream.makeStream() + let interceptor = ServerTracingInterceptor(emitEventOnEachWrite: false) + let response = try await interceptor.intercept( + request: .init(single: .init(metadata: ["trace-id": "some-trace-id"], message: [])), + context: .init(descriptor: methodDescriptor) + ) { _, _ in + { [serviceContext = ServiceContext.current] in + return ServerResponse.Stream( + accepted: .success( + .init( + metadata: [], + producer: { writer in + guard let serviceContext else { + XCTFail("There should be a service context present.") + return ["Result": "Test failed"] + } + + let traceID = serviceContext.traceID + XCTAssertEqual("some-trace-id", traceID) + + try await writer.write("response1") + try await writer.write("response2") + + return ["Result": "Trailing metadata"] + } + ) + ) + ) + }() + } + + let responseContents = try response.accepted.get() + let trailingMetadata = try await responseContents.producer( + RPCWriter(wrapping: TestWriter(streamContinuation: continuation)) + ) + continuation.finish() + XCTAssertEqual(trailingMetadata, ["Result": "Trailing metadata"]) + + var streamIterator = stream.makeAsyncIterator() + var element = await streamIterator.next() + XCTAssertEqual(element, "response1") + element = await streamIterator.next() + XCTAssertEqual(element, "response2") + element = await streamIterator.next() + XCTAssertNil(element) + + let tracer = InstrumentationSystem.tracer as! TestTracer + XCTAssertEqual( + tracer.getEventsForTestSpan(ofOperationName: methodDescriptor.fullyQualifiedMethod).map { + $0.name + }, + [ + "Received request start", + "Received request end", + "Sent response end", + ] + ) + } + + func testServerInterceptorAllEventsRecorded() async throws { + let methodDescriptor = MethodDescriptor( + service: "TracingInterceptorTests", + method: "testServerInterceptorAllEventsRecorded" + ) + let (stream, continuation) = AsyncStream.makeStream() + let interceptor = ServerTracingInterceptor(emitEventOnEachWrite: true) + let response = try await interceptor.intercept( + request: .init(single: .init(metadata: ["trace-id": "some-trace-id"], message: [])), + context: .init(descriptor: methodDescriptor) + ) { _, _ in + { [serviceContext = ServiceContext.current] in + return ServerResponse.Stream( + accepted: .success( + .init( + metadata: [], + producer: { writer in + guard let serviceContext else { + XCTFail("There should be a service context present.") + return ["Result": "Test failed"] + } + + let traceID = serviceContext.traceID + XCTAssertEqual("some-trace-id", traceID) + + try await writer.write("response1") + try await writer.write("response2") + + return ["Result": "Trailing metadata"] + } + ) + ) + ) + }() + } + + let responseContents = try response.accepted.get() + let trailingMetadata = try await responseContents.producer( + RPCWriter(wrapping: TestWriter(streamContinuation: continuation)) + ) + continuation.finish() + XCTAssertEqual(trailingMetadata, ["Result": "Trailing metadata"]) + + var streamIterator = stream.makeAsyncIterator() + var element = await streamIterator.next() + XCTAssertEqual(element, "response1") + element = await streamIterator.next() + XCTAssertEqual(element, "response2") + element = await streamIterator.next() + XCTAssertNil(element) + + let tracer = InstrumentationSystem.tracer as! TestTracer + XCTAssertEqual( + tracer.getEventsForTestSpan(ofOperationName: methodDescriptor.fullyQualifiedMethod).map { + $0.name + }, + [ + "Received request start", + "Received request end", + // Recorded when `response1` is sent + "Sending response part", + "Sent response part", + // Recorded when `response2` is sent + "Sending response part", + "Sent response part", + // Recorded when we're done sending response + "Sent response end", + ] + ) + } +} diff --git a/Tests/GRPCInterceptorsTests/TracingTestsUtilities.swift b/Tests/GRPCInterceptorsTests/TracingTestsUtilities.swift new file mode 100644 index 000000000..ffed5213a --- /dev/null +++ b/Tests/GRPCInterceptorsTests/TracingTestsUtilities.swift @@ -0,0 +1,182 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import GRPCCore +import NIOConcurrencyHelpers +import Tracing + +final class TestTracer: Tracer { + typealias Span = TestSpan + + private var testSpans: NIOLockedValueBox<[String: TestSpan]> = .init([:]) + + func getEventsForTestSpan(ofOperationName operationName: String) -> [SpanEvent] { + self.testSpans.withLockedValue({ $0[operationName] })?.events ?? [] + } + + func extract( + _ carrier: Carrier, + into context: inout ServiceContextModule.ServiceContext, + using extractor: Extract + ) where Carrier == Extract.Carrier, Extract: Instrumentation.Extractor { + let traceID = extractor.extract(key: TraceID.keyName, from: carrier) + context[TraceID.self] = traceID + } + + func inject( + _ context: ServiceContextModule.ServiceContext, + into carrier: inout Carrier, + using injector: Inject + ) where Carrier == Inject.Carrier, Inject: Instrumentation.Injector { + if let traceID = context.traceID { + injector.inject(traceID, forKey: TraceID.keyName, into: &carrier) + } + } + + func forceFlush() { + // no-op + } + + func startSpan( + _ operationName: String, + context: @autoclosure () -> ServiceContext, + ofKind kind: SpanKind, + at instant: @autoclosure () -> Instant, + function: String, + file fileID: String, + line: UInt + ) -> TestSpan where Instant: TracerInstant { + return self.testSpans.withLockedValue { testSpans in + let span = TestSpan(context: context(), operationName: operationName) + testSpans[operationName] = span + return span + } + } +} + +class TestSpan: Span { + var context: ServiceContextModule.ServiceContext + var operationName: String + var attributes: Tracing.SpanAttributes + var isRecording: Bool + private(set) var status: Tracing.SpanStatus? + private(set) var events: [Tracing.SpanEvent] = [] + + init( + context: ServiceContextModule.ServiceContext, + operationName: String, + attributes: Tracing.SpanAttributes = [:], + isRecording: Bool = true + ) { + self.context = context + self.operationName = operationName + self.attributes = attributes + self.isRecording = isRecording + } + + func setStatus(_ status: Tracing.SpanStatus) { + self.status = status + } + + func addEvent(_ event: Tracing.SpanEvent) { + self.events.append(event) + } + + func recordError( + _ error: any Error, + attributes: Tracing.SpanAttributes, + at instant: @autoclosure () -> Instant + ) where Instant: Tracing.TracerInstant { + self.setStatus( + .init( + code: .error, + message: "Error: \(error), attributes: \(attributes), at instant: \(instant())" + ) + ) + } + + func addLink(_ link: Tracing.SpanLink) { + self.context.spanLinks?.append(link) + } + + func end(at instant: @autoclosure () -> Instant) where Instant: Tracing.TracerInstant { + self.setStatus(.init(code: .ok, message: "Ended at instant: \(instant())")) + } +} + +enum TraceID: ServiceContextModule.ServiceContextKey { + typealias Value = String + + static let keyName = "trace-id" +} + +enum ServiceContextSpanLinksKey: ServiceContextModule.ServiceContextKey { + typealias Value = [SpanLink] + + static let keyName = "span-links" +} + +extension ServiceContext { + var traceID: String? { + get { + self[TraceID.self] + } + set { + self[TraceID.self] = newValue + } + } + + var spanLinks: [SpanLink]? { + get { + self[ServiceContextSpanLinksKey.self] + } + set { + self[ServiceContextSpanLinksKey.self] = newValue + } + } +} + +struct TestWriter: RPCWriterProtocol { + typealias Element = WriterElement + + private let streamContinuation: AsyncStream.Continuation + + init(streamContinuation: AsyncStream.Continuation) { + self.streamContinuation = streamContinuation + } + + func write(contentsOf elements: some Sequence) async throws { + elements.forEach { element in + self.streamContinuation.yield(element) + } + } +} + +#if swift(<5.9) +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +extension AsyncStream { + static func makeStream( + of elementType: Element.Type = Element.self, + bufferingPolicy limit: AsyncStream.Continuation.BufferingPolicy = .unbounded + ) -> (stream: AsyncStream, continuation: AsyncStream.Continuation) { + var continuation: AsyncStream.Continuation! + let stream = AsyncStream(Element.self, bufferingPolicy: limit) { + continuation = $0 + } + return (stream, continuation) + } +} +#endif