From 9c4a375215bbcaeabe876e99f7e1153e5fe25e6c Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 9 Mar 2025 09:30:58 -0400 Subject: [PATCH 01/24] Run python tests on all currently supported python versions --- .github/workflows/test.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index c1d9ac838..da3582766 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -33,9 +33,11 @@ jobs: fail-fast: false matrix: python-version: + - "3.9" - "3.10" - "3.11" - "3.12" + - "3.13" toolchain: - "stable" From 99dcb4f3be640b0402a56e7b3fba358c3a4cbaa3 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 9 Mar 2025 09:35:48 -0400 Subject: [PATCH 02/24] Update ruff checks to select all --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d16a18aa6..32ed5d155 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,7 +66,8 @@ features = ["substrait"] # Enable docstring linting using the google style guide [tool.ruff.lint] -select = ["E4", "E7", "E9", "F", "FA", "D", "W", "I"] +select = ["ALL" ] +ignore = ["COM812", "ISC001"] # Recommended to ignore these rules when using with ruff-format [tool.ruff.lint.pydocstyle] convention = "google" From 5306d1e6704192750fccb9710db56fdbd4631f9f Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 9 Mar 2025 09:36:15 -0400 Subject: [PATCH 03/24] Ruff auto fix --- benchmarks/tpch/tpch.py | 14 ++--- dev/release/check-rat-report.py | 2 +- dev/release/generate-changelog.py | 8 +-- examples/tpch/_tests.py | 15 ++--- python/datafusion/__init__.py | 38 ++++++------ python/datafusion/common.py | 12 ++-- python/datafusion/context.py | 3 +- python/datafusion/dataframe.py | 6 +- python/datafusion/expr.py | 94 ++++++++++++++--------------- python/datafusion/input/base.py | 2 - python/datafusion/input/location.py | 2 +- python/datafusion/object_store.py | 2 +- python/datafusion/plan.py | 2 +- python/datafusion/substrait.py | 6 +- python/datafusion/udf.py | 18 ++---- python/tests/test_dataframe.py | 12 ++-- python/tests/test_input.py | 4 +- 17 files changed, 109 insertions(+), 131 deletions(-) diff --git a/benchmarks/tpch/tpch.py b/benchmarks/tpch/tpch.py index fb86b12b6..bfb9ac398 100644 --- a/benchmarks/tpch/tpch.py +++ b/benchmarks/tpch/tpch.py @@ -59,13 +59,13 @@ def bench(data_path, query_path): end = time.time() time_millis = (end - start) * 1000 total_time_millis += time_millis - print("setup,{}".format(round(time_millis, 1))) - results.write("setup,{}\n".format(round(time_millis, 1))) + print(f"setup,{round(time_millis, 1)}") + results.write(f"setup,{round(time_millis, 1)}\n") results.flush() # run queries for query in range(1, 23): - with open("{}/q{}.sql".format(query_path, query)) as f: + with open(f"{query_path}/q{query}.sql") as f: text = f.read() tmp = text.split(";") queries = [] @@ -83,14 +83,14 @@ def bench(data_path, query_path): end = time.time() time_millis = (end - start) * 1000 total_time_millis += time_millis - print("q{},{}".format(query, round(time_millis, 1))) - results.write("q{},{}\n".format(query, round(time_millis, 1))) + print(f"q{query},{round(time_millis, 1)}") + results.write(f"q{query},{round(time_millis, 1)}\n") results.flush() except Exception as e: print("query", query, "failed", e) - print("total,{}".format(round(total_time_millis, 1))) - results.write("total,{}\n".format(round(total_time_millis, 1))) + print(f"total,{round(total_time_millis, 1)}") + results.write(f"total,{round(total_time_millis, 1)}\n") if __name__ == "__main__": diff --git a/dev/release/check-rat-report.py b/dev/release/check-rat-report.py index d3dd7c5dd..0c9f4c326 100644 --- a/dev/release/check-rat-report.py +++ b/dev/release/check-rat-report.py @@ -29,7 +29,7 @@ exclude_globs_filename = sys.argv[1] xml_filename = sys.argv[2] -globs = [line.strip() for line in open(exclude_globs_filename, "r")] +globs = [line.strip() for line in open(exclude_globs_filename)] tree = ET.parse(xml_filename) root = tree.getroot() diff --git a/dev/release/generate-changelog.py b/dev/release/generate-changelog.py index 2564eea86..f3bda2ef5 100755 --- a/dev/release/generate-changelog.py +++ b/dev/release/generate-changelog.py @@ -26,14 +26,12 @@ def print_pulls(repo_name, title, pulls): if len(pulls) > 0: - print("**{}:**".format(title)) + print(f"**{title}:**") print() for pull, commit in pulls: - url = "https://github.com/{}/pull/{}".format(repo_name, pull.number) + url = f"https://github.com/{repo_name}/pull/{pull.number}" print( - "- {} [#{}]({}) ({})".format( - pull.title, pull.number, url, commit.author.login - ) + f"- {pull.title} [#{pull.number}]({url}) ({commit.author.login})" ) print() diff --git a/examples/tpch/_tests.py b/examples/tpch/_tests.py index c4d872085..2be4dfabd 100644 --- a/examples/tpch/_tests.py +++ b/examples/tpch/_tests.py @@ -27,28 +27,25 @@ def df_selection(col_name, col_type): if col_type == pa.float64() or isinstance(col_type, pa.Decimal128Type): return F.round(col(col_name), lit(2)).alias(col_name) - elif col_type == pa.string() or col_type == pa.string_view(): + if col_type == pa.string() or col_type == pa.string_view(): return F.trim(col(col_name)).alias(col_name) - else: - return col(col_name) + return col(col_name) def load_schema(col_name, col_type): if col_type == pa.int64() or col_type == pa.int32(): return col_name, pa.string() - elif isinstance(col_type, pa.Decimal128Type): + if isinstance(col_type, pa.Decimal128Type): return col_name, pa.float64() - else: - return col_name, col_type + return col_name, col_type def expected_selection(col_name, col_type): if col_type == pa.int64() or col_type == pa.int32(): return F.trim(col(col_name)).cast(col_type).alias(col_name) - elif col_type == pa.string() or col_type == pa.string_view(): + if col_type == pa.string() or col_type == pa.string_view(): return F.trim(col(col_name)).alias(col_name) - else: - return col(col_name) + return col(col_name) def selections_and_schema(original_schema): diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index f11ce54a6..5a506e74e 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -54,38 +54,38 @@ __all__ = [ "Accumulator", + "AggregateUDF", + "Catalog", "Config", - "DataFrame", - "SessionContext", - "SessionConfig", - "SQLOptions", - "RuntimeEnvBuilder", - "Expr", - "ScalarUDF", - "WindowFrame", - "column", - "col", - "literal", - "lit", "DFSchema", - "Catalog", + "DataFrame", "Database", - "Table", - "AggregateUDF", - "WindowUDF", - "LogicalPlan", "ExecutionPlan", + "Expr", + "LogicalPlan", "RecordBatch", "RecordBatchStream", + "RuntimeEnvBuilder", + "SQLOptions", + "ScalarUDF", + "SessionConfig", + "SessionContext", + "Table", + "WindowFrame", + "WindowUDF", + "col", + "column", "common", "expr", "functions", + "lit", + "literal", "object_store", - "substrait", - "read_parquet", "read_avro", "read_csv", "read_json", + "read_parquet", + "substrait", ] diff --git a/python/datafusion/common.py b/python/datafusion/common.py index a2298c634..6dc3a3acb 100644 --- a/python/datafusion/common.py +++ b/python/datafusion/common.py @@ -38,15 +38,15 @@ "DFSchema", "DataType", "DataTypeMap", - "RexType", - "PythonType", - "SqlType", "NullTreatment", - "SqlTable", + "PythonType", + "RexType", + "SqlFunction", "SqlSchema", - "SqlView", "SqlStatistics", - "SqlFunction", + "SqlTable", + "SqlType", + "SqlView", ] diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 282b2a477..04d21a223 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -393,7 +393,6 @@ def with_temp_file_path(self, path: str | pathlib.Path) -> RuntimeEnvBuilder: class RuntimeConfig(RuntimeEnvBuilder): """See `RuntimeEnvBuilder`.""" - pass class SQLOptions: @@ -498,7 +497,7 @@ def __init__( self.ctx = SessionContextInternal(config, runtime) - def enable_url_table(self) -> "SessionContext": + def enable_url_table(self) -> SessionContext: """Control if local files can be queried as tables. Returns: diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index de5d8376e..c1d289c70 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -73,7 +73,7 @@ class Compression(Enum): LZ4_RAW = "lz4_raw" @classmethod - def from_str(cls, value: str) -> "Compression": + def from_str(cls, value: str) -> Compression: """Convert a string to a Compression enum value. Args: @@ -104,9 +104,9 @@ def get_default_level(self) -> Optional[int]: # https://github.com/apache/datafusion-python/pull/981#discussion_r1904789223 if self == Compression.GZIP: return 6 - elif self == Compression.BROTLI: + if self == Compression.BROTLI: return 1 - elif self == Compression.ZSTD: + if self == Compression.ZSTD: return 4 return None diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 3639abec6..702f75aed 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -101,63 +101,63 @@ WindowExpr = expr_internal.WindowExpr __all__ = [ - "Expr", - "Column", - "Literal", - "BinaryExpr", - "Literal", + "Aggregate", "AggregateFunction", - "Not", - "IsNotNull", - "IsNull", - "IsTrue", - "IsFalse", - "IsUnknown", - "IsNotTrue", - "IsNotFalse", - "IsNotUnknown", - "Negative", - "Like", - "ILike", - "SimilarTo", - "ScalarVariable", "Alias", - "InList", - "Exists", - "Subquery", - "InSubquery", - "ScalarSubquery", - "Placeholder", - "GroupingSet", + "Analyze", + "Between", + "BinaryExpr", "Case", "CaseBuilder", "Cast", - "TryCast", - "Between", + "Column", + "CreateMemoryTable", + "CreateView", + "Distinct", + "DropTable", + "EmptyRelation", + "Exists", "Explain", + "Expr", + "Extension", + "Filter", + "GroupingSet", + "ILike", + "InList", + "InSubquery", + "IsFalse", + "IsNotFalse", + "IsNotNull", + "IsNotTrue", + "IsNotUnknown", + "IsNull", + "IsTrue", + "IsUnknown", + "Join", + "JoinConstraint", + "JoinType", + "Like", "Limit", - "Aggregate", + "Literal", + "Literal", + "Negative", + "Not", + "Partitioning", + "Placeholder", + "Projection", + "Repartition", + "ScalarSubquery", + "ScalarVariable", + "SimilarTo", "Sort", "SortExpr", - "Analyze", - "EmptyRelation", - "Join", - "JoinType", - "JoinConstraint", + "Subquery", + "SubqueryAlias", + "TableScan", + "TryCast", "Union", "Unnest", "UnnestExpr", - "Extension", - "Filter", - "Projection", - "TableScan", - "CreateMemoryTable", - "CreateView", - "Distinct", - "SubqueryAlias", - "DropTable", - "Partitioning", - "Repartition", "Window", "WindowExpr", "WindowFrame", @@ -311,7 +311,7 @@ def __getitem__(self, key: str | int) -> Expr: ) return Expr(self.expr.__getitem__(key)) - def __eq__(self, rhs: Any) -> Expr: + def __eq__(self, rhs: object) -> Expr: """Equal to. Accepts either an expression or any valid PyArrow scalar literal value. @@ -320,7 +320,7 @@ def __eq__(self, rhs: Any) -> Expr: rhs = Expr.literal(rhs) return Expr(self.expr.__eq__(rhs.expr)) - def __ne__(self, rhs: Any) -> Expr: + def __ne__(self, rhs: object) -> Expr: """Not equal to. Accepts either an expression or any valid PyArrow scalar literal value. diff --git a/python/datafusion/input/base.py b/python/datafusion/input/base.py index 4eba19784..d53f7b23c 100644 --- a/python/datafusion/input/base.py +++ b/python/datafusion/input/base.py @@ -40,9 +40,7 @@ class BaseInputSource(ABC): @abstractmethod def is_correct_input(self, input_item: Any, table_name: str, **kwargs) -> bool: """Returns `True` if the input is valid.""" - pass @abstractmethod def build_table(self, input_item: Any, table_name: str, **kwarg) -> SqlTable: """Create a table from the input source.""" - pass diff --git a/python/datafusion/input/location.py b/python/datafusion/input/location.py index 517cd1578..6e9736738 100644 --- a/python/datafusion/input/location.py +++ b/python/datafusion/input/location.py @@ -69,7 +69,7 @@ def build_table( # to get that information. However, this should only be occurring # at table creation time and therefore shouldn't # slow down query performance. - with open(input_item, "r") as file: + with open(input_item) as file: reader = csv.reader(file) header_row = next(reader) print(header_row) diff --git a/python/datafusion/object_store.py b/python/datafusion/object_store.py index 7cc17506f..6298526f5 100644 --- a/python/datafusion/object_store.py +++ b/python/datafusion/object_store.py @@ -24,4 +24,4 @@ MicrosoftAzure = object_store.MicrosoftAzure Http = object_store.Http -__all__ = ["AmazonS3", "GoogleCloud", "LocalFileSystem", "MicrosoftAzure", "Http"] +__all__ = ["AmazonS3", "GoogleCloud", "Http", "LocalFileSystem", "MicrosoftAzure"] diff --git a/python/datafusion/plan.py b/python/datafusion/plan.py index 133fc446d..d9f5b81e7 100644 --- a/python/datafusion/plan.py +++ b/python/datafusion/plan.py @@ -27,8 +27,8 @@ from datafusion.context import SessionContext __all__ = [ - "LogicalPlan", "ExecutionPlan", + "LogicalPlan", ] diff --git a/python/datafusion/substrait.py b/python/datafusion/substrait.py index 06302fe38..e8c0e678e 100644 --- a/python/datafusion/substrait.py +++ b/python/datafusion/substrait.py @@ -39,8 +39,8 @@ from datafusion.context import SessionContext __all__ = [ - "Plan", "Consumer", + "Plan", "Producer", "Serde", ] @@ -71,7 +71,6 @@ def encode(self) -> bytes: class plan(Plan): """See `Plan`.""" - pass class Serde: @@ -143,7 +142,6 @@ def deserialize_bytes(proto_bytes: bytes) -> Plan: class serde(Serde): """See `Serde` instead.""" - pass class Producer: @@ -171,7 +169,6 @@ def to_substrait_plan(logical_plan: LogicalPlan, ctx: SessionContext) -> Plan: class producer(Producer): """Use `Producer` instead.""" - pass class Consumer: @@ -197,4 +194,3 @@ def from_substrait_plan(ctx: SessionContext, plan: Plan) -> LogicalPlan: class consumer(Consumer): """Use `Consumer` instead.""" - pass diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index af7bcf2ed..907477604 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -161,9 +161,8 @@ def __new__(cls, *args, **kwargs): if args and callable(args[0]): # Case 1: Used as a function, require the first parameter to be callable return cls._function(*args, **kwargs) - else: - # Case 2: Used as a decorator with parameters - return cls._decorator(*args, **kwargs) + # Case 2: Used as a decorator with parameters + return cls._decorator(*args, **kwargs) @staticmethod def _function( @@ -215,22 +214,18 @@ class Accumulator(metaclass=ABCMeta): @abstractmethod def state(self) -> List[pyarrow.Scalar]: """Return the current state.""" - pass @abstractmethod def update(self, *values: pyarrow.Array) -> None: """Evaluate an array of values and update state.""" - pass @abstractmethod def merge(self, states: List[pyarrow.Array]) -> None: """Merge a set of states.""" - pass @abstractmethod def evaluate(self) -> pyarrow.Scalar: """Return the resultant value.""" - pass class AggregateUDF: @@ -353,9 +348,8 @@ def __new__(cls, *args, **kwargs): if args and callable(args[0]): # Case 1: Used as a function, require the first parameter to be callable return cls._function(*args, **kwargs) - else: - # Case 2: Used as a decorator with parameters - return cls._decorator(*args, **kwargs) + # Case 2: Used as a decorator with parameters + return cls._decorator(*args, **kwargs) @staticmethod def _function( @@ -436,7 +430,6 @@ def memoize(self) -> None: `memoize` is called after each input batch is processed, and such functions can save whatever they need """ - pass def get_range(self, idx: int, num_rows: int) -> tuple[int, int]: """Return the range for the window fuction. @@ -500,7 +493,6 @@ def evaluate_all(self, values: list[pyarrow.Array], num_rows: int) -> pyarrow.Ar avg(x) OVER (PARTITION BY y ORDER BY z ROWS BETWEEN 2 PRECEDING AND 3 FOLLOWING) """ # noqa: W505 - pass def evaluate( self, values: list[pyarrow.Array], eval_range: tuple[int, int] @@ -519,7 +511,6 @@ def evaluate( and evaluation results of ORDER BY expressions. If function has a single argument, `values[1..]` will contain ORDER BY expression results. """ - pass def evaluate_all_with_rank( self, num_rows: int, ranks_in_partition: list[tuple[int, int]] @@ -552,7 +543,6 @@ def evaluate_all_with_rank( The user must implement this method if ``include_rank`` returns True. """ - pass def supports_bounded_execution(self) -> bool: """Can the window function be incrementally computed using bounded memory?""" diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index c636e896a..4a76c5697 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -731,7 +731,7 @@ def test_execution_plan(aggregate_df): plan = aggregate_df.execution_plan() expected = ( - "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[sum(test.c2)]\n" # noqa: E501 + "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[sum(test.c2)]\n" ) assert expected == plan.display() @@ -756,7 +756,7 @@ def test_execution_plan(aggregate_df): ctx = SessionContext() rows_returned = 0 - for idx in range(0, plan.partition_count): + for idx in range(plan.partition_count): stream = ctx.execute(plan, idx) try: batch = stream.next() @@ -969,7 +969,7 @@ def test_execute_stream_to_arrow_table(df, schema): (batch.to_pyarrow() for batch in stream), schema=df.schema() ) else: - pyarrow_table = pa.Table.from_batches((batch.to_pyarrow() for batch in stream)) + pyarrow_table = pa.Table.from_batches(batch.to_pyarrow() for batch in stream) assert isinstance(pyarrow_table, pa.Table) assert pyarrow_table.shape == (3, 3) @@ -1152,7 +1152,7 @@ def test_dataframe_export(df) -> None: table = pa.table(df, schema=desired_schema) assert table.num_columns == 1 assert table.num_rows == 3 - for i in range(0, 3): + for i in range(3): assert table[0][i].as_py() is None # Expect an error when we cannot convert schema @@ -1186,8 +1186,8 @@ def add_with_parameter(df_internal, value: Any) -> DataFrame: result = df.to_pydict() assert result["a"] == [1, 2, 3] - assert result["string_col"] == ["string data" for _i in range(0, 3)] - assert result["new_col"] == [3 for _i in range(0, 3)] + assert result["string_col"] == ["string data" for _i in range(3)] + assert result["new_col"] == [3 for _i in range(3)] def test_dataframe_repr_html(df) -> None: diff --git a/python/tests/test_input.py b/python/tests/test_input.py index 806471357..0b32769ca 100644 --- a/python/tests/test_input.py +++ b/python/tests/test_input.py @@ -27,6 +27,6 @@ def test_location_input(): input_file = cwd + "/testing/data/parquet/generated_simple_numerics/blogs.parquet" table_name = "blog" tbl = location_input.build_table(input_file, table_name) - assert "blog" == tbl.name - assert 3 == len(tbl.columns) + assert tbl.name == "blog" + assert len(tbl.columns) == 3 assert "blogs.parquet" in tbl.filepaths[0] From d66a946aeca754cb62f6e5f4e0e5f3e9ddb8bce9 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 9 Mar 2025 09:42:04 -0400 Subject: [PATCH 04/24] Applying ruff suggestions --- dev/release/generate-changelog.py | 4 +--- pyproject.toml | 2 +- python/datafusion/context.py | 1 - python/datafusion/substrait.py | 4 ---- python/tests/test_wrapper_coverage.py | 7 +++---- 5 files changed, 5 insertions(+), 13 deletions(-) diff --git a/dev/release/generate-changelog.py b/dev/release/generate-changelog.py index f3bda2ef5..e30e2def2 100755 --- a/dev/release/generate-changelog.py +++ b/dev/release/generate-changelog.py @@ -30,9 +30,7 @@ def print_pulls(repo_name, title, pulls): print() for pull, commit in pulls: url = f"https://github.com/{repo_name}/pull/{pull.number}" - print( - f"- {pull.title} [#{pull.number}]({url}) ({commit.author.login})" - ) + print(f"- {pull.title} [#{pull.number}]({url}) ({commit.author.login})") print() diff --git a/pyproject.toml b/pyproject.toml index 32ed5d155..905e36b7d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,7 +77,7 @@ max-doc-length = 88 # Disable docstring checking for these directories [tool.ruff.lint.per-file-ignores] -"python/tests/*" = ["D"] +"python/tests/*" = ["ANN", "ARG", "D", "S101", "SLF", "PD"] "examples/*" = ["D", "W505"] "dev/*" = ["D"] "benchmarks/*" = ["D", "F"] diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 04d21a223..0ab1a908a 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -394,7 +394,6 @@ class RuntimeConfig(RuntimeEnvBuilder): """See `RuntimeEnvBuilder`.""" - class SQLOptions: """Options to be used when performing SQL queries.""" diff --git a/python/datafusion/substrait.py b/python/datafusion/substrait.py index e8c0e678e..36cfa074a 100644 --- a/python/datafusion/substrait.py +++ b/python/datafusion/substrait.py @@ -72,7 +72,6 @@ class plan(Plan): """See `Plan`.""" - class Serde: """Provides the ``Substrait`` serialization and deserialization.""" @@ -143,7 +142,6 @@ class serde(Serde): """See `Serde` instead.""" - class Producer: """Generates substrait plans from a logical plan.""" @@ -170,7 +168,6 @@ class producer(Producer): """Use `Producer` instead.""" - class Consumer: """Generates a logical plan from a substrait plan.""" @@ -193,4 +190,3 @@ def from_substrait_plan(ctx: SessionContext, plan: Plan) -> LogicalPlan: @deprecated("Use `Consumer` instead.") class consumer(Consumer): """Use `Consumer` instead.""" - diff --git a/python/tests/test_wrapper_coverage.py b/python/tests/test_wrapper_coverage.py index ac064ba95..d3ce7c57f 100644 --- a/python/tests/test_wrapper_coverage.py +++ b/python/tests/test_wrapper_coverage.py @@ -19,6 +19,7 @@ import datafusion.functions import datafusion.object_store import datafusion.substrait +import pytest # EnumType introduced in 3.11. 3.10 and prior it was called EnumMeta. try: @@ -41,10 +42,8 @@ def missing_exports(internal_obj, wrapped_obj) -> None: internal_attr = getattr(internal_obj, attr) wrapped_attr = getattr(wrapped_obj, attr) - if internal_attr is not None: - if wrapped_attr is None: - print("Missing attribute: ", attr) - assert False + if internal_attr is not None and wrapped_attr is None: + pytest.fail(f"Missing attribute: {attr}") if attr in ["__self__", "__class__"]: continue From d87e794ca635035ce5060c2ba1927b5fe33603c3 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 9 Mar 2025 09:49:44 -0400 Subject: [PATCH 05/24] noqa rules updates per ruff checks --- pyproject.toml | 2 +- python/tests/test_aggregation.py | 2 +- python/tests/test_imports.py | 4 ++-- python/tests/test_udaf.py | 2 +- python/tests/test_udwf.py | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 905e36b7d..751fe0e31 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,7 +77,7 @@ max-doc-length = 88 # Disable docstring checking for these directories [tool.ruff.lint.per-file-ignores] -"python/tests/*" = ["ANN", "ARG", "D", "S101", "SLF", "PD"] +"python/tests/*" = ["ANN", "ARG", "D", "S101", "SLF", "PD", "PLR2004", "RUF015"] "examples/*" = ["D", "W505"] "dev/*" = ["D"] "benchmarks/*" = ["D", "F"] diff --git a/python/tests/test_aggregation.py b/python/tests/test_aggregation.py index 5ef46131b..acf7dff8e 100644 --- a/python/tests/test_aggregation.py +++ b/python/tests/test_aggregation.py @@ -299,7 +299,7 @@ def test_aggregate_100(df_aggregate_100, name, expr, expected): ] -@pytest.mark.parametrize("name,expr,result", data_test_bitwise_and_boolean_functions) +@pytest.mark.parametrize(("name","expr","result"), data_test_bitwise_and_boolean_functions) def test_bit_and_bool_fns(df, name, expr, result): df = df.aggregate([], [expr.alias(name)]) diff --git a/python/tests/test_imports.py b/python/tests/test_imports.py index 0c155cbde..5313c937f 100644 --- a/python/tests/test_imports.py +++ b/python/tests/test_imports.py @@ -169,14 +169,14 @@ def test_class_module_is_datafusion(): def test_import_from_functions_submodule(): - from datafusion.functions import abs, sin # noqa + from datafusion.functions import abs, sin # noqa: A004 assert functions.abs is abs assert functions.sin is sin msg = "cannot import name 'foobar' from 'datafusion.functions'" with pytest.raises(ImportError, match=msg): - from datafusion.functions import foobar # noqa + from datafusion.functions import foobar # noqa: F401 def test_classes_are_inheritable(): diff --git a/python/tests/test_udaf.py b/python/tests/test_udaf.py index 97cf81f3c..82787fed7 100644 --- a/python/tests/test_udaf.py +++ b/python/tests/test_udaf.py @@ -86,7 +86,7 @@ def test_errors(df): "evaluate, merge, update)" ) with pytest.raises(Exception, match=msg): - accum = udaf( # noqa F841 + accum = udaf( # noqa: F841 MissingMethods, pa.int64(), pa.int64(), diff --git a/python/tests/test_udwf.py b/python/tests/test_udwf.py index 2fea34aa3..c8be0250a 100644 --- a/python/tests/test_udwf.py +++ b/python/tests/test_udwf.py @@ -298,7 +298,7 @@ def test_udwf_errors(df): ] -@pytest.mark.parametrize("name,expr,expected", data_test_udwf_functions) +@pytest.mark.parametrize(("name","expr","expected"), data_test_udwf_functions) def test_udwf_functions(df, name, expr, expected): df = df.select("a", "b", f.round(expr, lit(3)).alias(name)) From 2216d21a7a4df1863e660ba2a0448edfdd56d7ab Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 9 Mar 2025 10:07:34 -0400 Subject: [PATCH 06/24] Working through more ruff suggestions --- pyproject.toml | 2 +- python/tests/test_sql.py | 36 +++++++++++++++++++--------------- python/tests/test_store.py | 13 ++++++------ python/tests/test_substrait.py | 2 +- python/tests/test_udaf.py | 8 +++----- 5 files changed, 31 insertions(+), 30 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 751fe0e31..2d94a9e58 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,7 +77,7 @@ max-doc-length = 88 # Disable docstring checking for these directories [tool.ruff.lint.per-file-ignores] -"python/tests/*" = ["ANN", "ARG", "D", "S101", "SLF", "PD", "PLR2004", "RUF015"] +"python/tests/*" = ["ANN", "ARG", "D", "S101", "SLF", "PD", "PLR2004", "RUF015", "S608", "PLR0913"] "examples/*" = ["D", "W505"] "dev/*" = ["D"] "benchmarks/*" = ["D", "F"] diff --git a/python/tests/test_sql.py b/python/tests/test_sql.py index 862f745bf..3ecbaeadd 100644 --- a/python/tests/test_sql.py +++ b/python/tests/test_sql.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import gzip -import os +from pathlib import Path import numpy as np import pyarrow as pa @@ -47,9 +47,8 @@ def test_register_csv(ctx, tmp_path): ) write_csv(table, path) - with open(path, "rb") as csv_file: - with gzip.open(gzip_path, "wb") as gzipped_file: - gzipped_file.writelines(csv_file) + with Path.open(path, "rb") as csv_file, gzip.open(gzip_path, "wb") as gzipped_file: + gzipped_file.writelines(csv_file) ctx.register_csv("csv", path) ctx.register_csv("csv1", str(path)) @@ -158,7 +157,7 @@ def test_register_parquet(ctx, tmp_path): assert result.to_pydict() == {"cnt": [100]} -@pytest.mark.parametrize("path_to_str", (True, False)) +@pytest.mark.parametrize("path_to_str", [True, False]) def test_register_parquet_partitioned(ctx, tmp_path, path_to_str): dir_root = tmp_path / "dataset_parquet_partitioned" dir_root.mkdir(exist_ok=False) @@ -194,7 +193,7 @@ def test_register_parquet_partitioned(ctx, tmp_path, path_to_str): assert dict(zip(rd["grp"], rd["cnt"])) == {"a": 3, "b": 1} -@pytest.mark.parametrize("path_to_str", (True, False)) +@pytest.mark.parametrize("path_to_str", [True, False]) def test_register_dataset(ctx, tmp_path, path_to_str): path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data()) path = str(path) if path_to_str else path @@ -209,13 +208,15 @@ def test_register_dataset(ctx, tmp_path, path_to_str): def test_register_json(ctx, tmp_path): - path = os.path.dirname(os.path.abspath(__file__)) - test_data_path = os.path.join(path, "data_test_context", "data.json") + path = Path(__file__).parent.resolve() + test_data_path = Path(path) / "data_test_context" / "data.json" gzip_path = tmp_path / "data.json.gz" - with open(test_data_path, "rb") as json_file: - with gzip.open(gzip_path, "wb") as gzipped_file: - gzipped_file.writelines(json_file) + with ( + Path.open(test_data_path, "rb") as json_file, + gzip.open(gzip_path, "wb") as gzipped_file + ): + gzipped_file.writelines(json_file) ctx.register_json("json", test_data_path) ctx.register_json("json1", str(test_data_path)) @@ -470,16 +471,19 @@ def test_simple_select(ctx, tmp_path, arr): # In DF 43.0.0 we now default to having BinaryView and StringView # so the array that is saved to the parquet is slightly different # than the array read. Convert to values for comparison. - if isinstance(result, pa.BinaryViewArray) or isinstance(result, pa.StringViewArray): + if isinstance(result, (pa.BinaryViewArray, pa.StringViewArray)): arr = arr.tolist() result = result.tolist() np.testing.assert_equal(result, arr) -@pytest.mark.parametrize("file_sort_order", (None, [[col("int").sort(True, True)]])) -@pytest.mark.parametrize("pass_schema", (True, False)) -@pytest.mark.parametrize("path_to_str", (True, False)) +@pytest.mark.parametrize("file_sort_order", [ + None, + [[col("int").sort(ascending=True, nulls_first=True)]] + ]) +@pytest.mark.parametrize("pass_schema", [True, False]) +@pytest.mark.parametrize("path_to_str", [True, False]) def test_register_listing_table( ctx, tmp_path, pass_schema, file_sort_order, path_to_str ): @@ -528,7 +532,7 @@ def test_register_listing_table( assert dict(zip(rd["grp"], rd["count"])) == {"a": 5, "b": 2} result = ctx.sql( - "SELECT grp, COUNT(*) AS count FROM my_table WHERE date_id=20201005 GROUP BY grp" + "SELECT grp, COUNT(*) AS count FROM my_table WHERE date_id=20201005 GROUP BY grp" # noqa: E501 ).collect() result = pa.Table.from_batches(result) diff --git a/python/tests/test_store.py b/python/tests/test_store.py index 53ffc3acf..ac9af98f3 100644 --- a/python/tests/test_store.py +++ b/python/tests/test_store.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -import os +from pathlib import Path import pytest from datafusion import SessionContext @@ -23,17 +23,16 @@ @pytest.fixture def ctx(): - ctx = SessionContext() - return ctx + return SessionContext() def test_read_parquet(ctx): ctx.register_parquet( "test", - f"file://{os.getcwd()}/parquet/data/alltypes_plain.parquet", - [], - True, - ".parquet", + f"file://{Path.cwd()}/parquet/data/alltypes_plain.parquet", + table_partition_cols=[], + parquet_pruning=True, + file_extension=".parquet", ) df = ctx.sql("SELECT * FROM test") assert isinstance(df.collect(), list) diff --git a/python/tests/test_substrait.py b/python/tests/test_substrait.py index feada7cde..f367a447d 100644 --- a/python/tests/test_substrait.py +++ b/python/tests/test_substrait.py @@ -50,7 +50,7 @@ def test_substrait_serialization(ctx): substrait_plan = ss.Producer.to_substrait_plan(df.logical_plan(), ctx) -@pytest.mark.parametrize("path_to_str", (True, False)) +@pytest.mark.parametrize("path_to_str", [True, False]) def test_substrait_file_serialization(ctx, tmp_path, path_to_str): batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, 3]), pa.array([4, 5, 6])], diff --git a/python/tests/test_udaf.py b/python/tests/test_udaf.py index 82787fed7..453ff6f4f 100644 --- a/python/tests/test_udaf.py +++ b/python/tests/test_udaf.py @@ -17,8 +17,6 @@ from __future__ import annotations -from typing import List - import pyarrow as pa import pyarrow.compute as pc import pytest @@ -31,7 +29,7 @@ class Summarize(Accumulator): def __init__(self, initial_value: float = 0.0): self._sum = pa.scalar(initial_value) - def state(self) -> List[pa.Scalar]: + def state(self) -> list[pa.Scalar]: return [self._sum] def update(self, values: pa.Array) -> None: @@ -39,7 +37,7 @@ def update(self, values: pa.Array) -> None: # This breaks on `None` self._sum = pa.scalar(self._sum.as_py() + pc.sum(values).as_py()) - def merge(self, states: List[pa.Array]) -> None: + def merge(self, states: list[pa.Array]) -> None: # Not nice since pyarrow scalars can't be summed yet. # This breaks on `None` self._sum = pa.scalar(self._sum.as_py() + pc.sum(states[0]).as_py()) @@ -56,7 +54,7 @@ class MissingMethods(Accumulator): def __init__(self): self._sum = pa.scalar(0) - def state(self) -> List[pa.Scalar]: + def state(self) -> list[pa.Scalar]: return [self._sum] From bd041f843b85475a05ad2158fb3b2cb164c4105a Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 9 Mar 2025 10:49:59 -0400 Subject: [PATCH 07/24] Working through more ruff suggestions --- pyproject.toml | 2 +- python/tests/test_aggregation.py | 4 +- python/tests/test_functions.py | 353 ++++++++++++++------------ python/tests/test_input.py | 8 +- python/tests/test_io.py | 13 +- python/tests/test_sql.py | 9 +- python/tests/test_udwf.py | 2 +- python/tests/test_wrapper_coverage.py | 2 +- 8 files changed, 206 insertions(+), 187 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2d94a9e58..23828d834 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,7 +77,7 @@ max-doc-length = 88 # Disable docstring checking for these directories [tool.ruff.lint.per-file-ignores] -"python/tests/*" = ["ANN", "ARG", "D", "S101", "SLF", "PD", "PLR2004", "RUF015", "S608", "PLR0913"] +"python/tests/*" = ["ANN", "ARG", "D", "S101", "SLF", "PD", "PLR2004", "PT011", "RUF015", "S608", "PLR0913"] "examples/*" = ["D", "W505"] "dev/*" = ["D"] "benchmarks/*" = ["D", "F"] diff --git a/python/tests/test_aggregation.py b/python/tests/test_aggregation.py index acf7dff8e..8093d66a8 100644 --- a/python/tests/test_aggregation.py +++ b/python/tests/test_aggregation.py @@ -299,7 +299,9 @@ def test_aggregate_100(df_aggregate_100, name, expr, expected): ] -@pytest.mark.parametrize(("name","expr","result"), data_test_bitwise_and_boolean_functions) +@pytest.mark.parametrize( + ("name", "expr", "result"), data_test_bitwise_and_boolean_functions +) def test_bit_and_bool_fns(df, name, expr, result): df = df.aggregate([], [expr.alias(name)]) diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index fca05bb8f..9d5301533 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import math -from datetime import datetime +from datetime import datetime, timezone import numpy as np import pyarrow as pa @@ -25,6 +25,8 @@ np.seterr(invalid="ignore") +DEFAULT_TZ = timezone.utc + @pytest.fixture def df(): @@ -37,9 +39,9 @@ def df(): pa.array(["hello ", " world ", " !"], type=pa.string_view()), pa.array( [ - datetime(2022, 12, 31), - datetime(2027, 6, 26), - datetime(2020, 7, 2), + datetime(2022, 12, 31, tzinfo=DEFAULT_TZ), + datetime(2027, 6, 26, tzinfo=DEFAULT_TZ), + datetime(2020, 7, 2, tzinfo=DEFAULT_TZ), ] ), pa.array([False, True, True]), @@ -79,7 +81,7 @@ def test_literal(df): literal("1"), literal("OK"), literal(3.14), - literal(True), + literal(True), # noqa: FBT003 literal(b"hello world"), ) result = df.collect() @@ -221,12 +223,12 @@ def py_indexof(arr, v): def py_arr_remove(arr, v, n=None): new_arr = arr[:] found = 0 - while found != n: - try: + try: + while found != n: new_arr.remove(v) found += 1 - except ValueError: - break + except ValueError: + pass return new_arr @@ -234,13 +236,13 @@ def py_arr_remove(arr, v, n=None): def py_arr_replace(arr, from_, to, n=None): new_arr = arr[:] found = 0 - while found != n: - try: + try: + while found != n: idx = new_arr.index(from_) new_arr[idx] = to found += 1 - except ValueError: - break + except ValueError: + pass return new_arr @@ -268,266 +270,266 @@ def py_flatten(arr): @pytest.mark.parametrize( ("stmt", "py_expr"), [ - [ + ( lambda col: f.array_append(col, literal(99.0)), lambda data: [np.append(arr, 99.0) for arr in data], - ], - [ + ), + ( lambda col: f.array_push_back(col, literal(99.0)), lambda data: [np.append(arr, 99.0) for arr in data], - ], - [ + ), + ( lambda col: f.list_append(col, literal(99.0)), lambda data: [np.append(arr, 99.0) for arr in data], - ], - [ + ), + ( lambda col: f.list_push_back(col, literal(99.0)), lambda data: [np.append(arr, 99.0) for arr in data], - ], - [ + ), + ( lambda col: f.array_concat(col, col), lambda data: [np.concatenate([arr, arr]) for arr in data], - ], - [ + ), + ( lambda col: f.array_cat(col, col), lambda data: [np.concatenate([arr, arr]) for arr in data], - ], - [ + ), + ( lambda col: f.list_cat(col, col), lambda data: [np.concatenate([arr, arr]) for arr in data], - ], - [ + ), + ( lambda col: f.list_concat(col, col), lambda data: [np.concatenate([arr, arr]) for arr in data], - ], - [ + ), + ( lambda col: f.array_dims(col), lambda data: [[len(r)] for r in data], - ], - [ + ), + ( lambda col: f.array_distinct(col), lambda data: [list(set(r)) for r in data], - ], - [ + ), + ( lambda col: f.list_distinct(col), lambda data: [list(set(r)) for r in data], - ], - [ + ), + ( lambda col: f.list_dims(col), lambda data: [[len(r)] for r in data], - ], - [ + ), + ( lambda col: f.array_element(col, literal(1)), lambda data: [r[0] for r in data], - ], - [ + ), + ( lambda col: f.array_empty(col), lambda data: [len(r) == 0 for r in data], - ], - [ + ), + ( lambda col: f.empty(col), lambda data: [len(r) == 0 for r in data], - ], - [ + ), + ( lambda col: f.array_extract(col, literal(1)), lambda data: [r[0] for r in data], - ], - [ + ), + ( lambda col: f.list_element(col, literal(1)), lambda data: [r[0] for r in data], - ], - [ + ), + ( lambda col: f.list_extract(col, literal(1)), lambda data: [r[0] for r in data], - ], - [ + ), + ( lambda col: f.array_length(col), lambda data: [len(r) for r in data], - ], - [ + ), + ( lambda col: f.list_length(col), lambda data: [len(r) for r in data], - ], - [ + ), + ( lambda col: f.array_has(col, literal(1.0)), lambda data: [1.0 in r for r in data], - ], - [ + ), + ( lambda col: f.array_has_all( col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]]) ), lambda data: [np.all([v in r for v in [1.0, 3.0, 5.0]]) for r in data], - ], - [ + ), + ( lambda col: f.array_has_any( col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]]) ), lambda data: [np.any([v in r for v in [1.0, 3.0, 5.0]]) for r in data], - ], - [ + ), + ( lambda col: f.array_position(col, literal(1.0)), lambda data: [py_indexof(r, 1.0) for r in data], - ], - [ + ), + ( lambda col: f.array_indexof(col, literal(1.0)), lambda data: [py_indexof(r, 1.0) for r in data], - ], - [ + ), + ( lambda col: f.list_position(col, literal(1.0)), lambda data: [py_indexof(r, 1.0) for r in data], - ], - [ + ), + ( lambda col: f.list_indexof(col, literal(1.0)), lambda data: [py_indexof(r, 1.0) for r in data], - ], - [ + ), + ( lambda col: f.array_positions(col, literal(1.0)), lambda data: [[i + 1 for i, _v in enumerate(r) if _v == 1.0] for r in data], - ], - [ + ), + ( lambda col: f.list_positions(col, literal(1.0)), lambda data: [[i + 1 for i, _v in enumerate(r) if _v == 1.0] for r in data], - ], - [ + ), + ( lambda col: f.array_ndims(col), lambda data: [np.array(r).ndim for r in data], - ], - [ + ), + ( lambda col: f.list_ndims(col), lambda data: [np.array(r).ndim for r in data], - ], - [ + ), + ( lambda col: f.array_prepend(literal(99.0), col), lambda data: [np.insert(arr, 0, 99.0) for arr in data], - ], - [ + ), + ( lambda col: f.array_push_front(literal(99.0), col), lambda data: [np.insert(arr, 0, 99.0) for arr in data], - ], - [ + ), + ( lambda col: f.list_prepend(literal(99.0), col), lambda data: [np.insert(arr, 0, 99.0) for arr in data], - ], - [ + ), + ( lambda col: f.list_push_front(literal(99.0), col), lambda data: [np.insert(arr, 0, 99.0) for arr in data], - ], - [ + ), + ( lambda col: f.array_pop_back(col), lambda data: [arr[:-1] for arr in data], - ], - [ + ), + ( lambda col: f.array_pop_front(col), lambda data: [arr[1:] for arr in data], - ], - [ + ), + ( lambda col: f.array_remove(col, literal(3.0)), lambda data: [py_arr_remove(arr, 3.0, 1) for arr in data], - ], - [ + ), + ( lambda col: f.list_remove(col, literal(3.0)), lambda data: [py_arr_remove(arr, 3.0, 1) for arr in data], - ], - [ + ), + ( lambda col: f.array_remove_n(col, literal(3.0), literal(2)), lambda data: [py_arr_remove(arr, 3.0, 2) for arr in data], - ], - [ + ), + ( lambda col: f.list_remove_n(col, literal(3.0), literal(2)), lambda data: [py_arr_remove(arr, 3.0, 2) for arr in data], - ], - [ + ), + ( lambda col: f.array_remove_all(col, literal(3.0)), lambda data: [py_arr_remove(arr, 3.0) for arr in data], - ], - [ + ), + ( lambda col: f.list_remove_all(col, literal(3.0)), lambda data: [py_arr_remove(arr, 3.0) for arr in data], - ], - [ + ), + ( lambda col: f.array_repeat(col, literal(2)), lambda data: [[arr] * 2 for arr in data], - ], - [ + ), + ( lambda col: f.list_repeat(col, literal(2)), lambda data: [[arr] * 2 for arr in data], - ], - [ + ), + ( lambda col: f.array_replace(col, literal(3.0), literal(4.0)), lambda data: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data], - ], - [ + ), + ( lambda col: f.list_replace(col, literal(3.0), literal(4.0)), lambda data: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data], - ], - [ + ), + ( lambda col: f.array_replace_n(col, literal(3.0), literal(4.0), literal(1)), lambda data: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data], - ], - [ + ), + ( lambda col: f.list_replace_n(col, literal(3.0), literal(4.0), literal(2)), lambda data: [py_arr_replace(arr, 3.0, 4.0, 2) for arr in data], - ], - [ + ), + ( lambda col: f.array_replace_all(col, literal(3.0), literal(4.0)), lambda data: [py_arr_replace(arr, 3.0, 4.0) for arr in data], - ], - [ + ), + ( lambda col: f.list_replace_all(col, literal(3.0), literal(4.0)), lambda data: [py_arr_replace(arr, 3.0, 4.0) for arr in data], - ], - [ + ), + ( lambda col: f.array_sort(col, descending=True, null_first=True), lambda data: [np.sort(arr)[::-1] for arr in data], - ], - [ + ), + ( lambda col: f.list_sort(col, descending=False, null_first=False), lambda data: [np.sort(arr) for arr in data], - ], - [ + ), + ( lambda col: f.array_slice(col, literal(2), literal(4)), lambda data: [arr[1:4] for arr in data], - ], + ), pytest.param( lambda col: f.list_slice(col, literal(-1), literal(2)), lambda data: [arr[-1:2] for arr in data], ), - [ + ( lambda col: f.array_intersect(col, literal([3.0, 4.0])), lambda data: [np.intersect1d(arr, [3.0, 4.0]) for arr in data], - ], - [ + ), + ( lambda col: f.list_intersect(col, literal([3.0, 4.0])), lambda data: [np.intersect1d(arr, [3.0, 4.0]) for arr in data], - ], - [ + ), + ( lambda col: f.array_union(col, literal([12.0, 999.0])), lambda data: [np.union1d(arr, [12.0, 999.0]) for arr in data], - ], - [ + ), + ( lambda col: f.list_union(col, literal([12.0, 999.0])), lambda data: [np.union1d(arr, [12.0, 999.0]) for arr in data], - ], - [ + ), + ( lambda col: f.array_except(col, literal([3.0])), lambda data: [np.setdiff1d(arr, [3.0]) for arr in data], - ], - [ + ), + ( lambda col: f.list_except(col, literal([3.0])), lambda data: [np.setdiff1d(arr, [3.0]) for arr in data], - ], - [ + ), + ( lambda col: f.array_resize(col, literal(10), literal(0.0)), lambda data: [py_arr_resize(arr, 10, 0.0) for arr in data], - ], - [ + ), + ( lambda col: f.list_resize(col, literal(10), literal(0.0)), lambda data: [py_arr_resize(arr, 10, 0.0) for arr in data], - ], - [ + ), + ( lambda col: f.range(literal(1), literal(5), literal(2)), lambda data: [np.arange(1, 5, 2)], - ], + ), ], ) def test_array_functions(stmt, py_expr): @@ -611,22 +613,22 @@ def test_make_array_functions(make_func): @pytest.mark.parametrize( ("stmt", "py_expr"), [ - [ + ( f.array_to_string(column("arr"), literal(",")), lambda data: [",".join([str(int(v)) for v in r]) for r in data], - ], - [ + ), + ( f.array_join(column("arr"), literal(",")), lambda data: [",".join([str(int(v)) for v in r]) for r in data], - ], - [ + ), + ( f.list_to_string(column("arr"), literal(",")), lambda data: [",".join([str(int(v)) for v in r]) for r in data], - ], - [ + ), + ( f.list_join(column("arr"), literal(",")), lambda data: [",".join([str(int(v)) for v in r]) for r in data], - ], + ), ], ) def test_array_function_obj_tests(stmt, py_expr): @@ -640,7 +642,7 @@ def test_array_function_obj_tests(stmt, py_expr): @pytest.mark.parametrize( - "function, expected_result", + ("function", "expected_result"), [ ( f.ascii(column("a")), @@ -894,54 +896,71 @@ def test_temporal_functions(df): assert result.column(0) == pa.array([12, 6, 7], type=pa.int32()) assert result.column(1) == pa.array([2022, 2027, 2020], type=pa.int32()) assert result.column(2) == pa.array( - [datetime(2022, 12, 1), datetime(2027, 6, 1), datetime(2020, 7, 1)], + [ + datetime(2022, 12, 1, tzinfo=DEFAULT_TZ), + datetime(2027, 6, 1, tzinfo=DEFAULT_TZ), + datetime(2020, 7, 1, tzinfo=DEFAULT_TZ), + ], type=pa.timestamp("us"), ) assert result.column(3) == pa.array( - [datetime(2022, 12, 31), datetime(2027, 6, 26), datetime(2020, 7, 2)], + [ + datetime(2022, 12, 31, tzinfo=DEFAULT_TZ), + datetime(2027, 6, 26, tzinfo=DEFAULT_TZ), + datetime(2020, 7, 2, tzinfo=DEFAULT_TZ), + ], type=pa.timestamp("us"), ) assert result.column(4) == pa.array( [ - datetime(2022, 12, 30, 23, 47, 30), - datetime(2027, 6, 25, 23, 47, 30), - datetime(2020, 7, 1, 23, 47, 30), + datetime(2022, 12, 30, 23, 47, 30, tzinfo=DEFAULT_TZ), + datetime(2027, 6, 25, 23, 47, 30, tzinfo=DEFAULT_TZ), + datetime(2020, 7, 1, 23, 47, 30, tzinfo=DEFAULT_TZ), ], type=pa.timestamp("ns"), ) assert result.column(5) == pa.array( - [datetime(2023, 1, 10, 20, 52, 54)] * 3, type=pa.timestamp("s") + [datetime(2023, 1, 10, 20, 52, 54, tzinfo=DEFAULT_TZ)] * 3, + type=pa.timestamp("s"), ) assert result.column(6) == pa.array( - [datetime(2023, 9, 7, 5, 6, 14, 523952)] * 3, type=pa.timestamp("ns") + [datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3, + type=pa.timestamp("ns"), ) assert result.column(7) == pa.array( - [datetime(2023, 9, 7, 5, 6, 14)] * 3, type=pa.timestamp("s") + [datetime(2023, 9, 7, 5, 6, 1, tzinfo=DEFAULT_TZ)] * 3, type=pa.timestamp("s") ) assert result.column(8) == pa.array( - [datetime(2023, 9, 7, 5, 6, 14, 523000)] * 3, type=pa.timestamp("ms") + [datetime(2023, 9, 7, 5, 6, 14, 523000, tzinfo=DEFAULT_TZ)] * 3, + type=pa.timestamp("ms"), ) assert result.column(9) == pa.array( - [datetime(2023, 9, 7, 5, 6, 14, 523952)] * 3, type=pa.timestamp("us") + [datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3, + type=pa.timestamp("us"), ) assert result.column(10) == pa.array([31, 26, 2], type=pa.int32()) assert result.column(11) == pa.array( - [datetime(2023, 9, 7, 5, 6, 14, 523952)] * 3, type=pa.timestamp("ns") + [datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3, + type=pa.timestamp("ns"), ) assert result.column(12) == pa.array( - [datetime(2023, 9, 7, 5, 6, 14)] * 3, type=pa.timestamp("s") + [datetime(2023, 9, 7, 5, 6, 14, tzinfo=DEFAULT_TZ)] * 3, type=pa.timestamp("s") ) assert result.column(13) == pa.array( - [datetime(2023, 9, 7, 5, 6, 14, 523000)] * 3, type=pa.timestamp("ms") + [datetime(2023, 9, 7, 5, 6, 14, 523000, tzinfo=DEFAULT_TZ)] * 3, + type=pa.timestamp("ms"), ) assert result.column(14) == pa.array( - [datetime(2023, 9, 7, 5, 6, 14, 523952)] * 3, type=pa.timestamp("us") + [datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3, + type=pa.timestamp("us"), ) assert result.column(15) == pa.array( - [datetime(2023, 9, 7, 5, 6, 14, 523952)] * 3, type=pa.timestamp("ns") + [datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3, + type=pa.timestamp("ns"), ) assert result.column(16) == pa.array( - [datetime(2023, 9, 7, 5, 6, 14, 523952)] * 3, type=pa.timestamp("ns") + [datetime(2023, 9, 7, 5, 6, 14, 523952, tzinfo=DEFAULT_TZ)] * 3, + type=pa.timestamp("ns"), ) @@ -1057,7 +1076,7 @@ def test_regr_funcs_sql_2(): @pytest.mark.parametrize( - "func, expected", + ("func", "expected"), [ pytest.param(f.regr_slope(column("c2"), column("c1")), [4.6], id="regr_slope"), pytest.param( @@ -1160,7 +1179,7 @@ def test_binary_string_functions(df): @pytest.mark.parametrize( - "python_datatype, name, expected", + ("python_datatype", "name", "expected"), [ pytest.param(bool, "e", pa.bool_(), id="bool"), pytest.param(int, "b", pa.int64(), id="int"), @@ -1179,7 +1198,7 @@ def test_cast(df, python_datatype, name: str, expected): @pytest.mark.parametrize( - "negated, low, high, expected", + ("negated", "low", "high", "expected"), [ pytest.param(False, 3, 5, {"filtered": [4, 5]}), pytest.param(False, 4, 5, {"filtered": [4, 5]}), diff --git a/python/tests/test_input.py b/python/tests/test_input.py index 0b32769ca..4663f6148 100644 --- a/python/tests/test_input.py +++ b/python/tests/test_input.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -import os +import pathlib from datafusion.input.location import LocationInputPlugin @@ -23,10 +23,10 @@ def test_location_input(): location_input = LocationInputPlugin() - cwd = os.getcwd() - input_file = cwd + "/testing/data/parquet/generated_simple_numerics/blogs.parquet" + cwd = pathlib.Path.cwd() + input_file = cwd / "testing/data/parquet/generated_simple_numerics/blogs.parquet" table_name = "blog" - tbl = location_input.build_table(input_file, table_name) + tbl = location_input.build_table(str(input_file), table_name) assert tbl.name == "blog" assert len(tbl.columns) == 3 assert "blogs.parquet" in tbl.filepaths[0] diff --git a/python/tests/test_io.py b/python/tests/test_io.py index 21ad188ee..7ca509689 100644 --- a/python/tests/test_io.py +++ b/python/tests/test_io.py @@ -14,8 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import os -import pathlib +from pathlib import Path import pyarrow as pa from datafusion import column @@ -23,10 +22,10 @@ def test_read_json_global_ctx(ctx): - path = os.path.dirname(os.path.abspath(__file__)) + path = Path(__file__).parent.resolve() # Default - test_data_path = os.path.join(path, "data_test_context", "data.json") + test_data_path = Path(path) / "data_test_context" / "data.json" df = read_json(test_data_path) result = df.collect() @@ -46,7 +45,7 @@ def test_read_json_global_ctx(ctx): assert result[0].schema == schema # File extension - test_data_path = os.path.join(path, "data_test_context", "data.json") + test_data_path = Path(path) / "data_test_context" / "data.json" df = read_json(test_data_path, file_extension=".json") result = df.collect() @@ -59,7 +58,7 @@ def test_read_parquet_global(): parquet_df.show() assert parquet_df is not None - path = pathlib.Path.cwd() / "parquet/data/alltypes_plain.parquet" + path = Path.cwd() / "parquet/data/alltypes_plain.parquet" parquet_df = read_parquet(path=path) assert parquet_df is not None @@ -90,6 +89,6 @@ def test_read_avro(): avro_df.show() assert avro_df is not None - path = pathlib.Path.cwd() / "testing/data/avro/alltypes_plain.avro" + path = Path.cwd() / "testing/data/avro/alltypes_plain.avro" avro_df = read_avro(path=path) assert avro_df is not None diff --git a/python/tests/test_sql.py b/python/tests/test_sql.py index 3ecbaeadd..b6348e3a0 100644 --- a/python/tests/test_sql.py +++ b/python/tests/test_sql.py @@ -214,7 +214,7 @@ def test_register_json(ctx, tmp_path): with ( Path.open(test_data_path, "rb") as json_file, - gzip.open(gzip_path, "wb") as gzipped_file + gzip.open(gzip_path, "wb") as gzipped_file, ): gzipped_file.writelines(json_file) @@ -478,10 +478,9 @@ def test_simple_select(ctx, tmp_path, arr): np.testing.assert_equal(result, arr) -@pytest.mark.parametrize("file_sort_order", [ - None, - [[col("int").sort(ascending=True, nulls_first=True)]] - ]) +@pytest.mark.parametrize( + "file_sort_order", [None, [[col("int").sort(ascending=True, nulls_first=True)]]] +) @pytest.mark.parametrize("pass_schema", [True, False]) @pytest.mark.parametrize("path_to_str", [True, False]) def test_register_listing_table( diff --git a/python/tests/test_udwf.py b/python/tests/test_udwf.py index c8be0250a..3d6dcf9d8 100644 --- a/python/tests/test_udwf.py +++ b/python/tests/test_udwf.py @@ -298,7 +298,7 @@ def test_udwf_errors(df): ] -@pytest.mark.parametrize(("name","expr","expected"), data_test_udwf_functions) +@pytest.mark.parametrize(("name", "expr", "expected"), data_test_udwf_functions) def test_udwf_functions(df, name, expr, expected): df = df.select("a", "b", f.round(expr, lit(3)).alias(name)) diff --git a/python/tests/test_wrapper_coverage.py b/python/tests/test_wrapper_coverage.py index d3ce7c57f..d7f6f6e35 100644 --- a/python/tests/test_wrapper_coverage.py +++ b/python/tests/test_wrapper_coverage.py @@ -43,7 +43,7 @@ def missing_exports(internal_obj, wrapped_obj) -> None: wrapped_attr = getattr(wrapped_obj, attr) if internal_attr is not None and wrapped_attr is None: - pytest.fail(f"Missing attribute: {attr}") + pytest.fail(f"Missing attribute: {attr}") if attr in ["__self__", "__class__"]: continue From 7b5642e783a0b5b9a04ef6c778a3ad50f82afe0c Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 9 Mar 2025 11:13:04 -0400 Subject: [PATCH 08/24] update timestamps on tests --- python/tests/test_functions.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index 9d5301533..9e96742bb 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -901,7 +901,7 @@ def test_temporal_functions(df): datetime(2027, 6, 1, tzinfo=DEFAULT_TZ), datetime(2020, 7, 1, tzinfo=DEFAULT_TZ), ], - type=pa.timestamp("us"), + type=pa.timestamp("ns", tz=DEFAULT_TZ), ) assert result.column(3) == pa.array( [ @@ -909,7 +909,7 @@ def test_temporal_functions(df): datetime(2027, 6, 26, tzinfo=DEFAULT_TZ), datetime(2020, 7, 2, tzinfo=DEFAULT_TZ), ], - type=pa.timestamp("us"), + type=pa.timestamp("ns", tz=DEFAULT_TZ), ) assert result.column(4) == pa.array( [ @@ -917,7 +917,7 @@ def test_temporal_functions(df): datetime(2027, 6, 25, 23, 47, 30, tzinfo=DEFAULT_TZ), datetime(2020, 7, 1, 23, 47, 30, tzinfo=DEFAULT_TZ), ], - type=pa.timestamp("ns"), + type=pa.timestamp("ns", tz=DEFAULT_TZ), ) assert result.column(5) == pa.array( [datetime(2023, 1, 10, 20, 52, 54, tzinfo=DEFAULT_TZ)] * 3, @@ -928,7 +928,7 @@ def test_temporal_functions(df): type=pa.timestamp("ns"), ) assert result.column(7) == pa.array( - [datetime(2023, 9, 7, 5, 6, 1, tzinfo=DEFAULT_TZ)] * 3, type=pa.timestamp("s") + [datetime(2023, 9, 7, 5, 6, 14, tzinfo=DEFAULT_TZ)] * 3, type=pa.timestamp("s") ) assert result.column(8) == pa.array( [datetime(2023, 9, 7, 5, 6, 14, 523000, tzinfo=DEFAULT_TZ)] * 3, @@ -944,7 +944,8 @@ def test_temporal_functions(df): type=pa.timestamp("ns"), ) assert result.column(12) == pa.array( - [datetime(2023, 9, 7, 5, 6, 14, tzinfo=DEFAULT_TZ)] * 3, type=pa.timestamp("s") + [datetime(2023, 9, 7, 5, 6, 14, tzinfo=DEFAULT_TZ)] * 3, + type=pa.timestamp("s"), ) assert result.column(13) == pa.array( [datetime(2023, 9, 7, 5, 6, 14, 523000, tzinfo=DEFAULT_TZ)] * 3, From ee2e488f04b5fa7e256107672c5b4c7db1eb2fc0 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 9 Mar 2025 11:32:13 -0400 Subject: [PATCH 09/24] More ruff updates --- pyproject.toml | 17 ++++++++++-- python/tests/test_context.py | 51 +++++++++++++++++----------------- python/tests/test_dataframe.py | 26 ++++++++--------- python/tests/test_expr.py | 11 +++----- 4 files changed, 57 insertions(+), 48 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 23828d834..0bc24bf7c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,7 +67,7 @@ features = ["substrait"] # Enable docstring linting using the google style guide [tool.ruff.lint] select = ["ALL" ] -ignore = ["COM812", "ISC001"] # Recommended to ignore these rules when using with ruff-format +ignore = ["COM812", "ISC001", "TD002"] # Recommended to ignore these rules when using with ruff-format [tool.ruff.lint.pydocstyle] convention = "google" @@ -77,7 +77,20 @@ max-doc-length = 88 # Disable docstring checking for these directories [tool.ruff.lint.per-file-ignores] -"python/tests/*" = ["ANN", "ARG", "D", "S101", "SLF", "PD", "PLR2004", "PT011", "RUF015", "S608", "PLR0913"] +"python/tests/*" = [ + "ANN", + "ARG", + "BLE001", + "D", + "S101", + "SLF", + "PD", + "PLR2004", + "PT011", + "RUF015", + "S608", + "PLR0913" +] "examples/*" = ["D", "W505"] "dev/*" = ["D"] "benchmarks/*" = ["D", "F"] diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 91046e6b8..04556ef4d 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -16,7 +16,6 @@ # under the License. import datetime as dt import gzip -import os import pathlib import pyarrow as pa @@ -45,7 +44,7 @@ def test_create_context_runtime_config_only(): SessionContext(runtime=RuntimeEnvBuilder()) -@pytest.mark.parametrize("path_to_str", (True, False)) +@pytest.mark.parametrize("path_to_str", [True, False]) def test_runtime_configs(tmp_path, path_to_str): path1 = tmp_path / "dir1" path2 = tmp_path / "dir2" @@ -62,7 +61,7 @@ def test_runtime_configs(tmp_path, path_to_str): assert db is not None -@pytest.mark.parametrize("path_to_str", (True, False)) +@pytest.mark.parametrize("path_to_str", [True, False]) def test_temporary_files(tmp_path, path_to_str): path = str(tmp_path) if path_to_str else tmp_path @@ -79,14 +78,14 @@ def test_create_context_with_all_valid_args(): runtime = RuntimeEnvBuilder().with_disk_manager_os().with_fair_spill_pool(10000000) config = ( SessionConfig() - .with_create_default_catalog_and_schema(True) + .with_create_default_catalog_and_schema(enabled=True) .with_default_catalog_and_schema("foo", "bar") .with_target_partitions(1) - .with_information_schema(True) - .with_repartition_joins(False) - .with_repartition_aggregations(False) - .with_repartition_windows(False) - .with_parquet_pruning(False) + .with_information_schema(enabled=True) + .with_repartition_joins(enabled=False) + .with_repartition_aggregations(enabled=False) + .with_repartition_windows(enabled=False) + .with_parquet_pruning(enabled=False) ) ctx = SessionContext(config, runtime) @@ -167,7 +166,7 @@ def test_from_arrow_table(ctx): def record_batch_generator(num_batches: int): schema = pa.schema([("a", pa.int64()), ("b", pa.int64())]) - for i in range(num_batches): + for _i in range(num_batches): yield pa.RecordBatch.from_arrays( [pa.array([1, 2, 3]), pa.array([4, 5, 6])], schema=schema ) @@ -492,10 +491,10 @@ def test_table_not_found(ctx): def test_read_json(ctx): - path = os.path.dirname(os.path.abspath(__file__)) + path = pathlib.Path(__file__).parent.resolve() # Default - test_data_path = os.path.join(path, "data_test_context", "data.json") + test_data_path = path / "data_test_context" / "data.json" df = ctx.read_json(test_data_path) result = df.collect() @@ -515,7 +514,7 @@ def test_read_json(ctx): assert result[0].schema == schema # File extension - test_data_path = os.path.join(path, "data_test_context", "data.json") + test_data_path = path / "data_test_context" / "data.json" df = ctx.read_json(test_data_path, file_extension=".json") result = df.collect() @@ -524,15 +523,16 @@ def test_read_json(ctx): def test_read_json_compressed(ctx, tmp_path): - path = os.path.dirname(os.path.abspath(__file__)) - test_data_path = os.path.join(path, "data_test_context", "data.json") + path = pathlib.Path(__file__).parent.resolve() + test_data_path = path / "data_test_context" / "data.json" # File compression type gzip_path = tmp_path / "data.json.gz" - with open(test_data_path, "rb") as csv_file: - with gzip.open(gzip_path, "wb") as gzipped_file: - gzipped_file.writelines(csv_file) + with pathlib.Path.open(test_data_path, "rb") as csv_file, gzip.open( + gzip_path, "wb" + ) as gzipped_file: + gzipped_file.writelines(csv_file) df = ctx.read_json(gzip_path, file_extension=".gz", file_compression_type="gz") result = df.collect() @@ -563,14 +563,15 @@ def test_read_csv_list(ctx): def test_read_csv_compressed(ctx, tmp_path): - test_data_path = "testing/data/csv/aggregate_test_100.csv" + test_data_path = pathlib.Path("testing/data/csv/aggregate_test_100.csv") # File compression type gzip_path = tmp_path / "aggregate_test_100.csv.gz" - with open(test_data_path, "rb") as csv_file: - with gzip.open(gzip_path, "wb") as gzipped_file: - gzipped_file.writelines(csv_file) + with pathlib.Path.open(test_data_path, "rb") as csv_file, gzip.open( + gzip_path, "wb" + ) as gzipped_file: + gzipped_file.writelines(csv_file) csv_df = ctx.read_csv(gzip_path, file_extension=".gz", file_compression_type="gz") csv_df.select(column("c1")).show() @@ -603,7 +604,7 @@ def test_create_sql_options(): def test_sql_with_options_no_ddl(ctx): sql = "CREATE TABLE IF NOT EXISTS valuetable AS VALUES(1,'HELLO'),(12,'DATAFUSION')" ctx.sql(sql) - options = SQLOptions().with_allow_ddl(False) + options = SQLOptions().with_allow_ddl(allow=False) with pytest.raises(Exception, match="DDL"): ctx.sql_with_options(sql, options=options) @@ -618,7 +619,7 @@ def test_sql_with_options_no_dml(ctx): ctx.register_dataset(table_name, dataset) sql = f'INSERT INTO "{table_name}" VALUES (1, 2), (2, 3);' ctx.sql(sql) - options = SQLOptions().with_allow_dml(False) + options = SQLOptions().with_allow_dml(allow=False) with pytest.raises(Exception, match="DML"): ctx.sql_with_options(sql, options=options) @@ -626,6 +627,6 @@ def test_sql_with_options_no_dml(ctx): def test_sql_with_options_no_statements(ctx): sql = "SET time zone = 1;" ctx.sql(sql) - options = SQLOptions().with_allow_statements(False) + options = SQLOptions().with_allow_statements(allow=False) with pytest.raises(Exception, match="SetVariable"): ctx.sql_with_options(sql, options=options) diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index 4a76c5697..d084f12dd 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -339,7 +339,7 @@ def test_join(): # Verify we don't make a breaking change to pre-43.0.0 # where users would pass join_keys as a positional argument - df2 = df.join(df1, (["a"], ["a"]), how="inner") # type: ignore + df2 = df.join(df1, (["a"], ["a"]), how="inner") df2.show() df2 = df2.sort(column("l.a")) table = pa.Table.from_batches(df2.collect()) @@ -375,17 +375,17 @@ def test_join_invalid_params(): with pytest.raises( ValueError, match=r"`left_on` or `right_on` should not provided with `on`" ): - df2 = df.join(df1, on="a", how="inner", right_on="test") # type: ignore + df2 = df.join(df1, on="a", how="inner", right_on="test") with pytest.raises( ValueError, match=r"`left_on` and `right_on` should both be provided." ): - df2 = df.join(df1, left_on="a", how="inner") # type: ignore + df2 = df.join(df1, left_on="a", how="inner") with pytest.raises( ValueError, match=r"either `on` or `left_on` and `right_on` should be provided." ): - df2 = df.join(df1, how="inner") # type: ignore + df2 = df.join(df1, how="inner") def test_join_on(): @@ -567,7 +567,7 @@ def test_distinct(): ] -@pytest.mark.parametrize("name,expr,result", data_test_window_functions) +@pytest.mark.parametrize(("name", "expr", "result"), data_test_window_functions) def test_window_functions(partitioned_df, name, expr, result): df = partitioned_df.select( column("a"), column("b"), column("c"), f.alias(expr, name) @@ -885,7 +885,7 @@ def test_union_distinct(ctx): ) df_c = ctx.create_dataframe([[batch]]).sort(column("a")) - df_a_u_b = df_a.union(df_b, True).sort(column("a")) + df_a_u_b = df_a.union(df_b, distinct=True).sort(column("a")) assert df_c.collect() == df_a_u_b.collect() assert df_c.collect() == df_a_u_b.collect() @@ -954,8 +954,6 @@ def test_to_arrow_table(df): def test_execute_stream(df): stream = df.execute_stream() - for s in stream: - print(type(s)) assert all(batch is not None for batch in stream) assert not list(stream) # after one iteration the generator must be exhausted @@ -1033,7 +1031,7 @@ def test_describe(df): } -@pytest.mark.parametrize("path_to_str", (True, False)) +@pytest.mark.parametrize("path_to_str", [True, False]) def test_write_csv(ctx, df, tmp_path, path_to_str): path = str(tmp_path) if path_to_str else tmp_path @@ -1046,7 +1044,7 @@ def test_write_csv(ctx, df, tmp_path, path_to_str): assert result == expected -@pytest.mark.parametrize("path_to_str", (True, False)) +@pytest.mark.parametrize("path_to_str", [True, False]) def test_write_json(ctx, df, tmp_path, path_to_str): path = str(tmp_path) if path_to_str else tmp_path @@ -1059,7 +1057,7 @@ def test_write_json(ctx, df, tmp_path, path_to_str): assert result == expected -@pytest.mark.parametrize("path_to_str", (True, False)) +@pytest.mark.parametrize("path_to_str", [True, False]) def test_write_parquet(df, tmp_path, path_to_str): path = str(tmp_path) if path_to_str else tmp_path @@ -1071,7 +1069,7 @@ def test_write_parquet(df, tmp_path, path_to_str): @pytest.mark.parametrize( - "compression, compression_level", + ("compression", "compression_level"), [("gzip", 6), ("brotli", 7), ("zstd", 15)], ) def test_write_compressed_parquet(df, tmp_path, compression, compression_level): @@ -1082,7 +1080,7 @@ def test_write_compressed_parquet(df, tmp_path, compression, compression_level): ) # test that the actual compression scheme is the one written - for root, dirs, files in os.walk(path): + for _root, _dirs, files in os.walk(path): for file in files: if file.endswith(".parquet"): metadata = pq.ParquetFile(tmp_path / file).metadata.to_dict() @@ -1097,7 +1095,7 @@ def test_write_compressed_parquet(df, tmp_path, compression, compression_level): @pytest.mark.parametrize( - "compression, compression_level", + ("compression", "compression_level"), [("gzip", 12), ("brotli", 15), ("zstd", 23), ("wrong", 12)], ) def test_write_compressed_parquet_wrong_compression_level( diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index 354c7e180..926e69845 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -85,18 +85,14 @@ def test_limit(test_ctx): plan = plan.to_variant() assert isinstance(plan, Limit) - # TODO: Upstream now has expressions for skip and fetch - # REF: https://github.com/apache/datafusion/pull/12836 - # assert plan.skip() == 0 + assert "Skip: None" in str(plan) df = test_ctx.sql("select c1 from test LIMIT 10 OFFSET 5") plan = df.logical_plan() plan = plan.to_variant() assert isinstance(plan, Limit) - # TODO: Upstream now has expressions for skip and fetch - # REF: https://github.com/apache/datafusion/pull/12836 - # assert plan.skip() == 5 + assert "Skip: Some(Literal(Int64(5)))" in str(plan) def test_aggregate_query(test_ctx): @@ -165,6 +161,7 @@ def traverse_logical_plan(plan): res = traverse_logical_plan(input_plan) if res is not None: return res + return None ctx = SessionContext() data = {"id": [1, 2, 3], "name": ["Alice", "Bob", "Charlie"]} @@ -176,7 +173,7 @@ def traverse_logical_plan(plan): assert variant.expr().to_variant().qualified_name() == "table1.name" assert ( str(variant.list()) - == '[Expr(Utf8("dfa")), Expr(Utf8("ad")), Expr(Utf8("dfre")), Expr(Utf8("vsa"))]' + == '[Expr(Utf8("dfa")), Expr(Utf8("ad")), Expr(Utf8("dfre")), Expr(Utf8("vsa"))]' # noqa: E501 ) assert not variant.negated() From f4e1754724d02fbb525b79393335e0eb80429bfb Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 9 Mar 2025 11:52:16 -0400 Subject: [PATCH 10/24] More ruff updates --- python/tests/generic.py | 19 ++++++++++--------- python/tests/test_aggregation.py | 14 ++++++-------- python/tests/test_catalog.py | 9 ++++++--- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/python/tests/generic.py b/python/tests/generic.py index 0177e2df0..1b98fdf9e 100644 --- a/python/tests/generic.py +++ b/python/tests/generic.py @@ -16,6 +16,7 @@ # under the License. import datetime +from datetime import timezone import numpy as np import pyarrow as pa @@ -26,29 +27,29 @@ def data(): - np.random.seed(1) + rng = np.random.default_rng(1) data = np.concatenate( [ - np.random.normal(0, 0.01, size=50), - np.random.normal(50, 0.01, size=50), + rng.normal(0, 0.01, size=50), + rng.normal(50, 0.01, size=50), ] ) return pa.array(data) def data_with_nans(): - np.random.seed(0) - data = np.random.normal(0, 0.01, size=50) - mask = np.random.randint(0, 2, size=50) + rng = np.random.default_rng(0) + data = rng.normal(0, 0.01, size=50) + mask = rng.normal(0, 2, size=50) data[mask == 0] = np.nan return data def data_datetime(f): data = [ - datetime.datetime.now(), - datetime.datetime.now() - datetime.timedelta(days=1), - datetime.datetime.now() + datetime.timedelta(days=1), + datetime.datetime.now(tz=timezone.utc), + datetime.datetime.now(tz=timezone.utc) - datetime.timedelta(days=1), + datetime.datetime.now(tz=timezone.utc) + datetime.timedelta(days=1), ] return pa.array(data, type=pa.timestamp(f), mask=np.array([False, True, False])) diff --git a/python/tests/test_aggregation.py b/python/tests/test_aggregation.py index 8093d66a8..d5e76dd0e 100644 --- a/python/tests/test_aggregation.py +++ b/python/tests/test_aggregation.py @@ -66,7 +66,7 @@ def df_aggregate_100(): @pytest.mark.parametrize( - "agg_expr, calc_expected", + ("agg_expr", "calc_expected"), [ (f.avg(column("a")), lambda a, b, c, d: np.array(np.average(a))), ( @@ -88,7 +88,7 @@ def df_aggregate_100(): f.covar_samp(column("b"), column("c")), lambda a, b, c, d: np.array(np.cov(b, c, ddof=1)[0][1]), ), - # f.grouping(col_a), # No physical plan implemented yet + # f.grouping(col_a), # No physical plan implemented yet # noqa: ERA001 (f.max(column("a")), lambda a, b, c, d: np.array(np.max(a))), (f.mean(column("b")), lambda a, b, c, d: np.array(np.mean(b))), (f.median(column("b")), lambda a, b, c, d: np.array(np.median(b))), @@ -114,7 +114,7 @@ def test_aggregation_stats(df, agg_expr, calc_expected): @pytest.mark.parametrize( - "agg_expr, expected, array_sort", + ("agg_expr", "expected", "array_sort"), [ (f.approx_distinct(column("b")), pa.array([2], type=pa.uint64()), False), ( @@ -182,12 +182,11 @@ def test_aggregation(df, agg_expr, expected, array_sort): agg_df.show() result = agg_df.collect()[0] - print(result) assert result.column(0) == expected @pytest.mark.parametrize( - "name,expr,expected", + ("name", "expr", "expected"), [ ( "approx_percentile_cont", @@ -313,7 +312,7 @@ def test_bit_and_bool_fns(df, name, expr, result): @pytest.mark.parametrize( - "name,expr,result", + ("name", "expr", "result"), [ ("first_value", f.first_value(column("a")), [0, 4]), ( @@ -363,7 +362,6 @@ def test_bit_and_bool_fns(df, name, expr, result): ), [8, 9], ), - ("first_value", f.first_value(column("a")), [0, 4]), ( "nth_value_ordered", f.nth_value(column("a"), 2, order_by=[column("a").sort(ascending=False)]), @@ -403,7 +401,7 @@ def test_first_last_value(df_partitioned, name, expr, result) -> None: @pytest.mark.parametrize( - "name,expr,result", + ("name", "expr", "result"), [ ("string_agg", f.string_agg(column("a"), ","), "one,two,three,two"), ("string_agg", f.string_agg(column("b"), ""), "03124"), diff --git a/python/tests/test_catalog.py b/python/tests/test_catalog.py index 214f6b165..23b328458 100644 --- a/python/tests/test_catalog.py +++ b/python/tests/test_catalog.py @@ -19,6 +19,9 @@ import pytest +# Note we take in `database` as a variable even though we don't use +# it because that will cause the fixture to set up the context with +# the tables we need. def test_basic(ctx, database): with pytest.raises(KeyError): ctx.catalog("non-existent") @@ -26,10 +29,10 @@ def test_basic(ctx, database): default = ctx.catalog() assert default.names() == ["public"] - for database in [default.database("public"), default.database()]: - assert database.names() == {"csv1", "csv", "csv2"} + for db in [default.database("public"), default.database()]: + assert db.names() == {"csv1", "csv", "csv2"} - table = database.table("csv") + table = db.table("csv") assert table.kind == "physical" assert table.schema == pa.schema( [ From 8c9b6d83b132427c0c05e78c1463f74313bc7a93 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 9 Mar 2025 11:54:25 -0400 Subject: [PATCH 11/24] Instead of importing udf static functions as variables, import --- python/datafusion/__init__.py | 9 +-------- python/datafusion/udf.py | 6 ++++++ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 5a506e74e..58ba4fdd2 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -48,7 +48,7 @@ from .io import read_avro, read_csv, read_json, read_parquet from .plan import ExecutionPlan, LogicalPlan from .record_batch import RecordBatch, RecordBatchStream -from .udf import Accumulator, AggregateUDF, ScalarUDF, WindowUDF +from .udf import Accumulator, AggregateUDF, ScalarUDF, WindowUDF, udaf, udf, udwf __version__ = importlib_metadata.version(__name__) @@ -120,10 +120,3 @@ def str_lit(value): def lit(value): """Create a literal expression.""" return Expr.literal(value) - - -udf = ScalarUDF.udf - -udaf = AggregateUDF.udaf - -udwf = WindowUDF.udwf diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index 907477604..9224ae2b4 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -656,3 +656,9 @@ def bias_10() -> BiasedNumbers: return_type=return_type, volatility=volatility, ) + +# Convenience exports so we can import instead of treating as +# variables at the package root +udf = ScalarUDF.udf +udaf = AggregateUDF.udaf +udwf = WindowUDF.udwf From f67a72767bb6f10c51462b42bd7268e650ce2c21 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 9 Mar 2025 19:45:23 -0400 Subject: [PATCH 12/24] More ruff formatting suggestions --- pyproject.toml | 8 +- python/datafusion/__init__.py | 3 + python/datafusion/udf.py | 228 +++++++++++++++++++--------------- 3 files changed, 138 insertions(+), 101 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0bc24bf7c..5bdaa0518 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,7 +67,13 @@ features = ["substrait"] # Enable docstring linting using the google style guide [tool.ruff.lint] select = ["ALL" ] -ignore = ["COM812", "ISC001", "TD002"] # Recommended to ignore these rules when using with ruff-format +ignore = [ + "ANN401", # Allow Any for wrapper classes + "COM812", # Recommended to ignore these rules when using with ruff-format + "ISC001", # Recommended to ignore these rules when using with ruff-format + "TD002", + "UP007" # Disallowing Union is pedantic +] [tool.ruff.lint.pydocstyle] convention = "google" diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 58ba4fdd2..286e5dc31 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -86,6 +86,9 @@ "read_json", "read_parquet", "substrait", + "udaf", + "udf", + "udwf", ] diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index 9224ae2b4..fbb208459 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -22,15 +22,15 @@ import functools from abc import ABCMeta, abstractmethod from enum import Enum -from typing import TYPE_CHECKING, Callable, List, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, overload -import pyarrow +import pyarrow as pa import datafusion._internal as df_internal from datafusion.expr import Expr if TYPE_CHECKING: - _R = TypeVar("_R", bound=pyarrow.DataType) + _R = TypeVar("_R", bound=pa.DataType) class Volatility(Enum): @@ -72,7 +72,7 @@ class Volatility(Enum): for each output row, resulting in a unique random value for each row. """ - def __str__(self): + def __str__(self) -> str: """Returns the string equivalent.""" return self.name.lower() @@ -88,7 +88,7 @@ def __init__( self, name: str, func: Callable[..., _R], - input_types: pyarrow.DataType | list[pyarrow.DataType], + input_types: pa.DataType | list[pa.DataType], return_type: _R, volatility: Volatility | str, ) -> None: @@ -96,7 +96,7 @@ def __init__( See helper method :py:func:`udf` for argument details. """ - if isinstance(input_types, pyarrow.DataType): + if isinstance(input_types, pa.DataType): input_types = [input_types] self._udf = df_internal.ScalarUDF( name, func, input_types, return_type, str(volatility) @@ -111,7 +111,27 @@ def __call__(self, *args: Expr) -> Expr: args_raw = [arg.expr for arg in args] return Expr(self._udf.__call__(*args_raw)) - class udf: + @overload + @staticmethod + def udf( + input_types: list[pa.DataType], + return_type: _R, + volatility: Volatility | str, + name: Optional[str] = None, + ) -> Callable[..., ScalarUDF]: ... + + @overload + @staticmethod + def udf( + func: Callable[..., _R], + input_types: list[pa.DataType], + return_type: _R, + volatility: Volatility | str, + name: Optional[str] = None, + ) -> ScalarUDF: ... + + @staticmethod + def udf(*args: Any, **kwargs: Any): # noqa: D417 """Create a new User-Defined Function (UDF). This class can be used both as a **function** and as a **decorator**. @@ -125,7 +145,7 @@ class udf: Args: func (Callable, optional): **Only needed when calling as a function.** Skip this argument when using `udf` as a decorator. - input_types (list[pyarrow.DataType]): The data types of the arguments + input_types (list[pa.DataType]): The data types of the arguments to `func`. This list must be of the same length as the number of arguments. return_type (_R): The data type of the return value from the function. @@ -141,39 +161,28 @@ class udf: ``` def double_func(x): return x * 2 - double_udf = udf(double_func, [pyarrow.int32()], pyarrow.int32(), + double_udf = udf(double_func, [pa.int32()], pa.int32(), "volatile", "double_it") ``` **Using `udf` as a decorator:** ``` - @udf([pyarrow.int32()], pyarrow.int32(), "volatile", "double_it") + @udf([pa.int32()], pa.int32(), "volatile", "double_it") def double_udf(x): return x * 2 ``` """ - def __new__(cls, *args, **kwargs): - """Create a new UDF. - - Trigger UDF function or decorator depending on if the first args is callable - """ - if args and callable(args[0]): - # Case 1: Used as a function, require the first parameter to be callable - return cls._function(*args, **kwargs) - # Case 2: Used as a decorator with parameters - return cls._decorator(*args, **kwargs) - - @staticmethod def _function( func: Callable[..., _R], - input_types: list[pyarrow.DataType], + input_types: list[pa.DataType], return_type: _R, volatility: Volatility | str, name: Optional[str] = None, ) -> ScalarUDF: if not callable(func): - raise TypeError("`func` argument must be callable") + msg = "`func` argument must be callable" + raise TypeError(msg) if name is None: if hasattr(func, "__qualname__"): name = func.__qualname__.lower() @@ -187,44 +196,49 @@ def _function( volatility=volatility, ) - @staticmethod def _decorator( - input_types: list[pyarrow.DataType], + input_types: list[pa.DataType], return_type: _R, volatility: Volatility | str, name: Optional[str] = None, - ): - def decorator(func): + ) -> Callable: + def decorator(func: Callable): # noqa: ANN202 udf_caller = ScalarUDF.udf( func, input_types, return_type, volatility, name ) @functools.wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any): # noqa: ANN202 return udf_caller(*args, **kwargs) return wrapper return decorator + if args and callable(args[0]): + # Case 1: Used as a function, require the first parameter to be callable + return _function(*args, **kwargs) + # Case 2: Used as a decorator with parameters + return _decorator(*args, **kwargs) + class Accumulator(metaclass=ABCMeta): """Defines how an :py:class:`AggregateUDF` accumulates values.""" @abstractmethod - def state(self) -> List[pyarrow.Scalar]: + def state(self) -> list[pa.Scalar]: """Return the current state.""" @abstractmethod - def update(self, *values: pyarrow.Array) -> None: + def update(self, *values: pa.Array) -> None: """Evaluate an array of values and update state.""" @abstractmethod - def merge(self, states: List[pyarrow.Array]) -> None: + def merge(self, states: list[pa.Array]) -> None: """Merge a set of states.""" @abstractmethod - def evaluate(self) -> pyarrow.Scalar: + def evaluate(self) -> pa.Scalar: """Return the resultant value.""" @@ -235,13 +249,13 @@ class AggregateUDF: also :py:class:`ScalarUDF` for operating on a row by row basis. """ - def __init__( + def __init__( # noqa: PLR0913 self, name: str, accumulator: Callable[[], Accumulator], - input_types: list[pyarrow.DataType], - return_type: pyarrow.DataType, - state_type: list[pyarrow.DataType], + input_types: list[pa.DataType], + return_type: pa.DataType, + state_type: list[pa.DataType], volatility: Volatility | str, ) -> None: """Instantiate a user-defined aggregate function (UDAF). @@ -267,7 +281,29 @@ def __call__(self, *args: Expr) -> Expr: args_raw = [arg.expr for arg in args] return Expr(self._udaf.__call__(*args_raw)) - class udaf: + @overload + @staticmethod + def udaf( + input_types: pa.DataType | list[pa.DataType], + return_type: pa.DataType, + state_type: list[pa.DataType], + volatility: Volatility | str, + name: Optional[str] = None, + ) -> Callable[..., AggregateUDF]: ... + + @overload + @staticmethod + def udaf( + accum: Callable[[], Accumulator], + input_types: pa.DataType | list[pa.DataType], + return_type: pa.DataType, + state_type: list[pa.DataType], + volatility: Volatility | str, + name: Optional[str] = None, + ) -> AggregateUDF: ... + + @staticmethod + def udaf(*args: Any, **kwargs: Any): # noqa: D417 """Create a new User-Defined Aggregate Function (UDAF). This class allows you to define an **aggregate function** that can be used in @@ -295,13 +331,13 @@ class Summarize(Accumulator): def __init__(self, bias: float = 0.0): self._sum = pa.scalar(bias) - def state(self) -> List[pa.Scalar]: + def state(self) -> list[pa.Scalar]: return [self._sum] def update(self, values: pa.Array) -> None: self._sum = pa.scalar(self._sum.as_py() + pc.sum(values).as_py()) - def merge(self, states: List[pa.Array]) -> None: + def merge(self, states: list[pa.Array]) -> None: self._sum = pa.scalar(self._sum.as_py() + pc.sum(states[0]).as_py()) def evaluate(self) -> pa.Scalar: @@ -339,36 +375,23 @@ def udf4() -> Summarize: aggregation or window function calls. """ - def __new__(cls, *args, **kwargs): - """Create a new UDAF. - - Trigger UDAF function or decorator depending on if the first args is - callable - """ - if args and callable(args[0]): - # Case 1: Used as a function, require the first parameter to be callable - return cls._function(*args, **kwargs) - # Case 2: Used as a decorator with parameters - return cls._decorator(*args, **kwargs) - - @staticmethod - def _function( + def _function( # noqa: PLR0913 accum: Callable[[], Accumulator], - input_types: pyarrow.DataType | list[pyarrow.DataType], - return_type: pyarrow.DataType, - state_type: list[pyarrow.DataType], + input_types: pa.DataType | list[pa.DataType], + return_type: pa.DataType, + state_type: list[pa.DataType], volatility: Volatility | str, name: Optional[str] = None, ) -> AggregateUDF: if not callable(accum): - raise TypeError("`func` must be callable.") - if not isinstance(accum.__call__(), Accumulator): - raise TypeError( - "Accumulator must implement the abstract base class Accumulator" - ) + msg = "`func` must be callable." + raise TypeError(msg) + if not isinstance(accum(), Accumulator): + msg = "Accumulator must implement the abstract base class Accumulator" + raise TypeError(msg) if name is None: - name = accum.__call__().__class__.__qualname__.lower() - if isinstance(input_types, pyarrow.DataType): + name = accum().__class__.__qualname__.lower() + if isinstance(input_types, pa.DataType): input_types = [input_types] return AggregateUDF( name=name, @@ -379,27 +402,32 @@ def _function( volatility=volatility, ) - @staticmethod def _decorator( - input_types: pyarrow.DataType | list[pyarrow.DataType], - return_type: pyarrow.DataType, - state_type: list[pyarrow.DataType], + input_types: pa.DataType | list[pa.DataType], + return_type: pa.DataType, + state_type: list[pa.DataType], volatility: Volatility | str, name: Optional[str] = None, - ): - def decorator(accum: Callable[[], Accumulator]): + ) -> Callable[..., Callable[..., Expr]]: + def decorator(accum: Callable[[], Accumulator]) -> Callable[..., Expr]: udaf_caller = AggregateUDF.udaf( accum, input_types, return_type, state_type, volatility, name ) @functools.wraps(accum) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Expr: return udaf_caller(*args, **kwargs) return wrapper return decorator + if args and callable(args[0]): + # Case 1: Used as a function, require the first parameter to be callable + return _function(*args, **kwargs) + # Case 2: Used as a decorator with parameters + return _decorator(*args, **kwargs) + class WindowEvaluator(metaclass=ABCMeta): """Evaluator class for user-defined window functions (UDWF). @@ -417,9 +445,9 @@ class WindowEvaluator(metaclass=ABCMeta): +------------------------+--------------------------------+------------------+---------------------------+ | True | True/False | True/False | ``evaluate`` | +------------------------+--------------------------------+------------------+---------------------------+ - """ # noqa: W505 + """ # noqa: W505, E501 - def memoize(self) -> None: + def memoize(self) -> None: # noqa: B027 """Perform a memoize operation to improve performance. When the window frame has a fixed beginning (e.g UNBOUNDED @@ -431,7 +459,7 @@ def memoize(self) -> None: such functions can save whatever they need """ - def get_range(self, idx: int, num_rows: int) -> tuple[int, int]: + def get_range(self, idx: int, num_rows: int) -> tuple[int, int]: # noqa: ARG002 """Return the range for the window fuction. If `uses_window_frame` flag is `false`. This method is used to @@ -453,14 +481,17 @@ def is_causal(self) -> bool: """Get whether evaluator needs future data for its result.""" return False - def evaluate_all(self, values: list[pyarrow.Array], num_rows: int) -> pyarrow.Array: + def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array: # noqa: B027 """Evaluate a window function on an entire input partition. This function is called once per input *partition* for window functions that *do not use* values from the window frame, such as - :py:func:`~datafusion.functions.row_number`, :py:func:`~datafusion.functions.rank`, - :py:func:`~datafusion.functions.dense_rank`, :py:func:`~datafusion.functions.percent_rank`, - :py:func:`~datafusion.functions.cume_dist`, :py:func:`~datafusion.functions.lead`, + :py:func:`~datafusion.functions.row_number`, + :py:func:`~datafusion.functions.rank`, + :py:func:`~datafusion.functions.dense_rank`, + :py:func:`~datafusion.functions.percent_rank`, + :py:func:`~datafusion.functions.cume_dist`, + :py:func:`~datafusion.functions.lead`, and :py:func:`~datafusion.functions.lag`. It produces the result of all rows in a single pass. It @@ -492,11 +523,11 @@ def evaluate_all(self, values: list[pyarrow.Array], num_rows: int) -> pyarrow.Ar .. code-block:: text avg(x) OVER (PARTITION BY y ORDER BY z ROWS BETWEEN 2 PRECEDING AND 3 FOLLOWING) - """ # noqa: W505 + """ # noqa: W505, E501 - def evaluate( - self, values: list[pyarrow.Array], eval_range: tuple[int, int] - ) -> pyarrow.Scalar: + def evaluate( # noqa: B027 + self, values: list[pa.Array], eval_range: tuple[int, int] + ) -> pa.Scalar: """Evaluate window function on a range of rows in an input partition. This is the simplest and most general function to implement @@ -512,9 +543,9 @@ def evaluate( single argument, `values[1..]` will contain ORDER BY expression results. """ - def evaluate_all_with_rank( + def evaluate_all_with_rank( # noqa: B027 self, num_rows: int, ranks_in_partition: list[tuple[int, int]] - ) -> pyarrow.Array: + ) -> pa.Array: """Called for window functions that only need the rank of a row. Evaluate the partition evaluator against the partition using @@ -557,10 +588,6 @@ def include_rank(self) -> bool: return False -if TYPE_CHECKING: - _W = TypeVar("_W", bound=WindowEvaluator) - - class WindowUDF: """Class for performing window user-defined functions (UDF). @@ -572,8 +599,8 @@ def __init__( self, name: str, func: Callable[[], WindowEvaluator], - input_types: list[pyarrow.DataType], - return_type: pyarrow.DataType, + input_types: list[pa.DataType], + return_type: pa.DataType, volatility: Volatility | str, ) -> None: """Instantiate a user-defined window function (UDWF). @@ -597,8 +624,8 @@ def __call__(self, *args: Expr) -> Expr: @staticmethod def udwf( func: Callable[[], WindowEvaluator], - input_types: pyarrow.DataType | list[pyarrow.DataType], - return_type: pyarrow.DataType, + input_types: pa.DataType | list[pa.DataType], + return_type: pa.DataType, volatility: Volatility | str, name: Optional[str] = None, ) -> WindowUDF: @@ -638,16 +665,16 @@ def bias_10() -> BiasedNumbers: Returns: A user-defined window function. - """ # noqa W505 + """ # noqa: W505, E501 if not callable(func): - raise TypeError("`func` must be callable.") - if not isinstance(func.__call__(), WindowEvaluator): - raise TypeError( - "`func` must implement the abstract base class WindowEvaluator" - ) + msg = "`func` must be callable." + raise TypeError(msg) + if not isinstance(func(), WindowEvaluator): + msg = "`func` must implement the abstract base class WindowEvaluator" + raise TypeError(msg) if name is None: - name = func.__call__().__class__.__qualname__.lower() - if isinstance(input_types, pyarrow.DataType): + name = func().__class__.__qualname__.lower() + if isinstance(input_types, pa.DataType): input_types = [input_types] return WindowUDF( name=name, @@ -657,6 +684,7 @@ def bias_10() -> BiasedNumbers: volatility=volatility, ) + # Convenience exports so we can import instead of treating as # variables at the package root udf = ScalarUDF.udf From 5d2e384e3d61300cf7b73e242c088020f5ec9fa2 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 9 Mar 2025 20:07:01 -0400 Subject: [PATCH 13/24] more ruff formatting suggestions --- pyproject.toml | 1 + python/datafusion/io.py | 29 ++++++++++++++++++----------- python/datafusion/plan.py | 6 +++--- python/datafusion/record_batch.py | 8 ++++---- python/datafusion/substrait.py | 11 ++++++----- python/datafusion/udf.py | 2 ++ 6 files changed, 34 insertions(+), 23 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5bdaa0518..0a7f48e2a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ ignore = [ "ANN401", # Allow Any for wrapper classes "COM812", # Recommended to ignore these rules when using with ruff-format "ISC001", # Recommended to ignore these rules when using with ruff-format + "SLF001", # Allow accessing private members "TD002", "UP007" # Disallowing Union is pedantic ] diff --git a/python/datafusion/io.py b/python/datafusion/io.py index 3b6264948..a97228369 100644 --- a/python/datafusion/io.py +++ b/python/datafusion/io.py @@ -19,23 +19,28 @@ from __future__ import annotations -import pathlib - -import pyarrow +from typing import TYPE_CHECKING from datafusion.dataframe import DataFrame -from datafusion.expr import Expr from ._internal import SessionContext as SessionContextInternal +if TYPE_CHECKING: + import pathlib + + import pyarrow as pa + + from datafusion.expr import Expr + -def read_parquet( +def read_parquet( # noqa: PLR0913 path: str | pathlib.Path, + *, table_partition_cols: list[tuple[str, str]] | None = None, parquet_pruning: bool = True, file_extension: str = ".parquet", skip_metadata: bool = True, - schema: pyarrow.Schema | None = None, + schema: pa.Schema | None = None, file_sort_order: list[list[Expr]] | None = None, ) -> DataFrame: """Read a Parquet source into a :py:class:`~datafusion.dataframe.Dataframe`. @@ -77,9 +82,10 @@ def read_parquet( ) -def read_json( +def read_json( # noqa: PLR0913 path: str | pathlib.Path, - schema: pyarrow.Schema | None = None, + *, + schema: pa.Schema | None = None, schema_infer_max_records: int = 1000, file_extension: str = ".json", table_partition_cols: list[tuple[str, str]] | None = None, @@ -118,9 +124,10 @@ def read_json( ) -def read_csv( +def read_csv( # noqa: PLR0913 path: str | pathlib.Path | list[str] | list[pathlib.Path], - schema: pyarrow.Schema | None = None, + *, + schema: pa.Schema | None = None, has_header: bool = True, delimiter: str = ",", schema_infer_max_records: int = 1000, @@ -173,7 +180,7 @@ def read_csv( def read_avro( path: str | pathlib.Path, - schema: pyarrow.Schema | None = None, + schema: pa.Schema | None = None, file_partition_cols: list[tuple[str, str]] | None = None, file_extension: str = ".avro", ) -> DataFrame: diff --git a/python/datafusion/plan.py b/python/datafusion/plan.py index d9f5b81e7..0b7bebcb3 100644 --- a/python/datafusion/plan.py +++ b/python/datafusion/plan.py @@ -19,7 +19,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, List +from typing import TYPE_CHECKING, Any import datafusion._internal as df_internal @@ -54,7 +54,7 @@ def to_variant(self) -> Any: """Convert the logical plan into its specific variant.""" return self._raw_plan.to_variant() - def inputs(self) -> List[LogicalPlan]: + def inputs(self) -> list[LogicalPlan]: """Returns the list of inputs to the logical plan.""" return [LogicalPlan(p) for p in self._raw_plan.inputs()] @@ -106,7 +106,7 @@ def __init__(self, plan: df_internal.ExecutionPlan) -> None: """This constructor should not be called by the end user.""" self._raw_plan = plan - def children(self) -> List[ExecutionPlan]: + def children(self) -> list[ExecutionPlan]: """Get a list of children `ExecutionPlan` that act as inputs to this plan. The returned list will be empty for leaf nodes such as scans, will contain a diff --git a/python/datafusion/record_batch.py b/python/datafusion/record_batch.py index 772cd9089..556eaa786 100644 --- a/python/datafusion/record_batch.py +++ b/python/datafusion/record_batch.py @@ -26,14 +26,14 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - import pyarrow + import pyarrow as pa import typing_extensions import datafusion._internal as df_internal class RecordBatch: - """This class is essentially a wrapper for :py:class:`pyarrow.RecordBatch`.""" + """This class is essentially a wrapper for :py:class:`pa.RecordBatch`.""" def __init__(self, record_batch: df_internal.RecordBatch) -> None: """This constructor is generally not called by the end user. @@ -42,8 +42,8 @@ def __init__(self, record_batch: df_internal.RecordBatch) -> None: """ self.record_batch = record_batch - def to_pyarrow(self) -> pyarrow.RecordBatch: - """Convert to :py:class:`pyarrow.RecordBatch`.""" + def to_pyarrow(self) -> pa.RecordBatch: + """Convert to :py:class:`pa.RecordBatch`.""" return self.record_batch.to_pyarrow() diff --git a/python/datafusion/substrait.py b/python/datafusion/substrait.py index 36cfa074a..f10adfb0c 100644 --- a/python/datafusion/substrait.py +++ b/python/datafusion/substrait.py @@ -23,7 +23,6 @@ from __future__ import annotations -import pathlib from typing import TYPE_CHECKING try: @@ -36,6 +35,8 @@ from ._internal import substrait as substrait_internal if TYPE_CHECKING: + import pathlib + from datafusion.context import SessionContext __all__ = [ @@ -68,7 +69,7 @@ def encode(self) -> bytes: @deprecated("Use `Plan` instead.") -class plan(Plan): +class plan(Plan): # noqa: N801 """See `Plan`.""" @@ -138,7 +139,7 @@ def deserialize_bytes(proto_bytes: bytes) -> Plan: @deprecated("Use `Serde` instead.") -class serde(Serde): +class serde(Serde): # noqa: N801 """See `Serde` instead.""" @@ -164,7 +165,7 @@ def to_substrait_plan(logical_plan: LogicalPlan, ctx: SessionContext) -> Plan: @deprecated("Use `Producer` instead.") -class producer(Producer): +class producer(Producer): # noqa: N801 """Use `Producer` instead.""" @@ -188,5 +189,5 @@ def from_substrait_plan(ctx: SessionContext, plan: Plan) -> LogicalPlan: @deprecated("Use `Consumer` instead.") -class consumer(Consumer): +class consumer(Consumer): # noqa: N801 """Use `Consumer` instead.""" diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index fbb208459..a96f95e67 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -575,10 +575,12 @@ def evaluate_all_with_rank( # noqa: B027 The user must implement this method if ``include_rank`` returns True. """ + @abstractmethod def supports_bounded_execution(self) -> bool: """Can the window function be incrementally computed using bounded memory?""" return False + @abstractmethod def uses_window_frame(self) -> bool: """Does the window function use the values from the window frame?""" return False From 412eef0aa8a956578de62a489d56ef892c4a8fe3 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 9 Mar 2025 20:20:36 -0400 Subject: [PATCH 14/24] More ruff formatting --- pyproject.toml | 2 ++ python/datafusion/common.py | 2 +- python/datafusion/functions.py | 6 +++-- python/datafusion/input/__init__.py | 2 +- python/datafusion/input/base.py | 4 +-- python/datafusion/input/location.py | 42 ++++++++++++++--------------- 6 files changed, 31 insertions(+), 27 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0a7f48e2a..78acdd75c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,9 +70,11 @@ select = ["ALL" ] ignore = [ "ANN401", # Allow Any for wrapper classes "COM812", # Recommended to ignore these rules when using with ruff-format + "FIX002", # Allow TODO lines - consider removing at some point "ISC001", # Recommended to ignore these rules when using with ruff-format "SLF001", # Allow accessing private members "TD002", + "TD003", # Allow TODO lines "UP007" # Disallowing Union is pedantic ] diff --git a/python/datafusion/common.py b/python/datafusion/common.py index 6dc3a3acb..e762a993b 100644 --- a/python/datafusion/common.py +++ b/python/datafusion/common.py @@ -20,7 +20,7 @@ from ._internal import common as common_internal -# TODO these should all have proper wrapper classes +# TODO: these should all have proper wrapper classes DFSchema = common_internal.DFSchema DataType = common_internal.DataType diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index b449c4868..f9a91b64e 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -18,13 +18,12 @@ from __future__ import annotations -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional import pyarrow as pa from datafusion._internal import functions as f from datafusion.common import NullTreatment -from datafusion.context import SessionContext from datafusion.expr import ( CaseBuilder, Expr, @@ -34,6 +33,9 @@ sort_list_to_raw_sort_list, ) +if TYPE_CHECKING: + from datafusion.context import SessionContext + __all__ = [ "abs", "acos", diff --git a/python/datafusion/input/__init__.py b/python/datafusion/input/__init__.py index f85ce21f0..f0c1f42b4 100644 --- a/python/datafusion/input/__init__.py +++ b/python/datafusion/input/__init__.py @@ -23,5 +23,5 @@ from .location import LocationInputPlugin __all__ = [ - LocationInputPlugin, + "LocationInputPlugin", ] diff --git a/python/datafusion/input/base.py b/python/datafusion/input/base.py index d53f7b23c..bb3ee183b 100644 --- a/python/datafusion/input/base.py +++ b/python/datafusion/input/base.py @@ -38,9 +38,9 @@ class BaseInputSource(ABC): """ @abstractmethod - def is_correct_input(self, input_item: Any, table_name: str, **kwargs) -> bool: + def is_correct_input(self, input_item: Any, table_name: str, **kwargs: Any) -> bool: """Returns `True` if the input is valid.""" @abstractmethod - def build_table(self, input_item: Any, table_name: str, **kwarg) -> SqlTable: + def build_table(self, input_item: Any, table_name: str, **kwarg: Any) -> SqlTable: # type: ignore[invalid-type-form] """Create a table from the input source.""" diff --git a/python/datafusion/input/location.py b/python/datafusion/input/location.py index 6e9736738..343dca2da 100644 --- a/python/datafusion/input/location.py +++ b/python/datafusion/input/location.py @@ -18,7 +18,7 @@ """The default input source for DataFusion.""" import glob -import os +from pathlib import Path from typing import Any from datafusion.common import DataTypeMap, SqlTable @@ -31,7 +31,7 @@ class LocationInputPlugin(BaseInputSource): This can be read in from a file (on disk, remote etc.). """ - def is_correct_input(self, input_item: Any, table_name: str, **kwargs): + def is_correct_input(self, input_item: Any, table_name: str, **kwargs: Any) -> bool: # noqa: ARG002 """Returns `True` if the input is valid.""" return isinstance(input_item, str) @@ -39,27 +39,28 @@ def build_table( self, input_item: str, table_name: str, - **kwargs, - ) -> SqlTable: + **kwargs: Any, # noqa: ARG002 + ) -> SqlTable: # type: ignore[invalid-type-form] """Create a table from the input source.""" - _, extension = os.path.splitext(input_item) - format = extension.lstrip(".").lower() + extension = Path(input_item).suffix + file_format = extension.lstrip(".").lower() num_rows = 0 # Total number of rows in the file. Used for statistics columns = [] - if format == "parquet": + if file_format == "parquet": import pyarrow.parquet as pq # Read the Parquet metadata metadata = pq.read_metadata(input_item) num_rows = metadata.num_rows # Iterate through the schema and build the SqlTable - for col in metadata.schema: - columns.append( - ( - col.name, - DataTypeMap.from_parquet_type_str(col.physical_type), - ) + columns = [ + ( + col.name, + DataTypeMap.from_parquet_type_str(col.physical_type), ) + for col in metadata.schema + ] + elif format == "csv": import csv @@ -69,21 +70,20 @@ def build_table( # to get that information. However, this should only be occurring # at table creation time and therefore shouldn't # slow down query performance. - with open(input_item) as file: + with Path(input_item).open() as file: reader = csv.reader(file) - header_row = next(reader) - print(header_row) + _header_row = next(reader) for _ in reader: num_rows += 1 # TODO: Need to actually consume this row into reasonable columns - raise RuntimeError("TODO: Currently unable to support CSV input files.") + msg = "TODO: Currently unable to support CSV input files." + raise RuntimeError(msg) else: - raise RuntimeError( - f"Input of format: `{format}` is currently not supported.\ + msg = f"Input of format: `{format}` is currently not supported.\ Only Parquet and CSV." - ) + raise RuntimeError(msg) # Input could possibly be multiple files. Create a list if so - input_files = glob.glob(input_item) + input_files = glob.glob(input_item) # noqa: PTH207 return SqlTable(table_name, columns, num_rows, input_files) From 76fda6f86a775f8b34d2e8fc9b4668fd68994083 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 9 Mar 2025 20:22:59 -0400 Subject: [PATCH 15/24] More ruff formatting --- pyproject.toml | 2 ++ python/datafusion/functions.py | 40 ++++++++++++++++----------------- python/datafusion/input/base.py | 2 +- python/datafusion/io.py | 3 --- 4 files changed, 23 insertions(+), 24 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 78acdd75c..434aef8db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,8 @@ ignore = [ "ANN401", # Allow Any for wrapper classes "COM812", # Recommended to ignore these rules when using with ruff-format "FIX002", # Allow TODO lines - consider removing at some point + "FBT001", # Allow boolean positional args + "FBT002", # Allow boolean positional args "ISC001", # Recommended to ignore these rules when using with ruff-format "SLF001", # Allow accessing private members "TD002", diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index f9a91b64e..0cc7434cf 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -83,8 +83,8 @@ "array_sort", "array_to_string", "array_union", - "arrow_typeof", "arrow_cast", + "arrow_typeof", "ascii", "asin", "asinh", @@ -99,6 +99,7 @@ "bool_and", "bool_or", "btrim", + "cardinality", "case", "cbrt", "ceil", @@ -118,6 +119,7 @@ "covar", "covar_pop", "covar_samp", + "cume_dist", "current_date", "current_time", "date_bin", @@ -127,17 +129,17 @@ "datetrunc", "decode", "degrees", + "dense_rank", "digest", "empty", "encode", "ends_with", - "extract", "exp", + "extract", "factorial", "find_in_set", "first_value", "flatten", - "cardinality", "floor", "from_unixtime", "gcd", @@ -145,8 +147,10 @@ "initcap", "isnan", "iszero", + "lag", "last_value", "lcm", + "lead", "left", "length", "levenshtein", @@ -168,10 +172,10 @@ "list_prepend", "list_push_back", "list_push_front", - "list_repeat", "list_remove", "list_remove_all", "list_remove_n", + "list_repeat", "list_replace", "list_replace_all", "list_replace_n", @@ -182,14 +186,14 @@ "list_union", "ln", "log", - "log10", "log2", + "log10", "lower", "lpad", "ltrim", "make_array", - "make_list", "make_date", + "make_list", "max", "md5", "mean", @@ -197,19 +201,22 @@ "min", "named_struct", "nanvl", - "nvl", "now", "nth_value", + "ntile", "nullif", + "nvl", "octet_length", "order_by", "overlay", + "percent_rank", "pi", "pow", "power", "radians", "random", "range", + "rank", "regexp_like", "regexp_match", "regexp_replace", @@ -227,6 +234,7 @@ "reverse", "right", "round", + "row_number", "rpad", "rtrim", "sha224", @@ -254,8 +262,8 @@ "to_hex", "to_timestamp", "to_timestamp_micros", - "to_timestamp_nanos", "to_timestamp_millis", + "to_timestamp_nanos", "to_timestamp_seconds", "to_unixtime", "translate", @@ -270,14 +278,6 @@ "when", # Window Functions "window", - "lead", - "lag", - "row_number", - "rank", - "dense_rank", - "percent_rank", - "cume_dist", - "ntile", ] @@ -294,14 +294,14 @@ def nullif(expr1: Expr, expr2: Expr) -> Expr: return Expr(f.nullif(expr1.expr, expr2.expr)) -def encode(input: Expr, encoding: Expr) -> Expr: +def encode(expr: Expr, encoding: Expr) -> Expr: """Encode the ``input``, using the ``encoding``. encoding can be base64 or hex.""" - return Expr(f.encode(input.expr, encoding.expr)) + return Expr(f.encode(expr.expr, encoding.expr)) -def decode(input: Expr, encoding: Expr) -> Expr: +def decode(expr: Expr, encoding: Expr) -> Expr: """Decode the ``input``, using the ``encoding``. encoding can be base64 or hex.""" - return Expr(f.decode(input.expr, encoding.expr)) + return Expr(f.decode(expr.expr, encoding.expr)) def array_to_string(expr: Expr, delimiter: Expr) -> Expr: diff --git a/python/datafusion/input/base.py b/python/datafusion/input/base.py index bb3ee183b..f67dde2a1 100644 --- a/python/datafusion/input/base.py +++ b/python/datafusion/input/base.py @@ -42,5 +42,5 @@ def is_correct_input(self, input_item: Any, table_name: str, **kwargs: Any) -> b """Returns `True` if the input is valid.""" @abstractmethod - def build_table(self, input_item: Any, table_name: str, **kwarg: Any) -> SqlTable: # type: ignore[invalid-type-form] + def build_table(self, input_item: Any, table_name: str, **kwarg: Any) -> SqlTable: # type: ignore[invalid-type-form] """Create a table from the input source.""" diff --git a/python/datafusion/io.py b/python/datafusion/io.py index a97228369..a8ad7bf99 100644 --- a/python/datafusion/io.py +++ b/python/datafusion/io.py @@ -35,7 +35,6 @@ def read_parquet( # noqa: PLR0913 path: str | pathlib.Path, - *, table_partition_cols: list[tuple[str, str]] | None = None, parquet_pruning: bool = True, file_extension: str = ".parquet", @@ -84,7 +83,6 @@ def read_parquet( # noqa: PLR0913 def read_json( # noqa: PLR0913 path: str | pathlib.Path, - *, schema: pa.Schema | None = None, schema_infer_max_records: int = 1000, file_extension: str = ".json", @@ -126,7 +124,6 @@ def read_json( # noqa: PLR0913 def read_csv( # noqa: PLR0913 path: str | pathlib.Path | list[str] | list[pathlib.Path], - *, schema: pa.Schema | None = None, has_header: bool = True, delimiter: str = ",", From 591425fcca1ec852ef14746abd9f74d2567b1eb6 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 9 Mar 2025 20:25:19 -0400 Subject: [PATCH 16/24] Cut off lint errors for this PR --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 434aef8db..c02a60350 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,8 @@ features = ["substrait"] [tool.ruff.lint] select = ["ALL" ] ignore = [ + "A001", # Allow using words like min as variable names + "A002", # Allow using words like filter as variable names "ANN401", # Allow Any for wrapper classes "COM812", # Recommended to ignore these rules when using with ruff-format "FIX002", # Allow TODO lines - consider removing at some point @@ -78,6 +80,7 @@ ignore = [ "TD002", "TD003", # Allow TODO lines "UP007" # Disallowing Union is pedantic + # TODO: Enable all of the following, but this PR is getting too large already ] [tool.ruff.lint.pydocstyle] From bd18e40cf7af01013296ef6948873a731adcdd58 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 9 Mar 2025 20:48:54 -0400 Subject: [PATCH 17/24] Working through more ruff checks and disabling a bunch for now --- docs/source/conf.py | 2 +- examples/python-udwf.py | 2 +- pyproject.toml | 45 +++++++++++++++++++++++++---- python/datafusion/dataframe.py | 11 ++++--- python/datafusion/input/location.py | 2 +- python/datafusion/io.py | 6 ++-- python/datafusion/udf.py | 8 ++--- python/tests/test_aggregation.py | 2 +- python/tests/test_functions.py | 2 +- python/tests/test_imports.py | 5 ++-- 10 files changed, 62 insertions(+), 23 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 2e5a41339..c82a189e0 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -73,7 +73,7 @@ autoapi_python_class_content = "both" -def autoapi_skip_member_fn(app, what, name, obj, skip, options): +def autoapi_skip_member_fn(app, what, name, obj, skip, options): # noqa: ARG001 skip_contents = [ # Re-exports ("class", "datafusion.DataFrame"), diff --git a/examples/python-udwf.py b/examples/python-udwf.py index 7d39dc1b8..98d118bf2 100644 --- a/examples/python-udwf.py +++ b/examples/python-udwf.py @@ -59,7 +59,7 @@ def __init__(self, alpha: float) -> None: def supports_bounded_execution(self) -> bool: return True - def get_range(self, idx: int, num_rows: int) -> tuple[int, int]: + def get_range(self, idx: int, num_rows: int) -> tuple[int, int]: # noqa: ARG002 # Override the default range of current row since uses_window_frame is False # So for the purpose of this test we just smooth from the previous row to # current. diff --git a/pyproject.toml b/pyproject.toml index c02a60350..48ee7f86d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,8 +79,42 @@ ignore = [ "SLF001", # Allow accessing private members "TD002", "TD003", # Allow TODO lines - "UP007" # Disallowing Union is pedantic + "UP007", # Disallowing Union is pedantic # TODO: Enable all of the following, but this PR is getting too large already + "PT001", + "ANN101", + "ANN204", + "B008", + "EM101", + "PLR0913", + "PLR1714", + "ANN201", + "C400", + "TRY003", + "B904", + "UP006", + "RUF012", + "FBT003", + "C416", + "SIM102", + "PGH003", + "PLR2004", + "PERF401", + "PD901", + "EM102", + "ERA001", + "SIM108", + "ICN001", + "ANN001", + "ANN202", + "PTH", + "N812", + "INP001", + "DTZ007", + "PLW2901", + "RET503", + "RUF015", + "TCH001", ] [tool.ruff.lint.pydocstyle] @@ -103,11 +137,12 @@ max-doc-length = 88 "PT011", "RUF015", "S608", - "PLR0913" + "PLR0913", + "PT004", ] -"examples/*" = ["D", "W505"] -"dev/*" = ["D"] -"benchmarks/*" = ["D", "F"] +"examples/*" = ["D", "W505", "E501", "T201", "S101"] +"dev/*" = ["D", "E", "T", "S", "PLR", "C", "SIM", "UP", "EXE", "N817"] +"benchmarks/*" = ["D", "F", "T", "BLE", "FURB", "PLR", "E", "TD", "TRY", "S", "SIM", "EXE", "UP"] "docs/*" = ["D"] [dependency-groups] diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index c1d289c70..d1c71c2bb 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -29,6 +29,7 @@ List, Literal, Optional, + Type, Union, overload, ) @@ -49,10 +50,11 @@ import polars as pl import pyarrow as pa + from datafusion._internal import DataFrame as DataFrameInternal + from datafusion._internal import expr as expr_internal + from enum import Enum -from datafusion._internal import DataFrame as DataFrameInternal -from datafusion._internal import expr as expr_internal from datafusion.expr import Expr, SortExpr, sort_or_default @@ -73,7 +75,7 @@ class Compression(Enum): LZ4_RAW = "lz4_raw" @classmethod - def from_str(cls, value: str) -> Compression: + def from_str(cls: Type[Compression], value: str) -> Compression: """Convert a string to a Compression enum value. Args: @@ -88,8 +90,9 @@ def from_str(cls, value: str) -> Compression: try: return cls(value.lower()) except ValueError: + valid_values = str([item.value for item in Compression]) raise ValueError( - f"{value} is not a valid Compression. Valid values are: {[item.value for item in Compression]}" + f"{value} is not a valid Compression. Valid values are: {valid_values}" ) def get_default_level(self) -> Optional[int]: diff --git a/python/datafusion/input/location.py b/python/datafusion/input/location.py index 343dca2da..08d98d115 100644 --- a/python/datafusion/input/location.py +++ b/python/datafusion/input/location.py @@ -84,6 +84,6 @@ def build_table( raise RuntimeError(msg) # Input could possibly be multiple files. Create a list if so - input_files = glob.glob(input_item) # noqa: PTH207 + input_files = glob.glob(input_item) return SqlTable(table_name, columns, num_rows, input_files) diff --git a/python/datafusion/io.py b/python/datafusion/io.py index a8ad7bf99..3e39703e3 100644 --- a/python/datafusion/io.py +++ b/python/datafusion/io.py @@ -33,7 +33,7 @@ from datafusion.expr import Expr -def read_parquet( # noqa: PLR0913 +def read_parquet( path: str | pathlib.Path, table_partition_cols: list[tuple[str, str]] | None = None, parquet_pruning: bool = True, @@ -81,7 +81,7 @@ def read_parquet( # noqa: PLR0913 ) -def read_json( # noqa: PLR0913 +def read_json( path: str | pathlib.Path, schema: pa.Schema | None = None, schema_infer_max_records: int = 1000, @@ -122,7 +122,7 @@ def read_json( # noqa: PLR0913 ) -def read_csv( # noqa: PLR0913 +def read_csv( path: str | pathlib.Path | list[str] | list[pathlib.Path], schema: pa.Schema | None = None, has_header: bool = True, diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index a96f95e67..aebf9af1b 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -202,13 +202,13 @@ def _decorator( volatility: Volatility | str, name: Optional[str] = None, ) -> Callable: - def decorator(func: Callable): # noqa: ANN202 + def decorator(func: Callable): udf_caller = ScalarUDF.udf( func, input_types, return_type, volatility, name ) @functools.wraps(func) - def wrapper(*args: Any, **kwargs: Any): # noqa: ANN202 + def wrapper(*args: Any, **kwargs: Any): return udf_caller(*args, **kwargs) return wrapper @@ -249,7 +249,7 @@ class AggregateUDF: also :py:class:`ScalarUDF` for operating on a row by row basis. """ - def __init__( # noqa: PLR0913 + def __init__( self, name: str, accumulator: Callable[[], Accumulator], @@ -375,7 +375,7 @@ def udf4() -> Summarize: aggregation or window function calls. """ - def _function( # noqa: PLR0913 + def _function( accum: Callable[[], Accumulator], input_types: pa.DataType | list[pa.DataType], return_type: pa.DataType, diff --git a/python/tests/test_aggregation.py b/python/tests/test_aggregation.py index d5e76dd0e..61b1c7d80 100644 --- a/python/tests/test_aggregation.py +++ b/python/tests/test_aggregation.py @@ -88,7 +88,7 @@ def df_aggregate_100(): f.covar_samp(column("b"), column("c")), lambda a, b, c, d: np.array(np.cov(b, c, ddof=1)[0][1]), ), - # f.grouping(col_a), # No physical plan implemented yet # noqa: ERA001 + # f.grouping(col_a), # No physical plan implemented yet (f.max(column("a")), lambda a, b, c, d: np.array(np.max(a))), (f.mean(column("b")), lambda a, b, c, d: np.array(np.mean(b))), (f.median(column("b")), lambda a, b, c, d: np.array(np.median(b))), diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index 9e96742bb..ed88a16e3 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -81,7 +81,7 @@ def test_literal(df): literal("1"), literal("OK"), literal(3.14), - literal(True), # noqa: FBT003 + literal(True), literal(b"hello world"), ) result = df.collect() diff --git a/python/tests/test_imports.py b/python/tests/test_imports.py index 5313c937f..9ef7ed89a 100644 --- a/python/tests/test_imports.py +++ b/python/tests/test_imports.py @@ -169,9 +169,10 @@ def test_class_module_is_datafusion(): def test_import_from_functions_submodule(): - from datafusion.functions import abs, sin # noqa: A004 + from datafusion.functions import abs as df_abs + from datafusion.functions import sin - assert functions.abs is abs + assert functions.abs is df_abs assert functions.sin is sin msg = "cannot import name 'foobar' from 'datafusion.functions'" From 6102da56b47d2f9b778978f8ccb2e9c8c21d93d9 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 9 Mar 2025 20:54:01 -0400 Subject: [PATCH 18/24] Address CI difference from local ruff --- pyproject.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 48ee7f86d..4c73557d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,7 +82,6 @@ ignore = [ "UP007", # Disallowing Union is pedantic # TODO: Enable all of the following, but this PR is getting too large already "PT001", - "ANN101", "ANN204", "B008", "EM101", @@ -114,7 +113,7 @@ ignore = [ "PLW2901", "RET503", "RUF015", - "TCH001", + "TC001", ] [tool.ruff.lint.pydocstyle] From 5f3f7c7519f1bc93da1bfa7da8c28c6a16192a9f Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 10 Mar 2025 07:03:52 -0400 Subject: [PATCH 19/24] UDWF isn't a proper abstract base class right now since users can opt in to all methods --- pyproject.toml | 2 +- python/datafusion/udf.py | 12 +++++------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4c73557d2..9b5f35622 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,7 +113,7 @@ ignore = [ "PLW2901", "RET503", "RUF015", - "TC001", + "ANN101", ] [tool.ruff.lint.pydocstyle] diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index aebf9af1b..603b7063d 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -429,7 +429,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Expr: return _decorator(*args, **kwargs) -class WindowEvaluator(metaclass=ABCMeta): +class WindowEvaluator: """Evaluator class for user-defined window functions (UDWF). It is up to the user to decide which evaluate function is appropriate. @@ -447,7 +447,7 @@ class WindowEvaluator(metaclass=ABCMeta): +------------------------+--------------------------------+------------------+---------------------------+ """ # noqa: W505, E501 - def memoize(self) -> None: # noqa: B027 + def memoize(self) -> None: """Perform a memoize operation to improve performance. When the window frame has a fixed beginning (e.g UNBOUNDED @@ -481,7 +481,7 @@ def is_causal(self) -> bool: """Get whether evaluator needs future data for its result.""" return False - def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array: # noqa: B027 + def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array: """Evaluate a window function on an entire input partition. This function is called once per input *partition* for window functions that @@ -525,7 +525,7 @@ def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array: # no avg(x) OVER (PARTITION BY y ORDER BY z ROWS BETWEEN 2 PRECEDING AND 3 FOLLOWING) """ # noqa: W505, E501 - def evaluate( # noqa: B027 + def evaluate( self, values: list[pa.Array], eval_range: tuple[int, int] ) -> pa.Scalar: """Evaluate window function on a range of rows in an input partition. @@ -543,7 +543,7 @@ def evaluate( # noqa: B027 single argument, `values[1..]` will contain ORDER BY expression results. """ - def evaluate_all_with_rank( # noqa: B027 + def evaluate_all_with_rank( self, num_rows: int, ranks_in_partition: list[tuple[int, int]] ) -> pa.Array: """Called for window functions that only need the rank of a row. @@ -575,12 +575,10 @@ def evaluate_all_with_rank( # noqa: B027 The user must implement this method if ``include_rank`` returns True. """ - @abstractmethod def supports_bounded_execution(self) -> bool: """Can the window function be incrementally computed using bounded memory?""" return False - @abstractmethod def uses_window_frame(self) -> bool: """Does the window function use the values from the window frame?""" return False From 96ad9ffc4db01c709d9b1afd69df6f02d3e62d8f Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 10 Mar 2025 07:52:36 -0400 Subject: [PATCH 20/24] Update pre-commit to match the version of ruff used in CI --- .pre-commit-config.yaml | 2 +- pyproject.toml | 3 ++- uv.lock | 1 - 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b548ff18f..abcfcf321 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,7 +22,7 @@ repos: - id: actionlint-docker - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.3.0 + rev: v0.9.10 hooks: # Run the linter. - id: ruff diff --git a/pyproject.toml b/pyproject.toml index 9b5f35622..0ee1e2ec9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,7 +113,8 @@ ignore = [ "PLW2901", "RET503", "RUF015", - "ANN101", + "A005", + "TC001", ] [tool.ruff.lint.pydocstyle] diff --git a/uv.lock b/uv.lock index 587ddc8b7..8958c3086 100644 --- a/uv.lock +++ b/uv.lock @@ -351,7 +351,6 @@ wheels = [ [[package]] name = "datafusion" -version = "44.0.0" source = { editable = "." } dependencies = [ { name = "pyarrow", version = "17.0.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, From 9b6d627ef4a9b89e7eaad0fec76297f27d0f81ae Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 10 Mar 2025 08:07:47 -0400 Subject: [PATCH 21/24] To enable testing in python 3.9 we need numpy. Also going to the current minimal supported version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0ee1e2ec9..9f05c8b42 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -148,7 +148,7 @@ max-doc-length = 88 [dependency-groups] dev = [ "maturin>=1.8.1", - "numpy>1.24.4 ; python_full_version >= '3.10'", + "numpy>1.25.0", "pytest>=7.4.4", "ruff>=0.9.1", "toml>=0.10.2", From 099baaa50b292a25f6a153f822be9a93b95b3c83 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 10 Mar 2025 08:11:15 -0400 Subject: [PATCH 22/24] Update min requried version of python to 3.9 in pyproject.toml. The other changes will come in #1043 that is soon to be merged. --- pyproject.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9f05c8b42..66691674c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ name = "datafusion" description = "Build and run queries against data" readme = "README.md" license = { file = "LICENSE.txt" } -requires-python = ">=3.8" +requires-python = ">=3.9" keywords = ["datafusion", "dataframe", "rust", "query-engine"] classifiers = [ "Development Status :: 2 - Pre-Alpha", @@ -35,7 +35,6 @@ classifiers = [ "Operating System :: Microsoft :: Windows", "Operating System :: POSIX :: Linux", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", From a33bc2193f6bd8132af5da386b0216f274caac3c Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 10 Mar 2025 08:13:51 -0400 Subject: [PATCH 23/24] Suppress UP035 --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 66691674c..060e3b80a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -114,6 +114,7 @@ ignore = [ "RUF015", "A005", "TC001", + "UP035", ] [tool.ruff.lint.pydocstyle] From a52b330b967d683a07f52524be8c7acffa6fb392 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 10 Mar 2025 08:15:00 -0400 Subject: [PATCH 24/24] ruff format --- python/tests/test_context.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 04556ef4d..7a0a7aa08 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -529,9 +529,10 @@ def test_read_json_compressed(ctx, tmp_path): # File compression type gzip_path = tmp_path / "data.json.gz" - with pathlib.Path.open(test_data_path, "rb") as csv_file, gzip.open( - gzip_path, "wb" - ) as gzipped_file: + with ( + pathlib.Path.open(test_data_path, "rb") as csv_file, + gzip.open(gzip_path, "wb") as gzipped_file, + ): gzipped_file.writelines(csv_file) df = ctx.read_json(gzip_path, file_extension=".gz", file_compression_type="gz") @@ -568,9 +569,10 @@ def test_read_csv_compressed(ctx, tmp_path): # File compression type gzip_path = tmp_path / "aggregate_test_100.csv.gz" - with pathlib.Path.open(test_data_path, "rb") as csv_file, gzip.open( - gzip_path, "wb" - ) as gzipped_file: + with ( + pathlib.Path.open(test_data_path, "rb") as csv_file, + gzip.open(gzip_path, "wb") as gzipped_file, + ): gzipped_file.writelines(csv_file) csv_df = ctx.read_csv(gzip_path, file_extension=".gz", file_compression_type="gz")