Skip to content

Commit 61827ae

Browse files
committed
test: added small tests for duckdb and spark utils
1 parent 4c3454b commit 61827ae

File tree

10 files changed

+115
-24
lines changed

10 files changed

+115
-24
lines changed

src/dve/core_engine/backends/implementations/duckdb/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Implementation of duckdb backend"""
2+
23
from dve.core_engine.backends.implementations.duckdb.readers.json import DuckDBJSONReader
34
from dve.core_engine.backends.readers import register_reader
45

src/dve/core_engine/backends/implementations/duckdb/utilities.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Utility objects for use with duckdb backend"""
22

33
import itertools
4+
45
from dve.core_engine.backends.base.utilities import _split_multiexpr_string
56

67

@@ -27,10 +28,9 @@ def expr_array_to_columns(expressions: list[str]) -> list[str]:
2728
"""Create list of duckdb expressions from list of expressions"""
2829
return list(
2930
itertools.chain.from_iterable(
30-
_split_multiexpr_string(expression)
31-
for expression in expressions
32-
)
31+
_split_multiexpr_string(expression) for expression in expressions
3332
)
33+
)
3434

3535

3636
def multiexpr_string_to_columns(expressions: str) -> list[str]:

src/dve/core_engine/backends/implementations/spark/spark_helpers.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,7 @@
1212
from dataclasses import dataclass, is_dataclass
1313
from decimal import Decimal
1414
from functools import wraps
15-
from typing import (
16-
Any,
17-
ClassVar,
18-
Optional,
19-
TypeVar,
20-
Union,
21-
overload,
22-
)
15+
from typing import Any, ClassVar, Optional, TypeVar, Union, overload
2316

2417
from delta.exceptions import ConcurrentAppendException, DeltaConcurrentModificationException
2518
from pydantic import BaseModel

src/dve/core_engine/backends/implementations/spark/utilities.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Some utilities which are useful for implementing Spark transformations."""
22

33
import datetime as dt
4+
import itertools
45
from collections.abc import Callable
56
from json import JSONEncoder
67
from operator import and_, or_
@@ -70,7 +71,13 @@ def expr_mapping_to_columns(expressions: ExpressionMapping) -> list[Column]:
7071

7172
def expr_array_to_columns(expressions: ExpressionArray) -> list[Column]:
7273
"""Convert an array of expressions to a list of columns."""
73-
return list(map(sf.expr, expressions))
74+
75+
_expr_list = list(
76+
itertools.chain.from_iterable(
77+
_split_multiexpr_string(expression) for expression in expressions
78+
)
79+
)
80+
return list(map(sf.expr, _expr_list))
7481

7582

7683
def multiexpr_string_to_columns(expressions: MultiExpression) -> list[Column]:

src/dve/core_engine/backends/readers/xml.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,7 @@
33

44
import re
55
from collections.abc import Collection, Iterator
6-
from typing import (
7-
IO,
8-
Any,
9-
GenericAlias, # type: ignore
10-
Optional,
11-
Union,
12-
overload
13-
)
6+
from typing import IO, Any, GenericAlias, Optional, Union, overload # type: ignore
147

158
import polars as pl
169
from lxml import etree # type: ignore

src/dve/core_engine/backends/utilities.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from dataclasses import is_dataclass
55
from datetime import date, datetime
66
from decimal import Decimal
7-
from typing import Any, ClassVar, Union
87
from typing import GenericAlias # type: ignore
8+
from typing import Any, ClassVar, Union
99

1010
import polars as pl # type: ignore
1111
from polars.datatypes.classes import DataTypeClass as PolarsType
@@ -39,7 +39,9 @@
3939

4040
def stringify_type(type_: Union[type, GenericAlias]) -> type:
4141
"""Stringify an individual type."""
42-
if isinstance(type_, type) and not isinstance(type_, GenericAlias): # A model, return the contents. # pylint: disable=C0301
42+
if isinstance(type_, type) and not isinstance(
43+
type_, GenericAlias
44+
): # A model, return the contents. # pylint: disable=C0301
4345
if issubclass(type_, BaseModel):
4446
return stringify_model(type_)
4547

src/dve/core_engine/message.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
import copy
44
import datetime as dt
5-
import operator
65
import json
6+
import operator
77
from collections.abc import Callable
88
from decimal import Decimal
99
from functools import reduce

src/dve/core_engine/type_hints.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
from pathlib import Path
77
from queue import Queue as ThreadQueue
88
from typing import TYPE_CHECKING, Any, List, Optional, TypeVar, Union # pylint: disable=W1901
9-
# TODO - cannot remove List from Typing. See L60 for details.
109

1110
from pyspark.sql import DataFrame
1211
from pyspark.sql.types import StructType
1312
from typing_extensions import Literal, ParamSpec, get_args
1413

14+
# TODO - cannot remove List from Typing. See L60 for details.
15+
1516

1617
if TYPE_CHECKING: # pragma: no cover
1718
from dve.core_engine.message import FeedbackMessage
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from typing import Dict, List
2+
import pytest
3+
4+
from dve.core_engine.backends.implementations.duckdb.utilities import (
5+
expr_mapping_to_columns,
6+
expr_array_to_columns,
7+
)
8+
9+
10+
@pytest.mark.parametrize(
11+
["expressions", "expected"],
12+
[
13+
(
14+
{"size(array_field)": "field_length", "another_field": "rename_another_field"},
15+
["size(array_field) as field_length", "another_field as rename_another_field"],
16+
),
17+
],
18+
)
19+
def test_expr_mapping_to_columns(expressions: Dict[str, str], expected: list[str]):
20+
observed = expr_mapping_to_columns(expressions)
21+
assert observed == expected
22+
23+
24+
@pytest.mark.parametrize(
25+
["expressions", "expected"],
26+
[
27+
(
28+
[
29+
"a_field",
30+
"another_field as renamed",
31+
"struct(a_field, another_field) as struct_field",
32+
],
33+
[
34+
"a_field",
35+
"another_field as renamed",
36+
"struct(a_field, another_field) as struct_field",
37+
],
38+
),
39+
(
40+
[
41+
"size(array_field)",
42+
"another_field as rename_another_field",
43+
"a_dynamic_field, another_dynamic_field",
44+
],
45+
[
46+
"size(array_field)",
47+
"another_field as rename_another_field",
48+
"a_dynamic_field",
49+
"another_dynamic_field",
50+
],
51+
),
52+
],
53+
)
54+
def test_expr_array_to_columns(expressions: Dict[str, str], expected: list[str]):
55+
observed = expr_array_to_columns(expressions)
56+
assert observed == expected
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import pytest
2+
from pyspark.sql.functions import expr
3+
4+
from dve.core_engine.backends.implementations.spark.utilities import (
5+
expr_mapping_to_columns,
6+
expr_array_to_columns,
7+
)
8+
9+
10+
@pytest.mark.parametrize(
11+
["expressions"],
12+
[
13+
(
14+
{"size(array_field)": "field_length", "another_field": "rename_another_field"},
15+
),
16+
]
17+
)
18+
def test_expr_mapping_to_columns(spark, expressions: dict[str, str]):
19+
observed = expr_mapping_to_columns(expressions)
20+
assert [cl._jc.toString() for cl in observed] == [expr(expression).alias(rename)._jc.toString() for expression, rename in expressions.items()]
21+
22+
23+
@pytest.mark.parametrize(
24+
["expressions", "expected"],
25+
[
26+
(
27+
["a_field", "another_field as renamed", "struct(a_field, another_field) as struct_field"],
28+
["a_field", "another_field as renamed", "struct(a_field, another_field) as struct_field"]
29+
),
30+
(
31+
["size(array_field)", "another_field as rename_another_field", "a_dynamic_field, another_dynamic_field"],
32+
["size(array_field)", "another_field as rename_another_field", "a_dynamic_field", "another_dynamic_field"],
33+
),
34+
],
35+
)
36+
def test_expr_array_to_columns(spark, expressions: dict[str, str], expected: list[str]):
37+
observed = expr_array_to_columns(expressions)
38+
assert [cl._jc.toString() for cl in observed] == [expr(expression)._jc.toString() for expression in expected]

0 commit comments

Comments
 (0)