diff --git a/src/dve/core_engine/backends/base/reader.py b/src/dve/core_engine/backends/base/reader.py index 9862e7e..54abaa9 100644 --- a/src/dve/core_engine/backends/base/reader.py +++ b/src/dve/core_engine/backends/base/reader.py @@ -8,9 +8,11 @@ from pydantic import BaseModel from typing_extensions import Protocol -from dve.core_engine.backends.exceptions import ReaderLacksEntityTypeSupport +from dve.core_engine.backends.exceptions import MessageBearingError, ReaderLacksEntityTypeSupport from dve.core_engine.backends.types import EntityName, EntityType +from dve.core_engine.message import FeedbackMessage from dve.core_engine.type_hints import URI, ArbitraryFunction, WrapDecorator +from dve.parser.file_handling.service import open_stream T = TypeVar("T") ET_co = TypeVar("ET_co", covariant=True) @@ -116,6 +118,8 @@ def read_to_entity_type( if entity_name == Iterator[dict[str, Any]]: return self.read_to_py_iterator(resource, entity_name, schema) # type: ignore + self.raise_if_not_sensible_file(resource, entity_name) + try: reader_func = self.__read_methods__[entity_type] except KeyError as err: @@ -137,3 +141,36 @@ def write_parquet( """ raise NotImplementedError(f"write_parquet not implemented in {self.__class__}") + + @staticmethod + def _check_likely_text_file(resource: URI) -> bool: + """Quick sense check of file to see if it looks like text + - not 100% full proof, but hopefully enough to weed out most + non-text files""" + with open_stream(resource, "rb") as fle: + start_chunk = fle.read(4096) + # check for BOM character - utf-16 can contain NULL bytes + if start_chunk.startswith((b"\xff\xfe", b"\xfe\xff")): + return True + # if null byte in - unlikely text + if b"\x00" in start_chunk: + return False + return True + + def raise_if_not_sensible_file(self, resource: URI, entity_name: str): + """Sense check that the file is a text file. Raise error if doesn't + appear to be the case.""" + if not self._check_likely_text_file(resource): + raise MessageBearingError( + "The submitted file doesn't appear to be text", + messages=[ + FeedbackMessage( + entity=entity_name, + record=None, + failure_type="submission", + error_location="Whole File", + error_code="MalformedFile", + error_message="The resource doesn't seem to be a valid text file", + ) + ], + ) diff --git a/src/dve/core_engine/backends/implementations/duckdb/readers/csv.py b/src/dve/core_engine/backends/implementations/duckdb/readers/csv.py index 3998bf5..ff65d9f 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/readers/csv.py +++ b/src/dve/core_engine/backends/implementations/duckdb/readers/csv.py @@ -16,6 +16,7 @@ get_duckdb_type_from_annotation, ) from dve.core_engine.backends.implementations.duckdb.types import SQLType +from dve.core_engine.backends.readers.utilities import check_csv_header_expected from dve.core_engine.backends.utilities import get_polars_type_from_annotation from dve.core_engine.message import FeedbackMessage from dve.core_engine.type_hints import URI, EntityName @@ -24,7 +25,14 @@ @duckdb_write_parquet class DuckDBCSVReader(BaseFileReader): - """A reader for CSV files""" + """A reader for CSV files including the ability to compare the passed model + to the file header, if it exists. + + field_check: flag to compare submitted file header to the accompanying pydantic model + field_check_error_code: The error code to provide if the file header doesn't contain + the expected fields + field_check_error_message: The error message to provide if the file header doesn't contain + the expected fields""" # TODO - the read_to_relation should include the schema and determine whether to # TODO - stringify or not @@ -35,15 +43,43 @@ def __init__( delim: str = ",", quotechar: str = '"', connection: Optional[DuckDBPyConnection] = None, + field_check: bool = False, + field_check_error_code: Optional[str] = "ExpectedVsActualFieldMismatch", + field_check_error_message: Optional[str] = "The submitted header is missing fields", **_, ): self.header = header self.delim = delim self.quotechar = quotechar self._connection = connection if connection else default_connection + self.field_check = field_check + self.field_check_error_code = field_check_error_code + self.field_check_error_message = field_check_error_message super().__init__() + def perform_field_check( + self, resource: URI, entity_name: str, expected_schema: type[BaseModel] + ): + """Check that the header of the CSV aligns with the provided model""" + if not self.header: + raise ValueError("Cannot perform field check without a CSV header") + + if missing := check_csv_header_expected(resource, expected_schema, self.delim): + raise MessageBearingError( + "The CSV header doesn't match what is expected", + messages=[ + FeedbackMessage( + entity=entity_name, + record=None, + failure_type="submission", + error_location="Whole File", + error_code=self.field_check_error_code, + error_message=f"{self.field_check_error_message} - missing fields: {missing}", # pylint: disable=line-too-long + ) + ], + ) + def read_to_py_iterator( self, resource: URI, entity_name: EntityName, schema: type[BaseModel] ) -> Iterator[dict[str, Any]]: @@ -58,6 +94,9 @@ def read_to_relation( # pylint: disable=unused-argument if get_content_length(resource) == 0: raise EmptyFileError(f"File at {resource} is empty.") + if self.field_check: + self.perform_field_check(resource, entity_name, schema) + reader_options: dict[str, Any] = { "header": self.header, "delimiter": self.delim, @@ -89,6 +128,9 @@ def read_to_relation( # pylint: disable=unused-argument if get_content_length(resource) == 0: raise EmptyFileError(f"File at {resource} is empty.") + if self.field_check: + self.perform_field_check(resource, entity_name, schema) + reader_options: dict[str, Any] = { "has_header": self.header, "separator": self.delim, @@ -132,6 +174,17 @@ class DuckDBCSVRepeatingHeaderReader(PolarsToDuckDBCSVReader): | shop1 | clothes | 2025-01-01 | """ + def __init__( + self, + *args, + non_unique_header_error_code: Optional[str] = "NonUniqueHeader", + non_unique_header_error_message: Optional[str] = None, + **kwargs, + ): + self._non_unique_header_code = non_unique_header_error_code + self._non_unique_header_message = non_unique_header_error_message + super().__init__(*args, **kwargs) + @read_function(DuckDBPyRelation) def read_to_relation( # pylint: disable=unused-argument self, resource: URI, entity_name: EntityName, schema: type[BaseModel] @@ -156,10 +209,12 @@ def read_to_relation( # pylint: disable=unused-argument failure_type="submission", error_message=( f"Found {no_records} distinct combination of header values." + if not self._non_unique_header_message + else self._non_unique_header_message ), error_location=entity_name, category="Bad file", - error_code="NonUniqueHeader", + error_code=self._non_unique_header_code, ) ], ) diff --git a/src/dve/core_engine/backends/readers/utilities.py b/src/dve/core_engine/backends/readers/utilities.py new file mode 100644 index 0000000..642c0b2 --- /dev/null +++ b/src/dve/core_engine/backends/readers/utilities.py @@ -0,0 +1,21 @@ +"""General utilities for file readers""" + +from typing import Optional + +from pydantic import BaseModel + +from dve.core_engine.type_hints import URI +from dve.parser.file_handling.service import open_stream + + +def check_csv_header_expected( + resource: URI, + expected_schema: type[BaseModel], + delimiter: Optional[str] = ",", + quote_char: str = '"', +) -> set[str]: + """Check the header of a CSV matches the expected fields""" + with open_stream(resource) as fle: + header_fields = fle.readline().rstrip().replace(quote_char, "").split(delimiter) + expected_fields = expected_schema.__fields__.keys() + return set(expected_fields).difference(header_fields) diff --git a/src/dve/pipeline/foundry_ddb_pipeline.py b/src/dve/pipeline/foundry_ddb_pipeline.py index 4c72375..f667d6e 100644 --- a/src/dve/pipeline/foundry_ddb_pipeline.py +++ b/src/dve/pipeline/foundry_ddb_pipeline.py @@ -109,6 +109,8 @@ def error_report( self._logger.exception(exc) sub_stats = None report_uri = None + submission_status = submission_status if submission_status else SubmissionStatus() + submission_status.processing_failed = True dump_processing_errors( fh.joinuri(self.processed_files_path, submission_info.submission_id), "error_report", @@ -148,7 +150,8 @@ def run_pipeline( sub_info, sub_status, sub_stats, report_uri = self.error_report( submission_info=submission_info, submission_status=sub_status ) - self._audit_tables.add_submission_statistics_records(sub_stats=[sub_stats]) + if sub_stats: + self._audit_tables.add_submission_statistics_records(sub_stats=[sub_stats]) except Exception as err: # pylint: disable=W0718 self._logger.error( f"During processing of submission_id: {sub_id}, this exception was raised: {err}" diff --git a/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_ddb_utils.py b/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_ddb_utils.py index 8490ab5..2899dc6 100644 --- a/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_ddb_utils.py +++ b/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_ddb_utils.py @@ -1,4 +1,3 @@ -from typing import Dict, List import pytest from dve.core_engine.backends.implementations.duckdb.utilities import ( @@ -16,7 +15,7 @@ ), ], ) -def test_expr_mapping_to_columns(expressions: Dict[str, str], expected: list[str]): +def test_expr_mapping_to_columns(expressions: dict[str, str], expected: list[str]): observed = expr_mapping_to_columns(expressions) assert observed == expected @@ -51,6 +50,7 @@ def test_expr_mapping_to_columns(expressions: Dict[str, str], expected: list[str ), ], ) -def test_expr_array_to_columns(expressions: Dict[str, str], expected: list[str]): +def test_expr_array_to_columns(expressions: dict[str, str], expected: list[str]): observed = expr_array_to_columns(expressions) assert observed == expected + diff --git a/tests/test_core_engine/test_backends/test_readers/test_ddb_json.py b/tests/test_core_engine/test_backends/test_readers/test_ddb_json.py index 900632d..c326fef 100644 --- a/tests/test_core_engine/test_backends/test_readers/test_ddb_json.py +++ b/tests/test_core_engine/test_backends/test_readers/test_ddb_json.py @@ -57,7 +57,7 @@ def test_ddb_json_reader_all_str(temp_json_file): expected_fields = [fld for fld in mdl.__fields__] reader = DuckDBJSONReader() rel: DuckDBPyRelation = reader.read_to_entity_type( - DuckDBPyRelation, uri, "test", stringify_model(mdl) + DuckDBPyRelation, uri.as_posix(), "test", stringify_model(mdl) ) assert rel.columns == expected_fields assert dict(zip(rel.columns, rel.dtypes)) == {fld: "VARCHAR" for fld in expected_fields} @@ -68,7 +68,7 @@ def test_ddb_json_reader_cast(temp_json_file): uri, data, mdl = temp_json_file expected_fields = [fld for fld in mdl.__fields__] reader = DuckDBJSONReader() - rel: DuckDBPyRelation = reader.read_to_entity_type(DuckDBPyRelation, uri, "test", mdl) + rel: DuckDBPyRelation = reader.read_to_entity_type(DuckDBPyRelation, uri.as_posix(), "test", mdl) assert rel.columns == expected_fields assert dict(zip(rel.columns, rel.dtypes)) == { @@ -82,7 +82,7 @@ def test_ddb_csv_write_parquet(temp_json_file): uri, _, mdl = temp_json_file reader = DuckDBJSONReader() rel: DuckDBPyRelation = reader.read_to_entity_type( - DuckDBPyRelation, uri, "test", stringify_model(mdl) + DuckDBPyRelation, uri.as_posix(), "test", stringify_model(mdl) ) target_loc: Path = uri.parent.joinpath("test_parquet.parquet").as_posix() reader.write_parquet(rel, target_loc) diff --git a/tests/test_core_engine/test_backends/test_readers/test_utilities.py b/tests/test_core_engine/test_backends/test_readers/test_utilities.py new file mode 100644 index 0000000..4426769 --- /dev/null +++ b/tests/test_core_engine/test_backends/test_readers/test_utilities.py @@ -0,0 +1,55 @@ +import datetime as dt +from pathlib import Path +import tempfile +from uuid import uuid4 + +import pytest +from pydantic import BaseModel, create_model + +from dve.core_engine.backends.readers.utilities import check_csv_header_expected + +@pytest.mark.parametrize( + ["header_row", "delim", "schema", "expected"], + [ + ( + "field1,field2,field3", + ",", + {"field1": (str, ...), "field2": (int, ...), "field3": (float, 1.2)}, + set(), + ), + ( + "field2,field3,field1", + ",", + {"field1": (str, ...), "field2": (int, ...), "field3": (float, 1.2)}, + set(), + ), + ( + "str_field|int_field|date_field|", + ",", + {"str_field": (str, ...), "int_field": (int, ...), "date_field": (dt.date, dt.date.today())}, + {"str_field","int_field","date_field"}, + ), + ( + '"str_field"|"int_field"|"date_field"', + "|", + {"str_field": (str, ...), "int_field": (int, ...), "date_field": (dt.date, dt.date.today())}, + set(), + ), + ( + 'str_field,int_field,date_field\n', + ",", + {"str_field": (str, ...), "int_field": (int, ...), "date_field": (dt.date, dt.date.today())}, + set(), + ), + + ], +) +def test_check_csv_header_expected( + header_row: str, delim: str, schema: type[BaseModel], expected: set[str] +): + mdl = create_model("TestModel", **schema) + with tempfile.TemporaryDirectory() as tmpdir: + fle = Path(tmpdir).joinpath(f"test_file_{uuid4().hex}.csv") + fle.open("w+").write(header_row) + res = check_csv_header_expected(fle.as_posix(), mdl, delim) + assert res == expected \ No newline at end of file