diff --git a/Sources/NIOPosix/ControlMessage.swift b/Sources/NIOPosix/ControlMessage.swift index 4c87483a48..6978a21e2a 100644 --- a/Sources/NIOPosix/ControlMessage.swift +++ b/Sources/NIOPosix/ControlMessage.swift @@ -26,30 +26,46 @@ import CNIOWindows struct UnsafeControlMessageStorage: Collection { let bytesPerMessage: Int var buffer: UnsafeMutableRawBufferPointer + private let deallocateBuffer: Bool /// Initialise which includes allocating memory /// parameter: /// - bytesPerMessage: How many bytes have been allocated for each supported message. /// - buffer: The memory allocated to use for control messages. - private init(bytesPerMessage: Int, buffer: UnsafeMutableRawBufferPointer) { + /// - deallocateBuffer: buffer owning indicator + private init(bytesPerMessage: Int, buffer: UnsafeMutableRawBufferPointer, deallocateBuffer: Bool) { self.bytesPerMessage = bytesPerMessage self.buffer = buffer + self.deallocateBuffer = deallocateBuffer } + // Guess that 4 Int32 payload messages is enough for anyone. + static var bytesPerMessage: Int { NIOBSDSocketControlMessage.space(payloadSize: MemoryLayout.stride) * 4 } + /// Allocate new memory - Caller must call `deallocate` when no longer required. /// parameter: /// - msghdrCount: How many `msghdr` structures will be fed from this buffer - we assume 4 Int32 cmsgs for each. static func allocate(msghdrCount: Int) -> UnsafeControlMessageStorage { - // Guess that 4 Int32 payload messages is enough for anyone. - let bytesPerMessage = NIOBSDSocketControlMessage.space(payloadSize: MemoryLayout.stride) * 4 + let bytesPerMessage = Self.bytesPerMessage let buffer = UnsafeMutableRawBufferPointer.allocate(byteCount: bytesPerMessage * msghdrCount, - alignment: MemoryLayout.alignment) - return UnsafeControlMessageStorage(bytesPerMessage: bytesPerMessage, buffer: buffer) + alignment: MemoryLayout.alignment) + return UnsafeControlMessageStorage(bytesPerMessage: bytesPerMessage, buffer: buffer, deallocateBuffer: true) + } + + /// Create an instance not owning the buffer + /// parameter: + /// - bytesPerMessage: How many bytes have been allocated for each supported message. + /// - buffer: The memory allocated to use for control messages. + static func makeNotOwning(bytesPerMessage: Int, buffer: UnsafeMutableRawBufferPointer) -> UnsafeControlMessageStorage { + precondition(buffer.count >= bytesPerMessage) + return UnsafeControlMessageStorage(bytesPerMessage: bytesPerMessage, buffer: buffer, deallocateBuffer: false) } mutating func deallocate() { - self.buffer.deallocate() - self.buffer = UnsafeMutableRawBufferPointer(start: UnsafeMutableRawPointer(bitPattern: 0x7eadbeef), count: 0) + if self.deallocateBuffer { + self.buffer.deallocate() + self.buffer = UnsafeMutableRawBufferPointer(start: UnsafeMutableRawPointer(bitPattern: 0x7eadbeef), count: 0) + } } /// Get the part of the buffer for use with a message. @@ -65,7 +81,6 @@ struct UnsafeControlMessageStorage: Collection { func index(after: Int) -> Int { return after + 1 } - } /// Representation of a `cmsghdr` and associated data. diff --git a/Sources/NIOPosix/PendingDatagramWritesManager.swift b/Sources/NIOPosix/PendingDatagramWritesManager.swift index 315ec03327..e2bfea7816 100644 --- a/Sources/NIOPosix/PendingDatagramWritesManager.swift +++ b/Sources/NIOPosix/PendingDatagramWritesManager.swift @@ -383,7 +383,6 @@ final class PendingDatagramWritesManager: PendingWritesManager { private let bufferPool: Pool private let msgBufferPool: Pool - private let controlMessageStorage: UnsafeControlMessageStorage private var state = PendingDatagramWritesState() @@ -400,13 +399,10 @@ final class PendingDatagramWritesManager: PendingWritesManager { /// /// - parameters: /// - bufferPool: a pool of buffers to be used for IOVector and storage references - /// - msgs: A pre-allocated array of `MMsgHdr` elements - /// - addresses: A pre-allocated array of `sockaddr_storage` elements - /// - controlMessageStorage: Pre-allocated memory for storing cmsghdr data during a vector write operation. - init(bufferPool: Pool, msgBufferPool: Pool, controlMessageStorage: UnsafeControlMessageStorage) { + /// - msgBufferPool: a pool of buffers to be usded for `MMsgHdr`, `sockaddr_storage` and cmsghdr elements + init(bufferPool: Pool, msgBufferPool: Pool) { self.bufferPool = bufferPool self.msgBufferPool = msgBufferPool - self.controlMessageStorage = controlMessageStorage } /// Mark the flush checkpoint. @@ -610,12 +606,12 @@ final class PendingDatagramWritesManager: PendingWritesManager { let msgBuffer = self.msgBufferPool.get() defer { self.msgBufferPool.put(msgBuffer) } - return try msgBuffer.withUnsafePointers { msgs, addresses in + return try msgBuffer.withUnsafePointers { msgs, addresses, controlMessageStorage in return self.didWrite(try doPendingDatagramWriteVectorOperation(pending: self.state, bufferPool: self.bufferPool, msgs: msgs, addresses: addresses, - controlMessageStorage: self.controlMessageStorage, + controlMessageStorage: controlMessageStorage, { try vectorWriteOperation($0) }), messages: msgs) } diff --git a/Sources/NIOPosix/Pool.swift b/Sources/NIOPosix/Pool.swift index 56999c7824..6d6d10b22e 100644 --- a/Sources/NIOPosix/Pool.swift +++ b/Sources/NIOPosix/Pool.swift @@ -212,6 +212,7 @@ struct PooledMsgBuffer: PoolElement { let count: Int let spaceForMsgHdrs: Int let spaceForAddresses: Int + let spaceForControlData: Int init(count: Int) { var spaceForMsgHdrs = MemoryLayout.stride * count @@ -220,13 +221,17 @@ struct PooledMsgBuffer: PoolElement { var spaceForAddress = MemoryLayout.stride * count spaceForAddress.roundUpToAlignment(for: MemorySentinel.self) + var spaceForControlData = (UnsafeControlMessageStorage.bytesPerMessage * count) + spaceForControlData.roundUpToAlignment(for: cmsghdr.self) + self.count = count self.spaceForMsgHdrs = spaceForMsgHdrs self.spaceForAddresses = spaceForAddress + self.spaceForControlData = spaceForControlData } var totalByteCount: Int { - self.spaceForMsgHdrs + self.spaceForAddresses + MemoryLayout.size + self.spaceForMsgHdrs + self.spaceForAddresses + self.spaceForControlData + MemoryLayout.size } var msgHdrsOffset: Int { @@ -237,8 +242,12 @@ struct PooledMsgBuffer: PoolElement { self.spaceForMsgHdrs } + var controlDataOffset: Int { + self.spaceForMsgHdrs + self.spaceForAddresses + } + var memorySentinelOffset: Int { - return self.spaceForMsgHdrs + self.spaceForAddresses + return self.spaceForMsgHdrs + self.spaceForAddresses + self.spaceForControlData } } @@ -254,6 +263,7 @@ struct PooledMsgBuffer: PoolElement { storage.withUnsafeMutablePointers { headPointer, tailPointer in UnsafeRawPointer(tailPointer + headPointer.pointee.msgHdrsOffset).bindMemory(to: MMsgHdr.self, capacity: count) UnsafeRawPointer(tailPointer + headPointer.pointee.addressesOffset).bindMemory(to: sockaddr_storage.self, capacity: count) + // space for control message data not needed to be bound UnsafeRawPointer(tailPointer + headPointer.pointee.memorySentinelOffset).bindMemory(to: MemorySentinel.self, capacity: 1) } @@ -261,11 +271,12 @@ struct PooledMsgBuffer: PoolElement { } func withUnsafeMutableTypedPointers( - _ body: (UnsafeMutableBufferPointer, UnsafeMutableBufferPointer, UnsafeMutablePointer) throws -> ReturnType + _ body: (UnsafeMutableBufferPointer, UnsafeMutableBufferPointer, UnsafeControlMessageStorage, UnsafeMutablePointer) throws -> ReturnType ) rethrows -> ReturnType { return try self.withUnsafeMutablePointers { headPointer, tailPointer in let msgHdrsPointer = UnsafeMutableRawPointer(tailPointer + headPointer.pointee.msgHdrsOffset).assumingMemoryBound(to: MMsgHdr.self) let addressesPointer = UnsafeMutableRawPointer(tailPointer + headPointer.pointee.addressesOffset).assumingMemoryBound(to: sockaddr_storage.self) + let controlDataPointer = UnsafeMutableRawBufferPointer(start: tailPointer + headPointer.pointee.controlDataOffset, count: headPointer.pointee.spaceForControlData) let sentinelPointer = UnsafeMutableRawPointer(tailPointer + headPointer.pointee.memorySentinelOffset).assumingMemoryBound(to: MemorySentinel.self) let msgHdrsBufferPointer = UnsafeMutableBufferPointer( @@ -274,13 +285,16 @@ struct PooledMsgBuffer: PoolElement { let addressesBufferPointer = UnsafeMutableBufferPointer( start: addressesPointer, count: headPointer.pointee.count ) - return try body(msgHdrsBufferPointer, addressesBufferPointer, sentinelPointer) + let controlMessageStorage = UnsafeControlMessageStorage.makeNotOwning( + bytesPerMessage: UnsafeControlMessageStorage.bytesPerMessage, + buffer: controlDataPointer) + return try body(msgHdrsBufferPointer, addressesBufferPointer, controlMessageStorage, sentinelPointer) } } } private func validateSentinel() { - self.storage.withUnsafeMutableTypedPointers { _, _, sentinelPointer in + self.storage.withUnsafeMutableTypedPointers { _, _, _, sentinelPointer in precondition(sentinelPointer.pointee == Self.sentinelValue, "Detected memory handling error!") } } @@ -289,7 +303,7 @@ struct PooledMsgBuffer: PoolElement { init() { self.storage = .create(count: Socket.writevLimitIOVectors) - self.storage.withUnsafeMutableTypedPointers { _, _, sentinelPointer in + self.storage.withUnsafeMutableTypedPointers { _, _, _, sentinelPointer in sentinelPointer.pointee = Self.sentinelValue } } @@ -299,22 +313,22 @@ struct PooledMsgBuffer: PoolElement { } func withUnsafePointers( - _ body: (UnsafeMutableBufferPointer, UnsafeMutableBufferPointer) throws -> ReturnValue + _ body: (UnsafeMutableBufferPointer, UnsafeMutableBufferPointer, UnsafeControlMessageStorage) throws -> ReturnValue ) rethrows -> ReturnValue { defer { self.validateSentinel() } - return try self.storage.withUnsafeMutableTypedPointers { msgs, addresses, _ in - return try body(msgs, addresses) + return try self.storage.withUnsafeMutableTypedPointers { msgs, addresses, controlMessageStorage, _ in + return try body(msgs, addresses, controlMessageStorage) } } func withUnsafePointersWithStorageManagement( - _ body: (UnsafeMutableBufferPointer, UnsafeMutableBufferPointer, Unmanaged) throws -> ReturnValue + _ body: (UnsafeMutableBufferPointer, UnsafeMutableBufferPointer, UnsafeControlMessageStorage, Unmanaged) throws -> ReturnValue ) rethrows -> ReturnValue { let storageRef: Unmanaged = Unmanaged.passUnretained(self.storage) - return try self.storage.withUnsafeMutableTypedPointers { msgs, addresses, _ in - try body(msgs, addresses, storageRef) + return try self.storage.withUnsafeMutableTypedPointers { msgs, addresses, controlMessageStorage, _ in + try body(msgs, addresses, controlMessageStorage, storageRef) } } } diff --git a/Sources/NIOPosix/SelectableEventLoop.swift b/Sources/NIOPosix/SelectableEventLoop.swift index cab345cbfe..083babd2d0 100644 --- a/Sources/NIOPosix/SelectableEventLoop.swift +++ b/Sources/NIOPosix/SelectableEventLoop.swift @@ -105,9 +105,6 @@ internal final class SelectableEventLoop: EventLoop { let bufferPool: Pool let msgBufferPool: Pool - // Used for UDP control messages. - private(set) var controlMessageStorage: UnsafeControlMessageStorage - // The `_parentGroup` will always be set unless this is a thread takeover or we shut down. @usableFromInline internal var _parentGroup: Optional @@ -185,7 +182,6 @@ Further information: self.thread = thread self.bufferPool = Pool(maxSize: 16) self.msgBufferPool = Pool(maxSize: 16) - self.controlMessageStorage = UnsafeControlMessageStorage.allocate(msghdrCount: Socket.writevLimitIOVectors) // We will process 4096 tasks per while loop. self.tasksCopy.reserveCapacity(4096) self.canBeShutdownIndividually = canBeShutdownIndividually @@ -202,7 +198,6 @@ Further information: "illegal internal state on deinit: \(self.internalState)") assert(self.externalState == .resourcesReclaimed, "illegal external state on shutdown: \(self.externalState)") - self.controlMessageStorage.deallocate() } /// Is this `SelectableEventLoop` still open (ie. not shutting down or shut down) diff --git a/Sources/NIOPosix/SocketChannel.swift b/Sources/NIOPosix/SocketChannel.swift index 6a0f3201c5..e0a1b8bac0 100644 --- a/Sources/NIOPosix/SocketChannel.swift +++ b/Sources/NIOPosix/SocketChannel.swift @@ -424,8 +424,7 @@ final class DatagramChannel: BaseSocketChannel { } self.pendingWrites = PendingDatagramWritesManager(bufferPool: eventLoop.bufferPool, - msgBufferPool: eventLoop.msgBufferPool, - controlMessageStorage: eventLoop.controlMessageStorage) + msgBufferPool: eventLoop.msgBufferPool) try super.init( socket: socket, @@ -440,8 +439,7 @@ final class DatagramChannel: BaseSocketChannel { self.vectorReadManager = nil try socket.setNonBlocking() self.pendingWrites = PendingDatagramWritesManager(bufferPool: eventLoop.bufferPool, - msgBufferPool: eventLoop.msgBufferPool, - controlMessageStorage: eventLoop.controlMessageStorage) + msgBufferPool: eventLoop.msgBufferPool) try super.init( socket: socket, parent: parent, @@ -607,24 +605,22 @@ final class DatagramChannel: BaseSocketChannel { override func readFromSocket() throws -> ReadResult { if self.vectorReadManager != nil { return try self.vectorReadFromSocket() + } else if self.reportExplicitCongestionNotifications || self.receivePacketInfo { + let pooledMsgBuffer = self.selectableEventLoop.msgBufferPool.get() + defer { self.selectableEventLoop.msgBufferPool.put(pooledMsgBuffer) } + return try pooledMsgBuffer.withUnsafePointers { _, _, controlMessageStorage in + return try self.singleReadFromSocket(controlBytesBuffer: controlMessageStorage[0]) + } } else { - return try self.singleReadFromSocket() + return try self.singleReadFromSocket(controlBytesBuffer: UnsafeMutableRawBufferPointer(start: nil, count: 0)) } } - private func singleReadFromSocket() throws -> ReadResult { + private func singleReadFromSocket(controlBytesBuffer: UnsafeMutableRawBufferPointer) throws -> ReadResult { var rawAddress = sockaddr_storage() var rawAddressLength = socklen_t(MemoryLayout.size) var readResult = ReadResult.none - // These control bytes must not escape the current call stack - let controlBytesBuffer: UnsafeMutableRawBufferPointer - if self.reportExplicitCongestionNotifications || self.receivePacketInfo { - controlBytesBuffer = self.selectableEventLoop.controlMessageStorage[0] - } else { - controlBytesBuffer = UnsafeMutableRawBufferPointer(start: nil, count: 0) - } - for _ in 1...self.maxMessagesPerRead { guard self.isOpen else { throw ChannelError.eof @@ -804,16 +800,17 @@ final class DatagramChannel: BaseSocketChannel { override func writeToSocket() throws -> OverallWriteResult { let result = try self.pendingWrites.triggerAppropriateWriteOperations( scalarWriteOperation: { (ptr, destinationPtr, destinationSize, metadata) in - // normal write - // Control bytes must not escape current stack. - var controlBytes = UnsafeOutboundControlBytes( - controlBytes: self.selectableEventLoop.controlMessageStorage[0]) - controlBytes.appendExplicitCongestionState(metadata: metadata, - protocolFamily: self.localAddress?.protocol) - return try self.socket.sendmsg(pointer: ptr, - destinationPtr: destinationPtr, - destinationSize: destinationSize, - controlBytes: controlBytes.validControlBytes) + let msgBuffer = self.selectableEventLoop.msgBufferPool.get() + defer { self.selectableEventLoop.msgBufferPool.put(msgBuffer) } + return try msgBuffer.withUnsafePointers { _, _, controlMessageStorage in + var controlBytes = UnsafeOutboundControlBytes(controlBytes: controlMessageStorage[0]) + controlBytes.appendExplicitCongestionState(metadata: metadata, + protocolFamily: self.localAddress?.protocol) + return try self.socket.sendmsg(pointer: ptr, + destinationPtr: destinationPtr, + destinationSize: destinationSize, + controlBytes: controlBytes.validControlBytes) + } }, vectorWriteOperation: { msgs in return try self.socket.sendmmsg(msgs: msgs) @@ -822,7 +819,6 @@ final class DatagramChannel: BaseSocketChannel { return result } - // MARK: Datagram Channel overrides not required by BaseSocketChannel override func bind0(to address: SocketAddress, promise: EventLoopPromise?) { diff --git a/Tests/NIOPosixTests/PendingDatagramWritesManagerTests.swift b/Tests/NIOPosixTests/PendingDatagramWritesManagerTests.swift index 6f3cc4580a..323143b9b0 100644 --- a/Tests/NIOPosixTests/PendingDatagramWritesManagerTests.swift +++ b/Tests/NIOPosixTests/PendingDatagramWritesManagerTests.swift @@ -49,11 +49,7 @@ class PendingDatagramWritesManagerTests: XCTestCase { private func withPendingDatagramWritesManager(_ body: (PendingDatagramWritesManager) throws -> Void) rethrows { let bufferPool = Pool(maxSize: 16) let msgBufferPool = Pool(maxSize: 16) - var controlMessageStorage = UnsafeControlMessageStorage.allocate(msghdrCount: Socket.writevLimitIOVectors) - defer { - controlMessageStorage.deallocate() - } - let pwm = NIOPosix.PendingDatagramWritesManager(bufferPool: bufferPool, msgBufferPool: msgBufferPool, controlMessageStorage: controlMessageStorage) + let pwm = NIOPosix.PendingDatagramWritesManager(bufferPool: bufferPool, msgBufferPool: msgBufferPool) XCTAssertTrue(pwm.isEmpty) XCTAssertTrue(pwm.isOpen)