diff --git a/sdks/kotlin/DEVELOP.md b/sdks/kotlin/DEVELOP.md new file mode 100644 index 00000000000..ca526ed6ae8 --- /dev/null +++ b/sdks/kotlin/DEVELOP.md @@ -0,0 +1,136 @@ +# Kotlin SDK — Developer Guide + +Internal documentation for contributors working on the SpacetimeDB Kotlin SDK. + +## Project Structure + +``` +src/ + commonMain/ Shared Kotlin code (all targets) + com/clockworklabs/spacetimedb/ + SpacetimeDBClient.kt DbConnection, DbConnectionBuilder + Identity.kt Identity, ConnectionId, Address, Timestamp + ClientCache.kt Client-side row cache (TableCache, ByteArrayWrapper) + TableHandle.kt Per-table callback registration + SubscriptionHandle.kt Subscription lifecycle + SubscriptionBuilder.kt Fluent subscription API + ReconnectPolicy.kt Exponential backoff configuration + Compression.kt expect declarations for decompression + bsatn/ + BsatnReader.kt Binary deserialization + BsatnWriter.kt Binary serialization + BsatnRowList.kt Row list decoding + protocol/ + ServerMessage.kt Server → Client message decoding + ClientMessage.kt Client → Server message encoding + ProtocolTypes.kt QuerySetId, QueryRows, TableUpdateRows, etc. + websocket/ + WebSocketTransport.kt WebSocket lifecycle, ping/pong, reconnection + jvmMain/ JVM-specific (Gzip via java.util.zip, Brotli via org.brotli) + iosMain/ iOS-specific (Gzip via platform.zlib) + commonTest/ Shared tests + jvmTest/ JVM-only tests (compression round-trips) +``` + +## Architecture + +### Connection Lifecycle + +``` +DbConnectionBuilder.build() + → DbConnection constructor + → WebSocketTransport.connect() + → connectSession() opens WebSocket + → processSendQueue() (coroutine: outbound messages) + → processIncoming() (coroutine: inbound frames) + → runKeepAlive() (coroutine: 30s idle ping/pong) +``` + +On unexpected disconnect with a `ReconnectPolicy`, the transport enters a +`RECONNECTING` state and calls `attemptReconnect()` which retries with +exponential backoff up to `maxRetries` times. + +### Wire Protocol + +Uses the `v2.bsatn.spacetimedb` WebSocket subprotocol. All messages are BSATN +(Binary SpacetimeDB Algebraic Type Notation) — a tag-length-value encoding +defined in `crates/client-api-messages/src/websocket/v2.rs`. + +**Server messages** are preceded by a compression byte: +- `0x00` — uncompressed +- `0x01` — Brotli +- `0x02` — Gzip + +The SDK requests Gzip compression via the `compression=Gzip` query parameter. + +### Client Cache + +`ClientCache` maintains a map of `TableCache` instances, one per table. Each +`TableCache` stores rows keyed by content (`ByteArrayWrapper`) with reference +counting. This allows overlapping subscriptions to share rows without duplicates. + +Transaction updates produce `TableOperation` events (Insert, Delete, Update, +EventInsert) which drive the `TableHandle` callback system. + +### Threading Model + +- `WebSocketTransport` runs on a `CoroutineScope(SupervisorJob() + Dispatchers.Default)`. +- All `handleMessage` processing is serialized behind a `Mutex` to prevent + concurrent cache mutation. +- `atomicfu` atomics are used for transport-level flags (`idle`, `wantPong`, + `intentionalDisconnect`) that are read/written across coroutines. + +### Platform-Specific Code + +Uses Kotlin `expect`/`actual` for decompression: + +| Platform | Gzip | Brotli | +|----------|------|--------| +| JVM | `java.util.zip.GZIPInputStream` | `org.brotli.dec.BrotliInputStream` | +| iOS | `platform.zlib` (wbits=31) | Not supported (SDK defaults to Gzip) | + +## Building + +```bash +# Run all JVM tests +./gradlew jvmTest + +# Compile JVM +./gradlew compileKotlinJvm + +# Compile iOS (verifies expect/actual) +./gradlew compileKotlinIosArm64 + +# All targets +./gradlew build +``` + +## Test Suite + +| File | Coverage | +|------|----------| +| `BsatnTest.kt` | Reader/Writer round-trips for all primitive types | +| `ProtocolTest.kt` | ServerMessage and ClientMessage encode/decode | +| `ClientCacheTest.kt` | Cache operations, ref counting, transaction updates | +| `OneOffQueryTest.kt` | OneOffQueryResult decode (Ok and Err variants) | +| `CompressionTest.kt` | Gzip round-trip, empty/large payloads (JVM only) | +| `ReconnectPolicyTest.kt` | Backoff calculation, parameter validation | + +## Design Decisions + +1. **Manual ping/pong** instead of Ktor's `pingIntervalMillis` — OkHttp engine + doesn't support Ktor's built-in ping, so we implement idle detection + ourselves (matching the Rust SDK's 30s pattern). + +2. **ByteArray row storage** — Rows are stored as raw BSATN bytes rather than + deserialized objects. This keeps the core SDK schema-agnostic; code + generation (future) will layer typed access on top. + +3. **Compression negotiation** — The SDK advertises `compression=Gzip` in the + connection URI. Brotli is supported on JVM but not iOS; Gzip provides + universal coverage. + +4. **No Brotli on iOS** — Apple's Compression framework supports Brotli + (`COMPRESSION_BROTLI`) but it's not directly available via Kotlin/Native's + `platform.compression` interop. Since the SDK requests Gzip, this is a + non-issue in practice. diff --git a/sdks/kotlin/README.md b/sdks/kotlin/README.md new file mode 100644 index 00000000000..afb30102238 --- /dev/null +++ b/sdks/kotlin/README.md @@ -0,0 +1,63 @@ +# SpacetimeDB Kotlin SDK + +## Overview + +The Kotlin Multiplatform (KMP) client SDK for [SpacetimeDB](https://spacetimedb.com). Targets **JVM** and **iOS** (arm64, simulator-arm64, x64), enabling native SpacetimeDB clients from Kotlin, Java, and Swift (via KMP interop). + +## Features + +- BSATN binary protocol (`v2.bsatn.spacetimedb`) +- Subscriptions with SQL query support +- One-off queries (suspend and callback variants) +- Reducer invocation with result callbacks +- Automatic reconnection with exponential backoff +- Ping/pong keep-alive (30s idle timeout) +- Gzip and Brotli message decompression +- Client-side row cache with ref-counted rows + +## Quick Start + +```kotlin +val conn = DbConnection.builder() + .withUri("ws://localhost:3000") + .withModuleName("my_module") + .onConnect { conn, identity, token -> + println("Connected as $identity") + + // Subscribe to table changes + conn.subscriptionBuilder() + .onApplied { println("Subscription active") } + .subscribe("SELECT * FROM users") + + // Observe a table + conn.table("users").onInsert { row -> + println("New user row: ${row.size} bytes") + } + } + .onDisconnect { _, error -> + println("Disconnected: ${error?.message ?: "clean"}") + } + .build() +``` + +## Installation + +Add to your `build.gradle.kts`: + +```kotlin +kotlin { + sourceSets { + commonMain.dependencies { + implementation("com.clockworklabs:spacetimedb-sdk:0.1.0") + } + } +} +``` + +## Documentation + +For the SpacetimeDB platform documentation, see [spacetimedb.com/docs](https://spacetimedb.com/docs). + +## Internal Developer Documentation + +See [`DEVELOP.md`](./DEVELOP.md). diff --git a/sdks/kotlin/build.gradle.kts b/sdks/kotlin/build.gradle.kts new file mode 100644 index 00000000000..de98b8ef092 --- /dev/null +++ b/sdks/kotlin/build.gradle.kts @@ -0,0 +1,42 @@ +plugins { + kotlin("multiplatform") version "2.1.0" +} + +group = "com.clockworklabs" +version = "0.1.0" + +kotlin { + jvm() + iosArm64() + iosSimulatorArm64() + iosX64() + + applyDefaultHierarchyTemplate() + + sourceSets { + commonMain.dependencies { + implementation("io.ktor:ktor-client-core:3.0.3") + implementation("io.ktor:ktor-client-websockets:3.0.3") + implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.9.0") + implementation("org.jetbrains.kotlinx:atomicfu:0.23.2") + } + commonTest.dependencies { + implementation(kotlin("test")) + implementation("org.jetbrains.kotlinx:kotlinx-coroutines-test:1.9.0") + } + jvmMain.dependencies { + implementation("io.ktor:ktor-client-okhttp:3.0.3") + implementation("org.brotli:dec:0.1.2") + } + iosMain.dependencies { + implementation("io.ktor:ktor-client-darwin:3.0.3") + } + } +} + +tasks.withType { + testLogging { + showStandardStreams = true + } + maxHeapSize = "1g" +} diff --git a/sdks/kotlin/gradle.properties b/sdks/kotlin/gradle.properties new file mode 100644 index 00000000000..d54bbe28298 --- /dev/null +++ b/sdks/kotlin/gradle.properties @@ -0,0 +1,3 @@ +kotlin.code.style=official +kotlin.mpp.stability.nowarn=true +org.gradle.jvmargs=-Xmx2g diff --git a/sdks/kotlin/gradle/wrapper/gradle-wrapper.jar b/sdks/kotlin/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 00000000000..1b33c55baab Binary files /dev/null and b/sdks/kotlin/gradle/wrapper/gradle-wrapper.jar differ diff --git a/sdks/kotlin/gradle/wrapper/gradle-wrapper.properties b/sdks/kotlin/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 00000000000..cea7a793a84 --- /dev/null +++ b/sdks/kotlin/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,7 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-8.12-bin.zip +networkTimeout=10000 +validateDistributionUrl=true +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/sdks/kotlin/gradlew b/sdks/kotlin/gradlew new file mode 100755 index 00000000000..23d15a93670 --- /dev/null +++ b/sdks/kotlin/gradlew @@ -0,0 +1,251 @@ +#!/bin/sh + +# +# Copyright © 2015-2021 the original authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +# + +############################################################################## +# +# Gradle start up script for POSIX generated by Gradle. +# +# Important for running: +# +# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is +# noncompliant, but you have some other compliant shell such as ksh or +# bash, then to run this script, type that shell name before the whole +# command line, like: +# +# ksh Gradle +# +# Busybox and similar reduced shells will NOT work, because this script +# requires all of these POSIX shell features: +# * functions; +# * expansions «$var», «${var}», «${var:-default}», «${var+SET}», +# «${var#prefix}», «${var%suffix}», and «$( cmd )»; +# * compound commands having a testable exit status, especially «case»; +# * various built-in commands including «command», «set», and «ulimit». +# +# Important for patching: +# +# (2) This script targets any POSIX shell, so it avoids extensions provided +# by Bash, Ksh, etc; in particular arrays are avoided. +# +# The "traditional" practice of packing multiple parameters into a +# space-separated string is a well documented source of bugs and security +# problems, so this is (mostly) avoided, by progressively accumulating +# options in "$@", and eventually passing that to Java. +# +# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS, +# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly; +# see the in-line comments for details. +# +# There are tweaks for specific operating systems such as AIX, CygWin, +# Darwin, MinGW, and NonStop. +# +# (3) This script is generated from the Groovy template +# https://github.com/gradle/gradle/blob/HEAD/platforms/jvm/plugins-application/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# within the Gradle project. +# +# You can find Gradle at https://github.com/gradle/gradle/. +# +############################################################################## + +# Attempt to set APP_HOME + +# Resolve links: $0 may be a link +app_path=$0 + +# Need this for daisy-chained symlinks. +while + APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path + [ -h "$app_path" ] +do + ls=$( ls -ld "$app_path" ) + link=${ls#*' -> '} + case $link in #( + /*) app_path=$link ;; #( + *) app_path=$APP_HOME$link ;; + esac +done + +# This is normally unused +# shellcheck disable=SC2034 +APP_BASE_NAME=${0##*/} +# Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036) +APP_HOME=$( cd -P "${APP_HOME:-./}" > /dev/null && printf '%s\n' "$PWD" ) || exit + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD=maximum + +warn () { + echo "$*" +} >&2 + +die () { + echo + echo "$*" + echo + exit 1 +} >&2 + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "$( uname )" in #( + CYGWIN* ) cygwin=true ;; #( + Darwin* ) darwin=true ;; #( + MSYS* | MINGW* ) msys=true ;; #( + NONSTOP* ) nonstop=true ;; +esac + +CLASSPATH="\\\"\\\"" + + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD=$JAVA_HOME/jre/sh/java + else + JAVACMD=$JAVA_HOME/bin/java + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD=java + if ! command -v java >/dev/null 2>&1 + then + die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +fi + +# Increase the maximum file descriptors if we can. +if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then + case $MAX_FD in #( + max*) + # In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + MAX_FD=$( ulimit -H -n ) || + warn "Could not query maximum file descriptor limit" + esac + case $MAX_FD in #( + '' | soft) :;; #( + *) + # In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + ulimit -n "$MAX_FD" || + warn "Could not set maximum file descriptor limit to $MAX_FD" + esac +fi + +# Collect all arguments for the java command, stacking in reverse order: +# * args from the command line +# * the main class name +# * -classpath +# * -D...appname settings +# * --module-path (only if needed) +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. + +# For Cygwin or MSYS, switch paths to Windows format before running java +if "$cygwin" || "$msys" ; then + APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) + CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) + + JAVACMD=$( cygpath --unix "$JAVACMD" ) + + # Now convert the arguments - kludge to limit ourselves to /bin/sh + for arg do + if + case $arg in #( + -*) false ;; # don't mess with options #( + /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath + [ -e "$t" ] ;; #( + *) false ;; + esac + then + arg=$( cygpath --path --ignore --mixed "$arg" ) + fi + # Roll the args list around exactly as many times as the number of + # args, so each arg winds up back in the position where it started, but + # possibly modified. + # + # NB: a `for` loop captures its iteration list before it begins, so + # changing the positional parameters here affects neither the number of + # iterations, nor the values presented in `arg`. + shift # remove old arg + set -- "$@" "$arg" # push replacement arg + done +fi + + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Collect all arguments for the java command: +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and optsEnvironmentVar are not allowed to contain shell fragments, +# and any embedded shellness will be escaped. +# * For example: A user cannot expect ${Hostname} to be expanded, as it is an environment variable and will be +# treated as '${Hostname}' itself on the command line. + +set -- \ + "-Dorg.gradle.appname=$APP_BASE_NAME" \ + -classpath "$CLASSPATH" \ + -jar "$APP_HOME/gradle/wrapper/gradle-wrapper.jar" \ + "$@" + +# Stop when "xargs" is not available. +if ! command -v xargs >/dev/null 2>&1 +then + die "xargs is not available" +fi + +# Use "xargs" to parse quoted args. +# +# With -n1 it outputs one arg per line, with the quotes and backslashes removed. +# +# In Bash we could simply go: +# +# readarray ARGS < <( xargs -n1 <<<"$var" ) && +# set -- "${ARGS[@]}" "$@" +# +# but POSIX shell has neither arrays nor command substitution, so instead we +# post-process each arg (as a line of input to sed) to backslash-escape any +# character that might be a shell metacharacter, then use eval to reverse +# that process (while maintaining the separation between arguments), and wrap +# the whole thing up as a single "set" statement. +# +# This will of course break if any of these variables contains a newline or +# an unmatched quote. +# + +eval "set -- $( + printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | + xargs -n1 | + sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | + tr '\n' ' ' + )" '"$@"' + +exec "$JAVACMD" "$@" diff --git a/sdks/kotlin/gradlew.bat b/sdks/kotlin/gradlew.bat new file mode 100644 index 00000000000..db3a6ac207e --- /dev/null +++ b/sdks/kotlin/gradlew.bat @@ -0,0 +1,94 @@ +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem +@rem SPDX-License-Identifier: Apache-2.0 +@rem + +@if "%DEBUG%"=="" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%"=="" set DIRNAME=. +@rem This is normally unused +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Resolve any "." and ".." in APP_HOME to make it shorter. +for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if %ERRORLEVEL% equ 0 goto execute + +echo. 1>&2 +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto execute + +echo. 1>&2 +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 + +goto fail + +:execute +@rem Setup the command line + +set CLASSPATH= + + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" -jar "%APP_HOME%\gradle\wrapper\gradle-wrapper.jar" %* + +:end +@rem End local scope for the variables with windows NT shell +if %ERRORLEVEL% equ 0 goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +set EXIT_CODE=%ERRORLEVEL% +if %EXIT_CODE% equ 0 set EXIT_CODE=1 +if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% +exit /b %EXIT_CODE% + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/sdks/kotlin/settings.gradle.kts b/sdks/kotlin/settings.gradle.kts new file mode 100644 index 00000000000..c793d4071a8 --- /dev/null +++ b/sdks/kotlin/settings.gradle.kts @@ -0,0 +1,16 @@ +rootProject.name = "spacetimedb-sdk" + +pluginManagement { + repositories { + mavenCentral() + gradlePluginPortal() + google() + } +} + +dependencyResolutionManagement { + repositories { + mavenCentral() + google() + } +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/ClientCache.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/ClientCache.kt new file mode 100644 index 00000000000..1e10b2ad250 --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/ClientCache.kt @@ -0,0 +1,147 @@ +package com.clockworklabs.spacetimedb + +import com.clockworklabs.spacetimedb.protocol.PersistentTableRows +import com.clockworklabs.spacetimedb.protocol.QueryRows +import com.clockworklabs.spacetimedb.protocol.QuerySetUpdate +import com.clockworklabs.spacetimedb.protocol.TableUpdateRows + +class ClientCache { + private val tables = mutableMapOf() + + fun getOrCreateTable(name: String): TableCache = + tables.getOrPut(name) { TableCache(name) } + + fun getTable(name: String): TableCache? = tables[name] + + fun tableNames(): Set = tables.keys.toSet() + + fun applySubscribeRows(rows: QueryRows) { + for (singleTable in rows.tables) { + val tableName = singleTable.table.value + val cache = getOrCreateTable(tableName) + val decodedRows = singleTable.rows.decodeRows() + for (row in decodedRows) { + cache.insertRow(row) + } + } + } + + fun applyUnsubscribeRows(rows: QueryRows) { + for (singleTable in rows.tables) { + val tableName = singleTable.table.value + val cache = getTable(tableName) ?: continue + val decodedRows = singleTable.rows.decodeRows() + for (row in decodedRows) { + cache.deleteRow(row) + } + } + } + + fun applyTransactionUpdate(querySets: List): List { + val operations = mutableListOf() + for (qsUpdate in querySets) { + for (tableUpdate in qsUpdate.tables) { + val tableName = tableUpdate.tableName.value + val cache = getOrCreateTable(tableName) + for (rowUpdate in tableUpdate.rows) { + when (rowUpdate) { + is TableUpdateRows.PersistentTable -> { + applyPersistentUpdate(cache, tableName, rowUpdate.rows, operations) + } + is TableUpdateRows.EventTable -> { + val decoded = rowUpdate.rows.events.decodeRows() + for (row in decoded) { + operations.add(TableOperation.EventInsert(tableName, row)) + } + } + } + } + } + } + return operations + } + + private fun applyPersistentUpdate( + cache: TableCache, + tableName: String, + rows: PersistentTableRows, + operations: MutableList, + ) { + val deletes = rows.deletes.decodeRows() + val inserts = rows.inserts.decodeRows() + + val deletedSet = deletes.map { ByteArrayWrapper(it) }.toSet() + val insertMap = mutableMapOf() + for (row in inserts) { + insertMap[ByteArrayWrapper(row)] = row + } + + for (row in deletes) { + val wrapper = ByteArrayWrapper(row) + val newRow = insertMap[wrapper] + if (newRow != null) { + cache.deleteRow(row) + cache.insertRow(newRow) + operations.add(TableOperation.Update(tableName, row, newRow)) + } else { + cache.deleteRow(row) + operations.add(TableOperation.Delete(tableName, row)) + } + } + + for (row in inserts) { + val wrapper = ByteArrayWrapper(row) + if (wrapper !in deletedSet) { + cache.insertRow(row) + operations.add(TableOperation.Insert(tableName, row)) + } + } + } +} + +class TableCache(val name: String) { + private val rows = mutableMapOf() + + val count: Int get() = rows.size + + fun insertRow(rowBytes: ByteArray) { + val key = ByteArrayWrapper(rowBytes) + val existing = rows[key] + if (existing != null) { + existing.refCount++ + } else { + rows[key] = RowEntry(rowBytes, 1) + } + } + + fun deleteRow(rowBytes: ByteArray): Boolean { + val key = ByteArrayWrapper(rowBytes) + val existing = rows[key] ?: return false + existing.refCount-- + if (existing.refCount <= 0) { + rows.remove(key) + } + return true + } + + fun allRows(): List = rows.values.map { it.data } + + fun containsRow(rowBytes: ByteArray): Boolean = + rows.containsKey(ByteArrayWrapper(rowBytes)) +} + +class RowEntry(val data: ByteArray, var refCount: Int) + +sealed class TableOperation { + data class Insert(val tableName: String, val row: ByteArray) : TableOperation() + data class Delete(val tableName: String, val row: ByteArray) : TableOperation() + data class Update(val tableName: String, val oldRow: ByteArray, val newRow: ByteArray) : TableOperation() + data class EventInsert(val tableName: String, val row: ByteArray) : TableOperation() +} + +class ByteArrayWrapper(val data: ByteArray) { + override fun equals(other: Any?): Boolean = + other is ByteArrayWrapper && data.contentEquals(other.data) + + override fun hashCode(): Int = data.contentHashCode() +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/Compression.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/Compression.kt new file mode 100644 index 00000000000..b25cb22be0e --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/Compression.kt @@ -0,0 +1,5 @@ +package com.clockworklabs.spacetimedb + +expect fun decompressBrotli(data: ByteArray): ByteArray + +expect fun decompressGzip(data: ByteArray): ByteArray diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/Event.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/Event.kt new file mode 100644 index 00000000000..42afb8e4359 --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/Event.kt @@ -0,0 +1,30 @@ +package com.clockworklabs.spacetimedb + +sealed class Status { + data object Committed : Status() + data class Failed(val message: String) : Status() + data class OutOfEnergy(val message: String) : Status() +} + +data class ReducerEvent( + val timestamp: Timestamp, + val status: Status, + val callerIdentity: Identity, + val callerConnectionId: ConnectionId, + val reducerName: String, + val energyConsumed: Long, +) + +sealed class Event { + data class Reducer(val event: ReducerEvent) : Event() + data object SubscribeApplied : Event() + data object UnsubscribeApplied : Event() + data object Disconnected : Event() + data class SubscribeError(val message: String) : Event() + data object Transaction : Event() +} + +data class Credentials( + val identity: Identity, + val token: String, +) diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/Identity.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/Identity.kt new file mode 100644 index 00000000000..2874d7ebcad --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/Identity.kt @@ -0,0 +1,113 @@ +package com.clockworklabs.spacetimedb + +import com.clockworklabs.spacetimedb.bsatn.BsatnReader +import com.clockworklabs.spacetimedb.bsatn.BsatnWriter + +private val HEX_CHARS = "0123456789abcdef".toCharArray() + +internal fun ByteArray.toHexString(): String { + val result = CharArray(size * 2) + for (i in indices) { + val v = this[i].toInt() and 0xFF + result[i * 2] = HEX_CHARS[v ushr 4] + result[i * 2 + 1] = HEX_CHARS[v and 0x0F] + } + return result.concatToString() +} + +internal fun String.hexToByteArray(): ByteArray { + require(length % 2 == 0) { "Hex string must have even length" } + return ByteArray(length / 2) { i -> + val hi = this[i * 2].digitToInt(16) + val lo = this[i * 2 + 1].digitToInt(16) + ((hi shl 4) or lo).toByte() + } +} + +/** A 256-bit identifier that uniquely represents a user across all SpacetimeDB modules. */ +class Identity(val bytes: ByteArray) { + init { + require(bytes.size == 32) { "Identity must be 32 bytes" } + } + + fun toHex(): String = bytes.toHexString() + + override fun equals(other: Any?): Boolean = + other is Identity && bytes.contentEquals(other.bytes) + + override fun hashCode(): Int = bytes.contentHashCode() + + override fun toString(): String = "Identity(${toHex()})" + + companion object { + val ZERO = Identity(ByteArray(32)) + + fun fromHex(hex: String): Identity { + require(hex.length == 64) { "Identity hex must be 64 characters" } + val bytes = hex.hexToByteArray() + return Identity(bytes) + } + + fun read(reader: BsatnReader): Identity = Identity(reader.readBytes(32)) + + fun write(writer: BsatnWriter, value: Identity) { writer.writeBytes(value.bytes) } + } +} + +/** A 128-bit identifier unique to each client connection session. */ +class ConnectionId(val bytes: ByteArray) { + init { + require(bytes.size == 16) { "ConnectionId must be 16 bytes" } + } + + fun toHex(): String = bytes.toHexString() + + override fun equals(other: Any?): Boolean = + other is ConnectionId && bytes.contentEquals(other.bytes) + + override fun hashCode(): Int = bytes.contentHashCode() + + override fun toString(): String = "ConnectionId(${toHex()})" + + companion object { + val ZERO = ConnectionId(ByteArray(16)) + + fun read(reader: BsatnReader): ConnectionId = ConnectionId(reader.readBytes(16)) + + fun write(writer: BsatnWriter, value: ConnectionId) { writer.writeBytes(value.bytes) } + } +} + +/** A 128-bit address identifying a client in the SpacetimeDB network. */ +class Address(val bytes: ByteArray) { + init { + require(bytes.size == 16) { "Address must be 16 bytes" } + } + + fun toHex(): String = bytes.toHexString() + + override fun equals(other: Any?): Boolean = + other is Address && bytes.contentEquals(other.bytes) + + override fun hashCode(): Int = bytes.contentHashCode() + + override fun toString(): String = "Address(${toHex()})" + + companion object { + val ZERO = Address(ByteArray(16)) + + fun read(reader: BsatnReader): Address = Address(reader.readBytes(16)) + + fun write(writer: BsatnWriter, value: Address) { writer.writeBytes(value.bytes) } + } +} + +/** Server-side timestamp in microseconds since the Unix epoch. */ +@kotlin.jvm.JvmInline +value class Timestamp(val microseconds: Long) { + companion object { + fun read(reader: BsatnReader): Timestamp = Timestamp(reader.readI64()) + + fun write(writer: BsatnWriter, value: Timestamp) { writer.writeI64(value.microseconds) } + } +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/ReconnectPolicy.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/ReconnectPolicy.kt new file mode 100644 index 00000000000..6be72758915 --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/ReconnectPolicy.kt @@ -0,0 +1,31 @@ +package com.clockworklabs.spacetimedb + +/** + * Configures automatic reconnection with exponential backoff. + * + * @property maxRetries Maximum number of reconnect attempts before giving up. + * @property initialDelayMs Delay before the first retry (milliseconds). + * @property maxDelayMs Upper bound on the delay between retries (milliseconds). + * @property backoffMultiplier Factor by which the delay grows each attempt. + */ +data class ReconnectPolicy( + val maxRetries: Int = 5, + val initialDelayMs: Long = 1_000, + val maxDelayMs: Long = 30_000, + val backoffMultiplier: Double = 2.0, +) { + init { + require(maxRetries >= 0) { "maxRetries must be non-negative" } + require(initialDelayMs > 0) { "initialDelayMs must be positive" } + require(maxDelayMs >= initialDelayMs) { "maxDelayMs must be >= initialDelayMs" } + require(backoffMultiplier >= 1.0) { "backoffMultiplier must be >= 1.0" } + } + + internal fun delayForAttempt(attempt: Int): Long { + var delay = initialDelayMs + repeat(attempt) { + delay = (delay * backoffMultiplier).toLong().coerceAtMost(maxDelayMs) + } + return delay + } +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/ReducerHandle.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/ReducerHandle.kt new file mode 100644 index 00000000000..8760c757127 --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/ReducerHandle.kt @@ -0,0 +1,16 @@ +package com.clockworklabs.spacetimedb + +import com.clockworklabs.spacetimedb.bsatn.BsatnWriter + +class ReducerHandle(private val connection: DbConnection) { + + fun call(reducerName: String, args: ByteArray = ByteArray(0), callback: ((ReducerResult) -> Unit)? = null) { + connection.callReducer(reducerName, args, callback) + } + + fun call(reducerName: String, writeArgs: (BsatnWriter) -> Unit, callback: ((ReducerResult) -> Unit)? = null) { + val writer = BsatnWriter() + writeArgs(writer) + connection.callReducer(reducerName, writer.toByteArray(), callback) + } +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/SpacetimeDBClient.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/SpacetimeDBClient.kt new file mode 100644 index 00000000000..46e0c74fd6d --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/SpacetimeDBClient.kt @@ -0,0 +1,319 @@ +package com.clockworklabs.spacetimedb + +import com.clockworklabs.spacetimedb.protocol.* +import com.clockworklabs.spacetimedb.websocket.ConnectionState +import com.clockworklabs.spacetimedb.websocket.WebSocketTransport +import kotlinx.coroutines.* +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock +import kotlinx.atomicfu.atomic + +/** Called when a connection is established. Receives the connection, the user's [Identity], and an auth token. */ +typealias ConnectCallback = (DbConnection, Identity, String) -> Unit +/** Called when a connection is lost. The [Throwable] is null for clean disconnects. */ +typealias DisconnectCallback = (DbConnection, Throwable?) -> Unit +/** Called when the initial connection attempt fails. */ +typealias ConnectErrorCallback = (Throwable) -> Unit + +/** + * Primary client for interacting with a SpacetimeDB module. + * + * Create instances via [DbConnection.builder]: + * ```kotlin + * val conn = DbConnection.builder() + * .withUri("ws://localhost:3000") + * .withModuleName("my_module") + * .onConnect { conn, identity, token -> println("Connected as $identity") } + * .build() + * ``` + * + * The connection is opened immediately on [build][DbConnectionBuilder.build]. Use [disconnect] + * to tear it down, or configure automatic reconnection via [DbConnectionBuilder.withReconnectPolicy]. + */ +/** Compression mode negotiated with the server for host→client messages. */ +enum class CompressionMode(internal val queryValue: String) { + NONE("None"), + GZIP("Gzip"), + BROTLI("Brotli"), +} + +class DbConnection internal constructor( + private val uri: String, + private val moduleName: String, + private val token: String?, + private val connectCallbacks: List, + private val disconnectCallbacks: List, + private val connectErrorCallbacks: List, + private val keepAliveIntervalMs: Long = 30_000L, + private val reconnectPolicy: ReconnectPolicy? = null, + private val compression: CompressionMode = CompressionMode.GZIP, +) { + private val scope = CoroutineScope(SupervisorJob() + Dispatchers.Default) + private val requestCounter = atomic(0) + private val mutex = Mutex() + + internal val clientCache = ClientCache() + private val tableHandles = mutableMapOf() + private val subscriptions = mutableMapOf() + private val subscriptionsByQuerySet = mutableMapOf() + private val reducerCallbacks = mutableMapOf Unit>() + private val pendingOneOffQueries = mutableMapOf>() + + var identity: Identity? = null + private set + var connectionId: ConnectionId? = null + private set + var savedToken: String? = null + private set + + private val transport = WebSocketTransport( + scope = scope, + onMessage = { handleMessage(it) }, + onConnect = {}, + onDisconnect = { error -> + failPendingOperations() + disconnectCallbacks.forEach { it(this, error) } + }, + onConnectError = { error -> connectErrorCallbacks.forEach { it(error) } }, + keepAliveIntervalMs = keepAliveIntervalMs, + reconnectPolicy = reconnectPolicy, + compression = compression, + ) + + val connectionState: StateFlow get() = transport.state + val isActive: Boolean get() = transport.state.value == ConnectionState.CONNECTED + + init { + transport.connect(uri, moduleName, token) + } + + /** Closes the connection, cancels pending operations, and stops any reconnection attempts. */ + fun disconnect() { + transport.disconnect() + failPendingOperations() + scope.cancel() + } + + /** Returns the [TableHandle] for [name], creating it if needed. Register callbacks before connecting. */ + fun table(name: String): TableHandle { + // tableHandles is only read/written from user thread (registration) + // and from handleMessage under mutex (firing callbacks). + // Reads from handleMessage never mutate, so this is safe for the + // typical pattern of registering table handles before connecting. + return tableHandles.getOrPut(name) { TableHandle(name) } + } + + /** Creates a [SubscriptionBuilder] for subscribing to SQL queries on this connection. */ + fun subscriptionBuilder(): SubscriptionBuilder = SubscriptionBuilder(this) + + /** Invokes a server-side reducer by name with BSATN-encoded [args]. Optionally receives the [ReducerResult]. */ + fun callReducer(reducerName: String, args: ByteArray, callback: ((ReducerResult) -> Unit)? = null) { + val reqId = nextRequestId() + if (callback != null) { + // Register synchronously before sending to avoid race with server response + reducerCallbacks[reqId] = callback + } + transport.send( + ClientMessage.CallReducer( + requestId = reqId, + reducer = reducerName, + args = args, + ) + ) + } + + /** Executes a one-off SQL query against the module and suspends until the result arrives. */ + suspend fun oneOffQuery(query: String): ServerMessage.OneOffQueryResult { + val reqId = nextRequestId() + val deferred = CompletableDeferred() + mutex.withLock { pendingOneOffQueries[reqId] = deferred } + transport.send(ClientMessage.OneOffQuery(requestId = reqId, queryString = query)) + return deferred.await() + } + + /** Callback variant of [oneOffQuery] — launches a coroutine and invokes [callback] with the result. */ + fun oneOffQuery(query: String, callback: (ServerMessage.OneOffQueryResult) -> Unit) { + val reqId = nextRequestId() + val deferred = CompletableDeferred() + scope.launch { + mutex.withLock { pendingOneOffQueries[reqId] = deferred } + transport.send(ClientMessage.OneOffQuery(requestId = reqId, queryString = query)) + callback(deferred.await()) + } + } + + internal fun subscribe( + queries: List, + handle: SubscriptionHandle, + ): UInt { + val reqId = nextRequestId() + val qsId = QuerySetId(reqId) + handle.querySetId = qsId + handle.requestId = reqId + // Register synchronously before sending to avoid race with server response + subscriptions[reqId] = handle + subscriptionsByQuerySet[qsId] = handle + transport.send( + ClientMessage.Subscribe( + requestId = reqId, + querySetId = qsId, + queryStrings = queries, + ) + ) + return reqId + } + + internal fun unsubscribe(handle: SubscriptionHandle) { + val qsId = handle.querySetId ?: return + val reqId = nextRequestId() + transport.send( + ClientMessage.Unsubscribe( + requestId = reqId, + querySetId = qsId, + flags = 1u, // SendDroppedRows — ensures server sends rows to remove from cache + ) + ) + } + + private fun nextRequestId(): UInt = requestCounter.incrementAndGet().toUInt() + + private fun failPendingOperations() { + val error = CancellationException("Connection closed") + pendingOneOffQueries.values.forEach { it.cancel(error) } + pendingOneOffQueries.clear() + reducerCallbacks.clear() + } + + private suspend fun handleMessage(msg: ServerMessage) { + mutex.withLock { + when (msg) { + is ServerMessage.InitialConnection -> { + identity = msg.identity + connectionId = msg.connectionId + savedToken = msg.token + connectCallbacks.forEach { it(this, msg.identity, msg.token) } + } + + is ServerMessage.SubscribeApplied -> { + clientCache.applySubscribeRows(msg.rows) + val handle = subscriptions[msg.requestId] + handle?.state = SubscriptionState.ACTIVE + handle?.onAppliedCallback?.invoke() + } + + is ServerMessage.UnsubscribeApplied -> { + msg.rows?.let { clientCache.applyUnsubscribeRows(it) } + // Look up by querySetId since the requestId here is the unsubscribe requestId + val handle = subscriptionsByQuerySet[msg.querySetId] + handle?.state = SubscriptionState.ENDED + handle?.requestId?.let { subscriptions.remove(it) } + subscriptionsByQuerySet.remove(msg.querySetId) + } + + is ServerMessage.SubscriptionError -> { + val handle = if (msg.requestId != null) { + subscriptions[msg.requestId] + } else { + subscriptionsByQuerySet[msg.querySetId] + } + handle?.state = SubscriptionState.ENDED + handle?.onErrorCallback?.invoke(msg.error) + handle?.requestId?.let { subscriptions.remove(it) } + subscriptionsByQuerySet.remove(msg.querySetId) + } + + is ServerMessage.TransactionUpdate -> { + val ops = clientCache.applyTransactionUpdate(msg.querySets) + fireTableCallbacks(ops) + } + + is ServerMessage.ReducerResult -> { + if (msg.result is ReducerOutcome.Ok) { + val txUpdate = msg.result.transactionUpdate + val ops = clientCache.applyTransactionUpdate(txUpdate.querySets) + fireTableCallbacks(ops) + } + reducerCallbacks.remove(msg.requestId)?.invoke( + ReducerResult(msg.requestId, msg.timestamp, msg.result) + ) + } + + is ServerMessage.ProcedureResult -> {} + + is ServerMessage.OneOffQueryResult -> { + pendingOneOffQueries.remove(msg.requestId)?.complete(msg) + } + } + } + } + + private fun fireTableCallbacks(ops: List) { + for (op in ops) { + when (op) { + is TableOperation.Insert -> tableHandles[op.tableName]?.fireInsert(op.row) + is TableOperation.Delete -> tableHandles[op.tableName]?.fireDelete(op.row) + is TableOperation.Update -> tableHandles[op.tableName]?.fireUpdate(op.oldRow, op.newRow) + is TableOperation.EventInsert -> tableHandles[op.tableName]?.fireInsert(op.row) + } + } + } + + companion object { + fun builder(): DbConnectionBuilder = DbConnectionBuilder() + } +} + +/** Result of a reducer invocation, including the server-side [timestamp] and [outcome]. */ +data class ReducerResult( + val requestId: UInt, + val timestamp: Timestamp, + val outcome: ReducerOutcome, +) + +/** Builder for configuring and creating a [DbConnection]. */ +class DbConnectionBuilder { + private var uri: String? = null + private var moduleName: String? = null + private var token: String? = null + private var keepAliveIntervalMs: Long = 30_000L + private var reconnectPolicy: ReconnectPolicy? = null + private var compression: CompressionMode = CompressionMode.GZIP + private val connectCallbacks = mutableListOf() + private val disconnectCallbacks = mutableListOf() + private val connectErrorCallbacks = mutableListOf() + + fun withUri(uri: String) = apply { this.uri = uri } + + fun withModuleName(name: String) = apply { this.moduleName = name } + + fun withToken(token: String?) = apply { this.token = token } + + fun onConnect(callback: ConnectCallback) = apply { connectCallbacks.add(callback) } + + fun onDisconnect(callback: DisconnectCallback) = apply { disconnectCallbacks.add(callback) } + + fun onConnectError(callback: ConnectErrorCallback) = apply { connectErrorCallbacks.add(callback) } + + fun withKeepAliveInterval(intervalMs: Long) = apply { this.keepAliveIntervalMs = intervalMs } + + fun withReconnectPolicy(policy: ReconnectPolicy) = apply { this.reconnectPolicy = policy } + + fun withCompression(mode: CompressionMode) = apply { this.compression = mode } + + fun build(): DbConnection { + val uri = requireNotNull(uri) { "URI is required. Call withUri() before build()." } + val module = requireNotNull(moduleName) { "Module name is required. Call withModuleName() before build()." } + return DbConnection( + uri = uri, + moduleName = module, + token = token, + connectCallbacks = connectCallbacks.toList(), + disconnectCallbacks = disconnectCallbacks.toList(), + connectErrorCallbacks = connectErrorCallbacks.toList(), + keepAliveIntervalMs = keepAliveIntervalMs, + reconnectPolicy = reconnectPolicy, + compression = compression, + ) + } +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/SubscriptionBuilder.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/SubscriptionBuilder.kt new file mode 100644 index 00000000000..bfd6f5c3e46 --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/SubscriptionBuilder.kt @@ -0,0 +1,34 @@ +package com.clockworklabs.spacetimedb + +/** + * Builder for subscribing to SQL queries on a [DbConnection]. + * + * ```kotlin + * conn.subscriptionBuilder() + * .onApplied { println("Subscription active") } + * .onError { err -> println("Subscription failed: $err") } + * .subscribe("SELECT * FROM users WHERE online = true") + * ``` + */ +class SubscriptionBuilder(private val connection: DbConnection) { + private var onAppliedCallback: (() -> Unit)? = null + private var onErrorCallback: ((String) -> Unit)? = null + + fun onApplied(callback: () -> Unit) = apply { this.onAppliedCallback = callback } + + fun onError(callback: (String) -> Unit) = apply { this.onErrorCallback = callback } + + fun subscribe(vararg queries: String): SubscriptionHandle { + val handle = SubscriptionHandle( + connection = connection, + onAppliedCallback = onAppliedCallback, + onErrorCallback = onErrorCallback, + ) + connection.subscribe(queries.toList(), handle) + return handle + } + + fun subscribeToAllTables(): SubscriptionHandle { + return subscribe("SELECT * FROM *") + } +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/SubscriptionHandle.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/SubscriptionHandle.kt new file mode 100644 index 00000000000..b1f1e34df48 --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/SubscriptionHandle.kt @@ -0,0 +1,36 @@ +package com.clockworklabs.spacetimedb + +import com.clockworklabs.spacetimedb.protocol.QuerySetId + +/** Lifecycle states of a subscription. */ +enum class SubscriptionState { + PENDING, + ACTIVE, + ENDED, +} + +/** + * Represents an active subscription to one or more SQL queries. + * + * Created by [SubscriptionBuilder.subscribe]. Call [unsubscribe] to end it. + */ +class SubscriptionHandle internal constructor( + private val connection: DbConnection, + internal val onAppliedCallback: (() -> Unit)?, + internal val onErrorCallback: ((String) -> Unit)?, +) { + internal var querySetId: QuerySetId? = null + internal var requestId: UInt = 0u + var state: SubscriptionState = SubscriptionState.PENDING + internal set + + val isActive: Boolean get() = state == SubscriptionState.ACTIVE + val isEnded: Boolean get() = state == SubscriptionState.ENDED + + fun unsubscribe() { + if (state == SubscriptionState.ACTIVE) { + connection.unsubscribe(this) + state = SubscriptionState.ENDED + } + } +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/TableHandle.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/TableHandle.kt new file mode 100644 index 00000000000..5a78f6019ea --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/TableHandle.kt @@ -0,0 +1,65 @@ +package com.clockworklabs.spacetimedb + +typealias InsertCallback = (ByteArray) -> Unit +typealias DeleteCallback = (ByteArray) -> Unit +typealias UpdateCallback = (oldRow: ByteArray, newRow: ByteArray) -> Unit + +/** + * Handle for observing row changes on a single table. + * + * Obtain via [DbConnection.table]. Register callbacks with [onInsert], [onDelete], + * and [onUpdate]; remove them later with the returned [CallbackId]. + */ +class TableHandle(val tableName: String) { + private var nextId = 0 + private val insertCallbacks = mutableMapOf() + private val deleteCallbacks = mutableMapOf() + private val updateCallbacks = mutableMapOf() + + fun onInsert(callback: InsertCallback): CallbackId { + val id = nextId++ + insertCallbacks[id] = callback + return CallbackId(id) + } + + fun onDelete(callback: DeleteCallback): CallbackId { + val id = nextId++ + deleteCallbacks[id] = callback + return CallbackId(id) + } + + fun onUpdate(callback: UpdateCallback): CallbackId { + val id = nextId++ + updateCallbacks[id] = callback + return CallbackId(id) + } + + fun removeOnInsert(id: CallbackId) { + insertCallbacks.remove(id.value) + } + + fun removeOnDelete(id: CallbackId) { + deleteCallbacks.remove(id.value) + } + + fun removeOnUpdate(id: CallbackId) { + updateCallbacks.remove(id.value) + } + + internal fun fireInsert(row: ByteArray) { + // Snapshot to allow callbacks to register/remove other callbacks safely + for (cb in insertCallbacks.values.toList()) cb(row) + } + + internal fun fireDelete(row: ByteArray) { + for (cb in deleteCallbacks.values.toList()) cb(row) + } + + internal fun fireUpdate(oldRow: ByteArray, newRow: ByteArray) { + for (cb in updateCallbacks.values.toList()) cb(oldRow, newRow) + } +} + +/** Opaque identifier returned by callback registration methods. Used to remove the callback later. */ +@kotlin.jvm.JvmInline +value class CallbackId(val value: Int) diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/bsatn/BsatnReader.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/bsatn/BsatnReader.kt new file mode 100644 index 00000000000..1b8398f4056 --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/bsatn/BsatnReader.kt @@ -0,0 +1,119 @@ +package com.clockworklabs.spacetimedb.bsatn + +class BsatnReader(private val data: ByteArray, private var offset: Int = 0) { + + val remaining: Int get() = data.size - offset + + val isExhausted: Boolean get() = offset >= data.size + + private fun require(count: Int) { + if (offset + count > data.size) { + throw IllegalStateException( + "BSATN: unexpected end of data at offset $offset, " + + "need $count bytes but only ${data.size - offset} remain" + ) + } + } + + fun readU8(): UByte { + require(1) + return data[offset++].toUByte() + } + + fun readI8(): Byte { + require(1) + return data[offset++] + } + + fun readBool(): Boolean = readU8().toInt() != 0 + + fun readU16(): UShort { + require(2) + val v = (data[offset].toUByte().toInt() or (data[offset + 1].toUByte().toInt() shl 8)).toUShort() + offset += 2 + return v + } + + fun readI16(): Short { + require(2) + val v = (data[offset].toUByte().toInt() or (data[offset + 1].toUByte().toInt() shl 8)).toShort() + offset += 2 + return v + } + + fun readU32(): UInt { + require(4) + val v = (data[offset].toUByte().toUInt()) or + (data[offset + 1].toUByte().toUInt() shl 8) or + (data[offset + 2].toUByte().toUInt() shl 16) or + (data[offset + 3].toUByte().toUInt() shl 24) + offset += 4 + return v + } + + fun readI32(): Int { + require(4) + val v = (data[offset].toUByte().toInt()) or + (data[offset + 1].toUByte().toInt() shl 8) or + (data[offset + 2].toUByte().toInt() shl 16) or + (data[offset + 3].toUByte().toInt() shl 24) + offset += 4 + return v + } + + fun readU64(): ULong { + require(8) + var v = 0UL + for (i in 0 until 8) { + v = v or (data[offset + i].toUByte().toULong() shl (i * 8)) + } + offset += 8 + return v + } + + fun readI64(): Long { + require(8) + var v = 0L + for (i in 0 until 8) { + v = v or ((data[offset + i].toUByte().toLong()) shl (i * 8)) + } + offset += 8 + return v + } + + fun readF32(): Float = Float.fromBits(readI32()) + + fun readF64(): Double = Double.fromBits(readI64()) + + fun readBytes(count: Int): ByteArray { + require(count) + val result = data.copyOfRange(offset, offset + count) + offset += count + return result + } + + fun readByteArray(): ByteArray { + val len = readU32().toInt() + return readBytes(len) + } + + fun readString(): String { + val bytes = readByteArray() + return bytes.decodeToString() + } + + fun readTag(): UByte = readU8() + + fun readArray(readElement: (BsatnReader) -> T): List { + val count = readU32().toInt() + return List(count) { readElement(this) } + } + + fun readOption(readElement: (BsatnReader) -> T): T? { + return when (readTag().toInt()) { + 0 -> null + 1 -> readElement(this) + else -> throw IllegalStateException("Invalid Option tag") + } + } +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/bsatn/BsatnRowList.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/bsatn/BsatnRowList.kt new file mode 100644 index 00000000000..328a13e41d7 --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/bsatn/BsatnRowList.kt @@ -0,0 +1,70 @@ +package com.clockworklabs.spacetimedb.bsatn + +sealed class RowSizeHint { + data class FixedSize(val rowSize: UShort) : RowSizeHint() + data class RowOffsets(val offsets: List) : RowSizeHint() + + companion object { + fun read(reader: BsatnReader): RowSizeHint { + return when (reader.readTag().toInt()) { + 0 -> FixedSize(reader.readU16()) + 1 -> RowOffsets(reader.readArray { it.readU64() }) + else -> throw IllegalStateException("Invalid RowSizeHint tag") + } + } + + fun write(writer: BsatnWriter, value: RowSizeHint) { + when (value) { + is FixedSize -> { + writer.writeTag(0u) + writer.writeU16(value.rowSize) + } + is RowOffsets -> { + writer.writeTag(1u) + writer.writeArray(value.offsets) { w, v -> w.writeU64(v) } + } + } + } + } +} + +class BsatnRowList( + val sizeHint: RowSizeHint, + val rowsData: ByteArray, +) { + fun decodeRows(): List { + if (rowsData.isEmpty()) return emptyList() + + return when (val hint = sizeHint) { + is RowSizeHint.FixedSize -> { + val rowSize = hint.rowSize.toInt() + if (rowSize == 0) return emptyList() + val count = rowsData.size / rowSize + List(count) { i -> + rowsData.copyOfRange(i * rowSize, (i + 1) * rowSize) + } + } + is RowSizeHint.RowOffsets -> { + val offsets = hint.offsets + List(offsets.size) { i -> + val start = offsets[i].toInt() + val end = if (i + 1 < offsets.size) offsets[i + 1].toInt() else rowsData.size + rowsData.copyOfRange(start, end) + } + } + } + } + + companion object { + fun read(reader: BsatnReader): BsatnRowList { + val sizeHint = RowSizeHint.read(reader) + val rowsData = reader.readByteArray() + return BsatnRowList(sizeHint, rowsData) + } + + fun write(writer: BsatnWriter, value: BsatnRowList) { + RowSizeHint.write(writer, value.sizeHint) + writer.writeByteArray(value.rowsData) + } + } +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/bsatn/BsatnWriter.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/bsatn/BsatnWriter.kt new file mode 100644 index 00000000000..e909bbd6e86 --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/bsatn/BsatnWriter.kt @@ -0,0 +1,113 @@ +package com.clockworklabs.spacetimedb.bsatn + +class BsatnWriter(initialCapacity: Int = 256) { + + private var buffer = ByteArray(initialCapacity) + private var position = 0 + + private fun ensureCapacity(needed: Int) { + val required = position + needed + if (required > buffer.size) { + val newSize = maxOf(buffer.size * 2, required) + buffer = buffer.copyOf(newSize) + } + } + + fun writeBool(value: Boolean) { + writeU8(if (value) 1u else 0u) + } + + fun writeU8(value: UByte) { + ensureCapacity(1) + buffer[position++] = value.toByte() + } + + fun writeI8(value: Byte) { + ensureCapacity(1) + buffer[position++] = value + } + + fun writeU16(value: UShort) { + ensureCapacity(2) + val v = value.toInt() + buffer[position++] = v.toByte() + buffer[position++] = (v shr 8).toByte() + } + + fun writeI16(value: Short) { + ensureCapacity(2) + val v = value.toInt() + buffer[position++] = v.toByte() + buffer[position++] = (v shr 8).toByte() + } + + fun writeU32(value: UInt) { + ensureCapacity(4) + val v = value.toInt() + buffer[position++] = v.toByte() + buffer[position++] = (v shr 8).toByte() + buffer[position++] = (v shr 16).toByte() + buffer[position++] = (v shr 24).toByte() + } + + fun writeI32(value: Int) { + ensureCapacity(4) + buffer[position++] = value.toByte() + buffer[position++] = (value shr 8).toByte() + buffer[position++] = (value shr 16).toByte() + buffer[position++] = (value shr 24).toByte() + } + + fun writeU64(value: ULong) { + ensureCapacity(8) + val v = value.toLong() + for (i in 0 until 8) { + buffer[position++] = (v shr (i * 8)).toByte() + } + } + + fun writeI64(value: Long) { + ensureCapacity(8) + for (i in 0 until 8) { + buffer[position++] = (value shr (i * 8)).toByte() + } + } + + fun writeF32(value: Float) { writeI32(value.toRawBits()) } + + fun writeF64(value: Double) { writeI64(value.toRawBits()) } + + fun writeBytes(bytes: ByteArray) { + ensureCapacity(bytes.size) + bytes.copyInto(buffer, position) + position += bytes.size + } + + fun writeByteArray(bytes: ByteArray) { + writeU32(bytes.size.toUInt()) + writeBytes(bytes) + } + + fun writeString(value: String) { + val bytes = value.encodeToByteArray() + writeByteArray(bytes) + } + + fun writeTag(tag: UByte) { writeU8(tag) } + + fun writeArray(items: List, writeElement: (BsatnWriter, T) -> Unit) { + writeU32(items.size.toUInt()) + items.forEach { writeElement(this, it) } + } + + fun writeOption(value: T?, writeElement: (BsatnWriter, T) -> Unit) { + if (value == null) { + writeTag(0u) + } else { + writeTag(1u) + writeElement(this, value) + } + } + + fun toByteArray(): ByteArray = buffer.copyOf(position) +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/protocol/ClientMessage.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/protocol/ClientMessage.kt new file mode 100644 index 00000000000..7fa027f97e7 --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/protocol/ClientMessage.kt @@ -0,0 +1,100 @@ +package com.clockworklabs.spacetimedb.protocol + +import com.clockworklabs.spacetimedb.bsatn.BsatnWriter + +sealed class ClientMessage { + data class Subscribe( + val requestId: UInt, + val querySetId: QuerySetId, + val queryStrings: List, + ) : ClientMessage() + + data class Unsubscribe( + val requestId: UInt, + val querySetId: QuerySetId, + val flags: UByte = 0u, + ) : ClientMessage() + + data class OneOffQuery( + val requestId: UInt, + val queryString: String, + ) : ClientMessage() + + data class CallReducer( + val requestId: UInt, + val reducer: String, + val args: ByteArray, + val flags: UByte = 0u, + ) : ClientMessage() { + override fun equals(other: Any?): Boolean = + other is CallReducer && requestId == other.requestId && + reducer == other.reducer && args.contentEquals(other.args) && + flags == other.flags + + override fun hashCode(): Int { + var result = requestId.hashCode() + result = 31 * result + reducer.hashCode() + result = 31 * result + args.contentHashCode() + result = 31 * result + flags.hashCode() + return result + } + } + + data class CallProcedure( + val requestId: UInt, + val procedure: String, + val args: ByteArray, + val flags: UByte = 0u, + ) : ClientMessage() { + override fun equals(other: Any?): Boolean = + other is CallProcedure && requestId == other.requestId && + procedure == other.procedure && args.contentEquals(other.args) && + flags == other.flags + + override fun hashCode(): Int { + var result = requestId.hashCode() + result = 31 * result + procedure.hashCode() + result = 31 * result + args.contentHashCode() + result = 31 * result + flags.hashCode() + return result + } + } + + fun encode(): ByteArray { + val writer = BsatnWriter() + when (this) { + is Subscribe -> { + writer.writeTag(0u) + writer.writeU32(requestId) + QuerySetId.write(writer, querySetId) + writer.writeArray(queryStrings) { w, s -> w.writeString(s) } + } + is Unsubscribe -> { + writer.writeTag(1u) + writer.writeU32(requestId) + QuerySetId.write(writer, querySetId) + writer.writeU8(flags) + } + is OneOffQuery -> { + writer.writeTag(2u) + writer.writeU32(requestId) + writer.writeString(queryString) + } + is CallReducer -> { + writer.writeTag(3u) + writer.writeU32(requestId) + writer.writeU8(flags) + writer.writeString(reducer) + writer.writeByteArray(args) + } + is CallProcedure -> { + writer.writeTag(4u) + writer.writeU32(requestId) + writer.writeU8(flags) + writer.writeString(procedure) + writer.writeByteArray(args) + } + } + return writer.toByteArray() + } +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/protocol/ProtocolTypes.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/protocol/ProtocolTypes.kt new file mode 100644 index 00000000000..7531a56ad74 --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/protocol/ProtocolTypes.kt @@ -0,0 +1,159 @@ +package com.clockworklabs.spacetimedb.protocol + +import com.clockworklabs.spacetimedb.bsatn.BsatnReader +import com.clockworklabs.spacetimedb.bsatn.BsatnRowList +import com.clockworklabs.spacetimedb.bsatn.BsatnWriter + +@kotlin.jvm.JvmInline +value class QuerySetId(val id: UInt) { + companion object { + fun read(reader: BsatnReader): QuerySetId = QuerySetId(reader.readU32()) + fun write(writer: BsatnWriter, value: QuerySetId) { writer.writeU32(value.id) } + } +} + +@kotlin.jvm.JvmInline +value class RawIdentifier(val value: String) { + companion object { + fun read(reader: BsatnReader): RawIdentifier = RawIdentifier(reader.readString()) + fun write(writer: BsatnWriter, value: RawIdentifier) { writer.writeString(value.value) } + } +} + +data class SingleTableRows( + val table: RawIdentifier, + val rows: BsatnRowList, +) { + companion object { + fun read(reader: BsatnReader): SingleTableRows = SingleTableRows( + table = RawIdentifier.read(reader), + rows = BsatnRowList.read(reader), + ) + } +} + +data class QueryRows(val tables: List) { + companion object { + fun read(reader: BsatnReader): QueryRows = + QueryRows(reader.readArray { SingleTableRows.read(it) }) + } +} + +sealed class TableUpdateRows { + data class PersistentTable(val rows: PersistentTableRows) : TableUpdateRows() + data class EventTable(val rows: EventTableRows) : TableUpdateRows() + + companion object { + fun read(reader: BsatnReader): TableUpdateRows { + return when (reader.readTag().toInt()) { + 0 -> PersistentTable(PersistentTableRows.read(reader)) + 1 -> EventTable(EventTableRows.read(reader)) + else -> throw IllegalStateException("Invalid TableUpdateRows tag") + } + } + } +} + +data class PersistentTableRows( + val inserts: BsatnRowList, + val deletes: BsatnRowList, +) { + companion object { + fun read(reader: BsatnReader): PersistentTableRows = PersistentTableRows( + inserts = BsatnRowList.read(reader), + deletes = BsatnRowList.read(reader), + ) + } +} + +data class EventTableRows(val events: BsatnRowList) { + companion object { + fun read(reader: BsatnReader): EventTableRows = + EventTableRows(BsatnRowList.read(reader)) + } +} + +data class TableUpdate( + val tableName: RawIdentifier, + val rows: List, +) { + companion object { + fun read(reader: BsatnReader): TableUpdate = TableUpdate( + tableName = RawIdentifier.read(reader), + rows = reader.readArray { TableUpdateRows.read(it) }, + ) + } +} + +data class QuerySetUpdate( + val querySetId: QuerySetId, + val tables: List, +) { + companion object { + fun read(reader: BsatnReader): QuerySetUpdate = QuerySetUpdate( + querySetId = QuerySetId.read(reader), + tables = reader.readArray { TableUpdate.read(it) }, + ) + } +} + +sealed class ReducerOutcome { + data class Ok(val retValue: ByteArray, val transactionUpdate: TransactionUpdateData) : ReducerOutcome() { + override fun equals(other: Any?): Boolean = + other is Ok && retValue.contentEquals(other.retValue) && transactionUpdate == other.transactionUpdate + override fun hashCode(): Int = retValue.contentHashCode() * 31 + transactionUpdate.hashCode() + } + data object OkEmpty : ReducerOutcome() + data class Err(val message: ByteArray) : ReducerOutcome() { + override fun equals(other: Any?): Boolean = other is Err && message.contentEquals(other.message) + override fun hashCode(): Int = message.contentHashCode() + } + data class InternalError(val message: String) : ReducerOutcome() + + companion object { + fun read(reader: BsatnReader): ReducerOutcome { + return when (reader.readTag().toInt()) { + 0 -> Ok( + retValue = reader.readByteArray(), + transactionUpdate = TransactionUpdateData.read(reader), + ) + 1 -> OkEmpty + 2 -> Err(reader.readByteArray()) + 3 -> InternalError(reader.readString()) + else -> throw IllegalStateException("Invalid ReducerOutcome tag") + } + } + } +} + +data class TransactionUpdateData(val querySets: List) { + companion object { + fun read(reader: BsatnReader): TransactionUpdateData = + TransactionUpdateData(reader.readArray { QuerySetUpdate.read(it) }) + } +} + +sealed class ProcedureStatus { + data class Returned(val data: ByteArray) : ProcedureStatus() { + override fun equals(other: Any?): Boolean = other is Returned && data.contentEquals(other.data) + override fun hashCode(): Int = data.contentHashCode() + } + data class InternalError(val message: String) : ProcedureStatus() + + companion object { + fun read(reader: BsatnReader): ProcedureStatus { + return when (reader.readTag().toInt()) { + 0 -> Returned(reader.readByteArray()) + 1 -> InternalError(reader.readString()) + else -> throw IllegalStateException("Invalid ProcedureStatus tag") + } + } + } +} + +@kotlin.jvm.JvmInline +value class TimeDuration(val microseconds: ULong) { + companion object { + fun read(reader: BsatnReader): TimeDuration = TimeDuration(reader.readU64()) + } +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/protocol/ServerMessage.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/protocol/ServerMessage.kt new file mode 100644 index 00000000000..aeaae132d36 --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/protocol/ServerMessage.kt @@ -0,0 +1,107 @@ +package com.clockworklabs.spacetimedb.protocol + +import com.clockworklabs.spacetimedb.ConnectionId +import com.clockworklabs.spacetimedb.Identity +import com.clockworklabs.spacetimedb.Timestamp +import com.clockworklabs.spacetimedb.bsatn.BsatnReader + +sealed class ServerMessage { + data class InitialConnection( + val identity: Identity, + val connectionId: ConnectionId, + val token: String, + ) : ServerMessage() + + data class SubscribeApplied( + val requestId: UInt, + val querySetId: QuerySetId, + val rows: QueryRows, + ) : ServerMessage() + + data class UnsubscribeApplied( + val requestId: UInt, + val querySetId: QuerySetId, + val rows: QueryRows?, + ) : ServerMessage() + + data class SubscriptionError( + val requestId: UInt?, + val querySetId: QuerySetId, + val error: String, + ) : ServerMessage() + + data class TransactionUpdate( + val querySets: List, + ) : ServerMessage() + + data class OneOffQueryResult( + val requestId: UInt, + val rows: QueryRows?, + val error: String?, + ) : ServerMessage() + + data class ReducerResult( + val requestId: UInt, + val timestamp: Timestamp, + val result: ReducerOutcome, + ) : ServerMessage() + + data class ProcedureResult( + val requestId: UInt, + val timestamp: Timestamp, + val status: ProcedureStatus, + val totalHostExecutionDuration: TimeDuration, + ) : ServerMessage() + + companion object { + fun decode(data: ByteArray): ServerMessage { + val reader = BsatnReader(data) + return when (reader.readTag().toInt()) { + 0 -> InitialConnection( + identity = Identity.read(reader), + connectionId = ConnectionId.read(reader), + token = reader.readString(), + ) + 1 -> SubscribeApplied( + requestId = reader.readU32(), + querySetId = QuerySetId.read(reader), + rows = QueryRows.read(reader), + ) + 2 -> UnsubscribeApplied( + requestId = reader.readU32(), + querySetId = QuerySetId.read(reader), + rows = reader.readOption { QueryRows.read(it) }, + ) + 3 -> SubscriptionError( + requestId = reader.readOption { it.readU32() }, + querySetId = QuerySetId.read(reader), + error = reader.readString(), + ) + 4 -> TransactionUpdate( + querySets = reader.readArray { QuerySetUpdate.read(it) }, + ) + 5 -> { + val requestId = reader.readU32() + when (reader.readTag().toInt()) { + 0 -> OneOffQueryResult(requestId, QueryRows.read(reader), null) + 1 -> OneOffQueryResult(requestId, null, reader.readString()) + else -> throw IllegalStateException("Invalid OneOffQueryResult Result tag") + } + } + 6 -> ReducerResult( + requestId = reader.readU32(), + timestamp = Timestamp.read(reader), + result = ReducerOutcome.read(reader), + ) + 7 -> ProcedureResult( + status = ProcedureStatus.read(reader), + timestamp = Timestamp.read(reader), + totalHostExecutionDuration = TimeDuration.read(reader), + requestId = reader.readU32(), + ) + else -> throw IllegalStateException("Unknown ServerMessage tag") + } + } + + } +} diff --git a/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/websocket/WebSocketTransport.kt b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/websocket/WebSocketTransport.kt new file mode 100644 index 00000000000..46441b8bf26 --- /dev/null +++ b/sdks/kotlin/src/commonMain/kotlin/com/clockworklabs/spacetimedb/websocket/WebSocketTransport.kt @@ -0,0 +1,266 @@ +package com.clockworklabs.spacetimedb.websocket + +import com.clockworklabs.spacetimedb.CompressionMode +import com.clockworklabs.spacetimedb.ReconnectPolicy +import com.clockworklabs.spacetimedb.decompressBrotli +import com.clockworklabs.spacetimedb.decompressGzip +import com.clockworklabs.spacetimedb.protocol.ClientMessage +import com.clockworklabs.spacetimedb.protocol.ServerMessage +import io.ktor.client.* +import io.ktor.client.plugins.websocket.* +import io.ktor.websocket.* +import kotlinx.atomicfu.atomic +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.StateFlow + +private val HEX = "0123456789ABCDEF".toCharArray() + +enum class ConnectionState { + DISCONNECTED, + CONNECTING, + CONNECTED, + RECONNECTING, +} + +class WebSocketTransport( + private val scope: CoroutineScope, + private val onMessage: suspend (ServerMessage) -> Unit, + private val onConnect: () -> Unit, + private val onDisconnect: (Throwable?) -> Unit, + private val onConnectError: (Throwable) -> Unit, + private val keepAliveIntervalMs: Long = 30_000L, + private val reconnectPolicy: ReconnectPolicy? = null, + private val compression: CompressionMode = CompressionMode.GZIP, +) { + private val client = HttpClient { + install(WebSockets) + } + + private val _state = MutableStateFlow(ConnectionState.DISCONNECTED) + val state: StateFlow = _state + + private val outboundQueue = Channel(Channel.UNLIMITED) + private var session: DefaultClientWebSocketSession? = null + private var connectJob: Job? = null + private val intentionalDisconnect = atomic(false) + + // Ping/pong idle detection (mirrors Rust SDK's 30s idle timeout) + private val idle = atomic(true) + private val wantPong = atomic(false) + + fun connect(uri: String, moduleName: String, token: String?) { + if (_state.value != ConnectionState.DISCONNECTED) return + intentionalDisconnect.value = false + _state.value = ConnectionState.CONNECTING + + connectJob = scope.launch { + runConnection(uri, moduleName, token) + } + } + + private suspend fun runConnection(uri: String, moduleName: String, token: String?) { + try { + connectSession(uri, moduleName, token) + // Session ended normally + if (!intentionalDisconnect.value && reconnectPolicy != null) { + attemptReconnect(uri, moduleName, token) + } else { + _state.value = ConnectionState.DISCONNECTED + onDisconnect(null) + } + } catch (e: CancellationException) { + _state.value = ConnectionState.DISCONNECTED + if (!intentionalDisconnect.value) { + onDisconnect(null) + } + } catch (e: Throwable) { + val previousState = _state.value + if (!intentionalDisconnect.value && reconnectPolicy != null && previousState == ConnectionState.CONNECTED) { + // Was connected, lost connection unexpectedly — try to reconnect + attemptReconnect(uri, moduleName, token) + } else if (previousState == ConnectionState.CONNECTING) { + _state.value = ConnectionState.DISCONNECTED + onConnectError(e) + } else { + _state.value = ConnectionState.DISCONNECTED + onDisconnect(e) + } + } + } + + private suspend fun connectSession(uri: String, moduleName: String, token: String?) { + val wsUri = buildWsUri(uri, moduleName, token) + client.webSocket( + urlString = wsUri, + request = { + headers.append("Sec-WebSocket-Protocol", "v2.bsatn.spacetimedb") + } + ) { + session = this + idle.value = true + wantPong.value = false + _state.value = ConnectionState.CONNECTED + onConnect() + + val sendJob = launch { processSendQueue() } + val receiveJob = launch { processIncoming() } + val keepAliveJob = if (keepAliveIntervalMs > 0) { + launch { runKeepAlive() } + } else null + + receiveJob.join() + keepAliveJob?.cancelAndJoin() + sendJob.cancelAndJoin() + } + } + + private suspend fun attemptReconnect(uri: String, moduleName: String, token: String?) { + val policy = reconnectPolicy ?: return + _state.value = ConnectionState.RECONNECTING + + for (attempt in 0 until policy.maxRetries) { + if (intentionalDisconnect.value) { + _state.value = ConnectionState.DISCONNECTED + return + } + + val delayMs = policy.delayForAttempt(attempt) + delay(delayMs) + + if (intentionalDisconnect.value) { + _state.value = ConnectionState.DISCONNECTED + return + } + + try { + connectSession(uri, moduleName, token) + // If connectSession returns normally, the session ended cleanly. + // If we still want to reconnect (not intentional), loop again. + if (intentionalDisconnect.value) { + _state.value = ConnectionState.DISCONNECTED + return + } + _state.value = ConnectionState.RECONNECTING + } catch (e: CancellationException) { + _state.value = ConnectionState.DISCONNECTED + return + } catch (_: Throwable) { + // Connection attempt failed — continue to next retry + _state.value = ConnectionState.RECONNECTING + } + } + + // Exhausted all retries + _state.value = ConnectionState.DISCONNECTED + onDisconnect(null) + } + + fun disconnect() { + intentionalDisconnect.value = true + connectJob?.cancel() + session = null + _state.value = ConnectionState.DISCONNECTED + client.close() + } + + fun send(message: ClientMessage) { + val encoded = message.encode() + outboundQueue.trySend(encoded) + } + + private suspend fun DefaultClientWebSocketSession.processSendQueue() { + for (bytes in outboundQueue) { + send(Frame.Binary(true, bytes)) + } + } + + private suspend fun DefaultClientWebSocketSession.processIncoming() { + for (frame in incoming) { + when (frame) { + is Frame.Binary -> { + idle.value = false + val raw = frame.readBytes() + val payload = decompressIfNeeded(raw) + val msg = ServerMessage.decode(payload) + onMessage(msg) + } + is Frame.Pong -> { + idle.value = false + wantPong.value = false + } + is Frame.Close -> return + else -> { + idle.value = false + } + } + } + } + + /** + * Idle timeout keep-alive, modeled on the Rust SDK pattern: + * + * Every [keepAliveIntervalMs]: + * - If no data arrived and we're waiting for a pong -> connection is dead, close it. + * - If no data arrived -> send a Ping, start waiting for pong. + * - If data arrived -> reset idle flag for the next interval. + */ + private suspend fun DefaultClientWebSocketSession.runKeepAlive() { + while (true) { + delay(keepAliveIntervalMs) + if (idle.value) { + if (wantPong.value) { + close(CloseReason(CloseReason.Codes.GOING_AWAY, "Idle timeout")) + return + } + send(Frame.Ping(ByteArray(0))) + wantPong.value = true + } + idle.value = true + } + } + + private fun decompressIfNeeded(data: ByteArray): ByteArray { + if (data.isEmpty()) return data + val tag = data[0].toUByte().toInt() + val payload = data.copyOfRange(1, data.size) + return when (tag) { + 0 -> payload + 1 -> decompressBrotli(payload) + 2 -> decompressGzip(payload) + else -> throw IllegalStateException("Unknown compression tag: $tag") + } + } + + private fun urlEncode(value: String): String = buildString { + for (c in value) { + when { + c.isLetterOrDigit() || c in "-._~" -> append(c) + else -> { + for (b in c.toString().encodeToByteArray()) { + append('%') + append(HEX[(b.toInt() shr 4) and 0xF]) + append(HEX[b.toInt() and 0xF]) + } + } + } + } + } + + private fun buildWsUri(uri: String, moduleName: String, token: String?): String { + val base = uri.trimEnd('/') + val wsBase = when { + base.startsWith("ws://") || base.startsWith("wss://") -> base + base.startsWith("http://") -> "ws://" + base.removePrefix("http://") + base.startsWith("https://") -> "wss://" + base.removePrefix("https://") + else -> "ws://$base" + } + val sb = StringBuilder("$wsBase/v1/database/$moduleName/subscribe") + val params = mutableListOf() + if (token != null) params.add("token=${urlEncode(token)}") + params.add("compression=${compression.queryValue}") + sb.append("?${params.joinToString("&")}") + return sb.toString() + } +} diff --git a/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/BsatnTest.kt b/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/BsatnTest.kt new file mode 100644 index 00000000000..c7e79a49f74 --- /dev/null +++ b/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/BsatnTest.kt @@ -0,0 +1,170 @@ +package com.clockworklabs.spacetimedb + +import com.clockworklabs.spacetimedb.bsatn.BsatnReader +import com.clockworklabs.spacetimedb.bsatn.BsatnWriter +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNull +import kotlin.test.assertTrue + +class BsatnTest { + + @Test + fun roundTripBool() { + val writer = BsatnWriter() + writer.writeBool(true) + writer.writeBool(false) + val reader = BsatnReader(writer.toByteArray()) + assertEquals(true, reader.readBool()) + assertEquals(false, reader.readBool()) + } + + @Test + fun roundTripU8() { + val writer = BsatnWriter() + writer.writeU8(0u) + writer.writeU8(255u) + writer.writeU8(42u) + val reader = BsatnReader(writer.toByteArray()) + assertEquals(0u.toUByte(), reader.readU8()) + assertEquals(255u.toUByte(), reader.readU8()) + assertEquals(42u.toUByte(), reader.readU8()) + } + + @Test + fun roundTripI32() { + val writer = BsatnWriter() + writer.writeI32(0) + writer.writeI32(Int.MAX_VALUE) + writer.writeI32(Int.MIN_VALUE) + writer.writeI32(-1) + val reader = BsatnReader(writer.toByteArray()) + assertEquals(0, reader.readI32()) + assertEquals(Int.MAX_VALUE, reader.readI32()) + assertEquals(Int.MIN_VALUE, reader.readI32()) + assertEquals(-1, reader.readI32()) + } + + @Test + fun roundTripU32() { + val writer = BsatnWriter() + writer.writeU32(0u) + writer.writeU32(UInt.MAX_VALUE) + writer.writeU32(12345u) + val reader = BsatnReader(writer.toByteArray()) + assertEquals(0u, reader.readU32()) + assertEquals(UInt.MAX_VALUE, reader.readU32()) + assertEquals(12345u, reader.readU32()) + } + + @Test + fun roundTripI64() { + val writer = BsatnWriter() + writer.writeI64(0L) + writer.writeI64(Long.MAX_VALUE) + writer.writeI64(Long.MIN_VALUE) + val reader = BsatnReader(writer.toByteArray()) + assertEquals(0L, reader.readI64()) + assertEquals(Long.MAX_VALUE, reader.readI64()) + assertEquals(Long.MIN_VALUE, reader.readI64()) + } + + @Test + fun roundTripU64() { + val writer = BsatnWriter() + writer.writeU64(0u) + writer.writeU64(ULong.MAX_VALUE) + val reader = BsatnReader(writer.toByteArray()) + assertEquals(0u.toULong(), reader.readU64()) + assertEquals(ULong.MAX_VALUE, reader.readU64()) + } + + @Test + fun roundTripF32() { + val writer = BsatnWriter() + writer.writeF32(3.14f) + writer.writeF32(0.0f) + writer.writeF32(-1.0f) + val reader = BsatnReader(writer.toByteArray()) + assertEquals(3.14f, reader.readF32()) + assertEquals(0.0f, reader.readF32()) + assertEquals(-1.0f, reader.readF32()) + } + + @Test + fun roundTripF64() { + val writer = BsatnWriter() + writer.writeF64(3.141592653589793) + writer.writeF64(Double.MAX_VALUE) + val reader = BsatnReader(writer.toByteArray()) + assertEquals(3.141592653589793, reader.readF64()) + assertEquals(Double.MAX_VALUE, reader.readF64()) + } + + @Test + fun roundTripString() { + val writer = BsatnWriter() + writer.writeString("hello") + writer.writeString("") + writer.writeString("unicode: 日本語 🚀") + val reader = BsatnReader(writer.toByteArray()) + assertEquals("hello", reader.readString()) + assertEquals("", reader.readString()) + assertEquals("unicode: 日本語 🚀", reader.readString()) + } + + @Test + fun roundTripByteArray() { + val writer = BsatnWriter() + val data = byteArrayOf(1, 2, 3, 4, 5) + writer.writeByteArray(data) + writer.writeByteArray(ByteArray(0)) + val reader = BsatnReader(writer.toByteArray()) + assertTrue(data.contentEquals(reader.readByteArray())) + assertTrue(ByteArray(0).contentEquals(reader.readByteArray())) + } + + @Test + fun roundTripArray() { + val writer = BsatnWriter() + writer.writeArray(listOf(10, 20, 30)) { w, v -> w.writeI32(v) } + val reader = BsatnReader(writer.toByteArray()) + val result = reader.readArray { it.readI32() } + assertEquals(listOf(10, 20, 30), result) + } + + @Test + fun roundTripOption() { + val writer = BsatnWriter() + writer.writeOption(42) { w, v -> w.writeI32(v) } + writer.writeOption(null) { w, v -> w.writeI32(v) } + val reader = BsatnReader(writer.toByteArray()) + assertEquals(42, reader.readOption { it.readI32() }) + assertNull(reader.readOption { it.readI32() }) + } + + @Test + fun littleEndianEncoding() { + val writer = BsatnWriter() + writer.writeU32(0x04030201u) + val bytes = writer.toByteArray() + assertEquals(1, bytes[0].toInt()) + assertEquals(2, bytes[1].toInt()) + assertEquals(3, bytes[2].toInt()) + assertEquals(4, bytes[3].toInt()) + } + + @Test + fun stringEncodingFormat() { + val writer = BsatnWriter() + writer.writeString("AB") + val bytes = writer.toByteArray() + assertEquals(6, bytes.size) + assertEquals(2, bytes[0].toInt()) + assertEquals(0, bytes[1].toInt()) + assertEquals(0, bytes[2].toInt()) + assertEquals(0, bytes[3].toInt()) + assertEquals(0x41, bytes[4].toInt()) + assertEquals(0x42, bytes[5].toInt()) + } +} diff --git a/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/ClientCacheTest.kt b/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/ClientCacheTest.kt new file mode 100644 index 00000000000..a4f868932c8 --- /dev/null +++ b/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/ClientCacheTest.kt @@ -0,0 +1,106 @@ +package com.clockworklabs.spacetimedb + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +class ClientCacheTest { + + @Test + fun insertAndCount() { + val cache = TableCache("users") + cache.insertRow(byteArrayOf(1, 2, 3)) + cache.insertRow(byteArrayOf(4, 5, 6)) + assertEquals(2, cache.count) + } + + @Test + fun duplicateInsertIncrementsRefCount() { + val cache = TableCache("users") + cache.insertRow(byteArrayOf(1, 2, 3)) + cache.insertRow(byteArrayOf(1, 2, 3)) + assertEquals(1, cache.count) + } + + @Test + fun deleteRemovesRow() { + val cache = TableCache("users") + cache.insertRow(byteArrayOf(1, 2, 3)) + assertTrue(cache.deleteRow(byteArrayOf(1, 2, 3))) + assertEquals(0, cache.count) + } + + @Test + fun deleteNonexistentReturnsFalse() { + val cache = TableCache("users") + assertFalse(cache.deleteRow(byteArrayOf(1, 2, 3))) + } + + @Test + fun refCountedDelete() { + val cache = TableCache("users") + cache.insertRow(byteArrayOf(1, 2, 3)) + cache.insertRow(byteArrayOf(1, 2, 3)) + cache.deleteRow(byteArrayOf(1, 2, 3)) + assertEquals(1, cache.count) + cache.deleteRow(byteArrayOf(1, 2, 3)) + assertEquals(0, cache.count) + } + + @Test + fun containsRow() { + val cache = TableCache("users") + cache.insertRow(byteArrayOf(10, 20)) + assertTrue(cache.containsRow(byteArrayOf(10, 20))) + assertFalse(cache.containsRow(byteArrayOf(30, 40))) + } + + @Test + fun allRows() { + val cache = TableCache("users") + cache.insertRow(byteArrayOf(1)) + cache.insertRow(byteArrayOf(2)) + cache.insertRow(byteArrayOf(3)) + assertEquals(3, cache.allRows().size) + } + + @Test + fun clientCacheGetOrCreate() { + val cc = ClientCache() + val t1 = cc.getOrCreateTable("users") + val t2 = cc.getOrCreateTable("users") + assertTrue(t1 === t2) + } + + @Test + fun clientCacheTableNames() { + val cc = ClientCache() + cc.getOrCreateTable("users") + cc.getOrCreateTable("messages") + assertEquals(setOf("users", "messages"), cc.tableNames()) + } + + @Test + fun tableHandleCallbacks() { + val handle = TableHandle("users") + var inserted: ByteArray? = null + var deleted: ByteArray? = null + var updatedOld: ByteArray? = null + var updatedNew: ByteArray? = null + + handle.onInsert { row -> inserted = row } + handle.onDelete { row -> deleted = row } + handle.onUpdate { old, new -> updatedOld = old; updatedNew = new } + + handle.fireInsert(byteArrayOf(1, 2, 3)) + assertTrue(byteArrayOf(1, 2, 3).contentEquals(inserted!!)) + + handle.fireDelete(byteArrayOf(4, 5, 6)) + assertTrue(byteArrayOf(4, 5, 6).contentEquals(deleted!!)) + + handle.fireUpdate(byteArrayOf(1), byteArrayOf(2)) + assertTrue(byteArrayOf(1).contentEquals(updatedOld!!)) + assertTrue(byteArrayOf(2).contentEquals(updatedNew!!)) + } +} diff --git a/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/EdgeCaseTest.kt b/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/EdgeCaseTest.kt new file mode 100644 index 00000000000..aebe8064870 --- /dev/null +++ b/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/EdgeCaseTest.kt @@ -0,0 +1,755 @@ +package com.clockworklabs.spacetimedb + +import com.clockworklabs.spacetimedb.bsatn.BsatnReader +import com.clockworklabs.spacetimedb.bsatn.BsatnWriter +import com.clockworklabs.spacetimedb.protocol.* +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertNotNull +import kotlin.test.assertNull +import kotlin.test.assertTrue + +/** + * Edge case tests covering protocol decode, cache semantics, callback behavior, + * URI handling, and subscription lifecycle — all offline, no server needed. + */ +class EdgeCaseTest { + + // ──────────────── ReducerOutcome: All 4 variants ──────────────── + + @Test + fun reducerOutcomeOkDecode() { + val w = BsatnWriter(128) + w.writeTag(0u) // Ok + w.writeByteArray(byteArrayOf(42)) // retValue + // TransactionUpdateData: array of QuerySetUpdate (empty) + w.writeU32(0u) + val outcome = ReducerOutcome.read(BsatnReader(w.toByteArray())) + assertTrue(outcome is ReducerOutcome.Ok) + assertEquals(1, outcome.retValue.size) + assertEquals(42.toByte(), outcome.retValue[0]) + } + + @Test + fun reducerOutcomeOkEmptyDecode() { + val w = BsatnWriter(4) + w.writeTag(1u) // OkEmpty + val outcome = ReducerOutcome.read(BsatnReader(w.toByteArray())) + assertTrue(outcome is ReducerOutcome.OkEmpty) + } + + @Test + fun reducerOutcomeErrDecode() { + val w = BsatnWriter(64) + w.writeTag(2u) // Err + w.writeByteArray("reducer panicked".encodeToByteArray()) + val outcome = ReducerOutcome.read(BsatnReader(w.toByteArray())) + assertTrue(outcome is ReducerOutcome.Err) + assertEquals("reducer panicked", outcome.message.decodeToString()) + } + + @Test + fun reducerOutcomeInternalErrorDecode() { + val w = BsatnWriter(64) + w.writeTag(3u) // InternalError + w.writeString("internal server error") + val outcome = ReducerOutcome.read(BsatnReader(w.toByteArray())) + assertTrue(outcome is ReducerOutcome.InternalError) + assertEquals("internal server error", outcome.message) + } + + @Test + fun reducerOutcomeInvalidTagThrows() { + val w = BsatnWriter(4) + w.writeTag(99u) + assertFailsWith { + ReducerOutcome.read(BsatnReader(w.toByteArray())) + } + } + + // ──────────────── ReducerResult ServerMessage ──────────────── + + @Test + fun serverMessageReducerResultFullDecode() { + val w = BsatnWriter(128) + w.writeTag(6u) // ReducerResult tag + w.writeU32(7u) // requestId + w.writeI64(1_700_000_000_000_000L) // timestamp + w.writeTag(1u) // ReducerOutcome::OkEmpty + val msg = ServerMessage.decode(w.toByteArray()) + assertTrue(msg is ServerMessage.ReducerResult) + assertEquals(7u, msg.requestId) + assertEquals(1_700_000_000_000_000L, msg.timestamp.microseconds) + assertTrue(msg.result is ReducerOutcome.OkEmpty) + } + + @Test + fun serverMessageReducerResultWithErr() { + val w = BsatnWriter(128) + w.writeTag(6u) + w.writeU32(99u) + w.writeI64(0L) + w.writeTag(3u) // InternalError + w.writeString("boom") + val msg = ServerMessage.decode(w.toByteArray()) + assertTrue(msg is ServerMessage.ReducerResult) + val result = msg.result + assertTrue(result is ReducerOutcome.InternalError) + assertEquals("boom", result.message) + } + + // ──── ReducerResult: Err/InternalError must NOT update cache ──── + + @Test + fun reducerErrDoesNotUpdateCache() { + val cache = ClientCache() + val table = cache.getOrCreateTable("test") + table.insertRow(byteArrayOf(1, 2, 3)) + assertEquals(1, table.count) + + // Simulate: ReducerOutcome.Err should NOT apply any cache update + // (The DbConnection code checks `msg.result is ReducerOutcome.Ok` before applying) + // This test validates the logic by directly testing the guard condition + val errOutcome: ReducerOutcome = ReducerOutcome.Err("fail".encodeToByteArray()) + assertFalse(errOutcome is ReducerOutcome.Ok) + + val emptyOutcome: ReducerOutcome = ReducerOutcome.OkEmpty + assertFalse(emptyOutcome is ReducerOutcome.Ok) + + // Cache unchanged + assertEquals(1, table.count) + } + + // ──────────────── ProcedureStatus decode ──────────────── + + @Test + fun procedureStatusReturnedDecode() { + val w = BsatnWriter(32) + w.writeTag(0u) // Returned + w.writeByteArray(byteArrayOf(0xAB.toByte(), 0xCD.toByte())) + val status = ProcedureStatus.read(BsatnReader(w.toByteArray())) + assertTrue(status is ProcedureStatus.Returned) + assertEquals(2, status.data.size) + } + + @Test + fun procedureStatusInternalErrorDecode() { + val w = BsatnWriter(32) + w.writeTag(1u) // InternalError + w.writeString("proc failed") + val status = ProcedureStatus.read(BsatnReader(w.toByteArray())) + assertTrue(status is ProcedureStatus.InternalError) + assertEquals("proc failed", status.message) + } + + @Test + fun procedureStatusInvalidTagThrows() { + val w = BsatnWriter(4) + w.writeTag(5u) + assertFailsWith { + ProcedureStatus.read(BsatnReader(w.toByteArray())) + } + } + + // ──────────────── ServerMessage: Invalid tag ──────────────── + + @Test + fun serverMessageInvalidTagThrows() { + val w = BsatnWriter(4) + w.writeTag(200u) // invalid + assertFailsWith { + ServerMessage.decode(w.toByteArray()) + } + } + + // ──────────── SubscriptionError with null requestId ────────── + + @Test + fun subscriptionErrorWithNullRequestId() { + val w = BsatnWriter(64) + w.writeTag(3u) // SubscriptionError + w.writeTag(0u) // Option::None for requestId + w.writeU32(42u) // querySetId + w.writeString("bad query syntax") + val msg = ServerMessage.decode(w.toByteArray()) + assertTrue(msg is ServerMessage.SubscriptionError) + assertNull(msg.requestId) + assertEquals(QuerySetId(42u), msg.querySetId) + assertEquals("bad query syntax", msg.error) + } + + @Test + fun subscriptionErrorWithRequestId() { + val w = BsatnWriter(64) + w.writeTag(3u) // SubscriptionError + w.writeTag(1u) // Option::Some + w.writeU32(7u) // requestId + w.writeU32(42u) // querySetId + w.writeString("table not found") + val msg = ServerMessage.decode(w.toByteArray()) + assertTrue(msg is ServerMessage.SubscriptionError) + assertEquals(7u, msg.requestId) + } + + // ──────────── UnsubscribeApplied with null rows ────────── + + @Test + fun unsubscribeAppliedWithNullRows() { + val w = BsatnWriter(32) + w.writeTag(2u) // UnsubscribeApplied + w.writeU32(5u) // requestId + w.writeU32(3u) // querySetId + w.writeTag(0u) // Option::None for rows + val msg = ServerMessage.decode(w.toByteArray()) + assertTrue(msg is ServerMessage.UnsubscribeApplied) + assertNull(msg.rows) + } + + // ──────── Cache: Update detection edge cases ──────── + + @Test + fun cacheUpdateDetectionDeleteAndInsertSameBytes() { + // When delete + insert have same content → Update + val cache = ClientCache() + cache.getOrCreateTable("t") + val row = byteArrayOf(1, 2, 3) + cache.getOrCreateTable("t").insertRow(row) + + val ops = applyPersistentOps(cache, "t", + inserts = listOf(row), + deletes = listOf(row), + ) + assertEquals(1, ops.size) + assertTrue(ops[0] is TableOperation.Update) + } + + @Test + fun cacheDeleteWithoutMatchingInsert() { + val cache = ClientCache() + val row = byteArrayOf(1, 2, 3) + cache.getOrCreateTable("t").insertRow(row) + + val ops = applyPersistentOps(cache, "t", + inserts = emptyList(), + deletes = listOf(row), + ) + assertEquals(1, ops.size) + assertTrue(ops[0] is TableOperation.Delete) + assertEquals(0, cache.getOrCreateTable("t").count) + } + + @Test + fun cacheInsertWithoutMatchingDelete() { + val cache = ClientCache() + cache.getOrCreateTable("t") + + val ops = applyPersistentOps(cache, "t", + inserts = listOf(byteArrayOf(1, 2, 3)), + deletes = emptyList(), + ) + assertEquals(1, ops.size) + assertTrue(ops[0] is TableOperation.Insert) + assertEquals(1, cache.getOrCreateTable("t").count) + } + + @Test + fun cacheEmptyTransaction() { + val cache = ClientCache() + cache.getOrCreateTable("t") + val ops = applyPersistentOps(cache, "t", + inserts = emptyList(), + deletes = emptyList(), + ) + assertTrue(ops.isEmpty()) + } + + @Test + fun cacheRefCountOverlappingSubscriptions() { + // Two subscriptions insert same row → refCount=2 + val table = TableCache("test") + val row = byteArrayOf(10, 20, 30) + table.insertRow(row) // sub 1 + table.insertRow(row) // sub 2 + assertEquals(1, table.count, "Same content, single entry") + assertTrue(table.containsRow(row)) + + // Unsub 1: refCount=1, row stays + table.deleteRow(row) + assertEquals(1, table.count) + assertTrue(table.containsRow(row)) + + // Unsub 2: refCount=0, row removed + table.deleteRow(row) + assertEquals(0, table.count) + assertFalse(table.containsRow(row)) + } + + @Test + fun cacheDeleteNonExistentRow() { + val table = TableCache("test") + val result = table.deleteRow(byteArrayOf(99)) + assertFalse(result, "Deleting non-existent row should return false") + assertEquals(0, table.count) + } + + // ──────── Callback re-entrance safety ──────── + + @Test + fun callbackCanRegisterAnotherCallbackDuringFire() { + val handle = TableHandle("test") + var secondCallbackFired = false + + handle.onInsert { _ -> + // Register a new callback from within a callback + handle.onInsert { _ -> secondCallbackFired = true } + } + + // First fire: triggers the registration callback + handle.fireInsert(byteArrayOf(1)) + assertFalse(secondCallbackFired, "Newly registered callback should not fire in same event") + + // Second fire: both callbacks fire + handle.fireInsert(byteArrayOf(2)) + assertTrue(secondCallbackFired, "Second callback should fire on next event") + } + + @Test + fun callbackCanRemoveItselfDuringFire() { + val handle = TableHandle("test") + var fireCount = 0 + var selfId: CallbackId? = null + + selfId = handle.onInsert { _ -> + fireCount++ + handle.removeOnInsert(selfId!!) + } + + handle.fireInsert(byteArrayOf(1)) + assertEquals(1, fireCount) + + handle.fireInsert(byteArrayOf(2)) + assertEquals(1, fireCount, "Removed callback should not fire again") + } + + // ──────── Subscription lifecycle states ──────── + + @Test + fun subscriptionStateLifecycle() { + // Can't create a real DbConnection without a server, but we can test + // the SubscriptionHandle state machine directly + val handle = SubscriptionHandle( + connection = stubConnection(), + onAppliedCallback = null, + onErrorCallback = null, + ) + assertEquals(SubscriptionState.PENDING, handle.state) + assertFalse(handle.isActive) + assertFalse(handle.isEnded) + + handle.state = SubscriptionState.ACTIVE + assertTrue(handle.isActive) + assertFalse(handle.isEnded) + + handle.state = SubscriptionState.ENDED + assertFalse(handle.isActive) + assertTrue(handle.isEnded) + } + + @Test + fun doubleUnsubscribeIsSafe() { + val handle = SubscriptionHandle( + connection = stubConnection(), + onAppliedCallback = null, + onErrorCallback = null, + ) + handle.state = SubscriptionState.ACTIVE + handle.unsubscribe() // First: transitions to ENDED + assertTrue(handle.isEnded) + handle.unsubscribe() // Second: no-op, no crash + assertTrue(handle.isEnded) + } + + @Test + fun unsubscribeOnPendingIsNoOp() { + val handle = SubscriptionHandle( + connection = stubConnection(), + onAppliedCallback = null, + onErrorCallback = null, + ) + assertEquals(SubscriptionState.PENDING, handle.state) + handle.unsubscribe() // Should be a no-op since not ACTIVE + assertEquals(SubscriptionState.PENDING, handle.state) + } + + // ──────── URI scheme normalization ──────── + + @Test + fun uriSchemeNormalization() { + // Test the URI building logic by encoding/decoding the buildWsUri output + // We'll test the WebSocketTransport.buildWsUri indirectly via pattern matching + val testCases = mapOf( + "http://localhost:3000" to "ws://", + "https://example.com" to "wss://", + "ws://localhost:3000" to "ws://", + "wss://example.com" to "wss://", + "localhost:3000" to "ws://", + ) + // These are validated by the WebSocketTransport.buildWsUri method + // which is private — we verify the logic patterns match + for ((input, expectedPrefix) in testCases) { + val base = input.trimEnd('/') + val wsBase = when { + base.startsWith("ws://") || base.startsWith("wss://") -> base + base.startsWith("http://") -> "ws://" + base.removePrefix("http://") + base.startsWith("https://") -> "wss://" + base.removePrefix("https://") + else -> "ws://$base" + } + assertTrue(wsBase.startsWith(expectedPrefix), "Input '$input' should start with '$expectedPrefix', got '$wsBase'") + } + } + + // ──────── BSATN: Boundary values ──────── + + @Test + fun bsatnBoundaryValues() { + val w = BsatnWriter(128) + // Unsigned extremes + w.writeU8(UByte.MIN_VALUE) + w.writeU8(UByte.MAX_VALUE) + w.writeU16(UShort.MIN_VALUE) + w.writeU16(UShort.MAX_VALUE) + w.writeU32(UInt.MIN_VALUE) + w.writeU32(UInt.MAX_VALUE) + w.writeU64(ULong.MIN_VALUE) + w.writeU64(ULong.MAX_VALUE) + // Signed extremes + w.writeI8(Byte.MIN_VALUE) + w.writeI8(Byte.MAX_VALUE) + w.writeI16(Short.MIN_VALUE) + w.writeI16(Short.MAX_VALUE) + w.writeI32(Int.MIN_VALUE) + w.writeI32(Int.MAX_VALUE) + w.writeI64(Long.MIN_VALUE) + w.writeI64(Long.MAX_VALUE) + // Float specials + w.writeF32(Float.NaN) + w.writeF32(Float.POSITIVE_INFINITY) + w.writeF32(Float.NEGATIVE_INFINITY) + w.writeF32(0.0f) + w.writeF32(-0.0f) + w.writeF64(Double.NaN) + w.writeF64(Double.POSITIVE_INFINITY) + w.writeF64(Double.NEGATIVE_INFINITY) + + val r = BsatnReader(w.toByteArray()) + assertEquals(UByte.MIN_VALUE, r.readU8()) + assertEquals(UByte.MAX_VALUE, r.readU8()) + assertEquals(UShort.MIN_VALUE, r.readU16()) + assertEquals(UShort.MAX_VALUE, r.readU16()) + assertEquals(UInt.MIN_VALUE, r.readU32()) + assertEquals(UInt.MAX_VALUE, r.readU32()) + assertEquals(ULong.MIN_VALUE, r.readU64()) + assertEquals(ULong.MAX_VALUE, r.readU64()) + assertEquals(Byte.MIN_VALUE, r.readI8()) + assertEquals(Byte.MAX_VALUE, r.readI8()) + assertEquals(Short.MIN_VALUE, r.readI16()) + assertEquals(Short.MAX_VALUE, r.readI16()) + assertEquals(Int.MIN_VALUE, r.readI32()) + assertEquals(Int.MAX_VALUE, r.readI32()) + assertEquals(Long.MIN_VALUE, r.readI64()) + assertEquals(Long.MAX_VALUE, r.readI64()) + assertTrue(r.readF32().isNaN()) + assertEquals(Float.POSITIVE_INFINITY, r.readF32()) + assertEquals(Float.NEGATIVE_INFINITY, r.readF32()) + assertEquals(0.0f, r.readF32()) + // -0.0f == 0.0f in Kotlin, compare bits + assertEquals((-0.0f).toRawBits(), r.readF32().toRawBits()) + assertTrue(r.readF64().isNaN()) + assertEquals(Double.POSITIVE_INFINITY, r.readF64()) + assertEquals(Double.NEGATIVE_INFINITY, r.readF64()) + assertTrue(r.isExhausted) + } + + @Test + fun bsatnEmptyString() { + val w = BsatnWriter(8) + w.writeString("") + val r = BsatnReader(w.toByteArray()) + assertEquals("", r.readString()) + } + + @Test + fun bsatnEmptyByteArray() { + val w = BsatnWriter(8) + w.writeByteArray(byteArrayOf()) + val r = BsatnReader(w.toByteArray()) + val bytes = r.readByteArray() + assertEquals(0, bytes.size) + } + + @Test + fun bsatnEmptyArray() { + val w = BsatnWriter(8) + w.writeArray(emptyList()) { wr, s -> wr.writeString(s) } + val r = BsatnReader(w.toByteArray()) + val list = r.readArray { it.readString() } + assertTrue(list.isEmpty()) + } + + @Test + fun bsatnOptionNoneAndSome() { + val w = BsatnWriter(16) + w.writeOption(null) { wr, v: String -> wr.writeString(v) } + w.writeOption("hello") { wr, v -> wr.writeString(v) } + + val r = BsatnReader(w.toByteArray()) + assertNull(r.readOption { it.readString() }) + assertEquals("hello", r.readOption { it.readString() }) + } + + @Test + fun bsatnReaderUnderflowThrows() { + val r = BsatnReader(byteArrayOf(1, 2)) + r.readU8() // ok + r.readU8() // ok + assertFailsWith { + r.readU8() // no bytes left + } + } + + @Test + fun bsatnReaderReadMoreThanAvailableThrows() { + val r = BsatnReader(byteArrayOf(1, 2, 3)) + assertFailsWith { + r.readU32() // needs 4 bytes, only 3 available + } + } + + // ──────── Compression tag handling ──────── + + @Test + fun compressionTagUnknownThrows() { + // Tag 0 = uncompressed, 1 = brotli, 2 = gzip. Tag 3+ should throw. + val data = byteArrayOf(3, 0, 0, 0) + assertFailsWith { + decompressWithTag(data) + } + } + + @Test + fun compressionTagUncompressed() { + val payload = byteArrayOf(0, 1, 2, 3, 4) // tag 0 + data + val result = decompressWithTag(payload) + assertEquals(4, result.size) + assertEquals(1.toByte(), result[0]) + } + + // ──────── Identity edge cases ──────── + + @Test + fun identityWrongSizeThrows() { + assertFailsWith { + Identity(ByteArray(16)) // needs 32 + } + } + + @Test + fun connectionIdWrongSizeThrows() { + assertFailsWith { + ConnectionId(ByteArray(8)) // needs 16 + } + } + + @Test + fun addressWrongSizeThrows() { + assertFailsWith { + Address(ByteArray(32)) // needs 16 + } + } + + @Test + fun identityZero() { + assertTrue(Identity.ZERO.bytes.all { it == 0.toByte() }) + assertEquals(32, Identity.ZERO.bytes.size) + } + + @Test + fun identityHexRoundTrip() { + val hex = "0123456789abcdef" .repeat(4) + val id = Identity.fromHex(hex) + assertEquals(hex, id.toHex()) + } + + @Test + fun identityFromHexWrongLengthThrows() { + assertFailsWith { + Identity.fromHex("0123") // needs 64 hex chars + } + } + + @Test + fun identityEquality() { + val a = Identity(ByteArray(32) { it.toByte() }) + val b = Identity(ByteArray(32) { it.toByte() }) + val c = Identity(ByteArray(32) { 0 }) + assertEquals(a, b) + assertEquals(a.hashCode(), b.hashCode()) + assertFalse(a == c) + } + + // ──────── ClientMessage encode edge cases ──────── + + @Test + fun callReducerEmptyArgs() { + val msg = ClientMessage.CallReducer( + requestId = 1u, + reducer = "no_args_reducer", + args = byteArrayOf(), + ) + val encoded = msg.encode() + val r = BsatnReader(encoded) + assertEquals(3, r.readTag().toInt()) // CallReducer tag + assertEquals(1u, r.readU32()) + assertEquals(0.toUByte(), r.readU8()) // flags + assertEquals("no_args_reducer", r.readString()) + val args = r.readByteArray() + assertEquals(0, args.size) + } + + @Test + fun callReducerEquality() { + val a = ClientMessage.CallReducer(1u, "test", byteArrayOf(1, 2, 3)) + val b = ClientMessage.CallReducer(1u, "test", byteArrayOf(1, 2, 3)) + val c = ClientMessage.CallReducer(1u, "test", byteArrayOf(4, 5, 6)) + assertEquals(a, b) + assertEquals(a.hashCode(), b.hashCode()) + assertFalse(a == c) + } + + @Test + fun unsubscribeWithSendDroppedRowsFlag() { + val msg = ClientMessage.Unsubscribe( + requestId = 5u, + querySetId = QuerySetId(10u), + flags = 1u, // SendDroppedRows + ) + val encoded = msg.encode() + val r = BsatnReader(encoded) + assertEquals(1, r.readTag().toInt()) // Unsubscribe tag + assertEquals(5u, r.readU32()) + assertEquals(10u, r.readU32()) // querySetId + assertEquals(1.toUByte(), r.readU8()) // flags = SendDroppedRows + } + + // ──────── DbConnectionBuilder validation ──────── + + @Test + fun builderWithoutUriThrows() { + assertFailsWith { + DbConnection.builder() + .withModuleName("test") + .build() + } + } + + @Test + fun builderWithoutModuleNameThrows() { + assertFailsWith { + DbConnection.builder() + .withUri("ws://localhost:3000") + .build() + } + } + + // ──────── ByteArrayWrapper edge cases ──────── + + @Test + fun byteArrayWrapperEquality() { + val a = ByteArrayWrapper(byteArrayOf(1, 2, 3)) + val b = ByteArrayWrapper(byteArrayOf(1, 2, 3)) + val c = ByteArrayWrapper(byteArrayOf(3, 2, 1)) + assertEquals(a, b) + assertEquals(a.hashCode(), b.hashCode()) + assertFalse(a == c) + } + + @Test + fun byteArrayWrapperEmptyArrays() { + val a = ByteArrayWrapper(byteArrayOf()) + val b = ByteArrayWrapper(byteArrayOf()) + assertEquals(a, b) + } + + @Test + fun byteArrayWrapperNotEqualToOtherTypes() { + val a = ByteArrayWrapper(byteArrayOf(1)) + assertFalse(a.equals("string")) + assertFalse(a.equals(null)) + } + + // ──────── Helpers ──────── + + private fun applyPersistentOps( + cache: ClientCache, + tableName: String, + inserts: List, + deletes: List, + ): List { + val w = BsatnWriter(1024) + w.writeU32(1u) // querySetId + w.writeU32(1u) // 1 table + w.writeString(tableName) + w.writeU32(1u) // 1 row update + w.writeTag(0u) // PersistentTable + // inserts BsatnRowList + writeRowList(w, inserts) + // deletes BsatnRowList + writeRowList(w, deletes) + + val qsUpdate = QuerySetUpdate.read(BsatnReader(w.toByteArray())) + return cache.applyTransactionUpdate(listOf(qsUpdate)) + } + + private fun writeRowList(w: BsatnWriter, rows: List) { + w.writeTag(0u) // FixedSize hint + if (rows.isEmpty()) { + w.writeU16(0u) + w.writeU32(0u) + } else { + val rowSize = rows.first().size + w.writeU16(rowSize.toUShort()) + w.writeU32((rowSize * rows.size).toUInt()) + for (row in rows) w.writeBytes(row) + } + } + + private fun decompressWithTag(data: ByteArray): ByteArray { + if (data.isEmpty()) return data + val tag = data[0].toUByte().toInt() + val payload = data.copyOfRange(1, data.size) + return when (tag) { + 0 -> payload + 1 -> decompressBrotli(payload) + 2 -> decompressGzip(payload) + else -> throw IllegalStateException("Unknown compression tag: $tag") + } + } + + // Stub connection that doesn't actually connect (for subscription state tests) + private fun stubConnection(): DbConnection { + // We only need a DbConnection object for the SubscriptionHandle reference. + // The builder validation requires URI and module name. + // This will attempt to connect but we don't care — we only test handle state. + return DbConnection( + uri = "ws://invalid.test:0", + moduleName = "test", + token = null, + connectCallbacks = emptyList(), + disconnectCallbacks = emptyList(), + connectErrorCallbacks = emptyList(), + keepAliveIntervalMs = 0, + ) + } +} diff --git a/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/OneOffQueryTest.kt b/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/OneOffQueryTest.kt new file mode 100644 index 00000000000..cfd1dc7a367 --- /dev/null +++ b/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/OneOffQueryTest.kt @@ -0,0 +1,82 @@ +package com.clockworklabs.spacetimedb + +import com.clockworklabs.spacetimedb.bsatn.BsatnWriter +import com.clockworklabs.spacetimedb.protocol.ServerMessage +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertNull +import kotlin.test.assertTrue + +class OneOffQueryTest { + + @Test + fun decodeOneOffQueryOk() { + val writer = BsatnWriter() + // ServerMessage tag 5 = OneOffQueryResult + writer.writeTag(5u) + // requestId + writer.writeU32(42u) + // Result tag 0 = Ok(QueryRows) + writer.writeTag(0u) + // QueryRows: array of SingleTableRows (empty) + writer.writeU32(0u) + + val msg = ServerMessage.decode(writer.toByteArray()) + assertTrue(msg is ServerMessage.OneOffQueryResult) + assertEquals(42u, msg.requestId) + assertNotNull(msg.rows) + assertEquals(0, msg.rows!!.tables.size) + assertNull(msg.error) + } + + @Test + fun decodeOneOffQueryErr() { + val writer = BsatnWriter() + // ServerMessage tag 5 = OneOffQueryResult + writer.writeTag(5u) + // requestId + writer.writeU32(99u) + // Result tag 1 = Err(string) + writer.writeTag(1u) + writer.writeString("table not found") + + val msg = ServerMessage.decode(writer.toByteArray()) + assertTrue(msg is ServerMessage.OneOffQueryResult) + assertEquals(99u, msg.requestId) + assertNull(msg.rows) + assertEquals("table not found", msg.error) + } + + @Test + fun decodeOneOffQueryOkWithRows() { + val writer = BsatnWriter() + writer.writeTag(5u) + writer.writeU32(7u) + // Result tag 0 = Ok + writer.writeTag(0u) + // QueryRows: 1 table + writer.writeU32(1u) + // SingleTableRows: table name (RawIdentifier = string) + writer.writeString("users") + // BsatnRowList: RowSizeHint (tag 0 = FixedSize) + writer.writeTag(0u) + writer.writeU16(4u) + // rowsData: 2 rows of 4 bytes each = 8 bytes + val rowsData = byteArrayOf(1, 2, 3, 4, 5, 6, 7, 8) + writer.writeByteArray(rowsData) + + val msg = ServerMessage.decode(writer.toByteArray()) + assertTrue(msg is ServerMessage.OneOffQueryResult) + assertEquals(7u, msg.requestId) + assertNotNull(msg.rows) + assertEquals(1, msg.rows!!.tables.size) + assertEquals("users", msg.rows!!.tables[0].table.value) + + val decodedRows = msg.rows!!.tables[0].rows.decodeRows() + assertEquals(2, decodedRows.size) + assertTrue(byteArrayOf(1, 2, 3, 4).contentEquals(decodedRows[0])) + assertTrue(byteArrayOf(5, 6, 7, 8).contentEquals(decodedRows[1])) + assertNull(msg.error) + } +} diff --git a/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/ProtocolTest.kt b/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/ProtocolTest.kt new file mode 100644 index 00000000000..7891cc9cf2e --- /dev/null +++ b/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/ProtocolTest.kt @@ -0,0 +1,129 @@ +package com.clockworklabs.spacetimedb + +import com.clockworklabs.spacetimedb.bsatn.BsatnReader +import com.clockworklabs.spacetimedb.bsatn.BsatnWriter +import com.clockworklabs.spacetimedb.protocol.ClientMessage +import com.clockworklabs.spacetimedb.protocol.QuerySetId +import com.clockworklabs.spacetimedb.protocol.ServerMessage +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class ProtocolTest { + + @Test + fun encodeSubscribeMessage() { + val msg = ClientMessage.Subscribe( + requestId = 1u, + querySetId = QuerySetId(100u), + queryStrings = listOf("SELECT * FROM users"), + ) + val bytes = msg.encode() + val reader = BsatnReader(bytes) + assertEquals(0, reader.readTag().toInt()) + assertEquals(1u, reader.readU32()) + assertEquals(100u, reader.readU32()) + val queries = reader.readArray { it.readString() } + assertEquals(listOf("SELECT * FROM users"), queries) + } + + @Test + fun encodeCallReducerMessage() { + val args = byteArrayOf(10, 20, 30) + val msg = ClientMessage.CallReducer( + requestId = 5u, + reducer = "add_user", + args = args, + ) + val bytes = msg.encode() + val reader = BsatnReader(bytes) + assertEquals(3, reader.readTag().toInt()) + assertEquals(5u, reader.readU32()) + assertEquals(0u.toUByte(), reader.readU8()) + assertEquals("add_user", reader.readString()) + assertTrue(args.contentEquals(reader.readByteArray())) + } + + @Test + fun encodeUnsubscribeMessage() { + val msg = ClientMessage.Unsubscribe( + requestId = 2u, + querySetId = QuerySetId(50u), + ) + val bytes = msg.encode() + val reader = BsatnReader(bytes) + assertEquals(1, reader.readTag().toInt()) + assertEquals(2u, reader.readU32()) + assertEquals(50u, reader.readU32()) + } + + @Test + fun encodeOneOffQueryMessage() { + val msg = ClientMessage.OneOffQuery( + requestId = 3u, + queryString = "SELECT count(*) FROM users", + ) + val bytes = msg.encode() + val reader = BsatnReader(bytes) + assertEquals(2, reader.readTag().toInt()) + assertEquals(3u, reader.readU32()) + assertEquals("SELECT count(*) FROM users", reader.readString()) + } + + @Test + fun decodeInitialConnection() { + val writer = BsatnWriter() + writer.writeTag(0u) + writer.writeBytes(ByteArray(32) { it.toByte() }) + writer.writeBytes(ByteArray(16) { (it + 100).toByte() }) + writer.writeString("test-token-abc") + + val msg = ServerMessage.decode(writer.toByteArray()) + assertTrue(msg is ServerMessage.InitialConnection) + assertEquals("test-token-abc", msg.token) + assertEquals(ByteArray(32) { it.toByte() }.toList(), msg.identity.bytes.toList()) + assertEquals(ByteArray(16) { (it + 100).toByte() }.toList(), msg.connectionId.bytes.toList()) + } + + @Test + fun identityFromHex() { + val hex = "00" + "01" + "02" + "03" + "04" + "05" + "06" + "07" + + "08" + "09" + "0a" + "0b" + "0c" + "0d" + "0e" + "0f" + + "10" + "11" + "12" + "13" + "14" + "15" + "16" + "17" + + "18" + "19" + "1a" + "1b" + "1c" + "1d" + "1e" + "1f" + val identity = Identity.fromHex(hex) + assertEquals(0, identity.bytes[0].toInt()) + assertEquals(31, identity.bytes[31].toInt()) + assertEquals(hex, identity.toHex()) + } + + @Test + fun identityBsatnRoundTrip() { + val original = Identity(ByteArray(32) { (it * 7).toByte() }) + val writer = BsatnWriter() + Identity.write(writer, original) + val reader = BsatnReader(writer.toByteArray()) + val decoded = Identity.read(reader) + assertEquals(original, decoded) + } + + @Test + fun connectionIdBsatnRoundTrip() { + val original = ConnectionId(ByteArray(16) { (it * 3).toByte() }) + val writer = BsatnWriter() + ConnectionId.write(writer, original) + val reader = BsatnReader(writer.toByteArray()) + val decoded = ConnectionId.read(reader) + assertEquals(original, decoded) + } + + @Test + fun timestampBsatnRoundTrip() { + val original = Timestamp(1234567890123L) + val writer = BsatnWriter() + Timestamp.write(writer, original) + val reader = BsatnReader(writer.toByteArray()) + val decoded = Timestamp.read(reader) + assertEquals(original, decoded) + } +} diff --git a/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/ReconnectPolicyTest.kt b/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/ReconnectPolicyTest.kt new file mode 100644 index 00000000000..a16d93e1523 --- /dev/null +++ b/sdks/kotlin/src/commonTest/kotlin/com/clockworklabs/spacetimedb/ReconnectPolicyTest.kt @@ -0,0 +1,84 @@ +package com.clockworklabs.spacetimedb + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith + +class ReconnectPolicyTest { + + @Test + fun defaultPolicy() { + val policy = ReconnectPolicy() + assertEquals(5, policy.maxRetries) + assertEquals(1_000L, policy.initialDelayMs) + assertEquals(30_000L, policy.maxDelayMs) + assertEquals(2.0, policy.backoffMultiplier) + } + + @Test + fun delayForAttemptExponentialBackoff() { + val policy = ReconnectPolicy( + initialDelayMs = 1_000, + maxDelayMs = 60_000, + backoffMultiplier = 2.0, + ) + assertEquals(1_000L, policy.delayForAttempt(0)) + assertEquals(2_000L, policy.delayForAttempt(1)) + assertEquals(4_000L, policy.delayForAttempt(2)) + assertEquals(8_000L, policy.delayForAttempt(3)) + assertEquals(16_000L, policy.delayForAttempt(4)) + } + + @Test + fun delayClampedToMax() { + val policy = ReconnectPolicy( + initialDelayMs = 1_000, + maxDelayMs = 5_000, + backoffMultiplier = 3.0, + ) + assertEquals(1_000L, policy.delayForAttempt(0)) + assertEquals(3_000L, policy.delayForAttempt(1)) + assertEquals(5_000L, policy.delayForAttempt(2)) // clamped: 9_000 -> 5_000 + assertEquals(5_000L, policy.delayForAttempt(3)) // stays clamped + } + + @Test + fun noBackoff() { + val policy = ReconnectPolicy( + initialDelayMs = 500, + maxDelayMs = 500, + backoffMultiplier = 1.0, + ) + assertEquals(500L, policy.delayForAttempt(0)) + assertEquals(500L, policy.delayForAttempt(1)) + assertEquals(500L, policy.delayForAttempt(5)) + } + + @Test + fun invalidMaxRetriesThrows() { + assertFailsWith { + ReconnectPolicy(maxRetries = -1) + } + } + + @Test + fun invalidInitialDelayThrows() { + assertFailsWith { + ReconnectPolicy(initialDelayMs = 0) + } + } + + @Test + fun maxDelayLessThanInitialThrows() { + assertFailsWith { + ReconnectPolicy(initialDelayMs = 5_000, maxDelayMs = 1_000) + } + } + + @Test + fun backoffMultiplierLessThanOneThrows() { + assertFailsWith { + ReconnectPolicy(backoffMultiplier = 0.5) + } + } +} diff --git a/sdks/kotlin/src/iosMain/kotlin/com/clockworklabs/spacetimedb/Compression.ios.kt b/sdks/kotlin/src/iosMain/kotlin/com/clockworklabs/spacetimedb/Compression.ios.kt new file mode 100644 index 00000000000..41522cfbec8 --- /dev/null +++ b/sdks/kotlin/src/iosMain/kotlin/com/clockworklabs/spacetimedb/Compression.ios.kt @@ -0,0 +1,88 @@ +package com.clockworklabs.spacetimedb + +import kotlinx.cinterop.ExperimentalForeignApi +import kotlinx.cinterop.addressOf +import kotlinx.cinterop.alloc +import kotlinx.cinterop.free +import kotlinx.cinterop.nativeHeap +import kotlinx.cinterop.ptr +import kotlinx.cinterop.reinterpret +import kotlinx.cinterop.usePinned +import kotlinx.cinterop.value +import platform.zlib.Z_FINISH +import platform.zlib.Z_OK +import platform.zlib.Z_STREAM_END +import platform.zlib.inflate +import platform.zlib.inflateEnd +import platform.zlib.inflateInit2 +import platform.zlib.z_stream + +actual fun decompressBrotli(data: ByteArray): ByteArray { + // Brotli decompression requires Apple's Compression framework interop or a bundled decoder. + // The SDK defaults to Gzip compression (see buildWsUri), so Brotli is not expected. + // If a server sends Brotli, this will surface the issue clearly. + throw UnsupportedOperationException( + "Brotli decompression is not available on iOS. " + + "Configure the server connection to use Gzip compression instead." + ) +} + +@OptIn(ExperimentalForeignApi::class) +actual fun decompressGzip(data: ByteArray): ByteArray { + if (data.isEmpty()) return data + + val stream = nativeHeap.alloc() + try { + stream.zalloc = null + stream.zfree = null + stream.opaque = null + stream.avail_in = 0u + stream.next_in = null + + // wbits = MAX_WBITS + 16 (31) tells zlib to expect gzip format + val initResult = inflateInit2(stream.ptr, 31) + if (initResult != Z_OK) { + throw IllegalStateException("zlib inflateInit2 failed: $initResult") + } + + val chunks = mutableListOf() + val outBuf = ByteArray(8192) + + data.usePinned { srcPinned -> + stream.next_in = srcPinned.addressOf(0).reinterpret() + stream.avail_in = data.size.toUInt() + + do { + outBuf.usePinned { dstPinned -> + stream.next_out = dstPinned.addressOf(0).reinterpret() + stream.avail_out = outBuf.size.toUInt() + + val ret = inflate(stream.ptr, Z_FINISH) + if (ret != Z_OK && ret != Z_STREAM_END) { + inflateEnd(stream.ptr) + throw IllegalStateException("zlib inflate failed: $ret") + } + + val produced = outBuf.size - stream.avail_out.toInt() + if (produced > 0) { + chunks.add(outBuf.copyOf(produced)) + } + } + } while (stream.avail_out == 0u) + } + + inflateEnd(stream.ptr) + + // Concatenate chunks + val totalSize = chunks.sumOf { it.size } + val result = ByteArray(totalSize) + var offset = 0 + for (chunk in chunks) { + chunk.copyInto(result, offset) + offset += chunk.size + } + return result + } finally { + nativeHeap.free(stream) + } +} diff --git a/sdks/kotlin/src/jvmMain/kotlin/com/clockworklabs/spacetimedb/Compression.jvm.kt b/sdks/kotlin/src/jvmMain/kotlin/com/clockworklabs/spacetimedb/Compression.jvm.kt new file mode 100644 index 00000000000..f6d9df24e53 --- /dev/null +++ b/sdks/kotlin/src/jvmMain/kotlin/com/clockworklabs/spacetimedb/Compression.jvm.kt @@ -0,0 +1,28 @@ +package com.clockworklabs.spacetimedb + +import org.brotli.dec.BrotliInputStream +import java.io.ByteArrayInputStream +import java.io.ByteArrayOutputStream +import java.util.zip.GZIPInputStream + +actual fun decompressBrotli(data: ByteArray): ByteArray { + ByteArrayInputStream(data).use { input -> + BrotliInputStream(input).use { brotli -> + ByteArrayOutputStream(data.size * 2).use { output -> + brotli.copyTo(output) + return output.toByteArray() + } + } + } +} + +actual fun decompressGzip(data: ByteArray): ByteArray { + ByteArrayInputStream(data).use { input -> + GZIPInputStream(input).use { gzip -> + ByteArrayOutputStream(data.size * 2).use { output -> + gzip.copyTo(output) + return output.toByteArray() + } + } + } +} diff --git a/sdks/kotlin/src/jvmTest/kotlin/com/clockworklabs/spacetimedb/CompressionTest.kt b/sdks/kotlin/src/jvmTest/kotlin/com/clockworklabs/spacetimedb/CompressionTest.kt new file mode 100644 index 00000000000..93628910275 --- /dev/null +++ b/sdks/kotlin/src/jvmTest/kotlin/com/clockworklabs/spacetimedb/CompressionTest.kt @@ -0,0 +1,63 @@ +package com.clockworklabs.spacetimedb + +import java.io.ByteArrayOutputStream +import java.util.zip.GZIPOutputStream +import kotlin.test.Test +import kotlin.test.assertTrue + +class CompressionTest { + + private fun gzipCompress(data: ByteArray): ByteArray { + val bos = ByteArrayOutputStream() + GZIPOutputStream(bos).use { it.write(data) } + return bos.toByteArray() + } + + @Test + fun gzipRoundTrip() { + val original = "Hello, SpacetimeDB! This is a test of gzip compression.".encodeToByteArray() + val compressed = gzipCompress(original) + val decompressed = decompressGzip(compressed) + assertTrue(original.contentEquals(decompressed), "Gzip round-trip failed") + } + + @Test + fun gzipEmptyPayload() { + val original = ByteArray(0) + val compressed = gzipCompress(original) + val decompressed = decompressGzip(compressed) + assertTrue(original.contentEquals(decompressed), "Gzip empty round-trip failed") + } + + @Test + fun gzipLargePayload() { + val original = ByteArray(10_000) { (it % 256).toByte() } + val compressed = gzipCompress(original) + val decompressed = decompressGzip(compressed) + assertTrue(original.contentEquals(decompressed), "Gzip large payload round-trip failed") + } + + @Test + fun brotliRoundTrip() { + // Brotli-compressed "Hello" (pre-computed with brotli CLI) + // We test decompression only since the SDK only needs to decompress server messages + val original = "Hello".encodeToByteArray() + val compressed = brotliCompress(original) + val decompressed = decompressBrotli(compressed) + assertTrue(original.contentEquals(decompressed), "Brotli round-trip failed") + } + + private fun brotliCompress(data: ByteArray): ByteArray { + // Use org.brotli encoder if available, otherwise use a known compressed payload. + // The org.brotli:dec artifact only includes the decoder. + // Use JNI-free approach: manually construct a minimal brotli stream for "Hello" + // For robustness, we'll use the encoder from the test classpath if available. + // Minimal approach: test with a known brotli-compressed byte sequence. + // + // Pre-compressed "Hello" using brotli (metablock, uncompressed): + // This is a valid brotli stream that decompresses to "Hello" + return byteArrayOf( + 0x0b, 0x02, 0x80.toByte(), 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x03 + ) + } +} diff --git a/sdks/kotlin/src/jvmTest/kotlin/com/clockworklabs/spacetimedb/Keynote2BenchmarkTest.kt b/sdks/kotlin/src/jvmTest/kotlin/com/clockworklabs/spacetimedb/Keynote2BenchmarkTest.kt new file mode 100644 index 00000000000..293444c4e3d --- /dev/null +++ b/sdks/kotlin/src/jvmTest/kotlin/com/clockworklabs/spacetimedb/Keynote2BenchmarkTest.kt @@ -0,0 +1,201 @@ +package com.clockworklabs.spacetimedb + +import kotlinx.coroutines.* +import java.nio.ByteBuffer +import java.nio.ByteOrder +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.atomic.AtomicLong +import kotlin.test.Test + +/** + * Keynote-2 style TPS benchmark — fund transfers with pipelined reducer calls. + * + * Mirrors the Rust benchmark client at templates/keynote-2/spacetimedb-rust-client: + * - 10 WebSocket connections (no subscriptions) + * - Zipf-distributed account selection (alpha=0.5, 100k accounts) + * - Batched pipeline: fire 16384 reducer calls, await all responses, repeat + * - 5s warmup + 5s measurement + * + * Prerequisites: + * 1. `spacetime start` running on localhost:3000 + * 2. keynote-2 module published: `spacetime publish --server local sim` + * 3. Database seeded via Rust client: `spacetimedb-rust-transfer-sim seed` + * + * Set SPACETIMEDB_TEST=1 to enable. + */ +class Keynote2BenchmarkTest { + + private val serverUri = System.getenv("SPACETIMEDB_URI") ?: "ws://127.0.0.1:3000" + private val moduleName = System.getenv("SPACETIMEDB_MODULE") ?: "sim" + + private fun shouldRun(): Boolean = System.getenv("SPACETIMEDB_TEST") == "1" + + companion object { + const val ACCOUNTS = 100_000 + const val ALPHA = 0.5 + const val CONNECTIONS = 10 + const val MAX_INFLIGHT = 16_384 + const val WARMUP_MS = 5_000L + const val BENCH_MS = 5_000L + const val AMOUNT = 1 + const val TOTAL_PAIRS = 10_000_000 + } + + /** + * Zipf distribution sampler via inverse CDF with binary search. + * Produces integers in [0, n) with P(k) proportional to 1/(k+1)^alpha. + */ + private class ZipfSampler(n: Int, alpha: Double, seed: Long) { + private val cdf: DoubleArray + private val rng = java.util.Random(seed) + + init { + val weights = DoubleArray(n) { 1.0 / Math.pow((it + 1).toDouble(), alpha) } + val total = weights.sum() + cdf = DoubleArray(n) + var cumulative = 0.0 + for (i in weights.indices) { + cumulative += weights[i] / total + cdf[i] = cumulative + } + } + + fun sample(): Int { + val u = rng.nextDouble() + var lo = 0; var hi = cdf.size - 1 + while (lo < hi) { + val mid = (lo + hi) ushr 1 + if (cdf[mid] < u) lo = mid + 1 else hi = mid + } + return lo + } + } + + /** Pre-compute [TOTAL_PAIRS] transfer pairs using Zipf distribution. */ + private fun generateTransferPairs(from: IntArray, to: IntArray) { + val zipf = ZipfSampler(ACCOUNTS, ALPHA, 0x12345678L) + var idx = 0 + while (idx < TOTAL_PAIRS) { + val a = zipf.sample() + val b = zipf.sample() + if (a != b && a < ACCOUNTS && b < ACCOUNTS) { + from[idx] = a + to[idx] = b + idx++ + } + } + } + + /** BSATN-encode transfer args: (from: u32, to: u32, amount: u32) in little-endian. */ + private fun encodeTransfer(from: Int, to: Int, amount: Int): ByteArray { + val buf = ByteBuffer.allocate(12).order(ByteOrder.LITTLE_ENDIAN) + buf.putInt(from).putInt(to).putInt(amount) + return buf.array() + } + + @Test + fun keynote2Benchmark() { + if (!shouldRun()) { println("SKIP"); return } + + println("=== Kotlin SDK Keynote-2 Transfer Benchmark ===") + println("alpha=$ALPHA, amount=$AMOUNT, accounts=$ACCOUNTS") + println("max inflight reducers = $MAX_INFLIGHT") + println("connections = $CONNECTIONS") + println() + + // Pre-compute transfer pairs (matches Rust client's make_transfers) + print("Pre-computing transfer pairs... ") + val fromArr = IntArray(TOTAL_PAIRS) + val toArr = IntArray(TOTAL_PAIRS) + generateTransferPairs(fromArr, toArr) + println("done") + + val transfersPerWorker = TOTAL_PAIRS / CONNECTIONS + + runBlocking { + // Open connections (no subscriptions — pure reducer pipelining) + println("Initializing $CONNECTIONS connections...") + val connections = (0 until CONNECTIONS).map { + val ready = CompletableDeferred() + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .withCompression(CompressionMode.NONE) + .onConnect { c, _, _ -> ready.complete(c) } + .onConnectError { e -> ready.completeExceptionally(e) } + .build() + withTimeout(10_000) { ready.await() } + conn + } + println("All $CONNECTIONS connections established") + + val completed = AtomicLong(0) + val workersReady = AtomicInteger(0) + val benchStartNanos = AtomicLong(0) + + println("Warming up for ${WARMUP_MS / 1000}s...") + val warmupStartNanos = System.nanoTime() + + val jobs = connections.mapIndexed { workerIdx, conn -> + launch(Dispatchers.Default) { + var tIdx = workerIdx * transfersPerWorker + + // Pipeline batch: fire MAX_INFLIGHT calls, suspend until all respond + suspend fun runBatch(): Long { + val batchDone = CompletableDeferred() + val remaining = AtomicInteger(MAX_INFLIGHT) + + repeat(MAX_INFLIGHT) { + val idx = tIdx % TOTAL_PAIRS + tIdx++ + val args = encodeTransfer(fromArr[idx], toArr[idx], AMOUNT) + conn.callReducer("transfer", args) { + if (remaining.decrementAndGet() == 0) { + batchDone.complete(Unit) + } + } + } + + batchDone.await() + return MAX_INFLIGHT.toLong() + } + + // ── Warmup phase ── + while (System.nanoTime() - warmupStartNanos < WARMUP_MS * 1_000_000) { + runBatch() + } + + // Sync: wait for all workers to finish warmup + workersReady.incrementAndGet() + while (workersReady.get() < CONNECTIONS) delay(1) + + // First worker to pass sets the shared start time + benchStartNanos.compareAndSet(0, System.nanoTime()) + + // ── Measurement phase ── + val myStart = System.nanoTime() + while (System.nanoTime() - myStart < BENCH_MS * 1_000_000) { + val count = runBatch() + completed.addAndGet(count) + } + } + } + + println("Finished warmup. Benchmarking for ${BENCH_MS / 1000}s...") + jobs.forEach { it.join() } + + val benchEndNanos = System.nanoTime() + val totalCompleted = completed.get() + val elapsed = (benchEndNanos - benchStartNanos.get()) / 1_000_000_000.0 + val tps = totalCompleted / elapsed + + println() + println("=== Results ===") + println("ran for ${"%.3f".format(elapsed)} seconds") + println("completed $totalCompleted transfers") + println("throughput was ${"%.1f".format(tps)} TPS") + + connections.forEach { it.disconnect() } + } + } +} diff --git a/sdks/kotlin/src/jvmTest/kotlin/com/clockworklabs/spacetimedb/LiveEdgeCaseTest.kt b/sdks/kotlin/src/jvmTest/kotlin/com/clockworklabs/spacetimedb/LiveEdgeCaseTest.kt new file mode 100644 index 00000000000..7fb78dd3f24 --- /dev/null +++ b/sdks/kotlin/src/jvmTest/kotlin/com/clockworklabs/spacetimedb/LiveEdgeCaseTest.kt @@ -0,0 +1,466 @@ +package com.clockworklabs.spacetimedb + +import com.clockworklabs.spacetimedb.bsatn.BsatnWriter +import com.clockworklabs.spacetimedb.websocket.ConnectionState +import kotlinx.coroutines.* +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertNotNull +import kotlin.test.assertNull +import kotlin.test.assertTrue +import kotlin.time.measureTime + +/** + * Live edge case tests against a local SpacetimeDB server. + * + * Set SPACETIMEDB_TEST=1 to enable. + */ +class LiveEdgeCaseTest { + + private val serverUri = System.getenv("SPACETIMEDB_URI") ?: "ws://127.0.0.1:3000" + private val moduleName = System.getenv("SPACETIMEDB_MODULE") ?: "kotlin-sdk-test" + + private fun shouldRun(): Boolean = System.getenv("SPACETIMEDB_TEST") == "1" + + // ──────── Invalid connection scenarios ──────── + + @Test + fun connectToNonExistentModule() { + if (!shouldRun()) { println("SKIP"); return } + + runBlocking { + val connectError = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName("non_existent_module_xyz_12345") + .onConnect { _, _, _ -> connectError.completeExceptionally(AssertionError("Should not connect")) } + .onConnectError { e -> connectError.complete(e) } + .onDisconnect { _, e -> if (!connectError.isCompleted) connectError.complete(e ?: RuntimeException("disconnected")) } + .build() + + val error = withTimeout(10000) { connectError.await() } + assertNotNull(error) + println("PASS: Non-existent module rejected: ${error.message?.take(80)}") + conn.disconnect() + } + } + + @Test + fun connectToUnreachableHost() { + if (!shouldRun()) { println("SKIP"); return } + + runBlocking { + val connectError = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri("ws://192.0.2.1:9999") // TEST-NET, guaranteed unreachable + .withModuleName("test") + .onConnectError { e -> connectError.complete(e) } + .onDisconnect { _, e -> if (!connectError.isCompleted) connectError.complete(e ?: RuntimeException("disconnected")) } + .build() + + val error = withTimeout(15000) { connectError.await() } + assertNotNull(error) + println("PASS: Unreachable host properly errored: ${error::class.simpleName}") + conn.disconnect() + } + } + + // ──────── Subscription edge cases ──────── + + @Test + fun subscribeWithInvalidSqlSyntax() { + if (!shouldRun()) { println("SKIP"); return } + + runBlocking { + val connected = CompletableDeferred() + val subError = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .onConnect { c, _, _ -> + c.subscriptionBuilder() + .onApplied { subError.completeExceptionally(AssertionError("Should not apply")) } + .onError { err -> subError.complete(err) } + .subscribe("SELECTT * FROMM invalid_table_xyz") + connected.complete(Unit) + } + .onConnectError { e -> connected.completeExceptionally(e) } + .build() + + withTimeout(5000) { connected.await() } + val error = withTimeout(5000) { subError.await() } + assertTrue(error.isNotEmpty(), "Should get a non-empty error message") + println("PASS: Invalid SQL rejected: ${error.take(80)}") + conn.disconnect() + } + } + + @Test + fun subscribeToNonExistentTable() { + if (!shouldRun()) { println("SKIP"); return } + + runBlocking { + val connected = CompletableDeferred() + val subError = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .onConnect { c, _, _ -> + c.subscriptionBuilder() + .onApplied { subError.completeExceptionally(AssertionError("Should not apply")) } + .onError { err -> subError.complete(err) } + .subscribe("SELECT * FROM nonexistent_table_xyz") + connected.complete(Unit) + } + .onConnectError { e -> connected.completeExceptionally(e) } + .build() + + withTimeout(5000) { connected.await() } + val error = withTimeout(5000) { subError.await() } + assertTrue(error.isNotEmpty()) + println("PASS: Non-existent table rejected: ${error.take(80)}") + conn.disconnect() + } + } + + @Test + fun multipleIndependentSubscriptions() { + if (!shouldRun()) { println("SKIP"); return } + + runBlocking { + val connected = CompletableDeferred() + val sub1Applied = CompletableDeferred() + val sub2Applied = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .onConnect { c, _, _ -> + c.subscriptionBuilder() + .onApplied { sub1Applied.complete(Unit) } + .subscribe("SELECT * FROM player") + + c.subscriptionBuilder() + .onApplied { sub2Applied.complete(Unit) } + .subscribe("SELECT * FROM message") + + connected.complete(Unit) + } + .onConnectError { e -> connected.completeExceptionally(e) } + .build() + + withTimeout(5000) { connected.await() } + withTimeout(5000) { sub1Applied.await() } + withTimeout(5000) { sub2Applied.await() } + println("PASS: Two independent subscriptions applied concurrently") + conn.disconnect() + } + } + + @Test + fun subscribeToAllTables() { + if (!shouldRun()) { println("SKIP"); return } + + runBlocking { + val connected = CompletableDeferred() + val subApplied = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .onConnect { c, _, _ -> + c.subscriptionBuilder() + .onApplied { subApplied.complete(Unit) } + .subscribeToAllTables() + connected.complete(Unit) + } + .onConnectError { e -> connected.completeExceptionally(e) } + .build() + + withTimeout(5000) { connected.await() } + withTimeout(5000) { subApplied.await() } + println("PASS: subscribeToAllTables (SELECT * FROM *) applied") + conn.disconnect() + } + } + + // ──────── Reducer edge cases ──────── + + @Test + fun callNonExistentReducer() { + if (!shouldRun()) { println("SKIP"); return } + + runBlocking { + val connected = CompletableDeferred() + val result = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .onConnect { _, _, _ -> connected.complete(Unit) } + .onConnectError { e -> connected.completeExceptionally(e) } + .build() + + withTimeout(5000) { connected.await() } + + conn.callReducer("nonexistent_reducer_xyz", byteArrayOf()) { r -> + result.complete(r) + } + + val res = withTimeout(5000) { result.await() } + // Should get an error outcome, not a crash + assertNotNull(res) + println("PASS: Non-existent reducer returned: ${res.outcome::class.simpleName}") + conn.disconnect() + } + } + + @Test + fun callReducerWithEmptyArgs() { + if (!shouldRun()) { println("SKIP"); return } + + runBlocking { + val connected = CompletableDeferred() + val subApplied = CompletableDeferred() + val result = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .onConnect { c, _, _ -> + c.subscriptionBuilder() + .onApplied { subApplied.complete(Unit) } + .subscribe("SELECT * FROM player") + connected.complete(Unit) + } + .onConnectError { e -> connected.completeExceptionally(e) } + .build() + + withTimeout(5000) { connected.await() } + withTimeout(5000) { subApplied.await() } + + // add_player expects a String arg — empty args should cause an error + conn.callReducer("add_player", byteArrayOf()) { r -> + result.complete(r) + } + + val res = withTimeout(5000) { result.await() } + assertNotNull(res) + // Should be an error since args don't match expected schema + println("PASS: Empty args to add_player returned: ${res.outcome::class.simpleName}") + conn.disconnect() + } + } + + // ──────── One-off query edge cases ──────── + + @Test + fun oneOffQueryInvalidSql() { + if (!shouldRun()) { println("SKIP"); return } + + runBlocking { + val connected = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .onConnect { c, _, _ -> connected.complete(c) } + .onConnectError { e -> connected.completeExceptionally(e) } + .build() + + val c = withTimeout(5000) { connected.await() } + val result = withTimeout(5000) { c.oneOffQuery("INVALID SQL QUERY!!!") } + assertNotNull(result.error, "Should return an error for invalid SQL") + assertNull(result.rows) + println("PASS: Invalid SQL one-off query returned error: ${result.error?.take(80)}") + conn.disconnect() + } + } + + @Test + fun oneOffQueryEmptyResult() { + if (!shouldRun()) { println("SKIP"); return } + + runBlocking { + val connected = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .onConnect { c, _, _ -> connected.complete(c) } + .onConnectError { e -> connected.completeExceptionally(e) } + .build() + + val c = withTimeout(5000) { connected.await() } + // Query with impossible WHERE clause + val result = withTimeout(5000) { c.oneOffQuery("SELECT * FROM player WHERE id = 999999999") } + if (result.error != null) { + println("PASS: Empty result query returned error: ${result.error}") + } else { + val rows = result.rows?.tables?.flatMap { it.rows.decodeRows() } ?: emptyList() + println("PASS: Empty result query returned ${rows.size} rows") + } + conn.disconnect() + } + } + + // ──────── Token reuse ──────── + + @Test + fun reconnectWithSavedToken() { + if (!shouldRun()) { println("SKIP"); return } + + runBlocking { + // First connection: get identity and token + val firstConnect = CompletableDeferred>() + val conn1 = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .onConnect { _, id, token -> firstConnect.complete(Pair(id, token)) } + .onConnectError { e -> firstConnect.completeExceptionally(e) } + .build() + + val (firstIdentity, token) = withTimeout(5000) { firstConnect.await() } + conn1.disconnect() + + // Second connection: reuse the token + val secondConnect = CompletableDeferred>() + val conn2 = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .withToken(token) + .onConnect { _, id, newToken -> secondConnect.complete(Pair(id, newToken)) } + .onConnectError { e -> secondConnect.completeExceptionally(e) } + .build() + + val (secondIdentity, _) = withTimeout(5000) { secondConnect.await() } + assertEquals(firstIdentity, secondIdentity, "Same token should yield same identity") + println("PASS: Token reuse preserved identity: ${firstIdentity.toHex().take(16)}...") + conn2.disconnect() + } + } + + // ──────── Rapid fire operations ──────── + + @Test + fun rapidReducerCallsWithCallbacks() { + if (!shouldRun()) { println("SKIP"); return } + + runBlocking { + val connected = CompletableDeferred() + val subApplied = CompletableDeferred() + val targetCount = 20 + val results = java.util.concurrent.ConcurrentHashMap() + val allDone = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .onConnect { c, _, _ -> + c.subscriptionBuilder() + .onApplied { subApplied.complete(Unit) } + .subscribe("SELECT * FROM player") + connected.complete(Unit) + } + .onConnectError { e -> connected.completeExceptionally(e) } + .build() + + withTimeout(5000) { connected.await() } + withTimeout(5000) { subApplied.await() } + + val elapsed = measureTime { + repeat(targetCount) { i -> + val w = BsatnWriter(64) + w.writeString("Rapid_${System.currentTimeMillis()}_$i") + conn.callReducer("add_player", w.toByteArray()) { result -> + results[result.requestId] = result + if (results.size >= targetCount && !allDone.isCompleted) { + allDone.complete(Unit) + } + } + } + withTimeout(15000) { allDone.await() } + } + + assertEquals(targetCount, results.size, "All $targetCount callbacks should fire") + // Verify all got unique requestIds + assertEquals(targetCount, results.keys.size) + println("PASS: $targetCount rapid reducer calls all received callbacks in ${elapsed.inWholeMilliseconds}ms") + conn.disconnect() + } + } + + // ──────── Connection state transitions ──────── + + @Test + fun connectionStateTransitions() { + if (!shouldRun()) { println("SKIP"); return } + + runBlocking { + val connected = CompletableDeferred() + val disconnected = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .onConnect { _, _, _ -> connected.complete(Unit) } + .onDisconnect { _, _ -> disconnected.complete(Unit) } + .onConnectError { e -> connected.completeExceptionally(e) } + .build() + + // Before connect completes, state should be CONNECTING or CONNECTED + val earlyState = conn.connectionState.value + assertTrue( + earlyState == ConnectionState.CONNECTING || earlyState == ConnectionState.CONNECTED, + "Early state should be CONNECTING or CONNECTED, got $earlyState" + ) + + withTimeout(5000) { connected.await() } + assertEquals(ConnectionState.CONNECTED, conn.connectionState.value) + assertTrue(conn.isActive) + + conn.disconnect() + assertEquals(ConnectionState.DISCONNECTED, conn.connectionState.value) + assertFalse(conn.isActive) + + // Identity should still be available after disconnect + assertNotNull(conn.identity, "Identity should persist after disconnect") + + println("PASS: State transitions: CONNECTING -> CONNECTED -> DISCONNECTED") + } + } + + // ──────── Identity null before connect ──────── + + @Test + fun identityNullBeforeConnect() { + if (!shouldRun()) { println("SKIP"); return } + + runBlocking { + val connected = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .onConnect { _, _, _ -> connected.complete(Unit) } + .onConnectError { e -> connected.completeExceptionally(e) } + .build() + + // Identity/connectionId/token should be null before InitialConnection arrives + // (This is a best-effort check — the connect could be very fast) + // We mainly verify they're non-null after connect + withTimeout(5000) { connected.await() } + + assertNotNull(conn.identity) + assertNotNull(conn.connectionId) + assertNotNull(conn.savedToken) + println("PASS: Identity, connectionId, and token all non-null after connect") + conn.disconnect() + } + } +} diff --git a/sdks/kotlin/src/jvmTest/kotlin/com/clockworklabs/spacetimedb/LiveIntegrationTest.kt b/sdks/kotlin/src/jvmTest/kotlin/com/clockworklabs/spacetimedb/LiveIntegrationTest.kt new file mode 100644 index 00000000000..3190dc9a53b --- /dev/null +++ b/sdks/kotlin/src/jvmTest/kotlin/com/clockworklabs/spacetimedb/LiveIntegrationTest.kt @@ -0,0 +1,278 @@ +package com.clockworklabs.spacetimedb + +import com.clockworklabs.spacetimedb.bsatn.BsatnWriter +import com.clockworklabs.spacetimedb.websocket.ConnectionState +import kotlinx.coroutines.* +import kotlinx.coroutines.flow.first +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue +import kotlin.time.measureTime + +/** + * Live integration tests against a local SpacetimeDB server. + * + * Prerequisites: + * 1. `spacetime start` running on localhost:3000 + * 2. Test module published: `spacetime publish --server local -p kotlin-sdk-test` + * + * Set `SPACETIMEDB_TEST=1` to enable. Skipped by default in CI. + */ +class LiveIntegrationTest { + + private val serverUri = System.getenv("SPACETIMEDB_URI") ?: "ws://127.0.0.1:3000" + private val moduleName = System.getenv("SPACETIMEDB_MODULE") ?: "kotlin-sdk-test" + + private fun skipIfNoServer() { + if (System.getenv("SPACETIMEDB_TEST") != "1") { + println("SKIP: Set SPACETIMEDB_TEST=1 to run live integration tests") + return + } + } + + private fun shouldRun(): Boolean = System.getenv("SPACETIMEDB_TEST") == "1" + + @Test + fun connectAndReceiveIdentity() { + if (!shouldRun()) { println("SKIP: Set SPACETIMEDB_TEST=1"); return } + + runBlocking { + val connected = CompletableDeferred>() + val disconnected = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .onConnect { c, id, token -> + connected.complete(Triple(c, id, token)) + } + .onDisconnect { _, err -> disconnected.complete(err) } + .onConnectError { e -> connected.completeExceptionally(e) } + .build() + + val (_, identity, token) = withTimeout(5000) { connected.await() } + + assertNotNull(identity, "Should receive an identity") + assertTrue(identity.bytes.size == 32, "Identity should be 32 bytes") + assertTrue(token.isNotEmpty(), "Should receive an auth token") + assertNotNull(conn.connectionId, "Should have a connectionId") + assertEquals(ConnectionState.CONNECTED, conn.connectionState.value) + + println("PASS: Connected as ${identity.toHex().take(16)}...") + println(" Token: ${token.take(20)}...") + println(" ConnectionId: ${conn.connectionId}") + + conn.disconnect() + } + } + + @Test + fun subscribeAndReceiveRows() { + if (!shouldRun()) { println("SKIP: Set SPACETIMEDB_TEST=1"); return } + + runBlocking { + val connected = CompletableDeferred() + val subApplied = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .onConnect { c, _, _ -> + c.subscriptionBuilder() + .onApplied { subApplied.complete(Unit) } + .onError { err -> subApplied.completeExceptionally(RuntimeException(err)) } + .subscribe("SELECT * FROM player") + connected.complete(Unit) + } + .onConnectError { e -> connected.completeExceptionally(e) } + .build() + + withTimeout(5000) { connected.await() } + withTimeout(5000) { subApplied.await() } + + println("PASS: Subscription to 'SELECT * FROM player' applied successfully") + + conn.disconnect() + } + } + + @Test + fun callReducerAndObserveInsert() { + if (!shouldRun()) { println("SKIP: Set SPACETIMEDB_TEST=1"); return } + + runBlocking { + val connected = CompletableDeferred() + val subApplied = CompletableDeferred() + val insertReceived = CompletableDeferred() + val reducerResult = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .onConnect { c, _, _ -> + c.table("player").onInsert { row -> + insertReceived.complete(row) + } + + c.subscriptionBuilder() + .onApplied { + subApplied.complete(Unit) + } + .onError { err -> subApplied.completeExceptionally(RuntimeException(err)) } + .subscribe("SELECT * FROM player") + + connected.complete(Unit) + } + .onConnectError { e -> connected.completeExceptionally(e) } + .build() + + withTimeout(5000) { connected.await() } + withTimeout(5000) { subApplied.await() } + + // Call the add_player reducer + val playerName = "KotlinSDK_${System.currentTimeMillis()}" + val argsWriter = BsatnWriter(64) + argsWriter.writeString(playerName) + + conn.callReducer("add_player", argsWriter.toByteArray()) { result -> + reducerResult.complete(result) + } + + val row = withTimeout(5000) { insertReceived.await() } + assertTrue(row.isNotEmpty(), "Should receive inserted row bytes") + + val result = withTimeout(5000) { reducerResult.await() } + assertNotNull(result, "Should receive reducer result") + + println("PASS: Called add_player('$playerName')") + println(" Received insert: ${row.size} bytes") + println(" Reducer result: ${result.outcome}") + + conn.disconnect() + } + } + + @Test + fun multipleReducerCallsPerformance() { + if (!shouldRun()) { println("SKIP: Set SPACETIMEDB_TEST=1"); return } + + runBlocking { + val connected = CompletableDeferred() + val subApplied = CompletableDeferred() + val insertCount = java.util.concurrent.atomic.AtomicInteger(0) + val targetCount = 50 + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .onConnect { c, _, _ -> + c.table("player").onInsert { insertCount.incrementAndGet() } + + c.subscriptionBuilder() + .onApplied { subApplied.complete(Unit) } + .onError { err -> subApplied.completeExceptionally(RuntimeException(err)) } + .subscribe("SELECT * FROM player") + + connected.complete(Unit) + } + .onConnectError { e -> connected.completeExceptionally(e) } + .build() + + withTimeout(5000) { connected.await() } + withTimeout(5000) { subApplied.await() } + + // Fire N reducer calls and measure round-trip time + val elapsed = measureTime { + repeat(targetCount) { i -> + val w = BsatnWriter(64) + w.writeString("Batch_${System.currentTimeMillis()}_$i") + conn.callReducer("add_player", w.toByteArray()) + } + + // Wait for all inserts to arrive + withTimeout(15000) { + while (insertCount.get() < targetCount) { + delay(50) + } + } + } + + assertTrue(insertCount.get() >= targetCount, "Should receive all $targetCount inserts") + val avgMs = elapsed.inWholeMilliseconds.toDouble() / targetCount + println("PASS: $targetCount reducer calls + round-trip in ${elapsed.inWholeMilliseconds}ms") + println(" Avg round-trip: ${"%.1f".format(avgMs)}ms per call") + + conn.disconnect() + } + } + + @Test + fun oneOffQueryExecution() { + if (!shouldRun()) { println("SKIP: Set SPACETIMEDB_TEST=1"); return } + + runBlocking { + val connected = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .onConnect { c, _, _ -> connected.complete(c) } + .onConnectError { e -> connected.completeExceptionally(e) } + .build() + + val c = withTimeout(5000) { connected.await() } + + val elapsed = measureTime { + val result = withTimeout(5000) { + c.oneOffQuery("SELECT * FROM player") + } + if (result.error != null) { + println(" Query returned error: ${result.error}") + } else { + val rows = result.rows?.tables?.flatMap { it.rows.decodeRows() } ?: emptyList() + println("PASS: One-off query returned ${rows.size} player rows") + } + } + println(" Query time: ${elapsed.inWholeMilliseconds}ms") + + conn.disconnect() + } + } + + @Test + fun reconnectionAfterDisconnect() { + if (!shouldRun()) { println("SKIP: Set SPACETIMEDB_TEST=1"); return } + + runBlocking { + var connectCount = 0 + val firstConnect = CompletableDeferred() + val secondConnect = CompletableDeferred() + + val conn = DbConnection.builder() + .withUri(serverUri) + .withModuleName(moduleName) + .withReconnectPolicy(ReconnectPolicy(maxRetries = 3, initialDelayMs = 500)) + .onConnect { _, _, _ -> + connectCount++ + if (connectCount == 1) firstConnect.complete(Unit) + else secondConnect.complete(Unit) + } + .onConnectError { e -> firstConnect.completeExceptionally(e) } + .build() + + withTimeout(5000) { firstConnect.await() } + assertEquals(ConnectionState.CONNECTED, conn.connectionState.value) + println("PASS: First connection established") + + // We can't easily force a server-side disconnect from the client, + // so we just verify the reconnect policy is wired up correctly + assertEquals(ConnectionState.CONNECTED, conn.connectionState.value) + println("PASS: Reconnect policy configured (maxRetries=3, initialDelay=500ms)") + + conn.disconnect() + assertEquals(ConnectionState.DISCONNECTED, conn.connectionState.value) + println("PASS: Clean disconnect") + } + } +} diff --git a/sdks/kotlin/src/jvmTest/kotlin/com/clockworklabs/spacetimedb/PerformanceBenchmarkTest.kt b/sdks/kotlin/src/jvmTest/kotlin/com/clockworklabs/spacetimedb/PerformanceBenchmarkTest.kt new file mode 100644 index 00000000000..7555f95c5e2 --- /dev/null +++ b/sdks/kotlin/src/jvmTest/kotlin/com/clockworklabs/spacetimedb/PerformanceBenchmarkTest.kt @@ -0,0 +1,461 @@ +package com.clockworklabs.spacetimedb + +import com.clockworklabs.spacetimedb.bsatn.BsatnReader +import com.clockworklabs.spacetimedb.bsatn.BsatnWriter +import com.clockworklabs.spacetimedb.protocol.* +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue +import kotlin.time.measureTime + +/** + * Performance benchmarks for the core SDK machinery. + * + * These validate throughput and latency of: + * - BSATN serialization/deserialization + * - ClientCache insert/delete/update operations + * - Full ServerMessage decode pipeline + * - Gzip decompression throughput + * + * All tests run offline — no server required. + */ +class PerformanceBenchmarkTest { + + // ───────────────────────────── BSATN ───────────────────────────── + + @Test + fun bsatnWriteThroughput() { + val iterations = 100_000 + // Simulate writing a "player row": u64 id, string name, i32 x, i32 y, f64 health + val elapsed = measureTime { + repeat(iterations) { + val w = BsatnWriter(64) + w.writeU64(it.toULong()) + w.writeString("Player_$it") + w.writeI32(it * 10) + w.writeI32(it * -5) + w.writeF64(100.0 - (it % 100)) + w.toByteArray() + } + } + val opsPerSec = iterations / elapsed.inWholeMilliseconds.coerceAtLeast(1) * 1000 + println("BSATN write: ${iterations} rows in ${elapsed.inWholeMilliseconds}ms ($opsPerSec rows/sec)") + // Sanity: should do at least 100k rows/sec on any modern machine + assertTrue(elapsed.inWholeMilliseconds < 5000, "BSATN write too slow: ${elapsed.inWholeMilliseconds}ms") + } + + @Test + fun bsatnReadThroughput() { + val iterations = 100_000 + // Pre-encode rows + val rows = Array(iterations) { i -> + val w = BsatnWriter(64) + w.writeU64(i.toULong()) + w.writeString("Player_$i") + w.writeI32(i * 10) + w.writeI32(i * -5) + w.writeF64(100.0 - (i % 100)) + w.toByteArray() + } + + val elapsed = measureTime { + for (data in rows) { + val r = BsatnReader(data) + r.readU64() + r.readString() + r.readI32() + r.readI32() + r.readF64() + } + } + val opsPerSec = iterations / elapsed.inWholeMilliseconds.coerceAtLeast(1) * 1000 + println("BSATN read: ${iterations} rows in ${elapsed.inWholeMilliseconds}ms ($opsPerSec rows/sec)") + assertTrue(elapsed.inWholeMilliseconds < 5000, "BSATN read too slow: ${elapsed.inWholeMilliseconds}ms") + } + + @Test + fun bsatnRoundTripIntegrity() { + // Verify data survives write → read for every primitive type + val w = BsatnWriter(256) + w.writeBool(true) + w.writeBool(false) + w.writeU8(255u) + w.writeI8(-128) + w.writeU16(65535u) + w.writeI16(-32768) + w.writeU32(UInt.MAX_VALUE) + w.writeI32(Int.MIN_VALUE) + w.writeU64(ULong.MAX_VALUE) + w.writeI64(Long.MIN_VALUE) + w.writeF32(3.14f) + w.writeF64(2.718281828459045) + w.writeString("Hello, SpacetimeDB! 🚀") + w.writeByteArray(byteArrayOf(0xCA.toByte(), 0xFE.toByte())) + + val r = BsatnReader(w.toByteArray()) + assertEquals(true, r.readBool()) + assertEquals(false, r.readBool()) + assertEquals(255.toUByte(), r.readU8()) + assertEquals((-128).toByte(), r.readI8()) + assertEquals(65535.toUShort(), r.readU16()) + assertEquals((-32768).toShort(), r.readI16()) + assertEquals(UInt.MAX_VALUE, r.readU32()) + assertEquals(Int.MIN_VALUE, r.readI32()) + assertEquals(ULong.MAX_VALUE, r.readU64()) + assertEquals(Long.MIN_VALUE, r.readI64()) + assertEquals(3.14f, r.readF32()) + assertEquals(2.718281828459045, r.readF64()) + assertEquals("Hello, SpacetimeDB! 🚀", r.readString()) + val bytes = r.readByteArray() + assertEquals(0xCA.toByte(), bytes[0]) + assertEquals(0xFE.toByte(), bytes[1]) + assertTrue(r.isExhausted, "Reader should be fully consumed") + } + + // ───────────────────────── Client Cache ────────────────────────── + + @Test + fun cacheInsertThroughput() { + val cache = ClientCache() + val table = cache.getOrCreateTable("players") + val rowCount = 50_000 + // Pre-generate unique rows + val rows = Array(rowCount) { i -> + val w = BsatnWriter(32) + w.writeU64(i.toULong()) + w.writeString("P$i") + w.toByteArray() + } + + val elapsed = measureTime { + for (row in rows) { + table.insertRow(row) + } + } + assertEquals(rowCount, table.count) + val opsPerSec = rowCount / elapsed.inWholeMilliseconds.coerceAtLeast(1) * 1000 + println("Cache insert: $rowCount rows in ${elapsed.inWholeMilliseconds}ms ($opsPerSec rows/sec)") + assertTrue(elapsed.inWholeMilliseconds < 5000, "Cache insert too slow") + } + + @Test + fun cacheDeleteThroughput() { + val cache = ClientCache() + val table = cache.getOrCreateTable("players") + val rowCount = 50_000 + val rows = Array(rowCount) { i -> + val w = BsatnWriter(32) + w.writeU64(i.toULong()) + w.writeString("P$i") + w.toByteArray() + } + for (row in rows) table.insertRow(row) + + val elapsed = measureTime { + for (row in rows) { + table.deleteRow(row) + } + } + assertEquals(0, table.count) + val opsPerSec = rowCount / elapsed.inWholeMilliseconds.coerceAtLeast(1) * 1000 + println("Cache delete: $rowCount rows in ${elapsed.inWholeMilliseconds}ms ($opsPerSec rows/sec)") + assertTrue(elapsed.inWholeMilliseconds < 5000, "Cache delete too slow") + } + + @Test + fun cacheRefCountingCorrectness() { + // Overlapping subscriptions: same row inserted twice, deleted once → still present + val table = TableCache("test") + val row = byteArrayOf(1, 2, 3) + table.insertRow(row) + table.insertRow(row) // refCount = 2 + assertEquals(1, table.count, "Same row should not duplicate") + table.deleteRow(row) // refCount = 1 + assertEquals(1, table.count, "Row should remain with refCount > 0") + assertTrue(table.containsRow(row)) + table.deleteRow(row) // refCount = 0 + assertEquals(0, table.count, "Row should be removed at refCount 0") + } + + @Test + fun cacheTransactionUpdatePerformance() { + val cache = ClientCache() + // Pre-populate with 10k rows + val table = cache.getOrCreateTable("entities") + val existingRows = Array(10_000) { i -> + val w = BsatnWriter(16) + w.writeU64(i.toULong()) + w.writeI32(i) + w.toByteArray() + } + for (row in existingRows) table.insertRow(row) + + // Simulate a transaction: delete 1000 rows, insert 1000 new, update 500 + val deleteRows = existingRows.take(1500) // 1000 pure deletes + 500 updates + val updateNewRows = Array(500) { i -> + val w = BsatnWriter(16) + w.writeU64(i.toULong()) // same key as deleted + w.writeI32(i + 999_999) // different value + w.toByteArray() + } + val insertRows = Array(1000) { i -> + val w = BsatnWriter(16) + w.writeU64((20_000 + i).toULong()) + w.writeI32(i) + w.toByteArray() + } + + // Build the BsatnRowList payloads + val deletePayload = buildRowListPayload(deleteRows.toList()) + val insertPayload = buildRowListPayload(updateNewRows.toList() + insertRows.toList()) + + val qsUpdate = buildQuerySetUpdate("entities", insertPayload, deletePayload) + val elapsed = measureTime { + cache.applyTransactionUpdate(listOf(qsUpdate)) + } + + // Expected: 10000 - 1000 pure deletes + 1000 new inserts = 10000 (500 updates are in-place) + println("Transaction update: 2500 ops in ${elapsed.inWholeMilliseconds}ms") + assertTrue(elapsed.inWholeMilliseconds < 2000, "Transaction update too slow") + } + + // ──────────────────── Protocol Decode Pipeline ─────────────────── + + @Test + fun initialConnectionDecodePerformance() { + // Build a valid InitialConnection message + val w = BsatnWriter(256) + w.writeTag(0u) // InitialConnection tag + w.writeBytes(ByteArray(32) { it.toByte() }) // identity + w.writeBytes(ByteArray(16) { it.toByte() }) // connectionId + w.writeString("test-token-abc123") + val payload = w.toByteArray() + + val iterations = 50_000 + val elapsed = measureTime { + repeat(iterations) { + val msg = ServerMessage.decode(payload) + assertTrue(msg is ServerMessage.InitialConnection) + } + } + val opsPerSec = iterations / elapsed.inWholeMilliseconds.coerceAtLeast(1) * 1000 + println("InitialConnection decode: $iterations msgs in ${elapsed.inWholeMilliseconds}ms ($opsPerSec msg/sec)") + assertTrue(elapsed.inWholeMilliseconds < 5000, "Decode too slow") + } + + @Test + fun subscribeAppliedDecodeWithRows() { + // Build a SubscribeApplied with 100 rows across 2 tables + val w = BsatnWriter(4096) + w.writeTag(1u) // SubscribeApplied + w.writeU32(42u) // requestId + w.writeU32(7u) // querySetId + + // QueryRows: array of SingleTableRows + w.writeU32(1u) // 1 table + w.writeString("players") // table name + // BsatnRowList: RowSizeHint (tag + data) + length-prefixed row bytes + val rowSize = 12 // u64 + i32 + val rowCount = 100 + w.writeTag(0u) // RowSizeHint::FixedSize + w.writeU16(rowSize.toUShort()) + // Row data as a length-prefixed byte array + w.writeU32((rowSize * rowCount).toUInt()) + repeat(rowCount) { i -> + // Each row: u64 id, i32 score + for (b in 0 until 8) w.writeI8(((i shr (b * 8)) and 0xFF).toByte()) + w.writeI32(i * 100) + } + + val payload = w.toByteArray() + + val iterations = 10_000 + val elapsed = measureTime { + repeat(iterations) { + val msg = ServerMessage.decode(payload) + assertTrue(msg is ServerMessage.SubscribeApplied) + } + } + val opsPerSec = iterations / elapsed.inWholeMilliseconds.coerceAtLeast(1) * 1000 + println("SubscribeApplied decode (100 rows): $iterations msgs in ${elapsed.inWholeMilliseconds}ms ($opsPerSec msg/sec)") + assertTrue(elapsed.inWholeMilliseconds < 10000, "SubscribeApplied decode too slow") + } + + @Test + fun clientMessageEncodeThroughput() { + val iterations = 100_000 + val elapsed = measureTime { + repeat(iterations) { i -> + val msg = ClientMessage.CallReducer( + requestId = i.toUInt(), + reducer = "set_position", + args = byteArrayOf(1, 2, 3, 4, 5, 6, 7, 8), + ) + msg.encode() + } + } + val opsPerSec = iterations / elapsed.inWholeMilliseconds.coerceAtLeast(1) * 1000 + println("CallReducer encode: $iterations msgs in ${elapsed.inWholeMilliseconds}ms ($opsPerSec msg/sec)") + assertTrue(elapsed.inWholeMilliseconds < 5000, "Encode too slow") + } + + // ──────────────────────── Gzip Decompression ───────────────────── + + @Test + fun gzipDecompressionThroughput() { + // Compress a realistic payload (1KB of row data) then benchmark decompression + val payload = ByteArray(1024) { (it % 256).toByte() } + val compressed = compressGzip(payload) + println("Gzip: ${payload.size} bytes → ${compressed.size} bytes (${compressed.size * 100 / payload.size}%)") + + val iterations = 50_000 + val elapsed = measureTime { + repeat(iterations) { + val decompressed = decompressGzip(compressed) + assertEquals(payload.size, decompressed.size) + } + } + val opsPerSec = iterations / elapsed.inWholeMilliseconds.coerceAtLeast(1) * 1000 + println("Gzip decompress: $iterations x ${compressed.size}B in ${elapsed.inWholeMilliseconds}ms ($opsPerSec ops/sec)") + assertTrue(elapsed.inWholeMilliseconds < 10000, "Gzip decompression too slow") + } + + @Test + fun gzipLargePayloadDecompression() { + // Simulate a large SubscribeApplied (100KB) + val payload = ByteArray(100_000) { (it % 256).toByte() } + val compressed = compressGzip(payload) + println("Gzip large: ${payload.size} bytes → ${compressed.size} bytes") + + val iterations = 1_000 + val elapsed = measureTime { + repeat(iterations) { + val result = decompressGzip(compressed) + assertEquals(payload.size, result.size) + } + } + val mbPerSec = (payload.size.toLong() * iterations / 1024 / 1024) / + elapsed.inWholeMilliseconds.coerceAtLeast(1) * 1000 + println("Gzip large decompress: $iterations x ${payload.size / 1024}KB in ${elapsed.inWholeMilliseconds}ms ($mbPerSec MB/sec)") + assertTrue(elapsed.inWholeMilliseconds < 10000, "Large gzip decompression too slow") + } + + // ──────────────────── Callback System ──────────────────────────── + + @Test + fun tableHandleCallbackPerformance() { + val handle = TableHandle("test") + var insertCount = 0 + var deleteCount = 0 + var updateCount = 0 + + // Register multiple callbacks + repeat(10) { + handle.onInsert { insertCount++ } + handle.onDelete { deleteCount++ } + handle.onUpdate { _, _ -> updateCount++ } + } + + val row = byteArrayOf(1, 2, 3, 4) + val iterations = 100_000 + val elapsed = measureTime { + repeat(iterations) { + handle.fireInsert(row) + handle.fireDelete(row) + handle.fireUpdate(row, row) + } + } + assertEquals(iterations * 10, insertCount) + assertEquals(iterations * 10, deleteCount) + assertEquals(iterations * 10, updateCount) + println("Callbacks: ${iterations * 3} fires (10 listeners each) in ${elapsed.inWholeMilliseconds}ms") + assertTrue(elapsed.inWholeMilliseconds < 5000, "Callbacks too slow") + } + + @Test + fun callbackRegistrationAndRemoval() { + val handle = TableHandle("test") + var count = 0 + val ids = mutableListOf() + + // Register 100 callbacks that all increment count + repeat(100) { + ids.add(handle.onInsert { count++ }) + } + + // Remove every other one (50 removed, 50 remain) + for (i in ids.indices step 2) { + handle.removeOnInsert(ids[i]) + } + + handle.fireInsert(byteArrayOf(1)) + assertEquals(50, count, "Should have 50 callbacks remaining") + } + + // ──────────────────── End-to-End Message Flow ──────────────────── + + @Test + fun fullMessageRoundTrip() { + // Encode a Subscribe message, verify it round-trips through binary + val subscribe = ClientMessage.Subscribe( + requestId = 1u, + querySetId = QuerySetId(42u), + queryStrings = listOf("SELECT * FROM players", "SELECT * FROM items WHERE owner_id = 7"), + ) + val encoded = subscribe.encode() + assertTrue(encoded.isNotEmpty()) + + // Decode it back manually + val reader = BsatnReader(encoded) + assertEquals(0, reader.readTag().toInt()) // Subscribe tag + assertEquals(1u, reader.readU32()) // requestId + assertEquals(42u, reader.readU32()) // querySetId + val queryCount = reader.readU32().toInt() + assertEquals(2, queryCount) + assertEquals("SELECT * FROM players", reader.readString()) + assertEquals("SELECT * FROM items WHERE owner_id = 7", reader.readString()) + assertTrue(reader.isExhausted) + } + + // ──────────────────── Helpers ──────────────────────────────────── + + private fun compressGzip(data: ByteArray): ByteArray { + val bos = java.io.ByteArrayOutputStream() + java.util.zip.GZIPOutputStream(bos).use { it.write(data) } + return bos.toByteArray() + } + + private fun buildRowListPayload(rows: List): ByteArray { + val w = BsatnWriter(256) + w.writeTag(0u) // RowSizeHint::FixedSize + if (rows.isEmpty()) { + w.writeU16(0u) + w.writeU32(0u) // empty data + return w.toByteArray() + } + val rowSize = rows.first().size + w.writeU16(rowSize.toUShort()) + w.writeU32((rowSize * rows.size).toUInt()) // length-prefixed data + for (row in rows) w.writeBytes(row) + return w.toByteArray() + } + + private fun buildQuerySetUpdate( + tableName: String, + insertPayload: ByteArray, + deletePayload: ByteArray, + ): QuerySetUpdate { + // Encode to BSATN and decode — ensures we go through the real codec + val w = BsatnWriter(insertPayload.size + deletePayload.size + 256) + w.writeU32(1u) // querySetId + w.writeU32(1u) // 1 table + w.writeString(tableName) + w.writeU32(1u) // 1 row update block + w.writeTag(0u) // TableUpdateRows::PersistentTable + // PersistentTableRows: inserts then deletes (each is a full BsatnRowList) + w.writeBytes(insertPayload) + w.writeBytes(deletePayload) + + return QuerySetUpdate.read(BsatnReader(w.toByteArray())) + } +}