diff --git a/sshlib/src/main/kotlin/org/connectbot/sshlib/client/SshConnection.kt b/sshlib/src/main/kotlin/org/connectbot/sshlib/client/SshConnection.kt index e14d755..dfe29a3 100644 --- a/sshlib/src/main/kotlin/org/connectbot/sshlib/client/SshConnection.kt +++ b/sshlib/src/main/kotlin/org/connectbot/sshlib/client/SshConnection.kt @@ -744,11 +744,7 @@ class SshConnection( setMethodSpecificFields(noneAuth) } - var noneResult = channel.receive() - while (noneResult is InternalAuthResult.Banner) { - handler.onBanner(noneResult.message) - noneResult = channel.receive() - } + val noneResult = receiveAuthResult(channel, handler) if (noneResult is InternalAuthResult.Success) return PublicAuthResult.Success if (noneResult !is InternalAuthResult.Failure) return PublicAuthResult.Error("Unexpected response to 'none' auth: $noneResult") @@ -823,12 +819,7 @@ class SshConnection( } setMethodSpecificFields(pubkeyAuth) } - var response = channel.receive() - while (response is InternalAuthResult.Banner) { - handler.onBanner(response.message) - response = channel.receive() - } - return response + return receiveAuthResult(channel, handler) } private suspend fun signPublicKey( @@ -881,11 +872,7 @@ class SshConnection( } } - var response = channel.receive() - while (response is InternalAuthResult.Banner) { - handler.onBanner(response.message) - response = channel.receive() - } + val response = receiveAuthResult(channel, handler) return when (response) { is InternalAuthResult.Success -> true else -> false @@ -965,11 +952,7 @@ class SshConnection( setMethodSpecificFields(passAuth) } - var response = channel.receive() - while (response is InternalAuthResult.Banner) { - handler.onBanner(response.message) - response = channel.receive() - } + val response = receiveAuthResult(channel, handler) return when (val result = response) { is InternalAuthResult.Success -> true @@ -982,6 +965,18 @@ class SshConnection( } } + private suspend fun receiveAuthResult( + channel: Channel, + handler: AuthHandler, + ): InternalAuthResult { + var result = channel.receive() + while (result is InternalAuthResult.Banner) { + handler.onBanner(result.message) + result = channel.receive() + } + return result + } + private suspend fun sendAuthRequest( username: String, method: String, diff --git a/sshlib/src/test/kotlin/org/connectbot/sshlib/client/FakeSshServer.kt b/sshlib/src/test/kotlin/org/connectbot/sshlib/client/FakeSshServer.kt index 933836a..a33726f 100644 --- a/sshlib/src/test/kotlin/org/connectbot/sshlib/client/FakeSshServer.kt +++ b/sshlib/src/test/kotlin/org/connectbot/sshlib/client/FakeSshServer.kt @@ -549,28 +549,24 @@ class FakeSshServer( } } - fun sendUserauthBanner(message: String) { - scope.launch(coroutineContext) { - val banner = SshMsgUserauthBanner() - val utf8 = createUtf8String(message) - banner.setMessage(utf8) - banner.setLanguageTag(createByteString(ByteArray(0))) - banner._check() - writeMutex.withLock { - serverIo.writePacket(SshEnums.MessageType.SSH_MSG_USERAUTH_BANNER.id().toInt(), banner.toByteArray()) - } + suspend fun sendUserauthBanner(message: String) { + val banner = SshMsgUserauthBanner() + val utf8 = createUtf8String(message) + banner.setMessage(utf8) + banner.setLanguageTag(createByteString(ByteArray(0))) + banner._check() + writeMutex.withLock { + serverIo.writePacket(SshEnums.MessageType.SSH_MSG_USERAUTH_BANNER.id().toInt(), banner.toByteArray()) } } - fun sendUserauthFailure(allowedMethods: Set, partialSuccess: Boolean) { - scope.launch(coroutineContext) { - val failure = SshMsgUserauthFailure() - failure.setValidAuthentications(createNameList(allowedMethods.joinToString(","))) - failure.setPartialSuccess(if (partialSuccess) 1 else 0) - failure._check() - writeMutex.withLock { - serverIo.writePacket(SshEnums.MessageType.SSH_MSG_USERAUTH_FAILURE.id().toInt(), failure.toByteArray()) - } + suspend fun sendUserauthFailure(allowedMethods: Set, partialSuccess: Boolean) { + val failure = SshMsgUserauthFailure() + failure.setValidAuthentications(createNameList(allowedMethods.joinToString(","))) + failure.setPartialSuccess(if (partialSuccess) 1 else 0) + failure._check() + writeMutex.withLock { + serverIo.writePacket(SshEnums.MessageType.SSH_MSG_USERAUTH_FAILURE.id().toInt(), failure.toByteArray()) } }