diff --git a/Package.swift b/Package.swift index 330890a89..197abece0 100644 --- a/Package.swift +++ b/Package.swift @@ -44,6 +44,7 @@ let package = Package( .package(url: "https://github.com/apple/swift-atomics.git", from: "1.0.2"), .package(url: "https://github.com/apple/swift-algorithms.git", from: "1.0.0"), .package(url: "https://github.com/apple/swift-distributed-tracing.git", from: "1.3.0"), + .package(url: "https://github.com/apple/swift-crypto.git", from: "4.2.0"), .package(url: "https://github.com/apple/swift-configuration.git", from: "1.0.0"), ], targets: [ @@ -70,6 +71,7 @@ let package = Package( .product(name: "NIOTransportServices", package: "swift-nio-transport-services"), .product(name: "Atomics", package: "swift-atomics"), .product(name: "Algorithms", package: "swift-algorithms"), + .product(name: "Crypto", package: "swift-crypto"), .product(name: "Configuration", package: "swift-configuration"), // Observability support .product(name: "Logging", package: "swift-log"), diff --git a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest+Prepared.swift b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest+Prepared.swift index 03cd8e464..7c731ba8f 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest+Prepared.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest+Prepared.swift @@ -42,6 +42,7 @@ extension HTTPClientRequest { var head: HTTPRequestHead var body: Body? var tlsConfiguration: TLSConfiguration? + var tlsPinning: SPKIPinningConfiguration? } } @@ -82,7 +83,8 @@ extension HTTPClientRequest.Prepared { headers: headers ), body: request.body.map { .init($0) }, - tlsConfiguration: request.tlsConfiguration + tlsConfiguration: request.tlsConfiguration, + tlsPinning: request.tlsPinning ) } } diff --git a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest.swift b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest.swift index dca7de0ef..125fe69ee 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest.swift @@ -53,6 +53,9 @@ public struct HTTPClientRequest: Sendable { /// Request-specific TLS configuration, defaults to no request-specific TLS configuration. public var tlsConfiguration: TLSConfiguration? + /// Optional SPKI pinning configuration for TLS certificate validation. + public var tlsPinning: SPKIPinningConfiguration? + public init(url: String) { self.url = url self.method = .GET diff --git a/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift b/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift index 30c7c877f..f6eb23e4d 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift @@ -173,6 +173,7 @@ final class Transaction: extension Transaction: HTTPSchedulableRequest { var poolKey: ConnectionPool.Key { self.request.poolKey } var tlsConfiguration: TLSConfiguration? { self.request.tlsConfiguration } + var tlsPinning: SPKIPinningConfiguration? { self.request.tlsPinning } var requiredEventLoop: EventLoop? { nil } func requestWasQueued(_ scheduler: HTTPRequestScheduler) { diff --git a/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/SPKIPinningHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/SPKIPinningHandler.swift new file mode 100644 index 000000000..5e32dd1c1 --- /dev/null +++ b/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/SPKIPinningHandler.swift @@ -0,0 +1,352 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Foundation +import NIOCore +import NIOTLS +import NIOSSL +import Logging +import Crypto +import Algorithms + +/// SPKI hash for certificate pinning validation. +/// +/// Validates server identity using the DER-encoded public key structure (RFC 5280, Section 4.1). +/// Survives legitimate certificate rotations and prevents algorithm downgrade attacks. +/// +/// - SeeAlso: https://datatracker.ietf.org/doc/html/rfc5280#section-4.1 +/// - SeeAlso: https://owasp.org/www-project-mobile-security-testing-guide/latest/0x05g-Testing-Network-Communication.html +public struct SPKIHash: Sendable, Hashable { + + /// Raw hash digest bytes of the SPKI structure. + public let bytes: Data + + fileprivate let algorithmID: ObjectIdentifier + private let algorithm: @Sendable (Data) -> any Sequence + + // MARK: - Initialization + + /// Creates an SPKI hash using a custom hash algorithm and base64-encoded string. + /// + /// - Parameters: + /// - algorithm: Hash algorithm used to generate the digest. + /// - base64: Base64-encoded hash digest. + /// + /// - Throws: `HTTPClientError.invalidDigestLength` if length doesn't match algorithm. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public init(algorithm: Algorithm.Type, base64: String) throws { + guard let data = Data(base64Encoded: base64) else { + throw HTTPClientError.invalidDigestLength + } + try self.init(algorithm: algorithm, bytes: data) + } + + /// Creates an SPKI hash from raw digest bytes using a specified hash algorithm. + /// + /// - Parameters: + /// - algorithm: Hash algorithm that generated the digest bytes. + /// - bytes: Raw digest bytes. + /// + /// - Throws: `HTTPClientError.invalidDigestLength` if byte count doesn't match algorithm. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public init(algorithm: Algorithm.Type, bytes: Data) throws { + guard bytes.count == Algorithm.Digest.byteCount else { + throw HTTPClientError.invalidDigestLength + } + self.bytes = bytes + self.algorithm = Algorithm.hash(data:) + self.algorithmID = .init(algorithm) + } + + // MARK: - Equality and Hashing + + public static func ==(lhs: Self, rhs: Self) -> Bool { + lhs.bytes == rhs.bytes && lhs.algorithmID == rhs.algorithmID + } + + public func hash(into hasher: inout Hasher) { + hasher.combine(bytes) + hasher.combine(algorithmID) + } + + fileprivate func hash(_ spkiData: Data) -> Data { + Data(algorithm(spkiData)) + } +} + +/// Constant-time comparison to prevent timing attacks during SPKI pin validation. +/// +/// Timing attacks exploit micro-variations in execution time to infer secret values. +/// This function eliminates such leaks by: +/// 1. Iterating a fixed number of times (determined by hash algorithm) +/// 2. Processing all candidates even after a match is found +/// 3. Performing all secret-dependent operations before any conditional branches +/// +/// SECURITY INVARIANT: All candidates share identical length (enforced by +/// SPKIPinningConfiguration grouping and SPKIHash validation). Length mismatches +/// are rejected early using public knowledge (algorithm-determined digest size), +/// which cannot leak secret information. +internal func constantTimeAnyMatch(_ target: Data, _ candidates: [SPKIHash]) -> Bool { + guard !candidates.isEmpty else { return false } + + let expectedLength = candidates[0].bytes.count + guard target.count == expectedLength else { return false } + + var anyMatch: UInt8 = 0 + for candidate in candidates { + precondition( + candidate.bytes.count == expectedLength, + "Algorithm grouping invariant violated: candidates must share identical length" + ) + + var diff: UInt8 = 0 + for i in 0 ..< expectedLength { + diff |= target[i] ^ candidate.bytes[i] + } + anyMatch |= (diff == 0) ? 1 : 0 + } + return anyMatch != 0 +} + +/// Configuration for SPKI pinning validation. +/// +/// - Warning: Always deploy multiple pins to enable safe certificate rotation. +/// Single-pin configurations in `.strict` mode risk catastrophic lockout. +public struct SPKIPinningConfiguration: Sendable, Hashable { + + /// SPKI hashes of trusted certificates. + public let pins: [SPKIHash] + + /// Policy for handling pin validation failures. + public let policy: SPKIPinningPolicy + + private let pinsByAlgorithm: [ObjectIdentifier: [SPKIHash]] + + /// Creates an SPKI pinning configuration. + /// + /// - Parameters: + /// - pins: Hashes of trusted certificates. For production safety, include + /// pins for both current and upcoming certificates to enable rotation. + /// - policy: Validation failure policy (`.strict` for production, `.audit` for debugging). + public init( + pins: [SPKIHash], + policy: SPKIPinningPolicy = .strict + ) { + self.pins = pins + self.pinsByAlgorithm = Dictionary(grouping: Set(pins), by: \.algorithmID) + self.policy = policy + } + + internal func contains(spkiBytes: [UInt8]) -> Bool { + let spkiData = Data(spkiBytes) + + var anyMatch: UInt8 = 0 + for hashes in pinsByAlgorithm.values { + guard let first = hashes.first else { continue } + let computedHash = first.hash(spkiData) + let isMatch = constantTimeAnyMatch(computedHash, hashes) + anyMatch |= isMatch ? 1 : 0 + } + return anyMatch != 0 + } +} + +/// Security policy for SPKI pin validation failures. +/// +/// Determines the client's response when a server's certificate fails SPKI pin validation: +/// - `.audit`: Permit the connection for observability (staging/debugging only) +/// - `.strict`: Terminate the connection immediately (production environments) +/// +/// - Warning: Never use `.audit` in production — it effectively disables pinning security +/// guarantees while maintaining audit visibility. +public struct SPKIPinningPolicy: Sendable, Hashable { + + private enum RawValue: Sendable, Hashable { + case audit + case strict + } + + /// Permit connections with untrusted certificates for observability only. + /// + /// Use exclusively for debugging, testing, or migration scenarios. Never use in production. + public static let audit = SPKIPinningPolicy(rawValue: .audit) + + /// Reject connections with untrusted certificates. + /// + /// Use in production environments where security is paramount. + public static let strict = SPKIPinningPolicy(rawValue: .strict) + + public var description: String { + switch self.rawValue { + case .audit: return "audit" + case .strict: return "strict" + } + } + + private let rawValue: RawValue + + private init(rawValue: RawValue) { + self.rawValue = rawValue + } + + public static func ==(lhs: Self, rhs: Self) -> Bool { + lhs.rawValue == rhs.rawValue + } + + public func hash(into hasher: inout Hasher) { + hasher.combine(rawValue) + } +} + +/// ChannelHandler that validates server certificates using SPKI pinning. +/// +/// Performs constant-time comparison of the server's public key hash against trusted pins. +/// Rejects connections in `.strict` mode on mismatch; permits in `.audit` mode for observability. +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +final class SPKIPinningHandler: ChannelInboundHandler, RemovableChannelHandler { + + typealias InboundIn = NIOAny + + private let tlsPinning: SPKIPinningConfiguration + private let logger: Logger + + init( + tlsPinning: SPKIPinningConfiguration, + logger: Logger + ) { + self.tlsPinning = tlsPinning + self.logger = logger + + if tlsPinning.pins.count < 2 && tlsPinning.policy == .strict { + logger.warning( + "SPKIPinningHandler deployed with < 2 pins in strict mode — catastrophic lockout risk on certificate rotation!", + metadata: [ + "current_pin_count": .stringConvertible(tlsPinning.pins.count), + "recommendation": .string("Deploy multiple pins to enable safe certificate rotation") + ] + ) + } + } + + func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { + guard case .handshakeCompleted = (event as? TLSUserEvent) else { + context.fireUserInboundEventTriggered(event) + return + } + + // ⚠️ Security-critical: handshake propagation is delayed until validation completes + let result = self.validatePinning(for: Result { + try context.pipeline.syncOperations.handler(type: NIOSSLHandler.self).peerCertificate + }) + + switch result { + case .accepted: + context.fireUserInboundEventTriggered(event) + + case .auditWarning(let error): + logger.warning( + "SPKI pinning failed — connection allowed for audit purposes", + metadata: [ + "error": .string(String(describing: error)), + "policy": .string(tlsPinning.policy.description) + ] + ) + context.fireUserInboundEventTriggered(event) + + case .rejected(let error): + let metadata: Logger.Metadata = [ + "policy": .string(tlsPinning.policy.description), + "expected_pins": .string( + tlsPinning.pins.map { $0.bytes.base64EncodedString() }.joined(separator: ", ") + ) + ] + + logger.error("SPKI pinning failed — connection blocked", metadata: metadata) + + let error = HTTPClientError.invalidCertificatePinning(String(describing: error)) + context.fireErrorCaught(error) + context.close(promise: nil) + } + } + + func validatePinning(for peerCertificate: Result) -> PinningValidationResult { + switch peerCertificate { + case .success(let peerCertificate): + guard let leaf = peerCertificate else { + let error = SPKIPinningHandlerError.emptyCertificateChain + return tlsPinning.policy == .audit + ? .auditWarning(error) + : .rejected(error) + } + + let spkiBytes: [UInt8] + do { + let publicKey = try leaf.extractPublicKey() + spkiBytes = try publicKey.toSPKIBytes() + } catch { + let error = SPKIPinningHandlerError.extractionFailed(String(describing: error)) + return tlsPinning.policy == .audit + ? .auditWarning(error) + : .rejected(error) + } + + if tlsPinning.contains(spkiBytes: spkiBytes) { + return .accepted + } + + let error = SPKIPinningHandlerError.pinMismatch + return tlsPinning.policy == .audit + ? .auditWarning(error) + : .rejected(error) + + case .failure(let error): + let handlerError = SPKIPinningHandlerError.handlerNotFound(String(describing: error)) + return tlsPinning.policy == .audit + ? .auditWarning(handlerError) + : .rejected(handlerError) + } + } +} + +/// Result of SPKI pinning validation — decoupled from pipeline side effects. +enum PinningValidationResult { + /// Pin matched or audit mode allowed mismatch — propagate handshake event. + case accepted + + /// Pin mismatch in strict mode or critical error — close connection. + case rejected(Error) + + /// Pin mismatch in audit mode — propagate with warning (still accepted). + case auditWarning(Error) +} + +enum SPKIPinningHandlerError: Error, CustomStringConvertible { + case emptyCertificateChain + case pinMismatch + case extractionFailed(String) + case handlerNotFound(String) + + var description: String { + switch self { + case .emptyCertificateChain: + return "Empty certificate chain" + case .pinMismatch: + return "SPKI pin mismatch" + case .extractionFailed(let error): + return "SPKI extraction failed: \(error)" + case .handlerNotFound(let error): + return "SSL handler not found: \(error)" + } + } +} diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Factory.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Factory.swift index 3dc47c5ae..4d17a01bc 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Factory.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Factory.swift @@ -30,12 +30,14 @@ extension HTTPConnectionPool { struct ConnectionFactory { let key: ConnectionPool.Key let clientConfiguration: HTTPClient.Configuration + let tlsPinning: SPKIPinningConfiguration? let tlsConfiguration: TLSConfiguration let sslContextCache: SSLContextCache init( key: ConnectionPool.Key, tlsConfiguration: TLSConfiguration?, + tlsPinning: SPKIPinningConfiguration?, clientConfiguration: HTTPClient.Configuration, sslContextCache: SSLContextCache ) { @@ -44,6 +46,7 @@ extension HTTPConnectionPool { self.sslContextCache = sslContextCache self.tlsConfiguration = tlsConfiguration ?? clientConfiguration.tlsConfiguration ?? .makeClientConfiguration() + self.tlsPinning = tlsPinning ?? clientConfiguration.tlsPinning } } } @@ -393,6 +396,19 @@ extension HTTPConnectionPool.ConnectionFactory { serverHostname: sslServerHostname ) try channel.pipeline.syncOperations.addHandler(sslHandler) + + if let tlsPinning { + if #available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) { + let pinningHandler = SPKIPinningHandler( + tlsPinning: tlsPinning, + logger: logger + ) + try channel.pipeline.syncOperations.addHandler(pinningHandler) + } else { + fatalError("SPKI pinning requires minimum OS version 10.15/13.0. Cannot proceed with pinning disabled.") + } + } + let tlsEventHandler = TLSEventsHandler(deadline: deadline) try channel.pipeline.syncOperations.addHandler(tlsEventHandler) @@ -596,7 +612,21 @@ extension HTTPConnectionPool.ConnectionFactory { let tlsEventHandler = TLSEventsHandler(deadline: deadline) try sync.addHandler(sslHandler) + + if let tlsPinning { + if #available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) { + let pinningHandler = SPKIPinningHandler( + tlsPinning: tlsPinning, + logger: logger + ) + try sync.addHandler(pinningHandler) + } else { + fatalError("SPKI pinning requires minimum OS version 10.15/13.0. Cannot proceed with pinning disabled.") + } + } + try sync.addHandler(tlsEventHandler) + return channel.eventLoop.makeSucceededVoidFuture() } catch { return channel.eventLoop.makeFailedFuture(error) diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Manager.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Manager.swift index 4c313e92b..d2321221a 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Manager.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Manager.swift @@ -64,6 +64,7 @@ extension HTTPConnectionPool { eventLoopGroup: self.eventLoopGroup, sslContextCache: self.sslContextCache, tlsConfiguration: request.tlsConfiguration, + tlsPinning: request.tlsPinning, clientConfiguration: self.configuration, key: poolKey, delegate: self, diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool.swift index 676df915a..edb9f3247 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool.swift @@ -51,6 +51,7 @@ final class HTTPConnectionPool: eventLoopGroup: EventLoopGroup, sslContextCache: SSLContextCache, tlsConfiguration: TLSConfiguration?, + tlsPinning: SPKIPinningConfiguration?, clientConfiguration: HTTPClient.Configuration, key: ConnectionPool.Key, delegate: HTTPConnectionPoolDelegate, @@ -61,6 +62,7 @@ final class HTTPConnectionPool: self.connectionFactory = ConnectionFactory( key: key, tlsConfiguration: tlsConfiguration, + tlsPinning: tlsPinning, clientConfiguration: clientConfiguration, sslContextCache: sslContextCache ) diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift index 32308a6be..d08674ee0 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift @@ -151,6 +151,9 @@ protocol HTTPSchedulableRequest: HTTPExecutableRequest { /// If you want to override the default `TLSConfiguration` ensure that this property is non nil var tlsConfiguration: TLSConfiguration? { get } + /// Optional SPKI pinning configuration for TLS certificate validation. + var tlsPinning: SPKIPinningConfiguration? { get } + /// The task's logger var logger: Logger { get } diff --git a/Sources/AsyncHTTPClient/HTTPClient.swift b/Sources/AsyncHTTPClient/HTTPClient.swift index 1aa37fb7e..1e9910f6d 100644 --- a/Sources/AsyncHTTPClient/HTTPClient.swift +++ b/Sources/AsyncHTTPClient/HTTPClient.swift @@ -828,6 +828,9 @@ public final class HTTPClient: Sendable { /// TLS configuration, defaults to `TLSConfiguration.makeClientConfiguration()`. public var tlsConfiguration: Optional + /// Optional SPKI pinning configuration for TLS certificate validation. + public var tlsPinning: SPKIPinningConfiguration? + /// Sometimes it can be useful to connect to one host e.g. `x.example.com` but /// request and validate the certificate chain as if we would connect to `y.example.com`. /// ``dnsOverride`` allows to do just that by mapping host names which we will request and validate the certificate chain, to a different @@ -917,6 +920,7 @@ public final class HTTPClient: Sendable { decompression: Decompression = .disabled ) { self.tlsConfiguration = tlsConfiguration + self.tlsPinning = nil self.redirectConfiguration = redirectConfiguration ?? RedirectConfiguration() self.timeout = timeout self.connectionPool = connectionPool @@ -1063,6 +1067,35 @@ public final class HTTPClient: Sendable { self.http2StreamChannelDebugInitializer = http2StreamChannelDebugInitializer self.tracing = tracing } + + public init( + tlsConfiguration: TLSConfiguration? = nil, + tlsPinning: SPKIPinningConfiguration?, + redirectConfiguration: RedirectConfiguration? = nil, + timeout: Timeout = Timeout(), + connectionPool: ConnectionPool = ConnectionPool(), + proxy: Proxy? = nil, + decompression: Decompression = .disabled, + http1_1ConnectionDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? = nil, + http2ConnectionDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? = nil, + http2StreamChannelDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? = nil, + tracing: TracingConfiguration = .init() + ) { + self.init( + tlsConfiguration: tlsConfiguration, + redirectConfiguration: redirectConfiguration, + timeout: timeout, + connectionPool: connectionPool, + proxy: proxy, + ignoreUncleanSSLShutdown: false, + decompression: decompression, + http1_1ConnectionDebugInitializer: http1_1ConnectionDebugInitializer, + http2ConnectionDebugInitializer: http2ConnectionDebugInitializer, + http2StreamChannelDebugInitializer: http2StreamChannelDebugInitializer, + tracing: tracing + ) + self.tlsPinning = tlsPinning + } } public struct TracingConfiguration: Sendable { @@ -1443,6 +1476,8 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { case socksHandshakeTimeout case httpProxyHandshakeTimeout case tlsHandshakeTimeout + case invalidDigestLength + case invalidCertificatePinning(String) case serverOfferedUnsupportedApplicationProtocol(String) case requestStreamCancelled case getConnectionFromPoolTimeout @@ -1526,6 +1561,10 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { return "HTTP proxy handshake timeout" case .tlsHandshakeTimeout: return "TLS handshake timeout" + case .invalidDigestLength: + return "Invalid digest length" + case .invalidCertificatePinning: + return "Invalid certificate pinning" case .serverOfferedUnsupportedApplicationProtocol: return "Server offered unsupported application protocol" case .requestStreamCancelled: @@ -1615,6 +1654,12 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { public static let httpProxyHandshakeTimeout = HTTPClientError(code: .httpProxyHandshakeTimeout) /// The tls handshake timed out. public static let tlsHandshakeTimeout = HTTPClientError(code: .tlsHandshakeTimeout) + /// The hash digest length is invalid. + public static let invalidDigestLength = HTTPClientError(code: .invalidDigestLength) + /// The server's certificate did not match any pinned SPKI hash. + public static func invalidCertificatePinning(_ reason: String) -> HTTPClientError { + HTTPClientError(code: .invalidCertificatePinning(reason)) + } /// The remote server only offered an unsupported application protocol public static func serverOfferedUnsupportedApplicationProtocol(_ proto: String) -> HTTPClientError { HTTPClientError(code: .serverOfferedUnsupportedApplicationProtocol(proto)) diff --git a/Sources/AsyncHTTPClient/HTTPHandler.swift b/Sources/AsyncHTTPClient/HTTPHandler.swift index 9c7becb0b..15cedf5b1 100644 --- a/Sources/AsyncHTTPClient/HTTPHandler.swift +++ b/Sources/AsyncHTTPClient/HTTPHandler.swift @@ -232,6 +232,9 @@ extension HTTPClient { /// Request-specific TLS configuration, defaults to no request-specific TLS configuration. public var tlsConfiguration: TLSConfiguration? + /// Optional SPKI pinning configuration for TLS certificate validation. + public var tlsPinning: SPKIPinningConfiguration? + /// Parsed, validated and deconstructed URL. let deconstructedURL: DeconstructedURL @@ -253,7 +256,14 @@ extension HTTPClient { headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil ) throws { - try self.init(url: url, method: method, headers: headers, body: body, tlsConfiguration: nil) + try self.init( + url: url, + method: method, + headers: headers, + body: body, + tlsConfiguration: nil, + tlsPinning: nil + ) } /// Create HTTP request. @@ -263,7 +273,8 @@ extension HTTPClient { /// - method: HTTP method. /// - headers: Custom HTTP headers. /// - body: Request body. - /// - tlsConfiguration: Request TLS configuration + /// - tlsConfiguration: Request TLS configuration. + /// - tlsPinning: SPKI pinning configuration to validate server certificates. /// - throws: /// - `invalidURL` if URL cannot be parsed. /// - `emptyScheme` if URL does not contain HTTP scheme. @@ -274,13 +285,21 @@ extension HTTPClient { method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil, - tlsConfiguration: TLSConfiguration? + tlsConfiguration: TLSConfiguration?, + tlsPinning: SPKIPinningConfiguration? ) throws { guard let url = URL(string: url) else { throw HTTPClientError.invalidURL } - try self.init(url: url, method: method, headers: headers, body: body, tlsConfiguration: tlsConfiguration) + try self.init( + url: url, + method: method, + headers: headers, + body: body, + tlsConfiguration: tlsConfiguration, + tlsPinning: tlsPinning + ) } /// Create an HTTP `Request`. @@ -297,7 +316,14 @@ extension HTTPClient { /// - `missingSocketPath` if URL does not contains a socketPath as an encoded host. public init(url: URL, method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil) throws { - try self.init(url: url, method: method, headers: headers, body: body, tlsConfiguration: nil) + try self.init( + url: url, + method: method, + headers: headers, + body: body, + tlsConfiguration: nil, + tlsPinning: nil + ) } /// Create an HTTP `Request`. @@ -307,7 +333,8 @@ extension HTTPClient { /// - method: HTTP method. /// - headers: Custom HTTP headers. /// - body: Request body. - /// - tlsConfiguration: Request TLS configuration + /// - tlsConfiguration: Request TLS configuration. + /// - tlsPinning: SPKI pinning configuration to validate server certificates. /// - throws: /// - `emptyScheme` if URL does not contain HTTP scheme. /// - `unsupportedScheme` if URL does contains unsupported HTTP scheme. @@ -318,7 +345,8 @@ extension HTTPClient { method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil, - tlsConfiguration: TLSConfiguration? + tlsConfiguration: TLSConfiguration?, + tlsPinning: SPKIPinningConfiguration? ) throws { self.deconstructedURL = try DeconstructedURL(url: url) @@ -327,6 +355,7 @@ extension HTTPClient { self.headers = headers self.body = body self.tlsConfiguration = tlsConfiguration + self.tlsPinning = tlsPinning } /// Remote host, resolved from `URL`. diff --git a/Sources/AsyncHTTPClient/RequestBag.swift b/Sources/AsyncHTTPClient/RequestBag.swift index 67385e3f1..b0374e435 100644 --- a/Sources/AsyncHTTPClient/RequestBag.swift +++ b/Sources/AsyncHTTPClient/RequestBag.swift @@ -86,6 +86,7 @@ final class RequestBag: Sendabl let eventLoopPreference: HTTPClient.EventLoopPreference let tlsConfiguration: TLSConfiguration? + let tlsPinning: SPKIPinningConfiguration? init( request: HTTPClient.Request, @@ -116,6 +117,7 @@ final class RequestBag: Sendabl self.delegate = delegate self.tlsConfiguration = request.tlsConfiguration + self.tlsPinning = request.tlsPinning self.task.taskDelegate = self self.task.futureResult.whenComplete { _ in diff --git a/Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests.swift b/Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests.swift index 43430b85c..13df58194 100644 --- a/Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests.swift +++ b/Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests.swift @@ -19,6 +19,7 @@ import NIOHTTP1 import NIOPosix import NIOSSL import XCTest +import Crypto @testable import AsyncHTTPClient @@ -1020,6 +1021,192 @@ final class AsyncAwaitEndToEndTests: XCTestCase { } } + func testSPKIPinning_ValidPin_AllowsConnection() { + XCTAsyncTest { + let certificate = TestTLS.certificate + let privateKey = TestTLS.privateKey + + let tlsConfig = TLSConfiguration.makeServerConfiguration( + certificateChain: [.certificate(certificate)], + privateKey: .privateKey(privateKey) + ) + + let bin = HTTPBin(.http2(tlsConfiguration: tlsConfig)) + defer { XCTAssertNoThrow(try bin.shutdown()) } + + let publicKey = try certificate.extractPublicKey() + let spkiBytes = try publicKey.toSPKIBytes() + let spkiHash = SHA256.hash(data: Data(spkiBytes)) + let pinBase64 = Data(spkiHash).base64EncodedString() + + var config = HTTPClient.Configuration().enableFastFailureModeForTesting() + config.tlsConfiguration = TLSConfiguration.makeClientConfiguration() + config.tlsConfiguration?.trustRoots = .certificates([certificate]) + config.tlsConfiguration?.certificateVerification = .noHostnameVerification + config.httpVersion = .automatic + + config.tlsPinning = SPKIPinningConfiguration( + pins: [try SPKIHash(algorithm: SHA256.self, base64: pinBase64)], + policy: .strict + ) + + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(MultiThreadedEventLoopGroup.singleton), + configuration: config + ) + defer { XCTAssertNoThrow(try localClient.syncShutdown()) } + + let request = HTTPClientRequest(url: "https://localhost:\(bin.port)/get") + + guard let response = await XCTAssertNoThrowWithResult( + try await localClient.execute(request, deadline: .now() + .seconds(10)) + ) else { return } + + XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.version, .http2) + } + } + + func testSPKIPinning_InvalidPin_RejectsConnection() { + XCTAsyncTest { + let certificate = TestTLS.certificate + let privateKey = TestTLS.privateKey + + let tlsConfig = TLSConfiguration.makeServerConfiguration( + certificateChain: [.certificate(certificate)], + privateKey: .privateKey(privateKey) + ) + + let bin = HTTPBin(.http2(tlsConfiguration: tlsConfig)) + defer { XCTAssertNoThrow(try bin.shutdown()) } + + let spkiHash = SHA256.hash(data: Data(UUID().uuidString.utf8)) + let pinBase64 = Data(spkiHash).base64EncodedString() + + var config = HTTPClient.Configuration().enableFastFailureModeForTesting() + config.tlsConfiguration = TLSConfiguration.makeClientConfiguration() + config.tlsConfiguration?.trustRoots = .certificates([certificate]) + config.tlsConfiguration?.certificateVerification = .noHostnameVerification + config.httpVersion = .automatic + + config.tlsPinning = SPKIPinningConfiguration( + pins: [try SPKIHash(algorithm: SHA256.self, base64: pinBase64)], + policy: .strict + ) + + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(MultiThreadedEventLoopGroup.singleton), + configuration: config + ) + defer { XCTAssertNoThrow(try localClient.syncShutdown()) } + + let request = HTTPClientRequest(url: "https://localhost:\(bin.port)/get") + + await XCTAssertThrowsError( + try await localClient.execute(request, deadline: .now() + .seconds(10)) + ) { error in + guard let httpClientError = error as? HTTPClientError else { + return XCTFail("Expecting HTTPClientError, received: \(type(of: error))") + } + XCTAssertTrue( + httpClientError.description.contains("pinning") || + httpClientError.description.contains("SPKI"), + "Unexpected error: \(httpClientError.description)" + ) + } + } + } + + func testSPKIPinning_ValidPin_AuditMode_AllowsConnection() { + XCTAsyncTest { + let certificate = TestTLS.certificate + let privateKey = TestTLS.privateKey + + let tlsConfig = TLSConfiguration.makeServerConfiguration( + certificateChain: [.certificate(certificate)], + privateKey: .privateKey(privateKey) + ) + + let bin = HTTPBin(.http2(tlsConfiguration: tlsConfig)) + defer { XCTAssertNoThrow(try bin.shutdown()) } + + let publicKey = try certificate.extractPublicKey() + let spkiBytes = try publicKey.toSPKIBytes() + let spkiHash = SHA256.hash(data: Data(spkiBytes)) + let pinBase64 = Data(spkiHash).base64EncodedString() + + var config = HTTPClient.Configuration().enableFastFailureModeForTesting() + config.tlsConfiguration = TLSConfiguration.makeClientConfiguration() + config.tlsConfiguration?.trustRoots = .certificates([certificate]) + config.tlsConfiguration?.certificateVerification = .noHostnameVerification + config.httpVersion = .automatic + + config.tlsPinning = SPKIPinningConfiguration( + pins: [try SPKIHash(algorithm: SHA256.self, base64: pinBase64)], + policy: .audit // <-- AUDIT MODE + ) + + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(MultiThreadedEventLoopGroup.singleton), + configuration: config + ) + defer { XCTAssertNoThrow(try localClient.syncShutdown()) } + + let request = HTTPClientRequest(url: "https://localhost:\(bin.port)/get") + + guard let response = await XCTAssertNoThrowWithResult( + try await localClient.execute(request, deadline: .now() + .seconds(10)) + ) else { return } + + XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.version, .http2) + } + } + + func testSPKIPinning_InvalidPin_AuditMode_AllowsConnection() { + XCTAsyncTest { + let certificate = TestTLS.certificate + let privateKey = TestTLS.privateKey + + let tlsConfig = TLSConfiguration.makeServerConfiguration( + certificateChain: [.certificate(certificate)], + privateKey: .privateKey(privateKey) + ) + + let bin = HTTPBin(.http2(tlsConfiguration: tlsConfig)) + defer { XCTAssertNoThrow(try bin.shutdown()) } + + let spkiHash = SHA256.hash(data: Data(UUID().uuidString.utf8)) + let pinBase64 = Data(spkiHash).base64EncodedString() + + var config = HTTPClient.Configuration().enableFastFailureModeForTesting() + config.tlsConfiguration = TLSConfiguration.makeClientConfiguration() + config.tlsConfiguration?.trustRoots = .certificates([certificate]) + config.tlsConfiguration?.certificateVerification = .noHostnameVerification + config.httpVersion = .automatic + + config.tlsPinning = SPKIPinningConfiguration( + pins: [try SPKIHash(algorithm: SHA256.self, base64: pinBase64)], + policy: .audit // <-- AUDIT MODE + ) + + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(MultiThreadedEventLoopGroup.singleton), + configuration: config + ) + defer { XCTAssertNoThrow(try localClient.syncShutdown()) } + + let request = HTTPClientRequest(url: "https://localhost:\(bin.port)/get") + + guard let response = await XCTAssertNoThrowWithResult( + try await localClient.execute(request, deadline: .now() + .seconds(10)) + ) else { return } + + XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.version, .http2) + } + } + // MARK: - POST to GET conversion on redirects @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) diff --git a/Tests/AsyncHTTPClientTests/HTTP2ConnectionTests.swift b/Tests/AsyncHTTPClientTests/HTTP2ConnectionTests.swift index 14a4d5630..555d2447d 100644 --- a/Tests/AsyncHTTPClientTests/HTTP2ConnectionTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP2ConnectionTests.swift @@ -417,6 +417,7 @@ final class TestConnectionCreator { let factory = HTTPConnectionPool.ConnectionFactory( key: .init(request), tlsConfiguration: tlsConfiguration, + tlsPinning: nil, clientConfiguration: config, sslContextCache: .init() ) @@ -460,6 +461,7 @@ final class TestConnectionCreator { let factory = HTTPConnectionPool.ConnectionFactory( key: .init(request), tlsConfiguration: tlsConfiguration, + tlsPinning: nil, clientConfiguration: config, sslContextCache: .init() ) diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index 959e0f939..e1532c1f6 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -4044,7 +4044,8 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { let firstRequest = try HTTPClient.Request( url: "https://localhost:\(localHTTPBin.port)/get", method: .GET, - tlsConfiguration: tlsConfig + tlsConfiguration: tlsConfig, + tlsPinning: nil ) let firstResponse = try localClient.execute(request: firstRequest).wait() guard let firstBody = firstResponse.body else { @@ -4056,7 +4057,8 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { let secondRequest = try HTTPClient.Request( url: "https://localhost:\(localHTTPBin.port)/get", method: .GET, - tlsConfiguration: tlsConfig + tlsConfiguration: tlsConfig, + tlsPinning: nil ) let secondResponse = try localClient.execute(request: secondRequest).wait() guard let secondBody = secondResponse.body else { @@ -4065,14 +4067,15 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { } let secondConnectionNumber = try decoder.decode(RequestInfo.self, from: secondBody).connectionNumber - // Uses a differrent TLS config. + // Uses a different TLS config. var tlsConfig2 = TLSConfiguration.makeClientConfiguration() tlsConfig2.certificateVerification = .none tlsConfig2.maximumTLSVersion = .tlsv1 let thirdRequest = try HTTPClient.Request( url: "https://localhost:\(localHTTPBin.port)/get", method: .GET, - tlsConfiguration: tlsConfig2 + tlsConfiguration: tlsConfig2, + tlsPinning: nil ) let thirdResponse = try localClient.execute(request: thirdRequest).wait() guard let thirdBody = thirdResponse.body else { diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+FactoryTests.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+FactoryTests.swift index 37ff3a1ef..4afea1878 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+FactoryTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+FactoryTests.swift @@ -44,6 +44,7 @@ class HTTPConnectionPool_FactoryTests: XCTestCase { let factory = HTTPConnectionPool.ConnectionFactory( key: .init(request), tlsConfiguration: nil, + tlsPinning: nil, clientConfiguration: .init(proxy: .socksServer(host: "127.0.0.1", port: server!.localAddress!.port!)), sslContextCache: .init() ) @@ -86,6 +87,7 @@ class HTTPConnectionPool_FactoryTests: XCTestCase { let factory = HTTPConnectionPool.ConnectionFactory( key: .init(request), tlsConfiguration: nil, + tlsPinning: nil, clientConfiguration: .init(proxy: .socksServer(host: "127.0.0.1", port: server!.localAddress!.port!)) .enableFastFailureModeForTesting(), sslContextCache: .init() @@ -126,6 +128,7 @@ class HTTPConnectionPool_FactoryTests: XCTestCase { let factory = HTTPConnectionPool.ConnectionFactory( key: .init(request), tlsConfiguration: nil, + tlsPinning: nil, clientConfiguration: .init(proxy: .server(host: "127.0.0.1", port: server!.localAddress!.port!)) .enableFastFailureModeForTesting(), sslContextCache: .init() @@ -168,6 +171,7 @@ class HTTPConnectionPool_FactoryTests: XCTestCase { let factory = HTTPConnectionPool.ConnectionFactory( key: .init(request), tlsConfiguration: nil, + tlsPinning: nil, clientConfiguration: .init(tlsConfiguration: tlsConfig) .enableFastFailureModeForTesting(), sslContextCache: .init() diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+RequestQueueTests.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+RequestQueueTests.swift index 6fbdda385..11f41b106 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+RequestQueueTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+RequestQueueTests.swift @@ -92,6 +92,7 @@ final private class MockScheduledRequest: HTTPSchedulableRequest { var poolKey: ConnectionPool.Key { preconditionFailure("Unimplemented") } var tlsConfiguration: TLSConfiguration? { nil } + var tlsPinning: SPKIPinningConfiguration? { nil } var logger: Logger { preconditionFailure("Unimplemented") } var connectionDeadline: NIODeadline { preconditionFailure("Unimplemented") } var preferredEventLoop: EventLoop { preconditionFailure("Unimplemented") } diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPoolTests.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPoolTests.swift index a40703456..4ba67ff35 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPoolTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPoolTests.swift @@ -35,6 +35,7 @@ class HTTPConnectionPoolTests: XCTestCase { eventLoopGroup: eventLoopGroup, sslContextCache: .init(), tlsConfiguration: .none, + tlsPinning: .none, clientConfiguration: .init(), key: .init(request), delegate: poolDelegate, @@ -89,6 +90,7 @@ class HTTPConnectionPoolTests: XCTestCase { eventLoopGroup: eventLoopGroup, sslContextCache: .init(), tlsConfiguration: .none, + tlsPinning: .none, clientConfiguration: .init(), key: .init(request), delegate: poolDelegate, @@ -157,6 +159,7 @@ class HTTPConnectionPoolTests: XCTestCase { eventLoopGroup: eventLoopGroup, sslContextCache: .init(), tlsConfiguration: .none, + tlsPinning: .none, clientConfiguration: configuration, key: .init(request), delegate: poolDelegate, @@ -214,6 +217,7 @@ class HTTPConnectionPoolTests: XCTestCase { eventLoopGroup: eventLoopGroup, sslContextCache: .init(), tlsConfiguration: .none, + tlsPinning: .none, clientConfiguration: .init(connectionPool: .init(idleTimeout: .milliseconds(500))), key: .init(request), delegate: poolDelegate, @@ -274,6 +278,7 @@ class HTTPConnectionPoolTests: XCTestCase { eventLoopGroup: eventLoopGroup, sslContextCache: .init(), tlsConfiguration: .none, + tlsPinning: .none, clientConfiguration: .init( proxy: .init(host: "localhost", port: httpBin.port, type: .http(.basic(credentials: "invalid"))) ), @@ -328,6 +333,7 @@ class HTTPConnectionPoolTests: XCTestCase { eventLoopGroup: eventLoopGroup, sslContextCache: .init(), tlsConfiguration: .none, + tlsPinning: .none, clientConfiguration: .init( proxy: .init(host: "localhost", port: httpBin.port, type: .http(.basic(credentials: "invalid"))) ), @@ -382,6 +388,7 @@ class HTTPConnectionPoolTests: XCTestCase { eventLoopGroup: eventLoopGroup, sslContextCache: .init(), tlsConfiguration: .none, + tlsPinning: .none, clientConfiguration: .init( proxy: .init(host: "localhost", port: httpBin.port, type: .http(.basic(credentials: "invalid"))) ), @@ -438,6 +445,7 @@ class HTTPConnectionPoolTests: XCTestCase { eventLoopGroup: eventLoopGroup, sslContextCache: .init(), tlsConfiguration: .none, + tlsPinning: .none, clientConfiguration: .init(), key: .init(request), delegate: poolDelegate, @@ -494,6 +502,7 @@ class HTTPConnectionPoolTests: XCTestCase { eventLoopGroup: eventLoopGroup, sslContextCache: .init(), tlsConfiguration: nil, + tlsPinning: nil, clientConfiguration: .init(), key: .init(request), delegate: poolDelegate, diff --git a/Tests/AsyncHTTPClientTests/Mocks/MockConnectionPool.swift b/Tests/AsyncHTTPClientTests/Mocks/MockConnectionPool.swift index bd8f0736a..494a46c1c 100644 --- a/Tests/AsyncHTTPClientTests/Mocks/MockConnectionPool.swift +++ b/Tests/AsyncHTTPClientTests/Mocks/MockConnectionPool.swift @@ -715,6 +715,8 @@ final class MockHTTPScheduableRequest: HTTPSchedulableRequest { var tlsConfiguration: TLSConfiguration? { nil } + var tlsPinning: SPKIPinningConfiguration? { nil } + func requestWasQueued(_: HTTPRequestScheduler) { preconditionFailure("Unimplemented") } diff --git a/Tests/AsyncHTTPClientTests/SPKIHashTests.swift b/Tests/AsyncHTTPClientTests/SPKIHashTests.swift new file mode 100644 index 000000000..68ffa4b1f --- /dev/null +++ b/Tests/AsyncHTTPClientTests/SPKIHashTests.swift @@ -0,0 +1,118 @@ +import XCTest +@testable import AsyncHTTPClient +import Crypto + +final class SPKIHashTests: XCTestCase { + + // MARK: - Initialization (custom algorithm + base64) + + func testInitWithSHA384AndValidBase64() throws { + let base64 = Data(repeating: 0, count: 48).base64EncodedString() + let hash = try SPKIHash(algorithm: SHA384.self, base64: base64) + XCTAssertEqual(hash.bytes.count, 48) + XCTAssertEqual(hash.bytes, Data(repeating: 0, count: 48)) + } + + func testInitWithSHA384AndWrongLengthThrows() throws { + let base64 = Data(repeating: 0, count: 32).base64EncodedString() + XCTAssertThrowsError(try SPKIHash(algorithm: SHA384.self, base64: base64)) { error in + XCTAssertEqual(error as? HTTPClientError, .invalidDigestLength) + } + } + + // MARK: - Initialization (raw bytes) + + func testInitWithSHA256AndValidBytes() throws { + let bytes = Data(repeating: 0, count: 32) + let hash = try SPKIHash(algorithm: SHA256.self, bytes: bytes) + XCTAssertEqual(hash.bytes, bytes) + } + + func testInitWithSHA512AndValidBytes() throws { + let bytes = Data(repeating: 0, count: 64) + let hash = try SPKIHash(algorithm: SHA512.self, bytes: bytes) + XCTAssertEqual(hash.bytes, bytes) + } + + func testInitWithWrongByteCountThrows() throws { + let bytes = Data(repeating: 0, count: 31) + XCTAssertThrowsError(try SPKIHash(algorithm: SHA256.self, bytes: bytes)) { error in + XCTAssertEqual(error as? HTTPClientError, .invalidDigestLength) + } + } + + // MARK: - Initialization (convenience bytes initializer) + + func testConvenienceBytesInitializerUsesSHA256() throws { + let bytes = Data([UInt8](0..<32)) + let hash1 = try SPKIHash(algorithm: SHA256.self, bytes: bytes) + let hash2 = try SPKIHash(algorithm: SHA256.self, bytes: bytes) + XCTAssertEqual(hash1, hash2) + } + + // MARK: - Equality + + func testEqualityWithSameBytesAndAlgorithm() throws { + let bytes = Data(repeating: 0, count: 32) + let hash1 = try SPKIHash(algorithm: SHA256.self, bytes: bytes) + let hash2 = try SPKIHash(algorithm: SHA256.self, bytes: bytes) + XCTAssertEqual(hash1, hash2) + } + + func testInequalityWithSameBytesDifferentAlgorithm() throws { + let hash1 = try SPKIHash(algorithm: SHA256.self, bytes: Data(repeating: 0, count: 32)) + let hash2 = try SPKIHash(algorithm: SHA384.self, bytes: Data(repeating: 0, count: 48)) + XCTAssertNotEqual(hash1, hash2) + } + + func testInequalityWithDifferentBytesSameAlgorithm() throws { + let hash1 = try SPKIHash(algorithm: SHA256.self, bytes: Data(repeating: 0, count: 32)) + let hash2 = try SPKIHash(algorithm: SHA256.self, bytes: Data(repeating: 1, count: 32)) + XCTAssertNotEqual(hash1, hash2) + } + + // MARK: - Hashable + + func testHashableWithEqualValues() throws { + let bytes = Data(repeating: 0, count: 32) + let hash1 = try SPKIHash(algorithm: SHA256.self, bytes: bytes) + let hash2 = try SPKIHash(algorithm: SHA256.self, bytes: bytes) + + var set = Set() + set.insert(hash1) + set.insert(hash2) + XCTAssertEqual(set.count, 1) + } + + func testHashableWithDifferentAlgorithms() throws { + let hash1 = try SPKIHash(algorithm: SHA256.self, bytes: Data(repeating: 0, count: 32)) + let hash2 = try SPKIHash(algorithm: SHA384.self, bytes: Data(repeating: 0, count: 48)) + + var set = Set() + set.insert(hash1) + set.insert(hash2) + XCTAssertEqual(set.count, 2) + } + + // MARK: - Real-world test vectors + + func testSHA256EmptyInputHash() throws { + let expectedBase64 = "47DEQpj8HBSa+/TImW+5JCeuQeRkm5NMpJWZG3hSuFU=" + let hash = try SPKIHash(algorithm: SHA256.self, base64: expectedBase64) + + let expectedBytes = Data([ + 0xe3, 0xb0, 0xc4, 0x42, 0x98, 0xfc, 0x1c, 0x14, + 0x9a, 0xfb, 0xf4, 0xc8, 0x99, 0x6f, 0xb9, 0x24, + 0x27, 0xae, 0x41, 0xe4, 0x64, 0x9b, 0x93, 0x4c, + 0xa4, 0x95, 0x99, 0x1b, 0x78, 0x52, 0xb8, 0x55 + ]) + XCTAssertEqual(hash.bytes, expectedBytes) + } + + func testSHA384EmptyInputHash() throws { + let emptyHash = Data(SHA384.hash(data: Data())) + let base64 = emptyHash.base64EncodedString() + let hash = try SPKIHash(algorithm: SHA384.self, base64: base64) + XCTAssertEqual(hash.bytes, emptyHash) + } +} diff --git a/Tests/AsyncHTTPClientTests/SPKIPinningTests.swift b/Tests/AsyncHTTPClientTests/SPKIPinningTests.swift new file mode 100644 index 000000000..a62fe4ac1 --- /dev/null +++ b/Tests/AsyncHTTPClientTests/SPKIPinningTests.swift @@ -0,0 +1,278 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Crypto +import XCTest +import NIOSSL +import NIOTLS +import Logging +import NIOCore +import NIOEmbedded + +@testable import AsyncHTTPClient + +class SPKIPinningTests: XCTestCase { + + // MARK: - SPKIPinningConfiguration.contains(spkiBytes:) + + func testContains_WithMatchingPin_ReturnsTrue() throws { + let (certificate, spkiHash) = try Self.testCertificateAndSPKIHash() + let pin = try SPKIHash(algorithm: SHA256.self, bytes: Data(spkiHash)) + let config = SPKIPinningConfiguration( + pins: [pin], + policy: .strict + ) + + let publicKey = try certificate.extractPublicKey() + let spkiBytes = try publicKey.toSPKIBytes() + + XCTAssertTrue(config.contains(spkiBytes: spkiBytes)) + } + + func testContains_WithMismatchedPin_ReturnsFalse() throws { + let (certificate, _) = try Self.testCertificateAndSPKIHash() + let mismatchedPin = try SPKIHash(algorithm: SHA256.self, base64: "9uO07DlRgCzpXEaC2+ZiqB0VFcjdn43d6h+U2lUHORo=") + let config = SPKIPinningConfiguration( + pins: [mismatchedPin], + policy: .strict + ) + + let publicKey = try certificate.extractPublicKey() + let spkiBytes = try publicKey.toSPKIBytes() + + XCTAssertFalse(config.contains(spkiBytes: spkiBytes)) + } + + func testContains_WithEmptyInput_ReturnsFalse() throws { + let pin = try SPKIHash(algorithm: SHA256.self, base64: "9uO07DlRgCzpXEaC2+ZiqB0VFcjdn43d6h+U2lUHORo=") + let config = SPKIPinningConfiguration( + pins: [pin], + policy: .strict + ) + + XCTAssertFalse(config.contains(spkiBytes: [])) + } + + // MARK: - SPKIPinningHandler.validatePinning(for:) + + func testValidatePinning_WithValidPin_InStrictMode_ReturnsAccepted() throws { + let (certificate, spkiHash) = try Self.testCertificateAndSPKIHash() + let pin = try SPKIHash(algorithm: SHA256.self, bytes: Data(spkiHash)) + let config = SPKIPinningConfiguration( + pins: [pin], + policy: .strict + ) + let handler = try makeHandler(config: config) + + let result = handler.validatePinning(for: .success(certificate)) + + if case .accepted = result { + return + } + + XCTFail("Expected validation to succeed") + } + + func testValidatePinning_WithValidPin_InAuditMode_ReturnsAccepted() throws { + let (certificate, spkiHash) = try Self.testCertificateAndSPKIHash() + let pin = try SPKIHash(algorithm: SHA256.self, bytes: Data(spkiHash)) + let config = SPKIPinningConfiguration( + pins: [pin], + policy: .audit + ) + let handler = try makeHandler(config: config) + + let result = handler.validatePinning(for: .success(certificate)) + + if case .accepted = result { + return + } + + XCTFail("Expected validation to succeed, got \(result)") + } + + func testValidatePinning_WithMismatchedPin_InStrictMode_ReturnsRejected() throws { + let (certificate, _) = try Self.testCertificateAndSPKIHash() + let mismatchedPin = try SPKIHash(algorithm: SHA256.self, base64: "9uO07DlRgCzpXEaC2+ZiqB0VFcjdn43d6h+U2lUHORo=") + let config = SPKIPinningConfiguration( + pins: [mismatchedPin], + policy: .strict + ) + let handler = try makeHandler(config: config) + + let result = handler.validatePinning(for: .success(certificate)) + + guard case .rejected(let error) = result else { + XCTFail("Expected .rejected, got \(result)") + return + } + + if case .pinMismatch = error as? SPKIPinningHandlerError { + return + } + + XCTFail("Expected .pinMismatch, got \(error)") + } + + func testValidatePinning_WithMismatchedPin_InAuditMode_ReturnsAuditWarning() throws { + let (certificate, _) = try Self.testCertificateAndSPKIHash() + let mismatchedPin = try SPKIHash(algorithm: SHA256.self, base64: "9uO07DlRgCzpXEaC2+ZiqB0VFcjdn43d6h+U2lUHORo=") + let config = SPKIPinningConfiguration( + pins: [mismatchedPin], + policy: .audit + ) + let handler = try makeHandler(config: config) + + let result = handler.validatePinning(for: .success(certificate)) + + guard case .auditWarning(let error) = result else { + XCTFail("Expected .auditWarning, got \(result)") + return + } + + if case .pinMismatch = error as? SPKIPinningHandlerError { + return + } + + XCTFail("Expected .pinMismatch, got \(error)") + } + + func testValidatePinning_WithNilCertificate_InStrictMode_ReturnsRejected() throws { + let pin = try SPKIHash(algorithm: SHA256.self, base64: "9uO07DlRgCzpXEaC2+ZiqB0VFcjdn43d6h+U2lUHORo=") + let config = SPKIPinningConfiguration( + pins: [pin], + policy: .strict + ) + let handler = try makeHandler(config: config) + + let result = handler.validatePinning(for: .success(nil)) + + guard case .rejected(let error) = result else { + XCTFail("Expected .rejected, got \(result)") + return + } + + if case .emptyCertificateChain = error as? SPKIPinningHandlerError { + return + } + + XCTFail("Expected .emptyCertificateChain, got \(error)") + } + + func testValidatePinning_WithNilCertificate_InAuditMode_ReturnsAuditWarning() throws { + let pin = try SPKIHash(algorithm: SHA256.self, base64: "9uO07DlRgCzpXEaC2+ZiqB0VFcjdn43d6h+U2lUHORo=") + let config = SPKIPinningConfiguration( + pins: [pin], + policy: .audit + ) + let handler = try makeHandler(config: config) + + let result = handler.validatePinning(for: .success(nil)) + + guard case .auditWarning(let error) = result else { + XCTFail("Expected .auditWarning, got \(result)") + return + } + + if case .emptyCertificateChain = error as? SPKIPinningHandlerError { + return + } + + XCTFail("Expected .emptyCertificateChain, got \(error)") + } + + func testValidatePinning_WithExtractionFailure_InStrictMode_ReturnsRejected() throws { + let pin = try SPKIHash(algorithm: SHA256.self, base64: "9uO07DlRgCzpXEaC2+ZiqB0VFcjdn43d6h+U2lUHORo=") + let config = SPKIPinningConfiguration( + pins: [pin], + policy: .strict + ) + let handler = try makeHandler(config: config) + let extractionError = NSError(domain: "TestError", code: 1, userInfo: nil) + + let result = handler.validatePinning(for: .failure(extractionError)) + + guard case .rejected(let error) = result else { + XCTFail("Expected .rejected, got \(result)") + return + } + XCTAssertTrue((error as? SPKIPinningHandlerError)?.description.contains("SSL handler not found:") == true) + } + + func testValidatePinning_WithExtractionFailure_InAuditMode_ReturnsAuditWarning() throws { + let pin = try SPKIHash(algorithm: SHA256.self, base64: "9uO07DlRgCzpXEaC2+ZiqB0VFcjdn43d6h+U2lUHORo=") + let config = SPKIPinningConfiguration( + pins: [pin], + policy: .audit + ) + let handler = try makeHandler(config: config) + let extractionError = NSError(domain: "TestError", code: 1, userInfo: nil) + + let result = handler.validatePinning(for: .failure(extractionError)) + + guard case .auditWarning(let error) = result else { + XCTFail("Expected .auditWarning, got \(result)") + return + } + XCTAssertTrue((error as? SPKIPinningHandlerError)?.description.contains("SSL handler not found:") == true) + } + + // MARK: - SPKIPinningHandler.userInboundEventTriggered(...) + + func testUserInboundEventTriggered_IgnoresNonHandshakeEvents() throws { + let config = SPKIPinningConfiguration( + pins: [], + policy: .strict + ) + let handler = try makeHandler(config: config) + let event = TLSUserEvent.shutdownCompleted + + let embedded = EmbeddedChannel(handlers: [handler]) + embedded.pipeline.fireUserInboundEventTriggered(event) + try embedded.throwIfErrorCaught() + } + + func testUserInboundEventTriggered_OnHandshakeInitiatesValidation() throws { + let config = SPKIPinningConfiguration( + pins: [], + policy: .strict + ) + let handler = try makeHandler(config: config) + let event = TLSUserEvent.handshakeCompleted(negotiatedProtocol: nil) + + let embedded = EmbeddedChannel(handlers: [handler]) + embedded.pipeline.fireUserInboundEventTriggered(event) + + XCTAssertThrowsError(try embedded.throwIfErrorCaught()) { + if let error = $0 as? HTTPClientError { + XCTAssertTrue(error.description.contains("SSL handler not found:")) + } + } + } + + // MARK: - Helpers + + private func makeHandler(config: SPKIPinningConfiguration) throws -> SPKIPinningHandler { + let logger = Logger(label: "test", factory: SwiftLogNoOpLogHandler.init) + return SPKIPinningHandler(tlsPinning: config, logger: logger) + } + + private static func testCertificateAndSPKIHash() throws -> (NIOSSLCertificate, SHA256Digest) { + let certificate = TestTLS.certificate + let publicKey = try certificate.extractPublicKey() + let spkiBytes = try publicKey.toSPKIBytes() + let spkiHash = SHA256.hash(data: Data(spkiBytes)) + return (certificate, spkiHash) + } +}