@@ -285,6 +285,8 @@ def build_udf_endpoint(
285285 """
286286 if returns_data_format in ['scalar' , 'list' ]:
287287
288+ is_async = asyncio .iscoroutinefunction (func )
289+
288290 async def do_func (
289291 cancel_event : threading .Event ,
290292 row_ids : Sequence [int ],
@@ -297,7 +299,10 @@ async def do_func(
297299 raise asyncio .CancelledError (
298300 'Function call was cancelled' ,
299301 )
300- out .append (func (* row ))
302+ if is_async :
303+ out .append (await func (* row ))
304+ else :
305+ out .append (func (* row ))
301306 return row_ids , list (zip (out ))
302307
303308 return do_func
@@ -327,6 +332,7 @@ def build_vector_udf_endpoint(
327332 """
328333 masks = get_masked_params (func )
329334 array_cls = get_array_class (returns_data_format )
335+ is_async = asyncio .iscoroutinefunction (func )
330336
331337 async def do_func (
332338 cancel_event : threading .Event ,
@@ -341,9 +347,15 @@ async def do_func(
341347
342348 # Call the function with `cols` as the function parameters
343349 if cols and cols [0 ]:
344- out = func (* [x if m else x [0 ] for x , m in zip (cols , masks )])
350+ if is_async :
351+ out = await func (* [x if m else x [0 ] for x , m in zip (cols , masks )])
352+ else :
353+ out = func (* [x if m else x [0 ] for x , m in zip (cols , masks )])
345354 else :
346- out = func ()
355+ if is_async :
356+ out = await func ()
357+ else :
358+ out = func ()
347359
348360 # Single masked value
349361 if isinstance (out , Masked ):
@@ -381,6 +393,8 @@ def build_tvf_endpoint(
381393 """
382394 if returns_data_format in ['scalar' , 'list' ]:
383395
396+ is_async = asyncio .iscoroutinefunction (func )
397+
384398 async def do_func (
385399 cancel_event : threading .Event ,
386400 row_ids : Sequence [int ],
@@ -390,11 +404,15 @@ async def do_func(
390404 out_ids : List [int ] = []
391405 out = []
392406 # Call function on each row of data
393- for i , res in zip (row_ids , func_map ( func , rows ) ):
407+ for i , row in zip (row_ids , rows ):
394408 if cancel_event .is_set ():
395409 raise asyncio .CancelledError (
396410 'Function call was cancelled' ,
397411 )
412+ if is_async :
413+ res = await func (* row )
414+ else :
415+ res = func (* row )
398416 out .extend (as_list_of_tuples (res ))
399417 out_ids .extend ([row_ids [i ]] * (len (out )- len (out_ids )))
400418 return out_ids , out
@@ -440,13 +458,23 @@ async def do_func(
440458 # each result row, so we just have to use the same
441459 # row ID for all rows in the result.
442460
461+ is_async = asyncio .iscoroutinefunction (func )
462+
443463 # Call function on each column of data
444464 if cols and cols [0 ]:
445- res = get_dataframe_columns (
446- func (* [x if m else x [0 ] for x , m in zip (cols , masks )]),
447- )
465+ if is_async :
466+ res = get_dataframe_columns (
467+ await func (* [x if m else x [0 ] for x , m in zip (cols , masks )]),
468+ )
469+ else :
470+ res = get_dataframe_columns (
471+ func (* [x if m else x [0 ] for x , m in zip (cols , masks )]),
472+ )
448473 else :
449- res = get_dataframe_columns (func ())
474+ if is_async :
475+ res = get_dataframe_columns (await func ())
476+ else :
477+ res = get_dataframe_columns (func ())
450478
451479 # Generate row IDs
452480 if isinstance (res [0 ], Masked ):
@@ -508,6 +536,9 @@ def make_func(
508536 # Set timeout
509537 info ['timeout' ] = max (timeout , 1 )
510538
539+ # Set async flag
540+ info ['is_async' ] = asyncio .iscoroutinefunction (func )
541+
511542 # Setup argument types for rowdat_1 parser
512543 colspec = []
513544 for x in sig ['args' ]:
@@ -927,18 +958,28 @@ async def __call__(
927958
928959 cancel_event = threading .Event ()
929960
930- func_task = asyncio .create_task (
931- to_thread (
932- lambda : asyncio .run (
933- func (
934- cancel_event ,
935- * input_handler ['load' ]( # type: ignore
936- func_info ['colspec' ], b'' .join (data ),
961+ if func_info ['is_async' ]:
962+ func_task = asyncio .create_task (
963+ func (
964+ cancel_event ,
965+ * input_handler ['load' ]( # type: ignore
966+ func_info ['colspec' ], b'' .join (data ),
967+ ),
968+ ),
969+ )
970+ else :
971+ func_task = asyncio .create_task (
972+ to_thread (
973+ lambda : asyncio .run (
974+ func (
975+ cancel_event ,
976+ * input_handler ['load' ]( # type: ignore
977+ func_info ['colspec' ], b'' .join (data ),
978+ ),
937979 ),
938980 ),
939981 ),
940- ),
941- )
982+ )
942983 disconnect_task = asyncio .create_task (
943984 cancel_on_disconnect (receive ),
944985 )
@@ -970,6 +1011,7 @@ async def __call__(
9701011 elif task is func_task :
9711012 result .extend (task .result ())
9721013
1014+ print (result )
9731015 body = output_handler ['dump' ](
9741016 [x [1 ] for x in func_info ['returns' ]], * result , # type: ignore
9751017 )
0 commit comments