Skip to content
23 changes: 18 additions & 5 deletions Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2Connection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ final class HTTP2Connection {
let multiplexer: HTTP2StreamMultiplexer
let logger: Logger

/// A method with access to the stream channel that is called when creating the stream.
let streamChannelDebugInitializer: (@Sendable (Channel) -> EventLoopFuture<Void>)?

/// the connection pool that created the connection
let delegate: HTTP2ConnectionDelegate

Expand Down Expand Up @@ -95,7 +98,8 @@ final class HTTP2Connection {
decompression: HTTPClient.Decompression,
maximumConnectionUses: Int?,
delegate: HTTP2ConnectionDelegate,
logger: Logger
logger: Logger,
streamChannelDebugInitializer: (@Sendable (Channel) -> EventLoopFuture<Void>)? = nil
) {
self.channel = channel
self.id = connectionID
Expand All @@ -114,6 +118,7 @@ final class HTTP2Connection {
)
self.delegate = delegate
self.state = .initialized
self.streamChannelDebugInitializer = streamChannelDebugInitializer
}

deinit {
Expand All @@ -128,15 +133,17 @@ final class HTTP2Connection {
delegate: HTTP2ConnectionDelegate,
decompression: HTTPClient.Decompression,
maximumConnectionUses: Int?,
logger: Logger
logger: Logger,
streamChannelDebugInitializer: (@Sendable (Channel) -> EventLoopFuture<Void>)? = nil
) -> EventLoopFuture<(HTTP2Connection, Int)> {
let connection = HTTP2Connection(
channel: channel,
connectionID: connectionID,
decompression: decompression,
maximumConnectionUses: maximumConnectionUses,
delegate: delegate,
logger: logger
logger: logger,
streamChannelDebugInitializer: streamChannelDebugInitializer
)
return connection._start0().map { maxStreams in (connection, maxStreams) }
}
Expand Down Expand Up @@ -259,8 +266,14 @@ final class HTTP2Connection {
self.openStreams.remove(box)
}

channel.write(request, promise: nil)
return channel.eventLoop.makeSucceededVoidFuture()
if let streamChannelDebugInitializer = self.streamChannelDebugInitializer {
return streamChannelDebugInitializer(channel).map { _ in
channel.write(request, promise: nil)
}
} else {
channel.write(request, promise: nil)
return channel.eventLoop.makeSucceededVoidFuture()
}
} catch {
return channel.eventLoop.makeFailedFuture(error)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,19 @@ extension HTTPConnectionPool.ConnectionFactory {
decompression: self.clientConfiguration.decompression,
logger: logger
)
requester.http1ConnectionCreated(connection)

if let connectionDebugInitializer = self.clientConfiguration.http1_1ConnectionDebugInitializer {
connectionDebugInitializer(channel).whenComplete { debugInitializerResult in
switch debugInitializerResult {
case .success:
requester.http1ConnectionCreated(connection)
case .failure(let error):
requester.failedToCreateHTTPConnection(connectionID, error: error)
}
}
} else {
requester.http1ConnectionCreated(connection)
}
} catch {
requester.failedToCreateHTTPConnection(connectionID, error: error)
}
Expand All @@ -96,11 +108,34 @@ extension HTTPConnectionPool.ConnectionFactory {
delegate: http2ConnectionDelegate,
decompression: self.clientConfiguration.decompression,
maximumConnectionUses: self.clientConfiguration.maximumUsesPerConnection,
logger: logger
logger: logger,
streamChannelDebugInitializer:
self.clientConfiguration.http2StreamChannelDebugInitializer
).whenComplete { result in
switch result {
case .success((let connection, let maximumStreams)):
requester.http2ConnectionCreated(connection, maximumStreams: maximumStreams)
if let connectionDebugInitializer = self.clientConfiguration.http2ConnectionDebugInitializer {
connectionDebugInitializer(channel).whenComplete {
debugInitializerResult in
switch debugInitializerResult {
case .success:
requester.http2ConnectionCreated(
connection,
maximumStreams: maximumStreams
)
case .failure(let error):
requester.failedToCreateHTTPConnection(
connectionID,
error: error
)
}
}
} else {
requester.http2ConnectionCreated(
connection,
maximumStreams: maximumStreams
)
}
case .failure(let error):
requester.failedToCreateHTTPConnection(connectionID, error: error)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,9 @@ final class HTTPConnectionPool:
connection.executeRequest(request.req)

case .executeRequests(let requests, let connection):
for request in requests { connection.executeRequest(request.req) }
for request in requests {
connection.executeRequest(request.req)
}

case .failRequest(let request, let error):
request.req.fail(error)
Expand Down
35 changes: 35 additions & 0 deletions Sources/AsyncHTTPClient/HTTPClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,15 @@ public class HTTPClient {
/// By default, don't use it
public var enableMultipath: Bool

/// A method with access to the HTTP/1 connection channel that is called when creating the connection.
public var http1_1ConnectionDebugInitializer: (@Sendable (Channel) -> EventLoopFuture<Void>)?

/// A method with access to the HTTP/2 connection channel that is called when creating the connection.
public var http2ConnectionDebugInitializer: (@Sendable (Channel) -> EventLoopFuture<Void>)?

/// A method with access to the HTTP/2 stream channel that is called when creating the stream.
public var http2StreamChannelDebugInitializer: (@Sendable (Channel) -> EventLoopFuture<Void>)?

public init(
tlsConfiguration: TLSConfiguration? = nil,
redirectConfiguration: RedirectConfiguration? = nil,
Expand Down Expand Up @@ -949,6 +958,32 @@ public class HTTPClient {
decompression: decompression
)
}

public init(
tlsConfiguration: TLSConfiguration? = nil,
redirectConfiguration: RedirectConfiguration? = nil,
timeout: Timeout = Timeout(),
connectionPool: ConnectionPool = ConnectionPool(),
proxy: Proxy? = nil,
ignoreUncleanSSLShutdown: Bool = false,
decompression: Decompression = .disabled,
http1_1ConnectionDebugInitializer: (@Sendable (Channel) -> EventLoopFuture<Void>)? = nil,
http2ConnectionDebugInitializer: (@Sendable (Channel) -> EventLoopFuture<Void>)? = nil,
http2StreamChannelDebugInitializer: (@Sendable (Channel) -> EventLoopFuture<Void>)? = nil
) {
self.init(
tlsConfiguration: tlsConfiguration,
redirectConfiguration: redirectConfiguration,
timeout: timeout,
connectionPool: connectionPool,
proxy: proxy,
ignoreUncleanSSLShutdown: ignoreUncleanSSLShutdown,
decompression: decompression
)
self.http1_1ConnectionDebugInitializer = http1_1ConnectionDebugInitializer
self.http2ConnectionDebugInitializer = http2ConnectionDebugInitializer
self.http2StreamChannelDebugInitializer = http2StreamChannelDebugInitializer
}
}

/// Specifies how `EventLoopGroup` will be created and establishes lifecycle ownership.
Expand Down
170 changes: 170 additions & 0 deletions Tests/AsyncHTTPClientTests/HTTPClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4436,4 +4436,174 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass {
request.setBasicAuth(username: "foo", password: "bar")
XCTAssertEqual(request.headers.first(name: "Authorization"), "Basic Zm9vOmJhcg==")
}

func runBaseTestForHTTP1ConnectionDebugInitializer(ssl: Bool) {
let connectionDebugInitializerUtil = CountingDebugInitializerUtil()

// Initializing even with just `http1_1ConnectionDebugInitializer` (rather than manually
// modifying `config`) to ensure that the matching `init` actually wires up this argument
// with the respective property. This is necessary as these parameters are defaulted and can
// be easy to miss.
var config = HTTPClient.Configuration(
http1_1ConnectionDebugInitializer: { channel in
connectionDebugInitializerUtil.initialize(channel: channel)
}
)
config.httpVersion = .http1Only

if ssl {
config.tlsConfiguration = .clientDefault
config.tlsConfiguration?.certificateVerification = .none
}

let higherConnectTimeout = CountingDebugInitializerUtil.duration + .milliseconds(100)
var configWithHigherTimeout = config
configWithHigherTimeout.timeout = .init(connect: higherConnectTimeout)

let clientWithHigherTimeout = HTTPClient(
eventLoopGroupProvider: .singleton,
configuration: configWithHigherTimeout,
backgroundActivityLogger: Logger(
label: "HTTPClient",
factory: StreamLogHandler.standardOutput(label:)
)
)
defer { XCTAssertNoThrow(try clientWithHigherTimeout.syncShutdown()) }

let bin = HTTPBin(.http1_1(ssl: ssl, compress: false))
defer { XCTAssertNoThrow(try bin.shutdown()) }

let scheme = ssl ? "https" : "http"

for _ in 0..<3 {
XCTAssertNoThrow(
try clientWithHigherTimeout.get(url: "\(scheme)://localhost:\(bin.port)/get").wait()
)
}

// Even though multiple requests were made, the connection debug initializer must be called
// only once.
XCTAssertEqual(connectionDebugInitializerUtil.executionCount, 1)

let lowerConnectTimeout = CountingDebugInitializerUtil.duration - .milliseconds(100)
var configWithLowerTimeout = config
configWithLowerTimeout.timeout = .init(connect: lowerConnectTimeout)

let clientWithLowerTimeout = HTTPClient(
eventLoopGroupProvider: .singleton,
configuration: configWithLowerTimeout,
backgroundActivityLogger: Logger(
label: "HTTPClient",
factory: StreamLogHandler.standardOutput(label:)
)
)
defer { XCTAssertNoThrow(try clientWithLowerTimeout.syncShutdown()) }

XCTAssertThrowsError(
try clientWithLowerTimeout.get(url: "\(scheme)://localhost:\(bin.port)/get").wait()
) {
XCTAssertEqual($0 as? HTTPClientError, .connectTimeout)
}
}

func testHTTP1PlainTextConnectionDebugInitializer() {
runBaseTestForHTTP1ConnectionDebugInitializer(ssl: false)
}

func testHTTP1EncryptedConnectionDebugInitializer() {
runBaseTestForHTTP1ConnectionDebugInitializer(ssl: true)
}

func testHTTP2ConnectionAndStreamChannelDebugInitializers() {
let connectionDebugInitializerUtil = CountingDebugInitializerUtil()
let streamChannelDebugInitializerUtil = CountingDebugInitializerUtil()

// Initializing even with just `http2ConnectionDebugInitializer` and
// `http2StreamChannelDebugInitializer` (rather than manually modifying `config`) to ensure
// that the matching `init` actually wires up these arguments with the respective
// properties. This is necessary as these parameters are defaulted and can be easy to miss.
var config = HTTPClient.Configuration(
http2ConnectionDebugInitializer: { channel in
connectionDebugInitializerUtil.initialize(channel: channel)
},
http2StreamChannelDebugInitializer: { channel in
streamChannelDebugInitializerUtil.initialize(channel: channel)
}
)
config.tlsConfiguration = .clientDefault
config.tlsConfiguration?.certificateVerification = .none
config.httpVersion = .automatic

let higherConnectTimeout = CountingDebugInitializerUtil.duration + .milliseconds(100)
var configWithHigherTimeout = config
configWithHigherTimeout.timeout = .init(connect: higherConnectTimeout)

let clientWithHigherTimeout = HTTPClient(
eventLoopGroupProvider: .singleton,
configuration: configWithHigherTimeout,
backgroundActivityLogger: Logger(
label: "HTTPClient",
factory: StreamLogHandler.standardOutput(label:)
)
)
defer { XCTAssertNoThrow(try clientWithHigherTimeout.syncShutdown()) }

let bin = HTTPBin(.http2(compress: false))
defer { XCTAssertNoThrow(try bin.shutdown()) }

let numberOfRequests = 3

for _ in 0..<numberOfRequests {
XCTAssertNoThrow(
try clientWithHigherTimeout.get(url: "https://localhost:\(bin.port)/get").wait()
)
}

// Even though multiple requests were made, the connection debug initializer must be called
// only once.
XCTAssertEqual(connectionDebugInitializerUtil.executionCount, 1)

// The stream channel debug initializer must be called only as much as the number of
// requests made.
XCTAssertEqual(streamChannelDebugInitializerUtil.executionCount, numberOfRequests)

let lowerConnectTimeout = CountingDebugInitializerUtil.duration - .milliseconds(100)
var configWithLowerTimeout = config
configWithLowerTimeout.timeout = .init(connect: lowerConnectTimeout)

let clientWithLowerTimeout = HTTPClient(
eventLoopGroupProvider: .singleton,
configuration: configWithLowerTimeout,
backgroundActivityLogger: Logger(
label: "HTTPClient",
factory: StreamLogHandler.standardOutput(label:)
)
)
defer { XCTAssertNoThrow(try clientWithLowerTimeout.syncShutdown()) }

XCTAssertThrowsError(
try clientWithLowerTimeout.get(url: "https://localhost:\(bin.port)/get").wait()
) {
XCTAssertEqual($0 as? HTTPClientError, .connectTimeout)
}
}
}

final class CountingDebugInitializerUtil: Sendable {
private let _executionCount = NIOLockedValueBox<Int>(0)
var executionCount: Int { self._executionCount.withLockedValue { $0 } }

/// The minimum time to spend running the debug initializer.
static let duration: TimeAmount = .milliseconds(300)

/// The actual debug initializer.
func initialize(channel: Channel) -> EventLoopFuture<Void> {
self._executionCount.withLockedValue { $0 += 1 }

let someScheduledTask = channel.eventLoop.scheduleTask(in: Self.duration) {
channel.eventLoop.makeSucceededVoidFuture()
}

return someScheduledTask.futureResult.flatMap { $0 }
}
}
Loading