Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -744,11 +744,7 @@ class SshConnection(
setMethodSpecificFields(noneAuth)
}

var noneResult = channel.receive()
while (noneResult is InternalAuthResult.Banner) {
handler.onBanner(noneResult.message)
noneResult = channel.receive()
}
val noneResult = receiveAuthResult(channel, handler)
if (noneResult is InternalAuthResult.Success) return PublicAuthResult.Success
if (noneResult !is InternalAuthResult.Failure) return PublicAuthResult.Error("Unexpected response to 'none' auth: $noneResult")

Expand Down Expand Up @@ -823,12 +819,7 @@ class SshConnection(
}
setMethodSpecificFields(pubkeyAuth)
}
var response = channel.receive()
while (response is InternalAuthResult.Banner) {
handler.onBanner(response.message)
response = channel.receive()
}
return response
return receiveAuthResult(channel, handler)
}

private suspend fun signPublicKey(
Expand Down Expand Up @@ -881,11 +872,7 @@ class SshConnection(
}
}

var response = channel.receive()
while (response is InternalAuthResult.Banner) {
handler.onBanner(response.message)
response = channel.receive()
}
val response = receiveAuthResult(channel, handler)
return when (response) {
is InternalAuthResult.Success -> true
else -> false
Expand Down Expand Up @@ -965,11 +952,7 @@ class SshConnection(
setMethodSpecificFields(passAuth)
}

var response = channel.receive()
while (response is InternalAuthResult.Banner) {
handler.onBanner(response.message)
response = channel.receive()
}
val response = receiveAuthResult(channel, handler)
return when (val result = response) {
is InternalAuthResult.Success -> true

Expand All @@ -982,6 +965,18 @@ class SshConnection(
}
}

private suspend fun receiveAuthResult(
channel: Channel<InternalAuthResult>,
handler: AuthHandler,
): InternalAuthResult {
var result = channel.receive()
while (result is InternalAuthResult.Banner) {
handler.onBanner(result.message)
result = channel.receive()
}
return result
}

private suspend fun sendAuthRequest(
username: String,
method: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -549,28 +549,24 @@ class FakeSshServer(
}
}

fun sendUserauthBanner(message: String) {
scope.launch(coroutineContext) {
val banner = SshMsgUserauthBanner()
val utf8 = createUtf8String(message)
banner.setMessage(utf8)
banner.setLanguageTag(createByteString(ByteArray(0)))
banner._check()
writeMutex.withLock {
serverIo.writePacket(SshEnums.MessageType.SSH_MSG_USERAUTH_BANNER.id().toInt(), banner.toByteArray())
}
suspend fun sendUserauthBanner(message: String) {
val banner = SshMsgUserauthBanner()
val utf8 = createUtf8String(message)
banner.setMessage(utf8)
banner.setLanguageTag(createByteString(ByteArray(0)))
banner._check()
writeMutex.withLock {
serverIo.writePacket(SshEnums.MessageType.SSH_MSG_USERAUTH_BANNER.id().toInt(), banner.toByteArray())
}
}

fun sendUserauthFailure(allowedMethods: Set<String>, partialSuccess: Boolean) {
scope.launch(coroutineContext) {
val failure = SshMsgUserauthFailure()
failure.setValidAuthentications(createNameList(allowedMethods.joinToString(",")))
failure.setPartialSuccess(if (partialSuccess) 1 else 0)
failure._check()
writeMutex.withLock {
serverIo.writePacket(SshEnums.MessageType.SSH_MSG_USERAUTH_FAILURE.id().toInt(), failure.toByteArray())
}
suspend fun sendUserauthFailure(allowedMethods: Set<String>, partialSuccess: Boolean) {
val failure = SshMsgUserauthFailure()
failure.setValidAuthentications(createNameList(allowedMethods.joinToString(",")))
failure.setPartialSuccess(if (partialSuccess) 1 else 0)
failure._check()
writeMutex.withLock {
serverIo.writePacket(SshEnums.MessageType.SSH_MSG_USERAUTH_FAILURE.id().toInt(), failure.toByteArray())
}
}

Expand Down
Loading