Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion src/dve/core_engine/backends/base/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
from pydantic import BaseModel
from typing_extensions import Protocol

from dve.core_engine.backends.exceptions import ReaderLacksEntityTypeSupport
from dve.core_engine.backends.exceptions import MessageBearingError, ReaderLacksEntityTypeSupport
from dve.core_engine.backends.types import EntityName, EntityType
from dve.core_engine.message import FeedbackMessage
from dve.core_engine.type_hints import URI, ArbitraryFunction, WrapDecorator
from dve.parser.file_handling.service import open_stream

T = TypeVar("T")
ET_co = TypeVar("ET_co", covariant=True)
Expand Down Expand Up @@ -116,6 +118,8 @@ def read_to_entity_type(
if entity_name == Iterator[dict[str, Any]]:
return self.read_to_py_iterator(resource, entity_name, schema) # type: ignore

self.raise_if_not_sensible_file(resource, entity_name)

try:
reader_func = self.__read_methods__[entity_type]
except KeyError as err:
Expand All @@ -137,3 +141,36 @@ def write_parquet(

"""
raise NotImplementedError(f"write_parquet not implemented in {self.__class__}")

@staticmethod
def _check_likely_text_file(resource: URI) -> bool:
"""Quick sense check of file to see if it looks like text
- not 100% full proof, but hopefully enough to weed out most
non-text files"""
with open_stream(resource, "rb") as fle:
start_chunk = fle.read(4096)
# check for BOM character - utf-16 can contain NULL bytes
if start_chunk.startswith((b"\xff\xfe", b"\xfe\xff")):
return True
# if null byte in - unlikely text
if b"\x00" in start_chunk:
return False
return True

def raise_if_not_sensible_file(self, resource: URI, entity_name: str):
"""Sense check that the file is a text file. Raise error if doesn't
appear to be the case."""
if not self._check_likely_text_file(resource):
raise MessageBearingError(
"The submitted file doesn't appear to be text",
messages=[
FeedbackMessage(
entity=entity_name,
record=None,
failure_type="submission",
error_location="Whole File",
error_code="MalformedFile",
error_message="The resource doesn't seem to be a valid text file",
)
],
)
59 changes: 57 additions & 2 deletions src/dve/core_engine/backends/implementations/duckdb/readers/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
get_duckdb_type_from_annotation,
)
from dve.core_engine.backends.implementations.duckdb.types import SQLType
from dve.core_engine.backends.readers.utilities import check_csv_header_expected
from dve.core_engine.backends.utilities import get_polars_type_from_annotation
from dve.core_engine.message import FeedbackMessage
from dve.core_engine.type_hints import URI, EntityName
Expand All @@ -24,7 +25,14 @@

@duckdb_write_parquet
class DuckDBCSVReader(BaseFileReader):
"""A reader for CSV files"""
"""A reader for CSV files including the ability to compare the passed model
to the file header, if it exists.

field_check: flag to compare submitted file header to the accompanying pydantic model
field_check_error_code: The error code to provide if the file header doesn't contain
the expected fields
field_check_error_message: The error message to provide if the file header doesn't contain
the expected fields"""

# TODO - the read_to_relation should include the schema and determine whether to
# TODO - stringify or not
Expand All @@ -35,15 +43,43 @@ def __init__(
delim: str = ",",
quotechar: str = '"',
connection: Optional[DuckDBPyConnection] = None,
field_check: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

think if we are adding these attributes then we need to add them to the docstrings explaining to the user what they are. Not as obvious vs existing attrs like header, delim etc.

field_check_error_code: Optional[str] = "ExpectedVsActualFieldMismatch",
field_check_error_message: Optional[str] = "The submitted header is missing fields",
**_,
):
self.header = header
self.delim = delim
self.quotechar = quotechar
self._connection = connection if connection else default_connection
self.field_check = field_check
self.field_check_error_code = field_check_error_code
self.field_check_error_message = field_check_error_message

super().__init__()

def perform_field_check(
self, resource: URI, entity_name: str, expected_schema: type[BaseModel]
):
"""Check that the header of the CSV aligns with the provided model"""
if not self.header:
raise ValueError("Cannot perform field check without a CSV header")

if missing := check_csv_header_expected(resource, expected_schema, self.delim):
raise MessageBearingError(
"The CSV header doesn't match what is expected",
messages=[
FeedbackMessage(
entity=entity_name,
record=None,
failure_type="submission",
error_location="Whole File",
error_code=self.field_check_error_code,
error_message=f"{self.field_check_error_message} - missing fields: {missing}", # pylint: disable=line-too-long
)
],
)

def read_to_py_iterator(
self, resource: URI, entity_name: EntityName, schema: type[BaseModel]
) -> Iterator[dict[str, Any]]:
Expand All @@ -58,6 +94,9 @@ def read_to_relation( # pylint: disable=unused-argument
if get_content_length(resource) == 0:
raise EmptyFileError(f"File at {resource} is empty.")

if self.field_check:
self.perform_field_check(resource, entity_name, schema)

reader_options: dict[str, Any] = {
"header": self.header,
"delimiter": self.delim,
Expand Down Expand Up @@ -89,6 +128,9 @@ def read_to_relation( # pylint: disable=unused-argument
if get_content_length(resource) == 0:
raise EmptyFileError(f"File at {resource} is empty.")

if self.field_check:
self.perform_field_check(resource, entity_name, schema)

reader_options: dict[str, Any] = {
"has_header": self.header,
"separator": self.delim,
Expand Down Expand Up @@ -132,6 +174,17 @@ class DuckDBCSVRepeatingHeaderReader(PolarsToDuckDBCSVReader):
| shop1 | clothes | 2025-01-01 |
"""

def __init__(
self,
*args,
non_unique_header_error_code: Optional[str] = "NonUniqueHeader",
non_unique_header_error_message: Optional[str] = None,
**kwargs,
):
self._non_unique_header_code = non_unique_header_error_code
self._non_unique_header_message = non_unique_header_error_message
super().__init__(*args, **kwargs)

@read_function(DuckDBPyRelation)
def read_to_relation( # pylint: disable=unused-argument
self, resource: URI, entity_name: EntityName, schema: type[BaseModel]
Expand All @@ -156,10 +209,12 @@ def read_to_relation( # pylint: disable=unused-argument
failure_type="submission",
error_message=(
f"Found {no_records} distinct combination of header values."
if not self._non_unique_header_message
else self._non_unique_header_message
),
error_location=entity_name,
category="Bad file",
error_code="NonUniqueHeader",
error_code=self._non_unique_header_code,
)
],
)
Expand Down
21 changes: 21 additions & 0 deletions src/dve/core_engine/backends/readers/utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""General utilities for file readers"""

from typing import Optional

from pydantic import BaseModel

from dve.core_engine.type_hints import URI
from dve.parser.file_handling.service import open_stream


def check_csv_header_expected(
resource: URI,
expected_schema: type[BaseModel],
delimiter: Optional[str] = ",",
quote_char: str = '"',
) -> set[str]:
"""Check the header of a CSV matches the expected fields"""
with open_stream(resource) as fle:
header_fields = fle.readline().rstrip().replace(quote_char, "").split(delimiter)
expected_fields = expected_schema.__fields__.keys()
return set(expected_fields).difference(header_fields)
5 changes: 4 additions & 1 deletion src/dve/pipeline/foundry_ddb_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def error_report(
self._logger.exception(exc)
sub_stats = None
report_uri = None
submission_status = submission_status if submission_status else SubmissionStatus()
submission_status.processing_failed = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

build a submission status if there isn't a submission status. Then you don't need if sub_stats on line 152. Check the error_report child method to check if there any scenarios where the submission_status is returned.

dump_processing_errors(
fh.joinuri(self.processed_files_path, submission_info.submission_id),
"error_report",
Expand Down Expand Up @@ -148,7 +150,8 @@ def run_pipeline(
sub_info, sub_status, sub_stats, report_uri = self.error_report(
submission_info=submission_info, submission_status=sub_status
)
self._audit_tables.add_submission_statistics_records(sub_stats=[sub_stats])
if sub_stats:
self._audit_tables.add_submission_statistics_records(sub_stats=[sub_stats])
except Exception as err: # pylint: disable=W0718
self._logger.error(
f"During processing of submission_id: {sub_id}, this exception was raised: {err}"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from typing import Dict, List
import pytest

from dve.core_engine.backends.implementations.duckdb.utilities import (
Expand All @@ -16,7 +15,7 @@
),
],
)
def test_expr_mapping_to_columns(expressions: Dict[str, str], expected: list[str]):
def test_expr_mapping_to_columns(expressions: dict[str, str], expected: list[str]):
observed = expr_mapping_to_columns(expressions)
assert observed == expected

Expand Down Expand Up @@ -51,6 +50,7 @@ def test_expr_mapping_to_columns(expressions: Dict[str, str], expected: list[str
),
],
)
def test_expr_array_to_columns(expressions: Dict[str, str], expected: list[str]):
def test_expr_array_to_columns(expressions: dict[str, str], expected: list[str]):
observed = expr_array_to_columns(expressions)
assert observed == expected

Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_ddb_json_reader_all_str(temp_json_file):
expected_fields = [fld for fld in mdl.__fields__]
reader = DuckDBJSONReader()
rel: DuckDBPyRelation = reader.read_to_entity_type(
DuckDBPyRelation, uri, "test", stringify_model(mdl)
DuckDBPyRelation, uri.as_posix(), "test", stringify_model(mdl)
)
assert rel.columns == expected_fields
assert dict(zip(rel.columns, rel.dtypes)) == {fld: "VARCHAR" for fld in expected_fields}
Expand All @@ -68,7 +68,7 @@ def test_ddb_json_reader_cast(temp_json_file):
uri, data, mdl = temp_json_file
expected_fields = [fld for fld in mdl.__fields__]
reader = DuckDBJSONReader()
rel: DuckDBPyRelation = reader.read_to_entity_type(DuckDBPyRelation, uri, "test", mdl)
rel: DuckDBPyRelation = reader.read_to_entity_type(DuckDBPyRelation, uri.as_posix(), "test", mdl)

assert rel.columns == expected_fields
assert dict(zip(rel.columns, rel.dtypes)) == {
Expand All @@ -82,7 +82,7 @@ def test_ddb_csv_write_parquet(temp_json_file):
uri, _, mdl = temp_json_file
reader = DuckDBJSONReader()
rel: DuckDBPyRelation = reader.read_to_entity_type(
DuckDBPyRelation, uri, "test", stringify_model(mdl)
DuckDBPyRelation, uri.as_posix(), "test", stringify_model(mdl)
)
target_loc: Path = uri.parent.joinpath("test_parquet.parquet").as_posix()
reader.write_parquet(rel, target_loc)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import datetime as dt
from pathlib import Path
import tempfile
from uuid import uuid4

import pytest
from pydantic import BaseModel, create_model

from dve.core_engine.backends.readers.utilities import check_csv_header_expected

@pytest.mark.parametrize(
["header_row", "delim", "schema", "expected"],
[
(
"field1,field2,field3",
",",
{"field1": (str, ...), "field2": (int, ...), "field3": (float, 1.2)},
set(),
),
(
"field2,field3,field1",
",",
{"field1": (str, ...), "field2": (int, ...), "field3": (float, 1.2)},
set(),
),
(
"str_field|int_field|date_field|",
",",
{"str_field": (str, ...), "int_field": (int, ...), "date_field": (dt.date, dt.date.today())},
{"str_field","int_field","date_field"},
),
(
'"str_field"|"int_field"|"date_field"',
"|",
{"str_field": (str, ...), "int_field": (int, ...), "date_field": (dt.date, dt.date.today())},
set(),
),
(
'str_field,int_field,date_field\n',
",",
{"str_field": (str, ...), "int_field": (int, ...), "date_field": (dt.date, dt.date.today())},
set(),
),

],
)
def test_check_csv_header_expected(
header_row: str, delim: str, schema: type[BaseModel], expected: set[str]
):
mdl = create_model("TestModel", **schema)
with tempfile.TemporaryDirectory() as tmpdir:
fle = Path(tmpdir).joinpath(f"test_file_{uuid4().hex}.csv")
fle.open("w+").write(header_row)
res = check_csv_header_expected(fle.as_posix(), mdl, delim)
assert res == expected