From adb23de0cbc999fb94757e1c884c60d8cfb58fd0 Mon Sep 17 00:00:00 2001 From: stevenhsd <56357022+stevenhsd@users.noreply.github.com> Date: Wed, 29 Apr 2026 00:48:14 +0100 Subject: [PATCH 1/3] refactor: configured refdata loader to be instantiated when required without need for class vars --- src/dve/core_engine/backends/base/backend.py | 33 +------- .../backends/base/reference_data.py | 9 +- src/dve/core_engine/backends/exceptions.py | 13 +++ .../implementations/duckdb/reference_data.py | 20 +++-- .../backends/implementations/spark/backend.py | 31 +++++-- .../implementations/spark/reference_data.py | 16 ++-- src/dve/pipeline/duckdb_pipeline.py | 13 ++- src/dve/pipeline/pipeline.py | 19 +++-- src/dve/pipeline/spark_pipeline.py | 14 +++- tests/features/steps/steps_pipeline.py | 23 ++---- .../test_duckdb/test_ddb_refdata.py | 82 ++++++++++--------- .../test_spark/test_spark_refdata.py | 59 ++++++------- tests/test_core_engine/test_engine.py | 3 +- tests/test_pipeline/test_duckdb_pipeline.py | 9 -- .../test_foundry_ddb_pipeline.py | 25 +----- tests/test_pipeline/test_pipeline.py | 3 - tests/test_pipeline/test_spark_pipeline.py | 26 +----- 17 files changed, 180 insertions(+), 218 deletions(-) diff --git a/src/dve/core_engine/backends/base/backend.py b/src/dve/core_engine/backends/base/backend.py index 29e8644..507ede8 100644 --- a/src/dve/core_engine/backends/base/backend.py +++ b/src/dve/core_engine/backends/base/backend.py @@ -41,14 +41,12 @@ def __init__( # pylint: disable=unused-argument self, contract: BaseDataContract[EntityType], steps: BaseStepImplementations[EntityType], - reference_data_loader_type: Optional[type[BaseRefDataLoader[EntityType]]], logger: Optional[logging.Logger] = None, **kwargs: Any, ) -> None: for component_name, component in ( ("Contract", contract), ("Step implementation", steps), - ("Reference data loader", reference_data_loader_type), ): component_entity_type = getattr(component, "__entity_type__", None) if component_entity_type != self.__entity_type__: @@ -61,12 +59,6 @@ def __init__( # pylint: disable=unused-argument """The data contract implementation used by the backend.""" self.step_implementations = steps """The step implementations used by the backend.""" - self.reference_data_loader_type = reference_data_loader_type - """ - The loader type to use for the reference data. If `None`, do not - load any reference data and error if it is provided. - - """ self.logger = logger or get_logger(type(self).__name__) """The `logging.Logger instance for the backend.""" @@ -74,29 +66,8 @@ def load_reference_data( self, reference_entity_config: dict[EntityName, ReferenceConfigUnion], submission_info: Optional[SubmissionInfo], - ) -> Mapping[EntityName, EntityType]: - """Load the reference data as specified in the reference entity config.""" - sub_info_entity: Optional[EntityType] = None - if submission_info: - sub_info_entity = self.convert_submission_info(submission_info) - - if self.reference_data_loader_type is None: - if reference_entity_config: - raise ValueError( - "Reference data has been specified but no reference data loader is " - + "configured for this backend" - ) - - reference_data_dict = {} - if sub_info_entity is not None: - reference_data_dict["dve_submission_info"] = sub_info_entity - return reference_data_dict - - reference_data_loader = self.reference_data_loader_type(reference_entity_config) - if sub_info_entity is not None: - reference_data_loader.entity_cache["dve_submission_info"] = sub_info_entity - - return reference_data_loader + ) -> BaseRefDataLoader[EntityType]: + raise NotImplementedError() @abstractmethod def convert_submission_info(self, submission_info: SubmissionInfo) -> EntityType: diff --git a/src/dve/core_engine/backends/base/reference_data.py b/src/dve/core_engine/backends/base/reference_data.py index 5be0ec0..187b798 100644 --- a/src/dve/core_engine/backends/base/reference_data.py +++ b/src/dve/core_engine/backends/base/reference_data.py @@ -11,6 +11,7 @@ from dve.core_engine.backends.base.core import get_entity_type from dve.core_engine.backends.exceptions import ( MissingRefDataEntity, + NoRefDataConfigSupplied, RefdataLacksFileExtensionSupport, ) from dve.core_engine.backends.types import EntityType @@ -147,11 +148,11 @@ class variable for the subclass. # pylint: disable=unused-argument def __init__( self, - reference_entity_config: dict[EntityName, ReferenceConfig], - dataset_config_uri: Optional[URI] = None, + reference_data_config: dict[EntityName, ReferenceConfig], + dataset_config_uri: URI, **kwargs, ) -> None: - self.reference_entity_config = reference_entity_config + self.reference_entity_config = reference_data_config self.dataset_config_uri = dataset_config_uri """ Configuration options for the reference data. This is likely to vary @@ -207,6 +208,8 @@ def __getitem__(self, key: EntityName) -> EntityType: try: config = self.reference_entity_config[key] return self.load_entity(entity_name=key, config=config) + except TypeError: + raise NoRefDataConfigSupplied() except Exception as err: raise MissingRefDataEntity(entity_name=key) from err diff --git a/src/dve/core_engine/backends/exceptions.py b/src/dve/core_engine/backends/exceptions.py index 8dd50ef..d5b90cf 100644 --- a/src/dve/core_engine/backends/exceptions.py +++ b/src/dve/core_engine/backends/exceptions.py @@ -118,6 +118,19 @@ def get_message_preamble(self) -> str: """ return f"Missing reference data entity {self.entity_name!r}" +class NoRefDataConfigSupplied(BackendError): + """An error raised when trying to load a refdata entity when no refdata + config has been supplied. + + """ + + def __init__(self, *args: object) -> None: + super().__init__(*args) + + def get_message_preamble(self) -> EntityName: + """Message for logging purposes""" + return f"Refdata loader not supplied with refdata config - unable to load refdata entities" + class ConstraintError(ValueError, BackendErrorMixin): """Raised when a given constraint is violated.""" diff --git a/src/dve/core_engine/backends/implementations/duckdb/reference_data.py b/src/dve/core_engine/backends/implementations/duckdb/reference_data.py index af815ce..b3f47e3 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/reference_data.py +++ b/src/dve/core_engine/backends/implementations/duckdb/reference_data.py @@ -7,6 +7,7 @@ from dve.core_engine.backends.base.reference_data import ( BaseRefDataLoader, + ReferenceConfig, ReferenceConfigUnion, ReferenceTable, mark_refdata_file_extension, @@ -17,19 +18,20 @@ # pylint: disable=too-few-public-methods class DuckDBRefDataLoader(BaseRefDataLoader[DuckDBPyRelation]): - """A reference data loader using already existing DuckDB tables.""" - - connection: DuckDBPyConnection - """The DuckDB connection for the backend.""" - dataset_config_uri: Optional[URI] = None - """The location of the dischema file""" + """A reference data loader using already existing DuckDB tables. + reference_entity_config and dataset_config_uri (if config uses relative paths) + should be supplied using setter methods for the dataset being processed before running.""" def __init__( self, - reference_entity_config: dict[EntityName, ReferenceConfigUnion], - **kwargs, + connection: DuckDBPyConnection, + reference_data_config: dict[EntityName, ReferenceConfig], + dataset_config_uri: URI, + **kwargs ) -> None: - super().__init__(reference_entity_config, self.dataset_config_uri, **kwargs) + super().__init__(reference_data_config, dataset_config_uri,**kwargs) + + self.connection = connection if not self.connection: raise AttributeError("DuckDBConnection must be specified") diff --git a/src/dve/core_engine/backends/implementations/spark/backend.py b/src/dve/core_engine/backends/implementations/spark/backend.py index 3999b62..e943b03 100644 --- a/src/dve/core_engine/backends/implementations/spark/backend.py +++ b/src/dve/core_engine/backends/implementations/spark/backend.py @@ -6,16 +6,18 @@ from pyspark.sql import DataFrame, SparkSession from dve.core_engine.backends.base.backend import BaseBackend +from dve.core_engine.backends.base.reference_data import ReferenceConfigUnion from dve.core_engine.backends.implementations.spark.contract import SparkDataContract from dve.core_engine.backends.implementations.spark.reference_data import SparkRefDataLoader from dve.core_engine.backends.implementations.spark.rules import SparkStepImplementations from dve.core_engine.backends.implementations.spark.spark_helpers import get_type_from_annotation from dve.core_engine.backends.implementations.spark.types import SparkEntities +from dve.core_engine.backends.types import EntityType from dve.core_engine.constants import RECORD_INDEX_COLUMN_NAME from dve.core_engine.loggers import get_child_logger, get_logger from dve.core_engine.models import SubmissionInfo -from dve.core_engine.type_hints import URI, EntityParquetLocations -from dve.parser.file_handling import get_resource_exists, joinuri +from dve.core_engine.type_hints import URI, EntityName, EntityParquetLocations +from dve.parser.file_handling import get_resource_exists, joinuri, get_parent class SparkBackend(BaseBackend[DataFrame]): @@ -26,7 +28,6 @@ def __init__( dataset_config_uri: Optional[URI] = None, contract: Optional[SparkDataContract] = None, steps: Optional[SparkStepImplementations] = None, - reference_data_loader: Optional[type[SparkRefDataLoader]] = None, logger: Optional[logging.Logger] = None, spark_session: Optional[SparkSession] = None, **kwargs: Any, @@ -36,6 +37,8 @@ def __init__( self.spark_session = spark_session or SparkSession.builder.getOrCreate() """The Spark session for the backend.""" + self.dataset_config_uri = dataset_config_uri + """The uri of the dischema specifying the DVE config""" if contract is None: contract = SparkDataContract( @@ -46,11 +49,23 @@ def __init__( steps = SparkStepImplementations.register_udfs( logger=get_child_logger("SparkStepImplementations", logger) ) - if reference_data_loader is None: - reference_data_loader = SparkRefDataLoader - reference_data_loader.spark = self.spark_session - reference_data_loader.dataset_config_uri = dataset_config_uri - super().__init__(contract, steps, reference_data_loader, logger, **kwargs) + super().__init__(contract, steps, logger, **kwargs) + + def load_reference_data(self, + reference_entity_config: dict[EntityName, ReferenceConfigUnion], + submission_info: Optional[SubmissionInfo],): + """Load the reference data as specified in the reference entity config.""" + sub_info_entity: Optional[EntityType] = None + if submission_info: + sub_info_entity = self.convert_submission_info(submission_info) + + reference_data_loader = SparkRefDataLoader(spark=self.spark_session, + reference_data_config=reference_entity_config, + dataset_config_uri=self.dataset_config_uri) + if sub_info_entity is not None: + reference_data_loader.entity_cache["dve_submission_info"] = sub_info_entity + + return reference_data_loader def write_entities_to_parquet( self, entities: SparkEntities, cache_prefix: URI diff --git a/src/dve/core_engine/backends/implementations/spark/reference_data.py b/src/dve/core_engine/backends/implementations/spark/reference_data.py index 90ba4f6..08507ee 100644 --- a/src/dve/core_engine/backends/implementations/spark/reference_data.py +++ b/src/dve/core_engine/backends/implementations/spark/reference_data.py @@ -17,19 +17,19 @@ # pylint: disable=too-few-public-methods class SparkRefDataLoader(BaseRefDataLoader[DataFrame]): - """A reference data loader using already existing Apache Spark Tables.""" - - spark: SparkSession - """The Spark session for the backend.""" - dataset_config_uri: Optional[URI] = None - """The location of the dischema file defining business rules""" + """A reference data loader using already existing Apache Spark Tables. + reference_entity_config and dataset_config_uri (if config uses relative paths) + should be supplied using setter methods for the dataset being processed before running.""" def __init__( self, - reference_entity_config: dict[EntityName, ReferenceConfig], + spark: SparkSession, + reference_data_config: dict[EntityName, ReferenceConfig], + dataset_config_uri: URI, **kwargs, ) -> None: - super().__init__(reference_entity_config, self.dataset_config_uri, **kwargs) + super().__init__(reference_data_config, dataset_config_uri, **kwargs) + self.spark = spark if not self.spark: raise AttributeError("Spark session must be provided") diff --git a/src/dve/pipeline/duckdb_pipeline.py b/src/dve/pipeline/duckdb_pipeline.py index 87e927d..713001f 100644 --- a/src/dve/pipeline/duckdb_pipeline.py +++ b/src/dve/pipeline/duckdb_pipeline.py @@ -5,14 +5,16 @@ from duckdb import DuckDBPyConnection, DuckDBPyRelation -from dve.core_engine.backends.base.reference_data import BaseRefDataLoader +from dve.core_engine.backends.base.reference_data import BaseRefDataLoader, ReferenceConfig from dve.core_engine.backends.implementations.duckdb.auditing import DDBAuditingManager from dve.core_engine.backends.implementations.duckdb.contract import DuckDBDataContract from dve.core_engine.backends.implementations.duckdb.duckdb_helpers import duckdb_get_entity_count +from dve.core_engine.backends.implementations.duckdb.reference_data import DuckDBRefDataLoader from dve.core_engine.backends.implementations.duckdb.rules import DuckDBStepImplementations from dve.core_engine.models import SubmissionInfo from dve.core_engine.type_hints import URI from dve.pipeline.pipeline import BaseDVEPipeline +import dve.parser.file_handling as fh # pylint: disable=abstract-method @@ -30,7 +32,6 @@ def __init__( connection: DuckDBPyConnection, rules_path: Optional[URI], submitted_files_path: Optional[URI], - reference_data_loader: Optional[type[BaseRefDataLoader]] = None, job_run_id: Optional[int] = None, logger: Optional[logging.Logger] = None, ): @@ -42,11 +43,17 @@ def __init__( DuckDBStepImplementations.register_udfs(connection=self._connection), rules_path, submitted_files_path, - reference_data_loader, job_run_id, logger, ) + def get_reference_data_loader(self, + reference_data_config: dict[str, ReferenceConfig], + **kwargs) -> BaseRefDataLoader[DuckDBPyRelation]: + return DuckDBRefDataLoader(connection=self._connection, + reference_data_config=reference_data_config, + dataset_config_uri=fh.get_parent(self._rules_path), + **kwargs) # pylint: disable=arguments-differ def write_file_to_parquet( # type: ignore self, submission_file_uri: URI, submission_info: SubmissionInfo, output: URI diff --git a/src/dve/pipeline/pipeline.py b/src/dve/pipeline/pipeline.py index 26b682e..f22073e 100644 --- a/src/dve/pipeline/pipeline.py +++ b/src/dve/pipeline/pipeline.py @@ -26,7 +26,7 @@ from dve.core_engine.backends.base.auditing import BaseAuditingManager from dve.core_engine.backends.base.contract import BaseDataContract from dve.core_engine.backends.base.core import EntityManager -from dve.core_engine.backends.base.reference_data import BaseRefDataLoader +from dve.core_engine.backends.base.reference_data import BaseRefDataLoader, ReferenceConfig from dve.core_engine.backends.base.rules import BaseStepImplementations from dve.core_engine.backends.exceptions import MessageBearingError from dve.core_engine.backends.readers import BaseFileReader @@ -36,7 +36,7 @@ from dve.core_engine.loggers import get_logger from dve.core_engine.message import FeedbackMessage from dve.core_engine.models import SubmissionInfo, SubmissionStatisticsRecord -from dve.core_engine.type_hints import URI, DVEStageName, FileURI, InfoURI +from dve.core_engine.type_hints import URI, DVEStageName, EntityName, FileURI, InfoURI from dve.parser import file_handling as fh from dve.parser.file_handling.implementations.file import LocalFilesystemImplementation from dve.parser.file_handling.service import _get_implementation @@ -62,14 +62,13 @@ def __init__( step_implementations: Optional[BaseStepImplementations[EntityType]], rules_path: Optional[URI], submitted_files_path: Optional[URI], - reference_data_loader: Optional[type[BaseRefDataLoader]] = None, job_run_id: Optional[int] = None, logger: Optional[logging.Logger] = None, ): self._submitted_files_path = submitted_files_path self._processed_files_path = processed_files_path self._rules_path = rules_path - self._reference_data_loader = reference_data_loader + self._reference_data_loader = None self._job_run_id = job_run_id self._audit_tables = audit_tables self._data_contract = data_contract @@ -113,6 +112,13 @@ def step_implementations(self) -> Optional[BaseStepImplementations[EntityType]]: def get_entity_count(entity: EntityType) -> int: """Get a row count of an entity stored as parquet""" raise NotImplementedError() + + def get_reference_data_loader(self, + reference_data_config: dict[EntityName, ReferenceConfig], + **kwargs) -> BaseRefDataLoader[EntityType]: + """Get reference data loader if required for business rules""" + raise NotImplementedError() + def get_submission_status( self, step_name: DVEStageName, submission_id: str @@ -542,9 +548,6 @@ def apply_business_rules( # pylint: disable=R0914 if not self.rules_path: raise AttributeError("business rules path not provided.") - if not self._reference_data_loader: - raise AttributeError("reference data loader not provided.") - if not self.processed_files_path: raise AttributeError("processed files path has not been provided.") @@ -556,8 +559,8 @@ def apply_business_rules( # pylint: disable=R0914 self._processed_files_path, submission_info.submission_id ) ref_data = config.get_reference_data_config() + reference_data = self.get_reference_data_loader(reference_data_config=ref_data) rules = config.get_rule_metadata() - reference_data = self._reference_data_loader(ref_data) # type: ignore entities = {} contract = fh.joinuri( self.processed_files_path, submission_info.submission_id, "data_contract" diff --git a/src/dve/pipeline/spark_pipeline.py b/src/dve/pipeline/spark_pipeline.py index 71fdb32..30d42ad 100644 --- a/src/dve/pipeline/spark_pipeline.py +++ b/src/dve/pipeline/spark_pipeline.py @@ -6,15 +6,17 @@ from pyspark.sql import DataFrame, SparkSession -from dve.core_engine.backends.base.reference_data import BaseRefDataLoader +from dve.core_engine.backends.base.reference_data import BaseRefDataLoader, ReferenceConfig from dve.core_engine.backends.implementations.spark.auditing import SparkAuditingManager from dve.core_engine.backends.implementations.spark.contract import SparkDataContract +from dve.core_engine.backends.implementations.spark.reference_data import SparkRefDataLoader from dve.core_engine.backends.implementations.spark.rules import SparkStepImplementations from dve.core_engine.backends.implementations.spark.spark_helpers import spark_get_entity_count from dve.core_engine.models import SubmissionInfo from dve.core_engine.type_hints import URI from dve.pipeline.pipeline import BaseDVEPipeline from dve.pipeline.utils import SubmissionStatus, unpersist_all_rdds +import dve.parser.file_handling as fh # pylint: disable=abstract-method @@ -31,7 +33,6 @@ def __init__( audit_tables: SparkAuditingManager, rules_path: Optional[URI], submitted_files_path: Optional[URI], - reference_data_loader: Optional[type[BaseRefDataLoader]] = None, spark: Optional[SparkSession] = None, job_run_id: Optional[int] = None, logger: Optional[logging.Logger] = None, @@ -44,10 +45,17 @@ def __init__( SparkStepImplementations.register_udfs(self._spark), rules_path, submitted_files_path, - reference_data_loader, job_run_id, logger, ) + + def get_reference_data_loader(self, + reference_data_config: dict[str, ReferenceConfig], + **kwargs) -> BaseRefDataLoader[DataFrame]: + return SparkRefDataLoader(spark=self._spark, + reference_data_config=reference_data_config, + dataset_config_uri=fh.get_parent(self._rules_path), + **kwargs) # pylint: disable=arguments-differ def write_file_to_parquet( # type: ignore diff --git a/tests/features/steps/steps_pipeline.py b/tests/features/steps/steps_pipeline.py index fa1e848..55acadd 100644 --- a/tests/features/steps/steps_pipeline.py +++ b/tests/features/steps/steps_pipeline.py @@ -48,9 +48,6 @@ def setup_spark_pipeline( schema_file_name = f"{dataset_id}.dischema.json" if not schema_file_name else schema_file_name rules_path = get_test_file_path(f"{dataset_id}/{schema_file_name}").resolve().as_uri() - # configure reference data - SparkRefDataLoader.spark = spark - SparkRefDataLoader.dataset_config_uri = fh.get_parent(rules_path) return SparkDVEPipeline( processed_files_path=processing_path.as_uri(), @@ -61,7 +58,6 @@ def setup_spark_pipeline( job_run_id=12345, rules_path=rules_path, submitted_files_path=processing_path.as_uri(), - reference_data_loader=SparkRefDataLoader, spark=spark, ) @@ -78,9 +74,6 @@ def setup_duckdb_pipeline( # create duckdbpyconnection with dve database file in context.tempdir # TODO - doesn't like file scheme - need to provide absolute path db_file = Path(processing_path, "dve.duckdb") - # configure refdata - DuckDBRefDataLoader.connection = connection - DuckDBRefDataLoader.dataset_config_uri = fh.get_parent(rules_path) return DDBDVEPipeline( processed_files_path=processing_path.as_posix(), audit_tables=DDBAuditingManager( @@ -91,8 +84,7 @@ def setup_duckdb_pipeline( job_run_id=12345, connection=connection, rules_path=rules_path, - submitted_files_path=processing_path.as_posix(), - reference_data_loader=DuckDBRefDataLoader + submitted_files_path=processing_path.as_posix() ) @@ -314,18 +306,17 @@ def create_refdata_tables(context: Context, database: str): record = row.as_dict() refdata_tables[record["table_name"]] = record["parquet_path"] pipeline = ctxt.get_pipeline(context) - refdata_loader = getattr(pipeline, "_reference_data_loader") - if refdata_loader == SparkRefDataLoader: - refdata_loader.spark.sql(f"CREATE DATABASE IF NOT EXISTS {database}") + if isinstance(pipeline, SparkDVEPipeline): + pipeline._spark.sql(f"CREATE DATABASE IF NOT EXISTS {database}") for tbl, source in refdata_tables.items(): - (refdata_loader.spark.read.parquet(source) + (pipeline._spark.read.parquet(source) .write.saveAsTable(f"{database}.{tbl}")) - if refdata_loader == DuckDBRefDataLoader: + if isinstance(pipeline, DDBDVEPipeline): ref_db_file = Path(ctxt.get_processing_location(context), f"{database}.duckdb").as_posix() - refdata_loader.connection.sql(f"ATTACH '{ref_db_file}' AS {database}") + pipeline._connection.sql(f"ATTACH '{ref_db_file}' AS {database}") for tbl, source in refdata_tables.items(): - refdata_loader.connection.read_parquet(source).to_table(f"{database}.{tbl}") + pipeline._connection.read_parquet(source).to_table(f"{database}.{tbl}") diff --git a/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_ddb_refdata.py b/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_ddb_refdata.py index 7ae4858..ff73f85 100644 --- a/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_ddb_refdata.py +++ b/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_ddb_refdata.py @@ -19,110 +19,118 @@ def temp_working_dir(): shutil.copytree(refdata_path.as_posix(), tmp, dirs_exist_ok=True) yield tmp -@pytest.fixture(scope="function") -def ddb_refdata_loader(temp_working_dir, temp_ddb_conn): - _, conn = temp_ddb_conn - DuckDBRefDataLoader.connection = conn - DuckDBRefDataLoader.dataset_config_uri = temp_working_dir - yield DuckDBRefDataLoader, temp_working_dir @pytest.fixture(scope="function") -def ddb_refdata_table(ddb_refdata_loader): - refdata_loader, _ = ddb_refdata_loader +def ddb_refdata_table(temp_ddb_conn): + _, conn = temp_ddb_conn schema = "dve_" + uuid4().hex tbl = "movies_sequels" - refdata_loader.connection.sql(f"CREATE SCHEMA IF NOT EXISTS {schema}") - refdata_loader.connection.read_parquet(get_test_file_path("movies/refdata/movies_sequels.parquet").as_posix()).to_table(f"{schema}.{tbl}") + conn.sql(f"CREATE SCHEMA IF NOT EXISTS {schema}") + conn.read_parquet(get_test_file_path("movies/refdata/movies_sequels.parquet").as_posix()).to_table(f"{schema}.{tbl}") yield schema, tbl - refdata_loader.connection.sql(f"DROP TABLE IF EXISTS {schema}.{tbl}") - refdata_loader.connection.sql(f"DROP SCHEMA IF EXISTS {schema}") + conn.sql(f"DROP TABLE IF EXISTS {schema}.{tbl}") + conn.sql(f"DROP SCHEMA IF EXISTS {schema}") -def test_load_arrow_file(ddb_refdata_loader): - refdata_loader, _ = ddb_refdata_loader +def test_load_arrow_file(temp_working_dir, temp_ddb_conn): + _, conn = temp_ddb_conn config = { "test_refdata": ReferenceFile(type="filename", filename="./movies_sequels.arrow") } - duckdb_refdata_loader: DuckDBRefDataLoader = refdata_loader(config) + duckdb_refdata_loader: DuckDBRefDataLoader = DuckDBRefDataLoader(connection=conn, + reference_data_config=config, + dataset_config_uri=temp_working_dir) test = duckdb_refdata_loader.load_file(config.get("test_refdata")) assert test.shape == (3, 3) -def test_load_parquet_file(ddb_refdata_loader): - refdata_loader, _ = ddb_refdata_loader +def test_load_parquet_file(temp_working_dir, temp_ddb_conn): + _, conn = temp_ddb_conn config = { "test_refdata": ReferenceFile(type="filename", filename="./movies_sequels.parquet") } - duckdb_refdata_loader: DuckDBRefDataLoader = refdata_loader(config) + duckdb_refdata_loader: DuckDBRefDataLoader = DuckDBRefDataLoader(connection=conn, + reference_data_config=config, + dataset_config_uri=temp_working_dir) test = duckdb_refdata_loader.load_file(config.get("test_refdata")) assert test.shape == (2, 3) -def test_load_uri_parquet(ddb_refdata_loader): - refdata_dir: Path - refdata_loader, refdata_dir = ddb_refdata_loader +def test_load_uri_parquet(temp_working_dir, temp_ddb_conn): + _, conn = temp_ddb_conn config = { "test_refdata": ReferenceURI(type="uri", - uri=Path(refdata_dir).joinpath("movies_sequels.parquet").as_posix()) + uri=Path(temp_working_dir).joinpath("movies_sequels.parquet").as_posix()) } - duckdb_refdata_loader: DuckDBRefDataLoader = refdata_loader(config) + duckdb_refdata_loader: DuckDBRefDataLoader = DuckDBRefDataLoader(connection=conn, + reference_data_config=config, + dataset_config_uri=temp_working_dir) test = duckdb_refdata_loader.load_uri(config.get("test_refdata")) assert test.shape == (2, 3) -def test_load_uri_arrow(ddb_refdata_loader): - refdata_loader, refdata_dir = ddb_refdata_loader +def test_load_uri_arrow(temp_working_dir, temp_ddb_conn): + _, conn = temp_ddb_conn config = { "test_refdata": ReferenceURI(type="uri", - uri=Path(refdata_dir).joinpath("movies_sequels.arrow").as_posix()) + uri=Path(temp_working_dir).joinpath("movies_sequels.arrow").as_posix()) } - duckdb_refdata_loader: DuckDBRefDataLoader = refdata_loader(config) + duckdb_refdata_loader: DuckDBRefDataLoader = DuckDBRefDataLoader(connection=conn, + reference_data_config=config, + dataset_config_uri=temp_working_dir) test = duckdb_refdata_loader.load_uri(config.get("test_refdata")) assert test.shape == (3, 3) -def test_table_read(ddb_refdata_loader, ddb_refdata_table): - refdata_loader, _ = ddb_refdata_loader +def test_table_read(temp_working_dir, temp_ddb_conn, ddb_refdata_table): + _, conn = temp_ddb_conn db, tbl = ddb_refdata_table config = { "test_refdata": ReferenceTable(type="table", table_name=tbl, database=db) } - duckdb_refdata_loader: DuckDBRefDataLoader = refdata_loader(config) + duckdb_refdata_loader: DuckDBRefDataLoader = DuckDBRefDataLoader(connection=conn, + reference_data_config=config, + dataset_config_uri=temp_working_dir) test = duckdb_refdata_loader.load_table(config.get("test_refdata")) assert test.shape == (2, 3) -def test_via_entity_manager(ddb_refdata_loader, ddb_refdata_table): - refdata_loader, refdata_dir = ddb_refdata_loader +def test_via_entity_manager(temp_working_dir, temp_ddb_conn, ddb_refdata_table): + _, conn = temp_ddb_conn db, tbl = ddb_refdata_table config = { "test_refdata_file": ReferenceFile(type="filename", filename="./movies_sequels.arrow"), "test_refdata_uri": ReferenceURI(type="uri", - uri=Path(refdata_dir).joinpath("movies_sequels.parquet").as_posix()), + uri=Path(temp_working_dir).joinpath("movies_sequels.parquet").as_posix()), "test_refdata_table": ReferenceTable(type="table", table_name=tbl, database=db) } - em = EntityManager({}, reference_data=refdata_loader(config)) + refdata_loader: DuckDBRefDataLoader = DuckDBRefDataLoader(connection=conn, + reference_data_config=config, + dataset_config_uri=temp_working_dir) + em = EntityManager({}, reference_data=refdata_loader) assert em.get("refdata_test_refdata_file").shape == (3, 3) assert em.get("refdata_test_refdata_uri").shape == (2, 3) assert em.get("refdata_test_refdata_table").shape == (2, 3) -def test_refdata_error(ddb_refdata_loader): - refdata_loader, refdata_dir = ddb_refdata_loader +def test_refdata_error(temp_working_dir, temp_ddb_conn): + _, conn = temp_ddb_conn config = { "test_refdata_file": ReferenceFile(type="filename", filename="./movies_sequels.arrow") } - duckdb_refdata_loader: DuckDBRefDataLoader = refdata_loader(config) + duckdb_refdata_loader: DuckDBRefDataLoader = DuckDBRefDataLoader(connection=conn, + reference_data_config=config, + dataset_config_uri=temp_working_dir) with pytest.raises(MissingRefDataEntity): duckdb_refdata_loader["missing_refdata"] \ No newline at end of file diff --git a/tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_refdata.py b/tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_refdata.py index b50b9bb..8c60619 100644 --- a/tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_refdata.py +++ b/tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_refdata.py @@ -2,7 +2,7 @@ import shutil import pytest -from dve.core_engine.backends.exceptions import MissingRefDataEntity, RefdataLacksFileExtensionSupport +from dve.core_engine.backends.exceptions import MissingRefDataEntity from dve.core_engine.backends.implementations.spark.reference_data import SparkRefDataLoader from dve.core_engine.backends.base.core import EntityManager from dve.core_engine.backends.base.reference_data import ReferenceFile, ReferenceTable, ReferenceURI @@ -19,83 +19,84 @@ def temp_working_dir(): yield tmp @pytest.fixture(scope="function") -def spark_refdata_loader(spark, temp_working_dir): - SparkRefDataLoader.spark = spark - SparkRefDataLoader.dataset_config_uri = temp_working_dir - yield SparkRefDataLoader, temp_working_dir - -@pytest.fixture(scope="function") -def spark_refdata_table(spark_refdata_loader, spark_test_database): - refdata_loader, _ = spark_refdata_loader +def spark_refdata_table(spark, spark_test_database): tbl = "movies_sequels" - refdata_loader.spark.read.parquet(get_test_file_path("movies/refdata/movies_sequels.parquet").as_posix()).write.saveAsTable(f"{spark_test_database}.{tbl}") + spark.read.parquet(get_test_file_path("movies/refdata/movies_sequels.parquet").as_posix()).write.saveAsTable(f"{spark_test_database}.{tbl}") yield spark_test_database, tbl - refdata_loader.spark.sql(f"DROP TABLE IF EXISTS {spark_test_database}.{tbl}") + spark.sql(f"DROP TABLE IF EXISTS {spark_test_database}.{tbl}") -def test_load_parquet_file(spark_refdata_loader): - refdata_loader, _ = spark_refdata_loader +def test_load_parquet_file(spark, temp_working_dir): config = { "test_refdata": ReferenceFile(type="filename", filename="./movies_sequels.parquet") } - spk_refdata_loader: SparkRefDataLoader = refdata_loader(config) + spk_refdata_loader: SparkRefDataLoader = SparkRefDataLoader(spark=spark, + reference_data_config=config, + dataset_config_uri=temp_working_dir) test = spk_refdata_loader.load_file(config.get("test_refdata")) assert test.count() == 2 -def test_load_uri_parquet(spark_refdata_loader): - refdata_dir: Path - refdata_loader, refdata_dir = spark_refdata_loader +def test_load_uri_parquet(spark, temp_working_dir): config = { "test_refdata": ReferenceURI(type="uri", - uri=Path(refdata_dir).joinpath("movies_sequels.parquet").as_posix()) + uri=Path(temp_working_dir).joinpath("movies_sequels.parquet").as_posix()) } - spk_refdata_loader: SparkRefDataLoader = refdata_loader(config) + spk_refdata_loader: SparkRefDataLoader = SparkRefDataLoader(spark=spark, + reference_data_config=config, + dataset_config_uri=temp_working_dir) test = spk_refdata_loader.load_uri(config.get("test_refdata")) assert test.count() == 2 -def test_table_read(spark_refdata_loader, spark_refdata_table): - refdata_loader, _ = spark_refdata_loader +def test_table_read(spark, temp_working_dir, spark_refdata_table): db, tbl = spark_refdata_table config = { "test_refdata": ReferenceTable(type="table", table_name=tbl, database=db) } - spk_refdata_loader: SparkRefDataLoader = refdata_loader(config) + spk_refdata_loader: SparkRefDataLoader = SparkRefDataLoader(spark=spark, + reference_data_config=config, + dataset_config_uri=temp_working_dir) test = spk_refdata_loader.load_table(config.get("test_refdata")) assert test.count() == 2 -def test_via_entity_manager(spark_refdata_loader, spark_refdata_table): - refdata_loader, refdata_dir = spark_refdata_loader +def test_via_entity_manager(spark, temp_working_dir, spark_refdata_table): db, tbl = spark_refdata_table config = { "test_refdata_file": ReferenceFile(type="filename", filename="./movies_sequels.parquet"), "test_refdata_uri": ReferenceURI(type="uri", - uri=Path(refdata_dir).joinpath("movies_sequels.parquet").as_posix()), + uri=Path(temp_working_dir).joinpath("movies_sequels.parquet").as_posix()), "test_refdata_table": ReferenceTable(type="table", table_name=tbl, database=db) } - em = EntityManager({}, reference_data=refdata_loader(config)) + + spk_refdata_loader: SparkRefDataLoader = SparkRefDataLoader(spark=spark, + reference_data_config=config, + dataset_config_uri=temp_working_dir) + em = EntityManager({}, reference_data=spk_refdata_loader) assert em.get("refdata_test_refdata_file").count() == 2 assert em.get("refdata_test_refdata_uri").count() == 2 assert em.get("refdata_test_refdata_table").count() == 2 -def test_refdata_error(spark_refdata_loader): - refdata_loader, _ = spark_refdata_loader +def test_refdata_error(spark, temp_working_dir): config = { "test_refdata_file": ReferenceFile(type="filename", filename="./movies_sequels.arrow") } - em = EntityManager({}, reference_data=refdata_loader(config)) + + spk_refdata_loader: SparkRefDataLoader = SparkRefDataLoader(spark=spark, + reference_data_config=config, + dataset_config_uri=temp_working_dir) + em = EntityManager({}, reference_data=spk_refdata_loader) with pytest.raises(MissingRefDataEntity): em["refdata_missing"] em["refdata_test_refdata_file"] diff --git a/tests/test_core_engine/test_engine.py b/tests/test_core_engine/test_engine.py index ef23d71..7e0fd6e 100644 --- a/tests/test_core_engine/test_engine.py +++ b/tests/test_core_engine/test_engine.py @@ -29,8 +29,7 @@ def test_dummy_planet_run(self, spark: SparkSession, temp_dir: str): dataset_config_path=config_path.as_posix(), output_prefix=Path(temp_dir), backend=SparkBackend(dataset_config_uri=config_path.parent.as_posix(), - spark_session=spark, - reference_data_loader=refdata_loader) + spark_session=spark) ) with test_instance: diff --git a/tests/test_pipeline/test_duckdb_pipeline.py b/tests/test_pipeline/test_duckdb_pipeline.py index 29e0734..aa65516 100644 --- a/tests/test_pipeline/test_duckdb_pipeline.py +++ b/tests/test_pipeline/test_duckdb_pipeline.py @@ -148,9 +148,6 @@ def test_business_rule_step( db_file, conn = temp_ddb_conn sub_info, processed_files_path = planets_data_after_data_contract - DuckDBRefDataLoader.connection = conn - DuckDBRefDataLoader.dataset_config_uri = fh.get_parent(PLANETS_RULES_PATH) - with DDBAuditingManager(db_file.as_uri(), ThreadPoolExecutor(1), conn) as audit_manager: dve_pipeline = DDBDVEPipeline( processed_files_path=processed_files_path, @@ -159,7 +156,6 @@ def test_business_rule_step( connection=conn, rules_path=PLANETS_RULES_PATH, submitted_files_path=None, - reference_data_loader=DuckDBRefDataLoader, ) audit_manager.add_new_submissions([sub_info], job_run_id=1) @@ -187,9 +183,6 @@ def test_error_report_step( db_file, conn = temp_ddb_conn submitted_file_info, processed_files_path, status = planets_data_after_business_rules - DuckDBRefDataLoader.connection = conn - DuckDBRefDataLoader.dataset_config_uri = fh.get_parent(PLANETS_RULES_PATH) - with DDBAuditingManager(db_file.as_uri(), ThreadPoolExecutor(1), conn) as audit_manager: dve_pipeline = DDBDVEPipeline( processed_files_path=processed_files_path, @@ -198,7 +191,6 @@ def test_error_report_step( connection=conn, rules_path=None, submitted_files_path=None, - reference_data_loader=DuckDBRefDataLoader, ) reports = dve_pipeline.error_report_step( @@ -222,7 +214,6 @@ def test_get_submission_status(temp_ddb_conn): connection=conn, rules_path=None, submitted_files_path=None, - reference_data_loader=DuckDBRefDataLoader, ) dve_pipeline._logger = Mock(spec=logging.Logger) # add four submissions diff --git a/tests/test_pipeline/test_foundry_ddb_pipeline.py b/tests/test_pipeline/test_foundry_ddb_pipeline.py index 350b990..666bd90 100644 --- a/tests/test_pipeline/test_foundry_ddb_pipeline.py +++ b/tests/test_pipeline/test_foundry_ddb_pipeline.py @@ -34,10 +34,6 @@ def test_foundry_runner_validation_fail(planet_test_files, temp_ddb_conn): shutil.copytree(planet_test_files, sub_folder) - DuckDBRefDataLoader.connection = conn - DuckDBRefDataLoader.dataset_config_uri = fh.get_parent(PLANETS_RULES_PATH) - - with DDBAuditingManager(db_file.as_uri(), None, conn) as audit_manager: dve_pipeline = FoundryDDBPipeline( processed_files_path=processing_folder, @@ -45,7 +41,6 @@ def test_foundry_runner_validation_fail(planet_test_files, temp_ddb_conn): connection=conn, rules_path=get_test_file_path("planets/planets_ddb.dischema.json").as_posix(), submitted_files_path=None, - reference_data_loader=DuckDBRefDataLoader, ) output_loc, report_uri, audit_files = dve_pipeline.run_pipeline(sub_info) assert fh.get_resource_exists(report_uri) @@ -69,11 +64,7 @@ def test_foundry_runner_validation_success(movies_test_files, temp_ddb_conn): datetime_received=datetime(2025,11,5)) sub_folder = processing_folder + f"/{sub_id}" - shutil.copytree(movies_test_files, sub_folder) - - DuckDBRefDataLoader.connection = conn - DuckDBRefDataLoader.dataset_config_uri = None - + shutil.copytree(movies_test_files, sub_folder) with DDBAuditingManager(db_file.as_uri(), None, conn) as audit_manager: dve_pipeline = FoundryDDBPipeline( @@ -82,7 +73,6 @@ def test_foundry_runner_validation_success(movies_test_files, temp_ddb_conn): connection=conn, rules_path=get_test_file_path("movies/movies_ddb.dischema.json").as_posix(), submitted_files_path=None, - reference_data_loader=DuckDBRefDataLoader, ) output_loc, report_uri, audit_files = dve_pipeline.run_pipeline(sub_info) assert fh.get_resource_exists(report_uri) @@ -100,10 +90,6 @@ def test_foundry_runner_error(planet_test_files, temp_ddb_conn): shutil.copytree(planet_test_files, sub_folder) - DuckDBRefDataLoader.connection = conn - DuckDBRefDataLoader.dataset_config_uri = fh.get_parent(PLANETS_RULES_PATH) - - with DDBAuditingManager(db_file.as_uri(), None, conn) as audit_manager: dve_pipeline = FoundryDDBPipeline( processed_files_path=processing_folder, @@ -111,7 +97,6 @@ def test_foundry_runner_error(planet_test_files, temp_ddb_conn): connection=conn, rules_path=get_test_file_path("planets/planets.dischema.json").as_posix(), submitted_files_path=None, - reference_data_loader=DuckDBRefDataLoader, ) output_loc, report_uri, audit_files = dve_pipeline.run_pipeline(sub_info) assert not fh.get_resource_exists(report_uri) @@ -174,9 +159,6 @@ def test_foundry_runner_with_submitted_files_path(movies_test_files, temp_ddb_co datetime_received=datetime(2025,11,5) ) - DuckDBRefDataLoader.connection = conn - DuckDBRefDataLoader.dataset_config_uri = None - with DDBAuditingManager(db_file.as_uri(), None, conn) as audit_manager: dve_pipeline = FoundryDDBPipeline( processed_files_path=processing_folder, @@ -184,7 +166,6 @@ def test_foundry_runner_with_submitted_files_path(movies_test_files, temp_ddb_co connection=conn, rules_path=get_test_file_path("movies/movies_ddb.dischema.json").as_posix(), submitted_files_path=submitted_files_path, - reference_data_loader=DuckDBRefDataLoader, ) output_loc, report_uri, audit_files = dve_pipeline.run_pipeline(sub_info) @@ -209,9 +190,6 @@ def test_foundry_runner_error_at_bi_rules(movies_test_files, temp_ddb_conn): datetime_received=datetime(2025,11,5) ) - DuckDBRefDataLoader.connection = conn - DuckDBRefDataLoader.dataset_config_uri = None - with DDBAuditingManager(db_file.as_uri(), None, conn) as audit_manager: dve_pipeline = FoundryDDBPipeline( processed_files_path=processing_folder, @@ -219,7 +197,6 @@ def test_foundry_runner_error_at_bi_rules(movies_test_files, temp_ddb_conn): connection=conn, rules_path=get_test_file_path("movies/movies_ddb.dischema.json").as_posix(), submitted_files_path=submitted_files_path, - reference_data_loader=DuckDBRefDataLoader, ) output_loc, report_uri, audit_files = dve_pipeline.run_pipeline(sub_info) diff --git a/tests/test_pipeline/test_pipeline.py b/tests/test_pipeline/test_pipeline.py index 38418d6..a8f59c7 100644 --- a/tests/test_pipeline/test_pipeline.py +++ b/tests/test_pipeline/test_pipeline.py @@ -25,7 +25,6 @@ def test_get_submission_files_for_run(planet_test_files): # pylint: disable=red rules_path=None, processed_files_path=planet_test_files, submitted_files_path=planet_test_files, - reference_data_loader=None, ) result = list(dve_pipeline._get_submission_files_for_run()) @@ -42,7 +41,6 @@ def test_write_file_to_parquet(planet_test_files): # pylint: disable=redefined- rules_path=PLANETS_RULES_PATH, processed_files_path=planet_test_files, submitted_files_path=planet_test_files, - reference_data_loader=None, ) sub_id = uuid4().hex @@ -80,7 +78,6 @@ def test_file_transformation(planet_test_files): # pylint: disable=redefined-ou rules_path=PLANETS_RULES_PATH, processed_files_path=tdir, submitted_files_path=planet_test_files, - reference_data_loader=None, ) sub_id = uuid4().hex diff --git a/tests/test_pipeline/test_spark_pipeline.py b/tests/test_pipeline/test_spark_pipeline.py index 262d84f..b3048a1 100644 --- a/tests/test_pipeline/test_spark_pipeline.py +++ b/tests/test_pipeline/test_spark_pipeline.py @@ -49,7 +49,6 @@ def test_audit_received_step(planet_test_files, spark, spark_test_database): job_run_id=1, rules_path=None, submitted_files_path=planet_test_files, - reference_data_loader=None, ) sub_ids: Dict[str, SubmissionInfo] = {} @@ -91,7 +90,6 @@ def test_file_transformation_step( job_run_id=1, rules_path=PLANETS_RULES_PATH, submitted_files_path=planet_test_files, - reference_data_loader=None, spark=spark, ) sub_id = uuid4().hex @@ -129,7 +127,6 @@ def test_apply_data_contract_success( job_run_id=1, rules_path=PLANETS_RULES_PATH, submitted_files_path=None, - reference_data_loader=None, spark=spark, ) sub_status = SubmissionStatus() @@ -150,7 +147,6 @@ def test_apply_data_contract_failed( # pylint: disable=redefined-outer-name job_run_id=1, rules_path=PLANETS_RULES_PATH, submitted_files_path=None, - reference_data_loader=None, spark=spark, ) sub_status = SubmissionStatus() @@ -228,7 +224,6 @@ def test_data_contract_step( job_run_id=1, rules_path=PLANETS_RULES_PATH, submitted_files_path=None, - reference_data_loader=None, ) success, failed = dve_pipeline.data_contract_step( @@ -252,9 +247,6 @@ def test_apply_business_rules_success( ): # pylint: disable=redefined-outer-name sub_info, processed_file_path = planets_data_after_data_contract - SparkRefDataLoader.spark = spark - SparkRefDataLoader.dataset_config_uri = fh.get_parent(PLANETS_RULES_PATH) - with SparkAuditingManager(spark_test_database, ThreadPoolExecutor(1), spark) as audit_manager: dve_pipeline = SparkDVEPipeline( processed_files_path=processed_file_path, @@ -262,7 +254,6 @@ def test_apply_business_rules_success( job_run_id=1, rules_path=PLANETS_RULES_PATH, submitted_files_path=None, - reference_data_loader=SparkRefDataLoader, spark=spark, ) @@ -296,9 +287,6 @@ def test_apply_business_rules_with_data_errors( # pylint: disable=redefined-out spark_test_database, ): sub_info, processed_file_path = planets_data_after_data_contract_that_break_business_rules - - SparkRefDataLoader.spark = spark - SparkRefDataLoader.dataset_config_uri = fh.get_parent(PLANETS_RULES_PATH) with SparkAuditingManager(spark_test_database, ThreadPoolExecutor(1), spark) as audit_manager: dve_pipeline = SparkDVEPipeline( @@ -307,7 +295,6 @@ def test_apply_business_rules_with_data_errors( # pylint: disable=redefined-out job_run_id=1, rules_path=PLANETS_RULES_PATH, submitted_files_path=None, - reference_data_loader=SparkRefDataLoader, spark=spark, ) @@ -380,9 +367,6 @@ def test_business_rule_step( ): # pylint: disable=redefined-outer-name sub_info, processed_files_path = planets_data_after_data_contract - SparkRefDataLoader.spark = spark - SparkRefDataLoader.dataset_config_uri = fh.get_parent(PLANETS_RULES_PATH) - with SparkAuditingManager(spark_test_database, ThreadPoolExecutor(1), spark) as audit_manager: dve_pipeline = SparkDVEPipeline( processed_files_path=processed_files_path, @@ -390,7 +374,6 @@ def test_business_rule_step( job_run_id=1, rules_path=PLANETS_RULES_PATH, submitted_files_path=None, - reference_data_loader=SparkRefDataLoader, spark=spark, ) audit_manager.add_new_submissions([sub_info], job_run_id=1) @@ -416,15 +399,12 @@ def test_error_report_where_report_is_expected( # pylint: disable=redefined-out ): sub_info, processed_file_path = error_data_after_business_rules - SparkRefDataLoader.spark = spark - dve_pipeline = SparkDVEPipeline( processed_files_path=processed_file_path, audit_tables=None, job_run_id=1, rules_path=PLANETS_RULES_PATH, submitted_files_path=None, - reference_data_loader=SparkRefDataLoader, spark=spark, ) @@ -545,7 +525,6 @@ def test_error_report_step( job_run_id=1, rules_path=None, submitted_files_path=None, - reference_data_loader=None, spark=spark, ) @@ -564,8 +543,7 @@ def test_error_report_step( def test_cluster_pipeline_run( spark: SparkSession, planet_test_files: str, spark_test_database ): # pylint: disable=redefined-outer-name - SparkRefDataLoader.spark = spark - SparkRefDataLoader.dataset_config_uri = fh.get_parent(PLANETS_RULES_PATH) + audit_manager = SparkAuditingManager(spark_test_database, ThreadPoolExecutor(1), spark) dve_pipeline = SparkDVEPipeline( @@ -574,7 +552,6 @@ def test_cluster_pipeline_run( job_run_id=1, rules_path=PLANETS_RULES_PATH, submitted_files_path=planet_test_files, - reference_data_loader=SparkRefDataLoader, spark=spark, ) @@ -595,7 +572,6 @@ def test_get_submission_status(spark, spark_test_database): job_run_id=1, rules_path=None, submitted_files_path=None, - reference_data_loader=None, spark=spark, ) dve_pipeline._logger = Mock(spec=logging.Logger) From a4d49b66e1d7537d44d9f5307335d360e383311d Mon Sep 17 00:00:00 2001 From: stevenhsd <56357022+stevenhsd@users.noreply.github.com> Date: Wed, 29 Apr 2026 10:09:50 +0100 Subject: [PATCH 2/3] style: address formatting, linting and type checking issues --- src/dve/core_engine/backends/base/backend.py | 3 ++- .../backends/base/reference_data.py | 4 ++-- src/dve/core_engine/backends/exceptions.py | 5 +++-- .../implementations/duckdb/duckdb_helpers.py | 2 +- .../implementations/duckdb/reference_data.py | 13 +++++------- .../backends/implementations/spark/backend.py | 20 +++++++++++-------- .../implementations/spark/reference_data.py | 6 ++---- src/dve/pipeline/duckdb_pipeline.py | 19 ++++++++++-------- src/dve/pipeline/pipeline.py | 18 +++++++---------- src/dve/pipeline/spark_pipeline.py | 20 ++++++++++--------- 10 files changed, 56 insertions(+), 54 deletions(-) diff --git a/src/dve/core_engine/backends/base/backend.py b/src/dve/core_engine/backends/base/backend.py index 507ede8..f627412 100644 --- a/src/dve/core_engine/backends/base/backend.py +++ b/src/dve/core_engine/backends/base/backend.py @@ -3,7 +3,7 @@ import logging import warnings from abc import ABC, abstractmethod -from collections.abc import Mapping, MutableMapping +from collections.abc import MutableMapping from typing import Any, ClassVar, Generic, Optional from pyspark.sql import DataFrame, SparkSession @@ -67,6 +67,7 @@ def load_reference_data( reference_entity_config: dict[EntityName, ReferenceConfigUnion], submission_info: Optional[SubmissionInfo], ) -> BaseRefDataLoader[EntityType]: + """Supply configured reference data loader for use with business rules""" raise NotImplementedError() @abstractmethod diff --git a/src/dve/core_engine/backends/base/reference_data.py b/src/dve/core_engine/backends/base/reference_data.py index 187b798..9010e8d 100644 --- a/src/dve/core_engine/backends/base/reference_data.py +++ b/src/dve/core_engine/backends/base/reference_data.py @@ -208,8 +208,8 @@ def __getitem__(self, key: EntityName) -> EntityType: try: config = self.reference_entity_config[key] return self.load_entity(entity_name=key, config=config) - except TypeError: - raise NoRefDataConfigSupplied() + except TypeError as err: + raise NoRefDataConfigSupplied() from err except Exception as err: raise MissingRefDataEntity(entity_name=key) from err diff --git a/src/dve/core_engine/backends/exceptions.py b/src/dve/core_engine/backends/exceptions.py index d5b90cf..6878fc2 100644 --- a/src/dve/core_engine/backends/exceptions.py +++ b/src/dve/core_engine/backends/exceptions.py @@ -118,9 +118,10 @@ def get_message_preamble(self) -> str: """ return f"Missing reference data entity {self.entity_name!r}" + class NoRefDataConfigSupplied(BackendError): """An error raised when trying to load a refdata entity when no refdata - config has been supplied. + config has been supplied. """ @@ -129,7 +130,7 @@ def __init__(self, *args: object) -> None: def get_message_preamble(self) -> EntityName: """Message for logging purposes""" - return f"Refdata loader not supplied with refdata config - unable to load refdata entities" + return "Refdata loader not supplied with refdata config - unable to load refdata entities" class ConstraintError(ValueError, BackendErrorMixin): diff --git a/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py b/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py index 394cd01..627822b 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py +++ b/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py @@ -411,7 +411,7 @@ def get_duckdb_cast_statement_from_annotation( stmt = rf"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{date_regex}') THEN TRY_CAST(TRIM({quoted_name}) as DATE) ELSE NULL END" # pylint: disable=C0301 return stmt if issubclass(type_, time): - stmt = rf"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{time_regex}') THEN TRY_CAST(TRIM({quoted_name}) as TIME) ELSE NULL END" # pylint: disable=C0301 + stmt = rf"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{time_regex}') THEN TRY_CAST(TRIM({quoted_name}) as TIME) ELSE NULL END" # pylint: disable=C0301 return stmt duck_type = get_duckdb_type_from_annotation(type_) if duck_type: diff --git a/src/dve/core_engine/backends/implementations/duckdb/reference_data.py b/src/dve/core_engine/backends/implementations/duckdb/reference_data.py index b3f47e3..431c16e 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/reference_data.py +++ b/src/dve/core_engine/backends/implementations/duckdb/reference_data.py @@ -1,14 +1,11 @@ """A reference data loader for duckdb.""" -from typing import Optional - from duckdb import DuckDBPyConnection, DuckDBPyRelation from pyarrow import ipc # type: ignore from dve.core_engine.backends.base.reference_data import ( BaseRefDataLoader, ReferenceConfig, - ReferenceConfigUnion, ReferenceTable, mark_refdata_file_extension, ) @@ -19,18 +16,18 @@ # pylint: disable=too-few-public-methods class DuckDBRefDataLoader(BaseRefDataLoader[DuckDBPyRelation]): """A reference data loader using already existing DuckDB tables. - reference_entity_config and dataset_config_uri (if config uses relative paths) - should be supplied using setter methods for the dataset being processed before running.""" + reference_entity_config and dataset_config_uri (if config uses relative paths) + should be supplied using setter methods for the dataset being processed before running.""" def __init__( self, connection: DuckDBPyConnection, reference_data_config: dict[EntityName, ReferenceConfig], dataset_config_uri: URI, - **kwargs + **kwargs, ) -> None: - super().__init__(reference_data_config, dataset_config_uri,**kwargs) - + super().__init__(reference_data_config, dataset_config_uri, **kwargs) + self.connection = connection if not self.connection: diff --git a/src/dve/core_engine/backends/implementations/spark/backend.py b/src/dve/core_engine/backends/implementations/spark/backend.py index e943b03..94b0650 100644 --- a/src/dve/core_engine/backends/implementations/spark/backend.py +++ b/src/dve/core_engine/backends/implementations/spark/backend.py @@ -17,7 +17,7 @@ from dve.core_engine.loggers import get_child_logger, get_logger from dve.core_engine.models import SubmissionInfo from dve.core_engine.type_hints import URI, EntityName, EntityParquetLocations -from dve.parser.file_handling import get_resource_exists, joinuri, get_parent +from dve.parser.file_handling import get_resource_exists, joinuri class SparkBackend(BaseBackend[DataFrame]): @@ -50,18 +50,22 @@ def __init__( logger=get_child_logger("SparkStepImplementations", logger) ) super().__init__(contract, steps, logger, **kwargs) - - def load_reference_data(self, + + def load_reference_data( + self, reference_entity_config: dict[EntityName, ReferenceConfigUnion], - submission_info: Optional[SubmissionInfo],): + submission_info: Optional[SubmissionInfo], + ): """Load the reference data as specified in the reference entity config.""" - sub_info_entity: Optional[EntityType] = None + sub_info_entity: Optional[DataFrame] = None if submission_info: sub_info_entity = self.convert_submission_info(submission_info) - reference_data_loader = SparkRefDataLoader(spark=self.spark_session, - reference_data_config=reference_entity_config, - dataset_config_uri=self.dataset_config_uri) + reference_data_loader = SparkRefDataLoader( + spark=self.spark_session, + reference_data_config=reference_entity_config, + dataset_config_uri=self.dataset_config_uri, # type: ignore + ) if sub_info_entity is not None: reference_data_loader.entity_cache["dve_submission_info"] = sub_info_entity diff --git a/src/dve/core_engine/backends/implementations/spark/reference_data.py b/src/dve/core_engine/backends/implementations/spark/reference_data.py index 08507ee..60f8613 100644 --- a/src/dve/core_engine/backends/implementations/spark/reference_data.py +++ b/src/dve/core_engine/backends/implementations/spark/reference_data.py @@ -1,8 +1,6 @@ # pylint: disable=no-member """A reference data loader for Spark.""" -from typing import Optional - from pyspark.sql import DataFrame, SparkSession from dve.core_engine.backends.base.reference_data import ( @@ -18,8 +16,8 @@ # pylint: disable=too-few-public-methods class SparkRefDataLoader(BaseRefDataLoader[DataFrame]): """A reference data loader using already existing Apache Spark Tables. - reference_entity_config and dataset_config_uri (if config uses relative paths) - should be supplied using setter methods for the dataset being processed before running.""" + reference_entity_config and dataset_config_uri (if config uses relative paths) + should be supplied using setter methods for the dataset being processed before running.""" def __init__( self, diff --git a/src/dve/pipeline/duckdb_pipeline.py b/src/dve/pipeline/duckdb_pipeline.py index 713001f..6dc446c 100644 --- a/src/dve/pipeline/duckdb_pipeline.py +++ b/src/dve/pipeline/duckdb_pipeline.py @@ -5,6 +5,7 @@ from duckdb import DuckDBPyConnection, DuckDBPyRelation +import dve.parser.file_handling as fh from dve.core_engine.backends.base.reference_data import BaseRefDataLoader, ReferenceConfig from dve.core_engine.backends.implementations.duckdb.auditing import DDBAuditingManager from dve.core_engine.backends.implementations.duckdb.contract import DuckDBDataContract @@ -14,7 +15,6 @@ from dve.core_engine.models import SubmissionInfo from dve.core_engine.type_hints import URI from dve.pipeline.pipeline import BaseDVEPipeline -import dve.parser.file_handling as fh # pylint: disable=abstract-method @@ -47,13 +47,16 @@ def __init__( logger, ) - def get_reference_data_loader(self, - reference_data_config: dict[str, ReferenceConfig], - **kwargs) -> BaseRefDataLoader[DuckDBPyRelation]: - return DuckDBRefDataLoader(connection=self._connection, - reference_data_config=reference_data_config, - dataset_config_uri=fh.get_parent(self._rules_path), - **kwargs) + def get_reference_data_loader( + self, reference_data_config: dict[str, ReferenceConfig], **kwargs + ) -> DuckDBRefDataLoader: + return DuckDBRefDataLoader( + connection=self._connection, + reference_data_config=reference_data_config, + dataset_config_uri=fh.get_parent(self._rules_path), # type: ignore + **kwargs + ) + # pylint: disable=arguments-differ def write_file_to_parquet( # type: ignore self, submission_file_uri: URI, submission_info: SubmissionInfo, output: URI diff --git a/src/dve/pipeline/pipeline.py b/src/dve/pipeline/pipeline.py index f22073e..bc6dbf8 100644 --- a/src/dve/pipeline/pipeline.py +++ b/src/dve/pipeline/pipeline.py @@ -112,13 +112,12 @@ def step_implementations(self) -> Optional[BaseStepImplementations[EntityType]]: def get_entity_count(entity: EntityType) -> int: """Get a row count of an entity stored as parquet""" raise NotImplementedError() - - def get_reference_data_loader(self, - reference_data_config: dict[EntityName, ReferenceConfig], - **kwargs) -> BaseRefDataLoader[EntityType]: + + def get_reference_data_loader( + self, reference_data_config: dict[EntityName, ReferenceConfig], **kwargs + ) -> BaseRefDataLoader: """Get reference data loader if required for business rules""" raise NotImplementedError() - def get_submission_status( self, step_name: DVEStageName, submission_id: str @@ -533,7 +532,7 @@ def data_contract_step( return processed_files, failed_processing - def apply_business_rules( # pylint: disable=R0914 + def apply_business_rules( # pylint: disable=R0914 self, submission_info: SubmissionInfo, submission_status: Optional[SubmissionStatus] = None ) -> tuple[SubmissionInfo, SubmissionStatus]: """Apply the business rules to a given submission, the submission may have failed at the @@ -559,7 +558,7 @@ def apply_business_rules( # pylint: disable=R0914 self._processed_files_path, submission_info.submission_id ) ref_data = config.get_reference_data_config() - reference_data = self.get_reference_data_loader(reference_data_config=ref_data) + reference_data: BaseRefDataLoader = self.get_reference_data_loader(reference_data_config=ref_data) rules = config.get_rule_metadata() entities = {} contract = fh.joinuri( @@ -585,10 +584,7 @@ def apply_business_rules( # pylint: disable=R0914 key_fields = {model: conf.reporting_fields for model, conf in model_config.items()} _errors_uri, rules_success = self.step_implementations.apply_rules( # type: ignore - working_directory, - entity_manager, - rules, - key_fields + working_directory, entity_manager, rules, key_fields ) rule_messages = load_feedback_messages( diff --git a/src/dve/pipeline/spark_pipeline.py b/src/dve/pipeline/spark_pipeline.py index 30d42ad..4a1ecc9 100644 --- a/src/dve/pipeline/spark_pipeline.py +++ b/src/dve/pipeline/spark_pipeline.py @@ -6,6 +6,7 @@ from pyspark.sql import DataFrame, SparkSession +import dve.parser.file_handling as fh from dve.core_engine.backends.base.reference_data import BaseRefDataLoader, ReferenceConfig from dve.core_engine.backends.implementations.spark.auditing import SparkAuditingManager from dve.core_engine.backends.implementations.spark.contract import SparkDataContract @@ -16,7 +17,6 @@ from dve.core_engine.type_hints import URI from dve.pipeline.pipeline import BaseDVEPipeline from dve.pipeline.utils import SubmissionStatus, unpersist_all_rdds -import dve.parser.file_handling as fh # pylint: disable=abstract-method @@ -48,14 +48,16 @@ def __init__( job_run_id, logger, ) - - def get_reference_data_loader(self, - reference_data_config: dict[str, ReferenceConfig], - **kwargs) -> BaseRefDataLoader[DataFrame]: - return SparkRefDataLoader(spark=self._spark, - reference_data_config=reference_data_config, - dataset_config_uri=fh.get_parent(self._rules_path), - **kwargs) + + def get_reference_data_loader( + self, reference_data_config: dict[str, ReferenceConfig], **kwargs + ) -> SparkRefDataLoader: + return SparkRefDataLoader( + spark=self._spark, + reference_data_config=reference_data_config, + dataset_config_uri=fh.get_parent(self._rules_path), # type: ignore + **kwargs + ) # pylint: disable=arguments-differ def write_file_to_parquet( # type: ignore From ae9f3aae439cad5bd026f47b4568822fb3c23a72 Mon Sep 17 00:00:00 2001 From: stevenhsd <56357022+stevenhsd@users.noreply.github.com> Date: Wed, 29 Apr 2026 11:45:53 +0100 Subject: [PATCH 3/3] style: address review comments and linting issues --- .../backends/implementations/duckdb/reference_data.py | 4 +--- .../core_engine/backends/implementations/spark/backend.py | 3 +-- .../backends/implementations/spark/reference_data.py | 4 +--- src/dve/pipeline/duckdb_pipeline.py | 6 +++--- src/dve/pipeline/pipeline.py | 7 ++++--- src/dve/pipeline/spark_pipeline.py | 6 +++--- 6 files changed, 13 insertions(+), 17 deletions(-) diff --git a/src/dve/core_engine/backends/implementations/duckdb/reference_data.py b/src/dve/core_engine/backends/implementations/duckdb/reference_data.py index 431c16e..c059811 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/reference_data.py +++ b/src/dve/core_engine/backends/implementations/duckdb/reference_data.py @@ -15,9 +15,7 @@ # pylint: disable=too-few-public-methods class DuckDBRefDataLoader(BaseRefDataLoader[DuckDBPyRelation]): - """A reference data loader using already existing DuckDB tables. - reference_entity_config and dataset_config_uri (if config uses relative paths) - should be supplied using setter methods for the dataset being processed before running.""" + """A reference data loader using already existing DuckDB tables.""" def __init__( self, diff --git a/src/dve/core_engine/backends/implementations/spark/backend.py b/src/dve/core_engine/backends/implementations/spark/backend.py index 94b0650..126e07a 100644 --- a/src/dve/core_engine/backends/implementations/spark/backend.py +++ b/src/dve/core_engine/backends/implementations/spark/backend.py @@ -12,7 +12,6 @@ from dve.core_engine.backends.implementations.spark.rules import SparkStepImplementations from dve.core_engine.backends.implementations.spark.spark_helpers import get_type_from_annotation from dve.core_engine.backends.implementations.spark.types import SparkEntities -from dve.core_engine.backends.types import EntityType from dve.core_engine.constants import RECORD_INDEX_COLUMN_NAME from dve.core_engine.loggers import get_child_logger, get_logger from dve.core_engine.models import SubmissionInfo @@ -64,7 +63,7 @@ def load_reference_data( reference_data_loader = SparkRefDataLoader( spark=self.spark_session, reference_data_config=reference_entity_config, - dataset_config_uri=self.dataset_config_uri, # type: ignore + dataset_config_uri=self.dataset_config_uri, # type: ignore ) if sub_info_entity is not None: reference_data_loader.entity_cache["dve_submission_info"] = sub_info_entity diff --git a/src/dve/core_engine/backends/implementations/spark/reference_data.py b/src/dve/core_engine/backends/implementations/spark/reference_data.py index 60f8613..44f49af 100644 --- a/src/dve/core_engine/backends/implementations/spark/reference_data.py +++ b/src/dve/core_engine/backends/implementations/spark/reference_data.py @@ -15,9 +15,7 @@ # pylint: disable=too-few-public-methods class SparkRefDataLoader(BaseRefDataLoader[DataFrame]): - """A reference data loader using already existing Apache Spark Tables. - reference_entity_config and dataset_config_uri (if config uses relative paths) - should be supplied using setter methods for the dataset being processed before running.""" + """A reference data loader using already existing Apache Spark Tables.""" def __init__( self, diff --git a/src/dve/pipeline/duckdb_pipeline.py b/src/dve/pipeline/duckdb_pipeline.py index 6dc446c..0370106 100644 --- a/src/dve/pipeline/duckdb_pipeline.py +++ b/src/dve/pipeline/duckdb_pipeline.py @@ -6,7 +6,7 @@ from duckdb import DuckDBPyConnection, DuckDBPyRelation import dve.parser.file_handling as fh -from dve.core_engine.backends.base.reference_data import BaseRefDataLoader, ReferenceConfig +from dve.core_engine.backends.base.reference_data import ReferenceConfig from dve.core_engine.backends.implementations.duckdb.auditing import DDBAuditingManager from dve.core_engine.backends.implementations.duckdb.contract import DuckDBDataContract from dve.core_engine.backends.implementations.duckdb.duckdb_helpers import duckdb_get_entity_count @@ -47,13 +47,13 @@ def __init__( logger, ) - def get_reference_data_loader( + def init_reference_data_loader( self, reference_data_config: dict[str, ReferenceConfig], **kwargs ) -> DuckDBRefDataLoader: return DuckDBRefDataLoader( connection=self._connection, reference_data_config=reference_data_config, - dataset_config_uri=fh.get_parent(self._rules_path), # type: ignore + dataset_config_uri=fh.get_parent(self._rules_path), # type: ignore **kwargs ) diff --git a/src/dve/pipeline/pipeline.py b/src/dve/pipeline/pipeline.py index bc6dbf8..91ff2ee 100644 --- a/src/dve/pipeline/pipeline.py +++ b/src/dve/pipeline/pipeline.py @@ -68,7 +68,6 @@ def __init__( self._submitted_files_path = submitted_files_path self._processed_files_path = processed_files_path self._rules_path = rules_path - self._reference_data_loader = None self._job_run_id = job_run_id self._audit_tables = audit_tables self._data_contract = data_contract @@ -113,7 +112,7 @@ def get_entity_count(entity: EntityType) -> int: """Get a row count of an entity stored as parquet""" raise NotImplementedError() - def get_reference_data_loader( + def init_reference_data_loader( self, reference_data_config: dict[EntityName, ReferenceConfig], **kwargs ) -> BaseRefDataLoader: """Get reference data loader if required for business rules""" @@ -558,7 +557,9 @@ def apply_business_rules( # pylint: disable=R0914 self._processed_files_path, submission_info.submission_id ) ref_data = config.get_reference_data_config() - reference_data: BaseRefDataLoader = self.get_reference_data_loader(reference_data_config=ref_data) + reference_data: BaseRefDataLoader = self.init_reference_data_loader( + reference_data_config=ref_data + ) rules = config.get_rule_metadata() entities = {} contract = fh.joinuri( diff --git a/src/dve/pipeline/spark_pipeline.py b/src/dve/pipeline/spark_pipeline.py index 4a1ecc9..201abbf 100644 --- a/src/dve/pipeline/spark_pipeline.py +++ b/src/dve/pipeline/spark_pipeline.py @@ -7,7 +7,7 @@ from pyspark.sql import DataFrame, SparkSession import dve.parser.file_handling as fh -from dve.core_engine.backends.base.reference_data import BaseRefDataLoader, ReferenceConfig +from dve.core_engine.backends.base.reference_data import ReferenceConfig from dve.core_engine.backends.implementations.spark.auditing import SparkAuditingManager from dve.core_engine.backends.implementations.spark.contract import SparkDataContract from dve.core_engine.backends.implementations.spark.reference_data import SparkRefDataLoader @@ -49,13 +49,13 @@ def __init__( logger, ) - def get_reference_data_loader( + def init_reference_data_loader( self, reference_data_config: dict[str, ReferenceConfig], **kwargs ) -> SparkRefDataLoader: return SparkRefDataLoader( spark=self._spark, reference_data_config=reference_data_config, - dataset_config_uri=fh.get_parent(self._rules_path), # type: ignore + dataset_config_uri=fh.get_parent(self._rules_path), # type: ignore **kwargs )