@@ -136,6 +136,19 @@ def execute(self, ddbc: DuckDBPyConnection, source: Optional[DuckDBPyRelation])
136136 raise NotImplementedError (f"{ self .__class__ } .execute" )
137137
138138
139+ class DuckdbInputPlugin (DuckdbPlugin ):
140+ num_inputs = 0
141+
142+ def execute_multi (
143+ self , ddbc : DuckDBPyConnection , sources : Mapping
144+ ) -> Optional [DuckDBPyRelation ]:
145+ assert len (sources ) == 0
146+ return self .execute (ddbc , None )
147+
148+ def execute (self , ddbc : DuckDBPyConnection , source : None ) -> Optional [DuckDBPyRelation ]:
149+ raise NotImplementedError (f"{ self .__class__ } .execute" )
150+
151+
139152class DuckdbStatementPlugin (DuckdbSimplePlugin ):
140153 def statement (self , ddbc : DuckDBPyConnection , source_table_name : str ) -> str :
141154 raise NotImplementedError (f"{ self .__class__ } .statement" )
@@ -156,10 +169,9 @@ class LoadFileMultiParam(MultiParam):
156169 filename = FileParam ("Filename" )
157170
158171
159- class DuckdbLoadFilePlugin (DuckdbSimplePlugin ):
172+ class DuckdbLoadFilePlugin (DuckdbInputPlugin ):
160173 files = FileArrayParam ("Files" , LoadFileMultiParam ("File" ))
161174 file_types : Sequence [tuple [str , Union [str , list [str ]]]] = [("Any" , "*" )]
162- num_inputs = 0
163175
164176 def __init__ (self , * a , ** k ):
165177 super ().__init__ (* a , ** k )
@@ -170,9 +182,7 @@ def filenames_and_params(self):
170182 for filename in glob .iglob (file_param .filename .value ):
171183 yield filename , file_param
172184
173- def execute (self , ddbc : DuckDBPyConnection , source : Optional [DuckDBPyRelation ]) -> Optional [DuckDBPyRelation ]:
174- assert source is None
175-
185+ def execute (self , ddbc : DuckDBPyConnection , source : None ) -> Optional [DuckDBPyRelation ]:
176186 filenames_and_params = list (self .filenames_and_params ())
177187
178188 cursor = ddbc
@@ -238,19 +248,25 @@ def dropped_columns(self) -> set[str]:
238248 return set ()
239249
240250 def input_columns (self ) -> dict [str , str ]:
241- raise NotImplementedError ( f" { self . __class__ } .input_columns" )
251+ return None
242252
243253 def output_columns (self ) -> dict [str , str ]:
244254 raise NotImplementedError (f"{ self .__class__ } .output_columns" )
245255
246256 def execute (self , ddbc , source ):
247257 """Perform a query which calls `self.transform` for every row."""
248258
249- escaped_input_columns = {
250- duckdb_escape_identifier (k ): str (v ).upper ()
251- for k , v in self .input_columns ().items ()
252- if k is not None and v is not None
253- }
259+ if self .input_columns () is None :
260+ escaped_input_columns = {
261+ duckdb_escape_identifier (k ): str (v ).upper ()
262+ for k , v in zip (source .columns , source .dtypes )
263+ }
264+ else :
265+ escaped_input_columns = {
266+ duckdb_escape_identifier (k ): str (v ).upper ()
267+ for k , v in self .input_columns ().items ()
268+ if k is not None and v is not None
269+ }
254270
255271 escaped_output_columns = {
256272 duckdb_escape_identifier (k ): str (v ).upper ()
0 commit comments