Skip to content

Commit dba5c6a

Browse files
committed
Refactor to use parking_lot for interior mutability in schema, config, dataframe, and conditional expression modules
1 parent 7030cec commit dba5c6a

File tree

6 files changed

+45
-87
lines changed

6 files changed

+45
-87
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ futures = "0.3"
5151
object_store = { version = "0.12.3", features = ["aws", "gcp", "azure", "http"] }
5252
url = "2"
5353
log = "0.4.27"
54+
parking_lot = "0.12"
5455

5556
[build-dependencies]
5657
prost-types = "0.13.1" # keep in line with `datafusion-substrait`

src/common/schema.rs

Lines changed: 14 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
// under the License.
1717

1818
use std::fmt::{self, Display, Formatter};
19-
use std::sync::{Arc, RwLock};
19+
use std::sync::Arc;
2020
use std::{any::Any, borrow::Cow};
2121

2222
use arrow::datatypes::Schema;
@@ -25,7 +25,6 @@ use datafusion::arrow::datatypes::SchemaRef;
2525
use datafusion::common::Constraints;
2626
use datafusion::datasource::TableType;
2727
use datafusion::logical_expr::{Expr, TableProviderFilterPushDown, TableSource};
28-
use pyo3::exceptions::PyRuntimeError;
2928
use pyo3::prelude::*;
3029

3130
use datafusion::logical_expr::utils::split_conjunction;
@@ -34,6 +33,8 @@ use crate::sql::logical::PyLogicalPlan;
3433

3534
use super::{data_type::DataTypeMap, function::SqlFunction};
3635

36+
use parking_lot::RwLock;
37+
3738
#[pyclass(name = "SqlSchema", module = "datafusion.common", subclass, frozen)]
3839
#[derive(Debug, Clone)]
3940
pub struct SqlSchema {
@@ -110,88 +111,60 @@ impl SqlSchema {
110111

111112
#[getter]
112113
fn name(&self) -> PyResult<String> {
113-
Ok(self
114-
.name
115-
.read()
116-
.map_err(|_| PyRuntimeError::new_err("failed to read schema name"))?
117-
.clone())
114+
Ok(self.name.read().clone())
118115
}
119116

120117
#[setter]
121118
fn set_name(&self, value: String) -> PyResult<()> {
122-
*self
123-
.name
124-
.write()
125-
.map_err(|_| PyRuntimeError::new_err("failed to write schema name"))? = value;
119+
*self.name.write() = value;
126120
Ok(())
127121
}
128122

129123
#[getter]
130124
fn tables(&self) -> PyResult<Vec<SqlTable>> {
131-
Ok(self
132-
.tables
133-
.read()
134-
.map_err(|_| PyRuntimeError::new_err("failed to read schema tables"))?
135-
.clone())
125+
Ok(self.tables.read().clone())
136126
}
137127

138128
#[setter]
139129
fn set_tables(&self, tables: Vec<SqlTable>) -> PyResult<()> {
140-
*self
141-
.tables
142-
.write()
143-
.map_err(|_| PyRuntimeError::new_err("failed to write schema tables"))? = tables;
130+
*self.tables.write() = tables;
144131
Ok(())
145132
}
146133

147134
#[getter]
148135
fn views(&self) -> PyResult<Vec<SqlView>> {
149-
Ok(self
150-
.views
151-
.read()
152-
.map_err(|_| PyRuntimeError::new_err("failed to read schema views"))?
153-
.clone())
136+
Ok(self.views.read().clone())
154137
}
155138

156139
#[setter]
157140
fn set_views(&self, views: Vec<SqlView>) -> PyResult<()> {
158-
*self
159-
.views
160-
.write()
161-
.map_err(|_| PyRuntimeError::new_err("failed to write schema views"))? = views;
141+
*self.views.write() = views;
162142
Ok(())
163143
}
164144

165145
#[getter]
166146
fn functions(&self) -> PyResult<Vec<SqlFunction>> {
167-
Ok(self
168-
.functions
169-
.read()
170-
.map_err(|_| PyRuntimeError::new_err("failed to read schema functions"))?
171-
.clone())
147+
Ok(self.functions.read().clone())
172148
}
173149

174150
#[setter]
175151
fn set_functions(&self, functions: Vec<SqlFunction>) -> PyResult<()> {
176-
*self
177-
.functions
178-
.write()
179-
.map_err(|_| PyRuntimeError::new_err("failed to write schema functions"))? = functions;
152+
*self.functions.write() = functions;
180153
Ok(())
181154
}
182155

183156
pub fn table_by_name(&self, table_name: &str) -> Option<SqlTable> {
184-
let tables = self.tables.read().expect("failed to read schema tables");
157+
let tables = self.tables.read();
185158
tables.iter().find(|tbl| tbl.name.eq(table_name)).cloned()
186159
}
187160

188161
pub fn add_table(&self, table: SqlTable) {
189-
let mut tables = self.tables.write().expect("failed to write schema tables");
162+
let mut tables = self.tables.write();
190163
tables.push(table);
191164
}
192165

193166
pub fn drop_table(&self, table_name: String) {
194-
let mut tables = self.tables.write().expect("failed to write schema tables");
167+
let mut tables = self.tables.write();
195168
tables.retain(|x| !x.name.eq(&table_name));
196169
}
197170
}

src/config.rs

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,16 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use std::sync::{Arc, RwLock};
18+
use std::sync::Arc;
1919

2020
use pyo3::prelude::*;
2121
use pyo3::types::*;
2222

2323
use datafusion::config::ConfigOptions;
2424

25-
use crate::errors::{PyDataFusionError, PyDataFusionResult};
25+
use crate::errors::PyDataFusionResult;
2626
use crate::utils::py_obj_to_scalar_value;
27+
use parking_lot::RwLock;
2728

2829
#[pyclass(name = "Config", module = "datafusion", subclass, frozen)]
2930
#[derive(Clone)]
@@ -50,10 +51,7 @@ impl PyConfig {
5051

5152
/// Get a configuration option
5253
pub fn get<'py>(&self, key: &str, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
53-
let options = self
54-
.config
55-
.read()
56-
.map_err(|_| PyDataFusionError::Common("failed to read configuration".to_string()))?;
54+
let options = self.config.read();
5755
for entry in options.entries() {
5856
if entry.key == key {
5957
return Ok(entry.value.into_pyobject(py)?);
@@ -65,21 +63,15 @@ impl PyConfig {
6563
/// Set a configuration option
6664
pub fn set(&self, key: &str, value: PyObject, py: Python) -> PyDataFusionResult<()> {
6765
let scalar_value = py_obj_to_scalar_value(py, value)?;
68-
let mut options = self
69-
.config
70-
.write()
71-
.map_err(|_| PyDataFusionError::Common("failed to lock configuration".to_string()))?;
66+
let mut options = self.config.write();
7267
options.set(key, scalar_value.to_string().as_str())?;
7368
Ok(())
7469
}
7570

7671
/// Get all configuration options
7772
pub fn get_all(&self, py: Python) -> PyResult<PyObject> {
7873
let dict = PyDict::new(py);
79-
let options = self
80-
.config
81-
.read()
82-
.map_err(|_| PyDataFusionError::Common("failed to read configuration".to_string()))?;
74+
let options = self.config.read();
8375
for entry in options.entries() {
8476
dict.set_item(entry.key, entry.value.clone().into_pyobject(py)?)?;
8577
}

src/dataframe.rs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
use std::collections::HashMap;
1919
use std::ffi::CString;
20-
use std::sync::{Arc, Mutex};
20+
use std::sync::Arc;
2121

2222
use arrow::array::{new_null_array, RecordBatch, RecordBatchIterator, RecordBatchReader};
2323
use arrow::compute::can_cast_types;
@@ -58,6 +58,8 @@ use crate::{
5858
expr::{sort_expr::PySortExpr, PyExpr},
5959
};
6060

61+
use parking_lot::Mutex;
62+
6163
// https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116
6264
// - we have not decided on the table_provider approach yet
6365
// this is an interim implementation
@@ -307,9 +309,7 @@ impl PyDataFrame {
307309
let PythonFormatter { formatter, config } = get_python_formatter_with_config(py)?;
308310

309311
let (cached_batches, should_cache) = {
310-
let mut cache = self.batches.lock().map_err(|_| {
311-
PyDataFusionError::Common("failed to lock DataFrame display cache".to_string())
312-
})?;
312+
let mut cache = self.batches.lock();
313313
let should_cache = *is_ipython_env(py) && cache.is_none();
314314
let batches = cache.take();
315315
(batches, should_cache)
@@ -354,9 +354,7 @@ impl PyDataFrame {
354354
let html_str: String = html_result.extract()?;
355355

356356
if should_cache {
357-
let mut cache = self.batches.lock().map_err(|_| {
358-
PyDataFusionError::Common("failed to lock DataFrame display cache".to_string())
359-
})?;
357+
let mut cache = self.batches.lock();
360358
*cache = Some((batches.clone(), has_more));
361359
}
362360

src/expr/conditional_expr.rs

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use std::sync::{Arc, Mutex};
18+
use std::sync::Arc;
1919

2020
use crate::{
2121
errors::{PyDataFusionError, PyDataFusionResult},
@@ -24,23 +24,14 @@ use crate::{
2424
use datafusion::logical_expr::conditional_expressions::CaseBuilder;
2525
use pyo3::prelude::*;
2626

27+
use parking_lot::{Mutex, MutexGuard};
28+
2729
#[pyclass(name = "CaseBuilder", module = "datafusion.expr", subclass, frozen)]
2830
#[derive(Clone)]
2931
pub struct PyCaseBuilder {
3032
case_builder: Arc<Mutex<Option<CaseBuilder>>>,
3133
}
3234

33-
impl From<PyCaseBuilder> for CaseBuilder {
34-
fn from(case_builder: PyCaseBuilder) -> Self {
35-
case_builder
36-
.case_builder
37-
.lock()
38-
.expect("Case builder mutex poisoned")
39-
.take()
40-
.expect("CaseBuilder has already been consumed")
41-
}
42-
}
43-
4435
impl From<CaseBuilder> for PyCaseBuilder {
4536
fn from(case_builder: CaseBuilder) -> PyCaseBuilder {
4637
PyCaseBuilder {
@@ -50,25 +41,27 @@ impl From<CaseBuilder> for PyCaseBuilder {
5041
}
5142

5243
impl PyCaseBuilder {
53-
fn lock_case_builder(
54-
&self,
55-
) -> PyDataFusionResult<std::sync::MutexGuard<'_, Option<CaseBuilder>>> {
56-
self.case_builder
57-
.lock()
58-
.map_err(|_| PyDataFusionError::Common("failed to lock CaseBuilder".to_string()))
44+
fn lock_case_builder(&self) -> MutexGuard<'_, Option<CaseBuilder>> {
45+
self.case_builder.lock()
5946
}
6047

6148
fn take_case_builder(&self) -> PyDataFusionResult<CaseBuilder> {
62-
let mut guard = self.lock_case_builder()?;
49+
let mut guard = self.lock_case_builder();
6350
guard.take().ok_or_else(|| {
6451
PyDataFusionError::Common("CaseBuilder has already been consumed".to_string())
6552
})
6653
}
6754

68-
fn store_case_builder(&self, builder: CaseBuilder) -> PyDataFusionResult<()> {
69-
let mut guard = self.lock_case_builder()?;
55+
fn store_case_builder(&self, builder: CaseBuilder) {
56+
let mut guard = self.lock_case_builder();
7057
*guard = Some(builder);
71-
Ok(())
58+
}
59+
60+
pub fn into_case_builder(self) -> PyDataFusionResult<CaseBuilder> {
61+
let mut guard = self.case_builder.lock();
62+
guard.take().ok_or_else(|| {
63+
PyDataFusionError::Common("CaseBuilder has already been consumed".to_string())
64+
})
7265
}
7366
}
7467

@@ -77,7 +70,7 @@ impl PyCaseBuilder {
7770
fn when(&self, when: PyExpr, then: PyExpr) -> PyDataFusionResult<PyCaseBuilder> {
7871
let mut builder = self.take_case_builder()?;
7972
let next_builder = builder.when(when.expr, then.expr);
80-
self.store_case_builder(next_builder)?;
73+
self.store_case_builder(next_builder);
8174
Ok(self.clone())
8275
}
8376

@@ -86,7 +79,7 @@ impl PyCaseBuilder {
8679
match builder.otherwise(else_expr.expr) {
8780
Ok(expr) => Ok(expr.clone().into()),
8881
Err(err) => {
89-
self.store_case_builder(builder)?;
82+
self.store_case_builder(builder);
9083
Err(err.into())
9184
}
9285
}
@@ -97,7 +90,7 @@ impl PyCaseBuilder {
9790
match builder.end() {
9891
Ok(expr) => Ok(expr.clone().into()),
9992
Err(err) => {
100-
self.store_case_builder(builder)?;
93+
self.store_case_builder(builder);
10194
Err(err.into())
10295
}
10396
}

0 commit comments

Comments
 (0)