Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions Package.resolved

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ let package = Package(
dependencies: [
.package(url: "https://github.com/apple/swift-log", from: "1.6.0"),
.package(url: "https://github.com/vapor/postgres-nio", from: "1.27.0"),
.package(url: "https://github.com/feather-framework/feather-database", exact: "1.0.0-beta.1"),
.package(url: "https://github.com/feather-framework/feather-database", exact: "1.0.0-beta.2"),
// [docc-plugin-placeholder]
],
targets: [
.target(
Expand Down
16 changes: 12 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

Postgres driver implementation for the abstract [Feather Database](https://github.com/feather-framework/feather-database) Swift API package.

![Release: 1.0.0-beta.1](https://img.shields.io/badge/Release-1%2E0%2E0--beta%2E1-F05138)
[
![Release: 1.0.0-beta.2](https://img.shields.io/badge/Release-1%2E0%2E0--beta%2E2-F05138)
](
https://github.com/feather-framework/feather-postgres-database/releases/tag/1.0.0-beta.2
)

## Features

Expand Down Expand Up @@ -33,7 +37,7 @@ Postgres driver implementation for the abstract [Feather Database](https://githu
Add the dependency to your `Package.swift`:

```swift
.package(url: "https://github.com/feather-framework/feather-postgres-database", exact: "1.0.0-beta.1"),
.package(url: "https://github.com/feather-framework/feather-postgres-database", exact: "1.0.0-beta.2"),
```

Then add `FeatherPostgresDatabase` to your target dependencies:
Expand All @@ -45,7 +49,11 @@ Then add `FeatherPostgresDatabase` to your target dependencies:

## Usage

![DocC API documentation](https://img.shields.io/badge/DocC-API_documentation-F05138)
[
![DocC API documentation](https://img.shields.io/badge/DocC-API_documentation-F05138)
](
https://feather-framework.github.io/feather-postgres-database/documentation/featherpostgresdatabase/
)

API documentation is available at the following link.

Expand Down Expand Up @@ -127,7 +135,7 @@ The following database driver implementations are available for use:
- Build: `swift build`
- Test:
- local: `swift test`
- using Docker: `swift docker-test`
- using Docker: `make docker-test`
- Format: `make format`
- Check: `make check`

Expand Down
15 changes: 6 additions & 9 deletions Sources/FeatherPostgresDatabase/PostgresDatabaseClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,10 @@ public struct PostgresDatabaseClient: DatabaseClient {
/// - Throws: A `DatabaseError` if connection handling fails.
/// - Returns: The query result produced by the closure.
@discardableResult
public func connection(
public func connection<T>(
isolation: isolated (any Actor)? = #isolation,
_ closure: (PostgresConnection) async throws ->
sending PostgresQueryResult,
) async throws(DatabaseError) -> sending PostgresQueryResult {
_ closure: (PostgresConnection) async throws -> sending T,
) async throws(DatabaseError) -> sending T {
do {
return try await client.withConnection(closure)
}
Expand All @@ -72,12 +71,10 @@ public struct PostgresDatabaseClient: DatabaseClient {
/// - Throws: A `DatabaseError` if the transaction fails.
/// - Returns: The query result produced by the closure.
@discardableResult
public func transaction(
public func transaction<T>(
isolation: isolated (any Actor)? = #isolation,
_ closure: (
(PostgresConnection) async throws -> sending PostgresQueryResult
),
) async throws(DatabaseError) -> sending PostgresQueryResult {
_ closure: ((PostgresConnection) async throws -> sending T),
) async throws(DatabaseError) -> sending T {
do {
return try await client.withTransaction(
logger: logger,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,150 @@ struct FeatherPostgresDatabaseTestSuite {
}
}

@Test
func concurrentTransactionUpdates() async throws {
try await runUsingTestDatabaseClient { database in
let suffix = randomTableSuffix()
let table = "sessions_\(suffix)"
let sessionID = "session_\(suffix)"

enum TestError: Error {
case missingRow
}

try await database.execute(
query: #"""
DROP TABLE IF EXISTS "\#(unescaped: table)" CASCADE;
"""#
)
try await database.execute(
query: #"""
CREATE TABLE "\#(unescaped: table)" (
"id" TEXT NOT NULL PRIMARY KEY,
"access_token" TEXT NOT NULL,
"access_expires_at" TIMESTAMPTZ NOT NULL,
"refresh_token" TEXT NOT NULL,
"refresh_count" INTEGER NOT NULL DEFAULT 0
);
"""#
)

// set an expired token
try await database.execute(
query: #"""
INSERT INTO "\#(unescaped: table)"
("id", "access_token", "access_expires_at", "refresh_token", "refresh_count")
VALUES
(
\#(sessionID),
'stale',
NOW() - INTERVAL '5 minutes',
'refresh',
0
);
"""#
)

func getValidAccessToken(sessionID: String) async throws -> String {
try await database.transaction { connection in
let result = try await connection.execute(
query: #"""
SELECT
"access_token",
"refresh_count",
"access_expires_at" > NOW() + INTERVAL '60 seconds' AS "is_valid"
FROM "\#(unescaped: table)"
WHERE "id" = \#(sessionID)
FOR UPDATE;
"""#
)
let rows = try await result.collect()

guard let row = rows.first else {
throw TestError.missingRow
}

let isValid = try row.decode(
column: "is_valid",
as: Bool.self
)
if isValid {
// token was valid, must be called X times
return try row.decode(
column: "access_token",
as: String.self
)
}

// refresh, this branch can only be called 1 time
let refreshCount = try row.decode(
column: "refresh_count",
as: Int.self
)
let newRefreshCount = refreshCount + 1
let newToken = "token_\(newRefreshCount)"

try await Task.sleep(for: .milliseconds(40))

_ = try await connection.execute(
query: #"""
UPDATE "\#(unescaped: table)"
SET
"access_token" = \#(newToken),
"access_expires_at" = NOW() + INTERVAL '10 minutes',
"refresh_count" = \#(newRefreshCount)
WHERE "id" = \#(sessionID);
"""#
)

return newToken
}
}

let workerCount = 80
var tokens: [String] = []
try await withThrowingTaskGroup(of: String.self) { group in
for _ in 0..<workerCount {
group.addTask {
try await getValidAccessToken(sessionID: sessionID)
}
}
for try await token in group {
tokens.append(token)
}
}

#expect(Set(tokens).count == 1)

let result =
try await database.execute(
query: #"""
SELECT
"access_token",
"refresh_count",
"access_expires_at" > NOW() AS "is_valid"
FROM "\#(unescaped: table)"
WHERE "id" = \#(sessionID);
"""#
)
.collect()

#expect(result.count == 1)
#expect(
try result[0].decode(column: "refresh_count", as: Int.self)
== 1
)
#expect(
try result[0].decode(column: "access_token", as: String.self)
== "token_1"
)
#expect(
try result[0].decode(column: "is_valid", as: Bool.self)
== true
)
}
}

@Test
func doubleRoundTrip() async throws {
try await runUsingTestDatabaseClient { database in
Expand Down