Skip to content

Commit 4a3f89d

Browse files
committed
feat: Added new option to check csv headers in duckdb csv readers
1 parent 0f0e72c commit 4a3f89d

File tree

3 files changed

+114
-5
lines changed

3 files changed

+114
-5
lines changed

src/dve/core_engine/backends/implementations/duckdb/readers/csv.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
get_duckdb_type_from_annotation,
1717
)
1818
from dve.core_engine.backends.implementations.duckdb.types import SQLType
19+
from dve.core_engine.backends.implementations.duckdb.utilities import check_csv_header_expected
1920
from dve.core_engine.backends.utilities import get_polars_type_from_annotation
2021
from dve.core_engine.message import FeedbackMessage
2122
from dve.core_engine.type_hints import URI, EntityName
22-
from dve.parser.file_handling import get_content_length
23+
from dve.parser.file_handling import get_content_length, open_stream
2324

2425

2526
@duckdb_write_parquet
@@ -35,15 +36,46 @@ def __init__(
3536
delim: str = ",",
3637
quotechar: str = '"',
3738
connection: Optional[DuckDBPyConnection] = None,
39+
field_check: bool = False,
40+
field_check_error_code: Optional[str] = "ExpectedVsActualFieldMismatch",
41+
field_check_error_message: Optional[str] = "The submitted header does not match what is expected",
3842
**_,
3943
):
4044
self.header = header
4145
self.delim = delim
4246
self.quotechar = quotechar
4347
self._connection = connection if connection else default_connection
48+
self.field_check = field_check
49+
self.field_check_error_code = field_check_error_code
50+
self.field_check_error_message = field_check_error_message
4451

4552
super().__init__()
4653

54+
def perform_field_check(
55+
self, resource: URI, entity_name: str, expected_schema: type[BaseModel]
56+
):
57+
if not self.header:
58+
raise ValueError("Cannot perform field check without a CSV header")
59+
60+
if missing := check_csv_header_expected(
61+
resource,
62+
expected_schema,
63+
self.delim
64+
):
65+
raise MessageBearingError(
66+
"The CSV header doesn't match what is expected",
67+
messages=[
68+
FeedbackMessage(
69+
entity=entity_name,
70+
failure_type="submission",
71+
error_location="Whole File",
72+
error_code=self.field_check_error_code,
73+
error_message=self.field_check_error_message,
74+
value=f"Missing fields: {missing}",
75+
)
76+
],
77+
)
78+
4779
def read_to_py_iterator(
4880
self, resource: URI, entity_name: EntityName, schema: type[BaseModel]
4981
) -> Iterator[dict[str, Any]]:
@@ -58,6 +90,9 @@ def read_to_relation( # pylint: disable=unused-argument
5890
if get_content_length(resource) == 0:
5991
raise EmptyFileError(f"File at {resource} is empty.")
6092

93+
if self.field_check:
94+
self.perform_field_check(resource, entity_name, schema)
95+
6196
reader_options: dict[str, Any] = {
6297
"header": self.header,
6398
"delimiter": self.delim,
@@ -89,6 +124,9 @@ def read_to_relation( # pylint: disable=unused-argument
89124
if get_content_length(resource) == 0:
90125
raise EmptyFileError(f"File at {resource} is empty.")
91126

127+
if self.field_check:
128+
self.perform_field_check(resource, entity_name, schema)
129+
92130
reader_options: dict[str, Any] = {
93131
"has_header": self.header,
94132
"separator": self.delim,
@@ -132,6 +170,12 @@ class DuckDBCSVRepeatingHeaderReader(PolarsToDuckDBCSVReader):
132170
| shop1 | clothes | 2025-01-01 |
133171
"""
134172

173+
def __init__(
174+
self, non_unique_header_error_code: Optional[str] = "NonUniqueHeader", *args, **kwargs
175+
):
176+
self._non_unique_header_code = non_unique_header_error_code
177+
super().__init__(*args, **kwargs)
178+
135179
@read_function(DuckDBPyRelation)
136180
def read_to_relation( # pylint: disable=unused-argument
137181
self, resource: URI, entity_name: EntityName, schema: type[BaseModel]
@@ -159,7 +203,7 @@ def read_to_relation( # pylint: disable=unused-argument
159203
),
160204
error_location=entity_name,
161205
category="Bad file",
162-
error_code="NonUniqueHeader",
206+
error_code=self._non_unique_header_code,
163207
)
164208
],
165209
)

src/dve/core_engine/backends/implementations/duckdb/utilities.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
11
"""Utility objects for use with duckdb backend"""
22

33
import itertools
4+
from typing import Optional
5+
6+
from pydantic import BaseModel
47

58
from dve.core_engine.backends.base.utilities import _split_multiexpr_string
9+
from dve.core_engine.backends.exceptions import MessageBearingError
10+
from dve.core_engine.message import FeedbackMessage
11+
from dve.core_engine.type_hints import URI
12+
from dve.parser.file_handling import open_stream
613

714

815
def parse_multiple_expressions(expressions) -> list[str]:
@@ -39,3 +46,15 @@ def multiexpr_string_to_columns(expressions: str) -> list[str]:
3946
"""
4047
expression_list = _split_multiexpr_string(expressions)
4148
return expr_array_to_columns(expression_list)
49+
50+
def check_csv_header_expected(
51+
resource: URI,
52+
expected_schema: type[BaseModel],
53+
delimiter: Optional[str] = ",",
54+
quote_char: str = '"') -> set[str]:
55+
"""Check the header of a CSV matches the expected fields"""
56+
with open_stream(resource) as fle:
57+
header_fields = fle.readline().replace(quote_char,"").split(delimiter)
58+
expected_fields = expected_schema.__fields__.keys()
59+
return set(expected_fields).difference(header_fields)
60+

tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_ddb_utils.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
1-
from typing import Dict, List
1+
import tempfile
2+
import datetime as dt
3+
from pathlib import Path
4+
from uuid import uuid4
5+
from pydantic import BaseModel, create_model
26
import pytest
37

48
from dve.core_engine.backends.implementations.duckdb.utilities import (
59
expr_mapping_to_columns,
610
expr_array_to_columns,
11+
check_csv_header_expected,
712
)
813

914

@@ -16,7 +21,7 @@
1621
),
1722
],
1823
)
19-
def test_expr_mapping_to_columns(expressions: Dict[str, str], expected: list[str]):
24+
def test_expr_mapping_to_columns(expressions: dict[str, str], expected: list[str]):
2025
observed = expr_mapping_to_columns(expressions)
2126
assert observed == expected
2227

@@ -51,6 +56,47 @@ def test_expr_mapping_to_columns(expressions: Dict[str, str], expected: list[str
5156
),
5257
],
5358
)
54-
def test_expr_array_to_columns(expressions: Dict[str, str], expected: list[str]):
59+
def test_expr_array_to_columns(expressions: dict[str, str], expected: list[str]):
5560
observed = expr_array_to_columns(expressions)
5661
assert observed == expected
62+
63+
64+
@pytest.mark.parametrize(
65+
["header_row", "delim", "schema", "expected"],
66+
[
67+
(
68+
"field1,field2,field3",
69+
",",
70+
{"field1": (str, ...), "field2": (int, ...), "field3": (float, 1.2)},
71+
set(),
72+
),
73+
(
74+
"field2,field3,field1",
75+
",",
76+
{"field1": (str, ...), "field2": (int, ...), "field3": (float, 1.2)},
77+
set(),
78+
),
79+
(
80+
"str_field|int_field|date_field|",
81+
",",
82+
{"str_field": (str, ...), "int_field": (int, ...), "date_field": (dt.date, dt.date.today())},
83+
{"str_field","int_field","date_field"},
84+
),
85+
(
86+
'"str_field"|"int_field"|"date_field"',
87+
"|",
88+
{"str_field": (str, ...), "int_field": (int, ...), "date_field": (dt.date, dt.date.today())},
89+
set(),
90+
),
91+
92+
],
93+
)
94+
def test_check_csv_header_expected(
95+
header_row: str, delim: str, schema: type[BaseModel], expected: set[str]
96+
):
97+
mdl = create_model("TestModel", **schema)
98+
with tempfile.TemporaryDirectory() as tmpdir:
99+
fle = Path(tmpdir).joinpath(f"test_file_{uuid4().hex}.csv")
100+
fle.open("w+").write(header_row)
101+
res = check_csv_header_expected(fle.as_posix(), mdl, delim)
102+
assert res == expected

0 commit comments

Comments
 (0)