From 72a6b5f40e643f208601eae3d55bbaf5e90d9509 Mon Sep 17 00:00:00 2001 From: jsai28 <54253219+jsai28@users.noreply.github.com> Date: Mon, 10 Mar 2025 20:12:03 -0600 Subject: [PATCH 01/13] Rename _global_ctx to global_ctx --- src/context.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/context.rs b/src/context.rs index 9ba87eb8a..0db0f4d7e 100644 --- a/src/context.rs +++ b/src/context.rs @@ -308,7 +308,7 @@ impl PySessionContext { #[classmethod] #[pyo3(signature = ())] - fn _global_ctx(_cls: &Bound<'_, PyType>) -> PyResult { + fn global_ctx(_cls: &Bound<'_, PyType>) -> PyResult { Ok(Self { ctx: get_global_ctx().clone(), }) From b7fea478c6cd8228faf00c1ab6131e4f7d7abfea Mon Sep 17 00:00:00 2001 From: jsai28 <54253219+jsai28@users.noreply.github.com> Date: Mon, 10 Mar 2025 22:40:32 -0600 Subject: [PATCH 02/13] Add global context to python wrapper code --- python/datafusion/context.py | 17 ++++++++++++++--- python/datafusion/io.py | 10 +++++----- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 282b2a477..d23e39b9c 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -498,6 +498,15 @@ def __init__( self.ctx = SessionContextInternal(config, runtime) + @classmethod + def global_ctx(cls) -> "SessionContext": + """Retrieve the global context + + Returns: + A `SessionContext` object that corresponds to the global context + """ + return SessionContextInternal.global_ctx() + def enable_url_table(self) -> "SessionContext": """Control if local files can be queried as tables. @@ -798,9 +807,11 @@ def register_parquet( file_extension, skip_metadata, schema, - [sort_list_to_raw_sort_list(exprs) for exprs in file_sort_order] - if file_sort_order is not None - else None, + ( + [sort_list_to_raw_sort_list(exprs) for exprs in file_sort_order] + if file_sort_order is not None + else None + ), ) def register_csv( diff --git a/python/datafusion/io.py b/python/datafusion/io.py index 3b6264948..c5c1f2e5b 100644 --- a/python/datafusion/io.py +++ b/python/datafusion/io.py @@ -26,7 +26,7 @@ from datafusion.dataframe import DataFrame from datafusion.expr import Expr -from ._internal import SessionContext as SessionContextInternal +from datafusion.context import SessionContext def read_parquet( @@ -65,7 +65,7 @@ def read_parquet( if table_partition_cols is None: table_partition_cols = [] return DataFrame( - SessionContextInternal._global_ctx().read_parquet( + SessionContext.global_ctx().read_parquet( str(path), table_partition_cols, parquet_pruning, @@ -107,7 +107,7 @@ def read_json( if table_partition_cols is None: table_partition_cols = [] return DataFrame( - SessionContextInternal._global_ctx().read_json( + SessionContext.global_ctx().read_json( str(path), schema, schema_infer_max_records, @@ -158,7 +158,7 @@ def read_csv( path = [str(p) for p in path] if isinstance(path, list) else str(path) return DataFrame( - SessionContextInternal._global_ctx().read_csv( + SessionContext.global_ctx().read_csv( path, schema, has_header, @@ -195,7 +195,7 @@ def read_avro( if file_partition_cols is None: file_partition_cols = [] return DataFrame( - SessionContextInternal._global_ctx().read_avro( + SessionContext.global_ctx().read_avro( str(path), schema, file_partition_cols, file_extension ) ) From 597ef33abefac0662a702b2934b898f1c6198b85 Mon Sep 17 00:00:00 2001 From: jsai28 <54253219+jsai28@users.noreply.github.com> Date: Tue, 11 Mar 2025 15:34:20 -0600 Subject: [PATCH 03/13] Update context.py --- python/datafusion/context.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index d23e39b9c..999f1c864 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -499,11 +499,11 @@ def __init__( self.ctx = SessionContextInternal(config, runtime) @classmethod - def global_ctx(cls) -> "SessionContext": - """Retrieve the global context + def global_ctx(cls) -> "SessionContextInternal": + """Retrieve the global context. Returns: - A `SessionContext` object that corresponds to the global context + A `SessionContextInternal` object that corresponds to the global context """ return SessionContextInternal.global_ctx() From 3751cf51e2bd7ae49fa1fdc41b1671a37cbcf2ab Mon Sep 17 00:00:00 2001 From: jsai28 <54253219+jsai28@users.noreply.github.com> Date: Tue, 11 Mar 2025 22:40:23 -0600 Subject: [PATCH 04/13] singleton for global context --- python/datafusion/context.py | 7 ++++++- python/tests/test_context.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 999f1c864..99dd2bd9e 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -468,6 +468,8 @@ class SessionContext: See :ref:`user_guide_concepts` in the online documentation for more information. """ + _global_instance = None + def __init__( self, config: SessionConfig | None = None, @@ -505,7 +507,10 @@ def global_ctx(cls) -> "SessionContextInternal": Returns: A `SessionContextInternal` object that corresponds to the global context """ - return SessionContextInternal.global_ctx() + if cls._global_instance is None: + internal_ctx = SessionContextInternal.global_ctx() + cls._global_instance = internal_ctx + return cls._global_instance def enable_url_table(self) -> "SessionContext": """Control if local files can be queried as tables. diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 91046e6b8..1e65d5d6c 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -30,8 +30,11 @@ SQLOptions, column, literal, + udf, ) +from datafusion._internal import SessionContext as SessionContextInternal + def test_create_context_no_args(): SessionContext() @@ -629,3 +632,32 @@ def test_sql_with_options_no_statements(ctx): options = SQLOptions().with_allow_statements(False) with pytest.raises(Exception, match="SetVariable"): ctx.sql_with_options(sql, options=options) + + +def test_global_context_type(): + ctx = SessionContext.global_ctx() + assert isinstance(ctx, SessionContextInternal) + + +def test_global_context_is_singleton(): + ctx1 = SessionContext.global_ctx() + ctx2 = SessionContext.global_ctx() + assert ctx1 is ctx2 + + +@pytest.fixture +def batch(): + return pa.RecordBatch.from_arrays( + [pa.array([4, 5, 6])], + names=["a"], + ) + + +def test_create_dataframe_with_global_ctx(batch): + ctx = SessionContext.global_ctx() + + df = ctx.create_dataframe([[batch]]) + + result = df.collect()[0].column(0) + + assert result == pa.array([4, 5, 6]) From 14ee8b3a28fd2f4c5cac5c27a4e24596338800a6 Mon Sep 17 00:00:00 2001 From: jsai28 <54253219+jsai28@users.noreply.github.com> Date: Tue, 11 Mar 2025 22:54:02 -0600 Subject: [PATCH 05/13] formatting --- python/datafusion/context.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 99dd2bd9e..de9737a17 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -812,11 +812,9 @@ def register_parquet( file_extension, skip_metadata, schema, - ( - [sort_list_to_raw_sort_list(exprs) for exprs in file_sort_order] - if file_sort_order is not None - else None - ), + [sort_list_to_raw_sort_list(exprs) for exprs in file_sort_order] + if file_sort_order is not None + else None, ) def register_csv( From 9e90b8175c40df18817eb35da34cf2d58dfe26da Mon Sep 17 00:00:00 2001 From: jsai28 <54253219+jsai28@users.noreply.github.com> Date: Tue, 11 Mar 2025 23:00:56 -0600 Subject: [PATCH 06/13] remove udf from import --- python/tests/test_context.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 1e65d5d6c..8844208dd 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -30,7 +30,6 @@ SQLOptions, column, literal, - udf, ) from datafusion._internal import SessionContext as SessionContextInternal From f716fb8bbd25f1786c2277a40a66965bc31c3973 Mon Sep 17 00:00:00 2001 From: jsai28 <54253219+jsai28@users.noreply.github.com> Date: Wed, 12 Mar 2025 12:32:11 -0600 Subject: [PATCH 07/13] remove _global_instance --- python/datafusion/context.py | 24 ++++++++++++------------ python/tests/test_context.py | 8 +------- 2 files changed, 13 insertions(+), 19 deletions(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index de9737a17..ccdf00adf 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -468,8 +468,6 @@ class SessionContext: See :ref:`user_guide_concepts` in the online documentation for more information. """ - _global_instance = None - def __init__( self, config: SessionConfig | None = None, @@ -501,16 +499,16 @@ def __init__( self.ctx = SessionContextInternal(config, runtime) @classmethod - def global_ctx(cls) -> "SessionContextInternal": - """Retrieve the global context. + def global_ctx(cls) -> "SessionContext": + """Retrieve the global context as a `SessionContext` wrapper. Returns: - A `SessionContextInternal` object that corresponds to the global context + A `SessionContext` object that wraps the global `SessionContextInternal`. """ - if cls._global_instance is None: - internal_ctx = SessionContextInternal.global_ctx() - cls._global_instance = internal_ctx - return cls._global_instance + internal_ctx = SessionContextInternal.global_ctx() + wrapper = cls() + wrapper.ctx = internal_ctx + return wrapper def enable_url_table(self) -> "SessionContext": """Control if local files can be queried as tables. @@ -812,9 +810,11 @@ def register_parquet( file_extension, skip_metadata, schema, - [sort_list_to_raw_sort_list(exprs) for exprs in file_sort_order] - if file_sort_order is not None - else None, + ( + [sort_list_to_raw_sort_list(exprs) for exprs in file_sort_order] + if file_sort_order is not None + else None + ), ) def register_csv( diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 8844208dd..06c37ab2e 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -635,13 +635,7 @@ def test_sql_with_options_no_statements(ctx): def test_global_context_type(): ctx = SessionContext.global_ctx() - assert isinstance(ctx, SessionContextInternal) - - -def test_global_context_is_singleton(): - ctx1 = SessionContext.global_ctx() - ctx2 = SessionContext.global_ctx() - assert ctx1 is ctx2 + assert isinstance(ctx, SessionContext) @pytest.fixture From b2bbf3304c7d10beb3594643b72cff82486f7bb7 Mon Sep 17 00:00:00 2001 From: jsai28 <54253219+jsai28@users.noreply.github.com> Date: Wed, 12 Mar 2025 12:36:04 -0600 Subject: [PATCH 08/13] formatting --- python/datafusion/context.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index ccdf00adf..984a6ca8e 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -810,11 +810,9 @@ def register_parquet( file_extension, skip_metadata, schema, - ( - [sort_list_to_raw_sort_list(exprs) for exprs in file_sort_order] - if file_sort_order is not None - else None - ), + [sort_list_to_raw_sort_list(exprs) for exprs in file_sort_order] + if file_sort_order is not None + else None ) def register_csv( From 8534f51ce32bd264581d7403ee2be09b3c547360 Mon Sep 17 00:00:00 2001 From: jsai28 <54253219+jsai28@users.noreply.github.com> Date: Wed, 12 Mar 2025 12:36:25 -0600 Subject: [PATCH 09/13] formatting --- python/datafusion/context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 984a6ca8e..ce9391900 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -812,7 +812,7 @@ def register_parquet( schema, [sort_list_to_raw_sort_list(exprs) for exprs in file_sort_order] if file_sort_order is not None - else None + else None, ) def register_csv( From ae9b6a2e16ea5ef245fb6af380b96d43c8647746 Mon Sep 17 00:00:00 2001 From: jsai28 <54253219+jsai28@users.noreply.github.com> Date: Wed, 12 Mar 2025 12:43:07 -0600 Subject: [PATCH 10/13] unnecessary test --- python/tests/test_context.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 06c37ab2e..98593a321 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -32,8 +32,6 @@ literal, ) -from datafusion._internal import SessionContext as SessionContextInternal - def test_create_context_no_args(): SessionContext() @@ -633,11 +631,6 @@ def test_sql_with_options_no_statements(ctx): ctx.sql_with_options(sql, options=options) -def test_global_context_type(): - ctx = SessionContext.global_ctx() - assert isinstance(ctx, SessionContext) - - @pytest.fixture def batch(): return pa.RecordBatch.from_arrays( From 069c4a339b1b6b649afbd16d45c45a8f941d0cb2 Mon Sep 17 00:00:00 2001 From: jsai28 <54253219+jsai28@users.noreply.github.com> Date: Wed, 12 Mar 2025 13:25:26 -0600 Subject: [PATCH 11/13] fix test_io.py --- python/datafusion/io.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/python/datafusion/io.py b/python/datafusion/io.py index c5c1f2e5b..3c0b8ae3c 100644 --- a/python/datafusion/io.py +++ b/python/datafusion/io.py @@ -64,8 +64,7 @@ def read_parquet( """ if table_partition_cols is None: table_partition_cols = [] - return DataFrame( - SessionContext.global_ctx().read_parquet( + return SessionContext.global_ctx().read_parquet( str(path), table_partition_cols, parquet_pruning, @@ -74,7 +73,6 @@ def read_parquet( schema, file_sort_order, ) - ) def read_json( @@ -106,8 +104,7 @@ def read_json( """ if table_partition_cols is None: table_partition_cols = [] - return DataFrame( - SessionContext.global_ctx().read_json( + return SessionContext.global_ctx().read_json( str(path), schema, schema_infer_max_records, @@ -115,7 +112,6 @@ def read_json( table_partition_cols, file_compression_type, ) - ) def read_csv( @@ -157,8 +153,7 @@ def read_csv( path = [str(p) for p in path] if isinstance(path, list) else str(path) - return DataFrame( - SessionContext.global_ctx().read_csv( + return SessionContext.global_ctx().read_csv( path, schema, has_header, @@ -168,7 +163,6 @@ def read_csv( table_partition_cols, file_compression_type, ) - ) def read_avro( @@ -194,8 +188,6 @@ def read_avro( """ if file_partition_cols is None: file_partition_cols = [] - return DataFrame( - SessionContext.global_ctx().read_avro( + return SessionContext.global_ctx().read_avro( str(path), schema, file_partition_cols, file_extension ) - ) From d124275078d617e237d7297db10a43f6187537d1 Mon Sep 17 00:00:00 2001 From: jsai28 <54253219+jsai28@users.noreply.github.com> Date: Wed, 12 Mar 2025 15:50:37 -0600 Subject: [PATCH 12/13] ran ruff --- python/datafusion/context.py | 4 ++-- python/datafusion/io.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index ce9391900..9cc341b71 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -499,7 +499,7 @@ def __init__( self.ctx = SessionContextInternal(config, runtime) @classmethod - def global_ctx(cls) -> "SessionContext": + def global_ctx(cls) -> SessionContext: """Retrieve the global context as a `SessionContext` wrapper. Returns: @@ -510,7 +510,7 @@ def global_ctx(cls) -> "SessionContext": wrapper.ctx = internal_ctx return wrapper - def enable_url_table(self) -> "SessionContext": + def enable_url_table(self) -> SessionContext: """Control if local files can be queried as tables. Returns: diff --git a/python/datafusion/io.py b/python/datafusion/io.py index 3c0b8ae3c..b962e628a 100644 --- a/python/datafusion/io.py +++ b/python/datafusion/io.py @@ -23,11 +23,10 @@ import pyarrow +from datafusion.context import SessionContext from datafusion.dataframe import DataFrame from datafusion.expr import Expr -from datafusion.context import SessionContext - def read_parquet( path: str | pathlib.Path, From 2baf728e77a29e66a898276e7338f3923b247f90 Mon Sep 17 00:00:00 2001 From: jsai28 <54253219+jsai28@users.noreply.github.com> Date: Wed, 12 Mar 2025 16:04:42 -0600 Subject: [PATCH 13/13] ran ruff format --- python/datafusion/io.py | 52 ++++++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/python/datafusion/io.py b/python/datafusion/io.py index 1ba4b8b35..ef5ebf96f 100644 --- a/python/datafusion/io.py +++ b/python/datafusion/io.py @@ -68,14 +68,14 @@ def read_parquet( if table_partition_cols is None: table_partition_cols = [] return SessionContext.global_ctx().read_parquet( - str(path), - table_partition_cols, - parquet_pruning, - file_extension, - skip_metadata, - schema, - file_sort_order, - ) + str(path), + table_partition_cols, + parquet_pruning, + file_extension, + skip_metadata, + schema, + file_sort_order, + ) def read_json( @@ -108,13 +108,13 @@ def read_json( if table_partition_cols is None: table_partition_cols = [] return SessionContext.global_ctx().read_json( - str(path), - schema, - schema_infer_max_records, - file_extension, - table_partition_cols, - file_compression_type, - ) + str(path), + schema, + schema_infer_max_records, + file_extension, + table_partition_cols, + file_compression_type, + ) def read_csv( @@ -157,15 +157,15 @@ def read_csv( path = [str(p) for p in path] if isinstance(path, list) else str(path) return SessionContext.global_ctx().read_csv( - path, - schema, - has_header, - delimiter, - schema_infer_max_records, - file_extension, - table_partition_cols, - file_compression_type, - ) + path, + schema, + has_header, + delimiter, + schema_infer_max_records, + file_extension, + table_partition_cols, + file_compression_type, + ) def read_avro( @@ -192,5 +192,5 @@ def read_avro( if file_partition_cols is None: file_partition_cols = [] return SessionContext.global_ctx().read_avro( - str(path), schema, file_partition_cols, file_extension - ) + str(path), schema, file_partition_cols, file_extension + )