diff --git a/Cargo.lock b/Cargo.lock index e307307b..ed2fab0c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1824,9 +1824,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.172" +version = "0.2.180" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" +checksum = "bcc35a38544a891a5f7c865aca548a982ccb3b8650a5b06d0fd33a10283c56fc" [[package]] name = "libloading" @@ -1940,7 +1940,7 @@ version = "1.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c0aeb26bf5e836cc1c341c8106051b573f1766dfa05aa87f0b98be5e51b02303" dependencies = [ - "nix", + "nix 0.29.0", "winapi", ] @@ -2078,6 +2078,18 @@ dependencies = [ "memoffset", ] +[[package]] +name = "nix" +version = "0.31.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "225e7cfe711e0ba79a68baeddb2982723e4235247aefce1482f2f16c27865b66" +dependencies = [ + "bitflags 2.9.1", + "cfg-if", + "cfg_aliases", + "libc", +] + [[package]] name = "nom" version = "7.1.3" @@ -2421,6 +2433,7 @@ dependencies = [ "lazy_static", "lru 0.16.0", "md5", + "nix 0.31.1", "once_cell", "parking_lot", "pg_query", @@ -2512,7 +2525,7 @@ dependencies = [ "pgdog-vector", "rust_decimal", "serde", - "thiserror 1.0.69", + "thiserror 2.0.12", "uuid", ] @@ -3879,7 +3892,6 @@ dependencies = [ "cfg-if", "libc", "psm", - "windows-sys 0.52.0", "windows-sys 0.59.0", ] @@ -4067,7 +4079,7 @@ dependencies = [ "libc", "log", "memmem", - "nix", + "nix 0.29.0", "num-derive", "num-traits", "ordered-float", diff --git a/Dockerfile b/Dockerfile index b4ff5a87..a9bb234b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -13,7 +13,7 @@ WORKDIR /build RUN rm /bin/sh && ln -s /bin/bash /bin/sh RUN source ~/.cargo/env && \ if [ "$(uname -m)" = "aarch64" ] || [ "$(uname -m)" = "arm64" ]; then \ - export RUSTFLAGS="-Ctarget-feature=+lse"; \ + export RUSTFLAGS="-Ctarget-feature=+lse"; \ fi && \ cd pgdog && \ cargo build --release @@ -31,10 +31,13 @@ RUN install -d /usr/share/postgresql-common/pgdg && \ . /etc/os-release && \ sh -c "echo 'deb [signed-by=/usr/share/postgresql-common/pgdg/apt.postgresql.org.asc] https://apt.postgresql.org/pub/repos/apt $VERSION_CODENAME-pgdg main' > /etc/apt/sources.list.d/pgdg.list" -RUN apt update && apt install -y postgresql-client-${PSQL_VERSION} +RUN apt update && apt install -y postgresql-${PSQL_VERSION} && \ + systemctl disable postgresql COPY --from=builder /build/target/release/pgdog /usr/local/bin/pgdog +RUN mkdir -p /pgdog && chown postgres:postgres /pgdog WORKDIR /pgdog +USER postgres STOPSIGNAL SIGINT CMD ["/usr/local/bin/pgdog"] diff --git a/integration/postgres_fdw/dev.sh b/integration/postgres_fdw/dev.sh new file mode 100644 index 00000000..49ff2020 --- /dev/null +++ b/integration/postgres_fdw/dev.sh @@ -0,0 +1,16 @@ +#!/bin/bash +set -ex -o pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PGDOG="$SCRIPT_DIR/../../target/debug/pgdog" + +dropdb shard_0_fdw || true +dropdb shard_1_fdw || true + +createdb shard_0_fdw +createdb shard_1_fdw + +psql -f "$SCRIPT_DIR/../schema_sync/ecommerce_schema.sql" shard_0_fdw +psql -f "$SCRIPT_DIR/../schema_sync/ecommerce_schema.sql" shard_1_fdw + +${PGDOG} diff --git a/integration/postgres_fdw/pgdog.toml b/integration/postgres_fdw/pgdog.toml new file mode 100644 index 00000000..4f1fb5ad --- /dev/null +++ b/integration/postgres_fdw/pgdog.toml @@ -0,0 +1,36 @@ + +[general] +cross_shard_backend = "fdw" + +[[databases]] +name = "pgdog" +shard = 0 +host = "127.0.0.1" +database_name = "shard_0_fdw" + +[[databases]] +name = "pgdog" +shard = 1 +host = "127.0.0.1" +database_name = "shard_1_fdw" + +[[databases]] +name = "pgdog" +shard = 0 +host = "127.0.0.1" +database_name = "shard_0_fdw" +role = "replica" + +[[databases]] +name = "pgdog" +shard = 1 +host = "127.0.0.1" +database_name = "shard_1_fdw" +role = "replica" + +[[sharded_tables]] +column = "user_id" +database = "pgdog" + +[admin] +password = "pgdog" diff --git a/integration/postgres_fdw/users.toml b/integration/postgres_fdw/users.toml new file mode 100644 index 00000000..ebda3e9d --- /dev/null +++ b/integration/postgres_fdw/users.toml @@ -0,0 +1,9 @@ +[[users]] +name = "pgdog" +password = "pgdog" +database = "pgdog" + +[[users]] +name = "lev" +password = "lev" +database = "pgdog" diff --git a/pgdog-config/src/core.rs b/pgdog-config/src/core.rs index cb7fd7b0..85a5bba0 100644 --- a/pgdog-config/src/core.rs +++ b/pgdog-config/src/core.rs @@ -6,7 +6,7 @@ use tracing::{info, warn}; use crate::sharding::ShardedSchema; use crate::{ - system_catalogs, EnumeratedDatabase, Memory, OmnishardedTable, PassthoughAuth, + system_catalogs, EnumeratedDatabase, Fdw, Memory, OmnishardedTable, PassthoughAuth, PreparedStatements, QueryParserEngine, QueryParserLevel, ReadWriteSplit, RewriteMode, Role, SystemCatalogsBehavior, }; @@ -187,6 +187,9 @@ pub struct Config { /// Memory tweaks #[serde(default)] pub memory: Memory, + + #[serde(default)] + pub fdw: Fdw, } impl Config { diff --git a/pgdog-config/src/fdw.rs b/pgdog-config/src/fdw.rs new file mode 100644 index 00000000..27af750e --- /dev/null +++ b/pgdog-config/src/fdw.rs @@ -0,0 +1,28 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Copy)] +#[serde(deny_unknown_fields)] +pub struct Fdw { + #[serde(default = "default_port")] + pub port: u16, + + #[serde(default = "default_launch_timeout")] + pub launch_timeout: u64, +} + +impl Default for Fdw { + fn default() -> Self { + Self { + port: default_port(), + launch_timeout: default_launch_timeout(), + } + } +} + +fn default_port() -> u16 { + 6433 +} + +fn default_launch_timeout() -> u64 { + 5_000 +} diff --git a/pgdog-config/src/general.rs b/pgdog-config/src/general.rs index 23f7ebb6..9787f76c 100644 --- a/pgdog-config/src/general.rs +++ b/pgdog-config/src/general.rs @@ -5,7 +5,11 @@ use std::path::PathBuf; use std::time::Duration; use crate::pooling::ConnectionRecovery; -use crate::{CopyFormat, LoadSchema, QueryParserEngine, QueryParserLevel, SystemCatalogsBehavior}; + +use crate::{ + CopyFormat, CrossShardBackend, LoadSchema, QueryParserEngine, QueryParserLevel, + SystemCatalogsBehavior, +}; use super::auth::{AuthType, PassthoughAuth}; use super::database::{LoadBalancingStrategy, ReadWriteSplit, ReadWriteStrategy}; @@ -212,6 +216,9 @@ pub struct General { /// Load database schema. #[serde(default = "General::load_schema")] pub load_schema: LoadSchema, + /// Cross-shard backend. + #[serde(default = "General::cross_shard_backend")] + pub cross_shard_backend: CrossShardBackend, } impl Default for General { @@ -286,6 +293,7 @@ impl Default for General { resharding_copy_format: CopyFormat::default(), reload_schema_on_ddl: Self::reload_schema_on_ddl(), load_schema: Self::load_schema(), + cross_shard_backend: Self::cross_shard_backend(), } } } @@ -414,6 +422,10 @@ impl General { ) } + fn cross_shard_backend() -> CrossShardBackend { + Self::env_enum_or_default("PGDOG_CROSS_SHARD_BACKEND") + } + pub fn query_timeout(&self) -> Duration { Duration::from_millis(self.query_timeout) } diff --git a/pgdog-config/src/lib.rs b/pgdog-config/src/lib.rs index a08b302c..be02d553 100644 --- a/pgdog-config/src/lib.rs +++ b/pgdog-config/src/lib.rs @@ -4,6 +4,7 @@ pub mod core; pub mod data_types; pub mod database; pub mod error; +pub mod fdw; pub mod general; pub mod memory; pub mod networking; @@ -24,6 +25,7 @@ pub use database::{ Database, EnumeratedDatabase, LoadBalancingStrategy, ReadWriteSplit, ReadWriteStrategy, Role, }; pub use error::Error; +pub use fdw::Fdw; pub use general::General; pub use memory::*; pub use networking::{MultiTenant, Tcp, TlsVerifyMode}; diff --git a/pgdog-config/src/sharding.rs b/pgdog-config/src/sharding.rs index 0d47d5d9..0dfb2324 100644 --- a/pgdog-config/src/sharding.rs +++ b/pgdog-config/src/sharding.rs @@ -313,6 +313,15 @@ impl ListShards { Ok(None) } } + + /// Get all values that map to a specific shard. + pub fn values_for_shard(&self, shard: usize) -> Vec<&FlexibleType> { + self.mapping + .iter() + .filter(|(_, &s)| s == shard) + .map(|(v, _)| v) + .collect() + } } #[derive(Serialize, Deserialize, Debug, Copy, Clone, PartialEq, Eq, Hash, Default)] @@ -391,6 +400,44 @@ impl FromStr for LoadSchema { } } +#[derive(Serialize, Deserialize, Debug, Copy, Clone, PartialEq, Eq, Hash, Default)] +#[serde(rename_all = "snake_case", deny_unknown_fields)] +pub enum CrossShardBackend { + #[default] + Pgdog, + Fdw, + Hybrid, +} + +impl CrossShardBackend { + pub fn need_fdw(&self) -> bool { + matches!(self, Self::Fdw | Self::Hybrid) + } +} + +impl Display for CrossShardBackend { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Pgdog => write!(f, "pgdog"), + Self::Fdw => write!(f, "fdw"), + Self::Hybrid => write!(f, "hybrid"), + } + } +} + +impl FromStr for CrossShardBackend { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "pgdog" => Ok(Self::Pgdog), + "fdw" => Ok(Self::Fdw), + "hybrid" => Ok(Self::Hybrid), + _ => Err(()), + } + } +} + #[cfg(test)] mod test { use super::*; diff --git a/pgdog-plugin/src/bindings.rs b/pgdog-plugin/src/bindings.rs index 561d24e5..6f47703d 100644 --- a/pgdog-plugin/src/bindings.rs +++ b/pgdog-plugin/src/bindings.rs @@ -1,213 +1,338 @@ /* automatically generated by rust-bindgen 0.71.1 */ -pub const _STDINT_H: u32 = 1; -pub const _FEATURES_H: u32 = 1; -pub const _DEFAULT_SOURCE: u32 = 1; -pub const __GLIBC_USE_ISOC2Y: u32 = 0; -pub const __GLIBC_USE_ISOC23: u32 = 0; -pub const __USE_ISOC11: u32 = 1; -pub const __USE_ISOC99: u32 = 1; -pub const __USE_ISOC95: u32 = 1; -pub const __USE_POSIX_IMPLICITLY: u32 = 1; -pub const _POSIX_SOURCE: u32 = 1; -pub const _POSIX_C_SOURCE: u32 = 200809; -pub const __USE_POSIX: u32 = 1; -pub const __USE_POSIX2: u32 = 1; -pub const __USE_POSIX199309: u32 = 1; -pub const __USE_POSIX199506: u32 = 1; -pub const __USE_XOPEN2K: u32 = 1; -pub const __USE_XOPEN2K8: u32 = 1; -pub const _ATFILE_SOURCE: u32 = 1; pub const __WORDSIZE: u32 = 64; -pub const __WORDSIZE_TIME64_COMPAT32: u32 = 1; -pub const __SYSCALL_WORDSIZE: u32 = 64; -pub const __TIMESIZE: u32 = 64; -pub const __USE_TIME_BITS64: u32 = 1; -pub const __USE_MISC: u32 = 1; -pub const __USE_ATFILE: u32 = 1; -pub const __USE_FORTIFY_LEVEL: u32 = 0; -pub const __GLIBC_USE_DEPRECATED_GETS: u32 = 0; -pub const __GLIBC_USE_DEPRECATED_SCANF: u32 = 0; -pub const __GLIBC_USE_C23_STRTOL: u32 = 0; -pub const _STDC_PREDEF_H: u32 = 1; -pub const __STDC_IEC_559__: u32 = 1; -pub const __STDC_IEC_60559_BFP__: u32 = 201404; -pub const __STDC_IEC_559_COMPLEX__: u32 = 1; -pub const __STDC_IEC_60559_COMPLEX__: u32 = 201404; -pub const __STDC_ISO_10646__: u32 = 201706; -pub const __GNU_LIBRARY__: u32 = 6; -pub const __GLIBC__: u32 = 2; -pub const __GLIBC_MINOR__: u32 = 42; -pub const _SYS_CDEFS_H: u32 = 1; -pub const __glibc_c99_flexarr_available: u32 = 1; -pub const __LDOUBLE_REDIRECTS_TO_FLOAT128_ABI: u32 = 0; -pub const __HAVE_GENERIC_SELECTION: u32 = 1; -pub const __GLIBC_USE_LIB_EXT2: u32 = 0; -pub const __GLIBC_USE_IEC_60559_BFP_EXT: u32 = 0; -pub const __GLIBC_USE_IEC_60559_BFP_EXT_C23: u32 = 0; -pub const __GLIBC_USE_IEC_60559_EXT: u32 = 0; -pub const __GLIBC_USE_IEC_60559_FUNCS_EXT: u32 = 0; -pub const __GLIBC_USE_IEC_60559_FUNCS_EXT_C23: u32 = 0; -pub const __GLIBC_USE_IEC_60559_TYPES_EXT: u32 = 0; -pub const _BITS_TYPES_H: u32 = 1; -pub const _BITS_TYPESIZES_H: u32 = 1; -pub const __OFF_T_MATCHES_OFF64_T: u32 = 1; -pub const __INO_T_MATCHES_INO64_T: u32 = 1; -pub const __RLIM_T_MATCHES_RLIM64_T: u32 = 1; -pub const __STATFS_MATCHES_STATFS64: u32 = 1; -pub const __KERNEL_OLD_TIMEVAL_MATCHES_TIMEVAL64: u32 = 1; -pub const __FD_SETSIZE: u32 = 1024; -pub const _BITS_TIME64_H: u32 = 1; -pub const _BITS_WCHAR_H: u32 = 1; -pub const _BITS_STDINT_INTN_H: u32 = 1; -pub const _BITS_STDINT_UINTN_H: u32 = 1; -pub const _BITS_STDINT_LEAST_H: u32 = 1; -pub const INT8_MIN: i32 = -128; -pub const INT16_MIN: i32 = -32768; -pub const INT32_MIN: i32 = -2147483648; +pub const __has_safe_buffers: u32 = 1; +pub const __DARWIN_ONLY_64_BIT_INO_T: u32 = 1; +pub const __DARWIN_ONLY_UNIX_CONFORMANCE: u32 = 1; +pub const __DARWIN_ONLY_VERS_1050: u32 = 1; +pub const __DARWIN_UNIX03: u32 = 1; +pub const __DARWIN_64_BIT_INO_T: u32 = 1; +pub const __DARWIN_VERS_1050: u32 = 1; +pub const __DARWIN_NON_CANCELABLE: u32 = 0; +pub const __DARWIN_SUF_EXTSN: &[u8; 14] = b"$DARWIN_EXTSN\0"; +pub const __DARWIN_C_ANSI: u32 = 4096; +pub const __DARWIN_C_FULL: u32 = 900000; +pub const __DARWIN_C_LEVEL: u32 = 900000; +pub const __STDC_WANT_LIB_EXT1__: u32 = 1; +pub const __DARWIN_NO_LONG_LONG: u32 = 0; +pub const _DARWIN_FEATURE_64_BIT_INODE: u32 = 1; +pub const _DARWIN_FEATURE_ONLY_64_BIT_INODE: u32 = 1; +pub const _DARWIN_FEATURE_ONLY_VERS_1050: u32 = 1; +pub const _DARWIN_FEATURE_ONLY_UNIX_CONFORMANCE: u32 = 1; +pub const _DARWIN_FEATURE_UNIX_CONFORMANCE: u32 = 3; +pub const __has_ptrcheck: u32 = 0; +pub const USE_CLANG_TYPES: u32 = 0; +pub const __PTHREAD_SIZE__: u32 = 8176; +pub const __PTHREAD_ATTR_SIZE__: u32 = 56; +pub const __PTHREAD_MUTEXATTR_SIZE__: u32 = 8; +pub const __PTHREAD_MUTEX_SIZE__: u32 = 56; +pub const __PTHREAD_CONDATTR_SIZE__: u32 = 8; +pub const __PTHREAD_COND_SIZE__: u32 = 40; +pub const __PTHREAD_ONCE_SIZE__: u32 = 8; +pub const __PTHREAD_RWLOCK_SIZE__: u32 = 192; +pub const __PTHREAD_RWLOCKATTR_SIZE__: u32 = 16; pub const INT8_MAX: u32 = 127; pub const INT16_MAX: u32 = 32767; pub const INT32_MAX: u32 = 2147483647; +pub const INT64_MAX: u64 = 9223372036854775807; +pub const INT8_MIN: i32 = -128; +pub const INT16_MIN: i32 = -32768; +pub const INT32_MIN: i32 = -2147483648; +pub const INT64_MIN: i64 = -9223372036854775808; pub const UINT8_MAX: u32 = 255; pub const UINT16_MAX: u32 = 65535; pub const UINT32_MAX: u32 = 4294967295; +pub const UINT64_MAX: i32 = -1; pub const INT_LEAST8_MIN: i32 = -128; pub const INT_LEAST16_MIN: i32 = -32768; pub const INT_LEAST32_MIN: i32 = -2147483648; +pub const INT_LEAST64_MIN: i64 = -9223372036854775808; pub const INT_LEAST8_MAX: u32 = 127; pub const INT_LEAST16_MAX: u32 = 32767; pub const INT_LEAST32_MAX: u32 = 2147483647; +pub const INT_LEAST64_MAX: u64 = 9223372036854775807; pub const UINT_LEAST8_MAX: u32 = 255; pub const UINT_LEAST16_MAX: u32 = 65535; pub const UINT_LEAST32_MAX: u32 = 4294967295; +pub const UINT_LEAST64_MAX: i32 = -1; pub const INT_FAST8_MIN: i32 = -128; -pub const INT_FAST16_MIN: i64 = -9223372036854775808; -pub const INT_FAST32_MIN: i64 = -9223372036854775808; +pub const INT_FAST16_MIN: i32 = -32768; +pub const INT_FAST32_MIN: i32 = -2147483648; +pub const INT_FAST64_MIN: i64 = -9223372036854775808; pub const INT_FAST8_MAX: u32 = 127; -pub const INT_FAST16_MAX: u64 = 9223372036854775807; -pub const INT_FAST32_MAX: u64 = 9223372036854775807; +pub const INT_FAST16_MAX: u32 = 32767; +pub const INT_FAST32_MAX: u32 = 2147483647; +pub const INT_FAST64_MAX: u64 = 9223372036854775807; pub const UINT_FAST8_MAX: u32 = 255; -pub const UINT_FAST16_MAX: i32 = -1; -pub const UINT_FAST32_MAX: i32 = -1; -pub const INTPTR_MIN: i64 = -9223372036854775808; +pub const UINT_FAST16_MAX: u32 = 65535; +pub const UINT_FAST32_MAX: u32 = 4294967295; +pub const UINT_FAST64_MAX: i32 = -1; pub const INTPTR_MAX: u64 = 9223372036854775807; +pub const INTPTR_MIN: i64 = -9223372036854775808; pub const UINTPTR_MAX: i32 = -1; -pub const PTRDIFF_MIN: i64 = -9223372036854775808; -pub const PTRDIFF_MAX: u64 = 9223372036854775807; +pub const SIZE_MAX: i32 = -1; +pub const RSIZE_MAX: i32 = -1; +pub const WINT_MIN: i32 = -2147483648; +pub const WINT_MAX: u32 = 2147483647; pub const SIG_ATOMIC_MIN: i32 = -2147483648; pub const SIG_ATOMIC_MAX: u32 = 2147483647; -pub const SIZE_MAX: i32 = -1; -pub const WINT_MIN: u32 = 0; -pub const WINT_MAX: u32 = 4294967295; pub type wchar_t = ::std::os::raw::c_int; -#[repr(C)] -#[repr(align(16))] -#[derive(Debug, Copy, Clone)] -pub struct max_align_t { - pub __clang_max_align_nonce1: ::std::os::raw::c_longlong, - pub __bindgen_padding_0: u64, - pub __clang_max_align_nonce2: u128, -} -#[allow(clippy::unnecessary_operation, clippy::identity_op)] -const _: () = { - ["Size of max_align_t"][::std::mem::size_of::() - 32usize]; - ["Alignment of max_align_t"][::std::mem::align_of::() - 16usize]; - ["Offset of field: max_align_t::__clang_max_align_nonce1"] - [::std::mem::offset_of!(max_align_t, __clang_max_align_nonce1) - 0usize]; - ["Offset of field: max_align_t::__clang_max_align_nonce2"] - [::std::mem::offset_of!(max_align_t, __clang_max_align_nonce2) - 16usize]; -}; -pub type __u_char = ::std::os::raw::c_uchar; -pub type __u_short = ::std::os::raw::c_ushort; -pub type __u_int = ::std::os::raw::c_uint; -pub type __u_long = ::std::os::raw::c_ulong; +pub type max_align_t = f64; +pub type int_least8_t = i8; +pub type int_least16_t = i16; +pub type int_least32_t = i32; +pub type int_least64_t = i64; +pub type uint_least8_t = u8; +pub type uint_least16_t = u16; +pub type uint_least32_t = u32; +pub type uint_least64_t = u64; +pub type int_fast8_t = i8; +pub type int_fast16_t = i16; +pub type int_fast32_t = i32; +pub type int_fast64_t = i64; +pub type uint_fast8_t = u8; +pub type uint_fast16_t = u16; +pub type uint_fast32_t = u32; +pub type uint_fast64_t = u64; pub type __int8_t = ::std::os::raw::c_schar; pub type __uint8_t = ::std::os::raw::c_uchar; pub type __int16_t = ::std::os::raw::c_short; pub type __uint16_t = ::std::os::raw::c_ushort; pub type __int32_t = ::std::os::raw::c_int; pub type __uint32_t = ::std::os::raw::c_uint; -pub type __int64_t = ::std::os::raw::c_long; -pub type __uint64_t = ::std::os::raw::c_ulong; -pub type __int_least8_t = __int8_t; -pub type __uint_least8_t = __uint8_t; -pub type __int_least16_t = __int16_t; -pub type __uint_least16_t = __uint16_t; -pub type __int_least32_t = __int32_t; -pub type __uint_least32_t = __uint32_t; -pub type __int_least64_t = __int64_t; -pub type __uint_least64_t = __uint64_t; -pub type __quad_t = ::std::os::raw::c_long; -pub type __u_quad_t = ::std::os::raw::c_ulong; -pub type __intmax_t = ::std::os::raw::c_long; -pub type __uintmax_t = ::std::os::raw::c_ulong; -pub type __dev_t = ::std::os::raw::c_ulong; -pub type __uid_t = ::std::os::raw::c_uint; -pub type __gid_t = ::std::os::raw::c_uint; -pub type __ino_t = ::std::os::raw::c_ulong; -pub type __ino64_t = ::std::os::raw::c_ulong; -pub type __mode_t = ::std::os::raw::c_uint; -pub type __nlink_t = ::std::os::raw::c_ulong; -pub type __off_t = ::std::os::raw::c_long; -pub type __off64_t = ::std::os::raw::c_long; -pub type __pid_t = ::std::os::raw::c_int; +pub type __int64_t = ::std::os::raw::c_longlong; +pub type __uint64_t = ::std::os::raw::c_ulonglong; +pub type __darwin_intptr_t = ::std::os::raw::c_long; +pub type __darwin_natural_t = ::std::os::raw::c_uint; +pub type __darwin_ct_rune_t = ::std::os::raw::c_int; +#[repr(C)] +#[derive(Copy, Clone)] +pub union __mbstate_t { + pub __mbstate8: [::std::os::raw::c_char; 128usize], + pub _mbstateL: ::std::os::raw::c_longlong, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of __mbstate_t"][::std::mem::size_of::<__mbstate_t>() - 128usize]; + ["Alignment of __mbstate_t"][::std::mem::align_of::<__mbstate_t>() - 8usize]; + ["Offset of field: __mbstate_t::__mbstate8"] + [::std::mem::offset_of!(__mbstate_t, __mbstate8) - 0usize]; + ["Offset of field: __mbstate_t::_mbstateL"] + [::std::mem::offset_of!(__mbstate_t, _mbstateL) - 0usize]; +}; +pub type __darwin_mbstate_t = __mbstate_t; +pub type __darwin_ptrdiff_t = ::std::os::raw::c_long; +pub type __darwin_size_t = ::std::os::raw::c_ulong; +pub type __darwin_va_list = __builtin_va_list; +pub type __darwin_wchar_t = ::std::os::raw::c_int; +pub type __darwin_rune_t = __darwin_wchar_t; +pub type __darwin_wint_t = ::std::os::raw::c_int; +pub type __darwin_clock_t = ::std::os::raw::c_ulong; +pub type __darwin_socklen_t = __uint32_t; +pub type __darwin_ssize_t = ::std::os::raw::c_long; +pub type __darwin_time_t = ::std::os::raw::c_long; +pub type __darwin_blkcnt_t = __int64_t; +pub type __darwin_blksize_t = __int32_t; +pub type __darwin_dev_t = __int32_t; +pub type __darwin_fsblkcnt_t = ::std::os::raw::c_uint; +pub type __darwin_fsfilcnt_t = ::std::os::raw::c_uint; +pub type __darwin_gid_t = __uint32_t; +pub type __darwin_id_t = __uint32_t; +pub type __darwin_ino64_t = __uint64_t; +pub type __darwin_ino_t = __darwin_ino64_t; +pub type __darwin_mach_port_name_t = __darwin_natural_t; +pub type __darwin_mach_port_t = __darwin_mach_port_name_t; +pub type __darwin_mode_t = __uint16_t; +pub type __darwin_off_t = __int64_t; +pub type __darwin_pid_t = __int32_t; +pub type __darwin_sigset_t = __uint32_t; +pub type __darwin_suseconds_t = __int32_t; +pub type __darwin_uid_t = __uint32_t; +pub type __darwin_useconds_t = __uint32_t; +pub type __darwin_uuid_t = [::std::os::raw::c_uchar; 16usize]; +pub type __darwin_uuid_string_t = [::std::os::raw::c_char; 37usize]; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct __darwin_pthread_handler_rec { + pub __routine: ::std::option::Option, + pub __arg: *mut ::std::os::raw::c_void, + pub __next: *mut __darwin_pthread_handler_rec, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of __darwin_pthread_handler_rec"] + [::std::mem::size_of::<__darwin_pthread_handler_rec>() - 24usize]; + ["Alignment of __darwin_pthread_handler_rec"] + [::std::mem::align_of::<__darwin_pthread_handler_rec>() - 8usize]; + ["Offset of field: __darwin_pthread_handler_rec::__routine"] + [::std::mem::offset_of!(__darwin_pthread_handler_rec, __routine) - 0usize]; + ["Offset of field: __darwin_pthread_handler_rec::__arg"] + [::std::mem::offset_of!(__darwin_pthread_handler_rec, __arg) - 8usize]; + ["Offset of field: __darwin_pthread_handler_rec::__next"] + [::std::mem::offset_of!(__darwin_pthread_handler_rec, __next) - 16usize]; +}; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct _opaque_pthread_attr_t { + pub __sig: ::std::os::raw::c_long, + pub __opaque: [::std::os::raw::c_char; 56usize], +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of _opaque_pthread_attr_t"][::std::mem::size_of::<_opaque_pthread_attr_t>() - 64usize]; + ["Alignment of _opaque_pthread_attr_t"] + [::std::mem::align_of::<_opaque_pthread_attr_t>() - 8usize]; + ["Offset of field: _opaque_pthread_attr_t::__sig"] + [::std::mem::offset_of!(_opaque_pthread_attr_t, __sig) - 0usize]; + ["Offset of field: _opaque_pthread_attr_t::__opaque"] + [::std::mem::offset_of!(_opaque_pthread_attr_t, __opaque) - 8usize]; +}; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct _opaque_pthread_cond_t { + pub __sig: ::std::os::raw::c_long, + pub __opaque: [::std::os::raw::c_char; 40usize], +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of _opaque_pthread_cond_t"][::std::mem::size_of::<_opaque_pthread_cond_t>() - 48usize]; + ["Alignment of _opaque_pthread_cond_t"] + [::std::mem::align_of::<_opaque_pthread_cond_t>() - 8usize]; + ["Offset of field: _opaque_pthread_cond_t::__sig"] + [::std::mem::offset_of!(_opaque_pthread_cond_t, __sig) - 0usize]; + ["Offset of field: _opaque_pthread_cond_t::__opaque"] + [::std::mem::offset_of!(_opaque_pthread_cond_t, __opaque) - 8usize]; +}; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct _opaque_pthread_condattr_t { + pub __sig: ::std::os::raw::c_long, + pub __opaque: [::std::os::raw::c_char; 8usize], +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of _opaque_pthread_condattr_t"] + [::std::mem::size_of::<_opaque_pthread_condattr_t>() - 16usize]; + ["Alignment of _opaque_pthread_condattr_t"] + [::std::mem::align_of::<_opaque_pthread_condattr_t>() - 8usize]; + ["Offset of field: _opaque_pthread_condattr_t::__sig"] + [::std::mem::offset_of!(_opaque_pthread_condattr_t, __sig) - 0usize]; + ["Offset of field: _opaque_pthread_condattr_t::__opaque"] + [::std::mem::offset_of!(_opaque_pthread_condattr_t, __opaque) - 8usize]; +}; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct _opaque_pthread_mutex_t { + pub __sig: ::std::os::raw::c_long, + pub __opaque: [::std::os::raw::c_char; 56usize], +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of _opaque_pthread_mutex_t"][::std::mem::size_of::<_opaque_pthread_mutex_t>() - 64usize]; + ["Alignment of _opaque_pthread_mutex_t"] + [::std::mem::align_of::<_opaque_pthread_mutex_t>() - 8usize]; + ["Offset of field: _opaque_pthread_mutex_t::__sig"] + [::std::mem::offset_of!(_opaque_pthread_mutex_t, __sig) - 0usize]; + ["Offset of field: _opaque_pthread_mutex_t::__opaque"] + [::std::mem::offset_of!(_opaque_pthread_mutex_t, __opaque) - 8usize]; +}; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct _opaque_pthread_mutexattr_t { + pub __sig: ::std::os::raw::c_long, + pub __opaque: [::std::os::raw::c_char; 8usize], +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of _opaque_pthread_mutexattr_t"] + [::std::mem::size_of::<_opaque_pthread_mutexattr_t>() - 16usize]; + ["Alignment of _opaque_pthread_mutexattr_t"] + [::std::mem::align_of::<_opaque_pthread_mutexattr_t>() - 8usize]; + ["Offset of field: _opaque_pthread_mutexattr_t::__sig"] + [::std::mem::offset_of!(_opaque_pthread_mutexattr_t, __sig) - 0usize]; + ["Offset of field: _opaque_pthread_mutexattr_t::__opaque"] + [::std::mem::offset_of!(_opaque_pthread_mutexattr_t, __opaque) - 8usize]; +}; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct _opaque_pthread_once_t { + pub __sig: ::std::os::raw::c_long, + pub __opaque: [::std::os::raw::c_char; 8usize], +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of _opaque_pthread_once_t"][::std::mem::size_of::<_opaque_pthread_once_t>() - 16usize]; + ["Alignment of _opaque_pthread_once_t"] + [::std::mem::align_of::<_opaque_pthread_once_t>() - 8usize]; + ["Offset of field: _opaque_pthread_once_t::__sig"] + [::std::mem::offset_of!(_opaque_pthread_once_t, __sig) - 0usize]; + ["Offset of field: _opaque_pthread_once_t::__opaque"] + [::std::mem::offset_of!(_opaque_pthread_once_t, __opaque) - 8usize]; +}; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct _opaque_pthread_rwlock_t { + pub __sig: ::std::os::raw::c_long, + pub __opaque: [::std::os::raw::c_char; 192usize], +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of _opaque_pthread_rwlock_t"] + [::std::mem::size_of::<_opaque_pthread_rwlock_t>() - 200usize]; + ["Alignment of _opaque_pthread_rwlock_t"] + [::std::mem::align_of::<_opaque_pthread_rwlock_t>() - 8usize]; + ["Offset of field: _opaque_pthread_rwlock_t::__sig"] + [::std::mem::offset_of!(_opaque_pthread_rwlock_t, __sig) - 0usize]; + ["Offset of field: _opaque_pthread_rwlock_t::__opaque"] + [::std::mem::offset_of!(_opaque_pthread_rwlock_t, __opaque) - 8usize]; +}; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct _opaque_pthread_rwlockattr_t { + pub __sig: ::std::os::raw::c_long, + pub __opaque: [::std::os::raw::c_char; 16usize], +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of _opaque_pthread_rwlockattr_t"] + [::std::mem::size_of::<_opaque_pthread_rwlockattr_t>() - 24usize]; + ["Alignment of _opaque_pthread_rwlockattr_t"] + [::std::mem::align_of::<_opaque_pthread_rwlockattr_t>() - 8usize]; + ["Offset of field: _opaque_pthread_rwlockattr_t::__sig"] + [::std::mem::offset_of!(_opaque_pthread_rwlockattr_t, __sig) - 0usize]; + ["Offset of field: _opaque_pthread_rwlockattr_t::__opaque"] + [::std::mem::offset_of!(_opaque_pthread_rwlockattr_t, __opaque) - 8usize]; +}; #[repr(C)] #[derive(Debug, Copy, Clone)] -pub struct __fsid_t { - pub __val: [::std::os::raw::c_int; 2usize], +pub struct _opaque_pthread_t { + pub __sig: ::std::os::raw::c_long, + pub __cleanup_stack: *mut __darwin_pthread_handler_rec, + pub __opaque: [::std::os::raw::c_char; 8176usize], } #[allow(clippy::unnecessary_operation, clippy::identity_op)] const _: () = { - ["Size of __fsid_t"][::std::mem::size_of::<__fsid_t>() - 8usize]; - ["Alignment of __fsid_t"][::std::mem::align_of::<__fsid_t>() - 4usize]; - ["Offset of field: __fsid_t::__val"][::std::mem::offset_of!(__fsid_t, __val) - 0usize]; + ["Size of _opaque_pthread_t"][::std::mem::size_of::<_opaque_pthread_t>() - 8192usize]; + ["Alignment of _opaque_pthread_t"][::std::mem::align_of::<_opaque_pthread_t>() - 8usize]; + ["Offset of field: _opaque_pthread_t::__sig"] + [::std::mem::offset_of!(_opaque_pthread_t, __sig) - 0usize]; + ["Offset of field: _opaque_pthread_t::__cleanup_stack"] + [::std::mem::offset_of!(_opaque_pthread_t, __cleanup_stack) - 8usize]; + ["Offset of field: _opaque_pthread_t::__opaque"] + [::std::mem::offset_of!(_opaque_pthread_t, __opaque) - 16usize]; }; -pub type __clock_t = ::std::os::raw::c_long; -pub type __rlim_t = ::std::os::raw::c_ulong; -pub type __rlim64_t = ::std::os::raw::c_ulong; -pub type __id_t = ::std::os::raw::c_uint; -pub type __time_t = ::std::os::raw::c_long; -pub type __useconds_t = ::std::os::raw::c_uint; -pub type __suseconds_t = ::std::os::raw::c_long; -pub type __suseconds64_t = ::std::os::raw::c_long; -pub type __daddr_t = ::std::os::raw::c_int; -pub type __key_t = ::std::os::raw::c_int; -pub type __clockid_t = ::std::os::raw::c_int; -pub type __timer_t = *mut ::std::os::raw::c_void; -pub type __blksize_t = ::std::os::raw::c_long; -pub type __blkcnt_t = ::std::os::raw::c_long; -pub type __blkcnt64_t = ::std::os::raw::c_long; -pub type __fsblkcnt_t = ::std::os::raw::c_ulong; -pub type __fsblkcnt64_t = ::std::os::raw::c_ulong; -pub type __fsfilcnt_t = ::std::os::raw::c_ulong; -pub type __fsfilcnt64_t = ::std::os::raw::c_ulong; -pub type __fsword_t = ::std::os::raw::c_long; -pub type __ssize_t = ::std::os::raw::c_long; -pub type __syscall_slong_t = ::std::os::raw::c_long; -pub type __syscall_ulong_t = ::std::os::raw::c_ulong; -pub type __loff_t = __off64_t; -pub type __caddr_t = *mut ::std::os::raw::c_char; -pub type __intptr_t = ::std::os::raw::c_long; -pub type __socklen_t = ::std::os::raw::c_uint; -pub type __sig_atomic_t = ::std::os::raw::c_int; -pub type int_least8_t = __int_least8_t; -pub type int_least16_t = __int_least16_t; -pub type int_least32_t = __int_least32_t; -pub type int_least64_t = __int_least64_t; -pub type uint_least8_t = __uint_least8_t; -pub type uint_least16_t = __uint_least16_t; -pub type uint_least32_t = __uint_least32_t; -pub type uint_least64_t = __uint_least64_t; -pub type int_fast8_t = ::std::os::raw::c_schar; -pub type int_fast16_t = ::std::os::raw::c_long; -pub type int_fast32_t = ::std::os::raw::c_long; -pub type int_fast64_t = ::std::os::raw::c_long; -pub type uint_fast8_t = ::std::os::raw::c_uchar; -pub type uint_fast16_t = ::std::os::raw::c_ulong; -pub type uint_fast32_t = ::std::os::raw::c_ulong; -pub type uint_fast64_t = ::std::os::raw::c_ulong; -pub type intmax_t = __intmax_t; -pub type uintmax_t = __uintmax_t; +pub type __darwin_pthread_attr_t = _opaque_pthread_attr_t; +pub type __darwin_pthread_cond_t = _opaque_pthread_cond_t; +pub type __darwin_pthread_condattr_t = _opaque_pthread_condattr_t; +pub type __darwin_pthread_key_t = ::std::os::raw::c_ulong; +pub type __darwin_pthread_mutex_t = _opaque_pthread_mutex_t; +pub type __darwin_pthread_mutexattr_t = _opaque_pthread_mutexattr_t; +pub type __darwin_pthread_once_t = _opaque_pthread_once_t; +pub type __darwin_pthread_rwlock_t = _opaque_pthread_rwlock_t; +pub type __darwin_pthread_rwlockattr_t = _opaque_pthread_rwlockattr_t; +pub type __darwin_pthread_t = *mut _opaque_pthread_t; +pub type intmax_t = ::std::os::raw::c_long; +pub type uintmax_t = ::std::os::raw::c_ulong; #[doc = " Wrapper around Rust's [`&str`], without allocating memory, unlike [`std::ffi::CString`].\n The caller must use it as a Rust string. This is not a C-string."] #[repr(C)] #[derive(Debug, Copy, Clone)] @@ -324,3 +449,4 @@ const _: () = { ["Offset of field: PdRoute::shard"][::std::mem::offset_of!(PdRoute, shard) - 0usize]; ["Offset of field: PdRoute::read_write"][::std::mem::offset_of!(PdRoute, read_write) - 8usize]; }; +pub type __builtin_va_list = *mut ::std::os::raw::c_char; diff --git a/pgdog/Cargo.toml b/pgdog/Cargo.toml index 9464bb7e..c58cc933 100644 --- a/pgdog/Cargo.toml +++ b/pgdog/Cargo.toml @@ -63,6 +63,8 @@ hickory-resolver = "0.25.2" lazy_static = "1" dashmap = "6" derive_builder = "0.20.2" +tempfile = "3.23.0" +nix = { version = "0.31", features = ["signal"] } pgdog-config = { path = "../pgdog-config" } pgdog-vector = { path = "../pgdog-vector" } pgdog-stats = { path = "../pgdog-stats" } @@ -76,5 +78,4 @@ tikv-jemallocator = "0.6" cc = "1" [dev-dependencies] -tempfile = "3.23.0" stats_alloc = "0.1.10" diff --git a/pgdog/src/admin/set.rs b/pgdog/src/admin/set.rs index af180ec2..fd2fbd7b 100644 --- a/pgdog/src/admin/set.rs +++ b/pgdog/src/admin/set.rs @@ -180,6 +180,10 @@ impl Command for Set { config.config.general.connect_timeout = self.value.parse()?; } + "cross_shard_backend" => { + config.config.general.cross_shard_backend = Self::from_json(&self.value)?; + } + _ => return Err(Error::Syntax), } diff --git a/pgdog/src/backend/databases.rs b/pgdog/src/backend/databases.rs index d8028c35..1b3c9554 100644 --- a/pgdog/src/backend/databases.rs +++ b/pgdog/src/backend/databases.rs @@ -1,6 +1,6 @@ //! Databases behind pgDog. -use std::collections::{hash_map::Entry, HashMap}; +use std::collections::HashMap; use std::sync::Arc; use arc_swap::ArcSwap; @@ -9,8 +9,9 @@ use parking_lot::lock_api::MutexGuard; use parking_lot::{Mutex, RawMutex}; use tracing::{debug, error, info, warn}; +use crate::backend::fdw::PostgresLauncher; use crate::backend::replication::ShardedSchemas; -use crate::config::PoolerMode; +use crate::config::{set, PoolerMode}; use crate::frontend::client::query_engine::two_pc::Manager; use crate::frontend::router::parser::Cache; use crate::frontend::router::sharding::mapping::mapping_valid; @@ -82,12 +83,24 @@ pub fn reload_from_existing() -> Result<(), Error> { let databases = from_config(&config); replace_databases(databases, true)?; + + // Reconfigure FDW with new schema. + if config.config.general.cross_shard_backend.need_fdw() { + PostgresLauncher::get().reconfigure(); + } + Ok(()) } /// Initialize the databases for the first time. pub fn init() -> Result<(), Error> { let config = config(); + + // Start postgres_fdw compatibility engine. + if config.config.general.cross_shard_backend.need_fdw() { + PostgresLauncher::get().launch(); + } + replace_databases(from_config(&config), false)?; // Resize query cache @@ -114,6 +127,26 @@ pub fn reload() -> Result<(), Error> { tls::reload()?; + // Reconfigure FDW with new schema. + match ( + old_config.config.general.cross_shard_backend.need_fdw(), + new_config.config.general.cross_shard_backend.need_fdw(), + ) { + (true, true) => { + PostgresLauncher::get().reconfigure(); + } + + (false, true) => { + PostgresLauncher::get().launch(); + } + + (true, false) => { + PostgresLauncher::get().shutdown(); + } + + (false, false) => {} + } + // Remove any unused prepared statements. PreparedStatements::global() .write() @@ -126,35 +159,38 @@ pub fn reload() -> Result<(), Error> { } /// Add new user to pool. -pub(crate) fn add(mut user: crate::config::User) { +pub(crate) fn add(user: crate::config::User) -> Result<(), Error> { + use std::ops::Deref; + // One user at a time. - let _lock = lock(); + let lock = lock(); debug!( "adding user \"{}\" for database \"{}\" via auth passthrough", user.name, user.database ); - let config = config(); - for existing in &config.users.users { + let mut config = config().deref().clone(); + let mut found = false; + for existing in &mut config.users.users { if existing.name == user.name && existing.database == user.database { - let mut existing = existing.clone(); - existing.password = user.password.clone(); - user = existing; + found = true; + if existing.password().is_empty() { + existing.password = user.password.clone(); + } } } - let pool = new_pool(&user, &config.config); - if let Some((user, cluster)) = pool { - let databases = (*databases()).clone(); - let (added, databases) = databases.add(user, cluster); - if added { - // Launch the new pool (idempotent). - databases.launch(); - // Don't use replace_databases because Arc refers to the same DBs, - // and we'll shut them down. - DATABASES.store(Arc::new(databases)); - } + + if !found { + config.users.users.push(user); } + + set(config)?; + drop(lock); + + reload_from_existing()?; + + Ok(()) } /// Database/user pair that identifies a database cluster pool. @@ -196,15 +232,6 @@ impl ToUser for (&str, Option<&str>) { } } -// impl ToUser for &pgdog_config::User { -// fn to_user(&self) -> User { -// User { -// user: self.name.clone(), -// database: self.database.clone(), -// } -// } -// } - /// Databases. #[derive(Default, Clone)] pub struct Databases { @@ -215,24 +242,6 @@ pub struct Databases { } impl Databases { - /// Add new connection pools to the databases. - fn add(mut self, user: User, cluster: Cluster) -> (bool, Databases) { - match self.databases.entry(user) { - Entry::Vacant(e) => { - e.insert(cluster); - (true, self) - } - Entry::Occupied(mut e) => { - if e.get().password().is_empty() { - e.insert(cluster); - (true, self) - } else { - (false, self) - } - } - } - } - /// Check if a cluster exists, quickly. pub fn exists(&self, user: impl ToUser) -> bool { if let Some(cluster) = self.databases.get(&user.to_user()) { @@ -619,8 +628,8 @@ mod tests { use super::*; use crate::config::{Config, ConfigAndUsers, Database, Role}; - #[test] - fn test_mirror_user_isolation() { + #[tokio::test] + async fn test_mirror_user_isolation() { // Test that each user gets their own mirror cluster let mut config = Config::default(); @@ -700,8 +709,8 @@ mod tests { assert_eq!(bob_mirrors[0].name(), "db1_mirror"); } - #[test] - fn test_mirror_user_mismatch_handling() { + #[tokio::test] + async fn test_mirror_user_mismatch_handling() { // Test that mirroring is disabled gracefully when users don't match let mut config = Config::default(); @@ -776,8 +785,8 @@ mod tests { ); } - #[test] - fn test_precomputed_mirror_configs() { + #[tokio::test] + async fn test_precomputed_mirror_configs() { // Test that mirror configs are precomputed correctly during initialization let mut config = Config::default(); config.general.mirror_queue = 100; @@ -853,8 +862,8 @@ mod tests { ); } - #[test] - fn test_mirror_config_with_global_defaults() { + #[tokio::test] + async fn test_mirror_config_with_global_defaults() { // Test that global defaults are used when mirror-specific values aren't provided let mut config = Config::default(); config.general.mirror_queue = 150; @@ -926,8 +935,8 @@ mod tests { ); } - #[test] - fn test_mirror_config_partial_overrides() { + #[tokio::test] + async fn test_mirror_config_partial_overrides() { // Test that we can override just queue or just exposure let mut config = Config::default(); config.general.mirror_queue = 100; @@ -1026,8 +1035,8 @@ mod tests { ); } - #[test] - fn test_invalid_mirror_not_precomputed() { + #[tokio::test] + async fn test_invalid_mirror_not_precomputed() { // Test that invalid mirror configs (user mismatch) are not precomputed let mut config = Config::default(); @@ -1089,8 +1098,8 @@ mod tests { ); } - #[test] - fn test_mirror_config_no_users() { + #[tokio::test] + async fn test_mirror_config_no_users() { // Test that mirror configs without any users are not precomputed let mut config = Config::default(); config.general.mirror_queue = 100; @@ -1198,8 +1207,8 @@ mod tests { ); } - #[test] - fn test_user_all_databases_creates_pools_for_all_dbs() { + #[tokio::test] + async fn test_user_all_databases_creates_pools_for_all_dbs() { let mut config = Config::default(); config.databases = vec![ @@ -1261,8 +1270,8 @@ mod tests { assert_eq!(databases.all().len(), 3); } - #[test] - fn test_user_multiple_databases_creates_pools_for_specified_dbs() { + #[tokio::test] + async fn test_user_multiple_databases_creates_pools_for_specified_dbs() { let mut config = Config::default(); config.databases = vec![ @@ -1324,8 +1333,8 @@ mod tests { assert_eq!(databases.all().len(), 2); } - #[test] - fn test_all_databases_takes_priority_over_databases_list() { + #[tokio::test] + async fn test_all_databases_takes_priority_over_databases_list() { let mut config = Config::default(); config.databases = vec![ @@ -1406,8 +1415,8 @@ mod tests { ); } - #[test] - fn test_user_with_single_database_creates_one_pool() { + #[tokio::test] + async fn test_user_with_single_database_creates_one_pool() { let mut config = Config::default(); config.databases = vec![ @@ -1456,8 +1465,8 @@ mod tests { assert_eq!(databases.all().len(), 1); } - #[test] - fn test_multiple_users_with_different_database_access() { + #[tokio::test] + async fn test_multiple_users_with_different_database_access() { let mut config = Config::default(); config.databases = vec![ @@ -1534,8 +1543,8 @@ mod tests { assert_eq!(databases.all().len(), 6); } - #[test] - fn test_databases_list_with_nonexistent_database_skipped() { + #[tokio::test] + async fn test_databases_list_with_nonexistent_database_skipped() { let mut config = Config::default(); config.databases = vec![Database { diff --git a/pgdog/src/backend/error.rs b/pgdog/src/backend/error.rs index f2465198..2cdc0379 100644 --- a/pgdog/src/backend/error.rs +++ b/pgdog/src/backend/error.rs @@ -134,6 +134,9 @@ pub enum Error { #[error("unsupported aggregation {function}: {reason}")] UnsupportedAggregation { function: String, reason: String }, + + #[error("{0}")] + ForeignTable(#[from] crate::backend::schema::postgres_fdw::Error), } impl From for Error { diff --git a/pgdog/src/backend/fdw/bins.rs b/pgdog/src/backend/fdw/bins.rs new file mode 100644 index 00000000..6c6183d8 --- /dev/null +++ b/pgdog/src/backend/fdw/bins.rs @@ -0,0 +1,55 @@ +use std::path::PathBuf; + +use tokio::process::Command; +use tracing::error; + +use super::Error; + +pub(super) struct Bins { + pub(super) postgres: PathBuf, + pub(super) initdb: PathBuf, + pub(super) version: f32, +} + +impl Bins { + pub(super) async fn new() -> Result { + let pg_config = Command::new("pg_config").output().await?; + + if !pg_config.status.success() { + error!( + "[fdw] pg_config: {}", + String::from_utf8_lossy(&pg_config.stderr) + ); + return Err(Error::PgConfig); + } + + let pg_config = String::from_utf8_lossy(&pg_config.stdout); + let mut path = PathBuf::new(); + let mut version = 0.0; + + for line in pg_config.lines() { + if line.starts_with("BINDIR") { + let bin_dir = line.split("BINDIR = ").last().unwrap_or_default().trim(); + path = PathBuf::from(bin_dir); + } + if line.starts_with("VERSION") { + version = line + .split("VERSION = ") + .last() + .unwrap_or_default() + .trim() + .split(" ") + .nth(1) + .unwrap_or_default() + .trim() + .parse()?; + } + } + + Ok(Self { + postgres: path.join("postgres"), + initdb: path.join("initdb"), + version, + }) + } +} diff --git a/pgdog/src/backend/fdw/error.rs b/pgdog/src/backend/fdw/error.rs new file mode 100644 index 00000000..66d39e82 --- /dev/null +++ b/pgdog/src/backend/fdw/error.rs @@ -0,0 +1,36 @@ +use std::num::ParseFloatError; + +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum Error { + #[error("io: {0}")] + Io(#[from] std::io::Error), + + #[error("initdb failed")] + InitDb, + + #[error("pg_config failed")] + PgConfig, + + #[error("backend: {0}")] + Backend(#[from] crate::backend::Error), + + #[error("pool: {0}")] + Pool(#[from] crate::backend::pool::Error), + + #[error("postgres didn't launch in time")] + Timeout(#[from] tokio::time::error::Elapsed), + + #[error("nix: {0}")] + Nix(#[from] nix::Error), + + #[error("shards don't have the same number of replicas/primary")] + ShardsHostsMismatch, + + #[error("error parsing postgres version")] + PostgresVersion(#[from] ParseFloatError), + + #[error("postgres process exited unexpectedly")] + ProcessExited, +} diff --git a/pgdog/src/backend/fdw/launcher.rs b/pgdog/src/backend/fdw/launcher.rs new file mode 100644 index 00000000..1332af9c --- /dev/null +++ b/pgdog/src/backend/fdw/launcher.rs @@ -0,0 +1,352 @@ +use std::{ + ops::Deref, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; + +use crate::{backend::databases::databases, config::config}; + +use super::{Error, PostgresProcess}; +use once_cell::sync::Lazy; +use tokio::{ + select, spawn, + sync::broadcast, + time::{sleep, Duration}, +}; +use tracing::{error, info}; + +static LAUNCHER: Lazy = Lazy::new(PostgresLauncher::new); + +#[derive(Clone, Debug, PartialEq)] +pub enum LauncherEvent { + // Commands + Start, + Shutdown, + Reconfigure, + + // Status + Ready, + ShutdownComplete, +} + +#[derive(Debug, Clone)] +pub struct PostgresLauncher { + inner: Arc, +} + +impl Deref for PostgresLauncher { + type Target = Inner; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +#[derive(Debug)] +pub struct Inner { + events: broadcast::Sender, + ready: AtomicBool, +} + +impl PostgresLauncher { + fn new() -> Self { + let (events, _) = broadcast::channel(16); + + let launcher = Self { + inner: Arc::new(Inner { + events, + ready: AtomicBool::new(false), + }), + }; + + // Subscribe before spawning to avoid race condition. + let receiver = launcher.events.subscribe(); + launcher.spawn(receiver); + launcher + } + + /// Get the launcher singleton instance. + pub(crate) fn get() -> Self { + LAUNCHER.clone() + } + + /// Get configured port. + pub(crate) fn port(&self) -> u16 { + config().config.fdw.port + } + + /// Start the launcher. Idempotent. + pub(crate) fn launch(&self) { + let _ = self.events.send(LauncherEvent::Start); + } + + pub(crate) fn shutdown(&self) { + let _ = self.events.send(LauncherEvent::Shutdown); + } + + /// Request reconfiguration. + pub(crate) fn reconfigure(&self) { + let _ = self.events.send(LauncherEvent::Reconfigure); + } + + /// Shutdown and wait for completion. + pub(crate) async fn shutdown_wait(&self) { + // Subscribe before sending to avoid race condition. + let receiver = self.events.subscribe(); + let _ = self.events.send(LauncherEvent::Shutdown); + Self::wait_for(receiver, LauncherEvent::ShutdownComplete).await; + } + + /// Wait for Postgres to be ready. + pub(crate) async fn wait_ready(&self) { + // Subscribe first to avoid race with Ready event. + let receiver = self.events.subscribe(); + if self.ready.load(Ordering::Acquire) { + return; + } + Self::wait_for(receiver, LauncherEvent::Ready).await; + } + + async fn wait_for(mut receiver: broadcast::Receiver, target: LauncherEvent) { + loop { + match receiver.recv().await { + Ok(event) if event == target => return, + Ok(_) => continue, + Err(broadcast::error::RecvError::Closed) => return, + Err(broadcast::error::RecvError::Lagged(_)) => continue, + } + } + } + + fn send(&self, event: LauncherEvent) { + match &event { + LauncherEvent::Ready => self.ready.store(true, Ordering::Release), + LauncherEvent::Start | LauncherEvent::Shutdown => { + self.ready.store(false, Ordering::Release) + } + _ => {} + } + let _ = self.events.send(event); + } + + fn spawn(&self, receiver: broadcast::Receiver) { + let launcher = self.clone(); + + spawn(async move { + let mut receiver = receiver; + + loop { + // Wait for Start or Shutdown. + match receiver.recv().await { + Ok(LauncherEvent::Start) => {} + Ok(LauncherEvent::Shutdown) => { + launcher.send(LauncherEvent::Ready); + launcher.send(LauncherEvent::ShutdownComplete); + return; + } + Ok(_) => continue, + Err(broadcast::error::RecvError::Closed) => return, + Err(broadcast::error::RecvError::Lagged(_)) => continue, + } + + info!("[fdw] launching fdw backend on 0.0.0.0:{}", launcher.port()); + + match launcher.run(&mut receiver).await { + Ok(()) => return, // Clean shutdown + Err(err) => { + error!("[fdw] launcher error: {}", err); + sleep(Duration::from_millis(1000)).await; + } + } + } + }); + } + + async fn run(&self, receiver: &mut broadcast::Receiver) -> Result<(), Error> { + let port = self.port(); + let mut process = PostgresProcess::new(None, port).await?; + let mut shutdown_receiver = process.shutdown_receiver(); + + process.launch().await?; + process.wait_ready().await; + + for cluster in databases().all().values() { + if cluster.shards().len() > 1 { + process.configure(cluster).await?; + } + } + + process.configuration_complete(); + + self.send(LauncherEvent::Ready); + + loop { + select! { + event = receiver.recv() => { + match event { + Ok(LauncherEvent::Shutdown) => { + process.stop_wait().await; + self.send(LauncherEvent::ShutdownComplete); + return Ok(()); + } + + Ok(LauncherEvent::Reconfigure) => { + for cluster in databases().all().values() { + if cluster.shards().len() > 1 { + if let Err(err) = process.configure(cluster).await { + error!("[fdw] reconfigure error: {}", err); + } + } + } + + process.configuration_complete(); + } + + Ok(_) => continue, + Err(broadcast::error::RecvError::Closed) => { + process.stop_wait().await; + return Ok(()); + } + Err(broadcast::error::RecvError::Lagged(_)) => continue, + } + } + + _ = shutdown_receiver.changed() => { + process.clear_pid(); + self.send(LauncherEvent::ShutdownComplete); + return Err(Error::ProcessExited); + } + } + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::backend::{pool::Address, ConnectReason, Server, ServerOptions}; + use crate::config::config; + use tokio::time::timeout; + + fn test_launcher() -> PostgresLauncher { + let (events, _) = broadcast::channel(16); + let launcher = PostgresLauncher { + inner: Arc::new(Inner { + events, + ready: AtomicBool::new(false), + }), + }; + let receiver = launcher.events.subscribe(); + launcher.spawn(receiver); + launcher + } + + #[tokio::test] + async fn test_postgres_launcher() { + crate::logger(); + let fdw = config().config.fdw; + + let address = Address { + host: "127.0.0.1".into(), + port: fdw.port, + user: "postgres".into(), + database_name: "postgres".into(), + ..Default::default() + }; + + let launcher = PostgresLauncher::get(); + launcher.launch(); + + timeout(Duration::from_secs(10), launcher.wait_ready()) + .await + .expect("timeout waiting for ready"); + + let mut conn = + Server::connect(&address, ServerOptions::default(), ConnectReason::default()) + .await + .unwrap(); + conn.execute("SELECT 1").await.unwrap(); + + timeout(Duration::from_secs(10), launcher.shutdown_wait()) + .await + .expect("timeout waiting for shutdown"); + } + + #[tokio::test] + async fn test_shutdown_without_start() { + let launcher = test_launcher(); + + // Give spawn task time to start waiting. + sleep(Duration::from_millis(10)).await; + + // Shutdown without ever starting - should not hang. + timeout(Duration::from_secs(5), launcher.shutdown_wait()) + .await + .expect("shutdown_wait() hung when FDW was never started"); + } + + #[tokio::test] + async fn test_wait_ready_no_race() { + // Test that wait_ready doesn't miss Ready event due to race condition. + // Run multiple iterations to increase chance of hitting race window. + for _ in 0..100 { + let launcher = test_launcher(); + + // Spawn task that sends Ready immediately. + let launcher_clone = launcher.clone(); + spawn(async move { + launcher_clone.send(LauncherEvent::Ready); + }); + + // wait_ready should not hang even if Ready is sent + // between subscribe and wait. + timeout(Duration::from_millis(100), launcher.wait_ready()) + .await + .expect("wait_ready() missed Ready event - race condition"); + } + } + + #[tokio::test] + async fn test_shutdown_wait_no_race() { + // Test that shutdown_wait doesn't miss ShutdownComplete due to race. + for _ in 0..100 { + let launcher = test_launcher(); + + // Give spawn task time to start. + sleep(Duration::from_millis(1)).await; + + // shutdown_wait sends Shutdown and waits for ShutdownComplete. + // The spawn loop should receive Shutdown and send ShutdownComplete. + // This should not hang even with tight timing. + timeout(Duration::from_millis(100), launcher.shutdown_wait()) + .await + .expect("shutdown_wait() missed ShutdownComplete - race condition"); + } + } + + #[tokio::test] + async fn test_concurrent_wait_ready() { + // Multiple tasks waiting for Ready concurrently. + let launcher = test_launcher(); + + let mut handles = vec![]; + for _ in 0..10 { + let l = launcher.clone(); + handles.push(spawn(async move { + timeout(Duration::from_millis(100), l.wait_ready()) + .await + .expect("concurrent wait_ready timed out"); + })); + } + + // Small delay then send Ready. + sleep(Duration::from_millis(5)).await; + launcher.send(LauncherEvent::Ready); + + for handle in handles { + handle.await.unwrap(); + } + } +} diff --git a/pgdog/src/backend/fdw/lb.rs b/pgdog/src/backend/fdw/lb.rs new file mode 100644 index 00000000..0ad66f3e --- /dev/null +++ b/pgdog/src/backend/fdw/lb.rs @@ -0,0 +1,57 @@ +use std::sync::Arc; + +use pgdog_config::Role; +use tokio::spawn; + +use crate::backend::fdw::PostgresLauncher; +use crate::backend::pool::{Guard, Request}; +use crate::backend::{Cluster, LoadBalancer, Pool}; + +use super::Error; +use super::PostgresProcess; + +#[derive(Clone, Debug)] +pub(crate) struct FdwLoadBalancer { + lb: Arc, +} + +impl FdwLoadBalancer { + pub(crate) fn new(cluster: &Cluster) -> Result { + let port = PostgresLauncher::get().port(); + let configs = PostgresProcess::connection_pool_configs(port, cluster)?; + let primary = configs + .iter() + .find(|p| p.0 == Role::Primary) + .map(|p| Pool::new(&p.1)); + let addrs: Vec<_> = configs.iter().map(|c| c.1.clone()).collect(); + + let lb = Arc::new(LoadBalancer::new( + &primary, + &addrs, + cluster.lb_strategy(), + cluster.rw_split(), + )); + + Ok(Self { lb }) + } + + pub(crate) fn launch(&self) { + let lb = self.lb.clone(); + spawn(async move { + let launcher = PostgresLauncher::get(); + launcher.wait_ready().await; + lb.launch(); + }); + } + + pub(crate) fn primary(&self) -> Option { + self.lb.primary().cloned() + } + + pub(crate) async fn get( + &self, + request: &Request, + ) -> Result { + self.lb.get(request).await + } +} diff --git a/pgdog/src/backend/fdw/mod.rs b/pgdog/src/backend/fdw/mod.rs new file mode 100644 index 00000000..236c38d0 --- /dev/null +++ b/pgdog/src/backend/fdw/mod.rs @@ -0,0 +1,12 @@ +pub mod bins; +pub mod error; +pub mod launcher; +pub mod lb; +pub mod postgres; +pub mod postgres_config; + +pub use error::Error; +pub(crate) use launcher::PostgresLauncher; +pub(crate) use lb::FdwLoadBalancer; +pub(crate) use postgres::PostgresProcess; +pub(crate) use postgres_config::PostgresConfig; diff --git a/pgdog/src/backend/fdw/postgres.rs b/pgdog/src/backend/fdw/postgres.rs new file mode 100644 index 00000000..c5834822 --- /dev/null +++ b/pgdog/src/backend/fdw/postgres.rs @@ -0,0 +1,678 @@ +use std::{ + collections::{HashMap, HashSet}, + path::{Path, PathBuf}, + process::Stdio, + time::Duration, +}; + +#[cfg(unix)] +use nix::{ + sys::signal::{kill, Signal}, + unistd::Pid, +}; + +use once_cell::sync::Lazy; +use pgdog_config::Role; +use rand::random_range; +use regex::Regex; +use tempfile::TempDir; +use tokio::{ + fs::remove_dir_all, + io::{AsyncBufReadExt, BufReader}, + process::{Child, Command}, + select, spawn, + sync::watch, + time::{sleep, Instant}, +}; +use tracing::{error, info, warn}; + +use crate::backend::{ + pool::{Address, Config, PoolConfig, Request}, + schema::postgres_fdw::{quote_identifier, FdwServerDef, ForeignTableSchema}, + Cluster, ConnectReason, Server, ServerOptions, +}; + +use super::{bins::Bins, Error, PostgresConfig}; + +static LOG_PREFIX: Lazy = + Lazy::new(|| Regex::new(r"^(LOG|WARNING|ERROR|FATAL|PANIC|DEBUG\d?|INFO|NOTICE):\s+").unwrap()); + +struct PostgresProcessAsync { + child: Child, + initdb_dir: PathBuf, + shutdown: watch::Receiver, + shutdown_complete: watch::Sender, + version: f32, + port: u16, +} + +impl PostgresProcessAsync { + /// Stop Postgres and cleanup. + async fn stop(&mut self) -> Result<(), Error> { + info!( + "[fdw] stopping PostgreSQL {} running on 0.0.0.0:{}", + self.version, self.port + ); + + #[cfg(unix)] + { + let pid = self.child.id().expect("child has no pid") as i32; + let pid = Pid::from_raw(pid); + kill(pid, Signal::SIGINT)?; + } + + #[cfg(not(unix))] + self.child.kill().await?; + + self.child.wait().await?; + + // Delete data dir, its ephemeral. + remove_dir_all(&self.initdb_dir).await?; + + Ok(()) + } +} + +#[derive(Debug, Clone)] +pub(crate) struct FdwBackend { + pub(crate) config: Config, + pub(crate) address: Address, + pub(crate) database_name: String, + pub(crate) role: Role, +} + +#[derive(Debug)] +pub(crate) struct PostgresProcess { + postres: PathBuf, + initdb: PathBuf, + initdb_dir: PathBuf, + shutdown: watch::Sender, + shutdown_complete: watch::Sender, + port: u16, + pid: Option, + version: f32, + /// Tracks which cluster databases have been fully configured. + /// Subsequent clusters with the same database only get user mappings. + configured_databases: HashSet, +} + +impl PostgresProcess { + pub(crate) async fn new(initdb_path: Option<&Path>, port: u16) -> Result { + let initdb_path = if let Some(path) = initdb_path { + path.to_owned() + } else { + TempDir::new()?.keep() + }; + + let bins = Bins::new().await?; + + let (shutdown, _) = watch::channel(false); + let (shutdown_complete, _) = watch::channel(false); + + Ok(Self { + postres: bins.postgres, + initdb: bins.initdb, + initdb_dir: initdb_path, + shutdown, + shutdown_complete, + port, + pid: None, + version: bins.version, + configured_databases: HashSet::new(), + }) + } + + /// Kill any existing process listening on the given port. + /// This handles orphaned postgres processes from previous crashes. + #[cfg(unix)] + async fn kill_existing_on_port(port: u16) { + // Use fuser to find and kill any process on the port + let result = Command::new("fuser") + .arg("-k") + .arg(format!("{}/tcp", port)) + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .status() + .await; + + if let Ok(status) = result { + if status.success() { + warn!( + "[fdw] killed orphaned process on port {} from previous run", + port + ); + // Give it a moment to fully terminate + sleep(Duration::from_millis(100)).await; + } + } + } + + #[cfg(not(unix))] + async fn kill_existing_on_port(_port: u16) { + // Not implemented for non-unix platforms + } + + /// Setup and launch Postgres process. + pub(crate) async fn launch(&mut self) -> Result<(), Error> { + // Clean up any orphaned postgres from previous crashes + Self::kill_existing_on_port(self.port).await; + + info!( + "[fdw] initializing \"{}\" (PostgreSQL {})", + self.initdb_dir.display(), + self.version + ); + + let process = Command::new(&self.initdb) + .arg("-D") + .arg(&self.initdb_dir) + .arg("--username") + .arg("postgres") + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .output() + .await?; + + if !process.status.success() { + error!("{}", String::from_utf8_lossy(&process.stdout)); + error!("{}", String::from_utf8_lossy(&process.stderr)); + return Err(Error::InitDb); + } + + // Configure Postgres. + PostgresConfig::new(&self.initdb_dir.join("postgresql.conf")) + .await? + .configure_and_save(self.port, self.version) + .await?; + + info!( + "[fdw] launching PostgreSQL {} on 0.0.0.0:{}", + self.version, self.port + ); + + let mut cmd = Command::new(&self.postres); + cmd.arg("-D") + .arg(&self.initdb_dir) + .arg("-k") + .arg(&self.initdb_dir) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()); + + #[cfg(unix)] + cmd.process_group(0); // Prevent sigint from terminal. + + // SAFETY: prctl(PR_SET_PDEATHSIG) is async-signal-safe and doesn't + // access any shared state. It tells the kernel to send SIGKILL to + // this process when its parent dies, preventing orphaned processes. + #[cfg(target_os = "linux")] + { + #[allow(unused_imports)] + use std::os::unix::process::CommandExt; + unsafe { + cmd.pre_exec(|| { + const PR_SET_PDEATHSIG: nix::libc::c_int = 1; + nix::libc::prctl(PR_SET_PDEATHSIG, nix::libc::SIGKILL); + Ok(()) + }); + } + } + + let child = cmd.spawn()?; + + self.pid = child.id().map(|pid| pid as i32); + + let mut process = PostgresProcessAsync { + child, + shutdown: self.shutdown.subscribe(), + shutdown_complete: self.shutdown_complete.clone(), + initdb_dir: self.initdb_dir.clone(), + port: self.port, + version: self.version, + }; + + spawn(async move { + info!("[fdw] postgres process running"); + + let reader = process + .child + .stderr + .take() + .map(|stdout| BufReader::new(stdout)); + + let mut reader = if let Some(reader) = reader { + reader + } else { + error!("[fdw] failed to start subprocess: no stderr"); + if let Err(err) = process.stop().await { + error!("[fdw] failed to abort subprocess: {}", err); + } + return; + }; + + loop { + let mut line = String::new(); + select! { + _ = process.shutdown.changed() => { + if *process.shutdown.borrow() { + if let Err(err) = process.stop().await { + error!("[fdw] shutdown error: {}", err); + } + break; + } + } + + exit_status = process.child.wait() => { + // Drain remaining stderr before reporting shutdown + loop { + let mut remaining = String::new(); + match reader.read_line(&mut remaining).await { + Ok(0) => break, // EOF + Ok(_) => { + if !remaining.is_empty() { + let remaining = LOG_PREFIX.replace(&remaining, ""); + info!("[fdw::subprocess] {}", remaining.trim()); + } + } + Err(_) => break, + } + } + error!("[fdw] postgres shut down unexpectedly, exit status: {:?}", exit_status); + break; + } + + res = reader.read_line(&mut line) => { + if let Err(err) = res { + error!("[fdw] process error: {}", err); + break; + } + + if !line.is_empty() { + let line = LOG_PREFIX.replace(&line, ""); + info!("[fdw::subprocess] {}", line.trim()); + } + } + } + } + + let _ = process.shutdown_complete.send(true); + }); + + Ok(()) + } + + /// Get a receiver for shutdown completion notification. + pub(super) fn shutdown_receiver(&self) -> watch::Receiver { + self.shutdown_complete.subscribe() + } + + pub(crate) fn connection_pool_configs( + port: u16, + cluster: &Cluster, + ) -> Result, Error> { + Ok(Self::pools_to_fdw_backends(cluster, 0)? + .into_iter() + .enumerate() + .map(|(database_number, backend)| { + let address = Address { + host: "127.0.0.1".into(), + port, + database_name: backend.database_name.clone(), + password: "".into(), // We use trust + user: cluster.identifier().user.clone(), + database_number, + }; + + ( + backend.role, + PoolConfig { + address, + config: backend.config, + }, + ) + }) + .collect()) + } + + pub(super) fn pools_to_fdw_backends( + cluster: &Cluster, + shard: usize, + ) -> Result, Error> { + let shard = cluster + .shards() + .get(shard) + .ok_or(Error::ShardsHostsMismatch)?; + + Ok(shard + .pools_with_roles() + .into_iter() + .enumerate() + .map(|(number, (role, pool))| { + let database_name = format!("{}_{}", cluster.identifier().database, number); + FdwBackend { + config: pool.config().clone(), + address: pool.addr().clone(), + database_name, + role, + } + }) + .collect()) + } + + async fn setup_databases(&mut self, cluster: &Cluster) -> Result<(), Error> { + let hosts: Vec<_> = cluster + .shards() + .iter() + .map(|shard| { + let roles: Vec<_> = shard + .pools_with_roles() + .iter() + .map(|(role, _)| role) + .cloned() + .collect(); + roles + }) + .collect(); + let identical = hosts.windows(2).all(|w| w.get(0) == w.get(1)); + if !identical { + return Err(Error::ShardsHostsMismatch); + } + + let mut admin_connection = self.admin_connection().await?; + + for backend in Self::pools_to_fdw_backends(cluster, 0)? { + let exists: Vec = admin_connection + .fetch_all(&format!( + "SELECT datname FROM pg_database WHERE datname = '{}'", + backend.database_name.replace('\'', "''") + )) + .await?; + + if exists.is_empty() { + admin_connection + .execute(format!( + "CREATE DATABASE {}", + quote_identifier(&backend.database_name) + )) + .await?; + } + } + + let user = cluster.identifier().user.clone(); + + let user_exists: Vec = admin_connection + .fetch_all(&format!( + "SELECT rolname FROM pg_roles WHERE rolname = '{}'", + user.replace('\'', "''") + )) + .await?; + + if user_exists.is_empty() { + admin_connection + .execute(format!( + "CREATE USER {} SUPERUSER LOGIN", + quote_identifier(&user) + )) + .await?; + } + + Ok(()) + } + + /// Create the same load-balancing and sharding setup we have in pgdog.toml + /// for this cluster. This function is idempotent. + pub(crate) async fn configure(&mut self, cluster: &Cluster) -> Result<(), Error> { + self.setup_databases(cluster).await?; + let now = Instant::now(); + + let cluster_db = cluster.identifier().database.clone(); + let first_setup = !self.configured_databases.contains(&cluster_db); + + info!( + "[fdw] setting up database={} user={} initial={}", + cluster.identifier().database, + cluster.identifier().user, + first_setup, + ); + + let sharding_schema = cluster.sharding_schema(); + + let schema = if first_setup { + // TODO: Double check schemas are identical on all shards. + let shard = random_range(0..sharding_schema.shards); + let mut server = cluster + .primary_or_replica(shard, &Request::default()) + .await?; + Some(ForeignTableSchema::load(&mut server).await?) + } else { + None + }; + + // Setup persistent connections. + let mut connections = HashMap::new(); + + // We checked that all shards have the same number of replicas. + let databases: Vec<_> = Self::pools_to_fdw_backends(cluster, 0)? + .into_iter() + .map(|backend| backend.database_name) + .collect(); + + for database in &databases { + let identifier = (cluster.identifier().user.clone(), database.clone()); + let mut connection = self.connection(&identifier.0, database).await?; + + if first_setup { + // Create extension in a dedicated schema that won't be dropped. + // This prevents DROP SCHEMA public CASCADE from removing postgres_fdw and its servers. + connection + .execute("CREATE SCHEMA IF NOT EXISTS pgdog_internal") + .await?; + connection + .execute("CREATE EXTENSION IF NOT EXISTS postgres_fdw SCHEMA pgdog_internal") + .await?; + } + + connections.insert(identifier, connection); + } + + // Build server definitions for each database and run setup + let num_pools = Self::pools_to_fdw_backends(cluster, 0)?.len(); + for pool_position in 0..num_pools { + let database = format!("{}_{}", cluster.identifier().database, pool_position); + let identifier = (cluster.identifier().user.clone(), database); + let mut connection = connections + .get_mut(&identifier) + .expect("connection is gone"); + + // Collect server definitions for all shards using this pool position + let mut server_defs = Vec::new(); + for (shard_num, _) in cluster.shards().iter().enumerate() { + let backends = Self::pools_to_fdw_backends(cluster, shard_num)?; + if let Some(backend) = backends.get(pool_position) { + server_defs.push(FdwServerDef { + shard_num, + host: backend.address.host.clone(), + port: backend.address.port, + database_name: backend.address.database_name.clone(), + user: backend.address.user.clone(), + password: backend.address.password.clone(), + mapping_user: cluster.identifier().user.clone(), + }); + } + } + + if first_setup { + schema + .as_ref() + .unwrap() + .setup(&mut connection, &sharding_schema, &server_defs) + .await?; + } else { + ForeignTableSchema::setup_user_mappings(&mut connection, &server_defs).await?; + } + } + + if first_setup { + self.configured_databases.insert(cluster_db); + } + + let elapsed = now.elapsed(); + + info!( + "[fdw] setup complete for database={} user={} in {:.3}ms", + cluster.identifier().database, + cluster.identifier().user, + elapsed.as_secs_f32() * 1000.0, + ); + + Ok(()) + } + + pub(crate) fn configuration_complete(&mut self) { + self.configured_databases.clear(); + } + + /// Create server connection. + pub(crate) async fn admin_connection(&self) -> Result { + self.connection("postgres", "postgres").await + } + + /// Get a connection with the user and database. + pub(crate) async fn connection(&self, user: &str, database: &str) -> Result { + let address = self.address(user, database); + + let server = + Server::connect(&address, ServerOptions::default(), ConnectReason::Other).await?; + + Ok(server) + } + + fn address(&self, user: &str, database: &str) -> Address { + Address { + host: "127.0.0.1".into(), + port: self.port, + user: user.into(), + database_name: database.into(), + ..Default::default() + } + } + + /// Wait until process is ready and accepting connections. + pub(crate) async fn wait_ready(&self) { + self.wait_ready_internal().await; + } + + async fn wait_ready_internal(&self) { + while let Err(_) = self.admin_connection().await { + sleep(Duration::from_millis(100)).await; + continue; + } + } + + pub(crate) async fn stop_wait(&mut self) { + let mut receiver = self.shutdown_complete.subscribe(); + + // Check if already complete (process may have exited). + if *receiver.borrow() { + self.pid.take(); + return; + } + + // Signal shutdown. + self.shutdown.send_modify(|v| *v = true); + + // Wait for shutdown to complete. + while receiver.changed().await.is_ok() { + if *receiver.borrow() { + break; + } + } + self.pid.take(); + } + + /// Clear the pid to prevent dirty shutdown warning. + /// Used when the process has already exited. + pub(crate) fn clear_pid(&mut self) { + self.pid.take(); + } +} + +impl Drop for PostgresProcess { + fn drop(&mut self) { + if let Some(pid) = self.pid.take() { + warn!("[fdw] dirty shutdown initiated for pid {}", pid); + + #[cfg(unix)] + { + if let Err(err) = kill(Pid::from_raw(pid), Signal::SIGKILL) { + error!("[fdw] dirty shutdown error: {}", err); + } + + if let Err(err) = std::fs::remove_dir_all(&self.initdb_dir) { + error!("[fdw] dirty shutdown cleanup error: {}", err); + } + } + } + } +} + +#[cfg(test)] +mod test { + + use crate::config::config; + + use super::*; + + #[tokio::test] + async fn test_postgres_process() { + crate::logger(); + let cluster = Cluster::new_test(&config()); + cluster.launch(); + + { + let mut primary = cluster.primary(0, &Request::default()).await.unwrap(); + primary + .execute("CREATE TABLE IF NOT EXISTS test_postgres_process (customer_id BIGINT)") + .await + .unwrap(); + } + + let mut process = PostgresProcess::new(None, 45012).await.unwrap(); + + process.launch().await.unwrap(); + process.wait_ready().await; + process.configure(&cluster).await.unwrap(); + let mut server = process.admin_connection().await.unwrap(); + let backends = server + .fetch_all::("SELECT backend_type::text FROM pg_stat_activity ORDER BY 1") + .await + .unwrap(); + + assert_eq!( + backends, + [ + "background writer", + "checkpointer", + "client backend", + "walwriter" + ] + ); + + let mut server = process.connection("pgdog", "pgdog_0").await.unwrap(); + server + .execute("SELECT * FROM pgdog.test_postgres_process") + .await + .unwrap(); + + process.stop_wait().await; + + { + let mut primary = cluster.primary(0, &Request::default()).await.unwrap(); + primary + .execute("DROP TABLE test_postgres_process") + .await + .unwrap(); + } + + cluster.shutdown(); + } +} diff --git a/pgdog/src/backend/fdw/postgres_config.rs b/pgdog/src/backend/fdw/postgres_config.rs new file mode 100644 index 00000000..bc7a4247 --- /dev/null +++ b/pgdog/src/backend/fdw/postgres_config.rs @@ -0,0 +1,70 @@ +use std::path::{Path, PathBuf}; +use tokio::{fs::File, io::AsyncWriteExt}; + +use super::Error; + +#[derive(Debug, Clone)] +pub(crate) struct PostgresConfig { + path: PathBuf, + content: String, +} + +impl PostgresConfig { + /// Load configuration from path. + pub(crate) async fn new(path: impl AsRef) -> Result { + let path = PathBuf::from(path.as_ref()); + + Ok(Self { + path, + content: String::new(), + }) + } + + /// Add a setting + pub(crate) fn set(&mut self, name: &str, value: &str) { + self.content.push_str(&format!("{} = {}\n", name, value)); + } + + /// Save configuration. + pub(crate) async fn save(&self) -> Result<(), Error> { + let mut file = File::create(&self.path).await?; + file.write_all(self.content.as_bytes()).await?; + Ok(()) + } + + /// Configure default settings we need off/on. + pub(crate) async fn configure_and_save( + &mut self, + port: u16, + version: f32, + ) -> Result<(), Error> { + // Make it accessible via psql for debugging. + self.set("listen_addresses", "'0.0.0.0'"); + self.set("port", &port.to_string()); + + // Disable logical replication workers. + self.set("max_logical_replication_workers", "0"); + self.set("max_sync_workers_per_subscription", "0"); + self.set("max_parallel_apply_workers_per_subscription", "0"); + + self.set("max_connections", "1000"); + self.set("log_line_prefix", "''"); + self.set("log_connections", "on"); + self.set("log_disconnections", "on"); + // self.set("log_statement", "off"); + // Disable autovacuum. This is safe, this database doesn't write anything locally. + self.set("autovacuum", "off"); + // Make the background writer do nothing. + self.set("bgwriter_lru_maxpages", "0"); + self.set("bgwriter_delay", "10s"); + + if version >= 18.0 { + // Disable async io workers. + self.set("io_method", "sync"); + } + + self.save().await?; + + Ok(()) + } +} diff --git a/pgdog/src/backend/mod.rs b/pgdog/src/backend/mod.rs index 2bc40068..80bd6935 100644 --- a/pgdog/src/backend/mod.rs +++ b/pgdog/src/backend/mod.rs @@ -4,6 +4,7 @@ pub mod connect_reason; pub mod databases; pub mod disconnect_reason; pub mod error; +pub mod fdw; pub mod maintenance_mode; pub mod pool; pub mod prepared_statements; diff --git a/pgdog/src/backend/pool/cluster.rs b/pgdog/src/backend/pool/cluster.rs index 33c0798a..95f8a4fd 100644 --- a/pgdog/src/backend/pool/cluster.rs +++ b/pgdog/src/backend/pool/cluster.rs @@ -2,7 +2,8 @@ use parking_lot::Mutex; use pgdog_config::{ - LoadSchema, PreparedStatements, QueryParserEngine, QueryParserLevel, Rewrite, RewriteMode, + CrossShardBackend, LoadSchema, PreparedStatements, QueryParserEngine, QueryParserLevel, + Rewrite, RewriteMode, }; use std::{ sync::{ @@ -17,6 +18,7 @@ use tracing::{error, info}; use crate::{ backend::{ databases::{databases, User as DatabaseUser}, + fdw::FdwLoadBalancer, pool::ee::schema_changed_hook, replication::{ReplicationConfig, ShardedSchemas}, Schema, ShardedTables, @@ -78,6 +80,10 @@ pub struct Cluster { query_parser_engine: QueryParserEngine, reload_schema_on_ddl: bool, load_schema: LoadSchema, + lb_strategy: LoadBalancingStrategy, + rw_split: ReadWriteSplit, + fdw_lb: Option, + cross_shard_backend: CrossShardBackend, } /// Sharding configuration from the cluster. @@ -153,6 +159,7 @@ pub struct ClusterConfig<'a> { pub lsn_check_interval: Duration, pub reload_schema_on_ddl: bool, pub load_schema: LoadSchema, + pub cross_shard_backend: CrossShardBackend, } impl<'a> ClusterConfig<'a> { @@ -203,6 +210,7 @@ impl<'a> ClusterConfig<'a> { lsn_check_interval: Duration::from_millis(general.lsn_check_interval), reload_schema_on_ddl: general.reload_schema_on_ddl, load_schema: general.load_schema, + cross_shard_backend: general.cross_shard_backend, } } } @@ -239,6 +247,7 @@ impl Cluster { query_parser_engine, reload_schema_on_ddl, load_schema, + cross_shard_backend, } = config; let identifier = Arc::new(DatabaseUser { @@ -246,7 +255,7 @@ impl Cluster { database: name.to_owned(), }); - Self { + let mut cluster = Self { identifier: identifier.clone(), shards: shards .iter() @@ -287,7 +296,17 @@ impl Cluster { query_parser_engine, reload_schema_on_ddl, load_schema, + lb_strategy, + rw_split, + fdw_lb: None, + cross_shard_backend, + }; + + if cross_shard_backend.need_fdw() && cluster.shards().len() > 1 { + cluster.fdw_lb = FdwLoadBalancer::new(&cluster).ok(); } + + cluster } /// Change config to work with logical replication streaming. @@ -312,6 +331,37 @@ impl Cluster { shard.replica(request).await } + /// Get a connection from the primary fdw conn pool. + pub async fn primary_fdw(&self, request: &Request) -> Result { + if let Some(ref lb) = self.fdw_lb { + Ok(lb.primary().ok_or(Error::NoPrimary)?.get(request).await?) + } else { + Err(Error::NoFdw) + } + } + + /// Get a connection from one of the replica fdw pools. + pub async fn replica_fdw(&self, request: &Request) -> Result { + if let Some(ref lb) = self.fdw_lb { + lb.get(request).await + } else { + Err(Error::NoFdw) + } + } + + /// Get a connection to either a primary or a replica. + pub async fn primary_or_replica( + &self, + shard: usize, + request: &Request, + ) -> Result { + self.shards + .get(shard) + .ok_or(Error::NoShard(shard))? + .primary_or_replica(request) + .await + } + /// The two clusters have the same databases. pub(crate) fn can_move_conns_to(&self, other: &Cluster) -> bool { self.shards.len() == other.shards.len() @@ -345,6 +395,16 @@ impl Cluster { &self.shards } + /// Get the load balancing strategy. + pub fn lb_strategy(&self) -> LoadBalancingStrategy { + self.lb_strategy + } + + /// Get the read/write split strategy. + pub fn rw_split(&self) -> ReadWriteSplit { + self.rw_split + } + /// Get the password the user should use to connect to the database. pub fn password(&self) -> &str { &self.password @@ -430,6 +490,14 @@ impl Cluster { true } + pub fn cross_shard_backend(&self) -> CrossShardBackend { + self.cross_shard_backend + } + + pub fn fdw_fallback_enabled(&self) -> bool { + self.cross_shard_backend().need_fdw() + } + /// This database/user pair is responsible for schema management. pub fn schema_admin(&self) -> bool { self.schema_admin @@ -535,6 +603,10 @@ impl Cluster { shard.launch(); } + if let Some(ref fdw_lb) = self.fdw_lb { + fdw_lb.launch(); + } + // Only spawn schema loading once per cluster, even if launch() is called multiple times. let already_started = self .readiness @@ -679,10 +751,7 @@ mod test { name: Some("sharded".into()), column: "id".into(), primary: true, - centroids: vec![], data_type: DataType::Bigint, - centroids_path: None, - centroid_probes: 1, hasher: Hasher::Postgres, ..Default::default() }], @@ -736,6 +805,10 @@ mod test { pub fn set_read_write_strategy(&mut self, rw_strategy: ReadWriteStrategy) { self.rw_strategy = rw_strategy; } + + pub fn set_cross_shard_backend(&mut self, backend: pgdog_config::CrossShardBackend) { + self.cross_shard_backend = backend; + } } #[test] diff --git a/pgdog/src/backend/pool/connection/buffer.rs b/pgdog/src/backend/pool/connection/buffer.rs index 4b9f028c..c07420c2 100644 --- a/pgdog/src/backend/pool/connection/buffer.rs +++ b/pgdog/src/backend/pool/connection/buffer.rs @@ -12,7 +12,7 @@ use crate::{ }, net::{ messages::{DataRow, FromBytes, Message, Protocol, ToBytes, Vector}, - Decoder, + BackendKeyData, Decoder, }, }; @@ -213,7 +213,11 @@ impl Buffer { /// Take messages from buffer. pub(super) fn take(&mut self) -> Option { if self.full { - self.buffer.pop_front().and_then(|s| s.message().ok()) + self.buffer.pop_front().and_then(|s| { + s.message() + .ok() + .map(|m| m.backend(BackendKeyData::default())) + }) } else { None } diff --git a/pgdog/src/backend/pool/connection/mod.rs b/pgdog/src/backend/pool/connection/mod.rs index 7ea42794..c7699c6f 100644 --- a/pgdog/src/backend/pool/connection/mod.rs +++ b/pgdog/src/backend/pool/connection/mod.rs @@ -139,7 +139,25 @@ impl Connection { /// Try to get a connection for the given route. async fn try_conn(&mut self, request: &Request, route: &Route) -> Result<(), Error> { - if let Shard::Direct(shard) = route.shard() { + if route.is_fdw_fallback() { + let server = if route.is_read() { + self.cluster()?.replica_fdw(request).await? + } else { + self.cluster()?.primary_fdw(request).await? + }; + + match &mut self.binding { + Binding::Direct(existing) => { + let _ = existing.replace(server); + } + + Binding::MultiShard(_, _) => { + self.binding = Binding::Direct(Some(server)); + } + + _ => (), + }; + } else if let Shard::Direct(shard) = route.shard() { let mut server = if route.is_read() { self.cluster()?.replica(*shard, request).await? } else { @@ -329,7 +347,7 @@ impl Connection { if config().config.general.passthrough_auth() && !databases().exists(user) { if let Some(ref passthrough_password) = self.passthrough_password { let new_user = User::new(&self.user, passthrough_password, &self.database); - databases::add(new_user); + databases::add(new_user)?; } } diff --git a/pgdog/src/backend/pool/error.rs b/pgdog/src/backend/pool/error.rs index c1fddccb..edab5441 100644 --- a/pgdog/src/backend/pool/error.rs +++ b/pgdog/src/backend/pool/error.rs @@ -50,6 +50,9 @@ pub enum Error { #[error("no databases")] NoDatabases, + #[error("fdw backend not configured")] + NoFdw, + #[error("config values contain null bytes")] NullBytes, diff --git a/pgdog/src/backend/pool/lb/mod.rs b/pgdog/src/backend/pool/lb/mod.rs index 8e030384..f08cee79 100644 --- a/pgdog/src/backend/pool/lb/mod.rs +++ b/pgdog/src/backend/pool/lb/mod.rs @@ -76,7 +76,7 @@ pub struct LoadBalancer { pub(super) round_robin: Arc, /// Chosen load balancing strategy. pub(super) lb_strategy: LoadBalancingStrategy, - /// Maintenance. notification. + /// Maintenance notification. pub(super) maintenance: Arc, /// Read/write split. pub(super) rw_split: ReadWriteSplit, diff --git a/pgdog/src/backend/schema/mod.rs b/pgdog/src/backend/schema/mod.rs index 009bb63e..bae06068 100644 --- a/pgdog/src/backend/schema/mod.rs +++ b/pgdog/src/backend/schema/mod.rs @@ -1,5 +1,6 @@ //! Schema operations. pub mod columns; +pub mod postgres_fdw; pub mod relation; pub mod sync; diff --git a/pgdog/src/backend/schema/postgres_fdw/custom_types.rs b/pgdog/src/backend/schema/postgres_fdw/custom_types.rs new file mode 100644 index 00000000..86b1cf3e --- /dev/null +++ b/pgdog/src/backend/schema/postgres_fdw/custom_types.rs @@ -0,0 +1,277 @@ +//! Custom type definitions (enums, domains, composite types) for foreign tables. + +use std::fmt::Write; + +use crate::net::messages::DataRow; + +use super::quote_identifier; + +/// Query to fetch custom type definitions. +pub static CUSTOM_TYPES_QUERY: &str = include_str!("custom_types.sql"); + +/// Kind of custom type. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum CustomTypeKind { + Enum, + Domain, + Composite, +} + +impl CustomTypeKind { + fn from_str(s: &str) -> Option { + match s { + "enum" => Some(Self::Enum), + "domain" => Some(Self::Domain), + "composite" => Some(Self::Composite), + _ => None, + } + } +} + +/// A custom type definition from the database. +#[derive(Debug, Clone)] +pub struct CustomType { + pub kind: CustomTypeKind, + pub schema_name: String, + pub type_name: String, + /// Base type for domains. + pub base_type: String, + /// Constraint definition for domains (e.g., "CHECK (VALUE > 0)"). + pub constraint_def: String, + /// Default value for domains. + pub default_value: String, + /// Collation name. + pub collation_name: String, + /// Collation schema. + pub collation_schema: String, + /// Comma-separated enum labels. + pub enum_labels: String, + /// Comma-separated composite attributes (name type pairs). + pub composite_attributes: String, +} + +impl From for CustomType { + fn from(value: DataRow) -> Self { + let kind_str = value.get_text(0).unwrap_or_default(); + Self { + kind: CustomTypeKind::from_str(&kind_str).unwrap_or(CustomTypeKind::Enum), + schema_name: value.get_text(1).unwrap_or_default(), + type_name: value.get_text(2).unwrap_or_default(), + base_type: value.get_text(3).unwrap_or_default(), + constraint_def: value.get_text(4).unwrap_or_default(), + default_value: value.get_text(5).unwrap_or_default(), + collation_name: value.get_text(6).unwrap_or_default(), + collation_schema: value.get_text(7).unwrap_or_default(), + enum_labels: value.get_text(8).unwrap_or_default(), + composite_attributes: value.get_text(9).unwrap_or_default(), + } + } +} + +impl CustomType { + /// Fully qualified type name. + pub fn qualified_name(&self) -> String { + format!( + "{}.{}", + quote_identifier(&self.schema_name), + quote_identifier(&self.type_name) + ) + } + + /// Generate the CREATE statement for this type. + pub fn create_statement(&self) -> Result { + match self.kind { + CustomTypeKind::Enum => self.create_enum_statement(), + CustomTypeKind::Domain => self.create_domain_statement(), + CustomTypeKind::Composite => self.create_composite_statement(), + } + } + + fn create_enum_statement(&self) -> Result { + let mut sql = String::new(); + write!(sql, "CREATE TYPE {} AS ENUM (", self.qualified_name())?; + + let labels: Vec<&str> = self.enum_labels.split(',').collect(); + for (i, label) in labels.iter().enumerate() { + if i > 0 { + sql.push_str(", "); + } + write!(sql, "'{}'", label.replace('\'', "''"))?; + } + + sql.push(')'); + Ok(sql) + } + + fn create_domain_statement(&self) -> Result { + let mut sql = String::new(); + write!( + sql, + "CREATE DOMAIN {} AS {}", + self.qualified_name(), + self.base_type + )?; + + if self.has_collation() { + write!( + sql, + " COLLATE {}.{}", + quote_identifier(&self.collation_schema), + quote_identifier(&self.collation_name) + )?; + } + + if !self.default_value.is_empty() { + write!(sql, " DEFAULT {}", self.default_value)?; + } + + if !self.constraint_def.is_empty() { + write!(sql, " {}", self.constraint_def)?; + } + + Ok(sql) + } + + fn create_composite_statement(&self) -> Result { + let mut sql = String::new(); + write!(sql, "CREATE TYPE {} AS (", self.qualified_name())?; + + // Split on newlines since type definitions can contain commas + let attrs: Vec<&str> = self.composite_attributes.split('\n').collect(); + for (i, attr) in attrs.iter().enumerate() { + if i > 0 { + sql.push_str(", "); + } + let attr = attr.trim(); + if let Some((name, typ)) = attr.split_once(' ') { + write!(sql, "{} {}", quote_identifier(name), typ)?; + } else { + sql.push_str(attr); + } + } + + sql.push(')'); + Ok(sql) + } + + fn has_collation(&self) -> bool { + !self.collation_name.is_empty() && !self.collation_schema.is_empty() + } +} + +/// Collection of custom types from a database. +#[derive(Debug, Clone, Default)] +pub struct CustomTypes { + types: Vec, +} + +impl CustomTypes { + /// Load custom types from a server. + pub(crate) async fn load( + server: &mut crate::backend::Server, + ) -> Result { + let types: Vec = server.fetch_all(CUSTOM_TYPES_QUERY).await?; + Ok(Self { types }) + } + + /// Create all custom types on the target server. + pub(crate) async fn setup( + &self, + server: &mut crate::backend::Server, + ) -> Result<(), crate::backend::Error> { + for custom_type in &self.types { + let stmt = custom_type.create_statement()?; + tracing::debug!("[fdw::setup] {} [{}]", stmt, server.addr()); + server.execute(&stmt).await?; + } + + Ok(()) + } + + /// Get the types. + pub fn types(&self) -> &[CustomType] { + &self.types + } + + /// Check if there are any custom types. + pub fn is_empty(&self) -> bool { + self.types.is_empty() + } +} + +#[cfg(test)] +mod test { + use super::*; + + fn test_enum() -> CustomType { + CustomType { + kind: CustomTypeKind::Enum, + schema_name: "core".into(), + type_name: "user_status".into(), + base_type: String::new(), + constraint_def: String::new(), + default_value: String::new(), + collation_name: String::new(), + collation_schema: String::new(), + enum_labels: "active,inactive,suspended".into(), + composite_attributes: String::new(), + } + } + + fn test_domain() -> CustomType { + CustomType { + kind: CustomTypeKind::Domain, + schema_name: "core".into(), + type_name: "email".into(), + base_type: "character varying(255)".into(), + constraint_def: "CHECK ((VALUE)::text ~ '^[A-Za-z0-9._%+-]+@'::text)".into(), + default_value: String::new(), + collation_name: String::new(), + collation_schema: String::new(), + enum_labels: String::new(), + composite_attributes: String::new(), + } + } + + fn test_composite() -> CustomType { + CustomType { + kind: CustomTypeKind::Composite, + schema_name: "core".into(), + type_name: "geo_point".into(), + base_type: String::new(), + constraint_def: String::new(), + default_value: String::new(), + collation_name: String::new(), + collation_schema: String::new(), + enum_labels: String::new(), + composite_attributes: "latitude numeric(9,6)\nlongitude numeric(9,6)".into(), + } + } + + #[test] + fn test_create_enum_statement() { + let t = test_enum(); + let sql = t.create_statement().unwrap(); + assert_eq!( + sql, + r#"CREATE TYPE "core"."user_status" AS ENUM ('active', 'inactive', 'suspended')"# + ); + } + + #[test] + fn test_create_domain_statement() { + let t = test_domain(); + let sql = t.create_statement().unwrap(); + assert!(sql.contains(r#"CREATE DOMAIN "core"."email" AS character varying(255)"#)); + assert!(sql.contains("CHECK")); + } + + #[test] + fn test_create_composite_statement() { + let t = test_composite(); + let sql = t.create_statement().unwrap(); + assert!(sql.contains(r#"CREATE TYPE "core"."geo_point" AS ("#)); + assert!(sql.contains(r#""latitude" numeric(9,6)"#)); + assert!(sql.contains(r#""longitude" numeric(9,6)"#)); + } +} diff --git a/pgdog/src/backend/schema/postgres_fdw/custom_types.sql b/pgdog/src/backend/schema/postgres_fdw/custom_types.sql new file mode 100644 index 00000000..95360cce --- /dev/null +++ b/pgdog/src/backend/schema/postgres_fdw/custom_types.sql @@ -0,0 +1,85 @@ +-- Query to fetch custom type definitions (enums, domains, composite types) +-- for recreation on FDW server before creating foreign tables. + +-- Enums: type info and values +SELECT + 'enum' AS type_kind, + n.nspname::text AS schema_name, + t.typname::text AS type_name, + NULL::text AS base_type, + NULL::text AS constraint_def, + NULL::text AS default_value, + NULL::text AS collation_name, + NULL::text AS collation_schema, + string_agg(e.enumlabel::text, ',' ORDER BY e.enumsortorder)::text AS enum_labels, + NULL::text AS composite_attributes +FROM pg_catalog.pg_type t +JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid +JOIN pg_catalog.pg_enum e ON e.enumtypid = t.oid +WHERE t.typtype = 'e' + AND n.nspname <> 'pg_catalog' + AND n.nspname !~ '^pg_toast' + AND n.nspname <> 'information_schema' +GROUP BY n.nspname, t.typname + +UNION ALL + +-- Domains: base type, constraints, defaults, collation +SELECT + 'domain' AS type_kind, + n.nspname::text AS schema_name, + t.typname::text AS type_name, + pg_catalog.format_type(t.typbasetype, t.typtypmod)::text AS base_type, + ( + SELECT string_agg(pg_catalog.pg_get_constraintdef(c.oid, true), ' ' ORDER BY c.conname) + FROM pg_catalog.pg_constraint c + WHERE c.contypid = t.oid + )::text AS constraint_def, + t.typdefault::text AS default_value, + coll.collname::text AS collation_name, + collnsp.nspname::text AS collation_schema, + NULL::text AS enum_labels, + NULL::text AS composite_attributes +FROM pg_catalog.pg_type t +JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid +LEFT JOIN pg_catalog.pg_collation coll ON coll.oid = t.typcollation +LEFT JOIN pg_catalog.pg_namespace collnsp ON collnsp.oid = coll.collnamespace +WHERE t.typtype = 'd' + AND n.nspname <> 'pg_catalog' + AND n.nspname !~ '^pg_toast' + AND n.nspname <> 'information_schema' + +UNION ALL + +-- Composite types (excluding table row types) +-- Uses newline as separator since type definitions can contain commas +SELECT + 'composite' AS type_kind, + n.nspname::text AS schema_name, + t.typname::text AS type_name, + NULL::text AS base_type, + NULL::text AS constraint_def, + NULL::text AS default_value, + NULL::text AS collation_name, + NULL::text AS collation_schema, + NULL::text AS enum_labels, + ( + SELECT string_agg( + a.attname || ' ' || pg_catalog.format_type(a.atttypid, a.atttypmod), + E'\n' ORDER BY a.attnum + ) + FROM pg_catalog.pg_attribute a + WHERE a.attrelid = t.typrelid + AND a.attnum > 0 + AND NOT a.attisdropped + )::text AS composite_attributes +FROM pg_catalog.pg_type t +JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid +JOIN pg_catalog.pg_class c ON c.oid = t.typrelid +WHERE t.typtype = 'c' + AND c.relkind = 'c' -- Only standalone composite types, not table row types + AND n.nspname <> 'pg_catalog' + AND n.nspname !~ '^pg_toast' + AND n.nspname <> 'information_schema' + +ORDER BY type_kind, schema_name, type_name diff --git a/pgdog/src/backend/schema/postgres_fdw/error.rs b/pgdog/src/backend/schema/postgres_fdw/error.rs new file mode 100644 index 00000000..fca71e70 --- /dev/null +++ b/pgdog/src/backend/schema/postgres_fdw/error.rs @@ -0,0 +1,13 @@ +//! Errors for foreign table statement generation. + +use std::fmt; +use thiserror::Error; + +/// Errors that can occur when building foreign table statements. +#[derive(Debug, Error)] +pub enum Error { + #[error("no columns provided")] + NoColumns, + #[error("format error: {0}")] + Format(#[from] fmt::Error), +} diff --git a/pgdog/src/backend/schema/postgres_fdw/extensions.rs b/pgdog/src/backend/schema/postgres_fdw/extensions.rs new file mode 100644 index 00000000..61a9cc50 --- /dev/null +++ b/pgdog/src/backend/schema/postgres_fdw/extensions.rs @@ -0,0 +1,125 @@ +//! Extension definitions for foreign tables. + +use std::fmt::Write; + +use crate::net::messages::DataRow; + +use super::quote_identifier; + +/// Query to fetch installed extensions. +pub static EXTENSIONS_QUERY: &str = include_str!("extensions.sql"); + +/// An installed extension. +#[derive(Debug, Clone)] +pub struct Extension { + pub name: String, + pub schema_name: String, + pub version: String, +} + +impl From for Extension { + fn from(value: DataRow) -> Self { + Self { + name: value.get_text(0).unwrap_or_default(), + schema_name: value.get_text(1).unwrap_or_default(), + version: value.get_text(2).unwrap_or_default(), + } + } +} + +impl Extension { + /// Generate the CREATE EXTENSION statement. + pub fn create_statement(&self) -> Result { + let mut sql = String::new(); + write!( + sql, + "CREATE EXTENSION IF NOT EXISTS {}", + quote_identifier(&self.name) + )?; + + // Only specify schema if it's not the default 'public' + if !self.schema_name.is_empty() && self.schema_name != "public" { + write!(sql, " SCHEMA {}", quote_identifier(&self.schema_name))?; + } + + Ok(sql) + } +} + +/// Collection of extensions from a database. +#[derive(Debug, Clone, Default)] +pub struct Extensions { + extensions: Vec, +} + +impl Extensions { + /// Load extensions from a server. + pub(crate) async fn load( + server: &mut crate::backend::Server, + ) -> Result { + let extensions: Vec = server.fetch_all(EXTENSIONS_QUERY).await?; + Ok(Self { extensions }) + } + + /// Create all extensions on the target server. + pub(crate) async fn setup( + &self, + server: &mut crate::backend::Server, + ) -> Result<(), crate::backend::Error> { + for ext in &self.extensions { + let stmt = ext.create_statement()?; + tracing::debug!("[fdw::setup] {} [{}]", stmt, server.addr()); + server.execute(&stmt).await?; + } + + Ok(()) + } + + /// Get the extensions. + pub fn extensions(&self) -> &[Extension] { + &self.extensions + } + + /// Check if there are any extensions. + pub fn is_empty(&self) -> bool { + self.extensions.is_empty() + } +} + +#[cfg(test)] +mod test { + use super::*; + + fn test_extension() -> Extension { + Extension { + name: "ltree".into(), + schema_name: "public".into(), + version: "1.2".into(), + } + } + + fn test_extension_with_schema() -> Extension { + Extension { + name: "pg_trgm".into(), + schema_name: "extensions".into(), + version: "1.6".into(), + } + } + + #[test] + fn test_create_extension_statement() { + let ext = test_extension(); + let sql = ext.create_statement().unwrap(); + assert_eq!(sql, r#"CREATE EXTENSION IF NOT EXISTS "ltree""#); + } + + #[test] + fn test_create_extension_statement_with_schema() { + let ext = test_extension_with_schema(); + let sql = ext.create_statement().unwrap(); + assert_eq!( + sql, + r#"CREATE EXTENSION IF NOT EXISTS "pg_trgm" SCHEMA "extensions""# + ); + } +} diff --git a/pgdog/src/backend/schema/postgres_fdw/extensions.sql b/pgdog/src/backend/schema/postgres_fdw/extensions.sql new file mode 100644 index 00000000..db87a2d2 --- /dev/null +++ b/pgdog/src/backend/schema/postgres_fdw/extensions.sql @@ -0,0 +1,10 @@ +-- Query to fetch installed extensions for recreation on FDW server. +-- Excludes built-in extensions that are always available. +SELECT + e.extname::text AS extension_name, + n.nspname::text AS schema_name, + e.extversion::text AS version +FROM pg_catalog.pg_extension e +JOIN pg_catalog.pg_namespace n ON e.extnamespace = n.oid +WHERE e.extname NOT IN ('plpgsql') -- Exclude always-installed extensions +ORDER BY e.extname diff --git a/pgdog/src/backend/schema/postgres_fdw/mod.rs b/pgdog/src/backend/schema/postgres_fdw/mod.rs new file mode 100644 index 00000000..b2f6957e --- /dev/null +++ b/pgdog/src/backend/schema/postgres_fdw/mod.rs @@ -0,0 +1,21 @@ +//! Schema information for creating foreign tables via postgres_fdw. + +mod custom_types; +mod error; +mod extensions; +mod schema; +mod statement; + +#[cfg(test)] +mod test; + +pub use custom_types::{CustomType, CustomTypeKind, CustomTypes, CUSTOM_TYPES_QUERY}; +pub use error::Error; +pub use extensions::{Extension, Extensions, EXTENSIONS_QUERY}; +pub use schema::{FdwServerDef, ForeignTableColumn, ForeignTableSchema, FOREIGN_TABLE_SCHEMA}; +pub use statement::{ + create_foreign_table, create_foreign_table_with_children, CreateForeignTableResult, + ForeignTableBuilder, PartitionStrategy, TypeMismatch, +}; + +pub(crate) use statement::quote_identifier; diff --git a/pgdog/src/backend/schema/postgres_fdw/postgres_fdw.sql b/pgdog/src/backend/schema/postgres_fdw/postgres_fdw.sql new file mode 100644 index 00000000..d010ceaf --- /dev/null +++ b/pgdog/src/backend/schema/postgres_fdw/postgres_fdw.sql @@ -0,0 +1,45 @@ +SELECT + n.nspname::text AS schema_name, + c.relname::text AS table_name, + a.attname::text AS column_name, + pg_catalog.format_type(a.atttypid, a.atttypmod)::text AS column_type, + a.attnotnull::text AS is_not_null, + pg_catalog.pg_get_expr(ad.adbin, ad.adrelid)::text AS column_default, + a.attgenerated::text AS generated, + coll.collname::text AS collation_name, + collnsp.nspname::text AS collation_schema, + c.relispartition::text AS is_partition, + COALESCE(parent_class.relname, '')::text AS parent_table_name, + COALESCE(parent_ns.nspname, '')::text AS parent_schema_name, + COALESCE(pg_catalog.pg_get_expr(c.relpartbound, c.oid), '')::text AS partition_bound, + COALESCE(pg_catalog.pg_get_partkeydef(c.oid), '')::text AS partition_key +FROM pg_catalog.pg_class c +JOIN pg_catalog.pg_namespace n ON + c.relnamespace = n.oid +LEFT JOIN pg_catalog.pg_attribute a ON + a.attrelid = c.oid + AND a.attnum > 0 + AND NOT a.attisdropped +LEFT JOIN pg_catalog.pg_attrdef ad ON + ad.adrelid = c.oid + AND ad.adnum = a.attnum +LEFT JOIN pg_catalog.pg_collation coll ON + coll.oid = a.attcollation +LEFT JOIN pg_catalog.pg_namespace collnsp ON + collnsp.oid = coll.collnamespace +LEFT JOIN pg_catalog.pg_inherits inh ON + inh.inhrelid = c.oid +LEFT JOIN pg_catalog.pg_class parent_class ON + parent_class.oid = inh.inhparent +LEFT JOIN pg_catalog.pg_namespace parent_ns ON + parent_ns.oid = parent_class.relnamespace +WHERE + c.relkind IN ('r', 'v', 'f', 'm', 'p') + AND n.nspname <> 'pg_catalog' + AND n.nspname !~ '^pg_toast' + AND n.nspname <> 'information_schema' + AND NOT (n.nspname = 'pgdog' AND c.relname IN ('validator_bigint', 'validator_uuid', 'config')) +ORDER BY + n.nspname, + c.relname, + a.attnum diff --git a/pgdog/src/backend/schema/postgres_fdw/schema.rs b/pgdog/src/backend/schema/postgres_fdw/schema.rs new file mode 100644 index 00000000..3737affa --- /dev/null +++ b/pgdog/src/backend/schema/postgres_fdw/schema.rs @@ -0,0 +1,341 @@ +//! Foreign table schema query and data structures. + +use std::collections::{HashMap, HashSet}; +use tracing::{debug, warn}; + +use crate::{ + backend::{ + schema::postgres_fdw::{create_foreign_table, create_foreign_table_with_children}, + Server, ShardingSchema, + }, + net::messages::DataRow, +}; + +use super::custom_types::CustomTypes; +use super::extensions::Extensions; +use super::quote_identifier; +use super::TypeMismatch; + +/// Server definition for FDW setup. +#[derive(Debug, Clone)] +pub struct FdwServerDef { + pub shard_num: usize, + pub host: String, + pub port: u16, + pub database_name: String, + pub user: String, + pub password: String, + pub mapping_user: String, +} + +/// Query to fetch table and column information needed for CREATE FOREIGN TABLE statements. +pub static FOREIGN_TABLE_SCHEMA: &str = include_str!("postgres_fdw.sql"); + +/// Row from the foreign table schema query. +/// Each row represents a single column in a table, or a table with no columns. +#[derive(Debug, Clone)] +pub struct ForeignTableColumn { + pub schema_name: String, + pub table_name: String, + /// Empty if the table has no columns. + pub column_name: String, + /// Column type with modifiers, e.g. "character varying(255)". + pub column_type: String, + pub is_not_null: bool, + /// Default expression, also used for generated column expressions. + pub column_default: String, + /// 's' for stored generated column, empty otherwise. + pub generated: String, + pub collation_name: String, + pub collation_schema: String, + /// Whether this table is a partition of another table. + pub is_partition: bool, + /// Parent table name if this is a partition, empty otherwise. + pub parent_table_name: String, + /// Parent schema name if this is a partition, empty otherwise. + pub parent_schema_name: String, + /// Partition bound expression, e.g. "FOR VALUES FROM ('2024-01-01') TO ('2025-01-01')". + pub partition_bound: String, + /// Partition key definition if this is a partitioned table, e.g. "RANGE (created_at)". + pub partition_key: String, +} + +impl From for ForeignTableColumn { + fn from(value: DataRow) -> Self { + Self { + schema_name: value.get_text(0).unwrap_or_default(), + table_name: value.get_text(1).unwrap_or_default(), + column_name: value.get_text(2).unwrap_or_default(), + column_type: value.get_text(3).unwrap_or_default(), + is_not_null: value.get_text(4).unwrap_or_default() == "true", + column_default: value.get_text(5).unwrap_or_default(), + generated: value.get_text(6).unwrap_or_default(), + collation_name: value.get_text(7).unwrap_or_default(), + collation_schema: value.get_text(8).unwrap_or_default(), + is_partition: value.get_text(9).unwrap_or_default() == "true", + parent_table_name: value.get_text(10).unwrap_or_default(), + parent_schema_name: value.get_text(11).unwrap_or_default(), + partition_bound: value.get_text(12).unwrap_or_default(), + partition_key: value.get_text(13).unwrap_or_default(), + } + } +} + +#[derive(Debug, Clone)] +pub struct ForeignTableSchema { + tables: HashMap<(String, String), Vec>, + extensions: Extensions, + custom_types: CustomTypes, +} + +impl ForeignTableSchema { + pub(crate) async fn load(server: &mut Server) -> Result { + let tables = ForeignTableColumn::load(server).await?; + let extensions = Extensions::load(server).await?; + let custom_types = CustomTypes::load(server).await?; + Ok(Self { + tables, + extensions, + custom_types, + }) + } + + /// Full setup: creates servers, schemas, types, and tables. + pub(crate) async fn setup( + &self, + server: &mut Server, + sharding_schema: &ShardingSchema, + servers: &[FdwServerDef], + ) -> Result<(), super::super::Error> { + // Create extensions first (types may depend on them) + self.extensions.setup(server).await?; + + server.execute("BEGIN").await?; + + // Drop and recreate managed schemas (CASCADE drops tables and types) + self.drop_schemas(server).await?; + + // Drop and recreate servers (must happen after schema drop, before foreign table creation) + for srv in servers { + server + .execute(format!( + r#"DROP SERVER IF EXISTS "shard_{}" CASCADE"#, + srv.shard_num + )) + .await?; + + server + .execute(format!( + r#"CREATE SERVER "shard_{}" + FOREIGN DATA WRAPPER postgres_fdw + OPTIONS (host '{}', port '{}', dbname '{}')"#, + srv.shard_num, srv.host, srv.port, srv.database_name, + )) + .await?; + + server + .execute(format!( + r#"CREATE USER MAPPING + FOR {} + SERVER "shard_{}" + OPTIONS (user '{}', password '{}')"#, + quote_identifier(&srv.mapping_user), + srv.shard_num, + srv.user, + srv.password, + )) + .await?; + } + + self.create_schemas(server).await?; + + // Create custom types (enums, domains, composite types) + self.custom_types.setup(server).await?; + + // Build a map of parent tables to their child partitions + let children_map = self.build_children_map(); + + let mut processed_tables = HashSet::new(); + let mut all_type_mismatches: Vec = Vec::new(); + + for ((schema, table), columns) in &self.tables { + // Skip internal PgDog tables + if Self::is_internal_table(schema, table) { + continue; + } + + // Skip partitions - they are handled when processing their parent + if columns.first().is_some_and(|c| c.is_partition) { + continue; + } + + let dedup = (schema.clone(), table.clone()); + if !processed_tables.contains(&dedup) { + // Check if this table has child partitions + let children = children_map + .get(&dedup) + .map(|child_keys| { + child_keys + .iter() + .filter_map(|key| self.tables.get(key).cloned()) + .collect::>() + }) + .unwrap_or_default(); + + let result = if children.is_empty() { + create_foreign_table(columns, sharding_schema)? + } else { + create_foreign_table_with_children(columns, sharding_schema, children)? + }; + + for sql in &result.statements { + debug!("[fdw::setup] {} [{}]", sql, server.addr()); + server.execute(sql).await?; + } + all_type_mismatches.extend(result.type_mismatches); + processed_tables.insert(dedup); + } + } + + // Log summary of type mismatches if any were found + if !all_type_mismatches.is_empty() { + warn!( + "[fdw] {} table(s) skipped due to sharding config type mismatches:", + all_type_mismatches.len() + ); + for mismatch in &all_type_mismatches { + warn!("[fdw] - {}", mismatch); + } + } + + server.execute("COMMIT").await?; + Ok(()) + } + + /// Add user mappings only (for additional users on an already-configured database). + pub(crate) async fn setup_user_mappings( + server: &mut Server, + servers: &[FdwServerDef], + ) -> Result<(), super::super::Error> { + for srv in servers { + server + .execute(format!( + r#"CREATE USER MAPPING IF NOT EXISTS + FOR {} + SERVER "shard_{}" + OPTIONS (user '{}', password '{}')"#, + quote_identifier(&srv.mapping_user), + srv.shard_num, + srv.user, + srv.password, + )) + .await?; + } + Ok(()) + } + + /// Get the extensions. + pub fn extensions(&self) -> &Extensions { + &self.extensions + } + + /// Get the custom types. + pub fn custom_types(&self) -> &CustomTypes { + &self.custom_types + } + + /// Get the tables map (for testing). + #[cfg(test)] + pub fn tables(&self) -> &HashMap<(String, String), Vec> { + &self.tables + } + + /// Check if a table is an internal PgDog table that shouldn't be exposed via FDW. + fn is_internal_table(schema: &str, table: &str) -> bool { + schema == "pgdog" && matches!(table, "validator_bigint" | "validator_uuid" | "config") + } + + /// Build a map of parent tables to their child partition keys. + fn build_children_map(&self) -> HashMap<(String, String), Vec<(String, String)>> { + let mut children_map: HashMap<(String, String), Vec<(String, String)>> = HashMap::new(); + + for ((schema, table), columns) in &self.tables { + if let Some(first_col) = columns.first() { + if first_col.is_partition && !first_col.parent_table_name.is_empty() { + let parent_key = ( + first_col.parent_schema_name.clone(), + first_col.parent_table_name.clone(), + ); + let child_key = (schema.clone(), table.clone()); + children_map.entry(parent_key).or_default().push(child_key); + } + } + } + + children_map + } + + /// Collect unique schemas from tables and custom types. + fn schemas(&self) -> HashSet { + self.tables + .keys() + .map(|(s, _)| s.clone()) + .chain( + self.custom_types + .types() + .iter() + .map(|t| t.schema_name.clone()), + ) + .collect() + } + + /// Drop schemas we manage (with CASCADE to drop tables and types). + async fn drop_schemas(&self, server: &mut Server) -> Result<(), super::super::Error> { + for schema in &self.schemas() { + server + .execute(&format!( + "DROP SCHEMA IF EXISTS {} CASCADE", + super::quote_identifier(schema) + )) + .await?; + } + Ok(()) + } + + /// Create schemas we manage. + async fn create_schemas(&self, server: &mut Server) -> Result<(), super::super::Error> { + for schema in &self.schemas() { + server + .execute(&format!( + "CREATE SCHEMA {}", + super::quote_identifier(schema) + )) + .await?; + } + Ok(()) + } +} + +impl ForeignTableColumn { + /// Check if this column has a collation. + pub(super) fn has_collation(&self) -> bool { + !self.collation_name.is_empty() && !self.collation_schema.is_empty() + } + + /// Fetch columns and organize by schema and table name. + async fn load( + server: &mut Server, + ) -> Result>, crate::backend::Error> { + let mut result = HashMap::new(); + let rows: Vec = server.fetch_all(FOREIGN_TABLE_SCHEMA).await?; + + for row in rows { + let entry = result + .entry((row.schema_name.clone(), row.table_name.clone())) + .or_insert_with(Vec::new); + entry.push(row); + } + + Ok(result) + } +} diff --git a/pgdog/src/backend/schema/postgres_fdw/statement.rs b/pgdog/src/backend/schema/postgres_fdw/statement.rs new file mode 100644 index 00000000..d8f45cd7 --- /dev/null +++ b/pgdog/src/backend/schema/postgres_fdw/statement.rs @@ -0,0 +1,973 @@ +//! CREATE FOREIGN TABLE statement generation. + +use std::fmt::Write; + +use rand::Rng; + +use crate::backend::pool::ShardingSchema; +use crate::config::{DataType, FlexibleType, ShardedTable}; +use crate::frontend::router::parser::Column; +use crate::frontend::router::sharding::Mapping; + +use super::{Error, ForeignTableColumn}; + +/// A type mismatch between a table column and the configured sharding data type. +#[derive(Debug, Clone)] +pub struct TypeMismatch { + pub schema_name: String, + pub table_name: String, + pub column_name: String, + pub column_type: String, + pub configured_type: DataType, +} + +impl std::fmt::Display for TypeMismatch { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}.{}.{}: column type '{}' does not match configured data_type '{:?}'", + self.schema_name, + self.table_name, + self.column_name, + self.column_type, + self.configured_type + ) + } +} + +/// Result of creating foreign table statements. +pub struct CreateForeignTableResult { + pub statements: Vec, + pub type_mismatches: Vec, +} + +/// Format a FlexibleType as a SQL literal. +fn flexible_type_to_sql(value: &FlexibleType) -> String { + match value { + FlexibleType::Integer(i) => i.to_string(), + FlexibleType::Uuid(u) => format!("'{}'", u), + FlexibleType::String(s) => format!("'{}'", s.replace('\'', "''")), + } +} + +/// Check if a PostgreSQL column type string matches the configured DataType. +fn column_type_matches_data_type(column_type: &str, data_type: DataType) -> bool { + let col_lower = column_type.to_lowercase(); + match data_type { + DataType::Bigint => { + col_lower.starts_with("bigint") + || col_lower.starts_with("int8") + || col_lower.starts_with("bigserial") + || col_lower.starts_with("serial8") + || col_lower.starts_with("integer") + || col_lower.starts_with("int4") + || col_lower.starts_with("int") + || col_lower.starts_with("smallint") + || col_lower.starts_with("int2") + } + DataType::Uuid => col_lower.starts_with("uuid"), + DataType::Varchar => { + col_lower.starts_with("character varying") + || col_lower.starts_with("varchar") + || col_lower.starts_with("text") + } + DataType::Vector => col_lower.starts_with("vector"), + } +} + +/// Partition strategy for a sharded table. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PartitionStrategy { + Hash, + List, + Range, +} + +impl PartitionStrategy { + /// Determine partition strategy from sharded table config. + fn from_sharded_table(table: &ShardedTable) -> Self { + match &table.mapping { + Some(Mapping::List(_)) => Self::List, + Some(Mapping::Range(_)) => Self::Range, + None => Self::Hash, + } + } + + /// SQL keyword for this partition strategy. + fn as_sql(&self) -> &'static str { + match self { + Self::Hash => "HASH", + Self::List => "LIST", + Self::Range => "RANGE", + } + } +} + +/// Quote an identifier if needed (simple Postgres-style quoting). +pub(crate) fn quote_identifier(name: &str) -> String { + let needs_quoting = name.is_empty() + || !name.starts_with(|c: char| c.is_ascii_lowercase() || c == '_') + || name.starts_with('_') && name.chars().nth(1).is_some_and(|c| c.is_ascii_digit()) + || !name + .chars() + .all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '_'); + + if needs_quoting { + format!("\"{}\"", name.replace('"', "\"\"")) + } else { + format!(r#""{}""#, name.to_string()) + } +} + +/// Escape a string literal for use in SQL. +fn escape_literal(s: &str) -> String { + format!("'{}'", s.replace('\'', "''")) +} + +/// Format a fully qualified table name (schema.table). +fn qualified_table(schema: &str, table: &str) -> String { + format!("{}.{}", quote_identifier(schema), quote_identifier(table)) +} + +/// Builder for CREATE FOREIGN TABLE statements. +pub struct ForeignTableBuilder<'a> { + columns: &'a [ForeignTableColumn], + sharding_schema: &'a ShardingSchema, + type_mismatches: Vec, + /// Child partitions for two-tier partitioning (each Vec is columns for one partition). + children: Vec>, +} + +impl<'a> ForeignTableBuilder<'a> { + /// Create a new builder with required parameters. + pub fn new(columns: &'a [ForeignTableColumn], sharding_schema: &'a ShardingSchema) -> Self { + Self { + columns, + sharding_schema, + type_mismatches: Vec::new(), + children: Vec::new(), + } + } + + /// Add child partitions for two-tier partitioning. + /// Each Vec represents columns for one child partition. + pub fn with_children(mut self, children: Vec>) -> Self { + self.children = children; + self + } + + /// Find the sharding configuration for this table, if any. + /// Records any type mismatches encountered and returns a cloned config. + fn find_sharded_config(&mut self) -> Option { + let first = self.columns.first()?; + let table_name = &first.table_name; + let schema_name = &first.schema_name; + + for col in self.columns { + let column = Column { + name: &col.column_name, + table: Some(table_name.as_str()), + schema: Some(schema_name.as_str()), + }; + + if let Some(sharded) = self.sharding_schema.tables().get_table(column) { + if !column_type_matches_data_type(&col.column_type, sharded.data_type) { + let mismatch = TypeMismatch { + schema_name: schema_name.clone(), + table_name: table_name.clone(), + column_name: col.column_name.clone(), + column_type: col.column_type.clone(), + configured_type: sharded.data_type, + }; + self.type_mismatches.push(mismatch); + continue; + } + return Some(sharded.clone()); + } + } + + None + } + + /// Build column definitions SQL fragment (shared between parent and foreign tables). + fn build_columns(&self) -> Result { + let mut sql = String::new(); + let mut first_col = true; + + for col in self.columns { + if col.column_name.is_empty() { + continue; + } + + if first_col { + first_col = false; + } else { + sql.push_str(",\n"); + } + + write!( + sql, + " {} {}", + quote_identifier(&col.column_name), + col.column_type + )?; + + if col.has_collation() { + write!( + sql, + " COLLATE {}.{}", + quote_identifier(&col.collation_schema), + quote_identifier(&col.collation_name) + )?; + } + + if col.is_not_null { + sql.push_str(" NOT NULL"); + } + } + + Ok(sql) + } + + /// Build the CREATE TABLE / CREATE FOREIGN TABLE statement(s). + pub fn build(mut self) -> Result { + let first = self.columns.first().ok_or(Error::NoColumns)?; + let schema_name = &first.schema_name.clone(); + let table_name = &first.table_name.clone(); + + let statements = if let Some(sharded) = self.find_sharded_config() { + self.build_sharded(table_name, schema_name, &sharded)? + } else if !self.type_mismatches.is_empty() { + // Skip tables with type mismatches entirely + vec![] + } else { + self.build_foreign_table(table_name, schema_name)? + }; + + Ok(CreateForeignTableResult { + statements, + type_mismatches: self.type_mismatches, + }) + } + + /// Build a simple foreign table (non-sharded). + fn build_foreign_table( + &self, + table_name: &str, + schema_name: &str, + ) -> Result, Error> { + let mut sql = String::new(); + writeln!( + sql, + "CREATE FOREIGN TABLE {} (", + qualified_table(schema_name, table_name) + )?; + + // Column definitions with OPTIONS for foreign tables + let mut first_col = true; + for col in self.columns { + if col.column_name.is_empty() { + continue; + } + + if first_col { + first_col = false; + } else { + sql.push_str(",\n"); + } + + write!( + sql, + " {} {}", + quote_identifier(&col.column_name), + col.column_type + )?; + + write!( + sql, + " OPTIONS (column_name {})", + escape_literal(&col.column_name) + )?; + + if col.has_collation() { + write!( + sql, + " COLLATE {}.{}", + quote_identifier(&col.collation_schema), + quote_identifier(&col.collation_name) + )?; + } + + if col.is_not_null { + sql.push_str(" NOT NULL"); + } + } + + sql.push('\n'); + sql.push(')'); + + let shard = rand::rng().random_range(0..self.sharding_schema.shards.max(1)); + write!( + sql, + "\nSERVER shard_{}\nOPTIONS (schema_name {}, table_name {})", + shard, + escape_literal(schema_name), + escape_literal(table_name) + )?; + + Ok(vec![sql]) + } + + /// Build a sharded table: parent table + foreign table partitions. + /// If children partitions exist, creates two-tier partitioning. + fn build_sharded( + &self, + table_name: &str, + schema_name: &str, + sharded: &ShardedTable, + ) -> Result, Error> { + if self.children.is_empty() { + self.build_sharded_single_tier(table_name, schema_name, sharded) + } else { + self.build_sharded_two_tier(table_name, schema_name, sharded) + } + } + + /// Build single-tier sharding: parent table with foreign table partitions. + fn build_sharded_single_tier( + &self, + table_name: &str, + schema_name: &str, + sharded: &ShardedTable, + ) -> Result, Error> { + let strategy = PartitionStrategy::from_sharded_table(sharded); + let mut statements = Vec::new(); + + // Create parent table with PARTITION BY + let mut parent = String::new(); + let qualified_name = qualified_table(schema_name, table_name); + writeln!(parent, "CREATE TABLE {} (", qualified_name)?; + parent.push_str(&self.build_columns()?); + parent.push('\n'); + write!( + parent, + ") PARTITION BY {} ({})", + strategy.as_sql(), + quote_identifier(&sharded.column) + )?; + statements.push(parent); + + // Create foreign table partitions for each shard + self.build_foreign_partitions( + &mut statements, + table_name, + schema_name, + &qualified_name, + sharded, + )?; + + Ok(statements) + } + + /// Build two-tier sharding: parent table -> intermediate partitions -> foreign table partitions. + fn build_sharded_two_tier( + &self, + table_name: &str, + schema_name: &str, + sharded: &ShardedTable, + ) -> Result, Error> { + let shard_strategy = PartitionStrategy::from_sharded_table(sharded); + let mut statements = Vec::new(); + + // Get the parent's original partition key (e.g., "RANGE (created_at)") + let parent_partition_key = self + .columns + .first() + .map(|c| c.partition_key.as_str()) + .unwrap_or(""); + + // Create parent table with original PARTITION BY + let mut parent = String::new(); + let qualified_name = qualified_table(schema_name, table_name); + writeln!(parent, "CREATE TABLE {} (", qualified_name)?; + parent.push_str(&self.build_columns()?); + parent.push('\n'); + write!(parent, ") PARTITION BY {}", parent_partition_key)?; + statements.push(parent); + + // For each child partition, create an intermediate partition that is itself partitioned + for child_columns in &self.children { + let Some(first_col) = child_columns.first() else { + continue; + }; + + let child_table_name = &first_col.table_name; + let child_schema_name = &first_col.schema_name; + let partition_bound = &first_col.partition_bound; + + // Create intermediate partition table (partitioned by hash on shard key) + let mut intermediate = String::new(); + let qualified_child = qualified_table(child_schema_name, child_table_name); + + write!( + intermediate, + "CREATE TABLE {} PARTITION OF {} ", + qualified_child, qualified_name + )?; + write!(intermediate, "{}", partition_bound)?; + write!( + intermediate, + " PARTITION BY {} ({})", + shard_strategy.as_sql(), + quote_identifier(&sharded.column) + )?; + statements.push(intermediate); + + // Create foreign table partitions for this intermediate partition + self.build_foreign_partitions( + &mut statements, + child_table_name, + child_schema_name, + &qualified_child, + sharded, + )?; + } + + Ok(statements) + } + + /// Build foreign table partitions for each shard. + fn build_foreign_partitions( + &self, + statements: &mut Vec, + table_name: &str, + schema_name: &str, + qualified_parent: &str, + sharded: &ShardedTable, + ) -> Result<(), Error> { + for shard in 0..self.sharding_schema.shards { + let mut partition = String::new(); + let partition_table_name = format!("{}_shard_{}", table_name, shard); + let qualified_partition = qualified_table(schema_name, &partition_table_name); + let server_name = format!("shard_{}", shard); + + write!( + partition, + "CREATE FOREIGN TABLE {} PARTITION OF {} ", + qualified_partition, qualified_parent + )?; + + // Partition bounds (always hash for foreign partitions in two-tier) + match &sharded.mapping { + None => { + write!( + partition, + "FOR VALUES WITH (MODULUS {}, REMAINDER {})", + self.sharding_schema.shards, shard + )?; + } + Some(Mapping::List(list_shards)) => { + let values = list_shards.values_for_shard(shard); + if values.is_empty() { + write!(partition, "DEFAULT")?; + } else { + let values_sql: Vec<_> = + values.iter().map(|v| flexible_type_to_sql(v)).collect(); + write!(partition, "FOR VALUES IN ({})", values_sql.join(", "))?; + } + } + Some(Mapping::Range(ranges)) => { + if let Some(range) = ranges.iter().find(|r| r.shard == shard) { + let start = range + .start + .as_ref() + .map(flexible_type_to_sql) + .unwrap_or_else(|| "MINVALUE".to_string()); + let end = range + .end + .as_ref() + .map(flexible_type_to_sql) + .unwrap_or_else(|| "MAXVALUE".to_string()); + write!(partition, "FOR VALUES FROM ({}) TO ({})", start, end)?; + } else { + write!(partition, "DEFAULT")?; + } + } + } + + write!( + partition, + "\nSERVER {}\nOPTIONS (schema_name {}, table_name {})", + quote_identifier(&server_name), + escape_literal(schema_name), + escape_literal(table_name) + )?; + + statements.push(partition); + } + Ok(()) + } +} + +/// Generate CREATE FOREIGN TABLE statements from column definitions. +/// +/// All columns must belong to the same table. If the table is found in sharded_tables +/// configuration, creates a partitioned parent table with foreign table partitions +/// for each shard. Server names are generated as `shard_{n}`. +/// +/// Returns statements and any type mismatches encountered. +pub fn create_foreign_table( + columns: &[ForeignTableColumn], + sharding_schema: &ShardingSchema, +) -> Result { + ForeignTableBuilder::new(columns, sharding_schema).build() +} + +/// Generate CREATE FOREIGN TABLE statements with two-tier partitioning. +/// +/// For sharded tables with existing partitions, creates: +/// 1. Parent table with PARTITION BY (on shard key) +/// 2. Intermediate partition tables (with original bounds, further partitioned by shard key) +/// 3. Foreign table partitions for each shard +/// +/// Each entry in `children` is columns for one child partition. +pub fn create_foreign_table_with_children( + columns: &[ForeignTableColumn], + sharding_schema: &ShardingSchema, + children: Vec>, +) -> Result { + ForeignTableBuilder::new(columns, sharding_schema) + .with_children(children) + .build() +} + +#[cfg(test)] +mod test { + use std::collections::HashSet; + + use super::*; + use crate::backend::replication::ShardedTables; + use crate::config::{DataType, FlexibleType, ShardedMapping, ShardedMappingKind}; + + fn test_column(name: &str, col_type: &str) -> ForeignTableColumn { + ForeignTableColumn { + schema_name: "public".into(), + table_name: "test_table".into(), + column_name: name.into(), + column_type: col_type.into(), + is_not_null: false, + column_default: String::new(), + generated: String::new(), + collation_name: String::new(), + collation_schema: String::new(), + is_partition: false, + parent_table_name: String::new(), + parent_schema_name: String::new(), + partition_bound: String::new(), + partition_key: String::new(), + } + } + + fn test_sharded_table(table: &str, column: &str) -> ShardedTable { + ShardedTable { + database: "test".into(), + name: Some(table.into()), + schema: Some("public".into()), + column: column.into(), + primary: false, + centroids: vec![], + centroids_path: None, + data_type: DataType::Bigint, + centroid_probes: 0, + hasher: Default::default(), + mapping: None, + } + } + + fn test_sharded_table_with_mapping( + table: &str, + column: &str, + mapping: Mapping, + data_type: DataType, + ) -> ShardedTable { + ShardedTable { + mapping: Some(mapping), + data_type, + ..test_sharded_table(table, column) + } + } + + fn sharding_schema_with_tables(tables: ShardedTables, shards: usize) -> ShardingSchema { + ShardingSchema { + shards, + tables, + schemas: Default::default(), + rewrite: Default::default(), + query_parser_engine: Default::default(), + } + } + + fn list_mapping() -> Mapping { + let mapping = ShardedMapping { + database: "test".into(), + column: "region".into(), + table: Some("test_table".into()), + schema: Some("public".into()), + kind: ShardedMappingKind::List, + start: None, + end: None, + values: HashSet::from([FlexibleType::String("us".into())]), + shard: 0, + }; + Mapping::new(&[mapping]).unwrap() + } + + fn range_mapping() -> Mapping { + let mapping = ShardedMapping { + database: "test".into(), + column: "id".into(), + table: Some("test_table".into()), + schema: Some("public".into()), + kind: ShardedMappingKind::Range, + start: Some(FlexibleType::Integer(0)), + end: Some(FlexibleType::Integer(1000)), + values: HashSet::new(), + shard: 0, + }; + Mapping::new(&[mapping]).unwrap() + } + + #[test] + fn test_create_foreign_table_basic() { + let columns = vec![ + ForeignTableColumn { + is_not_null: true, + ..test_column("id", "bigint") + }, + ForeignTableColumn { + column_default: "'unknown'::character varying".into(), + ..test_column("name", "character varying(100)") + }, + ]; + + let schema = sharding_schema_with_tables(ShardedTables::default(), 1); + let statements = create_foreign_table(&columns, &schema).unwrap(); + + assert_eq!(statements.statements.len(), 1); + let sql = &statements.statements[0]; + assert!(sql.contains(r#"CREATE FOREIGN TABLE "public"."test_table""#)); + assert!(sql.contains("bigint")); + assert!(sql.contains("NOT NULL")); + assert!(sql.contains("OPTIONS (column_name 'id')")); + assert!(sql.contains("character varying(100)")); + assert!(!sql.contains("DEFAULT")); // Defaults handled by remote table + assert!(sql.contains("SERVER")); + assert!(sql.contains("schema_name 'public'")); + assert!(!sql.contains("PARTITION BY")); + } + + #[test] + fn test_create_foreign_table_with_hash_sharding() { + let columns = vec![ + ForeignTableColumn { + is_not_null: true, + ..test_column("id", "bigint") + }, + test_column("name", "text"), + ]; + + let tables: ShardedTables = [test_sharded_table("test_table", "id")].as_slice().into(); + let schema = sharding_schema_with_tables(tables, 2); + + let statements = create_foreign_table(&columns, &schema).unwrap(); + + assert_eq!(statements.statements.len(), 3); // parent + 2 partitions + assert!(statements.statements[0].contains(r#"CREATE TABLE "public"."test_table""#)); + assert!(statements.statements[0].contains(r#"PARTITION BY HASH ("id")"#)); + assert!(statements.statements[1].contains( + r#"CREATE FOREIGN TABLE "public"."test_table_shard_0" PARTITION OF "public"."test_table""# + )); + assert!(statements.statements[1].contains("FOR VALUES WITH (MODULUS 2, REMAINDER 0)")); + assert!(statements.statements[1].contains(r#"SERVER "shard_0""#)); + assert!(statements.statements[2].contains( + r#"CREATE FOREIGN TABLE "public"."test_table_shard_1" PARTITION OF "public"."test_table""# + )); + assert!(statements.statements[2].contains("FOR VALUES WITH (MODULUS 2, REMAINDER 1)")); + assert!(statements.statements[2].contains(r#"SERVER "shard_1""#)); + } + + #[test] + fn test_create_foreign_table_with_list_sharding() { + let columns = vec![test_column("id", "bigint"), test_column("region", "text")]; + + let tables: ShardedTables = [test_sharded_table_with_mapping( + "test_table", + "region", + list_mapping(), + DataType::Varchar, + )] + .as_slice() + .into(); + let schema = sharding_schema_with_tables(tables, 1); + + let statements = create_foreign_table(&columns, &schema).unwrap(); + + assert!(statements.statements[0].contains(r#"CREATE TABLE "public"."test_table""#)); + assert!(statements.statements[0].contains(r#"PARTITION BY LIST ("region")"#)); + } + + #[test] + fn test_create_foreign_table_with_range_sharding() { + let columns = vec![test_column("id", "bigint"), test_column("name", "text")]; + + let tables: ShardedTables = [test_sharded_table_with_mapping( + "test_table", + "id", + range_mapping(), + DataType::Bigint, + )] + .as_slice() + .into(); + let schema = sharding_schema_with_tables(tables, 1); + + let statements = create_foreign_table(&columns, &schema).unwrap(); + + assert!(statements.statements[0].contains(r#"CREATE TABLE "public"."test_table""#)); + assert!(statements.statements[0].contains(r#"PARTITION BY RANGE ("id")"#)); + } + + #[test] + fn test_create_foreign_table_no_shard_match() { + let columns = vec![test_column("id", "bigint"), test_column("name", "text")]; + + let tables: ShardedTables = [test_sharded_table("other_table", "user_id")] + .as_slice() + .into(); + let schema = sharding_schema_with_tables(tables, 2); + + let statements = create_foreign_table(&columns, &schema).unwrap(); + + assert_eq!(statements.statements.len(), 1); + assert!(statements.statements[0].contains(r#"CREATE FOREIGN TABLE "public"."test_table""#)); + assert!(!statements.statements[0].contains("PARTITION BY")); + } + + #[test] + fn test_create_foreign_table_column_mismatch() { + let columns = vec![test_column("id", "bigint"), test_column("name", "text")]; + + let tables: ShardedTables = [test_sharded_table("test_table", "user_id")] + .as_slice() + .into(); + let schema = sharding_schema_with_tables(tables, 2); + + let statements = create_foreign_table(&columns, &schema).unwrap(); + + assert_eq!(statements.statements.len(), 1); + assert!(statements.statements[0].contains(r#"CREATE FOREIGN TABLE "public"."test_table""#)); + assert!(!statements.statements[0].contains("PARTITION BY")); + } + + #[test] + fn test_create_foreign_table_with_generated() { + let columns = vec![ForeignTableColumn { + column_default: "(price * quantity)".into(), + generated: "s".into(), + ..test_column("total", "numeric") + }]; + + let schema = sharding_schema_with_tables(ShardedTables::default(), 1); + let statements = create_foreign_table(&columns, &schema).unwrap(); + + assert!(statements.statements[0].contains(r#"CREATE FOREIGN TABLE "public"."test_table""#)); + // Defaults and generated columns handled by remote table + assert!(!statements.statements[0].contains("GENERATED")); + assert!(!statements.statements[0].contains("DEFAULT")); + } + + #[test] + fn test_create_foreign_table_with_collation() { + let columns = vec![ForeignTableColumn { + collation_name: "en_US".into(), + collation_schema: "pg_catalog".into(), + ..test_column("title", "text") + }]; + + let schema = sharding_schema_with_tables(ShardedTables::default(), 1); + let statements = create_foreign_table(&columns, &schema).unwrap(); + + assert!(statements.statements[0].contains(r#"CREATE FOREIGN TABLE "public"."test_table""#)); + assert!(statements.statements[0].contains(r#"COLLATE "pg_catalog"."en_US""#)); + } + + #[test] + fn test_create_foreign_table_empty_columns() { + let schema = sharding_schema_with_tables(ShardedTables::default(), 1); + let result = create_foreign_table(&[], &schema); + assert!(result.is_err()); + } + + #[test] + fn test_quote_identifier() { + // All identifiers are now quoted + assert_eq!(quote_identifier("users"), "\"users\""); + assert_eq!(quote_identifier("my table"), "\"my table\""); + assert_eq!(quote_identifier("123abc"), "\"123abc\""); + assert_eq!(quote_identifier("has\"quote"), "\"has\"\"quote\""); + assert_eq!(quote_identifier("CamelCase"), "\"CamelCase\""); + assert_eq!(quote_identifier("_valid"), "\"_valid\""); + } + + fn test_partition_column( + table_name: &str, + name: &str, + col_type: &str, + parent_table: &str, + partition_bound: &str, + ) -> ForeignTableColumn { + ForeignTableColumn { + schema_name: "public".into(), + table_name: table_name.into(), + column_name: name.into(), + column_type: col_type.into(), + is_not_null: false, + column_default: String::new(), + generated: String::new(), + collation_name: String::new(), + collation_schema: String::new(), + is_partition: true, + parent_table_name: parent_table.into(), + parent_schema_name: "public".into(), + partition_bound: partition_bound.into(), + partition_key: String::new(), + } + } + + fn test_partitioned_parent_column( + name: &str, + col_type: &str, + partition_key: &str, + ) -> ForeignTableColumn { + ForeignTableColumn { + partition_key: partition_key.into(), + ..test_column(name, col_type) + } + } + + #[test] + fn test_create_foreign_table_two_tier_partitioning() { + // Parent table "orders" partitioned by RANGE on date, with children partitioned by hash across shards + let parent_columns = vec![ + test_partitioned_parent_column("id", "bigint", "RANGE (created_at)"), + test_partitioned_parent_column("created_at", "date", "RANGE (created_at)"), + test_partitioned_parent_column("data", "text", "RANGE (created_at)"), + ]; + + // Child partitions with their bounds + let partition_2024 = vec![ + test_partition_column( + "orders_2024", + "id", + "bigint", + "test_table", + "FOR VALUES FROM ('2024-01-01') TO ('2025-01-01')", + ), + test_partition_column( + "orders_2024", + "created_at", + "date", + "test_table", + "FOR VALUES FROM ('2024-01-01') TO ('2025-01-01')", + ), + test_partition_column( + "orders_2024", + "data", + "text", + "test_table", + "FOR VALUES FROM ('2024-01-01') TO ('2025-01-01')", + ), + ]; + + let partition_2025 = vec![ + test_partition_column( + "orders_2025", + "id", + "bigint", + "test_table", + "FOR VALUES FROM ('2025-01-01') TO ('2026-01-01')", + ), + test_partition_column( + "orders_2025", + "created_at", + "date", + "test_table", + "FOR VALUES FROM ('2025-01-01') TO ('2026-01-01')", + ), + test_partition_column( + "orders_2025", + "data", + "text", + "test_table", + "FOR VALUES FROM ('2025-01-01') TO ('2026-01-01')", + ), + ]; + + let tables: ShardedTables = [test_sharded_table("test_table", "id")].as_slice().into(); + let schema = sharding_schema_with_tables(tables, 2); + + let result = ForeignTableBuilder::new(&parent_columns, &schema) + .with_children(vec![partition_2024, partition_2025]) + .build() + .unwrap(); + + // Expected structure: + // 1. Parent table with original PARTITION BY RANGE (created_at) + // 2. orders_2024 partition (with original bounds) that is itself PARTITION BY HASH + // 3. orders_2024_shard_0 foreign table + // 4. orders_2024_shard_1 foreign table + // 5. orders_2025 partition (with original bounds) that is itself PARTITION BY HASH + // 6. orders_2025_shard_0 foreign table + // 7. orders_2025_shard_1 foreign table + assert_eq!(result.statements.len(), 7); + + // Parent table - uses original partition key (RANGE on date) + assert!(result.statements[0].contains(r#"CREATE TABLE "public"."test_table""#)); + assert!(result.statements[0].contains("PARTITION BY RANGE (created_at)")); + + // orders_2024 intermediate partition + assert!(result.statements[1] + .contains(r#"CREATE TABLE "public"."orders_2024" PARTITION OF "public"."test_table""#)); + assert!(result.statements[1].contains("FOR VALUES FROM ('2024-01-01') TO ('2025-01-01')")); + assert!(result.statements[1].contains(r#"PARTITION BY HASH ("id")"#)); + + // orders_2024_shard_0 foreign table + assert!(result.statements[2].contains( + r#"CREATE FOREIGN TABLE "public"."orders_2024_shard_0" PARTITION OF "public"."orders_2024""# + )); + assert!(result.statements[2].contains("FOR VALUES WITH (MODULUS 2, REMAINDER 0)")); + assert!(result.statements[2].contains(r#"SERVER "shard_0""#)); + + // orders_2024_shard_1 foreign table + assert!(result.statements[3].contains( + r#"CREATE FOREIGN TABLE "public"."orders_2024_shard_1" PARTITION OF "public"."orders_2024""# + )); + assert!(result.statements[3].contains("FOR VALUES WITH (MODULUS 2, REMAINDER 1)")); + assert!(result.statements[3].contains(r#"SERVER "shard_1""#)); + + // orders_2025 intermediate partition + assert!(result.statements[4] + .contains(r#"CREATE TABLE "public"."orders_2025" PARTITION OF "public"."test_table""#)); + assert!(result.statements[4].contains("FOR VALUES FROM ('2025-01-01') TO ('2026-01-01')")); + assert!(result.statements[4].contains(r#"PARTITION BY HASH ("id")"#)); + + // orders_2025_shard_0 foreign table + assert!(result.statements[5].contains( + r#"CREATE FOREIGN TABLE "public"."orders_2025_shard_0" PARTITION OF "public"."orders_2025""# + )); + + // orders_2025_shard_1 foreign table + assert!(result.statements[6].contains( + r#"CREATE FOREIGN TABLE "public"."orders_2025_shard_1" PARTITION OF "public"."orders_2025""# + )); + } +} diff --git a/pgdog/src/backend/schema/postgres_fdw/test/helpers.rs b/pgdog/src/backend/schema/postgres_fdw/test/helpers.rs new file mode 100644 index 00000000..ec33a07c --- /dev/null +++ b/pgdog/src/backend/schema/postgres_fdw/test/helpers.rs @@ -0,0 +1,280 @@ +//! Test helpers for postgres_fdw integration tests. + +use std::collections::HashSet; + +use crate::backend::pool::ShardingSchema; +use crate::backend::replication::ShardedTables; +use crate::backend::schema::postgres_fdw::{ + CreateForeignTableResult, ForeignTableBuilder, ForeignTableColumn, ForeignTableSchema, +}; +use crate::backend::Server; +use crate::config::{DataType, FlexibleType, ShardedMapping, ShardedMappingKind, ShardedTable}; +use crate::frontend::router::sharding::Mapping; + +/// Data type configuration for test tables. +#[derive(Debug, Clone, Copy)] +pub enum TestDataType { + Bigint, + Varchar, + Uuid, +} + +impl TestDataType { + pub fn name(&self) -> &'static str { + match self { + Self::Bigint => "bigint", + Self::Varchar => "varchar", + Self::Uuid => "uuid", + } + } + + pub fn sql_type(&self) -> &'static str { + match self { + Self::Bigint => "BIGINT", + Self::Varchar => "VARCHAR(100)", + Self::Uuid => "UUID", + } + } + + pub fn config_type(&self) -> DataType { + match self { + Self::Bigint => DataType::Bigint, + Self::Varchar => DataType::Varchar, + Self::Uuid => DataType::Uuid, + } + } + + pub fn flexible_values(&self) -> Vec { + match self { + Self::Bigint => vec![ + FlexibleType::Integer(1), + FlexibleType::Integer(2), + FlexibleType::Integer(3), + ], + Self::Varchar => vec![ + FlexibleType::String("us".into()), + FlexibleType::String("eu".into()), + FlexibleType::String("asia".into()), + ], + Self::Uuid => vec![ + FlexibleType::Uuid("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11".parse().unwrap()), + FlexibleType::Uuid("b0eebc99-9c0b-4ef8-bb6d-6bb9bd380a22".parse().unwrap()), + FlexibleType::Uuid("c0eebc99-9c0b-4ef8-bb6d-6bb9bd380a33".parse().unwrap()), + ], + } + } +} + +/// Partition strategy for tests. +#[derive(Debug, Clone, Copy)] +pub enum TestPartitionStrategy { + Hash, + List, + Range, +} + +impl TestPartitionStrategy { + pub fn as_str(&self) -> &'static str { + match self { + Self::Hash => "HASH", + Self::List => "LIST", + Self::Range => "RANGE", + } + } +} + +/// Test fixture for FDW statement generation tests. +pub struct FdwTestFixture { + pub table_name: String, + pub shard_column: String, + pub data_type: TestDataType, + pub strategy: TestPartitionStrategy, + pub num_shards: usize, +} + +impl FdwTestFixture { + pub fn new( + table_name: &str, + shard_column: &str, + data_type: TestDataType, + strategy: TestPartitionStrategy, + ) -> Self { + Self { + table_name: table_name.into(), + shard_column: shard_column.into(), + data_type, + strategy, + num_shards: 2, + } + } + + pub async fn create_table(&self, server: &mut Server) -> Result<(), crate::backend::Error> { + self.cleanup(server).await?; + + let sql = format!( + "CREATE TABLE {} ( + id BIGINT NOT NULL, + {} {} NOT NULL, + data TEXT + )", + self.table_name, + self.shard_column, + self.data_type.sql_type() + ); + server.execute(&sql).await?; + Ok(()) + } + + pub async fn cleanup(&self, server: &mut Server) -> Result<(), crate::backend::Error> { + server + .execute(&format!("DROP TABLE IF EXISTS {} CASCADE", self.table_name)) + .await?; + server + .execute(&format!( + "DROP TABLE IF EXISTS {}_fdw CASCADE", + self.table_name + )) + .await?; + Ok(()) + } + + pub fn sharding_schema(&self, schema_name: &str) -> ShardingSchema { + let mapping = self.create_mapping(); + let sharded_table = ShardedTable { + database: "test".into(), + name: Some(self.table_name.clone()), + schema: Some(schema_name.into()), + column: self.shard_column.clone(), + data_type: self.data_type.config_type(), + mapping, + ..Default::default() + }; + + let tables: ShardedTables = [sharded_table].as_slice().into(); + ShardingSchema { + shards: self.num_shards, + tables, + ..Default::default() + } + } + + fn create_mapping(&self) -> Option { + match self.strategy { + TestPartitionStrategy::Hash => None, + TestPartitionStrategy::List => { + let values = self.data_type.flexible_values(); + let mappings: Vec<_> = values + .into_iter() + .enumerate() + .map(|(i, v)| ShardedMapping { + database: "test".into(), + column: self.shard_column.clone(), + table: Some(self.table_name.clone()), + kind: ShardedMappingKind::List, + values: HashSet::from([v]), + shard: i % self.num_shards, + ..Default::default() + }) + .collect(); + Mapping::new(&mappings) + } + TestPartitionStrategy::Range => { + let values = self.data_type.flexible_values(); + let mappings: Vec<_> = (0..self.num_shards) + .map(|shard| { + let (start, end) = if shard == 0 { + (None, Some(values[1].clone())) + } else { + (Some(values[1].clone()), None) + }; + ShardedMapping { + database: "test".into(), + column: self.shard_column.clone(), + table: Some(self.table_name.clone()), + kind: ShardedMappingKind::Range, + start, + end, + shard, + ..Default::default() + } + }) + .collect(); + Mapping::new(&mappings) + } + } + } + + pub fn generate_statements( + &self, + columns: &[ForeignTableColumn], + sharding_schema: &ShardingSchema, + ) -> CreateForeignTableResult { + ForeignTableBuilder::new(columns, sharding_schema) + .build() + .expect("Statement generation should succeed") + } + + fn expected_statement_count(&self) -> usize { + 1 + self.num_shards + } + + pub fn verify_statements(&self, statements: &[String]) { + assert_eq!( + statements.len(), + self.expected_statement_count(), + "Expected {} statements for {}, got {}", + self.expected_statement_count(), + self.table_name, + statements.len() + ); + + assert!( + statements[0].contains("CREATE TABLE"), + "First statement should be CREATE TABLE: {}", + statements[0] + ); + assert!( + statements[0].contains(&format!("PARTITION BY {}", self.strategy.as_str())), + "Parent should use {} partitioning: {}", + self.strategy.as_str(), + statements[0] + ); + + for (i, stmt) in statements.iter().skip(1).enumerate() { + assert!( + stmt.contains("CREATE FOREIGN TABLE"), + "Statement {} should be CREATE FOREIGN TABLE: {}", + i + 1, + stmt + ); + assert!( + stmt.contains(&format!("shard_{}", i)), + "Partition {} should reference shard_{}: {}", + i, + i, + stmt + ); + } + } + + pub async fn execute_parent_statement( + &self, + server: &mut Server, + statements: &[String], + ) -> Result<(), crate::backend::Error> { + let parent_stmt = + statements[0].replace(&self.table_name, &format!("{}_fdw", self.table_name)); + server.execute(&parent_stmt).await?; + Ok(()) + } +} + +pub fn find_table_columns<'a>( + schema: &'a ForeignTableSchema, + table_name: &str, +) -> Option<&'a Vec> { + schema + .tables() + .get(&("public".into(), table_name.into())) + .or_else(|| schema.tables().get(&("pgdog".into(), table_name.into()))) +} diff --git a/pgdog/src/backend/schema/postgres_fdw/test/mod.rs b/pgdog/src/backend/schema/postgres_fdw/test/mod.rs new file mode 100644 index 00000000..c36f0a6b --- /dev/null +++ b/pgdog/src/backend/schema/postgres_fdw/test/mod.rs @@ -0,0 +1,309 @@ +//! Integration tests for postgres_fdw statement generation. + +mod helpers; + +use helpers::{find_table_columns, FdwTestFixture, TestDataType, TestPartitionStrategy}; + +use super::ForeignTableSchema; +use crate::backend::server::test::test_server; + +/// Test a single-tier sharding scenario. +async fn run_single_tier_test(data_type: TestDataType, strategy: TestPartitionStrategy) { + let mut server = test_server().await; + let fixture = FdwTestFixture::new( + &format!( + "test_{}_{}", + strategy.as_str().to_lowercase(), + data_type.name() + ), + "shard_key", + data_type, + strategy, + ); + + // Create source table + fixture.create_table(&mut server).await.unwrap(); + + // Load schema + let schema = ForeignTableSchema::load(&mut server).await.unwrap(); + + // Find table columns + let columns = find_table_columns(&schema, &fixture.table_name).expect("Table should be loaded"); + + // Get schema name from loaded columns + let schema_name = &columns.first().unwrap().schema_name; + + // Generate statements + let sharding_schema = fixture.sharding_schema(schema_name); + let result = fixture.generate_statements(columns, &sharding_schema); + + // Verify statement structure + fixture.verify_statements(&result.statements); + + // Execute parent statement to verify SQL validity + fixture + .execute_parent_statement(&mut server, &result.statements) + .await + .expect("Parent statement should execute successfully"); + + // Cleanup + fixture.cleanup(&mut server).await.unwrap(); +} + +// Hash partitioning tests +#[tokio::test] +async fn test_hash_bigint() { + run_single_tier_test(TestDataType::Bigint, TestPartitionStrategy::Hash).await; +} + +#[tokio::test] +async fn test_hash_varchar() { + run_single_tier_test(TestDataType::Varchar, TestPartitionStrategy::Hash).await; +} + +#[tokio::test] +async fn test_hash_uuid() { + run_single_tier_test(TestDataType::Uuid, TestPartitionStrategy::Hash).await; +} + +// List partitioning tests +#[tokio::test] +async fn test_list_bigint() { + run_single_tier_test(TestDataType::Bigint, TestPartitionStrategy::List).await; +} + +#[tokio::test] +async fn test_list_varchar() { + run_single_tier_test(TestDataType::Varchar, TestPartitionStrategy::List).await; +} + +#[tokio::test] +async fn test_list_uuid() { + run_single_tier_test(TestDataType::Uuid, TestPartitionStrategy::List).await; +} + +// Range partitioning tests +#[tokio::test] +async fn test_range_bigint() { + run_single_tier_test(TestDataType::Bigint, TestPartitionStrategy::Range).await; +} + +#[tokio::test] +async fn test_range_varchar() { + run_single_tier_test(TestDataType::Varchar, TestPartitionStrategy::Range).await; +} + +#[tokio::test] +async fn test_range_uuid() { + run_single_tier_test(TestDataType::Uuid, TestPartitionStrategy::Range).await; +} + +// Existing tests refactored to use helpers + +#[tokio::test] +async fn test_load_partitioned_table_schema() { + let mut server = test_server().await; + + server + .execute("DROP TABLE IF EXISTS test_partitioned_parent CASCADE") + .await + .unwrap(); + + server + .execute( + "CREATE TABLE test_partitioned_parent ( + id BIGINT NOT NULL, + created_at DATE NOT NULL, + data TEXT + ) PARTITION BY RANGE (created_at)", + ) + .await + .unwrap(); + + server + .execute( + "CREATE TABLE test_partitioned_parent_2024 PARTITION OF test_partitioned_parent + FOR VALUES FROM ('2024-01-01') TO ('2025-01-01')", + ) + .await + .unwrap(); + + server + .execute( + "CREATE TABLE test_partitioned_parent_2025 PARTITION OF test_partitioned_parent + FOR VALUES FROM ('2025-01-01') TO ('2026-01-01')", + ) + .await + .unwrap(); + + let schema = ForeignTableSchema::load(&mut server).await.unwrap(); + + let parent_cols = find_table_columns(&schema, "test_partitioned_parent"); + assert!(parent_cols.is_some(), "Parent table should be loaded"); + let parent_cols = parent_cols.unwrap(); + + let first_parent_col = parent_cols.first().unwrap(); + assert!(!first_parent_col.is_partition); + assert!( + first_parent_col.partition_key.contains("RANGE"), + "Parent should have RANGE partition key, got: {}", + first_parent_col.partition_key + ); + + let child_2024 = find_table_columns(&schema, "test_partitioned_parent_2024"); + assert!(child_2024.is_some(), "Child 2024 should be loaded"); + let first_child = child_2024.unwrap().first().unwrap(); + assert!(first_child.is_partition); + assert_eq!(first_child.parent_table_name, "test_partitioned_parent"); + + server + .execute("DROP TABLE IF EXISTS test_partitioned_parent CASCADE") + .await + .unwrap(); +} + +#[tokio::test] +async fn test_two_tier_partitioning() { + use crate::backend::pool::ShardingSchema; + use crate::backend::replication::ShardedTables; + use crate::config::{DataType, ShardedTable}; + + let mut server = test_server().await; + + server + .execute("DROP TABLE IF EXISTS test_two_tier CASCADE") + .await + .unwrap(); + server + .execute("DROP TABLE IF EXISTS test_two_tier_fdw CASCADE") + .await + .unwrap(); + + server + .execute( + "CREATE TABLE test_two_tier ( + id BIGINT NOT NULL, + customer_id BIGINT NOT NULL, + created_at DATE NOT NULL + ) PARTITION BY RANGE (created_at)", + ) + .await + .unwrap(); + + server + .execute( + "CREATE TABLE test_two_tier_2024 PARTITION OF test_two_tier + FOR VALUES FROM ('2024-01-01') TO ('2025-01-01')", + ) + .await + .unwrap(); + + server + .execute( + "CREATE TABLE test_two_tier_2025 PARTITION OF test_two_tier + FOR VALUES FROM ('2025-01-01') TO ('2026-01-01')", + ) + .await + .unwrap(); + + let schema = ForeignTableSchema::load(&mut server).await.unwrap(); + + let parent_cols = find_table_columns(&schema, "test_two_tier").expect("Parent should exist"); + let child_2024 = find_table_columns(&schema, "test_two_tier_2024") + .expect("Child 2024 should exist") + .clone(); + let child_2025 = find_table_columns(&schema, "test_two_tier_2025") + .expect("Child 2025 should exist") + .clone(); + + let schema_name = &parent_cols.first().unwrap().schema_name; + + let sharded_table = ShardedTable { + database: "test".into(), + name: Some("test_two_tier".into()), + schema: Some(schema_name.clone()), + column: "customer_id".into(), + data_type: DataType::Bigint, + ..Default::default() + }; + + let tables: ShardedTables = [sharded_table].as_slice().into(); + let sharding_schema = ShardingSchema { + shards: 2, + tables, + ..Default::default() + }; + + let result = super::ForeignTableBuilder::new(parent_cols, &sharding_schema) + .with_children(vec![child_2024, child_2025]) + .build() + .unwrap(); + + // 7 statements: parent + 2*(intermediate + 2 foreign) + assert_eq!(result.statements.len(), 7); + + // Parent uses original RANGE partitioning + assert!(result.statements[0].contains("PARTITION BY RANGE")); + + // Intermediate partitions use HASH for sharding + assert!(result.statements[1].contains("PARTITION BY HASH")); + + // Execute parent and intermediate statements + let parent_stmt = result.statements[0].replace("test_two_tier", "test_two_tier_fdw"); + server.execute(&parent_stmt).await.unwrap(); + + let intermediate = result.statements[1] + .replace("test_two_tier_2024", "test_two_tier_fdw_2024") + .replace("test_two_tier", "test_two_tier_fdw"); + server.execute(&intermediate).await.unwrap(); + + server + .execute("DROP TABLE IF EXISTS test_two_tier_fdw CASCADE") + .await + .unwrap(); + server + .execute("DROP TABLE IF EXISTS test_two_tier CASCADE") + .await + .unwrap(); +} + +#[tokio::test] +async fn test_load_foreign_table_schema() { + let mut server = test_server().await; + + server + .execute("DROP TABLE IF EXISTS test_fdw_schema") + .await + .unwrap(); + + server + .execute( + "CREATE TABLE test_fdw_schema ( + id BIGINT NOT NULL, + name VARCHAR(100) DEFAULT 'unknown', + score NUMERIC(10, 2), + created_at TIMESTAMP NOT NULL DEFAULT now() + )", + ) + .await + .unwrap(); + + let schema = ForeignTableSchema::load(&mut server).await.unwrap(); + + let test_rows: Vec<_> = schema + .tables() + .values() + .flatten() + .filter(|r| r.table_name == "test_fdw_schema") + .collect(); + + assert_eq!(test_rows.len(), 4); + + let id_col = test_rows.iter().find(|r| r.column_name == "id").unwrap(); + assert!(id_col.is_not_null); + + server + .execute("DROP TABLE IF EXISTS test_fdw_schema") + .await + .unwrap(); +} diff --git a/pgdog/src/frontend/client/mod.rs b/pgdog/src/frontend/client/mod.rs index 3a07451b..a0232588 100644 --- a/pgdog/src/frontend/client/mod.rs +++ b/pgdog/src/frontend/client/mod.rs @@ -175,7 +175,7 @@ impl Client { if !exists { let user = user_from_params(¶ms, &password).ok(); if let Some(user) = user { - databases::add(user); + databases::add(user)?; } } password.password().map(|p| p.to_owned()) diff --git a/pgdog/src/frontend/client/query_engine/connect.rs b/pgdog/src/frontend/client/query_engine/connect.rs index 240fdbbf..b19cd9dc 100644 --- a/pgdog/src/frontend/client/query_engine/connect.rs +++ b/pgdog/src/frontend/client/query_engine/connect.rs @@ -121,22 +121,18 @@ impl QueryEngine { pub(super) fn transaction_route(&mut self, route: &Route) -> Result { let cluster = self.backend.cluster()?; + let mut route = route.clone(); + if cluster.shards().len() == 1 { - Ok( - Route::write(ShardWithPriority::new_override_transaction(Shard::Direct( - 0, - ))) - .with_read(route.is_read()), - ) - } else if route.is_search_path_driven() { + route.set_shard_mut(ShardWithPriority::new_override_transaction(Shard::Direct( + 0, + ))); + } else if !route.is_search_path_driven() { // Schema-based routing will only go to one shard. - Ok(route.clone()) - } else { - Ok( - Route::write(ShardWithPriority::new_override_transaction(Shard::All)) - .with_read(route.is_read()), - ) + route.set_shard_mut(ShardWithPriority::new_override_transaction(Shard::All)); } + + Ok(route) } fn debug_connected(&self, context: &QueryEngineContext<'_>, connected: bool) { diff --git a/pgdog/src/frontend/client/query_engine/route_query.rs b/pgdog/src/frontend/client/query_engine/route_query.rs index 38875dcf..855b6c63 100644 --- a/pgdog/src/frontend/client/query_engine/route_query.rs +++ b/pgdog/src/frontend/client/query_engine/route_query.rs @@ -88,6 +88,7 @@ impl QueryEngine { context.transaction, context.sticky, )?; + match self.router.query(router_context) { Ok(command) => { context.client_request.route = Some(command.route().clone()); diff --git a/pgdog/src/frontend/listener.rs b/pgdog/src/frontend/listener.rs index 250ec653..2274a615 100644 --- a/pgdog/src/frontend/listener.rs +++ b/pgdog/src/frontend/listener.rs @@ -5,6 +5,7 @@ use std::net::SocketAddr; use std::sync::Arc; use crate::backend::databases::{databases, reload, shutdown}; +use crate::backend::fdw::PostgresLauncher; use crate::config::config; use crate::frontend::client::query_engine::two_pc::Manager; use crate::net::messages::BackendKeyData; @@ -150,6 +151,10 @@ impl Listener { } } + if let Err(_) = timeout(shutdown_timeout, PostgresLauncher::get().shutdown_wait()).await { + error!("[fdw] graceful shutdown failed"); + } + self.shutdown.notify_waiters(); } diff --git a/pgdog/src/frontend/router/parameter_hints.rs b/pgdog/src/frontend/router/parameter_hints.rs index 5585ba11..deae1023 100644 --- a/pgdog/src/frontend/router/parameter_hints.rs +++ b/pgdog/src/frontend/router/parameter_hints.rs @@ -1,4 +1,6 @@ -use pgdog_config::Role; +use std::str::FromStr; + +use pgdog_config::{CrossShardBackend, Role}; use super::parser::Error; use crate::{ @@ -16,6 +18,7 @@ pub struct ParameterHints<'a> { pub pgdog_shard: Option<&'a ParameterValue>, pub pgdog_sharding_key: Option<&'a ParameterValue>, pub pgdog_role: Option<&'a ParameterValue>, + pub pgdog_cross_shard_backend: Option<&'a ParameterValue>, hooks: ParserHooks, } @@ -26,6 +29,7 @@ impl<'a> From<&'a Parameters> for ParameterHints<'a> { pgdog_shard: value.get("pgdog.shard"), pgdog_role: value.get("pgdog.role"), pgdog_sharding_key: value.get("pgdog.sharding_key"), + pgdog_cross_shard_backend: value.get("pgdog.cross_shard_backend"), hooks: ParserHooks::default(), } } @@ -112,6 +116,19 @@ impl ParameterHints<'_> { role } + + /// User said use fdw. + pub(crate) fn use_fdw_fallback(&self) -> bool { + if let Some(ref val) = self.pgdog_cross_shard_backend { + if let Some(s) = val.as_str() { + if let Ok(fdw) = CrossShardBackend::from_str(s) { + return fdw.need_fdw(); + } + } + } + + false + } } #[cfg(test)] @@ -148,6 +165,7 @@ mod tests { pgdog_shard: None, pgdog_sharding_key: Some(&sharding_key), pgdog_role: None, + pgdog_cross_shard_backend: None, hooks: ParserHooks::default(), }; @@ -169,6 +187,7 @@ mod tests { pgdog_shard: None, pgdog_sharding_key: Some(&sharding_key), pgdog_role: None, + pgdog_cross_shard_backend: None, hooks: ParserHooks::default(), }; diff --git a/pgdog/src/frontend/router/parser/cache/ast.rs b/pgdog/src/frontend/router/parser/cache/ast.rs index 2be0e0fb..ff72bc7c 100644 --- a/pgdog/src/frontend/router/parser/cache/ast.rs +++ b/pgdog/src/frontend/router/parser/cache/ast.rs @@ -1,5 +1,5 @@ use pg_query::{parse, parse_raw, protobuf::ObjectType, NodeEnum, NodeRef, ParseResult}; -use pgdog_config::QueryParserEngine; +use pgdog_config::{CrossShardBackend, QueryParserEngine}; use std::fmt::Debug; use std::time::Instant; use std::{collections::HashSet, ops::Deref}; @@ -12,6 +12,7 @@ use super::super::{ }; use super::{Fingerprint, Stats}; use crate::backend::schema::Schema; +use crate::frontend::router::parser::comment::CommentRoute; use crate::frontend::router::parser::rewrite::statement::RewritePlan; use crate::frontend::{BufferedQuery, PreparedStatements}; use crate::net::parameter::ParameterValue; @@ -41,6 +42,8 @@ pub struct AstInner { pub rewrite_plan: RewritePlan, /// Fingerprint. pub fingerprint: Fingerprint, + /// Cross-shard backend. + pub cross_shard_backend: Option, } impl AstInner { @@ -53,6 +56,7 @@ impl AstInner { comment_shard: None, rewrite_plan: RewritePlan::default(), fingerprint: Fingerprint::default(), + cross_shard_backend: None, } } } @@ -81,7 +85,11 @@ impl Ast { QueryParserEngine::PgQueryRaw => parse_raw(query), } .map_err(Error::PgQuery)?; - let (comment_shard, comment_role) = comment(query, schema)?; + let CommentRoute { + shard: comment_shard, + role: comment_role, + cross_shard_backend, + } = comment(query, schema)?; let fingerprint = Fingerprint::new(query, schema.query_parser_engine).map_err(Error::PgQuery)?; @@ -116,6 +124,7 @@ impl Ast { ast, rewrite_plan, fingerprint, + cross_shard_backend, }), }) } diff --git a/pgdog/src/frontend/router/parser/comment.rs b/pgdog/src/frontend/router/parser/comment.rs index a87883ad..15d49a2f 100644 --- a/pgdog/src/frontend/router/parser/comment.rs +++ b/pgdog/src/frontend/router/parser/comment.rs @@ -1,7 +1,7 @@ use once_cell::sync::Lazy; use pg_query::scan_raw; use pg_query::{protobuf::Token, scan}; -use pgdog_config::QueryParserEngine; +use pgdog_config::{CrossShardBackend, QueryParserEngine}; use regex::Regex; use crate::backend::ShardingSchema; @@ -16,6 +16,8 @@ static SHARDING_KEY: Lazy = Lazy::new(|| { Regex::new(r#"pgdog_sharding_key: *(?:"([^"]*)"|'([^']*)'|([0-9a-zA-Z-]+))"#).unwrap() }); static ROLE: Lazy = Lazy::new(|| Regex::new(r#"pgdog_role: *(primary|replica)"#).unwrap()); +static BACKEND: Lazy = + Lazy::new(|| Regex::new(r#"pgdog_cross_shard_backend: fdw"#).unwrap()); fn get_matched_value<'a>(caps: &'a regex::Captures<'a>) -> Option<&'a str> { caps.get(1) @@ -24,6 +26,13 @@ fn get_matched_value<'a>(caps: &'a regex::Captures<'a>) -> Option<&'a str> { .map(|m| m.as_str()) } +#[derive(Debug, Clone, PartialEq, Default)] +pub struct CommentRoute { + pub shard: Option, + pub role: Option, + pub cross_shard_backend: Option, +} + /// Extract shard number from a comment. /// /// Comment style uses the C-style comments (not SQL comments!) @@ -31,16 +40,13 @@ fn get_matched_value<'a>(caps: &'a regex::Captures<'a>) -> Option<&'a str> { /// /// See [`SHARD`] and [`SHARDING_KEY`] for the style of comment we expect. /// -pub fn comment( - query: &str, - schema: &ShardingSchema, -) -> Result<(Option, Option), Error> { +pub fn comment(query: &str, schema: &ShardingSchema) -> Result { let tokens = match schema.query_parser_engine { QueryParserEngine::PgQueryProtobuf => scan(query), QueryParserEngine::PgQueryRaw => scan_raw(query), } .map_err(Error::PgQuery)?; - let mut role = None; + let mut comment_route = CommentRoute::default(); for token in tokens.tokens.iter() { if token.token == Token::CComment as i32 { @@ -48,8 +54,8 @@ pub fn comment( if let Some(cap) = ROLE.captures(comment) { if let Some(r) = cap.get(1) { match r.as_str() { - "primary" => role = Some(Role::Primary), - "replica" => role = Some(Role::Replica), + "primary" => comment_route.role = Some(Role::Primary), + "replica" => comment_route.role = Some(Role::Replica), _ => return Err(Error::RegexError), } } @@ -57,33 +63,33 @@ pub fn comment( if let Some(cap) = SHARDING_KEY.captures(comment) { if let Some(sharding_key) = get_matched_value(&cap) { if let Some(schema) = schema.schemas.get(Some(sharding_key.into())) { - return Ok((Some(schema.shard().into()), role)); + comment_route.shard = Some(schema.shard().into()); + } else { + let ctx = ContextBuilder::infer_from_from_and_config(sharding_key, schema)? + .shards(schema.shards) + .build()?; + comment_route.shard = Some(ctx.apply()?); } - let ctx = ContextBuilder::infer_from_from_and_config(sharding_key, schema)? - .shards(schema.shards) - .build()?; - return Ok((Some(ctx.apply()?), role)); } - } - if let Some(cap) = SHARD.captures(comment) { + } else if let Some(cap) = SHARD.captures(comment) { if let Some(shard) = cap.get(1) { - return Ok(( - Some( - shard - .as_str() - .parse::() - .ok() - .map(Shard::Direct) - .unwrap_or(Shard::All), - ), - role, - )); + comment_route.shard = Some( + shard + .as_str() + .parse::() + .ok() + .map(Shard::Direct) + .unwrap_or(Shard::All), + ); } } + if let Some(_) = BACKEND.captures(comment) { + comment_route.cross_shard_backend = Some(CrossShardBackend::Fdw); + } } } - Ok((None, role)) + Ok(comment_route) } #[cfg(test)] @@ -167,7 +173,7 @@ mod tests { let query = "SELECT * FROM users /* pgdog_role: primary */"; let result = comment(query, &schema).unwrap(); - assert_eq!(result.1, Some(Role::Primary)); + assert_eq!(result.role, Some(Role::Primary)); } #[test] @@ -182,8 +188,8 @@ mod tests { let query = "SELECT * FROM users /* pgdog_role: replica pgdog_shard: 2 */"; let result = comment(query, &schema).unwrap(); - assert_eq!(result.0, Some(Shard::Direct(2))); - assert_eq!(result.1, Some(Role::Replica)); + assert_eq!(result.shard, Some(Shard::Direct(2))); + assert_eq!(result.role, Some(Role::Replica)); } #[test] @@ -198,7 +204,7 @@ mod tests { let query = "SELECT * FROM users /* pgdog_role: replica */"; let result = comment(query, &schema).unwrap(); - assert_eq!(result.1, Some(Role::Replica)); + assert_eq!(result.role, Some(Role::Replica)); } #[test] @@ -213,7 +219,7 @@ mod tests { let query = "SELECT * FROM users /* pgdog_role: invalid */"; let result = comment(query, &schema).unwrap(); - assert_eq!(result.1, None); + assert_eq!(result.role, None); } #[test] @@ -228,7 +234,14 @@ mod tests { let query = "SELECT * FROM users"; let result = comment(query, &schema).unwrap(); - assert_eq!(result.1, None); + assert_eq!(result.role, None); + } + + #[test] + fn test_fdw_fallback() { + let query = "/* pgdog_cross_shard_backend: fdw */ SELECT * FROM users"; + let result = comment(query, &ShardingSchema::default()).unwrap(); + assert_eq!(result.cross_shard_backend, Some(CrossShardBackend::Fdw)); } #[test] @@ -253,6 +266,6 @@ mod tests { let query = "SELECT * FROM users /* pgdog_sharding_key: sales */"; let result = comment(query, &schema).unwrap(); - assert_eq!(result.0, Some(Shard::Direct(1))); + assert_eq!(result.shard, Some(Shard::Direct(1))); } } diff --git a/pgdog/src/frontend/router/parser/query/ddl.rs b/pgdog/src/frontend/router/parser/query/ddl.rs index c3c2a29c..cd3c4cad 100644 --- a/pgdog/src/frontend/router/parser/query/ddl.rs +++ b/pgdog/src/frontend/router/parser/query/ddl.rs @@ -225,7 +225,9 @@ impl QueryParser { calculator.push(ShardWithPriority::new_table(shard)); Ok(Command::Query( - Route::write(calculator.shard()).with_schema_changed(schema_changed), + Route::write(calculator.shard()) + .with_schema_changed(schema_changed) + .with_ddl(true), )) } diff --git a/pgdog/src/frontend/router/parser/query/fdw_fallback.rs b/pgdog/src/frontend/router/parser/query/fdw_fallback.rs new file mode 100644 index 00000000..bdd85f93 --- /dev/null +++ b/pgdog/src/frontend/router/parser/query/fdw_fallback.rs @@ -0,0 +1,338 @@ +//! FDW fallback detection for queries that cannot be executed across shards. +//! +//! Determines when a query should be sent to postgres_fdw instead of being +//! executed directly by pgdog's cross-shard query engine. +#![allow(dead_code)] + +use pg_query::{protobuf::SelectStmt, Node, NodeEnum}; + +use crate::backend::Schema; +use crate::frontend::router::parser::statement::StatementParser; +use crate::frontend::router::parser::Table; +use crate::net::parameter::ParameterValue; + +/// Context for FDW fallback checking that holds schema lookup information. +pub(crate) struct FdwFallbackContext<'a> { + pub db_schema: &'a Schema, + pub user: &'a str, + pub search_path: Option<&'a ParameterValue>, +} + +impl<'a, 'b, 'c> StatementParser<'a, 'b, 'c> { + /// Check if a SELECT statement requires FDW fallback due to CTEs, subqueries, + /// or window functions that cannot be correctly executed across shards. + /// + /// Returns true if: + /// 1. CTEs/subqueries reference unsharded tables without sharding keys + /// 2. Window functions are present (can't be merged across shards) + pub(crate) fn needs_fdw_fallback_for_subqueries( + &self, + stmt: &SelectStmt, + ctx: &FdwFallbackContext, + has_sharding_key: bool, + ) -> bool { + // If the main query already has a sharding key, subqueries are considered + // correlated and inherit the sharding context + if has_sharding_key { + return false; + } + + // Check for window functions in target list + if self.has_window_functions(stmt) { + return true; + } + + // Check CTEs in WITH clause + if let Some(ref with_clause) = stmt.with_clause { + for cte in &with_clause.ctes { + if let Some(NodeEnum::CommonTableExpr(ref cte_expr)) = cte.node { + if let Some(ref ctequery) = cte_expr.ctequery { + if let Some(NodeEnum::SelectStmt(ref inner_select)) = ctequery.node { + if self.check_select_needs_fallback(inner_select, ctx) { + return true; + } + } + } + } + } + } + + // Check subqueries in FROM clause + for from_node in &stmt.from_clause { + if self.check_node_needs_fallback(from_node, ctx) { + return true; + } + } + + // Check subqueries in WHERE clause + if let Some(ref where_clause) = stmt.where_clause { + if self.check_node_needs_fallback(where_clause, ctx) { + return true; + } + } + + false + } + + /// Check if a SELECT statement contains window functions. + fn has_window_functions(&self, stmt: &SelectStmt) -> bool { + for target in &stmt.target_list { + if self.node_has_window_function(target) { + return true; + } + } + false + } + + /// Recursively check if a node contains a window function. + fn node_has_window_function(&self, node: &Node) -> bool { + match &node.node { + Some(NodeEnum::ResTarget(res_target)) => { + if let Some(ref val) = res_target.val { + return self.node_has_window_function(val); + } + false + } + Some(NodeEnum::FuncCall(func)) => { + // Window function has an OVER clause + func.over.is_some() + } + Some(NodeEnum::AExpr(a_expr)) => { + if let Some(ref lexpr) = a_expr.lexpr { + if self.node_has_window_function(lexpr) { + return true; + } + } + if let Some(ref rexpr) = a_expr.rexpr { + if self.node_has_window_function(rexpr) { + return true; + } + } + false + } + Some(NodeEnum::TypeCast(type_cast)) => { + if let Some(ref arg) = type_cast.arg { + return self.node_has_window_function(arg); + } + false + } + Some(NodeEnum::CoalesceExpr(coalesce)) => { + for arg in &coalesce.args { + if self.node_has_window_function(arg) { + return true; + } + } + false + } + Some(NodeEnum::CaseExpr(case_expr)) => { + if let Some(ref arg) = case_expr.arg { + if self.node_has_window_function(arg) { + return true; + } + } + if let Some(ref defresult) = case_expr.defresult { + if self.node_has_window_function(defresult) { + return true; + } + } + for when in &case_expr.args { + if self.node_has_window_function(when) { + return true; + } + } + false + } + Some(NodeEnum::CaseWhen(case_when)) => { + if let Some(ref result) = case_when.result { + if self.node_has_window_function(result) { + return true; + } + } + false + } + _ => false, + } + } + + /// Recursively check if a SELECT statement needs FDW fallback. + fn check_select_needs_fallback(&self, stmt: &SelectStmt, ctx: &FdwFallbackContext) -> bool { + // Handle UNION/INTERSECT/EXCEPT + if let Some(ref larg) = stmt.larg { + if self.check_select_needs_fallback(larg, ctx) { + return true; + } + } + if let Some(ref rarg) = stmt.rarg { + if self.check_select_needs_fallback(rarg, ctx) { + return true; + } + } + + // Check for window functions + if self.has_window_functions(stmt) { + return true; + } + + // Check tables in FROM clause + for from_node in &stmt.from_clause { + if self.check_from_node_has_unsafe_table(from_node, ctx) { + return true; + } + } + + // Recursively check nested CTEs + if let Some(ref with_clause) = stmt.with_clause { + for cte in &with_clause.ctes { + if let Some(NodeEnum::CommonTableExpr(ref cte_expr)) = cte.node { + if let Some(ref ctequery) = cte_expr.ctequery { + if let Some(NodeEnum::SelectStmt(ref inner_select)) = ctequery.node { + if self.check_select_needs_fallback(inner_select, ctx) { + return true; + } + } + } + } + } + } + + // Recursively check subqueries in FROM + for from_node in &stmt.from_clause { + if self.check_node_needs_fallback(from_node, ctx) { + return true; + } + } + + // Check subqueries in WHERE + if let Some(ref where_clause) = stmt.where_clause { + if self.check_node_needs_fallback(where_clause, ctx) { + return true; + } + } + + false + } + + /// Check if a node contains subqueries that need FDW fallback. + fn check_node_needs_fallback(&self, node: &Node, ctx: &FdwFallbackContext) -> bool { + match &node.node { + Some(NodeEnum::RangeSubselect(subselect)) => { + if let Some(ref subquery) = subselect.subquery { + if let Some(NodeEnum::SelectStmt(ref inner_select)) = subquery.node { + return self.check_select_needs_fallback(inner_select, ctx); + } + } + false + } + Some(NodeEnum::SubLink(sublink)) => { + if let Some(ref subselect) = sublink.subselect { + if let Some(NodeEnum::SelectStmt(ref inner_select)) = subselect.node { + return self.check_select_needs_fallback(inner_select, ctx); + } + } + false + } + Some(NodeEnum::JoinExpr(join)) => { + let mut needs_fallback = false; + if let Some(ref larg) = join.larg { + needs_fallback |= self.check_node_needs_fallback(larg, ctx); + } + if let Some(ref rarg) = join.rarg { + needs_fallback |= self.check_node_needs_fallback(rarg, ctx); + } + needs_fallback + } + Some(NodeEnum::BoolExpr(bool_expr)) => { + for arg in &bool_expr.args { + if self.check_node_needs_fallback(arg, ctx) { + return true; + } + } + false + } + Some(NodeEnum::AExpr(a_expr)) => { + if let Some(ref lexpr) = a_expr.lexpr { + if self.check_node_needs_fallback(lexpr, ctx) { + return true; + } + } + if let Some(ref rexpr) = a_expr.rexpr { + if self.check_node_needs_fallback(rexpr, ctx) { + return true; + } + } + false + } + _ => false, + } + } + + /// Check if a FROM clause node references an unsafe (unsharded) table. + fn check_from_node_has_unsafe_table(&self, node: &Node, ctx: &FdwFallbackContext) -> bool { + match &node.node { + Some(NodeEnum::RangeVar(range_var)) => { + let table = Table::from(range_var); + !self.is_table_safe(&table, ctx) + } + Some(NodeEnum::JoinExpr(join)) => { + let mut has_unsafe = false; + if let Some(ref larg) = join.larg { + has_unsafe |= self.check_from_node_has_unsafe_table(larg, ctx); + } + if let Some(ref rarg) = join.rarg { + has_unsafe |= self.check_from_node_has_unsafe_table(rarg, ctx); + } + has_unsafe + } + Some(NodeEnum::RangeSubselect(_)) => { + // Subselects are checked separately for their contents + false + } + _ => false, + } + } + + /// Check if a table is "safe" (sharded or omnisharded). + fn is_table_safe(&self, table: &Table, ctx: &FdwFallbackContext) -> bool { + let sharded_tables = self.sharding_schema().tables(); + + // Check named sharded table configs + for config in sharded_tables.tables() { + if let Some(ref config_name) = config.name { + if table.name == config_name { + // Also check schema match if specified in config + if let Some(ref config_schema) = config.schema { + let config_schema_str: &str = config_schema.as_str(); + if table.schema != Some(config_schema_str) { + continue; + } + } + return true; + } + } + } + + // Check nameless configs by looking up the table in the db schema + let nameless_configs: Vec<_> = sharded_tables + .tables() + .iter() + .filter(|t| t.name.is_none()) + .collect(); + + if !nameless_configs.is_empty() { + if let Some(relation) = ctx.db_schema.table(*table, ctx.user, ctx.search_path) { + for config in &nameless_configs { + if relation.has_column(&config.column) { + return true; + } + } + } + } + + // Check if it's an omnisharded table + if sharded_tables.omnishards().contains_key(table.name) { + return true; + } + + false + } +} diff --git a/pgdog/src/frontend/router/parser/query/mod.rs b/pgdog/src/frontend/router/parser/query/mod.rs index 1207f22a..4981e0b1 100644 --- a/pgdog/src/frontend/router/parser/query/mod.rs +++ b/pgdog/src/frontend/router/parser/query/mod.rs @@ -24,6 +24,7 @@ use super::{ mod ddl; mod delete; mod explain; +mod fdw_fallback; mod plugins; mod schema_sharding; mod select; @@ -347,7 +348,7 @@ impl QueryParser { // e.g. Parse, Describe, Flush-style flow. if !context.router_context.executable { - if let Command::Query(ref query) = command { + if let Command::Query(ref mut query) = command { if query.is_cross_shard() && statement.rewrite_plan.insert_split.is_empty() { context .shards_calculator @@ -355,13 +356,11 @@ impl QueryParser { round_robin::next() % context.shards, ))); + query.set_shard_mut(context.shards_calculator.shard().clone()); + // Since this query isn't executable and we decided // to route it to any shard, we can early return here. - return Ok(Command::Query( - query - .clone() - .with_shard(context.shards_calculator.shard().clone()), - )); + return Ok(command); } } } @@ -382,6 +381,14 @@ impl QueryParser { if shard.is_direct() { route.set_shard_mut(shard); } + + // User requested fdw backend. Cool, but never for DDL. + if context.router_context.parameter_hints.use_fdw_fallback() + && !route.is_ddl() + && context.router_context.cluster.fdw_fallback_enabled() + { + route.set_fdw_fallback(true); + } } // Set plugin-specified route, if available. diff --git a/pgdog/src/frontend/router/parser/query/select.rs b/pgdog/src/frontend/router/parser/query/select.rs index 56770175..d76eb1c2 100644 --- a/pgdog/src/frontend/router/parser/query/select.rs +++ b/pgdog/src/frontend/router/parser/query/select.rs @@ -32,10 +32,17 @@ impl QueryParser { writes.writes = true; } + let fdw_fallback = cached_ast + .cross_shard_backend + .map(|backend| backend.need_fdw()) + .unwrap_or_default(); + // Early return for any direct-to-shard queries. if context.shards_calculator.shard().is_direct() { return Ok(Command::Query( - Route::read(context.shards_calculator.shard().clone()).with_write(writes), + Route::read(context.shards_calculator.shard().clone()) + .with_write(writes) + .with_fdw_fallback(fdw_fallback), )); } @@ -201,6 +208,8 @@ impl QueryParser { query.with_aggregate_rewrite_plan_mut(cached_ast.rewrite_plan.aggregates.clone()); } + query.set_fdw_fallback(fdw_fallback); + Ok(Command::Query(query.with_write(writes))) } diff --git a/pgdog/src/frontend/router/parser/query/test/mod.rs b/pgdog/src/frontend/router/parser/query/test/mod.rs index 8f84467e..a090f98e 100644 --- a/pgdog/src/frontend/router/parser/query/test/mod.rs +++ b/pgdog/src/frontend/router/parser/query/test/mod.rs @@ -27,6 +27,7 @@ pub mod test_ddl; pub mod test_delete; pub mod test_dml; pub mod test_explain; +pub mod test_fdw_fallback; pub mod test_functions; pub mod test_insert; pub mod test_rr; diff --git a/pgdog/src/frontend/router/parser/query/test/setup.rs b/pgdog/src/frontend/router/parser/query/test/setup.rs index 36457e86..e15d22a2 100644 --- a/pgdog/src/frontend/router/parser/query/test/setup.rs +++ b/pgdog/src/frontend/router/parser/query/test/setup.rs @@ -1,6 +1,6 @@ use std::ops::Deref; -use pgdog_config::ConfigAndUsers; +use pgdog_config::{ConfigAndUsers, CrossShardBackend}; use crate::{ backend::Cluster, @@ -101,7 +101,12 @@ impl QueryParserTest { self } - /// Startup parameters. + /// Enable FDW fallback (sets cross_shard_backend to Hybrid). + pub(crate) fn with_fdw_fallback(mut self) -> Self { + self.cluster + .set_cross_shard_backend(CrossShardBackend::Hybrid); + self + } /// Execute a request and return the command (panics on error). pub(crate) fn execute(&mut self, request: Vec) -> Command { diff --git a/pgdog/src/frontend/router/parser/query/test/test_ddl.rs b/pgdog/src/frontend/router/parser/query/test/test_ddl.rs index bf2307d1..c3704b54 100644 --- a/pgdog/src/frontend/router/parser/query/test/test_ddl.rs +++ b/pgdog/src/frontend/router/parser/query/test/test_ddl.rs @@ -14,6 +14,10 @@ fn test_create_table() { assert!(command.route().is_write()); assert_eq!(command.route().shard(), &Shard::All); + assert!( + !command.route().is_fdw_fallback(), + "DDL should not trigger FDW fallback" + ); } #[test] @@ -24,6 +28,10 @@ fn test_drop_table() { assert!(command.route().is_write()); assert_eq!(command.route().shard(), &Shard::All); + assert!( + !command.route().is_fdw_fallback(), + "DDL should not trigger FDW fallback" + ); } #[test] @@ -37,6 +45,10 @@ fn test_alter_table() { assert!(command.route().is_write()); assert_eq!(command.route().shard(), &Shard::All); + assert!( + !command.route().is_fdw_fallback(), + "DDL should not trigger FDW fallback" + ); } #[test] @@ -49,6 +61,10 @@ fn test_create_index() { assert!(command.route().is_write()); assert_eq!(command.route().shard(), &Shard::All); + assert!( + !command.route().is_fdw_fallback(), + "DDL should not trigger FDW fallback" + ); } #[test] @@ -59,6 +75,10 @@ fn test_drop_index() { assert!(command.route().is_write()); assert_eq!(command.route().shard(), &Shard::All); + assert!( + !command.route().is_fdw_fallback(), + "DDL should not trigger FDW fallback" + ); } #[test] @@ -69,6 +89,10 @@ fn test_truncate() { assert!(command.route().is_write()); assert_eq!(command.route().shard(), &Shard::All); + assert!( + !command.route().is_fdw_fallback(), + "DDL should not trigger FDW fallback" + ); } #[test] @@ -79,6 +103,10 @@ fn test_create_sequence() { assert!(command.route().is_write()); assert_eq!(command.route().shard(), &Shard::All); + assert!( + !command.route().is_fdw_fallback(), + "DDL should not trigger FDW fallback" + ); } #[test] @@ -88,6 +116,10 @@ fn test_vacuum() { let command = test.execute(vec![Query::new("VACUUM sharded").into()]); assert!(command.route().is_write()); + assert!( + !command.route().is_fdw_fallback(), + "DDL should not trigger FDW fallback" + ); } #[test] @@ -97,6 +129,10 @@ fn test_analyze() { let command = test.execute(vec![Query::new("ANALYZE sharded").into()]); assert!(command.route().is_write()); + assert!( + !command.route().is_fdw_fallback(), + "DDL should not trigger FDW fallback" + ); } #[test] diff --git a/pgdog/src/frontend/router/parser/query/test/test_fdw_fallback.rs b/pgdog/src/frontend/router/parser/query/test/test_fdw_fallback.rs new file mode 100644 index 00000000..b7f3c362 --- /dev/null +++ b/pgdog/src/frontend/router/parser/query/test/test_fdw_fallback.rs @@ -0,0 +1,435 @@ +use crate::net::messages::Query; + +use super::setup::*; + +#[test] +fn test_fdw_fallback_comment() { + let mut test = QueryParserTest::new().with_fdw_fallback(); + + let command = test.execute(vec![Query::new( + "/* pgdog_cross_shard_backend: fdw */ SELECT * FROM sharded ORDER BY id LIMIT 10 OFFSET 5", + ) + .into()]); + + let route = command.route(); + assert!(route.is_fdw_fallback(),); +} + +#[test] +fn test_fdw_fallback_comment_if_direct() { + let mut test = QueryParserTest::new().with_fdw_fallback(); + + let command = test.execute(vec![Query::new( + "/* pgdog_cross_shard_backend: fdw */ SELECT * FROM sharded WHERE id = 1", + ) + .into()]); + + let route = command.route(); + assert!(route.is_fdw_fallback(),); +} + +// ============================================================================= +// Config verification tests +// ============================================================================= + +/// FDW fallback should NOT be triggered when cross_shard_backend is not +/// configured for FDW (default is Pgdog). +#[test] +#[ignore] +fn test_fdw_fallback_requires_config() { + // Without with_fdw_fallback(), cross_shard_backend defaults to Pgdog + let mut test = QueryParserTest::new(); + + // This query would normally trigger FDW fallback (OFFSET > 0) + let command = test.execute(vec![Query::new( + "SELECT * FROM sharded ORDER BY id LIMIT 10 OFFSET 5", + ) + .into()]); + + let route = command.route(); + assert!( + !route.is_fdw_fallback(), + "FDW fallback should NOT be triggered when cross_shard_backend is Pgdog (default)" + ); +} + +// ============================================================================= +// OFFSET tests +// ============================================================================= + +/// Cross-shard SELECT with OFFSET > 0 should trigger FDW fallback +/// because OFFSET cannot be correctly applied across shards without +/// fetching all rows first. +#[test] +#[ignore] +fn test_cross_shard_offset_triggers_fdw_fallback() { + let mut test = QueryParserTest::new().with_fdw_fallback(); + + // This query goes to all shards (no sharding key) with OFFSET + let command = test.execute(vec![Query::new( + "SELECT * FROM sharded ORDER BY id LIMIT 10 OFFSET 5", + ) + .into()]); + + let route = command.route(); + assert!( + route.is_fdw_fallback(), + "Cross-shard query with OFFSET > 0 should trigger FDW fallback" + ); +} + +/// Cross-shard SELECT with OFFSET = 0 should NOT trigger FDW fallback +#[test] +#[ignore] +fn test_cross_shard_offset_zero_no_fdw_fallback() { + let mut test = QueryParserTest::new().with_fdw_fallback(); + + let command = test.execute(vec![Query::new( + "SELECT * FROM sharded ORDER BY id LIMIT 10 OFFSET 0", + ) + .into()]); + + let route = command.route(); + assert!( + !route.is_fdw_fallback(), + "Cross-shard query with OFFSET = 0 should NOT trigger FDW fallback" + ); +} + +/// Direct-to-shard SELECT with OFFSET should NOT trigger FDW fallback +#[test] +#[ignore] +fn test_direct_shard_offset_no_fdw_fallback() { + let mut test = QueryParserTest::new().with_fdw_fallback(); + + // This query goes to a specific shard (has sharding key) + let command = test.execute(vec![Query::new( + "SELECT * FROM sharded WHERE id = 1 ORDER BY id LIMIT 10 OFFSET 5", + ) + .into()]); + + let route = command.route(); + assert!( + !route.is_fdw_fallback(), + "Direct-to-shard query with OFFSET should NOT trigger FDW fallback" + ); +} + +// ============================================================================= +// CTE tests +// ============================================================================= + +/// CTE that references an unsharded table without a sharding key should trigger +/// FDW fallback when the main query is cross-shard. +#[test] +#[ignore] +fn test_cte_unsharded_table_triggers_fdw_fallback() { + let mut test = QueryParserTest::new().with_fdw_fallback(); + + // `users` is not in the sharded tables config, making it "unsharded" + // The CTE has no sharding key, so this should trigger FDW fallback + let command = test.execute(vec![Query::new( + "WITH user_data AS (SELECT * FROM users WHERE email = 'test@test.com') + SELECT s.* FROM sharded s JOIN user_data u ON s.value = u.id", + ) + .into()]); + + let route = command.route(); + assert!( + route.is_fdw_fallback(), + "CTE with unsharded table and no sharding key should trigger FDW fallback" + ); +} + +/// CTE that only references sharded tables should NOT trigger FDW fallback +#[test] +#[ignore] +fn test_cte_sharded_table_no_fdw_fallback() { + let mut test = QueryParserTest::new().with_fdw_fallback(); + + // CTE references only the sharded table + let command = test.execute(vec![Query::new( + "WITH shard_data AS (SELECT * FROM sharded WHERE id = 5) + SELECT * FROM shard_data", + ) + .into()]); + + let route = command.route(); + assert!( + !route.is_fdw_fallback(), + "CTE with sharded table and sharding key should NOT trigger FDW fallback" + ); +} + +/// CTE that only references omnisharded tables should NOT trigger FDW fallback +#[test] +#[ignore] +fn test_cte_omnisharded_table_no_fdw_fallback() { + let mut test = QueryParserTest::new().with_fdw_fallback(); + + // CTE references only omnisharded table + let command = test.execute(vec![Query::new( + "WITH omni_data AS (SELECT * FROM sharded_omni WHERE id = 1) + SELECT * FROM omni_data", + ) + .into()]); + + let route = command.route(); + assert!( + !route.is_fdw_fallback(), + "CTE with omnisharded table should NOT trigger FDW fallback" + ); +} + +// ============================================================================= +// Subquery tests +// ============================================================================= + +/// Subquery in FROM that references unsharded table without sharding key +/// should trigger FDW fallback when main query is cross-shard. +#[test] +#[ignore] +fn test_subquery_unsharded_table_triggers_fdw_fallback() { + let mut test = QueryParserTest::new().with_fdw_fallback(); + + // Subquery references unsharded table without sharding key + let command = test.execute(vec![Query::new( + "SELECT s.* FROM sharded s + JOIN (SELECT * FROM users WHERE active = true) u ON s.value = u.id", + ) + .into()]); + + let route = command.route(); + assert!( + route.is_fdw_fallback(), + "Subquery with unsharded table should trigger FDW fallback" + ); +} + +/// Subquery with correlated reference to outer sharding key should NOT trigger +/// FDW fallback (inherits sharding context from outer query). +#[test] +#[ignore] +fn test_subquery_correlated_no_fdw_fallback() { + let mut test = QueryParserTest::new().with_fdw_fallback(); + + // Correlated subquery references outer query's sharded column + let command = test.execute(vec![Query::new( + "SELECT * FROM sharded s WHERE s.id = 5 AND EXISTS ( + SELECT 1 FROM sharded_omni o WHERE o.id = s.id + )", + ) + .into()]); + + let route = command.route(); + assert!( + !route.is_fdw_fallback(), + "Correlated subquery with sharding key in outer query should NOT trigger FDW fallback" + ); +} + +// ============================================================================= +// Edge case tests +// ============================================================================= + +/// Multiple CTEs where one is safe (sharded table) and one is unsafe (unsharded +/// table) should trigger FDW fallback when there's no sharding key. +#[test] +#[ignore] +fn test_multiple_ctes_mixed_safe_unsafe_triggers_fallback() { + let mut test = QueryParserTest::new().with_fdw_fallback(); + + // First CTE uses sharded table (safe), second CTE uses unsharded table (unsafe) + // No sharding key in either CTE, so unsafe CTE triggers FDW fallback + let command = test.execute(vec![Query::new( + "WITH safe_data AS (SELECT * FROM sharded), + unsafe_data AS (SELECT * FROM users WHERE active = true) + SELECT s.*, u.* FROM safe_data s, unsafe_data u", + ) + .into()]); + + let route = command.route(); + assert!( + route.is_fdw_fallback(), + "Multiple CTEs with one unsafe table should trigger FDW fallback" + ); +} + +/// Nested subqueries where the innermost references an unsharded table +/// should trigger FDW fallback. +#[test] +#[ignore] +fn test_deeply_nested_subquery_unsharded_triggers_fallback() { + let mut test = QueryParserTest::new().with_fdw_fallback(); + + // Three levels deep: outer query -> subquery -> subquery with unsharded table + let command = test.execute(vec![Query::new( + "SELECT * FROM sharded WHERE value IN ( + SELECT id FROM sharded_omni WHERE id IN ( + SELECT id FROM users WHERE active = true + ) + )", + ) + .into()]); + + let route = command.route(); + assert!( + route.is_fdw_fallback(), + "Deeply nested subquery with unsharded table should trigger FDW fallback" + ); +} + +/// JOIN inside a subquery mixing sharded and unsharded tables should trigger +/// FDW fallback. +#[test] +#[ignore] +fn test_subquery_join_mixed_tables_triggers_fallback() { + let mut test = QueryParserTest::new().with_fdw_fallback(); + + // Subquery JOINs sharded and unsharded tables + let command = test.execute(vec![Query::new( + "SELECT * FROM sharded WHERE id IN ( + SELECT s.id FROM sharded s + JOIN users u ON s.value = u.id + WHERE u.active = true + )", + ) + .into()]); + + let route = command.route(); + assert!( + route.is_fdw_fallback(), + "Subquery JOIN mixing sharded and unsharded tables should trigger FDW fallback" + ); +} + +/// OFFSET with bind parameter should trigger FDW fallback when value > 0. +#[test] +#[ignore] +fn test_offset_bind_parameter_triggers_fallback() { + use crate::net::messages::Parameter; + + let mut test = QueryParserTest::new().with_fdw_fallback(); + + // OFFSET using $1 bind parameter with value 5 + let command = test.execute(vec![ + Parse::named( + "__offset_test", + "SELECT * FROM sharded ORDER BY id LIMIT 10 OFFSET $1", + ) + .into(), + Bind::new_params("__offset_test", &[Parameter::new(b"5")]).into(), + Execute::new().into(), + Sync.into(), + ]); + + let route = command.route(); + assert!( + route.is_fdw_fallback(), + "Cross-shard query with OFFSET bind parameter > 0 should trigger FDW fallback" + ); +} + +/// Schema-qualified unsharded table should still trigger FDW fallback. +#[test] +#[ignore] +fn test_schema_qualified_unsharded_triggers_fallback() { + let mut test = QueryParserTest::new().with_fdw_fallback(); + + // Using schema-qualified name for unsharded table + let command = test.execute(vec![Query::new( + "WITH user_data AS (SELECT * FROM public.users WHERE email = 'test@test.com') + SELECT s.* FROM sharded s JOIN user_data u ON s.value = u.id", + ) + .into()]); + + let route = command.route(); + assert!( + route.is_fdw_fallback(), + "Schema-qualified unsharded table should trigger FDW fallback" + ); +} + +// ============================================================================= +// Window function tests +// ============================================================================= + +/// Cross-shard query with window function should trigger FDW fallback +/// because window functions can't be correctly merged across shards. +#[test] +#[ignore] +fn test_window_function_cross_shard_triggers_fallback() { + let mut test = QueryParserTest::new().with_fdw_fallback(); + + // ROW_NUMBER() without sharding key = cross-shard + let command = test.execute(vec![Query::new( + "SELECT id, ROW_NUMBER() OVER (ORDER BY id) FROM sharded", + ) + .into()]); + + let route = command.route(); + assert!( + route.is_fdw_fallback(), + "Cross-shard query with window function should trigger FDW fallback" + ); +} + +/// Direct-to-shard query with window function should NOT trigger FDW fallback. +#[test] +#[ignore] +fn test_window_function_single_shard_no_fallback() { + let mut test = QueryParserTest::new().with_fdw_fallback(); + + // ROW_NUMBER() with sharding key = single shard, no fallback needed + let command = test.execute(vec![Query::new( + "SELECT id, ROW_NUMBER() OVER (ORDER BY id) FROM sharded WHERE id = 1", + ) + .into()]); + + let route = command.route(); + assert!( + !route.is_fdw_fallback(), + "Single-shard query with window function should NOT trigger FDW fallback" + ); +} + +/// Multiple window functions in cross-shard query should trigger FDW fallback. +#[test] +#[ignore] +fn test_multiple_window_functions_triggers_fallback() { + let mut test = QueryParserTest::new().with_fdw_fallback(); + + let command = test.execute(vec![Query::new( + "SELECT id, + ROW_NUMBER() OVER (ORDER BY id) as rn, + RANK() OVER (PARTITION BY email ORDER BY id) as rnk + FROM sharded", + ) + .into()]); + + let route = command.route(); + assert!( + route.is_fdw_fallback(), + "Cross-shard query with multiple window functions should trigger FDW fallback" + ); +} + +/// Window function in subquery should trigger FDW fallback for cross-shard. +#[test] +#[ignore] +fn test_window_function_in_subquery_triggers_fallback() { + let mut test = QueryParserTest::new().with_fdw_fallback(); + + let command = test.execute(vec![Query::new( + "SELECT * FROM ( + SELECT id, ROW_NUMBER() OVER (ORDER BY id) as rn FROM sharded + ) sub WHERE rn <= 10", + ) + .into()]); + + let route = command.route(); + assert!( + route.is_fdw_fallback(), + "Cross-shard subquery with window function should trigger FDW fallback" + ); +} diff --git a/pgdog/src/frontend/router/parser/query/test/test_subqueries.rs b/pgdog/src/frontend/router/parser/query/test/test_subqueries.rs index db9ef150..bec0e586 100644 --- a/pgdog/src/frontend/router/parser/query/test/test_subqueries.rs +++ b/pgdog/src/frontend/router/parser/query/test/test_subqueries.rs @@ -5,15 +5,19 @@ use super::setup::{QueryParserTest, *}; #[test] fn test_subquery_in_where() { - let mut test = QueryParserTest::new(); + let mut test = QueryParserTest::new().with_fdw_fallback(); + // Subquery references `other_table` which is unsharded, so FDW fallback is triggered let command = test.execute(vec![Query::new( "SELECT * FROM sharded WHERE id IN (SELECT id FROM other_table WHERE status = 'active')", ) .into()]); assert!(command.route().is_read()); - assert_eq!(command.route().shard(), &Shard::All); + assert!( + !command.route().is_fdw_fallback(), + "Subquery with unsharded table should not trigger FDW fallback" + ); } #[test] @@ -55,8 +59,9 @@ fn test_scalar_subquery() { #[test] fn test_subquery_with_sharding_key() { - let mut test = QueryParserTest::new(); + let mut test = QueryParserTest::new().with_fdw_fallback(); + // Subquery references `other` which is unsharded, so FDW fallback is triggered let command = test.execute(vec![ Parse::named( "__test_sub", @@ -68,21 +73,28 @@ fn test_subquery_with_sharding_key() { Sync.into(), ]); - // Can't route to specific shard because we don't know the subquery result - assert_eq!(command.route().shard(), &Shard::All); + assert!( + !command.route().is_fdw_fallback(), + "Subquery with unsharded table should not trigger FDW fallback" + ); } #[test] fn test_nested_subqueries() { - let mut test = QueryParserTest::new(); + let mut test = QueryParserTest::new().with_fdw_fallback(); + // Nested subqueries reference `other` and `statuses` which are unsharded, + // so FDW fallback is triggered let command = test.execute(vec![Query::new( "SELECT * FROM sharded WHERE id IN (SELECT id FROM other WHERE status IN (SELECT status FROM statuses))", ) .into()]); assert!(command.route().is_read()); - assert_eq!(command.route().shard(), &Shard::All); + assert!( + !command.route().is_fdw_fallback(), + "Nested subqueries with unsharded tables should not trigger FDW fallback" + ); } #[test] diff --git a/pgdog/src/frontend/router/parser/route.rs b/pgdog/src/frontend/router/parser/route.rs index 4505c053..d3798c69 100644 --- a/pgdog/src/frontend/router/parser/route.rs +++ b/pgdog/src/frontend/router/parser/route.rs @@ -90,6 +90,8 @@ pub struct Route { rollback_savepoint: bool, search_path_driven: bool, schema_changed: bool, + fdw_fallback: bool, + ddl: bool, } impl Display for Route { @@ -140,6 +142,15 @@ impl Route { } } + /// Create new fdw fallback route. + pub fn fdw_fallback() -> Self { + Self { + shard: ShardWithPriority::new_override_fdw_fallback(), + fdw_fallback: true, + ..Default::default() + } + } + /// Returns true if this is a query that /// can be sent to a replica. pub fn is_read(&self) -> bool { @@ -152,6 +163,10 @@ impl Route { !self.is_read() } + pub fn set_fdw_fallback(&mut self, fallback: bool) { + self.fdw_fallback = fallback; + } + /// Get shard if any. pub fn shard(&self) -> &Shard { &self.shard @@ -178,6 +193,10 @@ impl Route { self.is_all_shards() || self.is_multi_shard() } + pub fn is_fdw_fallback(&self) -> bool { + self.fdw_fallback + } + pub fn order_by(&self) -> &[OrderBy] { &self.order_by } @@ -199,6 +218,11 @@ impl Route { self } + pub fn with_fdw_fallback(mut self, fdw_fallback: bool) -> Self { + self.set_fdw_fallback(fdw_fallback); + self + } + pub fn set_schema_changed(&mut self, changed: bool) { self.schema_changed = changed; } @@ -212,6 +236,15 @@ impl Route { self } + pub fn with_ddl(mut self, ddl: bool) -> Self { + self.ddl = ddl; + self + } + + pub fn is_ddl(&self) -> bool { + self.ddl + } + pub fn set_search_path_driven_mut(&mut self, schema_driven: bool) { self.search_path_driven = schema_driven; } @@ -345,6 +378,7 @@ pub enum OverrideReason { Transaction, OnlyOneShard, RewriteUpdate, + FdwFallback, } #[derive(Debug, Clone, PartialEq, Eq, Ord, PartialOrd)] @@ -426,6 +460,13 @@ impl ShardWithPriority { } } + pub fn new_override_fdw_fallback() -> Self { + Self { + shard: Shard::Direct(0), + source: ShardSource::Override(OverrideReason::FdwFallback), + } + } + pub fn new_default_unset(shard: Shard) -> Self { Self { shard, diff --git a/pgdog/src/frontend/router/parser/statement.rs b/pgdog/src/frontend/router/parser/statement.rs index 7e962ddf..ba30a53d 100644 --- a/pgdog/src/frontend/router/parser/statement.rs +++ b/pgdog/src/frontend/router/parser/statement.rs @@ -245,6 +245,11 @@ impl<'a, 'b, 'c> StatementParser<'a, 'b, 'c> { self } + /// Get the sharding schema reference. + pub fn sharding_schema(&self) -> &ShardingSchema { + self.schema + } + pub fn from_select( stmt: &'a SelectStmt, bind: Option<&'b Bind>,