From 07b98987e7fdc7f1bd46a5a61ff47f896ec3b77b Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Wed, 3 Jun 2026 14:22:52 -0700 Subject: [PATCH 1/3] Add `--run` Flag for Broker vs. Migrations --- src/config.rs | 94 ++++++-- src/lib.rs | 18 +- src/main.rs | 348 +--------------------------- src/run.rs | 401 +++++++++++++++++++++++++++++++++ src/store/adapters/postgres.rs | 33 --- 5 files changed, 501 insertions(+), 393 deletions(-) create mode 100644 src/run.rs diff --git a/src/config.rs b/src/config.rs index 5231b2d4..0bc7d049 100644 --- a/src/config.rs +++ b/src/config.rs @@ -274,12 +274,18 @@ pub struct Config { /// The port of the postgres database to use for the activation store. pub pg_port: u16, + // User permitted to run DDL operations. + pub pg_ddl_username: String, + /// The username of the postgres database to use for the activation store. pub pg_username: String, /// The password of the postgres database to use for the activation store. pub pg_password: String, + /// Password for the user permitted to run DDL operations. + pub pg_ddl_password: String, + /// The name of the postgres database to use for the activation store. pub pg_database_name: String, @@ -523,6 +529,8 @@ impl Default for Config { run_migrations: false, pg_host: "sentry-postgres-1".to_owned(), pg_port: 5432, + pg_ddl_username: "postgres".to_owned(), + pg_ddl_password: "password".to_owned(), pg_username: "postgres".to_owned(), pg_password: "password".to_owned(), pg_database_name: "default".to_owned(), @@ -1042,8 +1050,8 @@ mod tests { use figment::Jail; use validator::Validate; - use crate::Args; use crate::logging::LogFormat; + use crate::{Args, Run}; use super::{Config, DatabaseAdapter, DeliveryMode}; @@ -1184,6 +1192,7 @@ mod tests { jail.set_env("TASKBROKER_LOG_FILTER", "error"); let args = Args { + run: Run::Broker, config: Some("config.yaml".to_owned()), }; let config = Config::from_args(&args).unwrap(); @@ -1239,7 +1248,10 @@ mod tests { jail.set_env("TASKBROKER_DATABASE_ADAPTER", "postgres"); jail.set_env("TASKBROKER_MAX_PROCESSING_ATTEMPTS", "5"); - let args = Args { config: None }; + let args = Args { + run: Run::Broker, + config: None, + }; let config = Config::from_args(&args).unwrap(); assert_eq!(config.log_filter, "error"); assert_eq!(config.database_adapter, DatabaseAdapter::Postgres); @@ -1256,7 +1268,10 @@ mod tests { jail.set_env("TASKBROKER_MAX_PROCESSING_ATTEMPTS", "5"); jail.set_env("TASKBROKER_DEFAULT_METRICS_TAGS", "{key=value}"); - let args = Args { config: None }; + let args = Args { + run: Run::Broker, + config: None, + }; let config = Config::from_args(&args).unwrap(); assert_eq!(config.sentry_dsn, None); assert_eq!(config.sentry_env, None); @@ -1299,7 +1314,10 @@ mod tests { "{sentry=http://127.0.0.1:60052,launchpad=http://127.0.0.1:60053}", ); - let args = Args { config: None }; + let args = Args { + run: Run::Broker, + config: None, + }; let config = Config::from_args(&args).unwrap(); assert_eq!( config.worker_map, @@ -1315,7 +1333,10 @@ mod tests { #[test] fn test_kafka_consumer_config() { - let args = Args { config: None }; + let args = Args { + run: Run::Broker, + config: None, + }; let config = Config::from_args(&args).unwrap(); let consumer_config = config.kafka_consumer_config(); @@ -1335,7 +1356,10 @@ mod tests { jail.set_env("TASKBROKER_KAFKA_SASL_USERNAME", "taskbroker"); jail.set_env("TASKBROKER_KAFKA_SASL_PASSWORD", "secret-tech"); - let args = Args { config: None }; + let args = Args { + run: Run::Broker, + config: None, + }; let config = Config::from_args(&args).unwrap(); let consumer_config = config.kafka_consumer_config(); @@ -1370,7 +1394,10 @@ mod tests { "/etc/ssl/taskbroker/private.key", ); - let args = Args { config: None }; + let args = Args { + run: Run::Broker, + config: None, + }; let config = Config::from_args(&args).unwrap(); let consumer_config = config.kafka_consumer_config(); @@ -1393,7 +1420,10 @@ mod tests { #[test] fn test_kafka_producer_config() { - let args = Args { config: None }; + let args = Args { + run: Run::Broker, + config: None, + }; let config = Config::from_args(&args).unwrap(); let producer_config = config.kafka_producer_config(); @@ -1419,7 +1449,10 @@ mod tests { jail.set_env("TASKBROKER_KAFKA_DEADLETTER_SASL_USERNAME", "taskbroker"); jail.set_env("TASKBROKER_KAFKA_DEADLETTER_SASL_PASSWORD", "secret-tech"); - let args = Args { config: None }; + let args = Args { + run: Run::Broker, + config: None, + }; let config = Config::from_args(&args).unwrap(); let producer_config = config.kafka_producer_config(); @@ -1454,7 +1487,10 @@ mod tests { "/etc/ssl/taskbroker/private.key", ); - let args = Args { config: None }; + let args = Args { + run: Run::Broker, + config: None, + }; let config = Config::from_args(&args).unwrap(); let producer_config = config.kafka_producer_config(); @@ -1486,7 +1522,10 @@ mod tests { Jail::expect_with(|jail| { jail.set_env("TASKBROKER_DELIVERY_MODE", "push"); - let args = Args { config: None }; + let args = Args { + run: Run::Broker, + config: None, + }; let config = Config::from_args(&args).unwrap(); assert_eq!(config.delivery_mode, DeliveryMode::Push); @@ -1500,6 +1539,7 @@ mod tests { jail.create_file("config.yaml", "delivery_mode: push")?; let args = Args { + run: Run::Broker, config: Some("config.yaml".to_owned()), }; let config = Config::from_args(&args).unwrap(); @@ -1541,6 +1581,7 @@ kafka_clusters: )?; let args = Args { + run: Run::Broker, config: Some("config.yaml".to_owned()), }; let config = Config::from_args(&args).unwrap(); @@ -1627,7 +1668,10 @@ kafka_clusters: "10.0.0.2:9092", ); - let args = Args { config: None }; + let args = Args { + run: Run::Broker, + config: None, + }; let config = Config::from_args(&args).unwrap(); let topics = &config.kafka_topics; @@ -1672,6 +1716,7 @@ kafka_clusters: )?; let args = Args { + run: Run::Broker, config: Some("config.yaml".to_owned()), }; let err = Config::from_args(&args).unwrap_err(); @@ -1700,6 +1745,7 @@ kafka_topics: )?; let args = Args { + run: Run::Broker, config: Some("config.yaml".to_owned()), }; let err = Config::from_args(&args).unwrap_err(); @@ -1731,6 +1777,7 @@ kafka_clusters: )?; let args = Args { + run: Run::Broker, config: Some("config.yaml".to_owned()), }; let err = Config::from_args(&args).unwrap_err(); @@ -1766,6 +1813,7 @@ kafka_clusters: )?; let args = Args { + run: Run::Broker, config: Some("config.yaml".to_owned()), }; let err = Config::from_args(&args).unwrap_err(); @@ -1809,6 +1857,7 @@ kafka_clusters: )?; let args = Args { + run: Run::Broker, config: Some("config.yaml".to_owned()), }; let config = Config::from_args(&args).unwrap(); @@ -1849,6 +1898,7 @@ kafka_clusters: )?; let args = Args { + run: Run::Broker, config: Some("config.yaml".to_owned()), }; let err = Config::from_args(&args).unwrap_err(); @@ -1881,6 +1931,7 @@ kafka_clusters: )?; let args = Args { + run: Run::Broker, config: Some("config.yaml".to_owned()), }; let err = Config::from_args(&args).unwrap_err(); @@ -1905,7 +1956,10 @@ kafka_clusters: jail.set_env("TASKBROKER_KAFKA_TOPIC", "taskworker"); jail.set_env("TASKBROKER_KAFKA_DEADLETTER_TOPIC", "taskworker"); - let args = Args { config: None }; + let args = Args { + run: Run::Broker, + config: None, + }; let err = Config::from_args(&args).unwrap_err(); assert!( err.to_string().contains( @@ -1929,7 +1983,10 @@ kafka_clusters: jail.set_env("TASKBROKER_KAFKA_RETRY_TOPIC", "taskworker-dlq"); jail.set_env("TASKBROKER_KAFKA_DEADLETTER_TOPIC", "taskworker-dlq"); - let args = Args { config: None }; + let args = Args { + run: Run::Broker, + config: None, + }; let err = Config::from_args(&args).unwrap_err(); assert!( err.to_string().contains( @@ -1973,6 +2030,7 @@ kafka_clusters: )?; let args = Args { + run: Run::Broker, config: Some("config.yaml".to_owned()), }; let config = Config::from_args(&args).unwrap(); @@ -2020,6 +2078,7 @@ kafka_clusters: )?; let args = Args { + run: Run::Broker, config: Some("config.yaml".to_owned()), }; let config = Config::from_args(&args).unwrap(); @@ -2068,6 +2127,7 @@ kafka_clusters: )?; let args = Args { + run: Run::Broker, config: Some("config.yaml".to_owned()), }; let err = Config::from_args(&args).unwrap_err(); @@ -2112,6 +2172,7 @@ kafka_clusters: )?; let args = Args { + run: Run::Broker, config: Some("config.yaml".to_owned()), }; let err = Config::from_args(&args).unwrap_err(); @@ -2141,7 +2202,10 @@ kafka_clusters: jail.set_env("TASKBROKER_KAFKA_DEADLETTER_TOPIC", "taskworker-ingest-dlq"); jail.set_env("TASKBROKER_KAFKA_DEADLETTER_CLUSTER", "kafka-small:9092"); - let args = Args { config: None }; + let args = Args { + run: Run::Broker, + config: None, + }; let config = Config::from_args(&args).expect("legacy retry config should validate"); // The retry topic resolves to the deadletter cluster (where the diff --git a/src/lib.rs b/src/lib.rs index a0f8bd99..c654df5c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,4 @@ -use clap::Parser; +use clap::{Parser, ValueEnum}; use std::fs; pub mod config; @@ -9,6 +9,7 @@ pub mod kafka; pub mod logging; pub mod metrics; pub mod push; +pub mod run; pub mod runtime_config; pub mod store; pub mod test_utils; @@ -25,10 +26,21 @@ pub fn get_version() -> &'static str { Box::leak(release_name.into_boxed_str()) } +/// What are we running? +#[derive(Debug, Clone, Copy, ValueEnum)] +pub enum Run { + Migrations, + Broker, +} + #[derive(Parser, Debug)] pub struct Args { - /// Path to the configuration file - #[arg(short, long, help = "The path to a config file")] + /// What are we running? + #[arg(short, long, default_value = "broker")] + pub run: Run, + + /// Path to the configuration file. + #[arg(short, long)] pub config: Option, } diff --git a/src/main.rs b/src/main.rs index dd3f9939..d9e400ed 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,62 +1,17 @@ -use std::collections::HashMap; use std::sync::Arc; -use std::time::Duration; -use anyhow::{Error, anyhow}; -use chrono::Utc; +use anyhow::Error; use clap::Parser; -use sentry_protos::taskbroker::v1::consumer_service_server::ConsumerServiceServer; -use taskbroker::worker::{Worker, WorkerClient, WorkerMap}; -use tokio::signal::unix::SignalKind; -use tokio::task::JoinHandle; -use tokio::{select, time}; -use tonic::transport::Server; -use tonic_health::ServingStatus; -use tracing::{debug, error, info, warn}; -use taskbroker::config::{Config, DatabaseAdapter, DeliveryMode}; -use taskbroker::fetch::FetchPool; -use taskbroker::grpc::auth_middleware::AuthLayer; -use taskbroker::grpc::metrics_middleware::MetricsLayer; -use taskbroker::grpc::server::{TaskbrokerServer, flush_updates}; -use taskbroker::kafka::activation_batcher::{ActivationBatcher, ActivationBatcherConfig}; -use taskbroker::kafka::activation_writer::{ActivationWriter, ActivationWriterConfig}; -use taskbroker::kafka::admin::create_missing_topics; -use taskbroker::kafka::consumer::start_consumer; -use taskbroker::kafka::deserialize::{self, DeserializeConfig}; -use taskbroker::kafka::os_stream_writer::{OsStream, OsStreamWriter}; -use taskbroker::logging; +use taskbroker::config::Config; use taskbroker::metrics; -use taskbroker::processing_strategy; -use taskbroker::push::PushPool; -use taskbroker::runtime_config::RuntimeConfigManager; -use taskbroker::store::adapters::postgres::{PostgresStore, PostgresStoreConfig}; -use taskbroker::store::adapters::sqlite::{SqliteStore, SqliteStoreConfig}; -use taskbroker::store::traits::ActivationStore; -use taskbroker::upkeep::upkeep; use taskbroker::{Args, get_version}; -use taskbroker::{SERVICE_NAME, flusher}; - -async fn log_task_completion>(name: T, task: JoinHandle>) { - match task.await { - Ok(Ok(())) => { - info!("Task {} completed", name.as_ref()); - } - Ok(Err(e)) => { - error!("Task {} failed: {:?}", name.as_ref(), e); - } - Err(e) => { - error!("Task {} panicked: {:?}", name.as_ref(), e); - } - } -} +use taskbroker::{Run, logging, run}; #[tokio::main] async fn main() -> Result<(), Error> { let args = Args::parse(); let config = Arc::new(Config::from_args(&args)?); - let runtime_config_manager = - Arc::new(RuntimeConfigManager::new(config.runtime_config_path.clone()).await); println!("taskbroker starting"); println!("version: {}", get_version().trim()); @@ -64,299 +19,8 @@ async fn main() -> Result<(), Error> { logging::init(logging::LoggingConfig::from_config(&config)); metrics::init(metrics::MetricsConfig::from_config(&config)); - let store: Arc = match config.database_adapter { - DatabaseAdapter::Sqlite => Arc::new( - SqliteStore::new(&config.db_path, SqliteStoreConfig::from_config(&config)).await?, - ), - DatabaseAdapter::Postgres => { - Arc::new(PostgresStore::new(PostgresStoreConfig::from_config(&config)).await?) - } - }; - - // If this is an environment where the topics might not exist, check and create them. - if config.create_missing_topics { - let kafka_client_config = config.kafka_consumer_config(); - let (main_topic, _) = config - .consumable_topic() - .map_err(|e| anyhow!("invalid config: {}", e))?; - create_missing_topics( - kafka_client_config.clone(), - main_topic, - config.default_topic_partitions, - ) - .await?; - - // Create retry topic if configured - if let Some(ref retry_topic) = config.kafka_retry_topic { - create_missing_topics( - kafka_client_config, - retry_topic, - config.default_topic_partitions, - ) - .await?; - } - } - - if config.full_vacuum_on_start { - info!("Running full vacuum on database"); - match store.full_vacuum_db().await { - Ok(_) => info!("Full vacuum completed."), - Err(err) => error!("Failed to run full vacuum on startup: {:?}", err), - } + match args.run { + Run::Broker => run::broker(config).await, + Run::Migrations => run::migrations(config).await, } - // Get startup time after migrations and vacuum - let startup_time = Utc::now(); - - // Taskbroker exposes a grpc.v1.health endpoint. We use upkeep to track the health - // of the application. - let (health_reporter, health_service) = tonic_health::server::health_reporter(); - health_reporter - .set_service_status(SERVICE_NAME, ServingStatus::Serving) - .await; - - // Upkeep loop - let upkeep_task = taskbroker::tokio::spawn({ - let upkeep_store = store.clone(); - let upkeep_config = config.clone(); - let runtime_config_manager = runtime_config_manager.clone(); - async move { - upkeep( - upkeep_config, - upkeep_store, - startup_time, - runtime_config_manager.clone(), - health_reporter.clone(), - ) - .await?; - Ok(()) - } - }); - - // Maintenance task loop - let maintenance_task = taskbroker::tokio::spawn({ - let guard = elegant_departure::get_shutdown_guard().shutdown_on_drop(); - let maintenance_store = store.clone(); - let mut timer = time::interval(Duration::from_millis(config.maintenance_task_interval_ms)); - timer.set_missed_tick_behavior(time::MissedTickBehavior::Skip); - - async move { - loop { - select! { - _ = timer.tick() => { - match maintenance_store.vacuum_db().await { - Ok(_) => debug!("ran maintenance vacuum"), - Err(err) => warn!("failed to run maintenance vacuum {:?}", err), - } - }, - _ = guard.wait() => { - break; - } - } - } - Ok(()) - } - }); - - // Consumer from kafka - let consumer_task = taskbroker::tokio::spawn({ - let consumer_store = store.clone(); - let consumer_config = config.clone(); - let runtime_config_manager = runtime_config_manager.clone(); - - // Build list of topics to consume from - let (main_topic, _) = consumer_config - .consumable_topic() - .expect("invalid config: no consumable topic"); - let topics_to_consume = [main_topic.to_owned()]; - - async move { - // The consumer has an internal thread that listens for cancellations, so it doesn't need - // an outer select here like the other tasks. - let topic_refs: Vec<&str> = topics_to_consume.iter().map(|s| s.as_str()).collect(); - start_consumer( - &topic_refs, - &consumer_config.kafka_consumer_config(), - consumer_store.clone(), - processing_strategy!({ - err: - OsStreamWriter::new( - Duration::from_secs(1), - OsStream::StdErr, - ), - - map: - deserialize::new(DeserializeConfig::from_config(&consumer_config)), - - reduce: - ActivationBatcher::new( - ActivationBatcherConfig::from_config(&consumer_config), - runtime_config_manager.clone() - ), - ActivationWriter::new( - consumer_store.clone(), - ActivationWriterConfig::from_config(&consumer_config) - ), - - }), - ) - .await - } - }); - - // Status update flush task - let (status_update_tx, status_update_task) = if config.batch_status_updates { - let (tx, rx) = tokio::sync::mpsc::channel(config.status_update_batch_size); - - let flusher_store = store.clone(); - let flusher_config = config.clone(); - - let handle = taskbroker::tokio::spawn(async move { - flusher::run_flusher( - rx, - flusher_config.status_update_batch_size, - flusher_config.status_update_interval_ms, - move |buffer| Box::pin(flush_updates(flusher_store.clone(), buffer)), - ) - .await - }); - - (Some(tx), Some(handle)) - } else { - (None, None) - }; - - // GRPC server - only start if port is configured (port 0 disables it) - let grpc_server_task = if config.grpc_port > 0 { - Some(taskbroker::tokio::spawn({ - let grpc_store = store.clone(); - let grpc_config = config.clone(); - let grpc_status_tx = status_update_tx.clone(); - - async move { - let addr = format!("{}:{}", grpc_config.grpc_addr, grpc_config.grpc_port) - .parse() - .expect("Failed to parse address"); - - let layers = tower::ServiceBuilder::new() - .layer(MetricsLayer::default()) - .layer(AuthLayer::new(&grpc_config)) - .into_inner(); - - let server = Server::builder() - .layer(layers) - .add_service(ConsumerServiceServer::new(TaskbrokerServer { - store: grpc_store, - config: grpc_config, - update_tx: grpc_status_tx, - })) - .add_service(health_service.clone()) - .serve(addr); - - let guard = elegant_departure::get_shutdown_guard().shutdown_on_drop(); - info!("GRPC server listening on {}", addr); - select! { - biased; - - res = server => { - info!("GRPC server task failed, shutting down"); - - // Wait for any running requests to drain - tokio::time::sleep(Duration::from_secs(5)).await; - match res { - Ok(()) => Ok(()), - Err(e) => Err(anyhow!("GRPC server task failed: {:?}", e)), - } - } - _ = guard.wait() => { - info!("Cancellation token received, shutting down GRPC server"); - - // Wait for any running requests to drain - tokio::time::sleep(Duration::from_secs(5)).await; - Ok(()) - } - } - } - })) - } else { - info!("GRPC server disabled (grpc_port=0)"); - None - }; - - // Initialize push queue - let (sender, receiver) = flume::bounded(config.push_queue_size); - - // Initialize push and fetch pools - let push_pool = PushPool::new(receiver, config.clone(), store.clone()); - let fetch_pool = FetchPool::new(sender, store.clone(), config.clone()); - - // Initialize push threads - let push_task = if config.delivery_mode == DeliveryMode::Push { - let mut workers: Vec = vec![]; - - // For every push thread, create a map from applications to worker connections - for _ in 0..config.push_threads { - let mut map = HashMap::new(); - - for (application, endpoint) in config.worker_map.clone() { - let worker = match Worker::connect(config.clone(), endpoint).await { - Ok(w) => { - debug!("Connected to worker!"); - Box::new(w) as Box - } - - Err(e) => { - error!(error = ?e, "Failed to connect to worker"); - return Err(e); - } - }; - - map.insert(application, worker); - } - - workers.push(map); - } - - Some(taskbroker::tokio::spawn(async move { - push_pool.start(workers).await - })) - } else { - None - }; - - // Initialize fetch threads - let fetch_task = if config.delivery_mode == DeliveryMode::Push { - Some(taskbroker::tokio::spawn( - async move { fetch_pool.start().await }, - )) - } else { - None - }; - - let mut departure = elegant_departure::tokio::depart() - .on_termination() - .on_sigint() - .on_signal(SignalKind::hangup()) - .on_signal(SignalKind::quit()) - .on_completion(log_task_completion("consumer", consumer_task)) - .on_completion(log_task_completion("upkeep_task", upkeep_task)) - .on_completion(log_task_completion("maintenance_task", maintenance_task)); - - if let Some(task) = grpc_server_task { - departure = departure.on_completion(log_task_completion("grpc_server", task)); - } - - if let Some(task) = push_task { - departure = departure.on_completion(log_task_completion("push_task", task)); - } - - if let Some(task) = fetch_task { - departure = departure.on_completion(log_task_completion("fetch_task", task)); - } - - if let Some(task) = status_update_task { - departure = departure.on_completion(log_task_completion("status_update_task", task)); - } - - departure.await; - Ok(()) } diff --git a/src/run.rs b/src/run.rs new file mode 100644 index 00000000..00891a35 --- /dev/null +++ b/src/run.rs @@ -0,0 +1,401 @@ +use std::collections::HashMap; +use std::str::FromStr; +use std::sync::Arc; +use std::time::Duration; + +use anyhow::{Error, anyhow}; +use chrono::Utc; +use sentry_protos::taskbroker::v1::consumer_service_server::ConsumerServiceServer; +use sqlx::ConnectOptions; +use sqlx::postgres::PgConnectOptions; +use tokio::signal::unix::SignalKind; +use tokio::task::JoinHandle; +use tokio::{select, time}; +use tonic::transport::Server; +use tonic_health::ServingStatus; +use tracing::{debug, error, info, warn}; + +use crate::config::{Config, DatabaseAdapter, DeliveryMode}; +use crate::fetch::FetchPool; +use crate::grpc::auth_middleware::AuthLayer; +use crate::grpc::metrics_middleware::MetricsLayer; +use crate::grpc::server::{TaskbrokerServer, flush_updates}; +use crate::kafka::activation_batcher::{ActivationBatcher, ActivationBatcherConfig}; +use crate::kafka::activation_writer::{ActivationWriter, ActivationWriterConfig}; +use crate::kafka::admin::create_missing_topics; +use crate::kafka::consumer::start_consumer; +use crate::kafka::deserialize::{self, DeserializeConfig}; +use crate::kafka::os_stream_writer::{OsStream, OsStreamWriter}; +use crate::processing_strategy; +use crate::push::PushPool; +use crate::runtime_config::RuntimeConfigManager; +use crate::store::adapters::postgres::{ + PostgresStore, PostgresStoreConfig, create_default_postgres_pool, +}; +use crate::store::adapters::sqlite::{SqliteStore, SqliteStoreConfig}; +use crate::store::traits::ActivationStore; +use crate::upkeep::upkeep; +use crate::worker::{Worker, WorkerClient, WorkerMap}; +use crate::{SERVICE_NAME, flusher}; + +async fn log_task_completion>(name: T, task: JoinHandle>) { + match task.await { + Ok(Ok(())) => { + info!("Task {} completed", name.as_ref()); + } + Ok(Err(e)) => { + error!("Task {} failed: {:?}", name.as_ref(), e); + } + Err(e) => { + error!("Task {} panicked: {:?}", name.as_ref(), e); + } + } +} + +/// Run taskbroker. +pub async fn broker(config: Arc) -> Result<(), Error> { + let runtime_config_manager = + Arc::new(RuntimeConfigManager::new(config.runtime_config_path.clone()).await); + + let store: Arc = match config.database_adapter { + DatabaseAdapter::Sqlite => Arc::new( + SqliteStore::new(&config.db_path, SqliteStoreConfig::from_config(&config)).await?, + ), + DatabaseAdapter::Postgres => { + Arc::new(PostgresStore::new(PostgresStoreConfig::from_config(&config)).await?) + } + }; + + // If this is an environment where the topics might not exist, check and create them. + if config.create_missing_topics { + let kafka_client_config = config.kafka_consumer_config(); + let (main_topic, _) = config + .consumable_topic() + .map_err(|e| anyhow!("invalid config: {}", e))?; + create_missing_topics( + kafka_client_config.clone(), + main_topic, + config.default_topic_partitions, + ) + .await?; + + // Create retry topic if configured + if let Some(ref retry_topic) = config.kafka_retry_topic { + create_missing_topics( + kafka_client_config, + retry_topic, + config.default_topic_partitions, + ) + .await?; + } + } + + if config.full_vacuum_on_start { + info!("Running full vacuum on database"); + match store.full_vacuum_db().await { + Ok(_) => info!("Full vacuum completed."), + Err(err) => error!("Failed to run full vacuum on startup: {:?}", err), + } + } + // Get startup time after migrations and vacuum + let startup_time = Utc::now(); + + // Taskbroker exposes a grpc.v1.health endpoint. We use upkeep to track the health + // of the application. + let (health_reporter, health_service) = tonic_health::server::health_reporter(); + health_reporter + .set_service_status(SERVICE_NAME, ServingStatus::Serving) + .await; + + // Upkeep loop + let upkeep_task = crate::tokio::spawn({ + let upkeep_store = store.clone(); + let upkeep_config = config.clone(); + let runtime_config_manager = runtime_config_manager.clone(); + async move { + upkeep( + upkeep_config, + upkeep_store, + startup_time, + runtime_config_manager.clone(), + health_reporter.clone(), + ) + .await?; + Ok(()) + } + }); + + // Maintenance task loop + let maintenance_task = crate::tokio::spawn({ + let guard = elegant_departure::get_shutdown_guard().shutdown_on_drop(); + let maintenance_store = store.clone(); + let mut timer = time::interval(Duration::from_millis(config.maintenance_task_interval_ms)); + timer.set_missed_tick_behavior(time::MissedTickBehavior::Skip); + + async move { + loop { + select! { + _ = timer.tick() => { + match maintenance_store.vacuum_db().await { + Ok(_) => debug!("ran maintenance vacuum"), + Err(err) => warn!("failed to run maintenance vacuum {:?}", err), + } + }, + _ = guard.wait() => { + break; + } + } + } + Ok(()) + } + }); + + // Consumer from kafka + let consumer_task = crate::tokio::spawn({ + let consumer_store = store.clone(); + let consumer_config = config.clone(); + let runtime_config_manager = runtime_config_manager.clone(); + + // Build list of topics to consume from + let (main_topic, _) = consumer_config + .consumable_topic() + .expect("invalid config: no consumable topic"); + let topics_to_consume = [main_topic.to_owned()]; + + async move { + // The consumer has an internal thread that listens for cancellations, so it doesn't need + // an outer select here like the other tasks. + let topic_refs: Vec<&str> = topics_to_consume.iter().map(|s| s.as_str()).collect(); + start_consumer( + &topic_refs, + &consumer_config.kafka_consumer_config(), + consumer_store.clone(), + processing_strategy!({ + err: + OsStreamWriter::new( + Duration::from_secs(1), + OsStream::StdErr, + ), + + map: + deserialize::new(DeserializeConfig::from_config(&consumer_config)), + + reduce: + ActivationBatcher::new( + ActivationBatcherConfig::from_config(&consumer_config), + runtime_config_manager.clone() + ), + ActivationWriter::new( + consumer_store.clone(), + ActivationWriterConfig::from_config(&consumer_config) + ), + + }), + ) + .await + } + }); + + // Status update flush task + let (status_update_tx, status_update_task) = if config.batch_status_updates { + let (tx, rx) = tokio::sync::mpsc::channel(config.status_update_batch_size); + + let flusher_store = store.clone(); + let flusher_config = config.clone(); + + let handle = crate::tokio::spawn(async move { + flusher::run_flusher( + rx, + flusher_config.status_update_batch_size, + flusher_config.status_update_interval_ms, + move |buffer| Box::pin(flush_updates(flusher_store.clone(), buffer)), + ) + .await + }); + + (Some(tx), Some(handle)) + } else { + (None, None) + }; + + // GRPC server - only start if port is configured (port 0 disables it) + let grpc_server_task = if config.grpc_port > 0 { + Some(crate::tokio::spawn({ + let grpc_store = store.clone(); + let grpc_config = config.clone(); + let grpc_status_tx = status_update_tx.clone(); + + async move { + let addr = format!("{}:{}", grpc_config.grpc_addr, grpc_config.grpc_port) + .parse() + .expect("Failed to parse address"); + + let layers = tower::ServiceBuilder::new() + .layer(MetricsLayer::default()) + .layer(AuthLayer::new(&grpc_config)) + .into_inner(); + + let server = Server::builder() + .layer(layers) + .add_service(ConsumerServiceServer::new(TaskbrokerServer { + store: grpc_store, + config: grpc_config, + update_tx: grpc_status_tx, + })) + .add_service(health_service.clone()) + .serve(addr); + + let guard = elegant_departure::get_shutdown_guard().shutdown_on_drop(); + info!("GRPC server listening on {}", addr); + select! { + biased; + + res = server => { + info!("GRPC server task failed, shutting down"); + + // Wait for any running requests to drain + tokio::time::sleep(Duration::from_secs(5)).await; + match res { + Ok(()) => Ok(()), + Err(e) => Err(anyhow!("GRPC server task failed: {:?}", e)), + } + } + _ = guard.wait() => { + info!("Cancellation token received, shutting down GRPC server"); + + // Wait for any running requests to drain + tokio::time::sleep(Duration::from_secs(5)).await; + Ok(()) + } + } + } + })) + } else { + info!("GRPC server disabled (grpc_port=0)"); + None + }; + + // Initialize push queue + let (sender, receiver) = flume::bounded(config.push_queue_size); + + // Initialize push and fetch pools + let push_pool = PushPool::new(receiver, config.clone(), store.clone()); + let fetch_pool = FetchPool::new(sender, store.clone(), config.clone()); + + // Initialize push threads + let push_task = if config.delivery_mode == DeliveryMode::Push { + let mut workers: Vec = vec![]; + + // For every push thread, create a map from applications to worker connections + for _ in 0..config.push_threads { + let mut map = HashMap::new(); + + for (application, endpoint) in config.worker_map.clone() { + let worker = match Worker::connect(config.clone(), endpoint).await { + Ok(w) => { + debug!("Connected to worker!"); + Box::new(w) as Box + } + + Err(e) => { + error!(error = ?e, "Failed to connect to worker"); + return Err(e); + } + }; + + map.insert(application, worker); + } + + workers.push(map); + } + + Some(crate::tokio::spawn(async move { + push_pool.start(workers).await + })) + } else { + None + }; + + // Initialize fetch threads + let fetch_task = if config.delivery_mode == DeliveryMode::Push { + Some(crate::tokio::spawn(async move { fetch_pool.start().await })) + } else { + None + }; + + let mut departure = elegant_departure::tokio::depart() + .on_termination() + .on_sigint() + .on_signal(SignalKind::hangup()) + .on_signal(SignalKind::quit()) + .on_completion(log_task_completion("consumer", consumer_task)) + .on_completion(log_task_completion("upkeep_task", upkeep_task)) + .on_completion(log_task_completion("maintenance_task", maintenance_task)); + + if let Some(task) = grpc_server_task { + departure = departure.on_completion(log_task_completion("grpc_server", task)); + } + + if let Some(task) = push_task { + departure = departure.on_completion(log_task_completion("push_task", task)); + } + + if let Some(task) = fetch_task { + departure = departure.on_completion(log_task_completion("fetch_task", task)); + } + + if let Some(task) = status_update_task { + departure = departure.on_completion(log_task_completion("status_update_task", task)); + } + + departure.await; + Ok(()) +} + +/// Run migrations. +pub async fn migrations(config: Arc) -> Result<(), Error> { + if config.database_adapter == DatabaseAdapter::Sqlite { + return Ok(()); + } + + let mut conn_opts = PgConnectOptions::new() + .username(&config.pg_ddl_username) + .password(&config.pg_ddl_password) + .host(&config.pg_host) + .port(config.pg_port); + + if let Some(extra_query_params) = config.pg_extra_query_params.as_ref() { + let url = conn_opts.to_url_lossy(); + let new_url = + url.as_ref().split('?').next().unwrap().to_string() + "?" + extra_query_params; + conn_opts = PgConnectOptions::from_str(&new_url).unwrap(); + } + + let default_pool = + create_default_postgres_pool(&conn_opts, &config.pg_default_database_name).await?; + + // Create the database if it doesn't exist + let row: (bool,) = + sqlx::query_as("SELECT EXISTS ( SELECT 1 FROM pg_catalog.pg_database WHERE datname = $1 )") + .bind(&config.pg_database_name) + .fetch_one(&default_pool) + .await?; + + if !row.0 { + println!("Creating database {}", &config.pg_database_name); + sqlx::query(format!("CREATE DATABASE {}", &config.pg_database_name).as_str()) + .bind(&config.pg_database_name) + .execute(&default_pool) + .await?; + } + + // Close the default pool + default_pool.close().await; + + println!("Running migrations on database"); + sqlx::migrate!("./migrations/postgres") + .run(&default_pool) + .await?; + + Ok(()) +} diff --git a/src/store/adapters/postgres.rs b/src/store/adapters/postgres.rs index 9917dd98..7d7ce264 100644 --- a/src/store/adapters/postgres.rs +++ b/src/store/adapters/postgres.rs @@ -207,42 +207,9 @@ impl PostgresStore { #[framed] pub async fn new(config: PostgresStoreConfig) -> Result { - if config.run_migrations { - let default_pool = create_default_postgres_pool( - &config.pg_connection, - &config.pg_default_database_name, - ) - .await?; - - // Create the database if it doesn't exist - let row: (bool,) = sqlx::query_as( - "SELECT EXISTS ( SELECT 1 FROM pg_catalog.pg_database WHERE datname = $1 )", - ) - .bind(&config.pg_database_name) - .fetch_one(&default_pool) - .await?; - - if !row.0 { - println!("Creating database {}", &config.pg_database_name); - sqlx::query(format!("CREATE DATABASE {}", &config.pg_database_name).as_str()) - .bind(&config.pg_database_name) - .execute(&default_pool) - .await?; - } - // Close the default pool - default_pool.close().await; - } - let (read_pool, write_pool) = create_postgres_pool(&config.pg_connection, &config.pg_database_name).await?; - if config.run_migrations { - println!("Running migrations on database"); - sqlx::migrate!("./migrations/postgres") - .run(&write_pool) - .await?; - } - Ok(Self { read_pool, write_pool, From 714044c3406022168d099f82538c7943c378b66a Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Wed, 3 Jun 2026 14:40:19 -0700 Subject: [PATCH 2/3] Add GoCD Stage --- gocd/templates/bash/run-migrations.sh | 13 +++++++++++ gocd/templates/pipelines/taskbroker.libsonnet | 22 ++++++++++++++++++- src/run.rs | 11 +++++++--- 3 files changed, 42 insertions(+), 4 deletions(-) create mode 100644 gocd/templates/bash/run-migrations.sh diff --git a/gocd/templates/bash/run-migrations.sh b/gocd/templates/bash/run-migrations.sh new file mode 100644 index 00000000..650702ff --- /dev/null +++ b/gocd/templates/bash/run-migrations.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +eval $(regions-project-env-vars --region="${SENTRY_REGION}") +/devinfra/scripts/get-cluster-credentials + +k8s-spawn-job \ + --label-selector="${LABEL_SELECTOR}" \ + --image="us-central1-docker.pkg.dev/sentryio/taskbroker/image:${GO_REVISION_TASKBROKER_REPO}" \ + --container-name="taskbroker" \ + --name="taskbroker-migrations" \ + -- \ + /opt/taskbroker \ + --run migrations diff --git a/gocd/templates/pipelines/taskbroker.libsonnet b/gocd/templates/pipelines/taskbroker.libsonnet index a90c31be..3a0b08f5 100644 --- a/gocd/templates/pipelines/taskbroker.libsonnet +++ b/gocd/templates/pipelines/taskbroker.libsonnet @@ -18,6 +18,24 @@ local checks_stage = { }, }; +local run_migrations_stage = { + 'run-migrations': { + fetch_materials: true, + jobs: { + 'run-migrations': { + timeout: 60, + elastic_profile_id: 'taskbroker', + environment_variables: { + LABEL_SELECTOR: 'service=taskbroker', + }, + tasks: [ + gocdtasks.script(importstr '../bash/run-migrations.sh'), + ], + }, + }, + }, +}; + local deploy_canary_stage(region) = if region == 'us' || region == 'de' then [ @@ -73,5 +91,7 @@ function(region) { }, }, lock_behavior: 'unlockWhenFinished', - stages: [checks_stage] + deploy_canary_stage(region) + [deployPrimaryStage], + stages: [checks_stage, run_migrations_stage] + + deploy_canary_stage(region) + + [deployPrimaryStage], } diff --git a/src/run.rs b/src/run.rs index 00891a35..a528f782 100644 --- a/src/run.rs +++ b/src/run.rs @@ -7,7 +7,7 @@ use anyhow::{Error, anyhow}; use chrono::Utc; use sentry_protos::taskbroker::v1::consumer_service_server::ConsumerServiceServer; use sqlx::ConnectOptions; -use sqlx::postgres::PgConnectOptions; +use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use tokio::signal::unix::SignalKind; use tokio::task::JoinHandle; use tokio::{select, time}; @@ -389,13 +389,18 @@ pub async fn migrations(config: Arc) -> Result<(), Error> { .await?; } - // Close the default pool default_pool.close().await; + let migration_pool = PgPoolOptions::new() + .max_connections(1) + .connect_with(conn_opts.database(&config.pg_database_name)) + .await?; + println!("Running migrations on database"); sqlx::migrate!("./migrations/postgres") - .run(&default_pool) + .run(&migration_pool) .await?; + migration_pool.close().await; Ok(()) } From adc7cf024e8de8406a6832ea41a2f742e4070254 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Wed, 3 Jun 2026 16:32:13 -0700 Subject: [PATCH 3/3] Fixes, Tests --- src/config.rs | 13 ++ src/lib.rs | 3 +- src/main.rs | 359 ++++++++++++++++++++++++++++- src/run.rs | 410 --------------------------------- src/store/adapters/postgres.rs | 53 ++++- src/test_utils.rs | 14 +- 6 files changed, 427 insertions(+), 425 deletions(-) delete mode 100644 src/run.rs diff --git a/src/config.rs b/src/config.rs index 179ee82e..79b8e14f 100644 --- a/src/config.rs +++ b/src/config.rs @@ -13,6 +13,7 @@ use validator::{Validate, ValidationError}; use crate::Args; use crate::fetch::MAX_FETCH_THREADS; use crate::logging::LogFormat; +use crate::store::adapters::postgres; /// Configuration for a single Kafka topic in multi-topic mode. #[derive(Clone, Debug, PartialEq, Deserialize, Serialize)] @@ -130,6 +131,18 @@ pub enum DatabaseAdapter { Postgres, } +impl DatabaseAdapter { + pub async fn migrate(&self, config: &Config) -> Result<()> { + match self { + Self::Postgres => postgres::migrate(config).await, + Self::Sqlite => { + warn!("Standalone migration not supported for SQLite"); + Ok(()) + } + } + } +} + /// How the taskbroker delivers tasks to workers. #[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Deserialize, Serialize)] #[serde(rename_all = "lowercase")] diff --git a/src/lib.rs b/src/lib.rs index c654df5c..23b562ce 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,7 +9,6 @@ pub mod kafka; pub mod logging; pub mod metrics; pub mod push; -pub mod run; pub mod runtime_config; pub mod store; pub mod test_utils; @@ -27,7 +26,7 @@ pub fn get_version() -> &'static str { } /// What are we running? -#[derive(Debug, Clone, Copy, ValueEnum)] +#[derive(Debug, Clone, Copy, ValueEnum, PartialEq)] pub enum Run { Migrations, Broker, diff --git a/src/main.rs b/src/main.rs index d9e400ed..8ad579c9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,17 +1,62 @@ +use std::collections::{BTreeMap, HashMap}; use std::sync::Arc; +use std::time::Duration; -use anyhow::Error; +use anyhow::{Error, anyhow}; +use chrono::Utc; use clap::Parser; +use sentry_protos::taskbroker::v1::consumer_service_server::ConsumerServiceServer; +use taskbroker::worker::{Worker, WorkerClient, WorkerMap}; +use tokio::signal::unix::SignalKind; +use tokio::task::JoinHandle; +use tokio::{select, time}; +use tonic::transport::Server; +use tonic_health::ServingStatus; +use tracing::{debug, error, info, warn}; -use taskbroker::config::Config; +use taskbroker::config::{Config, DatabaseAdapter, DeliveryMode}; +use taskbroker::fetch::FetchPool; +use taskbroker::grpc::auth_middleware::AuthLayer; +use taskbroker::grpc::metrics_middleware::MetricsLayer; +use taskbroker::grpc::server::{TaskbrokerServer, flush_updates}; +use taskbroker::kafka::activation_batcher::{ActivationBatcher, ActivationBatcherConfig}; +use taskbroker::kafka::activation_writer::{ActivationWriter, ActivationWriterConfig}; +use taskbroker::kafka::admin::create_missing_topics; +use taskbroker::kafka::consumer::start_consumer; +use taskbroker::kafka::deserialize::{self, DeserializeConfig}; +use taskbroker::kafka::os_stream_writer::{OsStream, OsStreamWriter}; use taskbroker::metrics; +use taskbroker::processing_strategy; +use taskbroker::push::PushPool; +use taskbroker::runtime_config::RuntimeConfigManager; +use taskbroker::store::adapters::postgres::{self, PostgresStore, PostgresStoreConfig}; +use taskbroker::store::adapters::sqlite::{SqliteStore, SqliteStoreConfig}; +use taskbroker::store::traits::ActivationStore; +use taskbroker::upkeep::upkeep; use taskbroker::{Args, get_version}; -use taskbroker::{Run, logging, run}; +use taskbroker::{Run, logging}; +use taskbroker::{SERVICE_NAME, flusher}; + +async fn log_task_completion>(name: T, task: JoinHandle>) { + match task.await { + Ok(Ok(())) => { + info!("Task {} completed", name.as_ref()); + } + Ok(Err(e)) => { + error!("Task {} failed: {:?}", name.as_ref(), e); + } + Err(e) => { + error!("Task {} panicked: {:?}", name.as_ref(), e); + } + } +} #[tokio::main] async fn main() -> Result<(), Error> { let args = Args::parse(); let config = Arc::new(Config::from_args(&args)?); + let runtime_config_manager = + Arc::new(RuntimeConfigManager::new(config.runtime_config_path.clone()).await); println!("taskbroker starting"); println!("version: {}", get_version().trim()); @@ -19,8 +64,310 @@ async fn main() -> Result<(), Error> { logging::init(logging::LoggingConfig::from_config(&config)); metrics::init(metrics::MetricsConfig::from_config(&config)); - match args.run { - Run::Broker => run::broker(config).await, - Run::Migrations => run::migrations(config).await, + if args.run == Run::Migrations { + return config.database_adapter.migrate(&config).await; + } + + let store: Arc = match config.database_adapter { + DatabaseAdapter::Sqlite => Arc::new( + SqliteStore::new(&config.db_path, SqliteStoreConfig::from_config(&config)).await?, + ), + DatabaseAdapter::Postgres => { + if config.run_migrations { + postgres::migrate(&config).await?; + } + + Arc::new(PostgresStore::new(PostgresStoreConfig::from_config(&config)).await?) + } + }; + + // If this is an environment where the topics might not exist, check and create them. + if config.create_missing_topics { + // Group every declared topic by its cluster so each is created on the + // right brokers (main, retry, deadletter and produce-only topics, which + // may live on different clusters). + let mut topics_by_cluster: BTreeMap<&str, Vec<(&str, i32)>> = BTreeMap::new(); + for (topic_name, topic_config) in &config.kafka_topics { + topics_by_cluster + .entry(topic_config.cluster.as_str()) + .or_default() + .push((topic_name.as_str(), config.default_topic_partitions)); + } + for (cluster, topics) in topics_by_cluster { + create_missing_topics(config.kafka_admin_config(cluster), &topics).await?; + } + } + + if config.full_vacuum_on_start { + info!("Running full vacuum on database"); + match store.full_vacuum_db().await { + Ok(_) => info!("Full vacuum completed."), + Err(err) => error!("Failed to run full vacuum on startup: {:?}", err), + } + } + // Get startup time after migrations and vacuum + let startup_time = Utc::now(); + + // Taskbroker exposes a grpc.v1.health endpoint. We use upkeep to track the health + // of the application. + let (health_reporter, health_service) = tonic_health::server::health_reporter(); + health_reporter + .set_service_status(SERVICE_NAME, ServingStatus::Serving) + .await; + + // Upkeep loop + let upkeep_task = taskbroker::tokio::spawn({ + let upkeep_store = store.clone(); + let upkeep_config = config.clone(); + let runtime_config_manager = runtime_config_manager.clone(); + async move { + upkeep( + upkeep_config, + upkeep_store, + startup_time, + runtime_config_manager.clone(), + health_reporter.clone(), + ) + .await?; + Ok(()) + } + }); + + // Maintenance task loop + let maintenance_task = taskbroker::tokio::spawn({ + let guard = elegant_departure::get_shutdown_guard().shutdown_on_drop(); + let maintenance_store = store.clone(); + let mut timer = time::interval(Duration::from_millis(config.maintenance_task_interval_ms)); + timer.set_missed_tick_behavior(time::MissedTickBehavior::Skip); + + async move { + loop { + select! { + _ = timer.tick() => { + match maintenance_store.vacuum_db().await { + Ok(_) => debug!("ran maintenance vacuum"), + Err(err) => warn!("failed to run maintenance vacuum {:?}", err), + } + }, + _ = guard.wait() => { + break; + } + } + } + Ok(()) + } + }); + + // Consumer(s) from kafka. Each consumed topic gets its own consumer (own + // group.id and cluster), so we spawn one consumer task per consumable topic, + // all sharing the one activation store. + let consumer_topics: Vec = config + .consumable_topics() + .expect("invalid config: no consumable topic") + .into_iter() + .map(|(name, _)| name.to_owned()) + .collect(); + + let mut consumer_tasks: Vec<(String, JoinHandle>)> = Vec::new(); + for topic in consumer_topics { + let consumer_store = store.clone(); + let consumer_config = config.clone(); + let runtime_config_manager = runtime_config_manager.clone(); + let task_topic = topic.clone(); + + let handle = taskbroker::tokio::spawn(async move { + // The consumer has an internal thread that listens for cancellations, so it doesn't need + // an outer select here like the other tasks. + let topic_refs = [task_topic.as_str()]; + start_consumer( + &topic_refs, + &consumer_config.kafka_consumer_config_for(&task_topic), + consumer_store.clone(), + processing_strategy!({ + err: + OsStreamWriter::new( + Duration::from_secs(1), + OsStream::StdErr, + ), + + map: + deserialize::new(DeserializeConfig::from_topic(&consumer_config, &task_topic)), + + reduce: + ActivationBatcher::new( + ActivationBatcherConfig::from_topic(&consumer_config, &task_topic), + runtime_config_manager.clone() + ), + ActivationWriter::new( + consumer_store.clone(), + ActivationWriterConfig::from_topic(&consumer_config, &task_topic) + ), + + }), + ) + .await + }); + consumer_tasks.push((topic, handle)); + } + + // Status update flush task + let (status_update_tx, status_update_task) = if config.batch_status_updates { + let (tx, rx) = tokio::sync::mpsc::channel(config.status_update_batch_size); + + let flusher_store = store.clone(); + let flusher_config = config.clone(); + + let handle = taskbroker::tokio::spawn(async move { + flusher::run_flusher( + rx, + flusher_config.status_update_batch_size, + flusher_config.status_update_interval_ms, + move |buffer| Box::pin(flush_updates(flusher_store.clone(), buffer)), + ) + .await + }); + + (Some(tx), Some(handle)) + } else { + (None, None) + }; + + // GRPC server - only start if port is configured (port 0 disables it) + let grpc_server_task = if config.grpc_port > 0 { + Some(taskbroker::tokio::spawn({ + let grpc_store = store.clone(); + let grpc_config = config.clone(); + let grpc_status_tx = status_update_tx.clone(); + + async move { + let addr = format!("{}:{}", grpc_config.grpc_addr, grpc_config.grpc_port) + .parse() + .expect("Failed to parse address"); + + let layers = tower::ServiceBuilder::new() + .layer(MetricsLayer::default()) + .layer(AuthLayer::new(&grpc_config)) + .into_inner(); + + let server = Server::builder() + .layer(layers) + .add_service(ConsumerServiceServer::new(TaskbrokerServer { + store: grpc_store, + config: grpc_config, + update_tx: grpc_status_tx, + })) + .add_service(health_service.clone()) + .serve(addr); + + let guard = elegant_departure::get_shutdown_guard().shutdown_on_drop(); + info!("GRPC server listening on {}", addr); + select! { + biased; + + res = server => { + info!("GRPC server task failed, shutting down"); + + // Wait for any running requests to drain + tokio::time::sleep(Duration::from_secs(5)).await; + match res { + Ok(()) => Ok(()), + Err(e) => Err(anyhow!("GRPC server task failed: {:?}", e)), + } + } + _ = guard.wait() => { + info!("Cancellation token received, shutting down GRPC server"); + + // Wait for any running requests to drain + tokio::time::sleep(Duration::from_secs(5)).await; + Ok(()) + } + } + } + })) + } else { + info!("GRPC server disabled (grpc_port=0)"); + None + }; + + // Initialize push queue + let (sender, receiver) = flume::bounded(config.push_queue_size); + + // Initialize push and fetch pools + let push_pool = PushPool::new(receiver, config.clone(), store.clone()); + let fetch_pool = FetchPool::new(sender, store.clone(), config.clone()); + + // Initialize push threads + let push_task = if config.delivery_mode == DeliveryMode::Push { + let mut workers: Vec = vec![]; + + // For every push thread, create a map from applications to worker connections + for _ in 0..config.push_threads { + let mut map = HashMap::new(); + + for (application, endpoint) in config.worker_map.clone() { + let worker = match Worker::connect(config.clone(), endpoint).await { + Ok(w) => { + debug!("Connected to worker!"); + Box::new(w) as Box + } + + Err(e) => { + error!(error = ?e, "Failed to connect to worker"); + return Err(e); + } + }; + + map.insert(application, worker); + } + + workers.push(map); + } + + Some(taskbroker::tokio::spawn(async move { + push_pool.start(workers).await + })) + } else { + None + }; + + // Initialize fetch threads + let fetch_task = if config.delivery_mode == DeliveryMode::Push { + Some(taskbroker::tokio::spawn( + async move { fetch_pool.start().await }, + )) + } else { + None + }; + + let mut departure = elegant_departure::tokio::depart() + .on_termination() + .on_sigint() + .on_signal(SignalKind::hangup()) + .on_signal(SignalKind::quit()) + .on_completion(log_task_completion("upkeep_task", upkeep_task)) + .on_completion(log_task_completion("maintenance_task", maintenance_task)); + + for (topic, handle) in consumer_tasks { + departure = + departure.on_completion(log_task_completion(format!("consumer:{topic}"), handle)); } + + if let Some(task) = grpc_server_task { + departure = departure.on_completion(log_task_completion("grpc_server", task)); + } + + if let Some(task) = push_task { + departure = departure.on_completion(log_task_completion("push_task", task)); + } + + if let Some(task) = fetch_task { + departure = departure.on_completion(log_task_completion("fetch_task", task)); + } + + if let Some(task) = status_update_task { + departure = departure.on_completion(log_task_completion("status_update_task", task)); + } + + departure.await; + Ok(()) } diff --git a/src/run.rs b/src/run.rs deleted file mode 100644 index bcb16167..00000000 --- a/src/run.rs +++ /dev/null @@ -1,410 +0,0 @@ -use std::collections::{BTreeMap, HashMap}; -use std::str::FromStr; -use std::sync::Arc; -use std::time::Duration; - -use anyhow::{Error, anyhow}; -use chrono::Utc; -use sentry_protos::taskbroker::v1::consumer_service_server::ConsumerServiceServer; -use sqlx::ConnectOptions; -use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; -use tokio::signal::unix::SignalKind; -use tokio::task::JoinHandle; -use tokio::{select, time}; -use tonic::transport::Server; -use tonic_health::ServingStatus; -use tracing::{debug, error, info, warn}; - -use crate::config::{Config, DatabaseAdapter, DeliveryMode}; -use crate::fetch::FetchPool; -use crate::grpc::auth_middleware::AuthLayer; -use crate::grpc::metrics_middleware::MetricsLayer; -use crate::grpc::server::{TaskbrokerServer, flush_updates}; -use crate::kafka::activation_batcher::{ActivationBatcher, ActivationBatcherConfig}; -use crate::kafka::activation_writer::{ActivationWriter, ActivationWriterConfig}; -use crate::kafka::admin::create_missing_topics; -use crate::kafka::consumer::start_consumer; -use crate::kafka::deserialize::{self, DeserializeConfig}; -use crate::kafka::os_stream_writer::{OsStream, OsStreamWriter}; -use crate::processing_strategy; -use crate::push::PushPool; -use crate::runtime_config::RuntimeConfigManager; -use crate::store::adapters::postgres::{ - PostgresStore, PostgresStoreConfig, create_default_postgres_pool, -}; -use crate::store::adapters::sqlite::{SqliteStore, SqliteStoreConfig}; -use crate::store::traits::ActivationStore; -use crate::upkeep::upkeep; -use crate::worker::{Worker, WorkerClient, WorkerMap}; -use crate::{SERVICE_NAME, flusher}; - -async fn log_task_completion>(name: T, task: JoinHandle>) { - match task.await { - Ok(Ok(())) => { - info!("Task {} completed", name.as_ref()); - } - Ok(Err(e)) => { - error!("Task {} failed: {:?}", name.as_ref(), e); - } - Err(e) => { - error!("Task {} panicked: {:?}", name.as_ref(), e); - } - } -} - -/// Run taskbroker. -pub async fn broker(config: Arc) -> Result<(), Error> { - let runtime_config_manager = - Arc::new(RuntimeConfigManager::new(config.runtime_config_path.clone()).await); - - let store: Arc = match config.database_adapter { - DatabaseAdapter::Sqlite => Arc::new( - SqliteStore::new(&config.db_path, SqliteStoreConfig::from_config(&config)).await?, - ), - DatabaseAdapter::Postgres => { - Arc::new(PostgresStore::new(PostgresStoreConfig::from_config(&config)).await?) - } - }; - - // If this is an environment where the topics might not exist, check and create them. - if config.create_missing_topics { - // Group every declared topic by its cluster so each is created on the - // right brokers (main, retry, deadletter and produce-only topics, which - // may live on different clusters). - let mut topics_by_cluster: BTreeMap<&str, Vec<(&str, i32)>> = BTreeMap::new(); - - for (topic_name, topic_config) in &config.kafka_topics { - topics_by_cluster - .entry(topic_config.cluster.as_str()) - .or_default() - .push((topic_name.as_str(), config.default_topic_partitions)); - } - - for (cluster, topics) in topics_by_cluster { - create_missing_topics(config.kafka_admin_config(cluster), &topics).await?; - } - } - - if config.full_vacuum_on_start { - info!("Running full vacuum on database"); - match store.full_vacuum_db().await { - Ok(_) => info!("Full vacuum completed."), - Err(err) => error!("Failed to run full vacuum on startup: {:?}", err), - } - } - // Get startup time after migrations and vacuum - let startup_time = Utc::now(); - - // Taskbroker exposes a grpc.v1.health endpoint. We use upkeep to track the health - // of the application. - let (health_reporter, health_service) = tonic_health::server::health_reporter(); - health_reporter - .set_service_status(SERVICE_NAME, ServingStatus::Serving) - .await; - - // Upkeep loop - let upkeep_task = crate::tokio::spawn({ - let upkeep_store = store.clone(); - let upkeep_config = config.clone(); - let runtime_config_manager = runtime_config_manager.clone(); - async move { - upkeep( - upkeep_config, - upkeep_store, - startup_time, - runtime_config_manager.clone(), - health_reporter.clone(), - ) - .await?; - Ok(()) - } - }); - - // Maintenance task loop - let maintenance_task = crate::tokio::spawn({ - let guard = elegant_departure::get_shutdown_guard().shutdown_on_drop(); - let maintenance_store = store.clone(); - let mut timer = time::interval(Duration::from_millis(config.maintenance_task_interval_ms)); - timer.set_missed_tick_behavior(time::MissedTickBehavior::Skip); - - async move { - loop { - select! { - _ = timer.tick() => { - match maintenance_store.vacuum_db().await { - Ok(_) => debug!("ran maintenance vacuum"), - Err(err) => warn!("failed to run maintenance vacuum {:?}", err), - } - }, - _ = guard.wait() => { - break; - } - } - } - Ok(()) - } - }); - - // Consumer(s) from kafka. Each consumed topic gets its own consumer (own - // group.id and cluster), so we spawn one consumer task per consumable topic, - // all sharing the one activation store. - let consumer_topics: Vec = config - .consumable_topics() - .expect("invalid config: no consumable topic") - .into_iter() - .map(|(name, _)| name.to_owned()) - .collect(); - - let mut consumer_tasks: Vec<(String, JoinHandle>)> = Vec::new(); - for topic in consumer_topics { - let consumer_store = store.clone(); - let consumer_config = config.clone(); - let runtime_config_manager = runtime_config_manager.clone(); - let task_topic = topic.clone(); - - let handle = crate::tokio::spawn(async move { - // The consumer has an internal thread that listens for cancellations, so it doesn't need - // an outer select here like the other tasks. - let topic_refs = [task_topic.as_str()]; - start_consumer( - &topic_refs, - &consumer_config.kafka_consumer_config_for(&task_topic), - consumer_store.clone(), - processing_strategy!({ - err: - OsStreamWriter::new( - Duration::from_secs(1), - OsStream::StdErr, - ), - - map: - deserialize::new(DeserializeConfig::from_topic(&consumer_config, &task_topic)), - - reduce: - ActivationBatcher::new( - ActivationBatcherConfig::from_topic(&consumer_config, &task_topic), - runtime_config_manager.clone() - ), - ActivationWriter::new( - consumer_store.clone(), - ActivationWriterConfig::from_topic(&consumer_config, &task_topic) - ), - }), - ) - .await - }); - consumer_tasks.push((topic, handle)); - } - - // Status update flush task - let (status_update_tx, status_update_task) = if config.batch_status_updates { - let (tx, rx) = tokio::sync::mpsc::channel(config.status_update_batch_size); - - let flusher_store = store.clone(); - let flusher_config = config.clone(); - - let handle = crate::tokio::spawn(async move { - flusher::run_flusher( - rx, - flusher_config.status_update_batch_size, - flusher_config.status_update_interval_ms, - move |buffer| Box::pin(flush_updates(flusher_store.clone(), buffer)), - ) - .await - }); - - (Some(tx), Some(handle)) - } else { - (None, None) - }; - - // GRPC server - only start if port is configured (port 0 disables it) - let grpc_server_task = if config.grpc_port > 0 { - Some(crate::tokio::spawn({ - let grpc_store = store.clone(); - let grpc_config = config.clone(); - let grpc_status_tx = status_update_tx.clone(); - - async move { - let addr = format!("{}:{}", grpc_config.grpc_addr, grpc_config.grpc_port) - .parse() - .expect("Failed to parse address"); - - let layers = tower::ServiceBuilder::new() - .layer(MetricsLayer::default()) - .layer(AuthLayer::new(&grpc_config)) - .into_inner(); - - let server = Server::builder() - .layer(layers) - .add_service(ConsumerServiceServer::new(TaskbrokerServer { - store: grpc_store, - config: grpc_config, - update_tx: grpc_status_tx, - })) - .add_service(health_service.clone()) - .serve(addr); - - let guard = elegant_departure::get_shutdown_guard().shutdown_on_drop(); - info!("GRPC server listening on {}", addr); - select! { - biased; - - res = server => { - info!("GRPC server task failed, shutting down"); - - // Wait for any running requests to drain - tokio::time::sleep(Duration::from_secs(5)).await; - match res { - Ok(()) => Ok(()), - Err(e) => Err(anyhow!("GRPC server task failed: {:?}", e)), - } - } - _ = guard.wait() => { - info!("Cancellation token received, shutting down GRPC server"); - - // Wait for any running requests to drain - tokio::time::sleep(Duration::from_secs(5)).await; - Ok(()) - } - } - } - })) - } else { - info!("GRPC server disabled (grpc_port=0)"); - None - }; - - // Initialize push queue - let (sender, receiver) = flume::bounded(config.push_queue_size); - - // Initialize push and fetch pools - let push_pool = PushPool::new(receiver, config.clone(), store.clone()); - let fetch_pool = FetchPool::new(sender, store.clone(), config.clone()); - - // Initialize push threads - let push_task = if config.delivery_mode == DeliveryMode::Push { - let mut workers: Vec = vec![]; - - // For every push thread, create a map from applications to worker connections - for _ in 0..config.push_threads { - let mut map = HashMap::new(); - - for (application, endpoint) in config.worker_map.clone() { - let worker = match Worker::connect(config.clone(), endpoint).await { - Ok(w) => { - debug!("Connected to worker!"); - Box::new(w) as Box - } - - Err(e) => { - error!(error = ?e, "Failed to connect to worker"); - return Err(e); - } - }; - - map.insert(application, worker); - } - - workers.push(map); - } - - Some(crate::tokio::spawn(async move { - push_pool.start(workers).await - })) - } else { - None - }; - - // Initialize fetch threads - let fetch_task = if config.delivery_mode == DeliveryMode::Push { - Some(crate::tokio::spawn(async move { fetch_pool.start().await })) - } else { - None - }; - - let mut departure = elegant_departure::tokio::depart() - .on_termination() - .on_sigint() - .on_signal(SignalKind::hangup()) - .on_signal(SignalKind::quit()) - .on_completion(log_task_completion("upkeep_task", upkeep_task)) - .on_completion(log_task_completion("maintenance_task", maintenance_task)); - - for (topic, handle) in consumer_tasks { - departure = - departure.on_completion(log_task_completion(format!("consumer:{topic}"), handle)); - } - - if let Some(task) = grpc_server_task { - departure = departure.on_completion(log_task_completion("grpc_server", task)); - } - - if let Some(task) = push_task { - departure = departure.on_completion(log_task_completion("push_task", task)); - } - - if let Some(task) = fetch_task { - departure = departure.on_completion(log_task_completion("fetch_task", task)); - } - - if let Some(task) = status_update_task { - departure = departure.on_completion(log_task_completion("status_update_task", task)); - } - - departure.await; - Ok(()) -} - -/// Run migrations. -pub async fn migrations(config: Arc) -> Result<(), Error> { - if config.database_adapter == DatabaseAdapter::Sqlite { - return Ok(()); - } - - let mut conn_opts = PgConnectOptions::new() - .username(&config.pg_ddl_username) - .password(&config.pg_ddl_password) - .host(&config.pg_host) - .port(config.pg_port); - - if let Some(extra_query_params) = config.pg_extra_query_params.as_ref() { - let url = conn_opts.to_url_lossy(); - let new_url = - url.as_ref().split('?').next().unwrap().to_string() + "?" + extra_query_params; - conn_opts = PgConnectOptions::from_str(&new_url).unwrap(); - } - - let default_pool = - create_default_postgres_pool(&conn_opts, &config.pg_default_database_name).await?; - - // Create the database if it doesn't exist - let row: (bool,) = - sqlx::query_as("SELECT EXISTS ( SELECT 1 FROM pg_catalog.pg_database WHERE datname = $1 )") - .bind(&config.pg_database_name) - .fetch_one(&default_pool) - .await?; - - if !row.0 { - println!("Creating database {}", &config.pg_database_name); - sqlx::query(format!("CREATE DATABASE {}", &config.pg_database_name).as_str()) - .bind(&config.pg_database_name) - .execute(&default_pool) - .await?; - } - - default_pool.close().await; - - let migration_pool = PgPoolOptions::new() - .max_connections(1) - .connect_with(conn_opts.database(&config.pg_database_name)) - .await?; - - println!("Running migrations on database"); - sqlx::migrate!("./migrations/postgres") - .run(&migration_pool) - .await?; - migration_pool.close().await; - - Ok(()) -} diff --git a/src/store/adapters/postgres.rs b/src/store/adapters/postgres.rs index d34c820b..fd5d1dfb 100644 --- a/src/store/adapters/postgres.rs +++ b/src/store/adapters/postgres.rs @@ -8,7 +8,7 @@ use sqlx::pool::PoolConnection; use sqlx::postgres::{PgConnectOptions, PgPool, PgPoolOptions}; use sqlx::{FromRow, Pool, Postgres, QueryBuilder, Transaction}; -use anyhow::{Error, anyhow}; +use anyhow::{Error, Result, anyhow}; use async_backtrace::framed; use async_trait::async_trait; use chrono::{DateTime, Utc}; @@ -23,6 +23,55 @@ use crate::store::retry::{RetryConfig, retry_query}; use crate::store::traits::ActivationStore; use crate::store::types::{BucketRange, DepthCounts, FailedTasksForwarder}; +/// Run migrations. +pub async fn migrate(config: &Config) -> Result<()> { + let mut conn_opts = PgConnectOptions::new() + .username(&config.pg_ddl_username) + .password(&config.pg_ddl_password) + .host(&config.pg_host) + .port(config.pg_port); + + if let Some(extra_query_params) = config.pg_extra_query_params.as_ref() { + let url = conn_opts.to_url_lossy(); + let new_url = + url.as_ref().split('?').next().unwrap().to_string() + "?" + extra_query_params; + conn_opts = PgConnectOptions::from_str(&new_url).unwrap(); + } + + let default_pool = + create_default_postgres_pool(&conn_opts, &config.pg_default_database_name).await?; + + // Create the database if it doesn't exist + let row: (bool,) = + sqlx::query_as("SELECT EXISTS ( SELECT 1 FROM pg_catalog.pg_database WHERE datname = $1 )") + .bind(&config.pg_database_name) + .fetch_one(&default_pool) + .await?; + + if !row.0 { + println!("Creating database {}", &config.pg_database_name); + sqlx::query(format!("CREATE DATABASE {}", &config.pg_database_name).as_str()) + .execute(&default_pool) + .await?; + } + + default_pool.close().await; + + let migration_pool = PgPoolOptions::new() + .max_connections(1) + .connect_with(conn_opts.database(&config.pg_database_name)) + .await?; + + println!("Running migrations on database"); + sqlx::migrate!("./migrations/postgres") + .run(&migration_pool) + .await?; + + migration_pool.close().await; + + Ok(()) +} + #[derive(Debug, FromRow)] struct TableRow { pub id: String, @@ -152,12 +201,14 @@ impl PostgresStoreConfig { .password(&config.pg_password) .host(&config.pg_host) .port(config.pg_port); + if let Some(extra_query_params) = config.pg_extra_query_params.as_ref() { let url = conn_opts.to_url_lossy(); let new_url = url.as_ref().split('?').next().unwrap().to_string() + "?" + extra_query_params; conn_opts = PgConnectOptions::from_str(&new_url).unwrap(); } + Self { pg_connection: conn_opts, pg_database_name: config.pg_database_name.clone(), diff --git a/src/test_utils.rs b/src/test_utils.rs index 0d13c4c8..524b8177 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -17,7 +17,7 @@ use uuid::Uuid; use crate::config::Config; use crate::store::activation::{Activation, ActivationBuilder, ActivationStatus}; -use crate::store::adapters::postgres::{PostgresStore, PostgresStoreConfig}; +use crate::store::adapters::postgres::{self, PostgresStore, PostgresStoreConfig}; use crate::store::adapters::sqlite::{SqliteStore, SqliteStoreConfig}; use crate::store::traits::ActivationStore; @@ -284,13 +284,15 @@ pub async fn create_test_store(adapter: &str) -> Arc { .unwrap(), ) as Arc, "postgres" => { + let config = create_integration_config(); + postgres::migrate(&config).await.unwrap(); + let store = Arc::new( - PostgresStore::new(PostgresStoreConfig::from_config( - &create_integration_config(), - )) - .await - .unwrap(), + PostgresStore::new(PostgresStoreConfig::from_config(&config)) + .await + .unwrap(), ) as Arc; + store.assign_partitions(vec![0]).unwrap(); store }