Skip to content

Commit 10af0f1

Browse files
committed
Refactor schema, config, dataframe, and expression classes to use RwLock and Mutex for interior mutability
1 parent c8f7145 commit 10af0f1

File tree

7 files changed

+201
-75
lines changed

7 files changed

+201
-75
lines changed

src/common/schema.rs

Lines changed: 91 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
// under the License.
1717

1818
use std::fmt::{self, Display, Formatter};
19-
use std::sync::Arc;
19+
use std::sync::{Arc, RwLock};
2020
use std::{any::Any, borrow::Cow};
2121

2222
use arrow::datatypes::Schema;
@@ -25,6 +25,7 @@ use datafusion::arrow::datatypes::SchemaRef;
2525
use datafusion::common::Constraints;
2626
use datafusion::datasource::TableType;
2727
use datafusion::logical_expr::{Expr, TableProviderFilterPushDown, TableSource};
28+
use pyo3::exceptions::PyRuntimeError;
2829
use pyo3::prelude::*;
2930

3031
use datafusion::logical_expr::utils::split_conjunction;
@@ -33,17 +34,13 @@ use crate::sql::logical::PyLogicalPlan;
3334

3435
use super::{data_type::DataTypeMap, function::SqlFunction};
3536

36-
#[pyclass(name = "SqlSchema", module = "datafusion.common", subclass)]
37+
#[pyclass(name = "SqlSchema", module = "datafusion.common", subclass, frozen)]
3738
#[derive(Debug, Clone)]
3839
pub struct SqlSchema {
39-
#[pyo3(get, set)]
40-
pub name: String,
41-
#[pyo3(get, set)]
42-
pub tables: Vec<SqlTable>,
43-
#[pyo3(get, set)]
44-
pub views: Vec<SqlView>,
45-
#[pyo3(get, set)]
46-
pub functions: Vec<SqlFunction>,
40+
name: Arc<RwLock<String>>,
41+
tables: Arc<RwLock<Vec<SqlTable>>>,
42+
views: Arc<RwLock<Vec<SqlView>>>,
43+
functions: Arc<RwLock<Vec<SqlFunction>>>,
4744
}
4845

4946
#[pyclass(name = "SqlTable", module = "datafusion.common", subclass)]
@@ -104,28 +101,98 @@ impl SqlSchema {
104101
#[new]
105102
pub fn new(schema_name: &str) -> Self {
106103
Self {
107-
name: schema_name.to_owned(),
108-
tables: Vec::new(),
109-
views: Vec::new(),
110-
functions: Vec::new(),
104+
name: Arc::new(RwLock::new(schema_name.to_owned())),
105+
tables: Arc::new(RwLock::new(Vec::new())),
106+
views: Arc::new(RwLock::new(Vec::new())),
107+
functions: Arc::new(RwLock::new(Vec::new())),
111108
}
112109
}
113110

111+
#[getter]
112+
fn name(&self) -> PyResult<String> {
113+
Ok(self
114+
.name
115+
.read()
116+
.map_err(|_| PyRuntimeError::new_err("failed to read schema name"))?
117+
.clone())
118+
}
119+
120+
#[setter]
121+
fn set_name(&self, value: String) -> PyResult<()> {
122+
*self
123+
.name
124+
.write()
125+
.map_err(|_| PyRuntimeError::new_err("failed to write schema name"))? = value;
126+
Ok(())
127+
}
128+
129+
#[getter]
130+
fn tables(&self) -> PyResult<Vec<SqlTable>> {
131+
Ok(self
132+
.tables
133+
.read()
134+
.map_err(|_| PyRuntimeError::new_err("failed to read schema tables"))?
135+
.clone())
136+
}
137+
138+
#[setter]
139+
fn set_tables(&self, tables: Vec<SqlTable>) -> PyResult<()> {
140+
*self
141+
.tables
142+
.write()
143+
.map_err(|_| PyRuntimeError::new_err("failed to write schema tables"))? = tables;
144+
Ok(())
145+
}
146+
147+
#[getter]
148+
fn views(&self) -> PyResult<Vec<SqlView>> {
149+
Ok(self
150+
.views
151+
.read()
152+
.map_err(|_| PyRuntimeError::new_err("failed to read schema views"))?
153+
.clone())
154+
}
155+
156+
#[setter]
157+
fn set_views(&self, views: Vec<SqlView>) -> PyResult<()> {
158+
*self
159+
.views
160+
.write()
161+
.map_err(|_| PyRuntimeError::new_err("failed to write schema views"))? = views;
162+
Ok(())
163+
}
164+
165+
#[getter]
166+
fn functions(&self) -> PyResult<Vec<SqlFunction>> {
167+
Ok(self
168+
.functions
169+
.read()
170+
.map_err(|_| PyRuntimeError::new_err("failed to read schema functions"))?
171+
.clone())
172+
}
173+
174+
#[setter]
175+
fn set_functions(&self, functions: Vec<SqlFunction>) -> PyResult<()> {
176+
*self
177+
.functions
178+
.write()
179+
.map_err(|_| PyRuntimeError::new_err("failed to write schema functions"))? = functions;
180+
Ok(())
181+
}
182+
114183
pub fn table_by_name(&self, table_name: &str) -> Option<SqlTable> {
115-
for tbl in &self.tables {
116-
if tbl.name.eq(table_name) {
117-
return Some(tbl.clone());
118-
}
119-
}
120-
None
184+
let tables = self.tables.read().expect("failed to read schema tables");
185+
tables.iter().find(|tbl| tbl.name.eq(table_name)).cloned()
121186
}
122187

123-
pub fn add_table(&mut self, table: SqlTable) {
124-
self.tables.push(table);
188+
pub fn add_table(&self, table: SqlTable) {
189+
let mut tables = self.tables.write().expect("failed to write schema tables");
190+
tables.push(table);
125191
}
126192

127-
pub fn drop_table(&mut self, table_name: String) {
128-
self.tables.retain(|x| !x.name.eq(&table_name));
193+
pub fn drop_table(&self, table_name: String) {
194+
let mut tables = self.tables.write().expect("failed to write schema tables");
195+
tables.retain(|x| !x.name.eq(&table_name));
129196
}
130197
}
131198

src/config.rs

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,40 +15,45 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use std::sync::{Arc, RwLock};
19+
1820
use pyo3::prelude::*;
1921
use pyo3::types::*;
2022

2123
use datafusion::config::ConfigOptions;
2224

23-
use crate::errors::PyDataFusionResult;
25+
use crate::errors::{PyDataFusionError, PyDataFusionResult};
2426
use crate::utils::py_obj_to_scalar_value;
2527

26-
#[pyclass(name = "Config", module = "datafusion", subclass)]
28+
#[pyclass(name = "Config", module = "datafusion", subclass, frozen)]
2729
#[derive(Clone)]
2830
pub(crate) struct PyConfig {
29-
config: ConfigOptions,
31+
config: Arc<RwLock<ConfigOptions>>,
3032
}
3133

3234
#[pymethods]
3335
impl PyConfig {
3436
#[new]
3537
fn py_new() -> Self {
3638
Self {
37-
config: ConfigOptions::new(),
39+
config: Arc::new(RwLock::new(ConfigOptions::new())),
3840
}
3941
}
4042

4143
/// Get configurations from environment variables
4244
#[staticmethod]
4345
pub fn from_env() -> PyDataFusionResult<Self> {
4446
Ok(Self {
45-
config: ConfigOptions::from_env()?,
47+
config: Arc::new(RwLock::new(ConfigOptions::from_env()?)),
4648
})
4749
}
4850

4951
/// Get a configuration option
50-
pub fn get<'py>(&mut self, key: &str, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
51-
let options = self.config.to_owned();
52+
pub fn get<'py>(&self, key: &str, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
53+
let options = self
54+
.config
55+
.read()
56+
.map_err(|_| PyDataFusionError::Common("failed to read configuration".to_string()))?;
5257
for entry in options.entries() {
5358
if entry.key == key {
5459
return Ok(entry.value.into_pyobject(py)?);
@@ -58,25 +63,31 @@ impl PyConfig {
5863
}
5964

6065
/// Set a configuration option
61-
pub fn set(&mut self, key: &str, value: PyObject, py: Python) -> PyDataFusionResult<()> {
66+
pub fn set(&self, key: &str, value: PyObject, py: Python) -> PyDataFusionResult<()> {
6267
let scalar_value = py_obj_to_scalar_value(py, value)?;
63-
self.config.set(key, scalar_value.to_string().as_str())?;
68+
let mut options = self
69+
.config
70+
.write()
71+
.map_err(|_| PyDataFusionError::Common("failed to lock configuration".to_string()))?;
72+
options.set(key, scalar_value.to_string().as_str())?;
6473
Ok(())
6574
}
6675

6776
/// Get all configuration options
68-
pub fn get_all(&mut self, py: Python) -> PyResult<PyObject> {
77+
pub fn get_all(&self, py: Python) -> PyResult<PyObject> {
6978
let dict = PyDict::new(py);
70-
let options = self.config.to_owned();
79+
let options = self
80+
.config
81+
.read()
82+
.map_err(|_| PyDataFusionError::Common("failed to read configuration".to_string()))?;
7183
for entry in options.entries() {
7284
dict.set_item(entry.key, entry.value.clone().into_pyobject(py)?)?;
7385
}
7486
Ok(dict.into())
7587
}
7688

77-
fn __repr__(&mut self, py: Python) -> PyResult<String> {
78-
let dict = self.get_all(py);
79-
match dict {
89+
fn __repr__(&self, py: Python) -> PyResult<String> {
90+
match self.get_all(py) {
8091
Ok(result) => Ok(format!("Config({result})")),
8192
Err(err) => Ok(format!("Error: {:?}", err.to_string())),
8293
}

src/dataframe.rs

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
use std::collections::HashMap;
1919
use std::ffi::CString;
20-
use std::sync::Arc;
20+
use std::sync::{Arc, Mutex};
2121

2222
use arrow::array::{new_null_array, RecordBatch, RecordBatchIterator, RecordBatchReader};
2323
use arrow::compute::can_cast_types;
@@ -284,30 +284,38 @@ impl PyParquetColumnOptions {
284284
/// A PyDataFrame is a representation of a logical plan and an API to compose statements.
285285
/// Use it to build a plan and `.collect()` to execute the plan and collect the result.
286286
/// The actual execution of a plan runs natively on Rust and Arrow on a multi-threaded environment.
287-
#[pyclass(name = "DataFrame", module = "datafusion", subclass)]
287+
#[pyclass(name = "DataFrame", module = "datafusion", subclass, frozen)]
288288
#[derive(Clone)]
289289
pub struct PyDataFrame {
290290
df: Arc<DataFrame>,
291291

292292
// In IPython environment cache batches between __repr__ and _repr_html_ calls.
293-
batches: Option<(Vec<RecordBatch>, bool)>,
293+
batches: Arc<Mutex<Option<(Vec<RecordBatch>, bool)>>>,
294294
}
295295

296296
impl PyDataFrame {
297297
/// creates a new PyDataFrame
298298
pub fn new(df: DataFrame) -> Self {
299299
Self {
300300
df: Arc::new(df),
301-
batches: None,
301+
batches: Arc::new(Mutex::new(None)),
302302
}
303303
}
304304

305-
fn prepare_repr_string(&mut self, py: Python, as_html: bool) -> PyDataFusionResult<String> {
305+
fn prepare_repr_string(&self, py: Python, as_html: bool) -> PyDataFusionResult<String> {
306306
// Get the Python formatter and config
307307
let PythonFormatter { formatter, config } = get_python_formatter_with_config(py)?;
308308

309-
let should_cache = *is_ipython_env(py) && self.batches.is_none();
310-
let (batches, has_more) = match self.batches.take() {
309+
let (cached_batches, should_cache) = {
310+
let mut cache = self.batches.lock().map_err(|_| {
311+
PyDataFusionError::Common("failed to lock DataFrame display cache".to_string())
312+
})?;
313+
let should_cache = *is_ipython_env(py) && cache.is_none();
314+
let batches = cache.take();
315+
(batches, should_cache)
316+
};
317+
318+
let (batches, has_more) = match cached_batches {
311319
Some(b) => b,
312320
None => wait_for_future(
313321
py,
@@ -346,7 +354,10 @@ impl PyDataFrame {
346354
let html_str: String = html_result.extract()?;
347355

348356
if should_cache {
349-
self.batches = Some((batches, has_more));
357+
let mut cache = self.batches.lock().map_err(|_| {
358+
PyDataFusionError::Common("failed to lock DataFrame display cache".to_string())
359+
})?;
360+
*cache = Some((batches.clone(), has_more));
350361
}
351362

352363
Ok(html_str)
@@ -376,7 +387,7 @@ impl PyDataFrame {
376387
}
377388
}
378389

379-
fn __repr__(&mut self, py: Python) -> PyDataFusionResult<String> {
390+
fn __repr__(&self, py: Python) -> PyDataFusionResult<String> {
380391
self.prepare_repr_string(py, false)
381392
}
382393

@@ -411,7 +422,7 @@ impl PyDataFrame {
411422
Ok(format!("DataFrame()\n{batches_as_displ}{additional_str}"))
412423
}
413424

414-
fn _repr_html_(&mut self, py: Python) -> PyDataFusionResult<String> {
425+
fn _repr_html_(&self, py: Python) -> PyDataFusionResult<String> {
415426
self.prepare_repr_string(py, true)
416427
}
417428

@@ -874,7 +885,7 @@ impl PyDataFrame {
874885

875886
#[pyo3(signature = (requested_schema=None))]
876887
fn __arrow_c_stream__<'py>(
877-
&'py mut self,
888+
&'py self,
878889
py: Python<'py>,
879890
requested_schema: Option<Bound<'py, PyCapsule>>,
880891
) -> PyDataFusionResult<Bound<'py, PyCapsule>> {

0 commit comments

Comments
 (0)