Skip to content

Commit d6ad744

Browse files
committed
data table plugin :check:
1 parent 9591a05 commit d6ad744

File tree

2 files changed

+50
-27
lines changed

2 files changed

+50
-27
lines changed

countess/core/plugins.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,19 @@ def execute(self, ddbc: DuckDBPyConnection, source: Optional[DuckDBPyRelation])
136136
raise NotImplementedError(f"{self.__class__}.execute")
137137

138138

139+
class DuckdbInputPlugin(DuckdbPlugin):
140+
num_inputs = 0
141+
142+
def execute_multi(
143+
self, ddbc: DuckDBPyConnection, sources: Mapping
144+
) -> Optional[DuckDBPyRelation]:
145+
assert len(sources) == 0
146+
return self.execute(ddbc, None)
147+
148+
def execute(self, ddbc: DuckDBPyConnection, source: None) -> Optional[DuckDBPyRelation]:
149+
raise NotImplementedError(f"{self.__class__}.execute")
150+
151+
139152
class DuckdbStatementPlugin(DuckdbSimplePlugin):
140153
def statement(self, ddbc: DuckDBPyConnection, source_table_name: str) -> str:
141154
raise NotImplementedError(f"{self.__class__}.statement")
@@ -156,10 +169,9 @@ class LoadFileMultiParam(MultiParam):
156169
filename = FileParam("Filename")
157170

158171

159-
class DuckdbLoadFilePlugin(DuckdbSimplePlugin):
172+
class DuckdbLoadFilePlugin(DuckdbInputPlugin):
160173
files = FileArrayParam("Files", LoadFileMultiParam("File"))
161174
file_types: Sequence[tuple[str, Union[str, list[str]]]] = [("Any", "*")]
162-
num_inputs = 0
163175

164176
def __init__(self, *a, **k):
165177
super().__init__(*a, **k)
@@ -170,9 +182,7 @@ def filenames_and_params(self):
170182
for filename in glob.iglob(file_param.filename.value):
171183
yield filename, file_param
172184

173-
def execute(self, ddbc: DuckDBPyConnection, source: Optional[DuckDBPyRelation]) -> Optional[DuckDBPyRelation]:
174-
assert source is None
175-
185+
def execute(self, ddbc: DuckDBPyConnection, source: None) -> Optional[DuckDBPyRelation]:
176186
filenames_and_params = list(self.filenames_and_params())
177187

178188
cursor = ddbc
@@ -238,19 +248,25 @@ def dropped_columns(self) -> set[str]:
238248
return set()
239249

240250
def input_columns(self) -> dict[str, str]:
241-
raise NotImplementedError(f"{self.__class__}.input_columns")
251+
return None
242252

243253
def output_columns(self) -> dict[str, str]:
244254
raise NotImplementedError(f"{self.__class__}.output_columns")
245255

246256
def execute(self, ddbc, source):
247257
"""Perform a query which calls `self.transform` for every row."""
248258

249-
escaped_input_columns = {
250-
duckdb_escape_identifier(k): str(v).upper()
251-
for k, v in self.input_columns().items()
252-
if k is not None and v is not None
253-
}
259+
if self.input_columns() is None:
260+
escaped_input_columns = {
261+
duckdb_escape_identifier(k): str(v).upper()
262+
for k, v in zip(source.columns, source.dtypes)
263+
}
264+
else:
265+
escaped_input_columns = {
266+
duckdb_escape_identifier(k): str(v).upper()
267+
for k, v in self.input_columns().items()
268+
if k is not None and v is not None
269+
}
254270

255271
escaped_output_columns = {
256272
duckdb_escape_identifier(k): str(v).upper()

countess/plugins/data_table.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from typing import Iterable
1+
import logging
22

3-
import pandas as pd
3+
from duckdb import DuckDBPyConnection, DuckDBPyRelation
44

55
from countess import VERSION
66
from countess.core.parameters import (
@@ -11,16 +11,18 @@
1111
StringParam,
1212
TabularMultiParam,
1313
)
14-
from countess.core.plugins import PandasInputPlugin
14+
from countess.core.plugins import DuckdbInputPlugin
15+
from countess.utils.duckdb import duckdb_escape_literal, duckdb_escape_identifier
1516

17+
logger = logging.getLogger(__name__)
1618

1719
class _ColumnsMultiParam(MultiParam):
1820
name = StringParam("Name")
19-
type = DataTypeChoiceParam("Type", "string")
21+
type = DataTypeChoiceParam("Type")
2022
index = BooleanParam("Index?")
2123

2224

23-
class DataTablePlugin(PandasInputPlugin):
25+
class DataTablePlugin(DuckdbInputPlugin):
2426
"""DataTable"""
2527

2628
name = "DataTable"
@@ -32,7 +34,7 @@ class DataTablePlugin(PandasInputPlugin):
3234
columns = ArrayParam("Columns", _ColumnsMultiParam("Column"))
3335
rows = ArrayParam("Rows", TabularMultiParam("Row"))
3436

35-
show_preview = False
37+
#show_preview = False
3638

3739
def fix_columns(self):
3840
old_rows = self.rows.params
@@ -61,17 +63,22 @@ def set_parameter(self, key: str, *a, **k):
6163
self.fix_columns()
6264
super().set_parameter(key, *a, **k)
6365

64-
def finalize(self) -> Iterable[pd.DataFrame]:
66+
def execute(self, ddbc: DuckDBPyConnection, source: None) -> DuckDBPyRelation:
6567
self.fix_columns()
66-
values = []
67-
for row in self.rows:
68-
values.append({str(col.name): row[str(col.name)].value for col in self.columns})
6968

70-
df = pd.DataFrame(values)
71-
72-
index_cols = [str(col.name) for col in self.columns if col.index]
69+
if len(self.rows) == 0:
70+
return None
71+
72+
sql = ("SELECT * FROM (VALUES " +
73+
(", ".join(
74+
"(" + (", ".join(duckdb_escape_literal(val.value) for val in row.values())) + ")"
75+
for row in self.rows
76+
)) +
77+
") _(" +
78+
(", ".join(duckdb_escape_identifier(col.name.value) for col in self.columns)) +
79+
")"
80+
)
7381

74-
if index_cols:
75-
df = df.set_index(index_cols)
82+
logger.debug("DataTablePlugin.execute sql %s", sql)
7683

77-
yield df
84+
return ddbc.sql(sql)

0 commit comments

Comments
 (0)