Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/bin/keystone.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,12 @@ async fn main() -> Result<(), Report> {
let mut app = Router::new().merge(main_router.with_state(shared_state.clone()));

if shared_state.config.webauthn.enabled {
info!("Not enabling the WebAuthN extension due to the `config.webauthn.enabled` flag.");
let webauthn_extension = webauthn::api::init_extension(shared_state.clone())?;
let webauthn_cloned_token = token.clone();
let webauthn_extension =
webauthn::api::init_extension(shared_state.clone(), webauthn_cloned_token)?;
app = app.nest("/v4", webauthn_extension);
} else {
info!("Not enabling the WebAuthN extension due to the `config.webauthn.enabled` flag.");
}

app = app
Expand All @@ -274,7 +277,6 @@ async fn main() -> Result<(), Report> {
async fn cleanup(cancel: CancellationToken, state: ServiceState) {
let mut interval = time::interval(Duration::from_secs(60));
interval.tick().await;
// TODO: Clean passkeys expired states
info!("Start the periodic cleanup thread");
loop {
tokio::select! {
Expand Down
33 changes: 31 additions & 2 deletions src/webauthn/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@

use axum::Router;
use std::sync::Arc;
use std::time::Duration;
use tokio::{spawn, time};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, trace};
use utoipa::OpenApi;
use utoipa_axum::router::OpenApiRouter;
use webauthn_rs::WebauthnBuilder;
Expand All @@ -26,7 +30,7 @@ mod auth;
mod register;
pub mod types;

use crate::webauthn::{WebauthnError, driver::SqlDriver};
use crate::webauthn::{WebauthnApi, WebauthnError, driver::SqlDriver};
use types::{CombinedExtensionState, ExtensionState};

/// OpenApi specification for the user passkey support.
Expand All @@ -46,7 +50,10 @@ pub fn openapi_router() -> OpenApiRouter<CombinedExtensionState> {
}

/// Initialize the extension.
pub fn init_extension(main_state: ServiceState) -> Result<Router, KeystoneError> {
pub fn init_extension(
main_state: ServiceState,
cancellation_token: CancellationToken,
) -> Result<Router, KeystoneError> {
// Url containing the effective domain name
// MUST include the port number!
let rp = main_state
Expand Down Expand Up @@ -77,9 +84,31 @@ pub fn init_extension(main_state: ServiceState) -> Result<Router, KeystoneError>
core: main_state,
extension: extension_state,
};
spawn(cleanup(cancellation_token, combined_state.clone()));
let (router, _openapi) = OpenApiRouter::new()
.merge(openapi_router())
.with_state(combined_state)
.split_for_parts();
Ok(router)
}

/// Periodic cleanup job.
async fn cleanup(cancel: CancellationToken, state: CombinedExtensionState) {
let mut interval = time::interval(Duration::from_secs(60));
interval.tick().await;
info!("Start the periodic cleanup thread of the webauthn extension");
loop {
tokio::select! {
_ = interval.tick() => {
trace!("cleanup job tick");
if let Err(e) = state.extension.provider.cleanup(&state.core).await {
error!("Error during cleanup job: {}", e);
}
},
() = cancel.cancelled() => {
info!("Cancellation requested. Stopping webauthn cleanup task.");
break; // Exit the loop
}
}
}
}
6 changes: 6 additions & 0 deletions src/webauthn/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ pub struct SqlDriver {}

#[async_trait]
impl WebauthnApi for SqlDriver {
/// Cleanup expired Webauthn states.
#[tracing::instrument(level = "debug", skip(self, state))]
async fn cleanup(&self, state: &ServiceState) -> Result<(), WebauthnError> {
state::delete_expired(&state.db).await
}

/// Create webauthn credential for the user.
#[tracing::instrument(level = "debug", skip(self, state))]
async fn create_user_webauthn_credential(
Expand Down
2 changes: 1 addition & 1 deletion src/webauthn/driver/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ mod delete;
mod get;

pub use create::{create_auth, create_register};
pub use delete::delete;
pub use delete::{delete, delete_expired};
pub use get::{get_auth, get_register};
50 changes: 49 additions & 1 deletion src/webauthn/driver/state/delete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@
//
// SPDX-License-Identifier: Apache-2.0

use chrono::{TimeDelta, Utc};
use sea_orm::DatabaseConnection;
use sea_orm::entity::*;
use sea_orm::query::*;

use crate::db::entity::prelude::WebauthnState as DbPasskeyState;
use crate::db::entity::{
prelude::WebauthnState as DbPasskeyState, webauthn_state as db_webauthn_state,
};
use crate::error::DbContextExt;
use crate::webauthn::WebauthnError;

Expand All @@ -30,6 +34,18 @@ pub async fn delete<U: AsRef<str>>(
Ok(())
}

/// Delete expired states.
pub async fn delete_expired(db: &DatabaseConnection) -> Result<(), WebauthnError> {
if let Some(oldest_date) = Utc::now().checked_sub_signed(TimeDelta::minutes(5)) {
DbPasskeyState::delete_many()
.filter(db_webauthn_state::Column::CreatedAt.lt(oldest_date))
.exec(db)
.await
.context("deleting expired passkey states")?;
}
Ok(())
}

#[cfg(test)]
mod tests {
use sea_orm::{DatabaseBackend, MockDatabase, MockExecResult, Transaction};
Expand All @@ -56,4 +72,36 @@ mod tests {
),]
);
}

#[tokio::test]
async fn test_delete_expired() {
let db = MockDatabase::new(DatabaseBackend::Postgres)
.append_exec_results([MockExecResult {
rows_affected: 1,
..Default::default()
}])
.into_connection();

delete_expired(&db).await.unwrap();
for (l, r) in db
.into_transaction_log()
.iter()
.zip([Transaction::from_sql_and_values(
DatabaseBackend::Postgres,
r#"DELETE FROM "webauthn_state" WHERE "webauthn_state"."created_at" < $1"#,
[],
)])
{
assert_eq!(
l.statements()
.iter()
.map(|x| x.sql.clone())
.collect::<Vec<_>>(),
r.statements()
.iter()
.map(|x| x.sql.clone())
.collect::<Vec<_>>()
);
}
}
}
3 changes: 3 additions & 0 deletions src/webauthn/types/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ use crate::webauthn::{WebauthnError, types::WebauthnCredential};
#[cfg_attr(test, mockall::automock)]
#[async_trait]
pub trait WebauthnApi: Send + Sync {
/// Cleanup expired Webauthn states.
async fn cleanup(&self, state: &ServiceState) -> Result<(), WebauthnError>;

/// Create passkey.
async fn create_user_webauthn_credential(
&self,
Expand Down
Loading