diff --git a/app/src/main/java/to/bitkit/data/TrezorStore.kt b/app/src/main/java/to/bitkit/data/TrezorStore.kt index 12aceb195..2757cd1de 100644 --- a/app/src/main/java/to/bitkit/data/TrezorStore.kt +++ b/app/src/main/java/to/bitkit/data/TrezorStore.kt @@ -4,10 +4,13 @@ import android.content.Context import androidx.datastore.core.DataStore import androidx.datastore.dataStore import dagger.hilt.android.qualifiers.ApplicationContext +import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.first +import kotlinx.coroutines.withContext import kotlinx.serialization.Serializable import to.bitkit.data.serializers.TrezorDataSerializer +import to.bitkit.di.IoDispatcher import to.bitkit.repositories.KnownDevice import javax.inject.Inject import javax.inject.Singleton @@ -20,20 +23,24 @@ private val Context.trezorDataStore: DataStore by dataStore( @Singleton class TrezorStore @Inject constructor( @ApplicationContext private val context: Context, + @IoDispatcher private val ioDispatcher: CoroutineDispatcher, ) { private val store = context.trezorDataStore val data: Flow = store.data - suspend fun loadKnownDevices(): List = + suspend fun loadKnownDevices(): List = withContext(ioDispatcher) { store.data.first().knownDevices + } - suspend fun saveKnownDevices(devices: List) { + suspend fun saveKnownDevices(devices: List) = withContext(ioDispatcher) { store.updateData { it.copy(knownDevices = devices) } + Unit } - suspend fun reset() { + suspend fun reset() = withContext(ioDispatcher) { store.updateData { TrezorData() } + Unit } } diff --git a/app/src/main/java/to/bitkit/data/serializers/TrezorDataSerializer.kt b/app/src/main/java/to/bitkit/data/serializers/TrezorDataSerializer.kt index b9556998b..7a93a59f6 100644 --- a/app/src/main/java/to/bitkit/data/serializers/TrezorDataSerializer.kt +++ b/app/src/main/java/to/bitkit/data/serializers/TrezorDataSerializer.kt @@ -9,13 +9,15 @@ import java.io.InputStream import java.io.OutputStream object TrezorDataSerializer : Serializer { + private const val TAG = "TrezorDataSerializer" + override val defaultValue: TrezorData = TrezorData() override suspend fun readFrom(input: InputStream): TrezorData { return try { json.decodeFromString(input.readBytes().decodeToString()) } catch (e: SerializationException) { - Logger.error("Failed to deserialize: $e") + Logger.error("Deserialize Trezor data failed", e, context = TAG) defaultValue } } diff --git a/app/src/main/java/to/bitkit/repositories/TrezorRepo.kt b/app/src/main/java/to/bitkit/repositories/TrezorRepo.kt index 8dd2f5d0b..df68f6ebe 100644 --- a/app/src/main/java/to/bitkit/repositories/TrezorRepo.kt +++ b/app/src/main/java/to/bitkit/repositories/TrezorRepo.kt @@ -1,6 +1,7 @@ package to.bitkit.repositories import android.content.Context +import androidx.compose.runtime.Immutable import androidx.compose.runtime.Stable import com.synonym.bitkitcore.AccountInfoResult import com.synonym.bitkitcore.AccountType @@ -21,6 +22,9 @@ import com.synonym.bitkitcore.TrezorSignedTx import com.synonym.bitkitcore.TrezorTransportType import com.synonym.bitkitcore.WalletParams import dagger.hilt.android.qualifiers.ApplicationContext +import kotlinx.collections.immutable.ImmutableList +import kotlinx.collections.immutable.persistentListOf +import kotlinx.collections.immutable.toImmutableList import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.flow.MutableStateFlow @@ -29,6 +33,7 @@ import kotlinx.coroutines.flow.launchIn import kotlinx.coroutines.flow.onEach import kotlinx.coroutines.flow.update import kotlinx.coroutines.withContext +import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable import to.bitkit.data.TrezorStore import to.bitkit.di.IoDispatcher @@ -88,7 +93,7 @@ class TrezorRepo @Inject constructor( Logger.debug("Initializing Trezor with credential path: '$credentialPath'", context = TAG) trezorService.initialize(credentialPath) val known = loadKnownDevices() - _state.update { it.copy(isInitialized = true, knownDevices = known, error = null) } + _state.update { it.copy(isInitialized = true, knownDevices = known.toImmutableList(), error = null) } }.onFailure { e -> Logger.error("Trezor init failed", e, context = TAG) _state.update { it.copy(error = e.message) } @@ -101,7 +106,7 @@ class TrezorRepo @Inject constructor( val devices = trezorService.scan() val knownIds = _state.value.knownDevices.map { it.id }.toSet() val nearby = devices.filter { it.id !in knownIds } - _state.update { it.copy(isScanning = false, nearbyDevices = nearby) } + _state.update { it.copy(isScanning = false, nearbyDevices = nearby.toImmutableList()) } devices }.onFailure { e -> Logger.error("Trezor scan failed", e, context = TAG) @@ -114,7 +119,7 @@ class TrezorRepo @Inject constructor( val devices = trezorService.listDevices() val knownIds = _state.value.knownDevices.map { it.id }.toSet() val nearby = devices.filter { it.id !in knownIds } - _state.update { it.copy(nearbyDevices = nearby) } + _state.update { it.copy(nearbyDevices = nearby.toImmutableList()) } devices }.onFailure { e -> Logger.error("Trezor listDevices failed", e, context = TAG) @@ -132,10 +137,7 @@ class TrezorRepo @Inject constructor( ?: _state.value.knownDevices.find { it.id == deviceId }?.let { known -> TrezorDeviceInfo( id = known.id, - transportType = when (known.transportType) { - "bluetooth" -> TrezorTransportType.BLUETOOTH - else -> TrezorTransportType.USB - }, + transportType = known.transportType.toCoreTransportType(), name = known.name, path = known.path, label = known.label, @@ -149,9 +151,8 @@ class TrezorRepo @Inject constructor( _state.update { it.copy( isConnecting = false, - connectedDevice = features, - connectedDeviceId = deviceId, - nearbyDevices = it.nearbyDevices.filter { d -> d.id != deviceId }, + connected = ConnectedTrezorDevice(id = deviceId, features = features), + nearbyDevices = it.nearbyDevices.filter { d -> d.id != deviceId }.toImmutableList(), ) } features @@ -316,12 +317,12 @@ class TrezorRepo @Inject constructor( } suspend fun disconnect(): Result = withContext(ioDispatcher) { - runCatching { - TrezorDebugLog.log("DISCONNECT", "disconnect() called, connectedDeviceId=${_state.value.connectedDeviceId}") - runCatching { trezorService.disconnect() } - _state.update { - it.copy(connectedDevice = null, connectedDeviceId = null, lastAddress = null, lastPublicKey = null) - } + TrezorDebugLog.log("DISCONNECT", "disconnect() called, connectedDeviceId=${_state.value.connectedDeviceId}") + val result = runCatching { trezorService.disconnect() } + _state.update { + it.copy(connected = null, lastAddress = null, lastPublicKey = null) + } + result.onSuccess { TrezorDebugLog.log("DISCONNECT", "disconnect() complete (credentials NOT cleared)") }.onFailure { e -> TrezorDebugLog.log("DISCONNECT", "FAILED: ${e.message}") @@ -386,7 +387,7 @@ class TrezorRepo @Inject constructor( initialize(walletIndex).getOrThrow() } if (trezorService.isConnected()) { - _state.value.connectedDevice ?: error("Connected but no features") + _state.value.connectedDevice ?: throw AppError("Connected but no features") } else { val scannedDevices = scan().getOrThrow() val knownIds = knownDevices.map { it.id }.toSet() @@ -396,7 +397,7 @@ class TrezorRepo @Inject constructor( val idMatch = knownDevices.firstNotNullOfOrNull { known -> scannedDevices.find { it.id == known.id } } - val match = idMatch ?: usbDevice ?: error("No known device found nearby") + val match = idMatch ?: usbDevice ?: throw AppError("No known device found nearby") connect(match.id).getOrThrow() } }.onSuccess { @@ -436,7 +437,7 @@ class TrezorRepo @Inject constructor( TrezorDebugLog.log("RECONNECT", "Preferring USB over BLE") usbDevice } else { - exactMatch ?: error("Device not found nearby — is it powered on?") + exactMatch ?: throw AppError("Device not found nearby — is it powered on?") } TrezorDebugLog.log("RECONNECT", "Found matching device: id=${device.id}, name=${device.name}") TrezorDebugLog.log("RECONNECT", "Calling connectWithThpRetry...") @@ -444,7 +445,7 @@ class TrezorRepo @Inject constructor( TrezorDebugLog.log("RECONNECT", "Connected! label=${features.label}, model=${features.model}") addOrUpdateKnownDevice(device, features) _state.update { - it.copy(isConnecting = false, connectedDevice = features, connectedDeviceId = device.id) + it.copy(isConnecting = false, connected = ConnectedTrezorDevice(id = device.id, features = features)) } TrezorDebugLog.log("RECONNECT", "=== connectKnownDevice SUCCESS ===") features @@ -458,16 +459,21 @@ class TrezorRepo @Inject constructor( suspend fun forgetDevice(deviceId: String): Result = withContext(ioDispatcher) { runCatching { TrezorDebugLog.log("FORGET", "forgetDevice called for: $deviceId") - if (_state.value.connectedDeviceId == deviceId) { - runCatching { trezorService.disconnect() } - _state.update { it.copy(connectedDevice = null, connectedDeviceId = null) } + val disconnectResult = if (_state.value.connectedDeviceId == deviceId) { + runCatching { trezorService.disconnect() }.also { + _state.update { it.copy(connected = null) } + } + } else { + Result.success(Unit) } TrezorDebugLog.log("FORGET", "Clearing credentials...") trezorTransport.clearDeviceCredential(deviceId) - runCatching { trezorService.clearCredentials(deviceId) } + val clearCredentialsResult = runCatching { trezorService.clearCredentials(deviceId) } val updated = _state.value.knownDevices.filter { it.id != deviceId } saveKnownDevices(updated) - _state.update { it.copy(knownDevices = updated) } + _state.update { it.copy(knownDevices = updated.toImmutableList()) } + disconnectResult.getOrThrow() + clearCredentialsResult.getOrThrow() TrezorDebugLog.log("FORGET", "Device forgotten successfully") Logger.info("Forgot device: '$deviceId'", context = TAG) }.onFailure { e -> @@ -488,7 +494,7 @@ class TrezorRepo @Inject constructor( if (knownDevice?.id == currentId || path.contains(currentId)) { Logger.warn("External disconnect detected for '$currentId'", context = TAG) _state.update { - it.copy(connectedDevice = null, connectedDeviceId = null, error = "Device disconnected") + it.copy(connected = null, error = "Device disconnected") } } }.launchIn(scope) @@ -500,17 +506,14 @@ class TrezorRepo @Inject constructor( id = deviceInfo.id, name = deviceInfo.name, path = deviceInfo.path, - transportType = when (deviceInfo.transportType) { - TrezorTransportType.BLUETOOTH -> "bluetooth" - TrezorTransportType.USB -> "usb" - }, + transportType = deviceInfo.transportType.toKnownTransportType(), label = features.label ?: deviceInfo.label, model = features.model ?: deviceInfo.model, lastConnectedAt = System.currentTimeMillis(), ) val updated = existing.filter { it.id != known.id } + known saveKnownDevices(updated) - _state.update { it.copy(knownDevices = updated) } + _state.update { it.copy(knownDevices = updated.toImmutableList()) } } private suspend fun loadKnownDevices(): List = runCatching { @@ -531,15 +534,15 @@ class TrezorRepo @Inject constructor( if (trezorService.isConnected()) return val deviceId = _state.value.connectedDeviceId ?: _state.value.knownDevices.firstOrNull()?.id - ?: error("No device to reconnect") + ?: throw AppError("No device to reconnect") if (!_state.value.isInitialized) { initialize().getOrThrow() } val devices = trezorService.scan() val device = devices.find { it.id == deviceId } - ?: error("Device not found during reconnect") + ?: throw AppError("Device not found during reconnect") val features = connectWithThpRetry(device.id) - _state.update { it.copy(connectedDevice = features, connectedDeviceId = deviceId) } + _state.update { it.copy(connected = ConnectedTrezorDevice(id = deviceId, features = features)) } } suspend fun clearCredentials(deviceId: String): Result = withContext(ioDispatcher) { @@ -598,22 +601,53 @@ data class TrezorState( val isScanning: Boolean = false, val isConnecting: Boolean = false, val isAutoReconnecting: Boolean = false, - val knownDevices: List = emptyList(), - val nearbyDevices: List = emptyList(), - val connectedDevice: TrezorFeatures? = null, - val connectedDeviceId: String? = null, + val knownDevices: ImmutableList = persistentListOf(), + val nearbyDevices: ImmutableList = persistentListOf(), + val connected: ConnectedTrezorDevice? = null, val lastAddress: TrezorAddressResponse? = null, val lastPublicKey: TrezorPublicKeyResponse? = null, val error: String? = null, +) { + val connectedDevice: TrezorFeatures? + get() = connected?.features + + val connectedDeviceId: String? + get() = connected?.id +} + +@Stable +data class ConnectedTrezorDevice( + val id: String, + val features: TrezorFeatures, ) @Serializable +@Immutable data class KnownDevice( val id: String, val name: String?, val path: String, - val transportType: String, + val transportType: KnownDeviceTransportType, val label: String?, val model: String?, val lastConnectedAt: Long, ) + +@Serializable +enum class KnownDeviceTransportType { + @SerialName("bluetooth") + BLUETOOTH, + + @SerialName("usb") + USB, +} + +private fun TrezorTransportType.toKnownTransportType(): KnownDeviceTransportType = when (this) { + TrezorTransportType.BLUETOOTH -> KnownDeviceTransportType.BLUETOOTH + TrezorTransportType.USB -> KnownDeviceTransportType.USB +} + +private fun KnownDeviceTransportType.toCoreTransportType(): TrezorTransportType = when (this) { + KnownDeviceTransportType.BLUETOOTH -> TrezorTransportType.BLUETOOTH + KnownDeviceTransportType.USB -> TrezorTransportType.USB +} diff --git a/app/src/main/java/to/bitkit/services/TrezorDebugLog.kt b/app/src/main/java/to/bitkit/services/TrezorDebugLog.kt index 8ba7c0bf7..48ffb1115 100644 --- a/app/src/main/java/to/bitkit/services/TrezorDebugLog.kt +++ b/app/src/main/java/to/bitkit/services/TrezorDebugLog.kt @@ -1,5 +1,8 @@ package to.bitkit.services +import kotlinx.collections.immutable.ImmutableList +import kotlinx.collections.immutable.persistentListOf +import kotlinx.collections.immutable.toImmutableList import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.asStateFlow @@ -10,8 +13,8 @@ import java.util.Locale object TrezorDebugLog { private const val MAX_LINES = 300 - private val _lines = MutableStateFlow>(emptyList()) - val lines: StateFlow> = _lines.asStateFlow() + private val _lines = MutableStateFlow>(persistentListOf()) + val lines: StateFlow> = _lines.asStateFlow() private val fmt = SimpleDateFormat("HH:mm:ss.SSS", Locale.US) @@ -20,11 +23,15 @@ object TrezorDebugLog { val line = "$ts [$tag] $msg" _lines.update { current -> val updated = current + line - if (updated.size > MAX_LINES) updated.takeLast(MAX_LINES) else updated + if (updated.size > MAX_LINES) { + updated.takeLast(MAX_LINES).toImmutableList() + } else { + updated.toImmutableList() + } } } fun clear() { - _lines.update { emptyList() } + _lines.update { persistentListOf() } } } diff --git a/app/src/main/java/to/bitkit/services/TrezorTransport.kt b/app/src/main/java/to/bitkit/services/TrezorTransport.kt index 13b946c6d..826e5d0e9 100644 --- a/app/src/main/java/to/bitkit/services/TrezorTransport.kt +++ b/app/src/main/java/to/bitkit/services/TrezorTransport.kt @@ -158,10 +158,8 @@ class TrezorTransport @Inject constructor( bluetoothManager.adapter } - // USB connections private val usbConnections = ConcurrentHashMap() - // BLE connections private val bleConnections = ConcurrentHashMap() private val discoveredBleDevices = ConcurrentHashMap() @@ -184,12 +182,9 @@ class TrezorTransport @Inject constructor( @Volatile var writeStatus: Int = BluetoothGatt.GATT_SUCCESS, ) - // ==================== TrezorTransportCallback Implementation ==================== - override fun enumerateDevices(): List { val devices = mutableListOf() - // Enumerate USB devices runCatching { usbManager.deviceList.values .filter { isTrezorDevice(it) } @@ -209,7 +204,6 @@ class TrezorTransport @Inject constructor( Logger.error("USB enumerate failed", it, context = TAG) } - // Enumerate Bluetooth devices runCatching { enumerateBleDevices() }.onSuccess { @@ -272,58 +266,65 @@ class TrezorTransport @Inject constructor( messageType: UShort, data: ByteArray, ): TrezorCallMessageResult? { - // For BLE/THP devices, the Rust side now handles THP protocol directly. - // This callback returns null to let Rust use its built-in THP implementation. Logger.debug( - "callMessage called for '$path', type='$messageType' - returning null (Rust handles THP)", + "Delegating callMessage for '$path', type='$messageType' to core THP handling", context = TAG, ) return null } override fun getPairingCode(): String { - // This is called by Rust during BLE THP pairing when the device - // displays a 6-digit code that must be entered. - // - // We use a blocking approach with a latch. The UI observes needsPairingCode - // and shows a dialog. When the user enters the code, submitPairingCode() - // is called which releases the latch. TrezorDebugLog.log("PAIR", ">>> PAIRING CODE REQUESTED - Device requires re-pairing! <<<") - Logger.info(">>> PAIRING CODE REQUESTED <<<", context = TAG) - Logger.info("Look at your Trezor screen for a 6-digit code", context = TAG) + Logger.info("Requested pairing code from user", context = TAG) + Logger.info("Asked user to read the 6-digit code from Trezor screen", context = TAG) val latch = CountDownLatch(1) synchronized(pairingCodeLock) { - submittedPairingCode = "" + pairingCodeResult = null pairingCodeRequest = PairingCodeRequest(isRequested = true, latch = latch) _needsPairingCode.update { true } } - try { - // Wait for user to enter the code (with timeout) + val result = try { val received = latch.await(PAIRING_CODE_TIMEOUT_MS, TimeUnit.MILLISECONDS) - - if (!received) { - Logger.warn("Pairing code entry timed out", context = TAG) - _needsPairingCode.update { false } - return "" + synchronized(pairingCodeLock) { + val result = if (received) { + pairingCodeResult ?: PairingCodeResult.Cancelled + } else { + PairingCodeResult.TimedOut + } + clearPairingCodeRequest() + result } - - val code = submittedPairingCode - Logger.info("Pairing code received (len='${code.length}')", context = TAG) - return code } catch (e: InterruptedException) { - Logger.error("Pairing code wait interrupted", e, context = TAG) - _needsPairingCode.update { false } - return "" + Thread.currentThread().interrupt() + synchronized(pairingCodeLock) { + clearPairingCodeRequest() + } + PairingCodeResult.Interrupted(e) + } + + return when (result) { + is PairingCodeResult.Submitted -> { + Logger.info("Received pairing code (len='${result.code.length}')", context = TAG) + result.code + } + PairingCodeResult.Cancelled -> { + Logger.info("Cancelled pairing code entry", context = TAG) + "" + } + PairingCodeResult.TimedOut -> { + Logger.warn("Timed out waiting for pairing code entry", context = TAG) + "" + } + is PairingCodeResult.Interrupted -> { + Logger.error("Interrupted pairing code wait", result.error, context = TAG) + "" + } } } - /** - * Pairing code request state for UI observation. - * When getPairingCode() is called by Rust, we set this to true and wait. - */ data class PairingCodeRequest( val isRequested: Boolean = false, val latch: CountDownLatch? = null, @@ -333,10 +334,26 @@ class TrezorTransport @Inject constructor( private var pairingCodeRequest: PairingCodeRequest = PairingCodeRequest() @Volatile - private var submittedPairingCode: String = "" + private var pairingCodeResult: PairingCodeResult? = null private val pairingCodeLock = Object() + private sealed interface PairingCodeResult { + data class Submitted(val code: String) : PairingCodeResult + + data object Cancelled : PairingCodeResult + + data object TimedOut : PairingCodeResult + + data class Interrupted(val error: InterruptedException) : PairingCodeResult + } + + private fun clearPairingCodeRequest() { + pairingCodeRequest = PairingCodeRequest() + pairingCodeResult = null + _needsPairingCode.update { false } + } + /** * Flow to observe when a pairing code is needed. * UI should show a dialog when this is true. @@ -350,18 +367,20 @@ class TrezorTransport @Inject constructor( */ fun submitPairingCode(code: String) { synchronized(pairingCodeLock) { - Logger.info("Pairing code submitted (len='${code.length}')", context = TAG) - submittedPairingCode = code + Logger.info("Submitted pairing code (len='${code.length}')", context = TAG) + pairingCodeResult = PairingCodeResult.Submitted(code) _needsPairingCode.update { false } pairingCodeRequest.latch?.countDown() } } - /** - * Cancel pairing code entry (submit empty string). - */ fun cancelPairingCode() { - submitPairingCode("") + synchronized(pairingCodeLock) { + Logger.info("Cancelled pairing code entry", context = TAG) + pairingCodeResult = PairingCodeResult.Cancelled + _needsPairingCode.update { false } + pairingCodeRequest.latch?.countDown() + } } @Suppress("TooGenericExceptionCaught") @@ -375,8 +394,12 @@ class TrezorTransport @Inject constructor( if (credentialJson.isEmpty()) { val existed = file.exists() - file.delete() + val deleted = !existed || file.delete() TrezorDebugLog.log("SAVE", "CLEARED credential (file existed=$existed)") + if (!deleted) { + Logger.warn("Clear THP credential file failed for '${file.absolutePath}'", context = TAG) + return false + } Logger.info( "Cleared THP credential for device: '$deviceId' (path='${file.absolutePath}')", context = TAG, @@ -386,7 +409,6 @@ class TrezorTransport @Inject constructor( file.writeText(credentialJson) - // Immediately verify the file was written val verifyExists = file.exists() val verifySize = if (verifyExists) file.length() else 0 TrezorDebugLog.log( @@ -395,6 +417,7 @@ class TrezorTransport @Inject constructor( ) if (!verifyExists || verifySize == 0L) { TrezorDebugLog.log("SAVE", "WARNING: File verification FAILED after write!") + return false } Logger.info( @@ -423,7 +446,6 @@ class TrezorTransport @Inject constructor( TrezorDebugLog.log("LOAD", "loadThpCredential for: $deviceId") TrezorDebugLog.log("LOAD", "File: ${file.absolutePath}, exists=$exists, size=$size") - // List all files in credential directory for debugging val allFiles = credentialDir.listFiles()?.map { "${it.name} (${it.length()}b)" } ?: emptyList() TrezorDebugLog.log("LOAD", "All credential files: $allFiles") @@ -470,8 +492,6 @@ class TrezorTransport @Inject constructor( return File(credentialDir, "$sanitizedId.json") } - // ==================== USB Methods ==================== - /** * Request USB permission for a device and block until the user responds. * Returns true if permission was granted, false otherwise. @@ -511,7 +531,6 @@ class TrezorTransport @Inject constructor( Logger.info("Requesting USB permission for '${device.deviceName}'", context = TAG) usbManager.requestPermission(device, permissionIntent) - // Block until user responds (up to 60 seconds) val responded = latch.await(USB_PERMISSION_TIMEOUT_MS, TimeUnit.MILLISECONDS) if (!responded) { Logger.warn("USB permission request timed out", context = TAG) @@ -551,7 +570,6 @@ class TrezorTransport @Inject constructor( @Suppress("TooGenericExceptionCaught", "ReturnCount") private fun openUsbDevice(path: String): TrezorTransportWriteResult { return try { - // Close existing connection if any closeUsbDevice(path) val device = usbManager.deviceList[path] @@ -700,8 +718,6 @@ class TrezorTransport @Inject constructor( } } - // ==================== Bluetooth Methods ==================== - @SuppressLint("MissingPermission") private fun enumerateBleDevices(): List { if (bluetoothAdapter?.isEnabled != true) { @@ -711,7 +727,6 @@ class TrezorTransport @Inject constructor( val scanner = bluetoothAdapter?.bluetoothLeScanner ?: return emptyList() - // Start fresh scan discoveredBleDevices.clear() val scanFilter = ScanFilter.Builder() @@ -725,7 +740,6 @@ class TrezorTransport @Inject constructor( scanner.startScan(listOf(scanFilter), scanSettings, bleScanCallback) Logger.debug("BLE scan started", context = TAG) - // Wait for scan results Thread.sleep(SCAN_DURATION_MS) scanner.stopScan(bleScanCallback) @@ -802,10 +816,8 @@ class TrezorTransport @Inject constructor( val device = discoveredBleDevices[address] ?: return TrezorTransportWriteResult(success = false, error = "Device not found: $path") - // Close existing connection closeBleDevice(path) - // Check if device needs bonding val bondError = waitForBonding(device, address) if (bondError != null) return bondError @@ -832,10 +844,8 @@ class TrezorTransport @Inject constructor( return TrezorTransportWriteResult(success = false, error = "Failed to connect") } - // Request high-priority BLE connection for faster, more reliable handshake gatt.requestConnectionPriority(BluetoothGatt.CONNECTION_PRIORITY_HIGH) - // Drain any stale notifications from a previous connection attempt val staleCount = updatedConnection.readQueue.size if (staleCount > 0) { updatedConnection.readQueue.clear() @@ -856,13 +866,14 @@ class TrezorTransport @Inject constructor( ?: return TrezorTransportWriteResult(success = true, error = "") userInitiatedCloseSet.add(path) - try { + return try { val disconnectLatch = CountDownLatch(1) bleConnections[path] = connection.copy(disconnectLatch = disconnectLatch) connection.gatt.disconnect() val disconnected = disconnectLatch.await(DISCONNECT_TIMEOUT_MS, TimeUnit.MILLISECONDS) + val timeoutError = if (disconnected) null else "BLE disconnect timed out; forced close" if (!disconnected) { Logger.warn("BLE disconnect timeout, forcing close: '$path'", context = TAG) } @@ -870,14 +881,14 @@ class TrezorTransport @Inject constructor( bleConnections.remove(path) connection.gatt.close() Thread.sleep(100) + Logger.info("BLE device closed: '$path'", context = TAG) + TrezorTransportWriteResult(success = timeoutError == null, error = timeoutError.orEmpty()) } catch (e: Exception) { Logger.error("BLE close failed", e, context = TAG) + TrezorTransportWriteResult(success = false, error = e.message ?: "BLE close failed") } finally { userInitiatedCloseSet.remove(path) } - - Logger.info("BLE device closed: '$path'", context = TAG) - return TrezorTransportWriteResult(success = true, error = "") } @Suppress("TooGenericExceptionCaught") @@ -927,7 +938,6 @@ class TrezorTransport @Inject constructor( } return try { - // Retry logic for transient GATT busy states var lastError = "Write initiation failed" for (attempt in 1..BLE_WRITE_RETRY_COUNT) { val writeLatch = CountDownLatch(1) @@ -940,7 +950,6 @@ class TrezorTransport @Inject constructor( val success = connection.gatt.writeCharacteristic(writeChar) if (!success) { - // Get more diagnostic info val connState = connection.isConnected val charPropsHex = Integer.toHexString(writeChar.properties) Logger.warn( @@ -982,7 +991,6 @@ class TrezorTransport @Inject constructor( return TrezorTransportWriteResult(success = false, error = lastError) } - // Success! Logger.debug("BLE wrote '${data.size}' bytes to '$path' (attempt '$attempt')", context = TAG) // Small delay between writes to avoid overwhelming the GATT @@ -1004,6 +1012,17 @@ class TrezorTransport @Inject constructor( val path = "ble:${gatt.device.address}" val connection = bleConnections[path] + if (status != BluetoothGatt.GATT_SUCCESS) { + Logger.warn("BLE connection state changed with status '$status' for '$path'", context = TAG) + connection?.isConnected = false + connection?.connectionLatch?.countDown() + connection?.disconnectLatch?.countDown() + if (!userInitiatedCloseSet.remove(path)) { + _externalDisconnect.tryEmit(path) + } + return + } + when (newState) { BluetoothProfile.STATE_CONNECTED -> { Logger.debug("BLE connected, requesting MTU: '$path'", context = TAG) @@ -1068,7 +1087,6 @@ class TrezorTransport @Inject constructor( gatt.setCharacteristicNotification(notifyChar, true) - // Also subscribe to PUSH characteristic val pushChar = service.getCharacteristic(PUSH_CHAR_UUID) if (pushChar != null) { gatt.setCharacteristicNotification(pushChar, true) @@ -1198,8 +1216,6 @@ class TrezorTransport @Inject constructor( return result } - // ==================== Utility Methods ==================== - private fun isBleDevice(path: String): Boolean = path.startsWith("ble:") private fun isTrezorDevice(device: UsbDevice): Boolean { diff --git a/app/src/main/java/to/bitkit/ui/screens/trezor/AddressSection.kt b/app/src/main/java/to/bitkit/ui/screens/trezor/AddressSection.kt index 788aa2619..6f4e83c2d 100644 --- a/app/src/main/java/to/bitkit/ui/screens/trezor/AddressSection.kt +++ b/app/src/main/java/to/bitkit/ui/screens/trezor/AddressSection.kt @@ -150,7 +150,7 @@ private fun PreviewAddressSectionLoading() { AppThemeSurface { AddressSection( trezorState = TrezorPreviewData.connectedState, - uiState = TrezorUiState(isGettingAddress = true), + uiState = TrezorUiState(network = TrezorNetworkState(isGettingAddress = true)), onGetAddress = {}, onIncrementIndex = {}, ) diff --git a/app/src/main/java/to/bitkit/ui/screens/trezor/BalanceLookupSection.kt b/app/src/main/java/to/bitkit/ui/screens/trezor/BalanceLookupSection.kt index 20ce7df30..46a7f3b73 100644 --- a/app/src/main/java/to/bitkit/ui/screens/trezor/BalanceLookupSection.kt +++ b/app/src/main/java/to/bitkit/ui/screens/trezor/BalanceLookupSection.kt @@ -330,7 +330,12 @@ private fun PreviewBalanceLookupWithAddressInfo() { private fun PreviewBalanceLookupLoading() { AppThemeSurface { BalanceLookupSection( - uiState = TrezorUiState(lookupInput = "xpub6C...", isLookingUp = true), + uiState = TrezorUiState( + lookup = TrezorLookupState( + input = "xpub6C...", + isLookingUp = true, + ), + ), isDeviceConnected = false, onInputChange = {}, onLookup = {}, diff --git a/app/src/main/java/to/bitkit/ui/screens/trezor/DeviceListSection.kt b/app/src/main/java/to/bitkit/ui/screens/trezor/DeviceListSection.kt index 198cc3797..3eb472eb3 100644 --- a/app/src/main/java/to/bitkit/ui/screens/trezor/DeviceListSection.kt +++ b/app/src/main/java/to/bitkit/ui/screens/trezor/DeviceListSection.kt @@ -22,6 +22,7 @@ import com.synonym.bitkitcore.TrezorDeviceInfo import com.synonym.bitkitcore.TrezorTransportType import to.bitkit.R import to.bitkit.repositories.KnownDevice +import to.bitkit.repositories.KnownDeviceTransportType import to.bitkit.ui.components.Caption import to.bitkit.ui.components.CaptionB import to.bitkit.ui.components.HorizontalSpacer @@ -96,10 +97,9 @@ internal fun KnownDeviceCard( ) { Icon( painter = painterResource( - if (device.transportType == "bluetooth") { - R.drawable.ic_broadcast - } else { - R.drawable.ic_git_branch + when (device.transportType) { + KnownDeviceTransportType.BLUETOOTH -> R.drawable.ic_broadcast + KnownDeviceTransportType.USB -> R.drawable.ic_git_branch } ), contentDescription = null, @@ -119,7 +119,10 @@ internal fun KnownDeviceCard( verticalAlignment = Alignment.CenterVertically, ) { Caption( - text = if (device.transportType == "bluetooth") "Bluetooth" else "USB", + text = when (device.transportType) { + KnownDeviceTransportType.BLUETOOTH -> "Bluetooth" + KnownDeviceTransportType.USB -> "USB" + }, color = Colors.White50, ) Caption( diff --git a/app/src/main/java/to/bitkit/ui/screens/trezor/PublicKeySection.kt b/app/src/main/java/to/bitkit/ui/screens/trezor/PublicKeySection.kt index 6d87f5b4b..86480a55c 100644 --- a/app/src/main/java/to/bitkit/ui/screens/trezor/PublicKeySection.kt +++ b/app/src/main/java/to/bitkit/ui/screens/trezor/PublicKeySection.kt @@ -176,7 +176,7 @@ private fun PreviewPublicKeySectionLoading() { AppThemeSurface { PublicKeySection( trezorState = TrezorPreviewData.connectedState, - uiState = TrezorUiState(isGettingPublicKey = true), + uiState = TrezorUiState(network = TrezorNetworkState(isGettingPublicKey = true)), onGetPublicKey = {}, ) } diff --git a/app/src/main/java/to/bitkit/ui/screens/trezor/SendTransactionSection.kt b/app/src/main/java/to/bitkit/ui/screens/trezor/SendTransactionSection.kt index ed6b76f1a..ed88355dc 100644 --- a/app/src/main/java/to/bitkit/ui/screens/trezor/SendTransactionSection.kt +++ b/app/src/main/java/to/bitkit/ui/screens/trezor/SendTransactionSection.kt @@ -68,8 +68,8 @@ internal fun SendTransactionSection( ) VerticalSpacer(8.dp) - when (uiState.sendStep) { - SendStep.FORM -> ComposeForm( + when (val step = uiState.sendStep) { + SendStep.Form -> ComposeForm( uiState = uiState, onAddressChange = onAddressChange, onAmountChange = onAmountChange, @@ -78,24 +78,20 @@ internal fun SendTransactionSection( onCoinSelectionChange = onCoinSelectionChange, onCompose = onCompose, ) - SendStep.REVIEW -> uiState.composeResult?.let { result -> - ReviewSection( - result = result, - isDeviceConnected = isDeviceConnected, - isSigning = uiState.isSigning, - onSign = onSign, - onBack = onBack, - ) - } - SendStep.SIGNED -> uiState.signedTxResult?.let { signedTx -> - SignedResultSection( - signedTx = signedTx, - isBroadcasting = uiState.isBroadcasting, - broadcastTxid = uiState.broadcastTxid, - onBroadcast = onBroadcast, - onReset = onReset, - ) - } + is SendStep.Review -> ReviewSection( + result = step.composeResult, + isDeviceConnected = isDeviceConnected, + isSigning = uiState.isSigning, + onSign = onSign, + onBack = onBack, + ) + is SendStep.Signed -> SignedResultSection( + signedTx = step.signedTx, + isBroadcasting = uiState.isBroadcasting, + broadcastTxid = step.broadcastTxid, + onBroadcast = onBroadcast, + onReset = onReset, + ) } } } @@ -405,9 +401,11 @@ private fun PreviewSendFormFilled() { AppThemeSurface { SendTransactionSection( uiState = TrezorUiState( - sendAddress = "bc1qxy2kgdygjrsqtzq2n0yrf2493p83kkfjhx0wlh", - sendAmountSats = "45000", - sendFeeRate = "5", + send = TrezorSendState( + address = "bc1qxy2kgdygjrsqtzq2n0yrf2493p83kkfjhx0wlh", + amountSats = "45000", + feeRate = "5", + ), ), isDeviceConnected = true, onAddressChange = {}, diff --git a/app/src/main/java/to/bitkit/ui/screens/trezor/SignMessageSection.kt b/app/src/main/java/to/bitkit/ui/screens/trezor/SignMessageSection.kt index 22846e953..14ce5ebfa 100644 --- a/app/src/main/java/to/bitkit/ui/screens/trezor/SignMessageSection.kt +++ b/app/src/main/java/to/bitkit/ui/screens/trezor/SignMessageSection.kt @@ -158,7 +158,7 @@ private fun PreviewSignMessageSectionWithSignature() { private fun PreviewSignMessageSectionSigning() { AppThemeSurface { SignMessageSection( - uiState = TrezorUiState(isSigningMessage = true), + uiState = TrezorUiState(message = TrezorMessageState(isSigningMessage = true)), onMessageChange = {}, onSignMessage = {}, onVerifyMessage = {}, diff --git a/app/src/main/java/to/bitkit/ui/screens/trezor/TransactionHistorySection.kt b/app/src/main/java/to/bitkit/ui/screens/trezor/TransactionHistorySection.kt index 5da323a3d..70a8f8dab 100644 --- a/app/src/main/java/to/bitkit/ui/screens/trezor/TransactionHistorySection.kt +++ b/app/src/main/java/to/bitkit/ui/screens/trezor/TransactionHistorySection.kt @@ -175,7 +175,12 @@ private fun PreviewTransactionHistoryEmpty() { private fun PreviewTransactionHistoryLoading() { AppThemeSurface { TransactionHistorySection( - uiState = TrezorUiState(txHistoryInput = "vpub5Y...", isLoadingTxHistory = true), + uiState = TrezorUiState( + txHistory = TrezorTxHistoryState( + input = "vpub5Y...", + isLoading = true, + ), + ), onInputChange = {}, onLookup = {}, ) diff --git a/app/src/main/java/to/bitkit/ui/screens/trezor/TrezorPreviewData.kt b/app/src/main/java/to/bitkit/ui/screens/trezor/TrezorPreviewData.kt index 164a121b1..23b2d0fdc 100644 --- a/app/src/main/java/to/bitkit/ui/screens/trezor/TrezorPreviewData.kt +++ b/app/src/main/java/to/bitkit/ui/screens/trezor/TrezorPreviewData.kt @@ -17,7 +17,9 @@ import com.synonym.bitkitcore.TrezorSignedTx import com.synonym.bitkitcore.TrezorTransportType import com.synonym.bitkitcore.TxDirection import com.synonym.bitkitcore.WalletBalance +import to.bitkit.repositories.ConnectedTrezorDevice import to.bitkit.repositories.KnownDevice +import to.bitkit.repositories.KnownDeviceTransportType import to.bitkit.repositories.TrezorState import com.synonym.bitkitcore.Network as BitkitCoreNetwork @@ -55,7 +57,7 @@ internal object TrezorPreviewData { id = "usb-1", name = "Trezor Safe 5", path = "/dev/usb/001", - transportType = "usb", + transportType = KnownDeviceTransportType.USB, label = "My Savings", model = "Safe 5", lastConnectedAt = 1_700_000_000_000L, @@ -65,7 +67,7 @@ internal object TrezorPreviewData { id = "ble-1", name = "Trezor Safe 7", path = "AA:BB:CC:DD:EE:FF", - transportType = "bluetooth", + transportType = KnownDeviceTransportType.BLUETOOTH, label = "Daily Wallet", model = "Safe 7", lastConnectedAt = 1_700_000_000_000L, @@ -182,49 +184,59 @@ internal object TrezorPreviewData { val connectedState = TrezorState( isInitialized = true, - connectedDevice = sampleFeatures, - connectedDeviceId = "trezor-abc123", + connected = ConnectedTrezorDevice( + id = "trezor-abc123", + features = sampleFeatures, + ), ) val connectedStateWithResults = TrezorState( isInitialized = true, - connectedDevice = sampleFeatures, - connectedDeviceId = "trezor-abc123", + connected = ConnectedTrezorDevice( + id = "trezor-abc123", + features = sampleFeatures, + ), lastAddress = sampleAddressResponse, lastPublicKey = samplePublicKeyResponse, ) val uiStateWithSignature = TrezorUiState( - selectedNetwork = BitkitCoreNetwork.REGTEST, - lastSignature = "H3bK9x...signature...base64==", - lastSigningAddress = SAMPLE_ADDRESS, + network = TrezorNetworkState(selectedNetwork = BitkitCoreNetwork.REGTEST), + message = TrezorMessageState( + lastSignature = "H3bK9x...signature...base64==", + lastSigningAddress = SAMPLE_ADDRESS, + ), ) val uiStateWithAccountInfo = TrezorUiState( - selectedNetwork = BitkitCoreNetwork.REGTEST, - lookupInput = SAMPLE_XPUB, - accountInfoResult = sampleAccountInfoResult, + network = TrezorNetworkState(selectedNetwork = BitkitCoreNetwork.REGTEST), + lookup = TrezorLookupState( + input = SAMPLE_XPUB, + accountInfoResult = sampleAccountInfoResult, + ), ) val uiStateWithAddressInfo = TrezorUiState( - selectedNetwork = BitkitCoreNetwork.REGTEST, - lookupInput = SAMPLE_ADDRESS, - addressInfoResult = sampleAddressInfoResult, + network = TrezorNetworkState(selectedNetwork = BitkitCoreNetwork.REGTEST), + lookup = TrezorLookupState( + input = SAMPLE_ADDRESS, + addressInfoResult = sampleAddressInfoResult, + ), ) val uiStateReview = TrezorUiState( - selectedNetwork = BitkitCoreNetwork.REGTEST, - sendStep = SendStep.REVIEW, - sendAddress = "bc1qxy2kgdygjrsqtzq2n0yrf2493p83kkfjhx0wlh", - sendAmountSats = "45000", - sendFeeRate = "5", - composeResult = sampleComposeResult, + network = TrezorNetworkState(selectedNetwork = BitkitCoreNetwork.REGTEST), + send = TrezorSendState( + address = "bc1qxy2kgdygjrsqtzq2n0yrf2493p83kkfjhx0wlh", + amountSats = "45000", + feeRate = "5", + step = SendStep.Review(sampleComposeResult), + ), ) val uiStateSigned = TrezorUiState( - selectedNetwork = BitkitCoreNetwork.REGTEST, - sendStep = SendStep.SIGNED, - signedTxResult = sampleSignedTx, + network = TrezorNetworkState(selectedNetwork = BitkitCoreNetwork.REGTEST), + send = TrezorSendState(step = SendStep.Signed(sampleSignedTx)), ) val sampleWalletBalance = WalletBalance( @@ -284,15 +296,20 @@ internal object TrezorPreviewData { ) val uiStateWithTxHistory = TrezorUiState( - selectedNetwork = BitkitCoreNetwork.REGTEST, - txHistoryInput = SAMPLE_XPUB, - txHistoryResult = sampleTransactionHistoryResult, + network = TrezorNetworkState(selectedNetwork = BitkitCoreNetwork.REGTEST), + txHistory = TrezorTxHistoryState( + input = SAMPLE_XPUB, + result = sampleTransactionHistoryResult, + ), ) val uiStateBroadcast = TrezorUiState( - selectedNetwork = BitkitCoreNetwork.REGTEST, - sendStep = SendStep.SIGNED, - signedTxResult = sampleSignedTx, - broadcastTxid = "c4d5e6f7a8b9c4d5e6f7a8b9c4d5e6f7a8b9c4d5e6f7a8b9c4d5e6f7a8b9c4d5", + network = TrezorNetworkState(selectedNetwork = BitkitCoreNetwork.REGTEST), + send = TrezorSendState( + step = SendStep.Signed( + signedTx = sampleSignedTx, + broadcastTxid = "c4d5e6f7a8b9c4d5e6f7a8b9c4d5e6f7a8b9c4d5e6f7a8b9c4d5e6f7a8b9c4d5", + ), + ), ) } diff --git a/app/src/main/java/to/bitkit/ui/screens/trezor/TrezorScreen.kt b/app/src/main/java/to/bitkit/ui/screens/trezor/TrezorScreen.kt index 4786f9c86..87ec05f21 100644 --- a/app/src/main/java/to/bitkit/ui/screens/trezor/TrezorScreen.kt +++ b/app/src/main/java/to/bitkit/ui/screens/trezor/TrezorScreen.kt @@ -45,7 +45,9 @@ import androidx.navigation.NavController import com.google.accompanist.permissions.ExperimentalPermissionsApi import com.google.accompanist.permissions.rememberMultiplePermissionsState import com.synonym.bitkitcore.CoinSelection +import kotlinx.collections.immutable.toImmutableList import to.bitkit.R +import to.bitkit.repositories.ConnectedTrezorDevice import to.bitkit.repositories.KnownDevice import to.bitkit.repositories.TrezorState import to.bitkit.services.TrezorDebugLog @@ -657,9 +659,12 @@ private fun PreviewWithDevices() { Content( trezorState = TrezorState( isInitialized = true, - knownDevices = listOf(TrezorPreviewData.sampleKnownDevice), - nearbyDevices = listOf(TrezorPreviewData.sampleNearbyDevice), - connectedDeviceId = TrezorPreviewData.sampleKnownDevice.id, + knownDevices = listOf(TrezorPreviewData.sampleKnownDevice).toImmutableList(), + nearbyDevices = listOf(TrezorPreviewData.sampleNearbyDevice).toImmutableList(), + connected = ConnectedTrezorDevice( + id = TrezorPreviewData.sampleKnownDevice.id, + features = TrezorPreviewData.sampleFeatures, + ), ), uiState = TrezorUiState(), ) diff --git a/app/src/main/java/to/bitkit/ui/screens/trezor/TrezorViewModel.kt b/app/src/main/java/to/bitkit/ui/screens/trezor/TrezorViewModel.kt index 94cc72b16..a51ed468d 100644 --- a/app/src/main/java/to/bitkit/ui/screens/trezor/TrezorViewModel.kt +++ b/app/src/main/java/to/bitkit/ui/screens/trezor/TrezorViewModel.kt @@ -1,5 +1,6 @@ package to.bitkit.ui.screens.trezor +import androidx.compose.runtime.Immutable import androidx.compose.runtime.Stable import androidx.lifecycle.ViewModel import androidx.lifecycle.viewModelScope @@ -64,6 +65,7 @@ class TrezorViewModel @Inject constructor( val label = it.label ?: it.model ?: "Trezor" ToastEventBus.send(type = Toast.ToastType.INFO, title = "Reconnected to $label") } + .onFailure { ToastEventBus.send(it) } } } @@ -126,7 +128,7 @@ class TrezorViewModel @Inject constructor( fun getAddress(showOnTrezor: Boolean = false) { viewModelScope.launch(bgDispatcher) { - _uiState.update { it.copy(isGettingAddress = true) } + _uiState.update { it.copy(network = it.network.copy(isGettingAddress = true)) } val state = _uiState.value trezorRepo.getAddress( path = state.derivationPath, @@ -135,11 +137,11 @@ class TrezorViewModel @Inject constructor( coin = state.selectedNetwork.toTrezorCoinType(), ) .onSuccess { - _uiState.update { it.copy(isGettingAddress = false) } + _uiState.update { it.copy(network = it.network.copy(isGettingAddress = false)) } ToastEventBus.send(type = Toast.ToastType.INFO, title = "Address generated") } .onFailure { - _uiState.update { it.copy(isGettingAddress = false) } + _uiState.update { it.copy(network = it.network.copy(isGettingAddress = false)) } ToastEventBus.send(it) } } @@ -147,36 +149,36 @@ class TrezorViewModel @Inject constructor( fun getPublicKey(showOnTrezor: Boolean = false) { viewModelScope.launch(bgDispatcher) { - _uiState.update { it.copy(isGettingPublicKey = true) } + _uiState.update { it.copy(network = it.network.copy(isGettingPublicKey = true)) } val state = _uiState.value - val accountPath = state.derivationPath.split("/").take(4).joinToString("/") trezorRepo.getPublicKey( - path = accountPath, + path = accountPath(state.derivationPath), showOnTrezor = showOnTrezor, coin = state.selectedNetwork.toTrezorCoinType(), ) .onSuccess { - _uiState.update { it.copy(isGettingPublicKey = false) } + _uiState.update { it.copy(network = it.network.copy(isGettingPublicKey = false)) } ToastEventBus.send(type = Toast.ToastType.INFO, title = "Public key retrieved") } .onFailure { - _uiState.update { it.copy(isGettingPublicKey = false) } + _uiState.update { it.copy(network = it.network.copy(isGettingPublicKey = false)) } ToastEventBus.send(it) } } } fun setDerivationPath(path: String) { - _uiState.update { it.copy(derivationPath = path) } + _uiState.update { it.copy(network = it.network.copy(derivationPath = path)) } } fun setSelectedNetwork(network: BitkitCoreNetwork) { - val coinType = if (network == BitkitCoreNetwork.BITCOIN) "0" else "1" _uiState.update { it.copy( - selectedNetwork = network, - addressIndex = 0, - derivationPath = "m/84'/$coinType'/0'/0/0", + network = it.network.copy( + selectedNetwork = network, + addressIndex = 0, + derivationPath = derivationPath(network = network, index = 0), + ) ) } } @@ -184,10 +186,11 @@ class TrezorViewModel @Inject constructor( fun incrementAddressIndex() { _uiState.update { state -> val newIndex = state.addressIndex + 1 - val coinType = if (state.selectedNetwork == BitkitCoreNetwork.BITCOIN) "0" else "1" state.copy( - addressIndex = newIndex, - derivationPath = "m/84'/$coinType'/0'/0/$newIndex", + network = state.network.copy( + addressIndex = newIndex, + derivationPath = derivationPath(network = state.selectedNetwork, index = newIndex), + ) ) } } @@ -203,11 +206,11 @@ class TrezorViewModel @Inject constructor( } fun setMessageToSign(message: String) { - _uiState.update { it.copy(messageToSign = message) } + _uiState.update { it.copy(message = it.message.copy(messageToSign = message)) } } fun setLookupInput(input: String) { - _uiState.update { it.copy(lookupInput = input) } + _uiState.update { it.copy(lookup = it.lookup.copy(input = input)) } } fun lookupBalanceInfo() { @@ -219,21 +222,12 @@ class TrezorViewModel @Inject constructor( } _uiState.update { it.copy( - isLookingUp = true, - accountInfoResult = null, - addressInfoResult = null, - sendAddress = "", - sendAmountSats = "", - sendFeeRate = "2", - isSendMax = false, - isComposing = false, - isSigning = false, - composeResult = null, - signedTxResult = null, - sendStep = SendStep.FORM, - coinSelection = CoinSelection.BRANCH_AND_BOUND, - isBroadcasting = false, - broadcastTxid = null, + lookup = it.lookup.copy( + isLookingUp = true, + accountInfoResult = null, + addressInfoResult = null, + ), + send = TrezorSendState(), ) } @@ -241,21 +235,35 @@ class TrezorViewModel @Inject constructor( if (isExtendedKey(input)) { trezorRepo.getAccountInfo(extendedKey = input, network = network) .onSuccess { result -> - _uiState.update { it.copy(isLookingUp = false, accountInfoResult = result) } + _uiState.update { + it.copy( + lookup = it.lookup.copy( + isLookingUp = false, + accountInfoResult = result, + ) + ) + } ToastEventBus.send(type = Toast.ToastType.INFO, title = "Account info retrieved") } .onFailure { - _uiState.update { it.copy(isLookingUp = false) } + _uiState.update { it.copy(lookup = it.lookup.copy(isLookingUp = false)) } ToastEventBus.send(it) } } else { trezorRepo.getAddressInfo(address = input, network = network) .onSuccess { result -> - _uiState.update { it.copy(isLookingUp = false, addressInfoResult = result) } + _uiState.update { + it.copy( + lookup = it.lookup.copy( + isLookingUp = false, + addressInfoResult = result, + ) + ) + } ToastEventBus.send(type = Toast.ToastType.INFO, title = "Address info retrieved") } .onFailure { - _uiState.update { it.copy(isLookingUp = false) } + _uiState.update { it.copy(lookup = it.lookup.copy(isLookingUp = false)) } ToastEventBus.send(it) } } @@ -275,7 +283,7 @@ class TrezorViewModel @Inject constructor( return@launch } - _uiState.update { it.copy(isSigningMessage = true) } + _uiState.update { it.copy(message = it.message.copy(isSigningMessage = true)) } val state = _uiState.value trezorRepo.signMessage( path = state.derivationPath, @@ -285,15 +293,17 @@ class TrezorViewModel @Inject constructor( .onSuccess { response -> _uiState.update { it.copy( - lastSignature = response.signature, - lastSigningAddress = response.address, - isSigningMessage = false + message = it.message.copy( + lastSignature = response.signature, + lastSigningAddress = response.address, + isSigningMessage = false, + ) ) } ToastEventBus.send(type = Toast.ToastType.INFO, title = "Message signed!") } .onFailure { e -> - _uiState.update { it.copy(isSigningMessage = false) } + _uiState.update { it.copy(message = it.message.copy(isSigningMessage = false)) } ToastEventBus.send(e) } } @@ -310,7 +320,7 @@ class TrezorViewModel @Inject constructor( return@launch } - _uiState.update { it.copy(isVerifyingMessage = true) } + _uiState.update { it.copy(message = it.message.copy(isVerifyingMessage = true)) } trezorRepo.verifyMessage( address = address, signature = signature, @@ -318,82 +328,79 @@ class TrezorViewModel @Inject constructor( coin = _uiState.value.selectedNetwork.toTrezorCoinType(), ) .onSuccess { isValid -> - _uiState.update { it.copy(isVerifyingMessage = false) } + _uiState.update { it.copy(message = it.message.copy(isVerifyingMessage = false)) } val msg = if (isValid) "Signature is valid!" else "Signature is invalid" val type = if (isValid) Toast.ToastType.SUCCESS else Toast.ToastType.ERROR ToastEventBus.send(type = type, title = msg) } .onFailure { - _uiState.update { it.copy(isVerifyingMessage = false) } + _uiState.update { it.copy(message = it.message.copy(isVerifyingMessage = false)) } ToastEventBus.send(it) } } } fun setSendAddress(address: String) { - _uiState.update { it.copy(sendAddress = address) } + _uiState.update { it.copy(send = it.send.copy(address = address)) } } fun setSendAmount(amount: String) { - _uiState.update { it.copy(sendAmountSats = amount) } + _uiState.update { it.copy(send = it.send.copy(amountSats = amount)) } } fun setSendFeeRate(feeRate: String) { - _uiState.update { it.copy(sendFeeRate = feeRate) } + _uiState.update { it.copy(send = it.send.copy(feeRate = feeRate)) } } fun toggleSendMax() { - _uiState.update { it.copy(isSendMax = !it.isSendMax) } + _uiState.update { it.copy(send = it.send.copy(isMax = !it.isSendMax)) } } fun setCoinSelection(selection: CoinSelection) { - _uiState.update { it.copy(coinSelection = selection) } + _uiState.update { it.copy(send = it.send.copy(coinSelection = selection)) } } fun broadcastSignedTx() { viewModelScope.launch(bgDispatcher) { val state = _uiState.value - val rawTx = state.signedTxResult?.serializedTx ?: return@launch - _uiState.update { it.copy(isBroadcasting = true) } + val signedStep = state.sendStep as? SendStep.Signed ?: return@launch + val rawTx = signedStep.signedTx.serializedTx + _uiState.update { it.copy(send = it.send.copy(isBroadcasting = true)) } trezorRepo.broadcastRawTx(serializedTx = rawTx, network = state.selectedNetwork) .onSuccess { txid -> TrezorDebugLog.log("BROADCAST", "SUCCESS txid=$txid") - _uiState.update { it.copy(isBroadcasting = false, broadcastTxid = txid) } + _uiState.update { + if (it.send.step != signedStep) return@update it + + it.copy( + send = it.send.copy( + isBroadcasting = false, + step = signedStep.copy(broadcastTxid = txid), + ) + ) + } ToastEventBus.send(type = Toast.ToastType.SUCCESS, title = "Transaction broadcast") } .onFailure { TrezorDebugLog.log("BROADCAST", "FAILED: ${it.message}") - _uiState.update { it.copy(isBroadcasting = false) } + _uiState.update { it.copy(send = it.send.copy(isBroadcasting = false)) } ToastEventBus.send(it) } } } fun resetSendFlow() { - _uiState.update { - it.copy( - sendAddress = "", - sendAmountSats = "", - sendFeeRate = "2", - isSendMax = false, - isComposing = false, - isSigning = false, - composeResult = null, - signedTxResult = null, - sendStep = SendStep.FORM, - coinSelection = CoinSelection.BRANCH_AND_BOUND, - isBroadcasting = false, - broadcastTxid = null, - ) - } + _uiState.update { it.copy(send = TrezorSendState()) } } fun backToComposeForm() { _uiState.update { it.copy( - sendStep = SendStep.FORM, - composeResult = null, - signedTxResult = null, + send = it.send.copy( + step = SendStep.Form, + isSigning = false, + isBroadcasting = false, + ) ) } } @@ -404,9 +411,8 @@ class TrezorViewModel @Inject constructor( val accountInfo = state.accountInfoResult ?: return@launch if (!validateComposeInputs(state)) return@launch - _uiState.update { it.copy(isComposing = true) } - val feeRate = state.sendFeeRate.toFloatOrNull() ?: return@launch + _uiState.update { it.copy(send = it.send.copy(isComposing = true)) } TrezorDebugLog.log("COMPOSE", "=== composeTx START ===") TrezorDebugLog.log("COMPOSE", "address=${state.sendAddress}") TrezorDebugLog.log("COMPOSE", "amount=${state.sendAmountSats}, sendMax=${state.isSendMax}") @@ -419,7 +425,7 @@ class TrezorViewModel @Inject constructor( } else { val amountSats = state.sendAmountSats.toULongOrNull() if (amountSats == null) { - _uiState.update { it.copy(isComposing = false) } + _uiState.update { it.copy(send = it.send.copy(isComposing = false)) } ToastEventBus.send(type = Toast.ToastType.ERROR, title = "Enter a valid amount") return@launch } @@ -437,7 +443,7 @@ class TrezorViewModel @Inject constructor( .onSuccess { handleComposeResults(it) } .onFailure { TrezorDebugLog.log("COMPOSE", "FAILED: ${it.message}") - _uiState.update { it.copy(isComposing = false) } + _uiState.update { it.copy(send = it.send.copy(isComposing = false)) } ToastEventBus.send(it) } } @@ -480,16 +486,21 @@ class TrezorViewModel @Inject constructor( if (successResult != null) { TrezorDebugLog.log("COMPOSE", "=== composeTx SUCCESS ===") _uiState.update { - it.copy(isComposing = false, composeResult = successResult, sendStep = SendStep.REVIEW) + it.copy( + send = it.send.copy( + isComposing = false, + step = SendStep.Review(successResult), + ) + ) } ToastEventBus.send(type = Toast.ToastType.INFO, title = "Transaction composed") } else if (errorResult != null) { TrezorDebugLog.log("COMPOSE", "=== composeTx FAILED (compose error) ===") - _uiState.update { it.copy(isComposing = false) } + _uiState.update { it.copy(send = it.send.copy(isComposing = false)) } ToastEventBus.send(type = Toast.ToastType.ERROR, title = errorResult.error) } else { TrezorDebugLog.log("COMPOSE", "=== composeTx FAILED (no valid result) ===") - _uiState.update { it.copy(isComposing = false) } + _uiState.update { it.copy(send = it.send.copy(isComposing = false)) } ToastEventBus.send(type = Toast.ToastType.ERROR, title = "No valid composition returned") } } @@ -497,13 +508,13 @@ class TrezorViewModel @Inject constructor( fun signComposedTx() { viewModelScope.launch(bgDispatcher) { val state = _uiState.value - val result = state.composeResult ?: return@launch + val result = (state.sendStep as? SendStep.Review)?.composeResult ?: return@launch TrezorDebugLog.log("SIGN", "=== signComposedTx START ===") TrezorDebugLog.log("SIGN", "network=${state.selectedNetwork}") TrezorDebugLog.log("SIGN", "psbt length=${result.psbt.length}") - _uiState.update { it.copy(isSigning = true) } + _uiState.update { it.copy(send = it.send.copy(isSigning = true)) } trezorRepo.signTxFromPsbt( psbtBase64 = result.psbt, @@ -517,7 +528,12 @@ class TrezorViewModel @Inject constructor( "txid=${signedTx.txid}, rawTxLen=${signedTx.serializedTx.length}" ) _uiState.update { - it.copy(isSigning = false, signedTxResult = signedTx, sendStep = SendStep.SIGNED) + it.copy( + send = it.send.copy( + isSigning = false, + step = SendStep.Signed(signedTx = signedTx), + ) + ) } ToastEventBus.send( type = Toast.ToastType.SUCCESS, @@ -526,14 +542,14 @@ class TrezorViewModel @Inject constructor( } .onFailure { TrezorDebugLog.log("SIGN", "signTxFromPsbt FAILED: ${it.message}") - _uiState.update { s -> s.copy(isSigning = false) } + _uiState.update { it.copy(send = it.send.copy(isSigning = false)) } ToastEventBus.send(it) } } } fun setTxHistoryInput(input: String) { - _uiState.update { it.copy(txHistoryInput = input) } + _uiState.update { it.copy(txHistory = it.txHistory.copy(input = input)) } } fun lookupTransactionHistory() { @@ -543,19 +559,23 @@ class TrezorViewModel @Inject constructor( ToastEventBus.send(type = Toast.ToastType.ERROR, title = "Enter an xpub") return@launch } - _uiState.update { it.copy(isLoadingTxHistory = true, txHistoryResult = null) } + _uiState.update { + it.copy(txHistory = it.txHistory.copy(isLoading = true, result = null)) + } val network = _uiState.value.selectedNetwork trezorRepo.getTransactionHistory(extendedKey = input, network = network) .onSuccess { result -> - _uiState.update { it.copy(isLoadingTxHistory = false, txHistoryResult = result) } + _uiState.update { + it.copy(txHistory = it.txHistory.copy(isLoading = false, result = result)) + } ToastEventBus.send( type = Toast.ToastType.INFO, title = "Found ${result.txCount} transaction${if (result.txCount != 1u) "s" else ""}" ) } .onFailure { - _uiState.update { it.copy(isLoadingTxHistory = false) } + _uiState.update { it.copy(txHistory = it.txHistory.copy(isLoading = false)) } ToastEventBus.send(it) } } @@ -595,36 +615,168 @@ class TrezorViewModel @Inject constructor( @Stable data class TrezorUiState( + val network: TrezorNetworkState = TrezorNetworkState(), + val message: TrezorMessageState = TrezorMessageState(), + val lookup: TrezorLookupState = TrezorLookupState(), + val send: TrezorSendState = TrezorSendState(), + val txHistory: TrezorTxHistoryState = TrezorTxHistoryState(), +) { + val selectedNetwork: BitkitCoreNetwork + get() = network.selectedNetwork + + val addressIndex: Int + get() = network.addressIndex + + val derivationPath: String + get() = network.derivationPath + + val messageToSign: String + get() = message.messageToSign + + val lastSignature: String? + get() = message.lastSignature + + val lastSigningAddress: String? + get() = message.lastSigningAddress + + val isSigningMessage: Boolean + get() = message.isSigningMessage + + val isGettingAddress: Boolean + get() = network.isGettingAddress + + val isGettingPublicKey: Boolean + get() = network.isGettingPublicKey + + val isVerifyingMessage: Boolean + get() = message.isVerifyingMessage + + val lookupInput: String + get() = lookup.input + + val isLookingUp: Boolean + get() = lookup.isLookingUp + + val accountInfoResult: AccountInfoResult? + get() = lookup.accountInfoResult + + val addressInfoResult: SingleAddressInfoResult? + get() = lookup.addressInfoResult + + val sendAddress: String + get() = send.address + + val sendAmountSats: String + get() = send.amountSats + + val sendFeeRate: String + get() = send.feeRate + + val isSendMax: Boolean + get() = send.isMax + + val isComposing: Boolean + get() = send.isComposing + + val isSigning: Boolean + get() = send.isSigning + + val sendStep: SendStep + get() = send.step + + val composeResult: ComposeResult.Success? + get() = (send.step as? SendStep.Review)?.composeResult + + val signedTxResult: TrezorSignedTx? + get() = (send.step as? SendStep.Signed)?.signedTx + + val coinSelection: CoinSelection + get() = send.coinSelection + + val isBroadcasting: Boolean + get() = send.isBroadcasting + + val broadcastTxid: String? + get() = (send.step as? SendStep.Signed)?.broadcastTxid + + val txHistoryInput: String + get() = txHistory.input + + val isLoadingTxHistory: Boolean + get() = txHistory.isLoading + + val txHistoryResult: TransactionHistoryResult? + get() = txHistory.result +} + +@Stable +data class TrezorNetworkState( val selectedNetwork: BitkitCoreNetwork = Env.network.toCoreNetwork(), val addressIndex: Int = 0, - val derivationPath: String = - "m/84'/${if (Env.network.toCoreNetwork() == BitkitCoreNetwork.BITCOIN) "0" else "1"}'/0'/0/0", + val derivationPath: String = derivationPath( + network = selectedNetwork, + index = addressIndex, + ), + val isGettingAddress: Boolean = false, + val isGettingPublicKey: Boolean = false, +) + +@Immutable +data class TrezorMessageState( val messageToSign: String = "Hello, Trezor!", val lastSignature: String? = null, val lastSigningAddress: String? = null, val isSigningMessage: Boolean = false, - val isGettingAddress: Boolean = false, - val isGettingPublicKey: Boolean = false, val isVerifyingMessage: Boolean = false, - val lookupInput: String = "", +) + +@Stable +data class TrezorLookupState( + val input: String = "", val isLookingUp: Boolean = false, val accountInfoResult: AccountInfoResult? = null, val addressInfoResult: SingleAddressInfoResult? = null, - val sendAddress: String = "", - val sendAmountSats: String = "", - val sendFeeRate: String = "2", - val isSendMax: Boolean = false, +) + +@Stable +data class TrezorSendState( + val address: String = "", + val amountSats: String = "", + val feeRate: String = "2", + val isMax: Boolean = false, val isComposing: Boolean = false, val isSigning: Boolean = false, - val composeResult: ComposeResult.Success? = null, - val signedTxResult: TrezorSignedTx? = null, - val sendStep: SendStep = SendStep.FORM, + val step: SendStep = SendStep.Form, val coinSelection: CoinSelection = CoinSelection.BRANCH_AND_BOUND, val isBroadcasting: Boolean = false, - val broadcastTxid: String? = null, - val txHistoryInput: String = "", - val isLoadingTxHistory: Boolean = false, - val txHistoryResult: TransactionHistoryResult? = null, ) -enum class SendStep { FORM, REVIEW, SIGNED } +@Stable +data class TrezorTxHistoryState( + val input: String = "", + val isLoading: Boolean = false, + val result: TransactionHistoryResult? = null, +) + +sealed interface SendStep { + data object Form : SendStep + + data class Review(val composeResult: ComposeResult.Success) : SendStep + + data class Signed( + val signedTx: TrezorSignedTx, + val broadcastTxid: String? = null, + ) : SendStep +} + +private fun derivationPath(network: BitkitCoreNetwork, index: Int): String { + return "m/84'/${coinTypeFor(network)}'/0'/0/$index" +} + +private fun accountPath(derivationPath: String): String { + return derivationPath.split("/").take(4).joinToString("/") +} + +private fun coinTypeFor(network: BitkitCoreNetwork): String { + return if (network == BitkitCoreNetwork.BITCOIN) "0" else "1" +} diff --git a/app/src/test/java/to/bitkit/repositories/TrezorRepoTest.kt b/app/src/test/java/to/bitkit/repositories/TrezorRepoTest.kt index d7271223e..746bd797a 100644 --- a/app/src/test/java/to/bitkit/repositories/TrezorRepoTest.kt +++ b/app/src/test/java/to/bitkit/repositories/TrezorRepoTest.kt @@ -16,7 +16,10 @@ import org.junit.Test import org.junit.rules.TemporaryFolder import org.mockito.kotlin.any import org.mockito.kotlin.anyOrNull +import org.mockito.kotlin.argumentCaptor import org.mockito.kotlin.mock +import org.mockito.kotlin.times +import org.mockito.kotlin.verify import org.mockito.kotlin.whenever import to.bitkit.data.TrezorStore import to.bitkit.env.Env @@ -102,6 +105,22 @@ class TrezorRepoTest : BaseUnitTest() { on { this.model }.thenReturn(model) } + private fun mockKnownDevice( + id: String = DEVICE_ID, + name: String? = DEVICE_NAME, + path: String = DEVICE_PATH, + label: String? = DEVICE_LABEL, + model: String? = DEVICE_MODEL, + ) = KnownDevice( + id = id, + name = name, + path = path, + transportType = KnownDeviceTransportType.USB, + label = label, + model = model, + lastConnectedAt = 123L, + ) + // region initialize @Test @@ -145,6 +164,23 @@ class TrezorRepoTest : BaseUnitTest() { assertFalse(sut.state.value.isScanning) } + @Test + fun `scan should exclude known devices from nearbyDevices state`() = test { + val knownDevice = mockKnownDevice() + val known = mockDeviceInfo() + val nearby = mockDeviceInfo(id = "device-456", path = "/dev/trezor1") + whenever(trezorStore.loadKnownDevices()).thenReturn(listOf(knownDevice)) + whenever(trezorService.scan()).thenReturn(listOf(known, nearby)) + sut = createSut() + + sut.initialize() + val result = sut.scan() + + assertTrue(result.isSuccess) + assertEquals(listOf(known, nearby), result.getOrNull()) + assertEquals(listOf(nearby), sut.state.value.nearbyDevices) + } + @Test fun `scan should set error on failure`() = test { whenever(trezorService.scan()).thenThrow(RuntimeException("scan failed")) @@ -181,6 +217,56 @@ class TrezorRepoTest : BaseUnitTest() { assertFalse(sut.state.value.isConnecting) } + @Test + fun `connect should persist connected device as known device`() = test { + val features = mockFeatures(label = "Savings", model = "Safe 5") + val device = mockDeviceInfo() + whenever(trezorService.connect(DEVICE_ID)).thenReturn(features) + whenever(trezorService.scan()).thenReturn(listOf(device)) + sut = createSut() + + sut.scan() + val result = sut.connect(DEVICE_ID) + + assertTrue(result.isSuccess) + val captor = argumentCaptor>() + verify(trezorStore).saveKnownDevices(captor.capture()) + val saved = captor.firstValue.single() + assertEquals(DEVICE_ID, saved.id) + assertEquals(KnownDeviceTransportType.USB, saved.transportType) + assertEquals("Savings", saved.label) + assertEquals("Safe 5", saved.model) + } + + @Test + fun `connect should retry once for retryable THP errors`() = test { + val features = mockFeatures() + val device = mockDeviceInfo() + whenever(trezorService.connect(DEVICE_ID)) + .thenThrow(RuntimeException("thp timeout")) + .thenReturn(features) + whenever(trezorService.scan()).thenReturn(listOf(device)) + sut = createSut() + + sut.scan() + val result = sut.connect(DEVICE_ID) + + assertTrue(result.isSuccess) + assertEquals(features, result.getOrNull()) + verify(trezorService, times(2)).connect(DEVICE_ID) + } + + @Test + fun `connect should not retry non-retryable errors`() = test { + whenever(trezorService.connect(DEVICE_ID)).thenThrow(RuntimeException("bad pin")) + sut = createSut() + + val result = sut.connect(DEVICE_ID) + + assertTrue(result.isFailure) + verify(trezorService, times(1)).connect(DEVICE_ID) + } + @Test fun `connect should set error on failure`() = test { whenever(trezorService.connect(DEVICE_ID)).thenThrow(RuntimeException("connect failed")) @@ -218,6 +304,48 @@ class TrezorRepoTest : BaseUnitTest() { assertNull(sut.state.value.lastPublicKey) } + @Test + fun `disconnect should clear connectedDevice state on service failure`() = test { + val features = mockFeatures() + val device = mockDeviceInfo() + val addressResponse = mock() + val publicKeyResponse = mock() + whenever(trezorService.connect(DEVICE_ID)).thenReturn(features) + whenever(trezorService.scan()).thenReturn(listOf(device)) + whenever(trezorService.isConnected()).thenReturn(true) + whenever( + trezorService.getAddress( + path = any(), + coin = any(), + showOnTrezor = any(), + scriptType = anyOrNull(), + ) + ).thenReturn(addressResponse) + whenever( + trezorService.getPublicKey( + path = any(), + coin = any(), + showOnTrezor = any(), + ) + ).thenReturn(publicKeyResponse) + sut = createSut() + + sut.scan() + sut.connect(DEVICE_ID) + sut.getAddress() + sut.getPublicKey() + whenever(trezorService.disconnect()).thenThrow(RuntimeException("disconnect failed")) + + val result = sut.disconnect() + + assertTrue(result.isFailure) + assertNull(sut.state.value.connectedDevice) + assertNull(sut.state.value.connectedDeviceId) + assertNull(sut.state.value.lastAddress) + assertNull(sut.state.value.lastPublicKey) + assertEquals("disconnect failed", sut.state.value.error) + } + // endregion // region getAddress @@ -349,6 +477,62 @@ class TrezorRepoTest : BaseUnitTest() { // endregion + // region autoReconnect + + @Test + fun `autoReconnect should fail when no known devices exist`() = test { + sut = createSut() + + val result = sut.autoReconnect() + + assertTrue(result.isFailure) + assertEquals("No known devices", result.exceptionOrNull()?.message) + } + + @Test + fun `autoReconnect should scan and connect known nearby device`() = test { + val knownDevice = mockKnownDevice() + val device = mockDeviceInfo() + val features = mockFeatures() + whenever(trezorStore.loadKnownDevices()).thenReturn(listOf(knownDevice)) + whenever(trezorService.scan()).thenReturn(listOf(device)) + whenever(trezorService.connect(DEVICE_ID)).thenReturn(features) + whenever(trezorService.isConnected()).thenReturn(false) + sut = createSut() + + sut.initialize() + val result = sut.autoReconnect() + + assertTrue(result.isSuccess) + assertEquals(features, result.getOrNull()) + assertEquals(DEVICE_ID, sut.state.value.connectedDeviceId) + assertFalse(sut.state.value.isAutoReconnecting) + } + + // endregion + + // region connectKnownDevice + + @Test + fun `connectKnownDevice should connect exact known device match`() = test { + val knownDevice = mockKnownDevice() + val device = mockDeviceInfo() + val features = mockFeatures() + whenever(trezorStore.loadKnownDevices()).thenReturn(listOf(knownDevice)) + whenever(trezorService.scan()).thenReturn(listOf(device)) + whenever(trezorService.connect(DEVICE_ID)).thenReturn(features) + sut = createSut() + + sut.initialize() + val result = sut.connectKnownDevice(DEVICE_ID) + + assertTrue(result.isSuccess) + assertEquals(features, result.getOrNull()) + assertEquals(DEVICE_ID, sut.state.value.connectedDeviceId) + } + + // endregion + // region clearError @Test @@ -383,6 +567,72 @@ class TrezorRepoTest : BaseUnitTest() { // endregion + // region ensureConnected + + @Test + fun `getAddress should reconnect known device before reading address`() = test { + val knownDevice = mockKnownDevice() + val device = mockDeviceInfo() + val features = mockFeatures() + val addressResponse = mock() + whenever(trezorStore.loadKnownDevices()).thenReturn(listOf(knownDevice)) + whenever(trezorService.isConnected()).thenReturn(false) + whenever(trezorService.scan()).thenReturn(listOf(device)) + whenever(trezorService.connect(DEVICE_ID)).thenReturn(features) + whenever( + trezorService.getAddress( + path = any(), + coin = any(), + showOnTrezor = any(), + scriptType = anyOrNull(), + ) + ).thenReturn(addressResponse) + sut = createSut() + + sut.initialize() + val result = sut.getAddress() + + assertTrue(result.isSuccess) + assertEquals(addressResponse, result.getOrNull()) + assertEquals(DEVICE_ID, sut.state.value.connectedDeviceId) + verify(trezorService).scan() + verify(trezorService).connect(DEVICE_ID) + } + + // endregion + + // region forgetDevice + + @Test + fun `forgetDevice should remove known device when service cleanup fails`() = test { + val knownDevice = mockKnownDevice() + val features = mockFeatures() + val device = mockDeviceInfo() + whenever(trezorStore.loadKnownDevices()).thenReturn(listOf(knownDevice)) + whenever(trezorService.connect(DEVICE_ID)).thenReturn(features) + whenever(trezorService.scan()).thenReturn(listOf(device)) + sut = createSut() + + sut.initialize() + sut.scan() + sut.connect(DEVICE_ID) + whenever(trezorService.disconnect()).thenThrow(RuntimeException("disconnect failed")) + whenever(trezorService.clearCredentials(DEVICE_ID)).thenThrow(RuntimeException("clear failed")) + + val result = sut.forgetDevice(DEVICE_ID) + + assertTrue(result.isFailure) + assertTrue(sut.state.value.knownDevices.isEmpty()) + assertNull(sut.state.value.connectedDevice) + assertNull(sut.state.value.connectedDeviceId) + assertEquals("disconnect failed", sut.state.value.error) + verify(trezorTransport).clearDeviceCredential(DEVICE_ID) + verify(trezorService).clearCredentials(DEVICE_ID) + verify(trezorStore).saveKnownDevices(emptyList()) + } + + // endregion + // region initial state @Test diff --git a/app/src/test/java/to/bitkit/ui/screens/trezor/TrezorViewModelTest.kt b/app/src/test/java/to/bitkit/ui/screens/trezor/TrezorViewModelTest.kt index 1c6a2f572..e99521578 100644 --- a/app/src/test/java/to/bitkit/ui/screens/trezor/TrezorViewModelTest.kt +++ b/app/src/test/java/to/bitkit/ui/screens/trezor/TrezorViewModelTest.kt @@ -1,11 +1,15 @@ package to.bitkit.ui.screens.trezor +import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.test.TestScope import kotlinx.coroutines.test.advanceUntilIdle import org.junit.Before import org.junit.Test import org.mockito.kotlin.any +import org.mockito.kotlin.anyOrNull +import org.mockito.kotlin.doSuspendableAnswer import org.mockito.kotlin.mock import org.mockito.kotlin.never import org.mockito.kotlin.verify @@ -137,7 +141,7 @@ class TrezorViewModelTest : BaseUnitTest() { assertFalse(state.isSigning) assertNull(state.composeResult) assertNull(state.signedTxResult) - assertEquals(SendStep.FORM, state.sendStep) + assertEquals(SendStep.Form, state.sendStep) assertFalse(state.isBroadcasting) assertNull(state.broadcastTxid) } @@ -243,6 +247,59 @@ class TrezorViewModelTest : BaseUnitTest() { verify(trezorRepo, never()).broadcastRawTx(any(), any()) } + @Test + fun `broadcastSignedTx should not restore signed step after reset`() = test { + loadSignedTx() + val broadcastResult = CompletableDeferred>() + whenever(trezorRepo.broadcastRawTx(any(), any())) + .doSuspendableAnswer { broadcastResult.await() } + + sut.broadcastSignedTx() + assertTrue(sut.uiState.value.isBroadcasting) + + sut.resetSendFlow() + broadcastResult.complete(Result.success("broadcast-txid")) + advanceUntilIdle() + + val state = sut.uiState.value + assertEquals(SendStep.Form, state.sendStep) + assertFalse(state.isBroadcasting) + assertNull(state.broadcastTxid) + } + + @Test + fun `composeTx should not call repo when destination address is blank`() = test { + loadAccountInfo() + sut.setSendAmount("1000") + sut.setSendFeeRate("2") + + sut.composeTx() + advanceUntilIdle() + + verify(trezorRepo, never()).composeTransaction(any(), any(), any(), any(), anyOrNull(), any()) + } + + @Test + fun `composeTx should not call repo when fee rate is invalid`() = test { + loadAccountInfo() + sut.setSendAddress("bc1qtest123") + sut.setSendAmount("1000") + sut.setSendFeeRate("0") + + sut.composeTx() + advanceUntilIdle() + + verify(trezorRepo, never()).composeTransaction(any(), any(), any(), any(), anyOrNull(), any()) + } + + @Test + fun `signComposedTx should not call repo when no compose result exists`() = test { + sut.signComposedTx() + advanceUntilIdle() + + verify(trezorRepo, never()).signTxFromPsbt(any(), anyOrNull()) + } + @Test fun `clearError should call trezorRepo clearError`() { sut.clearError() @@ -279,4 +336,28 @@ class TrezorViewModelTest : BaseUnitTest() { bgDispatcher = testDispatcher, trezorRepo = trezorRepo, ) + + private suspend fun TestScope.loadAccountInfo() { + whenever(trezorRepo.getAccountInfo(any(), any(), anyOrNull())) + .thenReturn(Result.success(TrezorPreviewData.sampleAccountInfoResult)) + + sut.setLookupInput("xpub6test123") + sut.lookupBalanceInfo() + advanceUntilIdle() + } + + private suspend fun TestScope.loadSignedTx() { + loadAccountInfo() + whenever(trezorRepo.composeTransaction(any(), any(), any(), any(), anyOrNull(), any())) + .thenReturn(Result.success(listOf(TrezorPreviewData.sampleComposeResult))) + whenever(trezorRepo.signTxFromPsbt(any(), anyOrNull())) + .thenReturn(Result.success(TrezorPreviewData.sampleSignedTx)) + + sut.setSendAddress("bc1qtest123") + sut.setSendAmount("1000") + sut.composeTx() + advanceUntilIdle() + sut.signComposedTx() + advanceUntilIdle() + } } diff --git a/changelog.d/next/792.added.md b/changelog.d/next/792.added.md new file mode 100644 index 000000000..5d9878e50 --- /dev/null +++ b/changelog.d/next/792.added.md @@ -0,0 +1 @@ +Added Trezor hardware wallet support for connecting devices, signing messages, and managing on-chain transactions.