Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions app/src/main/java/to/bitkit/data/TrezorStore.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@ import android.content.Context
import androidx.datastore.core.DataStore
import androidx.datastore.dataStore
import dagger.hilt.android.qualifiers.ApplicationContext
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.withContext
import kotlinx.serialization.Serializable
import to.bitkit.data.serializers.TrezorDataSerializer
import to.bitkit.di.IoDispatcher
import to.bitkit.repositories.KnownDevice
import javax.inject.Inject
import javax.inject.Singleton
Expand All @@ -20,20 +23,24 @@ private val Context.trezorDataStore: DataStore<TrezorData> by dataStore(
@Singleton
class TrezorStore @Inject constructor(
@ApplicationContext private val context: Context,
@IoDispatcher private val ioDispatcher: CoroutineDispatcher,
) {
private val store = context.trezorDataStore

val data: Flow<TrezorData> = store.data

suspend fun loadKnownDevices(): List<KnownDevice> =
suspend fun loadKnownDevices(): List<KnownDevice> = withContext(ioDispatcher) {
store.data.first().knownDevices
}

suspend fun saveKnownDevices(devices: List<KnownDevice>) {
suspend fun saveKnownDevices(devices: List<KnownDevice>) = withContext(ioDispatcher) {
store.updateData { it.copy(knownDevices = devices) }
Unit
}

suspend fun reset() {
suspend fun reset() = withContext(ioDispatcher) {
store.updateData { TrezorData() }
Unit
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ import java.io.InputStream
import java.io.OutputStream

object TrezorDataSerializer : Serializer<TrezorData> {
private const val TAG = "TrezorDataSerializer"

override val defaultValue: TrezorData = TrezorData()

override suspend fun readFrom(input: InputStream): TrezorData {
return try {
json.decodeFromString(input.readBytes().decodeToString())
} catch (e: SerializationException) {
Logger.error("Failed to deserialize: $e")
Logger.error("Deserialize Trezor data failed", e, context = TAG)
defaultValue
}
}
Expand Down
112 changes: 73 additions & 39 deletions app/src/main/java/to/bitkit/repositories/TrezorRepo.kt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package to.bitkit.repositories

import android.content.Context
import androidx.compose.runtime.Immutable
import androidx.compose.runtime.Stable
import com.synonym.bitkitcore.AccountInfoResult
import com.synonym.bitkitcore.AccountType
Expand All @@ -21,6 +22,9 @@ import com.synonym.bitkitcore.TrezorSignedTx
import com.synonym.bitkitcore.TrezorTransportType
import com.synonym.bitkitcore.WalletParams
import dagger.hilt.android.qualifiers.ApplicationContext
import kotlinx.collections.immutable.ImmutableList
import kotlinx.collections.immutable.persistentListOf
import kotlinx.collections.immutable.toImmutableList
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.flow.MutableStateFlow
Expand All @@ -29,6 +33,7 @@ import kotlinx.coroutines.flow.launchIn
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.flow.update
import kotlinx.coroutines.withContext
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import to.bitkit.data.TrezorStore
import to.bitkit.di.IoDispatcher
Expand Down Expand Up @@ -88,7 +93,7 @@ class TrezorRepo @Inject constructor(
Logger.debug("Initializing Trezor with credential path: '$credentialPath'", context = TAG)
trezorService.initialize(credentialPath)
val known = loadKnownDevices()
_state.update { it.copy(isInitialized = true, knownDevices = known, error = null) }
_state.update { it.copy(isInitialized = true, knownDevices = known.toImmutableList(), error = null) }
}.onFailure { e ->
Logger.error("Trezor init failed", e, context = TAG)
_state.update { it.copy(error = e.message) }
Expand All @@ -101,7 +106,7 @@ class TrezorRepo @Inject constructor(
val devices = trezorService.scan()
val knownIds = _state.value.knownDevices.map { it.id }.toSet()
val nearby = devices.filter { it.id !in knownIds }
_state.update { it.copy(isScanning = false, nearbyDevices = nearby) }
_state.update { it.copy(isScanning = false, nearbyDevices = nearby.toImmutableList()) }
devices
}.onFailure { e ->
Logger.error("Trezor scan failed", e, context = TAG)
Expand All @@ -114,7 +119,7 @@ class TrezorRepo @Inject constructor(
val devices = trezorService.listDevices()
val knownIds = _state.value.knownDevices.map { it.id }.toSet()
val nearby = devices.filter { it.id !in knownIds }
_state.update { it.copy(nearbyDevices = nearby) }
_state.update { it.copy(nearbyDevices = nearby.toImmutableList()) }
devices
}.onFailure { e ->
Logger.error("Trezor listDevices failed", e, context = TAG)
Expand All @@ -132,10 +137,7 @@ class TrezorRepo @Inject constructor(
?: _state.value.knownDevices.find { it.id == deviceId }?.let { known ->
TrezorDeviceInfo(
id = known.id,
transportType = when (known.transportType) {
"bluetooth" -> TrezorTransportType.BLUETOOTH
else -> TrezorTransportType.USB
},
transportType = known.transportType.toCoreTransportType(),
name = known.name,
path = known.path,
label = known.label,
Expand All @@ -149,9 +151,8 @@ class TrezorRepo @Inject constructor(
_state.update {
it.copy(
isConnecting = false,
connectedDevice = features,
connectedDeviceId = deviceId,
nearbyDevices = it.nearbyDevices.filter { d -> d.id != deviceId },
connected = ConnectedTrezorDevice(id = deviceId, features = features),
nearbyDevices = it.nearbyDevices.filter { d -> d.id != deviceId }.toImmutableList(),
)
}
features
Expand Down Expand Up @@ -316,12 +317,12 @@ class TrezorRepo @Inject constructor(
}

suspend fun disconnect(): Result<Unit> = withContext(ioDispatcher) {
runCatching {
TrezorDebugLog.log("DISCONNECT", "disconnect() called, connectedDeviceId=${_state.value.connectedDeviceId}")
runCatching { trezorService.disconnect() }
_state.update {
it.copy(connectedDevice = null, connectedDeviceId = null, lastAddress = null, lastPublicKey = null)
}
TrezorDebugLog.log("DISCONNECT", "disconnect() called, connectedDeviceId=${_state.value.connectedDeviceId}")
val result = runCatching { trezorService.disconnect() }
_state.update {
it.copy(connected = null, lastAddress = null, lastPublicKey = null)
}
result.onSuccess {
TrezorDebugLog.log("DISCONNECT", "disconnect() complete (credentials NOT cleared)")
}.onFailure { e ->
TrezorDebugLog.log("DISCONNECT", "FAILED: ${e.message}")
Expand Down Expand Up @@ -386,7 +387,7 @@ class TrezorRepo @Inject constructor(
initialize(walletIndex).getOrThrow()
}
if (trezorService.isConnected()) {
_state.value.connectedDevice ?: error("Connected but no features")
_state.value.connectedDevice ?: throw AppError("Connected but no features")
} else {
val scannedDevices = scan().getOrThrow()
val knownIds = knownDevices.map { it.id }.toSet()
Expand All @@ -396,7 +397,7 @@ class TrezorRepo @Inject constructor(
val idMatch = knownDevices.firstNotNullOfOrNull { known ->
scannedDevices.find { it.id == known.id }
}
val match = idMatch ?: usbDevice ?: error("No known device found nearby")
val match = idMatch ?: usbDevice ?: throw AppError("No known device found nearby")
connect(match.id).getOrThrow()
}
}.onSuccess {
Expand Down Expand Up @@ -436,15 +437,15 @@ class TrezorRepo @Inject constructor(
TrezorDebugLog.log("RECONNECT", "Preferring USB over BLE")
usbDevice
} else {
exactMatch ?: error("Device not found nearby — is it powered on?")
exactMatch ?: throw AppError("Device not found nearby — is it powered on?")
}
TrezorDebugLog.log("RECONNECT", "Found matching device: id=${device.id}, name=${device.name}")
TrezorDebugLog.log("RECONNECT", "Calling connectWithThpRetry...")
val features = connectWithThpRetry(device.id)
TrezorDebugLog.log("RECONNECT", "Connected! label=${features.label}, model=${features.model}")
addOrUpdateKnownDevice(device, features)
_state.update {
it.copy(isConnecting = false, connectedDevice = features, connectedDeviceId = device.id)
it.copy(isConnecting = false, connected = ConnectedTrezorDevice(id = device.id, features = features))
}
TrezorDebugLog.log("RECONNECT", "=== connectKnownDevice SUCCESS ===")
features
Expand All @@ -458,16 +459,21 @@ class TrezorRepo @Inject constructor(
suspend fun forgetDevice(deviceId: String): Result<Unit> = withContext(ioDispatcher) {
runCatching {
TrezorDebugLog.log("FORGET", "forgetDevice called for: $deviceId")
if (_state.value.connectedDeviceId == deviceId) {
runCatching { trezorService.disconnect() }
_state.update { it.copy(connectedDevice = null, connectedDeviceId = null) }
val disconnectResult = if (_state.value.connectedDeviceId == deviceId) {
runCatching { trezorService.disconnect() }.also {
_state.update { it.copy(connected = null) }
}
} else {
Result.success(Unit)
}
TrezorDebugLog.log("FORGET", "Clearing credentials...")
trezorTransport.clearDeviceCredential(deviceId)
runCatching { trezorService.clearCredentials(deviceId) }
val clearCredentialsResult = runCatching { trezorService.clearCredentials(deviceId) }
val updated = _state.value.knownDevices.filter { it.id != deviceId }
saveKnownDevices(updated)
_state.update { it.copy(knownDevices = updated) }
_state.update { it.copy(knownDevices = updated.toImmutableList()) }
disconnectResult.getOrThrow()
clearCredentialsResult.getOrThrow()
TrezorDebugLog.log("FORGET", "Device forgotten successfully")
Logger.info("Forgot device: '$deviceId'", context = TAG)
}.onFailure { e ->
Expand All @@ -488,7 +494,7 @@ class TrezorRepo @Inject constructor(
if (knownDevice?.id == currentId || path.contains(currentId)) {
Logger.warn("External disconnect detected for '$currentId'", context = TAG)
_state.update {
it.copy(connectedDevice = null, connectedDeviceId = null, error = "Device disconnected")
it.copy(connected = null, error = "Device disconnected")
}
}
}.launchIn(scope)
Expand All @@ -500,17 +506,14 @@ class TrezorRepo @Inject constructor(
id = deviceInfo.id,
name = deviceInfo.name,
path = deviceInfo.path,
transportType = when (deviceInfo.transportType) {
TrezorTransportType.BLUETOOTH -> "bluetooth"
TrezorTransportType.USB -> "usb"
},
transportType = deviceInfo.transportType.toKnownTransportType(),
label = features.label ?: deviceInfo.label,
model = features.model ?: deviceInfo.model,
lastConnectedAt = System.currentTimeMillis(),
)
val updated = existing.filter { it.id != known.id } + known
saveKnownDevices(updated)
_state.update { it.copy(knownDevices = updated) }
_state.update { it.copy(knownDevices = updated.toImmutableList()) }
}

private suspend fun loadKnownDevices(): List<KnownDevice> = runCatching {
Expand All @@ -531,15 +534,15 @@ class TrezorRepo @Inject constructor(
if (trezorService.isConnected()) return
val deviceId = _state.value.connectedDeviceId
?: _state.value.knownDevices.firstOrNull()?.id
?: error("No device to reconnect")
?: throw AppError("No device to reconnect")
if (!_state.value.isInitialized) {
initialize().getOrThrow()
}
val devices = trezorService.scan()
val device = devices.find { it.id == deviceId }
?: error("Device not found during reconnect")
?: throw AppError("Device not found during reconnect")
val features = connectWithThpRetry(device.id)
_state.update { it.copy(connectedDevice = features, connectedDeviceId = deviceId) }
_state.update { it.copy(connected = ConnectedTrezorDevice(id = deviceId, features = features)) }
}

suspend fun clearCredentials(deviceId: String): Result<Unit> = withContext(ioDispatcher) {
Expand Down Expand Up @@ -598,22 +601,53 @@ data class TrezorState(
val isScanning: Boolean = false,
val isConnecting: Boolean = false,
val isAutoReconnecting: Boolean = false,
val knownDevices: List<KnownDevice> = emptyList(),
val nearbyDevices: List<TrezorDeviceInfo> = emptyList(),
val connectedDevice: TrezorFeatures? = null,
val connectedDeviceId: String? = null,
val knownDevices: ImmutableList<KnownDevice> = persistentListOf(),
val nearbyDevices: ImmutableList<TrezorDeviceInfo> = persistentListOf(),
val connected: ConnectedTrezorDevice? = null,
val lastAddress: TrezorAddressResponse? = null,
val lastPublicKey: TrezorPublicKeyResponse? = null,
val error: String? = null,
) {
val connectedDevice: TrezorFeatures?
get() = connected?.features

val connectedDeviceId: String?
get() = connected?.id
}

@Stable
data class ConnectedTrezorDevice(
val id: String,
val features: TrezorFeatures,
)

@Serializable
@Immutable
data class KnownDevice(
val id: String,
val name: String?,
val path: String,
val transportType: String,
val transportType: KnownDeviceTransportType,
val label: String?,
val model: String?,
val lastConnectedAt: Long,
)

@Serializable
enum class KnownDeviceTransportType {
@SerialName("bluetooth")
BLUETOOTH,

@SerialName("usb")
USB,
}

private fun TrezorTransportType.toKnownTransportType(): KnownDeviceTransportType = when (this) {
TrezorTransportType.BLUETOOTH -> KnownDeviceTransportType.BLUETOOTH
TrezorTransportType.USB -> KnownDeviceTransportType.USB
}

private fun KnownDeviceTransportType.toCoreTransportType(): TrezorTransportType = when (this) {
KnownDeviceTransportType.BLUETOOTH -> TrezorTransportType.BLUETOOTH
KnownDeviceTransportType.USB -> TrezorTransportType.USB
}
15 changes: 11 additions & 4 deletions app/src/main/java/to/bitkit/services/TrezorDebugLog.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package to.bitkit.services

import kotlinx.collections.immutable.ImmutableList
import kotlinx.collections.immutable.persistentListOf
import kotlinx.collections.immutable.toImmutableList
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
Expand All @@ -10,8 +13,8 @@ import java.util.Locale

object TrezorDebugLog {
private const val MAX_LINES = 300
private val _lines = MutableStateFlow<List<String>>(emptyList())
val lines: StateFlow<List<String>> = _lines.asStateFlow()
private val _lines = MutableStateFlow<ImmutableList<String>>(persistentListOf())
val lines: StateFlow<ImmutableList<String>> = _lines.asStateFlow()

private val fmt = SimpleDateFormat("HH:mm:ss.SSS", Locale.US)

Expand All @@ -20,11 +23,15 @@ object TrezorDebugLog {
val line = "$ts [$tag] $msg"
_lines.update { current ->
val updated = current + line
if (updated.size > MAX_LINES) updated.takeLast(MAX_LINES) else updated
if (updated.size > MAX_LINES) {
updated.takeLast(MAX_LINES).toImmutableList()
} else {
updated.toImmutableList()
}
}
}

fun clear() {
_lines.update { emptyList() }
_lines.update { persistentListOf() }
}
}
Loading
Loading