Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 29 additions & 39 deletions pgdog/src/backend/auth/azure_workload_identity.rs
Original file line number Diff line number Diff line change
@@ -1,58 +1,48 @@
use crate::backend::{pool::Address, Error};
use std::time::SystemTime;

use azure_core::credentials::TokenCredential;
use azure_identity::WorkloadIdentityCredential;

pub async fn token(addr: &Address) -> Result<String, Error> {
#[cfg(test)]
if let Some(token) = test_token_override() {
return Ok(token);
}
use crate::backend::{pool::Address, Error};

/// Fetch a fresh Azure Workload Identity token for `addr`.
///
/// This is the raw fetcher passed to [`TokenCache::get_or_fetch`] and
/// called by the monitor's refresh loop. Callers should never invoke it
/// directly — go through [`TokenCache::global`] instead.
pub(crate) async fn token(addr: Address) -> Result<(String, SystemTime), Error> {
let credential = WorkloadIdentityCredential::new(None).map_err(|error| {
Error::AzureWorkloadIdentityToken(format!(
"failed to build workload identity credential for {}@{}:{}: {}",
addr.user, addr.host, addr.port, error
))
})?;

credential
let access_token = credential
.get_token(
&["https://ossrdbms-aad.database.windows.net/.default"],
None,
)
.await
.map(|token| token.token.secret().to_string())
.map_err(|error| {
Error::AzureWorkloadIdentityToken(format!(
"failed to get Azure AD token for {}@{}:{}: {}",
addr.user, addr.host, addr.port, error
))
})
}

#[cfg(test)]
fn test_token_override() -> Option<String> {
TEST_TOKEN_OVERRIDE.lock().clone()
}
})?;

#[cfg(test)]
pub(crate) fn set_test_token_override(token: Option<String>) {
*TEST_TOKEN_OVERRIDE.lock() = token;
let expires_at = SystemTime::from(access_token.expires_on);
Ok((access_token.token.secret().to_string(), expires_at))
}

#[cfg(test)]
static TEST_TOKEN_OVERRIDE: once_cell::sync::Lazy<parking_lot::Mutex<Option<String>>> =
once_cell::sync::Lazy::new(|| parking_lot::Mutex::new(None));

#[cfg(test)]
mod tests {
use crate::backend::pool::Address;
use crate::config::ServerAuth;
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
use pgdog_config::Role;
use std::env;

use super::*;
use crate::config::ServerAuth;
use pgdog_config::Role;

struct EnvVarGuard {
key: &'static str,
Expand All @@ -69,10 +59,9 @@ mod tests {

impl Drop for EnvVarGuard {
fn drop(&mut self) {
if let Some(previous) = self.previous.take() {
env::set_var(self.key, previous);
} else {
env::remove_var(self.key);
match self.previous.take() {
Some(v) => env::set_var(self.key, v),
None => env::remove_var(self.key),
}
}
}
Expand All @@ -89,27 +78,28 @@ mod tests {
port: 5432,
database_name: "postgres".into(),
user: "db_user".into(),
passwords: vec![String::new().into()],
passwords: vec![],
database_number: 0,
server_auth: ServerAuth::AzureWorkloadIdentity,
server_iam_region: None,
configured_role: Role::Auto,
};

let b64_token = token(&addr).await.unwrap();
let (b64_token, expires_at) = token(addr).await.unwrap();

assert!(expires_at > std::time::SystemTime::now());

// Use functional chaining to extract and decode
let token = b64_token
let payload = b64_token
.split('.')
.nth(1)
.map(|payload| URL_SAFE_NO_PAD.decode(payload))
.map(|p| URL_SAFE_NO_PAD.decode(p))
.transpose()
.expect("Invalid JWT format") // Converts Option<Result<T, E>> to Result<Option<T>, E>
.expect("invalid JWT format")
.and_then(|bytes| String::from_utf8(bytes).ok())
.expect("Failed to parse JWT payload as valid UTF-8 JSON");
.expect("failed to parse JWT payload as UTF-8 JSON");

assert!(token.contains("https://sts.windows.net/"));
assert!(token.contains("https://management.azure.com"));
assert!(token.contains("appid"));
assert!(payload.contains("https://sts.windows.net/"));
assert!(payload.contains("https://management.azure.com"));
assert!(payload.contains("appid"));
}
}
130 changes: 87 additions & 43 deletions pgdog/src/backend/auth/rds_iam.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::time::{Duration, SystemTime};

use aws_config::{BehaviorVersion, Region};
use aws_sdk_rds::auth_token::{AuthTokenGenerator, Config as AuthTokenConfig};

Expand Down Expand Up @@ -43,13 +45,13 @@ fn resolve_region(addr: &Address) -> Result<String, Error> {
})
}

pub async fn token(addr: &Address) -> Result<String, Error> {
#[cfg(test)]
if let Some(token) = test_token_override() {
return Ok(token);
}

let region = resolve_region(addr)?;
/// Fetch a fresh RDS IAM token for `addr`.
///
/// This is the raw fetcher passed to [`TokenCache::get_or_fetch`] and
/// called by the monitor's refresh loop. Callers should never invoke it
/// directly — go through [`TokenCache::global`] instead.
pub(crate) async fn token(addr: Address) -> Result<(String, SystemTime), Error> {
let region = resolve_region(&addr)?;
let sdk_config = aws_config::load_defaults(BehaviorVersion::latest()).await;

let config = AuthTokenConfig::builder()
Expand All @@ -65,7 +67,7 @@ pub async fn token(addr: &Address) -> Result<String, Error> {
))
})?;

AuthTokenGenerator::new(config)
let token = AuthTokenGenerator::new(config)
.auth_token(&sdk_config)
.await
.map(|token| token.to_string())
Expand All @@ -74,33 +76,20 @@ pub async fn token(addr: &Address) -> Result<String, Error> {
"failed to generate RDS IAM token for {}@{}:{} in region {}: {}",
addr.user, addr.host, addr.port, region, error
))
})
}

#[cfg(test)]
fn test_token_override() -> Option<String> {
TEST_TOKEN_OVERRIDE.lock().clone()
}
})?;

#[cfg(test)]
pub(crate) fn set_test_token_override(token: Option<String>) {
*TEST_TOKEN_OVERRIDE.lock() = token;
// RDS IAM tokens are valid for 15 minutes.
let expires_at = SystemTime::now() + Duration::from_secs(900);
Ok((token, expires_at))
}

#[cfg(test)]
static TEST_TOKEN_OVERRIDE: once_cell::sync::Lazy<parking_lot::Mutex<Option<String>>> =
once_cell::sync::Lazy::new(|| parking_lot::Mutex::new(None));

#[cfg(test)]
mod tests {
use std::env;

use pgdog_config::Role;

use crate::backend::pool::Address;
use crate::config::ServerAuth;
use std::env;

use super::*;
use crate::config::ServerAuth;

struct EnvVarGuard {
key: &'static str,
Expand All @@ -117,14 +106,29 @@ mod tests {

impl Drop for EnvVarGuard {
fn drop(&mut self) {
if let Some(previous) = self.previous.take() {
env::set_var(self.key, previous);
} else {
env::remove_var(self.key);
match self.previous.take() {
Some(v) => env::set_var(self.key, v),
None => env::remove_var(self.key),
}
}
}

fn make_addr() -> Address {
Address {
host: "db.cluster-abc123.us-east-1.rds.amazonaws.com".into(),
port: 5432,
database_name: "postgres".into(),
user: "db_user".into(),
passwords: vec![],
database_number: 0,
server_auth: ServerAuth::RdsIam,
server_iam_region: Some("us-east-1".into()),
configured_role: Role::Auto,
}
}

// ── infer_region_from_rds_host ───────────────────────────────────────────

#[test]
fn test_infer_region_commercial_endpoint() {
let region = infer_region_from_rds_host("db.cluster-abc123.us-east-1.rds.amazonaws.com");
Expand All @@ -144,25 +148,65 @@ mod tests {
assert!(region.is_none());
}

#[test]
fn test_infer_region_fails_when_rds_is_first_label() {
let region = infer_region_from_rds_host("rds.amazonaws.com");
assert!(region.is_none());
}

#[test]
fn test_infer_region_fails_for_unknown_tld() {
let region = infer_region_from_rds_host("db.us-east-1.rds.example.com");
assert!(region.is_none());
}

// ── resolve_region ───────────────────────────────────────────────────────

#[test]
fn resolve_region_prefers_explicit_override() {
let mut addr = make_addr();
addr.server_iam_region = Some("eu-west-1".into());
// Host implies us-east-1, but the explicit override must win.
assert_eq!(resolve_region(&addr).unwrap(), "eu-west-1");
}

#[test]
fn resolve_region_falls_back_to_host_inference() {
let mut addr = make_addr();
addr.server_iam_region = None;
assert_eq!(resolve_region(&addr).unwrap(), "us-east-1");
}

#[test]
fn resolve_region_treats_empty_override_as_absent() {
let mut addr = make_addr();
addr.server_iam_region = Some("".into());
// Empty string must fall through to host inference.
assert_eq!(resolve_region(&addr).unwrap(), "us-east-1");
}

#[test]
fn resolve_region_errors_when_neither_override_nor_inference() {
let addr = Address {
host: "postgres.internal.example.com".into(),
port: 5432,
user: "u".into(),
server_iam_region: None,
..Default::default()
};
assert!(resolve_region(&addr).is_err());
}

#[tokio::test]
async fn test_token_contains_expected_query_fields() {
let _access_key = EnvVarGuard::set("AWS_ACCESS_KEY_ID", "AKIDEXAMPLE");
let _secret_key = EnvVarGuard::set("AWS_SECRET_ACCESS_KEY", "SECRETEXAMPLE");
let _session = EnvVarGuard::set("AWS_SESSION_TOKEN", "SESSIONEXAMPLE");

let addr = Address {
host: "db.cluster-abc123.us-east-1.rds.amazonaws.com".into(),
port: 5432,
database_name: "postgres".into(),
user: "db_user".into(),
passwords: vec![String::new().into()],
database_number: 0,
server_auth: ServerAuth::RdsIam,
server_iam_region: Some("us-east-1".into()),
configured_role: Role::Auto,
};
let addr = make_addr();
let (token, expires_at) = token(addr).await.unwrap();

let token = token(&addr).await.unwrap();
assert!(expires_at > SystemTime::now());
assert!(token.starts_with(
"db.cluster-abc123.us-east-1.rds.amazonaws.com:5432/?Action=connect&DBUser=db_user"
));
Expand Down
Loading
Loading