From c37feea6467b45fffbfbc00fcb5acca7066dd0b3 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Fri, 4 Apr 2025 12:47:41 -0500 Subject: [PATCH 01/16] Refactoring function signatures --- accel.c | 118 +++++--- singlestoredb/config.py | 2 +- singlestoredb/functions/decorator.py | 45 +-- singlestoredb/functions/ext/asgi.py | 7 +- singlestoredb/functions/signature.py | 436 ++++++++++++++++++++++----- 5 files changed, 470 insertions(+), 138 deletions(-) diff --git a/accel.c b/accel.c index e2e2193fa..7b809f94d 100644 --- a/accel.c +++ b/accel.c @@ -35,6 +35,7 @@ #define NUMPY_TIMEDELTA 12 #define NUMPY_DATETIME 13 #define NUMPY_OBJECT 14 +#define NUMPY_BYTES 15 #define MYSQL_FLAG_NOT_NULL 1 #define MYSQL_FLAG_PRI_KEY 2 @@ -339,6 +340,11 @@ #define CHECKRC(x) if ((x) < 0) goto error; +typedef struct { + int type; + Py_ssize_t length; +} NumpyColType; + typedef struct { int results_type; int parse_json; @@ -2646,8 +2652,8 @@ static char *get_array_base_address(PyObject *py_array) { } -static int get_numpy_col_type(PyObject *py_array) { - int out = 0; +static NumpyColType get_numpy_col_type(PyObject *py_array) { + NumpyColType out = {0}; char *str = NULL; PyObject *py_array_interface = NULL; PyObject *py_typestr = NULL; @@ -2665,58 +2671,79 @@ static int get_numpy_col_type(PyObject *py_array) { switch (str[1]) { case 'b': - out = NUMPY_BOOL; + out.type = NUMPY_BOOL; + out.length = 1; break; case 'i': switch (str[2]) { case '1': - out = NUMPY_INT8; + out.type = NUMPY_INT8; + out.length = 1; break; case '2': - out = NUMPY_INT16; + out.type = NUMPY_INT16; + out.length = 2; break; case '4': - out = NUMPY_INT32; + out.type = NUMPY_INT32; + out.length = 4; break; case '8': - out = NUMPY_INT64; + out.type = NUMPY_INT64; + out.length = 8; break; } break; case 'u': switch (str[2]) { case '1': - out = NUMPY_UINT8; + out.type = NUMPY_UINT8; + out.length = 1; break; case '2': - out = NUMPY_UINT16; + out.type = NUMPY_UINT16; + out.length = 2; break; case '4': - out = NUMPY_UINT32; + out.type = NUMPY_UINT32; + out.length = 4; break; case '8': - out = NUMPY_UINT64; + out.type = NUMPY_UINT64; + out.length = 8; break; } break; case 'f': switch (str[2]) { case '4': - out = NUMPY_FLOAT32; + out.type = NUMPY_FLOAT32; + out.length = 4; break; case '8': - out = NUMPY_FLOAT64; + out.type = NUMPY_FLOAT64; + out.length = 8; break; } break; case 'O': - out = NUMPY_OBJECT; + out.type = NUMPY_OBJECT; + out.length = 8; break; case 'm': - out = NUMPY_TIMEDELTA; + out.type = NUMPY_TIMEDELTA; + out.length = 8; break; case 'M': - out = NUMPY_DATETIME; + out.type = NUMPY_DATETIME; + out.length = 8; + break; + case 'S': + out.type = NUMPY_BYTES; + out.length = (Py_ssize_t)strtol(str + 2, NULL, 10); + if (out.length < 0) { + goto error; + } break; default: goto error; @@ -2730,7 +2757,8 @@ static int get_numpy_col_type(PyObject *py_array) { return out; error: - out = 0; + out.type = 0; + out.length = 0; goto exit; } @@ -2774,7 +2802,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k unsigned long long j = 0; char **cols = NULL; char **masks = NULL; - int *col_types = NULL; + NumpyColType *col_types = NULL; int64_t *row_ids = NULL; // Parse function args. @@ -2847,7 +2875,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k // Get column array memory cols = calloc(sizeof(char*), n_cols); if (!cols) goto error; - col_types = calloc(sizeof(int), n_cols); + col_types = calloc(sizeof(NumpyColType), n_cols); if (!col_types) goto error; masks = calloc(sizeof(char*), n_cols); if (!masks) goto error; @@ -2865,7 +2893,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k } col_types[i] = get_numpy_col_type(py_data); - if (!col_types[i]) { + if (!col_types[i].type) { PyErr_SetString(PyExc_ValueError, "unable to get column type of data column"); goto error; } @@ -2874,7 +2902,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k if (!py_mask) goto error; masks[i] = get_array_base_address(py_mask); - if (masks[i] && get_numpy_col_type(py_mask) != NUMPY_BOOL) { + if (masks[i] && get_numpy_col_type(py_mask).type != NUMPY_BOOL) { PyErr_SetString(PyExc_ValueError, "mask must only contain boolean values"); goto error; } @@ -2958,7 +2986,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k case MYSQL_TYPE_TINY: CHECKMEM(1); - switch (col_types[i]) { + switch (col_types[i].type) { case NUMPY_BOOL: i8 = *(int8_t*)(cols[i] + j * 1); CHECK_TINYINT(i8, 0); @@ -3025,7 +3053,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k // Use negative to indicate unsigned case -MYSQL_TYPE_TINY: CHECKMEM(1); - switch (col_types[i]) { + switch (col_types[i].type) { case NUMPY_BOOL: i8 = *(int8_t*)(cols[i] + j * 1); CHECK_UNSIGNED_TINYINT(i8, 0); @@ -3091,7 +3119,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k case MYSQL_TYPE_SHORT: CHECKMEM(2); - switch (col_types[i]) { + switch (col_types[i].type) { case NUMPY_BOOL: i8 = *(int8_t*)(cols[i] + j * 1); CHECK_SMALLINT(i8, 0); @@ -3158,7 +3186,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k // Use negative to indicate unsigned case -MYSQL_TYPE_SHORT: CHECKMEM(2); - switch (col_types[i]) { + switch (col_types[i].type) { case NUMPY_BOOL: i8 = *(int8_t*)(cols[i] + j * 1); CHECK_UNSIGNED_SMALLINT(i8, 0); @@ -3224,7 +3252,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k case MYSQL_TYPE_INT24: CHECKMEM(4); - switch (col_types[i]) { + switch (col_types[i].type) { case NUMPY_BOOL: i8 = *(int8_t*)(cols[i] + j * 1); CHECK_MEDIUMINT(i8, 0); @@ -3290,7 +3318,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k case MYSQL_TYPE_LONG: CHECKMEM(4); - switch (col_types[i]) { + switch (col_types[i].type) { case NUMPY_BOOL: i8 = *(int8_t*)(cols[i] + j * 1); CHECK_INT(i8, 0); @@ -3357,7 +3385,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k // Use negative to indicate unsigned case -MYSQL_TYPE_INT24: CHECKMEM(4); - switch (col_types[i]) { + switch (col_types[i].type) { case NUMPY_BOOL: i8 = *(int8_t*)(cols[i] + j * 1); CHECK_UNSIGNED_MEDIUMINT(i8, 0); @@ -3424,7 +3452,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k // Use negative to indicate unsigned case -MYSQL_TYPE_LONG: CHECKMEM(4); - switch (col_types[i]) { + switch (col_types[i].type) { case NUMPY_BOOL: i8 = *(int8_t*)(cols[i] + j * 1); CHECK_UNSIGNED_INT(i8, 0); @@ -3490,7 +3518,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k case MYSQL_TYPE_LONGLONG: CHECKMEM(8); - switch (col_types[i]) { + switch (col_types[i].type) { case NUMPY_BOOL: i8 = *(int8_t*)(cols[i] + j * 1); CHECK_BIGINT(i8, 0); @@ -3557,7 +3585,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k // Use negative to indicate unsigned case -MYSQL_TYPE_LONGLONG: CHECKMEM(8); - switch (col_types[i]) { + switch (col_types[i].type) { case NUMPY_BOOL: i8 = *(int8_t*)(cols[i] + j * 1); CHECK_UNSIGNED_BIGINT(i8, 0); @@ -3623,7 +3651,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k case MYSQL_TYPE_FLOAT: CHECKMEM(4); - switch (col_types[i]) { + switch (col_types[i].type) { case NUMPY_BOOL: flt = (float)((is_null) ? 0 : *(int8_t*)(cols[i] + j * 1)); break; @@ -3667,7 +3695,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k case MYSQL_TYPE_DOUBLE: CHECKMEM(8); - switch (col_types[i]) { + switch (col_types[i].type) { case NUMPY_BOOL: dbl = (double)((is_null) ? 0 : *(int8_t*)(cols[i] + j * 1)); break; @@ -3742,7 +3770,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k case MYSQL_TYPE_YEAR: CHECKMEM(2); - switch (col_types[i]) { + switch (col_types[i].type) { case NUMPY_BOOL: i8 = *(int8_t*)(cols[i] + j * 1); CHECK_YEAR(i8); @@ -3817,7 +3845,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k case MYSQL_TYPE_MEDIUM_BLOB: case MYSQL_TYPE_LONG_BLOB: case MYSQL_TYPE_BLOB: - if (col_types[i] != NUMPY_OBJECT) { + if (col_types[i].type != NUMPY_OBJECT) { PyErr_SetString(PyExc_ValueError, "unsupported numpy data type for character output types"); goto error; } @@ -3873,7 +3901,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k case -MYSQL_TYPE_MEDIUM_BLOB: case -MYSQL_TYPE_LONG_BLOB: case -MYSQL_TYPE_BLOB: - if (col_types[i] != NUMPY_OBJECT) { + if (col_types[i].type != NUMPY_OBJECT && col_types[i].type != NUMPY_BYTES) { PyErr_SetString(PyExc_ValueError, "unsupported numpy data type for binary output types"); goto error; } @@ -3884,6 +3912,24 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k memcpy(out+out_idx, &i64, 8); out_idx += 8; + } else if (col_types[i].type == NUMPY_BYTES) { + void *bytes = (void*)(cols[i] + j * 8); + + if (bytes == NULL) { + CHECKMEM(8); + i64 = 0; + memcpy(out+out_idx, &i64, 8); + out_idx += 8; + } else { + Py_ssize_t str_l = col_types[i].length; + CHECKMEM(8+str_l); + i64 = str_l; + memcpy(out+out_idx, &i64, 8); + out_idx += 8; + memcpy(out+out_idx, bytes, str_l); + out_idx += str_l; + } + } else { u64 = *(uint64_t*)(cols[i] + j * 8); diff --git a/singlestoredb/config.py b/singlestoredb/config.py index d83b9931b..664635977 100644 --- a/singlestoredb/config.py +++ b/singlestoredb/config.py @@ -415,7 +415,7 @@ ) register_option( - 'external_function.host', 'string', check_str, '127.0.0.1', + 'external_function.host', 'string', check_str, 'localhost', 'Specifies the host to bind the server to.', environ=['SINGLESTOREDB_EXT_FUNC_HOST'], ) diff --git a/singlestoredb/functions/decorator.py b/singlestoredb/functions/decorator.py index 9e6ef7ff3..0d43fe148 100644 --- a/singlestoredb/functions/decorator.py +++ b/singlestoredb/functions/decorator.py @@ -122,28 +122,29 @@ def process_types(params: Any) -> Any: raise TypeError(f'unrecognized data type for args: {params}') +ParameterType = Union[ + str, + List[str], + Dict[str, str], + 'pydantic.BaseModel', + type, +] + +ReturnType = Union[ + str, + List[DataType], + List[type], + 'pydantic.BaseModel', + type, +] + + def _func( func: Optional[Callable[..., Any]] = None, *, name: Optional[str] = None, - args: Optional[ - Union[ - DataType, - List[DataType], - Dict[str, DataType], - 'pydantic.BaseModel', - type, - ] - ] = None, - returns: Optional[ - Union[ - str, - List[DataType], - List[type], - 'pydantic.BaseModel', - type, - ] - ] = None, + args: Optional[ParameterType] = None, + returns: Optional[ReturnType] = None, data_format: Optional[str] = None, include_masks: bool = False, function_type: str = 'udf', @@ -204,8 +205,8 @@ def udf( func: Optional[Callable[..., Any]] = None, *, name: Optional[str] = None, - args: Optional[Union[DataType, List[DataType], Dict[str, DataType]]] = None, - returns: Optional[Union[str, List[DataType], List[type]]] = None, + args: Optional[ParameterType] = None, + returns: Optional[ReturnType] = None, data_format: Optional[str] = None, include_masks: bool = False, ) -> Callable[..., Any]: @@ -267,8 +268,8 @@ def tvf( func: Optional[Callable[..., Any]] = None, *, name: Optional[str] = None, - args: Optional[Union[DataType, List[DataType], Dict[str, DataType]]] = None, - returns: Optional[Union[str, List[DataType], List[type]]] = None, + args: Optional[ParameterType] = None, + returns: Optional[ReturnType] = None, data_format: Optional[str] = None, include_masks: bool = False, output_fields: Optional[List[str]] = None, diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index 702e3854b..4730d3269 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -286,7 +286,7 @@ async def do_func( # type: ignore do_func.__name__ = name do_func.__doc__ = func.__doc__ - sig = get_signature(func, name=name) + sig = get_signature(func, func_name=name) # Store signature for generating CREATE FUNCTION calls info['signature'] = sig @@ -1217,6 +1217,11 @@ def main(argv: Optional[List[str]] = None) -> None: or defaults.get('replace_existing') \ or get_option('external_function.replace_existing') + # Substitute in host / port if specified + if args.host != defaults.get('host') or args.port != defaults.get('port'): + u = urllib.parse.urlparse(args.url) + args.url = u._replace(netloc=f'{args.host}:{args.port}').geturl() + # Create application from functions / module app = Application( functions=args.functions, diff --git a/singlestoredb/functions/signature.py b/singlestoredb/functions/signature.py index 794f08f7d..8f6e79458 100644 --- a/singlestoredb/functions/signature.py +++ b/singlestoredb/functions/signature.py @@ -192,6 +192,39 @@ class ArrayCollection(Collection): pass +def is_typeddict(obj: Any) -> bool: + """Check if an object is a TypedDict.""" + if hasattr(typing, 'is_typeddict'): + return typing.is_typeddict(obj) # noqa: TYP006 + return False + + +def is_namedtuple(obj: Any) -> bool: + """Check if an object is a named tuple.""" + if inspect.isclass(obj): + return ( + issubclass(obj, tuple) and + hasattr(obj, '_asdict') and + hasattr(obj, '_fields') + ) + return ( + isinstance(obj, tuple) and + hasattr(obj, '_asdict') and + hasattr(obj, '_fields') + ) + + +def is_pydantic(obj: Any) -> bool: + """Check if an object is a pydantic model.""" + if not has_pydantic: + return False + + if inspect.isclass(obj): + return issubclass(obj, pydantic.BaseModel) + + return isinstance(obj, pydantic.BaseModel) + + def escape_name(name: str) -> str: """Escape a function parameter name.""" if '`' in name: @@ -203,6 +236,12 @@ def simplify_dtype(dtype: Any) -> List[Any]: """ Expand a type annotation to a flattened list of atomic types. + This function will attempty to find the underlying type of a + type annotation. For example, a Union of types will be flattened + to a list of types. A Tuple or Array type will be expanded to + a list of types. A TypeVar will be expanded to a list of + constraints and bounds. + Parameters ---------- dtype : Any @@ -210,7 +249,8 @@ def simplify_dtype(dtype: Any) -> List[Any]: Returns ------- - List[Any] -- list of dtype strings, TupleCollections, and ArrayCollections + List[Any] + list of dtype strings, TupleCollections, and ArrayCollections """ origin = typing.get_origin(dtype) @@ -252,10 +292,24 @@ def simplify_dtype(dtype: Any) -> List[Any]: return args -def classify_dtype(dtype: Any) -> str: - """Classify the type annotation into a type name.""" +def normalize_dtype(dtype: Any) -> str: + """ + Normalize the type annotation into a type name. + + Parameters + ---------- + dtype : Any + Type annotation, list of type annotations, or a string + containing a SQL type name + + Returns + ------- + str + Normalized type name + + """ if isinstance(dtype, list): - return '|'.join(classify_dtype(x) for x in dtype) + return '|'.join(normalize_dtype(x) for x in dtype) if isinstance(dtype, str): return sql_to_dtype(dtype) @@ -271,44 +325,60 @@ def classify_dtype(dtype: Any) -> str: return 'bool' if dataclasses.is_dataclass(dtype): - fields = dataclasses.fields(dtype) + dc_fields = dataclasses.fields(dtype) item_dtypes = ','.join( - f'{classify_dtype(simplify_dtype(x.type))}' for x in fields + f'{normalize_dtype(simplify_dtype(x.type))}' for x in dc_fields ) return f'tuple[{item_dtypes}]' - if has_pydantic and inspect.isclass(dtype) and issubclass(dtype, pydantic.BaseModel): - fields = dtype.model_fields.values() + if is_typeddict(dtype): + td_fields = inspect.get_annotations(dtype).keys() item_dtypes = ','.join( - f'{classify_dtype(simplify_dtype(x.annotation))}' # type: ignore - for x in fields + f'{normalize_dtype(simplify_dtype(dtype[x]))}' for x in td_fields + ) + return f'tuple[{item_dtypes}]' + + if is_pydantic(dtype): + pyd_fields = dtype.model_fields.values() + item_dtypes = ','.join( + f'{normalize_dtype(simplify_dtype(x.annotation))}' # type: ignore + for x in pyd_fields + ) + return f'tuple[{item_dtypes}]' + + if is_namedtuple(dtype): + nt_fields = inspect.get_annotations(dtype).values() + item_dtypes = ','.join( + f'{normalize_dtype(simplify_dtype(dtype[x]))}' for x in nt_fields ) return f'tuple[{item_dtypes}]' if not inspect.isclass(dtype): + # Check for compound types origin = typing.get_origin(dtype) if origin is not None: + # Tuple type if origin is Tuple: args = typing.get_args(dtype) - item_dtypes = ','.join(classify_dtype(x) for x in args) + item_dtypes = ','.join(normalize_dtype(x) for x in args) return f'tuple[{item_dtypes}]' # Array types elif issubclass(origin, array_types): args = typing.get_args(dtype) - item_dtype = classify_dtype(args[0]) + item_dtype = normalize_dtype(args[0]) return f'array[{item_dtype}]' raise TypeError(f'unsupported type annotation: {dtype}') if isinstance(dtype, ArrayCollection): - item_dtypes = ','.join(classify_dtype(x) for x in dtype.item_dtypes) + item_dtypes = ','.join(normalize_dtype(x) for x in dtype.item_dtypes) return f'array[{item_dtypes}]' if isinstance(dtype, TupleCollection): - item_dtypes = ','.join(classify_dtype(x) for x in dtype.item_dtypes) + item_dtypes = ','.join(normalize_dtype(x) for x in dtype.item_dtypes) return f'tuple[{item_dtypes}]' # Check numpy types if it's available @@ -346,7 +416,7 @@ def classify_dtype(dtype: Any) -> str: raise TypeError( f'unsupported type annotation: {dtype}; ' - 'use `args`/`returns` on the @udf/@tvf decotator to specify the data type', + 'use `args`/`returns` on the @udf/@tvf decorator to specify the data type', ) @@ -354,6 +424,9 @@ def collapse_dtypes(dtypes: Union[str, List[str]]) -> str: """ Collapse a dtype possibly containing multiple data types to one type. + This function can fail if there is no single type that naturally + encompasses all of the types in the list. + Parameters ---------- dtypes : str or list[str] @@ -364,6 +437,9 @@ def collapse_dtypes(dtypes: Union[str, List[str]]) -> str: str """ + if isinstance(dtypes, str) and '|' in dtypes: + dtypes = dtypes.split('|') + if not isinstance(dtypes, list): return dtypes @@ -443,7 +519,232 @@ def collapse_dtypes(dtypes: Union[str, List[str]]) -> str: return dtypes[0] + ('?' if is_nullable else '') -def get_signature(func: Callable[..., Any], name: Optional[str] = None) -> Dict[str, Any]: +def create_type( + types: List[Any], + output_fields: List[str], + function_type: str = 'udf', +) -> Tuple[str, str]: + """ + Create the normalized type and SQL code for the given type information. + + Parameters + ---------- + types : List[Any] + List of types to be used + output_fields : List[str] + List of field names for the resulting type + function_type : str + Type of function, either 'udf' or 'tvf' + + Returns + ------- + Tuple[str, str] + Tuple containing the output type and SQL code + + """ + out_type = 'tuple[' + ','.join([ + collapse_dtypes(normalize_dtype(x)) + for x in [simplify_dtype(y) for y in types] + ]) + ']' + + sql = dtype_to_sql( + out_type, function_type=function_type, field_names=output_fields, + ) + + return out_type, sql + + +def get_dataclass_schema(obj: Any) -> Tuple[List[Any], List[str]]: + """ + Get the schema of a dataclass. + + Parameters + ---------- + obj : dataclass + The dataclass to get the schema of + + Returns + ------- + Tuple[List[Any], List[str]] + A tuple containing the field types and field names + + """ + if not dataclasses.is_dataclass(obj): + raise TypeError('object is not a dataclass') + return ( + [x.type for x in obj.fields], + [x.name for x in obj.fields], + ) + + +def get_typeddict_schema(obj: Any) -> Tuple[List[Any], List[str]]: + """ + Get the schema of a TypedDict. + + Parameters + ---------- + obj : TypedDict + The TypedDict to get the schema of + + Returns + ------- + Tuple[List[Any], List[str]] + A tuple containing the field types and field names + + """ + return ( + list(inspect.get_annotations(obj).values()), + list(inspect.get_annotations(obj).keys()), + ) + + +def get_pydantic_schema(obj: pydantic.BaseModel) -> Tuple[List[Any], List[str]]: + """ + Get the schema of a pydantic model. + + Parameters + ---------- + obj : pydantic.BaseModel + The pydantic model to get the schema of + + Returns + ------- + Tuple[List[Any], List[str]] + A tuple containing the field types and field names + + """ + return ( + list(obj.model_fields.values()), + list(obj.model_fields.keys()), + ) + + +def get_namedtuple_schema(obj: Any) -> Tuple[List[Any], List[str]]: + """ + Get the schema of a named tuple. + + Parameters + ---------- + obj : NamedTuple + The named tuple to get the schema of + + Returns + ------- + Tuple[List[Any], List[str]] + A tuple containing the field types and field names + + """ + return ( + list(inspect.get_annotations(obj).values()), + list(inspect.get_annotations(obj).keys()), + ) + + +def get_return_type( + spec: Any, + name: str, + signature: inspect.Signature, + output_fields: Optional[List[str]] = None, + function_type: str = 'udf', +) -> Dict[str, Any]: + """ + Get the return type of a function. + + Parameters + ---------- + spec : Any + The return type specification + name : str + The name of the function + signature : inspect.Signature + The signature of the function + output_fields : List[str], optional + The output field names + function_type : str, optional + The type of function, either 'udf' or 'tvf' + + Returns + ------- + Dict[str, Any] + A dictionary containing the return type and SQL code + + """ + + # Make sure there is a return type annotation + if spec is None \ + and signature.return_annotation is inspect.Signature.empty: + raise TypeError(f'no return value annotation in function {name}') + + # + # Generate the return type and the corresponding SQL code for that value + # + + # Return type is specified by a SQL string + if isinstance(spec, str): + sql = spec + out_type = sql_to_dtype(sql) + + # Return type is a record (i.e., has multiple fields) + elif isinstance(spec, list): + + # Generate field names if needed + if not output_fields: + output_fields = [ + string.ascii_letters[i] for i in range(len(spec)) + ] + + out_type, sql = create_type( + spec, output_fields, function_type=function_type, + ) + + # Return type is specified by a dataclass definition + elif dataclasses.is_dataclass(spec): + out_type, sql = create_type( + *get_dataclass_schema(spec), + function_type=function_type, + ) + + # Return type is specified by a TypedDict definition + elif is_typeddict(spec): + out_type, sql = create_type( + *get_typeddict_schema(spec), + function_type=function_type, + ) + + # Return type is specified by a pydantic model + elif is_pydantic(spec): + out_type, sql = create_type( + *get_pydantic_schema(spec), + function_type=function_type, + ) + + # Return type is specified by a named tuple + elif is_namedtuple(spec): + out_type, sql = create_type( + *get_namedtuple_schema(spec), + function_type=function_type, + ) + + # Unrecognized return type + elif spec is not None: + if not output_fields and typing.get_origin(spec) is tuple: + output_fields = [ + string.ascii_letters[i] for i in range(len(typing.get_args(spec))) + ] + out_type = collapse_dtypes(normalize_dtype(simplify_dtype(spec))) + sql = dtype_to_sql(out_type, function_type=function_type) + + else: + out_type = 'null' + sql = 'NULL' + + return dict(dtype=out_type, sql=sql, default=None, output_fields=output_fields) + + +def get_signature( + func: Callable[..., Any], + func_name: Optional[str] = None, +) -> Dict[str, Any]: ''' Print the UDF signature of the Python callable. @@ -451,7 +752,7 @@ def get_signature(func: Callable[..., Any], name: Optional[str] = None) -> Dict[ ---------- func : Callable The function to extract the signature of - name : str, optional + func_name : str, optional Name override for function Returns @@ -462,10 +763,11 @@ def get_signature(func: Callable[..., Any], name: Optional[str] = None) -> Dict[ signature = inspect.signature(func) args: List[Dict[str, Any]] = [] attrs = getattr(func, '_singlestoredb_attrs', {}) - name = attrs.get('name', name if name else func.__name__) + name = attrs.get('name', func_name if func_name else func.__name__) function_type = attrs.get('function_type', 'udf') out: Dict[str, Any] = dict(name=name, args=args) + # Get parameter names, defaults, and annotations arg_names = [x for x in signature.parameters] defaults = [ x.default if x.default is not inspect.Parameter.empty else None @@ -476,6 +778,7 @@ def get_signature(func: Callable[..., Any], name: Optional[str] = None) -> Dict[ if x.annotation is not inspect.Parameter.empty } + # Do not allow variable positional or keyword arguments for p in signature.parameters.values(): if p.kind == inspect.Parameter.VAR_POSITIONAL: raise TypeError('variable positional arguments are not supported') @@ -483,17 +786,23 @@ def get_signature(func: Callable[..., Any], name: Optional[str] = None) -> Dict[ raise TypeError('variable keyword arguments are not supported') args_overrides = attrs.get('args', None) - returns_overrides = attrs.get('returns', None) + returns = attrs.get('returns', signature.return_annotation) output_fields = attrs.get('output_fields', None) spec_diff = set(arg_names).difference(set(annotations.keys())) + # # Make sure all arguments are annotated + # + + # If there are missing annotations and no overrides, raise an error if spec_diff and args_overrides is None: raise TypeError( 'missing annotations for {} in {}' .format(', '.join(spec_diff), name), ) + + # If there are missing annotations and overrides are provided, make sure they match elif isinstance(args_overrides, dict): for s in spec_diff: if s not in args_overrides: @@ -501,6 +810,8 @@ def get_signature(func: Callable[..., Any], name: Optional[str] = None) -> Dict[ 'missing annotations for {} in {}' .format(', '.join(spec_diff), name), ) + + # If there are missing annotations and overrides are provided, make sure they match elif isinstance(args_overrides, list): if len(arg_names) != len(args_overrides): raise TypeError( @@ -508,91 +819,57 @@ def get_signature(func: Callable[..., Any], name: Optional[str] = None) -> Dict[ .format(name, ', '.join(spec_diff)), ) + # + # Generate the parameter type and the corresponding SQL code for that parameter + # + for i, arg in enumerate(arg_names): + + # If arg_overrides is a list, use corresponding item as SQL if isinstance(args_overrides, list): sql = args_overrides[i] arg_type = sql_to_dtype(sql) + + # If arg_overrides is a dict, use the corresponding key as SQL elif isinstance(args_overrides, dict) and arg in args_overrides: sql = args_overrides[arg] arg_type = sql_to_dtype(sql) + + # If args_overrides is a string, use it as SQL (only one function parameter) elif isinstance(args_overrides, str): sql = args_overrides arg_type = sql_to_dtype(sql) + + # Unrecognized type for args_overrides elif args_overrides is not None \ and not isinstance(args_overrides, (list, dict, str)): raise TypeError(f'unrecognized type for arguments: {args_overrides}') + + # No args_overrides, use the Python type annotation else: arg_type = collapse_dtypes([ - classify_dtype(x) for x in simplify_dtype(annotations[arg]) + normalize_dtype(x) for x in simplify_dtype(annotations[arg]) ]) sql = dtype_to_sql(arg_type, function_type=function_type) - args.append(dict(name=arg, dtype=arg_type, sql=sql, default=defaults[i])) - if returns_overrides is None \ - and signature.return_annotation is inspect.Signature.empty: - raise TypeError(f'no return value annotation in function {name}') + # Append parameter information to the args list + args.append(dict(name=arg, dtype=arg_type, sql=sql, default=defaults[i])) - if isinstance(returns_overrides, str): - sql = returns_overrides - out_type = sql_to_dtype(sql) - elif isinstance(returns_overrides, list): - if not output_fields: - output_fields = [ - string.ascii_letters[i] for i in range(len(returns_overrides)) - ] - out_type = 'tuple[' + collapse_dtypes([ - classify_dtype(x) - for x in simplify_dtype(returns_overrides) - ]).replace('|', ',') + ']' - sql = dtype_to_sql( - out_type, function_type=function_type, field_names=output_fields, - ) - elif dataclasses.is_dataclass(returns_overrides): - out_type = collapse_dtypes([ - classify_dtype(x) - for x in simplify_dtype([x.type for x in returns_overrides.fields]) - ]) - sql = dtype_to_sql( - out_type, - function_type=function_type, - field_names=[x.name for x in returns_overrides.fields], - ) - elif has_pydantic and inspect.isclass(returns_overrides) \ - and issubclass(returns_overrides, pydantic.BaseModel): - out_type = collapse_dtypes([ - classify_dtype(x) - for x in simplify_dtype([x for x in returns_overrides.model_fields.values()]) - ]) - sql = dtype_to_sql( - out_type, - function_type=function_type, - field_names=[x for x in returns_overrides.model_fields.keys()], - ) - elif returns_overrides is not None and not isinstance(returns_overrides, str): - raise TypeError(f'unrecognized type for return value: {returns_overrides}') - else: - if not output_fields: - if dataclasses.is_dataclass(signature.return_annotation): - output_fields = [ - x.name for x in dataclasses.fields(signature.return_annotation) - ] - elif has_pydantic and inspect.isclass(signature.return_annotation) \ - and issubclass(signature.return_annotation, pydantic.BaseModel): - output_fields = list(signature.return_annotation.model_fields.keys()) - out_type = collapse_dtypes([ - classify_dtype(x) for x in simplify_dtype(signature.return_annotation) - ]) - sql = dtype_to_sql( - out_type, function_type=function_type, field_names=output_fields, - ) - out['returns'] = dict(dtype=out_type, sql=sql, default=None) + out['returns'] = get_return_type( + returns, name, signature, + output_fields=output_fields, function_type=function_type, + ) + # Copy keys from decorator to signature copied_keys = ['database', 'environment', 'packages', 'resources', 'replace'] for key in copied_keys: if attrs.get(key): out[key] = attrs[key] + # Set the function endpoint out['endpoint'] = '/invoke' + + # Set the function doc string out['doc'] = func.__doc__ return out @@ -666,6 +943,9 @@ def dtype_to_sql( if dtype.endswith('?'): nullable = ' NULL' dtype = dtype[:-1] + elif '|null' in dtype: + nullable = ' NULL' + dtype = dtype.replace('|null', '') if dtype == 'null': nullable = '' From 9dc58321d85dd3302519355d05c33ce4aa4d8ae0 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Fri, 4 Apr 2025 16:32:03 -0500 Subject: [PATCH 02/16] Refactor return values --- singlestoredb/functions/ext/asgi.py | 20 ++- singlestoredb/functions/signature.py | 210 ++++++++++++++------------- 2 files changed, 118 insertions(+), 112 deletions(-) diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index 4730d3269..4afac4908 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -306,20 +306,14 @@ async def do_func( # type: ignore colspec.append((x['name'], rowdat_1_type_map[dtype])) info['colspec'] = colspec - def parse_return_type(s: str) -> List[str]: - if s.startswith('tuple['): - return s[6:-1].split(',') - if s.startswith('array[tuple['): - return s[12:-2].split(',') - return [s] - # Setup return type returns = [] - for x in parse_return_type(sig['returns']['dtype']): - dtype = x.replace('?', '') + for x in sig['returns']: + dtype = x['dtype'].replace('?', '') if dtype not in rowdat_1_type_map: raise TypeError(f'no data type mapping for {dtype}') - returns.append(rowdat_1_type_map[dtype]) + print(x['name'], dtype) + returns.append((x['name'], rowdat_1_type_map[dtype])) info['returns'] = returns return do_func, info @@ -665,7 +659,9 @@ async def __call__( func_info['colspec'], b''.join(data), ), ) - body = output_handler['dump'](func_info['returns'], *out) # type: ignore + body = output_handler['dump']( + [x[1] for x in func_info['returns']], *out, # type: ignore + ) await send(output_handler['response']) @@ -682,6 +678,7 @@ async def __call__( endpoint_info['signature'], url=self.url or reflected_url, data_format=self.data_format, + function_type=endpoint_info['function_type'], ), ) body = '\n'.join(syntax).encode('utf-8') @@ -775,6 +772,7 @@ def show_create_functions( app_mode=self.app_mode, replace=replace, link=link or None, + function_type=endpoint_info['function_type'], ), ) diff --git a/singlestoredb/functions/signature.py b/singlestoredb/functions/signature.py index 8f6e79458..8b3bfcfa9 100644 --- a/singlestoredb/functions/signature.py +++ b/singlestoredb/functions/signature.py @@ -31,6 +31,7 @@ except ImportError: has_pydantic = False + from . import dtypes as dt from ..mysql.converters import escape_item # type: ignore @@ -192,6 +193,22 @@ class ArrayCollection(Collection): pass +def get_annotations(obj: Any) -> Dict[str, Any]: + """Get the annotations of an object.""" + if hasattr(inspect, 'get_annotations'): + return inspect.get_annotations(obj) + if isinstance(obj, type): + return obj.__dict__.get('__annotations__', {}) + return getattr(obj, '__annotations__', {}) + + +def is_dataframe(obj: Any) -> bool: + """Check if an object is a DataFrame.""" + # Cheating here a bit so we don't have to import pandas / polars / pyarrow + # unless we absolutely need to + return getattr(obj, '__name__', '') in ['DataFrame', 'Table'] + + def is_typeddict(obj: Any) -> bool: """Check if an object is a TypedDict.""" if hasattr(typing, 'is_typeddict'): @@ -332,7 +349,7 @@ def normalize_dtype(dtype: Any) -> str: return f'tuple[{item_dtypes}]' if is_typeddict(dtype): - td_fields = inspect.get_annotations(dtype).keys() + td_fields = get_annotations(dtype).keys() item_dtypes = ','.join( f'{normalize_dtype(simplify_dtype(dtype[x]))}' for x in td_fields ) @@ -347,7 +364,7 @@ def normalize_dtype(dtype: Any) -> str: return f'tuple[{item_dtypes}]' if is_namedtuple(dtype): - nt_fields = inspect.get_annotations(dtype).values() + nt_fields = get_annotations(dtype).values() item_dtypes = ','.join( f'{normalize_dtype(simplify_dtype(dtype[x]))}' for x in nt_fields ) @@ -554,7 +571,7 @@ def create_type( return out_type, sql -def get_dataclass_schema(obj: Any) -> Tuple[List[Any], List[str]]: +def get_dataclass_schema(obj: Any) -> List[Tuple[str, Any]]: """ Get the schema of a dataclass. @@ -565,19 +582,14 @@ def get_dataclass_schema(obj: Any) -> Tuple[List[Any], List[str]]: Returns ------- - Tuple[List[Any], List[str]] - A tuple containing the field types and field names + List[Tuple[str, Any]] + A list of tuples containing the field names and field types """ - if not dataclasses.is_dataclass(obj): - raise TypeError('object is not a dataclass') - return ( - [x.type for x in obj.fields], - [x.name for x in obj.fields], - ) + return list(get_annotations(obj).items()) -def get_typeddict_schema(obj: Any) -> Tuple[List[Any], List[str]]: +def get_typeddict_schema(obj: Any) -> List[Tuple[str, Any]]: """ Get the schema of a TypedDict. @@ -588,17 +600,14 @@ def get_typeddict_schema(obj: Any) -> Tuple[List[Any], List[str]]: Returns ------- - Tuple[List[Any], List[str]] - A tuple containing the field types and field names + List[Tuple[str, Any]] + A list of tuples containing the field names and field types """ - return ( - list(inspect.get_annotations(obj).values()), - list(inspect.get_annotations(obj).keys()), - ) + return list(get_annotations(obj).items()) -def get_pydantic_schema(obj: pydantic.BaseModel) -> Tuple[List[Any], List[str]]: +def get_pydantic_schema(obj: pydantic.BaseModel) -> List[Tuple[str, Any]]: """ Get the schema of a pydantic model. @@ -609,17 +618,14 @@ def get_pydantic_schema(obj: pydantic.BaseModel) -> Tuple[List[Any], List[str]]: Returns ------- - Tuple[List[Any], List[str]] - A tuple containing the field types and field names + List[Tuple[str, Any]] + A list of tuples containing the field names and field types """ - return ( - list(obj.model_fields.values()), - list(obj.model_fields.keys()), - ) + return [(k, v.annotation) for k, v in obj.model_fields.items()] -def get_namedtuple_schema(obj: Any) -> Tuple[List[Any], List[str]]: +def get_namedtuple_schema(obj: Any) -> List[Tuple[Any, str]]: """ Get the schema of a named tuple. @@ -630,115 +636,103 @@ def get_namedtuple_schema(obj: Any) -> Tuple[List[Any], List[str]]: Returns ------- - Tuple[List[Any], List[str]] - A tuple containing the field types and field names + List[Tuple[Any, str]] + A list of tuples containing the field names and field types """ - return ( - list(inspect.get_annotations(obj).values()), - list(inspect.get_annotations(obj).keys()), - ) + return list(get_annotations(obj).items()) -def get_return_type( +def get_return_schema( spec: Any, - name: str, - signature: inspect.Signature, output_fields: Optional[List[str]] = None, function_type: str = 'udf', -) -> Dict[str, Any]: +) -> List[Tuple[str, Any]]: """ - Get the return type of a function. + Expand a return type annotation into a list of types and field names. Parameters ---------- spec : Any The return type specification - name : str - The name of the function - signature : inspect.Signature - The signature of the function output_fields : List[str], optional The output field names - function_type : str, optional + function_type : str The type of function, either 'udf' or 'tvf' Returns ------- - Dict[str, Any] - A dictionary containing the return type and SQL code + List[Tuple[str, Any]] + A list of tuples containing the field names and field types """ + # Make sure that the result of a TVF is a list or dataframe + if function_type == 'tvf': - # Make sure there is a return type annotation - if spec is None \ - and signature.return_annotation is inspect.Signature.empty: - raise TypeError(f'no return value annotation in function {name}') - - # - # Generate the return type and the corresponding SQL code for that value - # + if typing.get_origin(spec) is list: + spec = typing.get_args(spec)[0] - # Return type is specified by a SQL string - if isinstance(spec, str): - sql = spec - out_type = sql_to_dtype(sql) + # DataFrames require special handling. You can't get the schema + # from the annotation, you need a separate structure to specify + # the types. This should be specified in the output_fields. + elif is_dataframe(spec): + if output_fields is None: + raise TypeError( + 'output_fields must be specified for DataFrames / Tables', + ) + spec = output_fields + output_fields = None - # Return type is a record (i.e., has multiple fields) - elif isinstance(spec, list): + else: + raise TypeError( + 'return type for TVF must be a list or DataFrame', + ) - # Generate field names if needed - if not output_fields: - output_fields = [ - string.ascii_letters[i] for i in range(len(spec)) - ] + elif typing.get_origin(spec) in [list, tuple, dict] \ + or is_dataframe(spec) \ + or dataclasses.is_dataclass(spec) \ + or is_typeddict(spec) \ + or is_pydantic(spec) \ + or is_namedtuple(spec): + raise TypeError('return type for UDF must be a scalar type') - out_type, sql = create_type( - spec, output_fields, function_type=function_type, - ) + # Return type is specified by a SQL string + if isinstance(spec, str): + return [('', sql_to_dtype(spec))] # Return type is specified by a dataclass definition - elif dataclasses.is_dataclass(spec): - out_type, sql = create_type( - *get_dataclass_schema(spec), - function_type=function_type, - ) + if dataclasses.is_dataclass(spec): + schema = get_dataclass_schema(spec) # Return type is specified by a TypedDict definition elif is_typeddict(spec): - out_type, sql = create_type( - *get_typeddict_schema(spec), - function_type=function_type, - ) + schema = get_typeddict_schema(spec) # Return type is specified by a pydantic model elif is_pydantic(spec): - out_type, sql = create_type( - *get_pydantic_schema(spec), - function_type=function_type, - ) + schema = get_pydantic_schema(spec) # Return type is specified by a named tuple elif is_namedtuple(spec): - out_type, sql = create_type( - *get_namedtuple_schema(spec), - function_type=function_type, - ) + schema = get_namedtuple_schema(spec) # Unrecognized return type elif spec is not None: - if not output_fields and typing.get_origin(spec) is tuple: + if typing.get_origin(spec) is tuple: output_fields = [ string.ascii_letters[i] for i in range(len(typing.get_args(spec))) ] - out_type = collapse_dtypes(normalize_dtype(simplify_dtype(spec))) - sql = dtype_to_sql(out_type, function_type=function_type) - - else: - out_type = 'null' - sql = 'NULL' - - return dict(dtype=out_type, sql=sql, default=None, output_fields=output_fields) + schema = [(x, y) for x, y in zip(output_fields, typing.get_args(spec))] + else: + schema = [('', spec)] + + # Normalize schema data types + out = [] + for k, v in schema: + out.append(( + k, collapse_dtypes([normalize_dtype(x) for x in simplify_dtype(v)]), + )) + return out def get_signature( @@ -762,13 +756,15 @@ def get_signature( ''' signature = inspect.signature(func) args: List[Dict[str, Any]] = [] + returns: List[Dict[str, Any]] = [] attrs = getattr(func, '_singlestoredb_attrs', {}) name = attrs.get('name', func_name if func_name else func.__name__) function_type = attrs.get('function_type', 'udf') - out: Dict[str, Any] = dict(name=name, args=args) + out: Dict[str, Any] = dict(name=name, args=args, returns=returns) # Get parameter names, defaults, and annotations arg_names = [x for x in signature.parameters] + args_overrides = attrs.get('args', None) defaults = [ x.default if x.default is not inspect.Parameter.empty else None for x in signature.parameters.values() @@ -785,10 +781,6 @@ def get_signature( elif p.kind == inspect.Parameter.VAR_KEYWORD: raise TypeError('variable keyword arguments are not supported') - args_overrides = attrs.get('args', None) - returns = attrs.get('returns', signature.return_annotation) - output_fields = attrs.get('output_fields', None) - spec_diff = set(arg_names).difference(set(annotations.keys())) # @@ -855,11 +847,20 @@ def get_signature( # Append parameter information to the args list args.append(dict(name=arg, dtype=arg_type, sql=sql, default=defaults[i])) - out['returns'] = get_return_type( - returns, name, signature, - output_fields=output_fields, function_type=function_type, + # + # Generate the return types and the corresponding SQL code for those values + # + + ret_schema = get_return_schema( + attrs.get('returns', signature.return_annotation), + output_fields=attrs.get('output_fields', None), + function_type=function_type, ) + for i, (name, rtype) in enumerate(ret_schema): + sql = dtype_to_sql(rtype, function_type=function_type) + returns.append(dict(name=name, dtype=rtype, sql=sql)) + # Copy keys from decorator to signature copied_keys = ['database', 'environment', 'packages', 'resources', 'replace'] for key in copied_keys: @@ -994,6 +995,7 @@ def signature_to_sql( app_mode: str = 'remote', link: Optional[str] = None, replace: bool = False, + function_type: str = 'udf', ) -> str: ''' Convert a dictionary function signature into SQL. @@ -1021,7 +1023,13 @@ def signature_to_sql( returns = '' if signature.get('returns'): - res = signature['returns']['sql'] + prefix = 'RECORD(' + if function_type == 'tvf': + prefix = 'TABLE(' + res = prefix + ', '.join( + f'{escape_name(x["name"])} {x["sql"]}' + for x in signature['returns'] + ) + ')' returns = f' RETURNS {res}' host = os.environ.get('SINGLESTOREDB_EXT_HOST', '127.0.0.1') From edc689398802c3bc0dba872c2171295da10b8b2b Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Mon, 7 Apr 2025 10:09:55 -0500 Subject: [PATCH 03/16] Fix UDF return values --- singlestoredb/functions/signature.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/singlestoredb/functions/signature.py b/singlestoredb/functions/signature.py index 8b3bfcfa9..79ca65247 100644 --- a/singlestoredb/functions/signature.py +++ b/singlestoredb/functions/signature.py @@ -1023,13 +1023,17 @@ def signature_to_sql( returns = '' if signature.get('returns'): - prefix = 'RECORD(' + ret = signature['returns'] if function_type == 'tvf': - prefix = 'TABLE(' - res = prefix + ', '.join( - f'{escape_name(x["name"])} {x["sql"]}' - for x in signature['returns'] - ) + ')' + res = 'TABLE(' + ', '.join( + f'{escape_name(x["name"])} {x["sql"]}' for x in ret + ) + ')' + elif ret[0]['name']: + res = 'RECORD(' + ', '.join( + f'{escape_name(x["name"])} {x["sql"]}' for x in ret + ) + ')' + else: + res = ret[0]['sql'] returns = f' RETURNS {res}' host = os.environ.get('SINGLESTOREDB_EXT_HOST', '127.0.0.1') From 114c73ad21884f25cec0d3132197b965a7b6cff6 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Thu, 10 Apr 2025 11:08:59 -0500 Subject: [PATCH 04/16] Big refactoring of the way parameters / return values work --- accel.c | 28 +- singlestoredb/functions/decorator.py | 215 ++------- singlestoredb/functions/dtypes.py | 609 ++++++++++++++++++------ singlestoredb/functions/ext/asgi.py | 147 +++--- singlestoredb/functions/ext/json.py | 4 + singlestoredb/functions/ext/rowdat_1.py | 4 + singlestoredb/functions/signature.py | 444 +++++++++++------ 7 files changed, 914 insertions(+), 537 deletions(-) diff --git a/accel.c b/accel.c index 7b809f94d..3c16f2182 100644 --- a/accel.c +++ b/accel.c @@ -36,6 +36,7 @@ #define NUMPY_DATETIME 13 #define NUMPY_OBJECT 14 #define NUMPY_BYTES 15 +#define NUMPY_FIXED_STRING 16 #define MYSQL_FLAG_NOT_NULL 1 #define MYSQL_FLAG_PRI_KEY 2 @@ -2745,6 +2746,13 @@ static NumpyColType get_numpy_col_type(PyObject *py_array) { goto error; } break; + case 'U': + out.type = NUMPY_FIXED_STRING; + out.length = (Py_ssize_t)strtol(str + 2, NULL, 10); + if (out.length < 0) { + goto error; + } + break; default: goto error; } @@ -3845,7 +3853,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k case MYSQL_TYPE_MEDIUM_BLOB: case MYSQL_TYPE_LONG_BLOB: case MYSQL_TYPE_BLOB: - if (col_types[i].type != NUMPY_OBJECT) { + if (col_types[i].type != NUMPY_OBJECT && col_types[i].type != NUMPY_FIXED_STRING) { PyErr_SetString(PyExc_ValueError, "unsupported numpy data type for character output types"); goto error; } @@ -3856,6 +3864,24 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k memcpy(out+out_idx, &i64, 8); out_idx += 8; + } else if (col_types[i].type == NUMPY_FIXED_STRING) { + void *bytes = (void*)(cols[i] + j * 8); + + if (bytes == NULL) { + CHECKMEM(8); + i64 = 0; + memcpy(out+out_idx, &i64, 8); + out_idx += 8; + } else { + Py_ssize_t str_l = strnlen(bytes, col_types[i].length); + CHECKMEM(8+str_l); + i64 = str_l; + memcpy(out+out_idx, &i64, 8); + out_idx += 8; + memcpy(out+out_idx, bytes, str_l); + out_idx += str_l; + } + } else { u64 = *(uint64_t*)(cols[i] + j * 8); diff --git a/singlestoredb/functions/decorator.py b/singlestoredb/functions/decorator.py index 0d43fe148..82ade5a46 100644 --- a/singlestoredb/functions/decorator.py +++ b/singlestoredb/functions/decorator.py @@ -1,142 +1,56 @@ -import dataclasses -import datetime +from __future__ import annotations + import functools import inspect from typing import Any from typing import Callable -from typing import Dict from typing import List from typing import Optional -from typing import Tuple +from typing import Type from typing import Union -from . import dtypes -from .dtypes import DataType -from .signature import simplify_dtype - -try: - import pydantic - has_pydantic = True -except ImportError: - has_pydantic = False - -python_type_map: Dict[Any, Callable[..., str]] = { - str: dtypes.TEXT, - int: dtypes.BIGINT, - float: dtypes.DOUBLE, - bool: dtypes.BOOL, - bytes: dtypes.BINARY, - bytearray: dtypes.BINARY, - datetime.datetime: dtypes.DATETIME, - datetime.date: dtypes.DATE, - datetime.timedelta: dtypes.TIME, -} - - -def listify(x: Any) -> List[Any]: - """Make sure sure value is a list.""" - if x is None: - return [] - if isinstance(x, (list, tuple, set)): - return list(x) - return [x] - - -def process_annotation(annotation: Any) -> Tuple[Any, bool]: - types = simplify_dtype(annotation) - if isinstance(types, list): - nullable = False - if type(None) in types: - nullable = True - types = [x for x in types if x is not type(None)] - if len(types) > 1: - raise ValueError(f'multiple types not supported: {annotation}') - return types[0], nullable - return types, True +ParameterType = Union[ + str, + Callable[..., str], + List[Union[str, Callable[..., str]]], + Type[Any], +] -def process_types(params: Any) -> Any: - if params is None: - return params, [] - - elif isinstance(params, (list, tuple)): - params = list(params) - for i, item in enumerate(params): - if params[i] in python_type_map: - params[i] = python_type_map[params[i]]() - elif callable(item): - params[i] = item() - for item in params: - if not isinstance(item, str): - raise TypeError(f'unrecognized type for parameter: {item}') - return params, [] - - elif isinstance(params, dict): - names = [] - params = dict(params) - for k, v in list(params.items()): - names.append(k) - if params[k] in python_type_map: - params[k] = python_type_map[params[k]]() - elif callable(v): - params[k] = v() - for item in params.values(): - if not isinstance(item, str): - raise TypeError(f'unrecognized type for parameter: {item}') - return params, names - - elif dataclasses.is_dataclass(params): - names = [] - out = [] - for item in dataclasses.fields(params): - typ, nullable = process_annotation(item.type) - sql_type = process_types(typ)[0] - if not nullable: - sql_type = sql_type.replace('NULL', 'NOT NULL') - out.append(sql_type) - names.append(item.name) - return out, names - - elif has_pydantic and inspect.isclass(params) \ - and issubclass(params, pydantic.BaseModel): - names = [] - out = [] - for name, item in params.model_fields.items(): - typ, nullable = process_annotation(item.annotation) - sql_type = process_types(typ)[0] - if not nullable: - sql_type = sql_type.replace('NULL', 'NOT NULL') - out.append(sql_type) - names.append(name) - return out, names - - elif params in python_type_map: - return python_type_map[params](), [] +ReturnType = ParameterType - elif callable(params): - return params(), [] - elif isinstance(params, str): - return params, [] +def expand_types(args: Any) -> Optional[Union[List[str], Type[Any]]]: + """Expand the types for the function arguments / return values.""" + if args is None: + return None - raise TypeError(f'unrecognized data type for args: {params}') + # SQL string + if isinstance(args, str): + return [args] + # General way of accepting pydantic.BaseModel, NamedTuple, TypedDict + elif inspect.isclass(args): + return args -ParameterType = Union[ - str, - List[str], - Dict[str, str], - 'pydantic.BaseModel', - type, -] + # Callable that returns a SQL string + elif callable(args): + out = args() + if not isinstance(out, str): + raise TypeError(f'unrecognized type for parameter: {args}') + return [out] -ReturnType = Union[ - str, - List[DataType], - List[type], - 'pydantic.BaseModel', - type, -] + # List of SQL strings or callables + else: + new_args = [] + for arg in args: + if isinstance(arg, str): + new_args.append(arg) + elif callable(arg): + new_args.append(arg()) + else: + raise TypeError(f'unrecognized type for parameter: {arg}') + return new_args def _func( @@ -145,40 +59,18 @@ def _func( name: Optional[str] = None, args: Optional[ParameterType] = None, returns: Optional[ReturnType] = None, - data_format: Optional[str] = None, include_masks: bool = False, function_type: str = 'udf', - output_fields: Optional[List[str]] = None, ) -> Callable[..., Any]: """Generic wrapper for UDF and TVF decorators.""" - args, _ = process_types(args) - returns, fields = process_types(returns) - - if not output_fields and fields: - output_fields = fields - - if isinstance(returns, list) \ - and isinstance(output_fields, list) \ - and len(output_fields) != len(returns): - raise ValueError( - 'The number of output fields must match the number of return types', - ) - - if include_masks and data_format == 'python': - raise RuntimeError( - 'include_masks is only valid when using ' - 'vectors for input parameters', - ) _singlestoredb_attrs = { # type: ignore k: v for k, v in dict( name=name, - args=args, - returns=returns, - data_format=data_format, + args=expand_types(args), + returns=expand_types(returns), include_masks=include_masks, function_type=function_type, - output_fields=output_fields or None, ).items() if v is not None } @@ -207,7 +99,6 @@ def udf( name: Optional[str] = None, args: Optional[ParameterType] = None, returns: Optional[ReturnType] = None, - data_format: Optional[str] = None, include_masks: bool = False, ) -> Callable[..., Any]: """ @@ -219,7 +110,7 @@ def udf( The UDF to apply parameters to name : str, optional The name to use for the UDF in the database - args : str | Callable | List[str | Callable] | Dict[str, str | Callable], optional + args : str | Callable | List[str | Callable], optional Specifies the data types of the function arguments. Typically, the function data types are derived from the function parameter annotations. These annotations can be overridden. If the function @@ -235,8 +126,6 @@ def udf( returns : str, optional Specifies the return data type of the function. If not specified, the type annotation from the function is used. - data_format : str, optional - The data format of each parameter: python, pandas, arrow, polars include_masks : bool, optional Should boolean masks be included with each input parameter to indicate which elements are NULL? This is only used when a input parameters are @@ -252,27 +141,18 @@ def udf( name=name, args=args, returns=returns, - data_format=data_format, include_masks=include_masks, function_type='udf', ) -udf.pandas = functools.partial(udf, data_format='pandas') # type: ignore -udf.polars = functools.partial(udf, data_format='polars') # type: ignore -udf.arrow = functools.partial(udf, data_format='arrow') # type: ignore -udf.numpy = functools.partial(udf, data_format='numpy') # type: ignore - - def tvf( func: Optional[Callable[..., Any]] = None, *, name: Optional[str] = None, args: Optional[ParameterType] = None, returns: Optional[ReturnType] = None, - data_format: Optional[str] = None, include_masks: bool = False, - output_fields: Optional[List[str]] = None, ) -> Callable[..., Any]: """ Apply attributes to a TVF. @@ -283,7 +163,7 @@ def tvf( The TVF to apply parameters to name : str, optional The name to use for the TVF in the database - args : str | Callable | List[str | Callable] | Dict[str, str | Callable], optional + args : str | Callable | List[str | Callable], optional Specifies the data types of the function arguments. Typically, the function data types are derived from the function parameter annotations. These annotations can be overridden. If the function @@ -299,15 +179,10 @@ def tvf( returns : str, optional Specifies the return data type of the function. If not specified, the type annotation from the function is used. - data_format : str, optional - The data format of each parameter: python, pandas, arrow, polars include_masks : bool, optional Should boolean masks be included with each input parameter to indicate which elements are NULL? This is only used when a input parameters are configured to a vector type (numpy, pandas, polars, arrow). - output_fields : List[str], optional - The names of the output fields for the TVF. If not specified, the - names are generated. Returns ------- @@ -319,14 +194,6 @@ def tvf( name=name, args=args, returns=returns, - data_format=data_format, include_masks=include_masks, function_type='tvf', - output_fields=output_fields, ) - - -tvf.pandas = functools.partial(tvf, data_format='pandas') # type: ignore -tvf.polars = functools.partial(tvf, data_format='polars') # type: ignore -tvf.arrow = functools.partial(tvf, data_format='arrow') # type: ignore -tvf.numpy = functools.partial(tvf, data_format='numpy') # type: ignore diff --git a/singlestoredb/functions/dtypes.py b/singlestoredb/functions/dtypes.py index da84e558a..905bec14d 100644 --- a/singlestoredb/functions/dtypes.py +++ b/singlestoredb/functions/dtypes.py @@ -20,6 +20,11 @@ DataType = Union[str, Callable[..., Any]] +class SQLString(str): + """SQL string type.""" + name: Optional[str] = None + + class NULL: """NULL (for use in default values).""" pass @@ -194,7 +199,12 @@ def _bool(x: Optional[bool] = None) -> Optional[bool]: return bool(x) -def BOOL(*, nullable: bool = True, default: Optional[bool] = None) -> str: +def BOOL( + *, + nullable: bool = True, + default: Optional[bool] = None, + name: Optional[str] = None, +) -> SQLString: """ BOOL type specification. @@ -204,16 +214,25 @@ def BOOL(*, nullable: bool = True, default: Optional[bool] = None) -> str: Can the value be NULL? default : bool, optional Default value + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ - return 'BOOL' + _modifiers(nullable=nullable, default=_bool(default)) + out = SQLString('BOOL' + _modifiers(nullable=nullable, default=_bool(default))) + out.name = name + return out -def BOOLEAN(*, nullable: bool = True, default: Optional[bool] = None) -> str: +def BOOLEAN( + *, + nullable: bool = True, + default: Optional[bool] = None, + name: Optional[str] = None, +) -> SQLString: """ BOOLEAN type specification. @@ -223,16 +242,25 @@ def BOOLEAN(*, nullable: bool = True, default: Optional[bool] = None) -> str: Can the value be NULL? default : bool, optional Default value + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ - return 'BOOLEAN' + _modifiers(nullable=nullable, default=_bool(default)) + out = SQLString('BOOLEAN' + _modifiers(nullable=nullable, default=_bool(default))) + out.name = name + return out -def BIT(*, nullable: bool = True, default: Optional[int] = None) -> str: +def BIT( + *, + nullable: bool = True, + default: Optional[int] = None, + name: Optional[str] = None, +) -> SQLString: """ BIT type specification. @@ -242,13 +270,17 @@ def BIT(*, nullable: bool = True, default: Optional[int] = None) -> str: Can the value be NULL? default : int, optional Default value + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ - return 'BIT' + _modifiers(nullable=nullable, default=default) + out = SQLString('BIT' + _modifiers(nullable=nullable, default=default)) + out.name = name + return out def TINYINT( @@ -257,7 +289,8 @@ def TINYINT( nullable: bool = True, default: Optional[int] = None, unsigned: bool = False, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ TINYINT type specification. @@ -271,14 +304,20 @@ def TINYINT( Default value unsigned : bool, optional Is the int unsigned? + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ out = f'TINYINT({display_width})' if display_width else 'TINYINT' - return out + _modifiers(nullable=nullable, default=default, unsigned=unsigned) + out = SQLString( + out + _modifiers(nullable=nullable, default=default, unsigned=unsigned), + ) + out.name = name + return out def TINYINT_UNSIGNED( @@ -286,7 +325,8 @@ def TINYINT_UNSIGNED( *, nullable: bool = True, default: Optional[int] = None, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ TINYINT UNSIGNED type specification. @@ -298,14 +338,18 @@ def TINYINT_UNSIGNED( Can the value be NULL? default : int, optional Default value + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ out = f'TINYINT({display_width})' if display_width else 'TINYINT' - return out + _modifiers(nullable=nullable, default=default, unsigned=True) + out = SQLString(out + _modifiers(nullable=nullable, default=default, unsigned=True)) + out.name = name + return out def SMALLINT( @@ -314,7 +358,8 @@ def SMALLINT( nullable: bool = True, default: Optional[int] = None, unsigned: bool = False, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ SMALLINT type specification. @@ -328,14 +373,20 @@ def SMALLINT( Default value unsigned : bool, optional Is the int unsigned? + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ out = f'SMALLINT({display_width})' if display_width else 'SMALLINT' - return out + _modifiers(nullable=nullable, default=default, unsigned=unsigned) + out = SQLString( + out + _modifiers(nullable=nullable, default=default, unsigned=unsigned), + ) + out.name = name + return out def SMALLINT_UNSIGNED( @@ -343,7 +394,8 @@ def SMALLINT_UNSIGNED( *, nullable: bool = True, default: Optional[int] = None, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ SMALLINT UNSIGNED type specification. @@ -355,14 +407,18 @@ def SMALLINT_UNSIGNED( Can the value be NULL? default : int, optional Default value + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ out = f'SMALLINT({display_width})' if display_width else 'SMALLINT' - return out + _modifiers(nullable=nullable, default=default, unsigned=True) + out = SQLString(out + _modifiers(nullable=nullable, default=default, unsigned=True)) + out.name = name + return out def MEDIUMINT( @@ -371,7 +427,8 @@ def MEDIUMINT( nullable: bool = True, default: Optional[int] = None, unsigned: bool = False, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ MEDIUMINT type specification. @@ -385,14 +442,20 @@ def MEDIUMINT( Default value unsigned : bool, optional Is the int unsigned? + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ out = f'MEDIUMINT({display_width})' if display_width else 'MEDIUMINT' - return out + _modifiers(nullable=nullable, default=default, unsigned=unsigned) + out = SQLString( + out + _modifiers(nullable=nullable, default=default, unsigned=unsigned), + ) + out.name = name + return out def MEDIUMINT_UNSIGNED( @@ -400,7 +463,8 @@ def MEDIUMINT_UNSIGNED( *, nullable: bool = True, default: Optional[int] = None, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ MEDIUMINT UNSIGNED type specification. @@ -412,14 +476,18 @@ def MEDIUMINT_UNSIGNED( Can the value be NULL? default : int, optional Default value + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ out = f'MEDIUMINT({display_width})' if display_width else 'MEDIUMINT' - return out + _modifiers(nullable=nullable, default=default, unsigned=True) + out = SQLString(out + _modifiers(nullable=nullable, default=default, unsigned=True)) + out.name = name + return out def INT( @@ -428,7 +496,8 @@ def INT( nullable: bool = True, default: Optional[int] = None, unsigned: bool = False, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ INT type specification. @@ -442,14 +511,20 @@ def INT( Default value unsigned : bool, optional Is the int unsigned? + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ out = f'INT({display_width})' if display_width else 'INT' - return out + _modifiers(nullable=nullable, default=default, unsigned=unsigned) + out = SQLString( + out + _modifiers(nullable=nullable, default=default, unsigned=unsigned), + ) + out.name = name + return out def INT_UNSIGNED( @@ -457,7 +532,8 @@ def INT_UNSIGNED( *, nullable: bool = True, default: Optional[int] = None, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ INT UNSIGNED type specification. @@ -469,14 +545,18 @@ def INT_UNSIGNED( Can the value be NULL? default : int, optional Default value + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ out = f'INT({display_width})' if display_width else 'INT' - return out + _modifiers(nullable=nullable, default=default, unsigned=True) + out = SQLString(out + _modifiers(nullable=nullable, default=default, unsigned=True)) + out.name = name + return out def INTEGER( @@ -485,7 +565,8 @@ def INTEGER( nullable: bool = True, default: Optional[int] = None, unsigned: bool = False, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ INTEGER type specification. @@ -499,14 +580,20 @@ def INTEGER( Default value unsigned : bool, optional Is the int unsigned? + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ out = f'INTEGER({display_width})' if display_width else 'INTEGER' - return out + _modifiers(nullable=nullable, default=default, unsigned=unsigned) + out = SQLString( + out + _modifiers(nullable=nullable, default=default, unsigned=unsigned), + ) + out.name = name + return out def INTEGER_UNSIGNED( @@ -514,7 +601,8 @@ def INTEGER_UNSIGNED( *, nullable: bool = True, default: Optional[int] = None, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ INTEGER UNSIGNED type specification. @@ -526,14 +614,18 @@ def INTEGER_UNSIGNED( Can the value be NULL? default : int, optional Default value + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ out = f'INTEGER({display_width})' if display_width else 'INTEGER' - return out + _modifiers(nullable=nullable, default=default, unsigned=True) + out = SQLString(out + _modifiers(nullable=nullable, default=default, unsigned=True)) + out.name = name + return out def BIGINT( @@ -542,7 +634,8 @@ def BIGINT( nullable: bool = True, default: Optional[int] = None, unsigned: bool = False, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ BIGINT type specification. @@ -556,14 +649,20 @@ def BIGINT( Default value unsigned : bool, optional Is the int unsigned? + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ out = f'BIGINT({display_width})' if display_width else 'BIGINT' - return out + _modifiers(nullable=nullable, default=default, unsigned=unsigned) + out = SQLString( + out + _modifiers(nullable=nullable, default=default, unsigned=unsigned), + ) + out.name = name + return out def BIGINT_UNSIGNED( @@ -571,7 +670,8 @@ def BIGINT_UNSIGNED( *, nullable: bool = True, default: Optional[int] = None, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ BIGINT UNSIGNED type specification. @@ -583,14 +683,18 @@ def BIGINT_UNSIGNED( Can the value be NULL? default : int, optional Default value + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ out = f'BIGINT({int(display_width)})' if display_width else 'BIGINT' - return out + _modifiers(nullable=nullable, default=default, unsigned=True) + out = SQLString(out + _modifiers(nullable=nullable, default=default, unsigned=True)) + out.name = name + return out def FLOAT( @@ -598,7 +702,8 @@ def FLOAT( *, nullable: bool = True, default: Optional[float] = None, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ FLOAT type specification. @@ -610,14 +715,18 @@ def FLOAT( Can the value be NULL? default : float, optional Default value + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ out = f'FLOAT({int(display_decimals)})' if display_decimals else 'FLOAT' - return out + _modifiers(nullable=nullable, default=default) + out = SQLString(out + _modifiers(nullable=nullable, default=default)) + out.name = name + return out def DOUBLE( @@ -625,7 +734,8 @@ def DOUBLE( *, nullable: bool = True, default: Optional[float] = None, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ DOUBLE type specification. @@ -637,14 +747,18 @@ def DOUBLE( Can the value be NULL? default : float, optional Default value + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ out = f'DOUBLE({int(display_decimals)})' if display_decimals else 'DOUBLE' - return out + _modifiers(nullable=nullable, default=default) + out = SQLString(out + _modifiers(nullable=nullable, default=default)) + out.name = name + return out def REAL( @@ -652,7 +766,8 @@ def REAL( *, nullable: bool = True, default: Optional[float] = None, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ REAL type specification. @@ -664,14 +779,18 @@ def REAL( Can the value be NULL? default : float, optional Default value + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ out = f'REAL({int(display_decimals)})' if display_decimals else 'REAL' - return out + _modifiers(nullable=nullable, default=default) + out = SQLString(out + _modifiers(nullable=nullable, default=default)) + out.name = name + return out def DECIMAL( @@ -680,7 +799,8 @@ def DECIMAL( *, nullable: bool = True, default: Optional[Union[str, decimal.Decimal]] = None, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ DECIMAL type specification. @@ -694,14 +814,20 @@ def DECIMAL( Can the value be NULL? default : str or decimal.Decimal, optional Default value + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ - return f'DECIMAL({int(precision)}, {int(scale)})' + \ - _modifiers(nullable=nullable, default=default) + out = SQLString( + f'DECIMAL({int(precision)}, {int(scale)})' + + _modifiers(nullable=nullable, default=default), + ) + out.name = name + return out def DEC( @@ -710,7 +836,8 @@ def DEC( *, nullable: bool = True, default: Optional[Union[str, decimal.Decimal]] = None, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ DEC type specification. @@ -724,14 +851,20 @@ def DEC( Can the value be NULL? default : str or decimal.Decimal, optional Default value + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ - return f'DEC({int(precision)}, {int(scale)})' + \ - _modifiers(nullable=nullable, default=default) + out = SQLString( + f'DEC({int(precision)}, {int(scale)})' + + _modifiers(nullable=nullable, default=default), + ) + out.name = name + return out def FIXED( @@ -740,7 +873,8 @@ def FIXED( *, nullable: bool = True, default: Optional[Union[str, decimal.Decimal]] = None, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ FIXED type specification. @@ -754,14 +888,20 @@ def FIXED( Can the value be NULL? default : str or decimal.Decimal, optional Default value + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ - return f'FIXED({int(precision)}, {int(scale)})' + \ - _modifiers(nullable=nullable, default=default) + out = SQLString( + f'FIXED({int(precision)}, {int(scale)})' + + _modifiers(nullable=nullable, default=default), + ) + out.name = name + return out def NUMERIC( @@ -770,7 +910,8 @@ def NUMERIC( *, nullable: bool = True, default: Optional[Union[str, decimal.Decimal]] = None, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ NUMERIC type specification. @@ -784,21 +925,28 @@ def NUMERIC( Can the value be NULL? default : str or decimal.Decimal, optional Default value + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ - return f'NUMERIC({int(precision)}, {int(scale)})' + \ - _modifiers(nullable=nullable, default=default) + out = SQLString( + f'NUMERIC({int(precision)}, {int(scale)})' + + _modifiers(nullable=nullable, default=default), + ) + out.name = name + return out def DATE( *, nullable: bool = True, default: Optional[Union[str, datetime.date]] = None, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ DATE type specification. @@ -808,13 +956,17 @@ def DATE( Can the value be NULL? default : str or datetime.date, optional Default value + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ - return 'DATE' + _modifiers(nullable=nullable, default=default) + out = SQLString('DATE' + _modifiers(nullable=nullable, default=default)) + out.name = name + return out def TIME( @@ -822,7 +974,8 @@ def TIME( *, nullable: bool = True, default: Optional[Union[str, datetime.timedelta]] = None, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ TIME type specification. @@ -834,14 +987,18 @@ def TIME( Can the value be NULL? default : str or datetime.timedelta, optional Default value + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ out = f'TIME({int(precision)})' if precision else 'TIME' - return out + _modifiers(nullable=nullable, default=default) + out = SQLString(out + _modifiers(nullable=nullable, default=default)) + out.name = name + return out def DATETIME( @@ -849,7 +1006,8 @@ def DATETIME( *, nullable: bool = True, default: Optional[Union[str, datetime.datetime]] = None, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ DATETIME type specification. @@ -861,14 +1019,18 @@ def DATETIME( Can the value be NULL? default : str or datetime.datetime, optional Default value + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ out = f'DATETIME({int(precision)})' if precision else 'DATETIME' - return out + _modifiers(nullable=nullable, default=default) + out = SQLString(out + _modifiers(nullable=nullable, default=default)) + out.name = name + return out def TIMESTAMP( @@ -876,7 +1038,8 @@ def TIMESTAMP( *, nullable: bool = True, default: Optional[Union[str, datetime.datetime]] = None, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ TIMESTAMP type specification. @@ -888,17 +1051,26 @@ def TIMESTAMP( Can the value be NULL? default : str or datetime.datetime, optional Default value + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ out = f'TIMESTAMP({int(precision)})' if precision else 'TIMESTAMP' - return out + _modifiers(nullable=nullable, default=default) + out = SQLString(out + _modifiers(nullable=nullable, default=default)) + out.name = name + return out -def YEAR(*, nullable: bool = True, default: Optional[int] = None) -> str: +def YEAR( + *, + nullable: bool = True, + default: Optional[int] = None, + name: Optional[str] = None, +) -> SQLString: """ YEAR type specification. @@ -908,13 +1080,17 @@ def YEAR(*, nullable: bool = True, default: Optional[int] = None) -> str: Can the value be NULL? default : int, optional Default value + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ - return 'YEAR' + _modifiers(nullable=nullable, default=default) + out = SQLString('YEAR' + _modifiers(nullable=nullable, default=default)) + out.name = name + return out def CHAR( @@ -924,7 +1100,8 @@ def CHAR( default: Optional[str] = None, collate: Optional[str] = None, charset: Optional[str] = None, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ CHAR type specification. @@ -940,17 +1117,23 @@ def CHAR( Collation charset : str, optional Character set + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ out = f'CHAR({int(length)})' if length else 'CHAR' - return out + _modifiers( - nullable=nullable, default=default, - collate=collate, charset=charset, + out = SQLString( + out + _modifiers( + nullable=nullable, default=default, + collate=collate, charset=charset, + ), ) + out.name = name + return out def VARCHAR( @@ -960,7 +1143,8 @@ def VARCHAR( default: Optional[str] = None, collate: Optional[str] = None, charset: Optional[str] = None, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ VARCHAR type specification. @@ -976,17 +1160,23 @@ def VARCHAR( Collation charset : str, optional Character set + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ out = f'VARCHAR({int(length)})' if length else 'VARCHAR' - return out + _modifiers( - nullable=nullable, default=default, - collate=collate, charset=charset, + out = SQLString( + out + _modifiers( + nullable=nullable, default=default, + collate=collate, charset=charset, + ), ) + out.name = name + return out def LONGTEXT( @@ -996,7 +1186,8 @@ def LONGTEXT( default: Optional[str] = None, collate: Optional[str] = None, charset: Optional[str] = None, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ LONGTEXT type specification. @@ -1012,17 +1203,23 @@ def LONGTEXT( Collation charset : str, optional Character set + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ out = f'LONGTEXT({int(length)})' if length else 'LONGTEXT' - return out + _modifiers( - nullable=nullable, default=default, - collate=collate, charset=charset, + out = SQLString( + out + _modifiers( + nullable=nullable, default=default, + collate=collate, charset=charset, + ), ) + out.name = name + return out def MEDIUMTEXT( @@ -1032,7 +1229,8 @@ def MEDIUMTEXT( default: Optional[str] = None, collate: Optional[str] = None, charset: Optional[str] = None, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ MEDIUMTEXT type specification. @@ -1048,17 +1246,23 @@ def MEDIUMTEXT( Collation charset : str, optional Character set + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ out = f'MEDIUMTEXT({int(length)})' if length else 'MEDIUMTEXT' - return out + _modifiers( - nullable=nullable, default=default, - collate=collate, charset=charset, + out = SQLString( + out + _modifiers( + nullable=nullable, default=default, + collate=collate, charset=charset, + ), ) + out.name = name + return out def TEXT( @@ -1068,7 +1272,8 @@ def TEXT( default: Optional[str] = None, collate: Optional[str] = None, charset: Optional[str] = None, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ TEXT type specification. @@ -1084,17 +1289,23 @@ def TEXT( Collation charset : str, optional Character set + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ out = f'TEXT({int(length)})' if length else 'TEXT' - return out + _modifiers( - nullable=nullable, default=default, - collate=collate, charset=charset, + out = SQLString( + out + _modifiers( + nullable=nullable, default=default, + collate=collate, charset=charset, + ), ) + out.name = name + return out def TINYTEXT( @@ -1104,7 +1315,8 @@ def TINYTEXT( default: Optional[str] = None, collate: Optional[str] = None, charset: Optional[str] = None, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ TINYTEXT type specification. @@ -1120,17 +1332,23 @@ def TINYTEXT( Collation charset : str, optional Character set + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ out = f'TINYTEXT({int(length)})' if length else 'TINYTEXT' - return out + _modifiers( - nullable=nullable, default=default, - collate=collate, charset=charset, + out = SQLString( + out + _modifiers( + nullable=nullable, default=default, + collate=collate, charset=charset, + ), ) + out.name = name + return out def BINARY( @@ -1139,7 +1357,8 @@ def BINARY( nullable: bool = True, default: Optional[bytes] = None, collate: Optional[str] = None, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ BINARY type specification. @@ -1153,16 +1372,22 @@ def BINARY( Default value collate : str, optional Collation + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ out = f'BINARY({int(length)})' if length else 'BINARY' - return out + _modifiers( - nullable=nullable, default=default, collate=collate, + out = SQLString( + out + _modifiers( + nullable=nullable, default=default, collate=collate, + ), ) + out.name = name + return out def VARBINARY( @@ -1171,7 +1396,8 @@ def VARBINARY( nullable: bool = True, default: Optional[bytes] = None, collate: Optional[str] = None, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ VARBINARY type specification. @@ -1185,16 +1411,22 @@ def VARBINARY( Default value collate : str, optional Collation + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ out = f'VARBINARY({int(length)})' if length else 'VARBINARY' - return out + _modifiers( - nullable=nullable, default=default, collate=collate, + out = SQLString( + out + _modifiers( + nullable=nullable, default=default, collate=collate, + ), ) + out.name = name + return out def LONGBLOB( @@ -1203,7 +1435,8 @@ def LONGBLOB( nullable: bool = True, default: Optional[bytes] = None, collate: Optional[str] = None, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ LONGBLOB type specification. @@ -1217,16 +1450,22 @@ def LONGBLOB( Default value collate : str, optional Collation + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ out = f'LONGBLOB({int(length)})' if length else 'LONGBLOB' - return out + _modifiers( - nullable=nullable, default=default, collate=collate, + out = SQLString( + out + _modifiers( + nullable=nullable, default=default, collate=collate, + ), ) + out.name = name + return out def MEDIUMBLOB( @@ -1235,7 +1474,8 @@ def MEDIUMBLOB( nullable: bool = True, default: Optional[bytes] = None, collate: Optional[str] = None, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ MEDIUMBLOB type specification. @@ -1249,16 +1489,22 @@ def MEDIUMBLOB( Default value collate : str, optional Collation + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ out = f'MEDIUMBLOB({int(length)})' if length else 'MEDIUMBLOB' - return out + _modifiers( - nullable=nullable, default=default, collate=collate, + out = SQLString( + out + _modifiers( + nullable=nullable, default=default, collate=collate, + ), ) + out.name = name + return out def BLOB( @@ -1267,7 +1513,8 @@ def BLOB( nullable: bool = True, default: Optional[bytes] = None, collate: Optional[str] = None, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ BLOB type specification. @@ -1281,16 +1528,22 @@ def BLOB( Default value collate : str, optional Collation + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ out = f'BLOB({int(length)})' if length else 'BLOB' - return out + _modifiers( - nullable=nullable, default=default, collate=collate, + out = SQLString( + out + _modifiers( + nullable=nullable, default=default, collate=collate, + ), ) + out.name = name + return out def TINYBLOB( @@ -1299,7 +1552,8 @@ def TINYBLOB( nullable: bool = True, default: Optional[bytes] = None, collate: Optional[str] = None, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ TINYBLOB type specification. @@ -1313,16 +1567,22 @@ def TINYBLOB( Default value collate : str, optional Collation + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ out = f'TINYBLOB({int(length)})' if length else 'TINYBLOB' - return out + _modifiers( - nullable=nullable, default=default, collate=collate, + out = SQLString( + out + _modifiers( + nullable=nullable, default=default, collate=collate, + ), ) + out.name = name + return out def JSON( @@ -1332,7 +1592,8 @@ def JSON( default: Optional[str] = None, collate: Optional[str] = None, charset: Optional[str] = None, -) -> str: + name: Optional[str] = None, +) -> SQLString: """ JSON type specification. @@ -1348,20 +1609,31 @@ def JSON( Collation charset : str, optional Character set + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ out = f'JSON({int(length)})' if length else 'JSON' - return out + _modifiers( - nullable=nullable, default=default, - collate=collate, charset=charset, + out = SQLString( + out + _modifiers( + nullable=nullable, default=default, + collate=collate, charset=charset, + ), ) + out.name = name + return out -def GEOGRAPHYPOINT(*, nullable: bool = True, default: Optional[str] = None) -> str: +def GEOGRAPHYPOINT( + *, + nullable: bool = True, + default: Optional[str] = None, + name: Optional[str] = None, +) -> SQLString: """ GEOGRAPHYPOINT type specification. @@ -1371,16 +1643,25 @@ def GEOGRAPHYPOINT(*, nullable: bool = True, default: Optional[str] = None) -> s Can the value be NULL? default : str, optional Default value + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ - return 'GEOGRAPHYPOINT' + _modifiers(nullable=nullable, default=default) + out = SQLString('GEOGRAPHYPOINT' + _modifiers(nullable=nullable, default=default)) + out.name = name + return out -def GEOGRAPHY(*, nullable: bool = True, default: Optional[str] = None) -> str: +def GEOGRAPHY( + *, + nullable: bool = True, + default: Optional[str] = None, + name: Optional[str] = None, +) -> SQLString: """ GEOGRAPHYPOINT type specification. @@ -1396,10 +1677,16 @@ def GEOGRAPHY(*, nullable: bool = True, default: Optional[str] = None) -> str: str """ - return 'GEOGRAPHY' + _modifiers(nullable=nullable, default=default) + out = SQLString('GEOGRAPHY' + _modifiers(nullable=nullable, default=default)) + out.name = name + return out -def RECORD(*args: Tuple[str, DataType], nullable: bool = True) -> str: +def RECORD( + *args: Tuple[str, DataType], + nullable: bool = True, + name: Optional[str] = None, +) -> SQLString: """ RECORD type specification. @@ -1409,10 +1696,12 @@ def RECORD(*args: Tuple[str, DataType], nullable: bool = True) -> str: Field specifications nullable : bool, optional Can the value be NULL? + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ assert len(args) > 0 @@ -1422,10 +1711,16 @@ def RECORD(*args: Tuple[str, DataType], nullable: bool = True) -> str: fields.append(f'{escape_name(name)} {value()}') else: fields.append(f'{escape_name(name)} {value}') - return f'RECORD({", ".join(fields)})' + _modifiers(nullable=nullable) + out = SQLString(f'RECORD({", ".join(fields)})' + _modifiers(nullable=nullable)) + out.name = name + return out -def ARRAY(dtype: DataType, nullable: bool = True) -> str: +def ARRAY( + dtype: DataType, + nullable: bool = True, + name: Optional[str] = None, +) -> SQLString: """ ARRAY type specification. @@ -1435,12 +1730,16 @@ def ARRAY(dtype: DataType, nullable: bool = True) -> str: The data type of the array elements nullable : bool, optional Can the value be NULL? + name : str, optional + Name of the column / parameter Returns ------- - str + SQLString """ if callable(dtype): dtype = dtype() - return f'ARRAY({dtype})' + _modifiers(nullable=nullable) + out = SQLString(f'ARRAY({dtype})' + _modifiers(nullable=nullable)) + out.name = name + return out diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index 4afac4908..165391a5d 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -149,7 +149,9 @@ def as_tuple(x: Any) -> Any: return tuple(x.model_dump().values()) if dataclasses.is_dataclass(x): return dataclasses.astuple(x) - return x + if isinstance(x, dict): + return tuple(x.values()) + return tuple(x) def as_list_of_tuples(x: Any) -> Any: @@ -159,9 +161,44 @@ def as_list_of_tuples(x: Any) -> Any: return [tuple(y.model_dump().values()) for y in x] if dataclasses.is_dataclass(x[0]): return [dataclasses.astuple(y) for y in x] + if isinstance(x[0], dict): + return [tuple(y.values()) for y in x] return x +def get_dataframe_columns(df: Any) -> List[Any]: + """Return columns of data from a dataframe/table.""" + rtype = str(type(df)).lower() + if 'dataframe' in rtype: + return [df[x] for x in df.columns] + elif 'table' in rtype: + return df.columns + raise TypeError( + 'Unsupported data type for dataframe columns: ' + f'{rtype}', + ) + + +def get_array_class(data_format: str) -> Callable[..., Any]: + """ + Get the array class for the current data format. + + """ + if data_format == 'polars': + import polars as pl + array_cls = pl.Series + elif data_format == 'arrow': + import pyarrow as pa + array_cls = pa.array + elif data_format == 'pandas': + import pandas as pd + array_cls = pd.Series + else: + import numpy as np + array_cls = np.array + return array_cls + + def make_func( name: str, func: Callable[..., Any], @@ -182,98 +219,89 @@ def make_func( """ attrs = getattr(func, '_singlestoredb_attrs', {}) - data_format = attrs.get('data_format') or 'python' include_masks = attrs.get('include_masks', False) function_type = attrs.get('function_type', 'udf').lower() info: Dict[str, Any] = {} + sig = get_signature(func, func_name=name) + + args_data_format = sig.get('args_data_format', 'scalar') + returns_data_format = sig.get('returns_data_format', 'scalar') + if function_type == 'tvf': - if data_format == 'python': + # Scalar (Python) types + if returns_data_format == 'scalar': async def do_func( row_ids: Sequence[int], rows: Sequence[Sequence[Any]], - ) -> Tuple[ - Sequence[int], - List[Tuple[Any]], - ]: + ) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]: '''Call function on given rows of data.''' out_ids: List[int] = [] out = [] + # Call function on each row of data for i, res in zip(row_ids, func_map(func, rows)): out.extend(as_list_of_tuples(res)) out_ids.extend([row_ids[i]] * (len(out)-len(out_ids))) return out_ids, out + # Vector formats else: - # Vector formats use the same function wrapper + array_cls = get_array_class(returns_data_format) + async def do_func( # type: ignore row_ids: Sequence[int], cols: Sequence[Tuple[Sequence[Any], Optional[Sequence[bool]]]], ) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]: '''Call function on given cols of data.''' + # NOTE: There is no way to determine which row ID belongs to + # each result row, so we just have to use the same + # row ID for all rows in the result. + + # If `include_masks` is set, the function is expected to return + # a tuple of (data, mask) for each column. if include_masks: out = func(*cols) assert isinstance(out, tuple) + row_ids = array_cls([row_ids[0]] * len(out[0][0])) return row_ids, [out] - out = [] - res = func(*[x[0] for x in cols]) - rtype = str(type(res)).lower() + # Call function on each column of data + res = get_dataframe_columns(func(*[x[0] for x in cols])) - # Map tables / dataframes to a list of columns - if 'dataframe' in rtype: - res = [res[x] for x in res.columns] - elif 'table' in rtype: - res = res.columns + # Generate row IDs + row_ids = array_cls([row_ids[0]] * len(res[0])) - for vec in res: - # C extension only supports Python objects as strings - if data_format == 'numpy' and str(vec.dtype)[:2] in [' Tuple[ - Sequence[int], - List[Tuple[Any]], - ]: + ) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]: '''Call function on given rows of data.''' return row_ids, [as_tuple(x) for x in zip(func_map(func, rows))] + # Vector formats else: - # Vector formats use the same function wrapper + array_cls = get_array_class(returns_data_format) + async def do_func( # type: ignore row_ids: Sequence[int], cols: Sequence[Tuple[Sequence[Any], Optional[Sequence[bool]]]], ) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]: '''Call function on given cols of data.''' + row_ids = array_cls(row_ids) + + # If `include_masks` is set, the function is expected to return + # a tuple of (data, mask) for each column.` if include_masks: out = func(*cols) assert isinstance(out, tuple) return row_ids, [out] + # Call the function with `cols` as the function parameters out = func(*[x[0] for x in cols]) # Multiple return values @@ -286,13 +314,12 @@ async def do_func( # type: ignore do_func.__name__ = name do_func.__doc__ = func.__doc__ - sig = get_signature(func, func_name=name) - # Store signature for generating CREATE FUNCTION calls info['signature'] = sig # Set data format - info['data_format'] = data_format + info['args_data_format'] = args_data_format + info['returns_data_format'] = returns_data_format # Set function type info['function_type'] = function_type @@ -312,7 +339,6 @@ async def do_func( # type: ignore dtype = x['dtype'].replace('?', '') if dtype not in rowdat_1_type_map: raise TypeError(f'no data type mapping for {dtype}') - print(x['name'], dtype) returns.append((x['name'], rowdat_1_type_map[dtype])) info['returns'] = returns @@ -399,11 +425,16 @@ class Application(object): # Data format + version handlers handlers = { - (b'application/octet-stream', b'1.0', 'python'): dict( + (b'application/octet-stream', b'1.0', 'scalar'): dict( load=rowdat_1.load, dump=rowdat_1.dump, response=rowdat_1_response_dict, ), + (b'application/octet-stream', b'1.0', 'list'): dict( + load=rowdat_1.load_list, + dump=rowdat_1.dump_list, + response=rowdat_1_response_dict, + ), (b'application/octet-stream', b'1.0', 'pandas'): dict( load=rowdat_1.load_pandas, dump=rowdat_1.dump_pandas, @@ -424,11 +455,16 @@ class Application(object): dump=rowdat_1.dump_arrow, response=rowdat_1_response_dict, ), - (b'application/json', b'1.0', 'python'): dict( + (b'application/json', b'1.0', 'scalar'): dict( load=jdata.load, dump=jdata.dump, response=json_response_dict, ), + (b'application/json', b'1.0', 'list'): dict( + load=jdata.load_list, + dump=jdata.dump_list, + response=json_response_dict, + ), (b'application/json', b'1.0', 'pandas'): dict( load=jdata.load_pandas, dump=jdata.dump_pandas, @@ -449,7 +485,7 @@ class Application(object): dump=jdata.dump_arrow, response=json_response_dict, ), - (b'application/vnd.apache.arrow.file', b'1.0', 'python'): dict( + (b'application/vnd.apache.arrow.file', b'1.0', 'scalar'): dict( load=arrow.load, dump=arrow.dump, response=arrow_response_dict, @@ -642,7 +678,8 @@ async def __call__( # Call the endpoint if method == 'POST' and func is not None and path == self.invoke_path: - data_format = func_info['data_format'] + args_data_format = func_info['args_data_format'] + returns_data_format = func_info['returns_data_format'] data = [] more_body = True while more_body: @@ -651,8 +688,8 @@ async def __call__( more_body = request.get('more_body', False) data_version = headers.get(b's2-ef-version', b'') - input_handler = self.handlers[(content_type, data_version, data_format)] - output_handler = self.handlers[(accepts, data_version, data_format)] + input_handler = self.handlers[(content_type, data_version, args_data_format)] + output_handler = self.handlers[(accepts, data_version, returns_data_format)] out = await func( *input_handler['load']( # type: ignore diff --git a/singlestoredb/functions/ext/json.py b/singlestoredb/functions/ext/json.py index c385e6422..3221b1d4f 100644 --- a/singlestoredb/functions/ext/json.py +++ b/singlestoredb/functions/ext/json.py @@ -313,6 +313,10 @@ def _dump_vectors( return json.dumps(dict(data=data), cls=JSONEncoder).encode('utf-8') +load_list = _load_vectors +dump_list = _dump_vectors + + def dump_pandas( returns: List[int], row_ids: 'pd.Series[int]', diff --git a/singlestoredb/functions/ext/rowdat_1.py b/singlestoredb/functions/ext/rowdat_1.py index 3ef3c4905..22940ba63 100644 --- a/singlestoredb/functions/ext/rowdat_1.py +++ b/singlestoredb/functions/ext/rowdat_1.py @@ -720,6 +720,8 @@ def _dump_arrow_accel( if not has_accel: load = _load_accel = _load dump = _dump_accel = _dump + load_list = _load_vectors # noqa: F811 + dump_list = _dump_vectors # noqa: F811 load_pandas = _load_pandas_accel = _load_pandas # noqa: F811 dump_pandas = _dump_pandas_accel = _dump_pandas # noqa: F811 load_numpy = _load_numpy_accel = _load_numpy # noqa: F811 @@ -734,6 +736,8 @@ def _dump_arrow_accel( _dump_accel = _singlestoredb_accel.dump_rowdat_1 load = _load_accel dump = _dump_accel + load_list = _load_vectors + dump_list = _dump_vectors load_pandas = _load_pandas_accel dump_pandas = _dump_pandas_accel load_numpy = _load_numpy_accel diff --git a/singlestoredb/functions/signature.py b/singlestoredb/functions/signature.py index 79ca65247..514d3b38e 100644 --- a/singlestoredb/functions/signature.py +++ b/singlestoredb/functions/signature.py @@ -41,6 +41,11 @@ _UNION_TYPES = {typing.Union} +def is_union(x: Any) -> bool: + """Check if the object is a Union.""" + return typing.get_origin(x) in _UNION_TYPES + + array_types: Tuple[Any, ...] if has_numpy: @@ -202,11 +207,75 @@ def get_annotations(obj: Any) -> Dict[str, Any]: return getattr(obj, '__annotations__', {}) +def is_numpy(obj: Any) -> bool: + """Check if an object is a numpy array.""" + if is_union(obj): + obj = typing.get_args(obj)[0] + if not has_numpy: + return False + if inspect.isclass(obj): + return obj is np.ndarray + if typing.get_origin(obj) is np.ndarray: + return True + return isinstance(obj, np.ndarray) + + def is_dataframe(obj: Any) -> bool: """Check if an object is a DataFrame.""" + # Cheating here a bit so we don't have to import pandas / polars / pyarrow: + # unless we absolutely need to + if getattr(obj, '__module__', '').startswith('pandas.'): + return getattr(obj, '__name__', '') == 'DataFrame' + if getattr(obj, '__module__', '').startswith('polars.'): + return getattr(obj, '__name__', '') == 'DataFrame' + if getattr(obj, '__module__', '').startswith('pyarrow.'): + return getattr(obj, '__name__', '') == 'Table' + return False + + +def get_data_format(obj: Any) -> str: + """Return the data format of the DataFrame / Table / vector.""" # Cheating here a bit so we don't have to import pandas / polars / pyarrow # unless we absolutely need to - return getattr(obj, '__name__', '') in ['DataFrame', 'Table'] + if getattr(obj, '__module__', '').startswith('pandas.'): + return 'pandas' + if getattr(obj, '__module__', '').startswith('polars.'): + return 'polars' + if getattr(obj, '__module__', '').startswith('pyarrow.'): + return 'pyarrow' + if getattr(obj, '__module__', '').startswith('numpy.'): + return 'numpy' + return 'scalar' + + +def is_pandas_series(obj: Any) -> bool: + """Check if an object is a pandas Series.""" + if is_union(obj): + obj = typing.get_args(obj)[0] + return ( + getattr(obj, '__module__', '').startswith('pandas.') and + getattr(obj, '__name__', '') == 'Series' + ) + + +def is_polars_series(obj: Any) -> bool: + """Check if an object is a polars Series.""" + if is_union(obj): + obj = typing.get_args(obj)[0] + return ( + getattr(obj, '__module__', '').startswith('polars.') and + getattr(obj, '__name__', '') == 'Series' + ) + + +def is_pyarrow_array(obj: Any) -> bool: + """Check if an object is a pyarrow Array.""" + if is_union(obj): + obj = typing.get_args(obj)[0] + return ( + getattr(obj, '__module__', '').startswith('pyarrow.') and + getattr(obj, '__name__', '') == 'Array' + ) def is_typeddict(obj: Any) -> bool: @@ -275,7 +344,7 @@ def simplify_dtype(dtype: Any) -> List[Any]: args = [] # Flatten Unions - if origin in _UNION_TYPES: + if is_union(dtype): for x in typing.get_args(dtype): args.extend(simplify_dtype(x)) @@ -287,7 +356,7 @@ def simplify_dtype(dtype: Any) -> List[Any]: args.extend(simplify_dtype(dtype.__bound__)) # Sequence types - elif origin is not None and issubclass(origin, Sequence): + elif origin is not None and inspect.isclass(origin) and issubclass(origin, Sequence): item_args: List[Union[List[type], type]] = [] for x in typing.get_args(dtype): item_dtype = simplify_dtype(x) @@ -331,6 +400,9 @@ def normalize_dtype(dtype: Any) -> str: if isinstance(dtype, str): return sql_to_dtype(dtype) + if typing.get_origin(dtype) is np.dtype: + dtype = typing.get_args(dtype)[0] + # Specific types if dtype is None or dtype is type(None): # noqa: E721 return 'null' @@ -383,7 +455,7 @@ def normalize_dtype(dtype: Any) -> str: return f'tuple[{item_dtypes}]' # Array types - elif issubclass(origin, array_types): + elif inspect.isclass(origin) and issubclass(origin, array_types): args = typing.get_args(dtype) item_dtype = normalize_dtype(args[0]) return f'array[{item_dtype}]' @@ -536,41 +608,6 @@ def collapse_dtypes(dtypes: Union[str, List[str]]) -> str: return dtypes[0] + ('?' if is_nullable else '') -def create_type( - types: List[Any], - output_fields: List[str], - function_type: str = 'udf', -) -> Tuple[str, str]: - """ - Create the normalized type and SQL code for the given type information. - - Parameters - ---------- - types : List[Any] - List of types to be used - output_fields : List[str] - List of field names for the resulting type - function_type : str - Type of function, either 'udf' or 'tvf' - - Returns - ------- - Tuple[str, str] - Tuple containing the output type and SQL code - - """ - out_type = 'tuple[' + ','.join([ - collapse_dtypes(normalize_dtype(x)) - for x in [simplify_dtype(y) for y in types] - ]) + ']' - - sql = dtype_to_sql( - out_type, function_type=function_type, field_names=output_fields, - ) - - return out_type, sql - - def get_dataclass_schema(obj: Any) -> List[Tuple[str, Any]]: """ Get the schema of a dataclass. @@ -643,11 +680,12 @@ def get_namedtuple_schema(obj: Any) -> List[Tuple[Any, str]]: return list(get_annotations(obj).items()) -def get_return_schema( +def get_schema( spec: Any, - output_fields: Optional[List[str]] = None, + overrides: Optional[List[str]] = None, function_type: str = 'udf', -) -> List[Tuple[str, Any]]: + mode: str = 'parameter', +) -> Tuple[List[Tuple[str, Any, Optional[str]]], str]: """ Expand a return type annotation into a list of types and field names. @@ -655,84 +693,225 @@ def get_return_schema( ---------- spec : Any The return type specification - output_fields : List[str], optional - The output field names + overrides : List[str], optional + List of SQL type specifications for the return type function_type : str The type of function, either 'udf' or 'tvf' + mode : str + The mode of the function, either 'parameter' or 'return' Returns ------- - List[Tuple[str, Any]] - A list of tuples containing the field names and field types + Tuple[List[Tuple[str, Any]], str] + A list of tuples containing the field names and field types, + the normalized data format, and optionally the SQL + definition of the type """ + data_format = 'scalar' + # Make sure that the result of a TVF is a list or dataframe - if function_type == 'tvf': + if function_type == 'tvf' and mode == 'return': + # Use the item type from the list if it's a list if typing.get_origin(spec) is list: spec = typing.get_args(spec)[0] # DataFrames require special handling. You can't get the schema # from the annotation, you need a separate structure to specify - # the types. This should be specified in the output_fields. + # the types. This should be specified in the overrides. elif is_dataframe(spec): - if output_fields is None: + if not overrides: raise TypeError( - 'output_fields must be specified for DataFrames / Tables', + 'type overrides must be specified for DataFrames / Tables', ) - spec = output_fields - output_fields = None + # Unsuported types else: raise TypeError( 'return type for TVF must be a list or DataFrame', ) - elif typing.get_origin(spec) in [list, tuple, dict] \ + # Error out for incorrect types + elif typing.get_origin(spec) in [tuple, dict] \ or is_dataframe(spec) \ or dataclasses.is_dataclass(spec) \ or is_typeddict(spec) \ or is_pydantic(spec) \ or is_namedtuple(spec): + if mode == 'parameter': + raise TypeError('parameter types must be scalar or vector') raise TypeError('return type for UDF must be a scalar type') - # Return type is specified by a SQL string - if isinstance(spec, str): - return [('', sql_to_dtype(spec))] + # + # Process each parameter / return type into a colspec + # + + # Compute overrides colspec from various formats + overrides_colspec = [] + if overrides: + if dataclasses.is_dataclass(overrides): + overrides_colspec = get_dataclass_schema(overrides) + elif is_typeddict(overrides): + overrides_colspec = get_typeddict_schema(overrides) + elif is_namedtuple(overrides): + overrides_colspec = get_namedtuple_schema(overrides) + elif is_pydantic(overrides): + overrides_colspec = get_pydantic_schema(overrides) + else: + overrides_colspec = [(getattr(x, 'name', ''), x) for x in overrides] + + # Numpy array types + if is_numpy(spec): + data_format = 'numpy' + if overrides: + colspec = overrides_colspec + elif len(typing.get_args(spec)) < 2: + raise TypeError( + 'numpy array must have a data type specified ' + 'in the @udf / @tvf decorator or with an NDArray type annotation', + ) + else: + colspec = [('', typing.get_args(spec)[1])] + + # Pandas Series + elif is_pandas_series(spec): + data_format = 'pandas' + if not overrides: + raise TypeError( + 'pandas Series must have a data type specified ' + 'in the @udf / @tvf decorator', + ) + colspec = overrides_colspec + + # Polars Series + elif is_polars_series(spec): + data_format = 'polars' + if not overrides: + raise TypeError( + 'polars Series must have a data type specified ' + 'in the @udf / @tvf decorator', + ) + colspec = overrides_colspec + + # PyArrow Array + elif is_pyarrow_array(spec): + data_format = 'pyarrow' + if not overrides: + raise TypeError( + 'pyarrow Arrays must have a data type specified ' + 'in the @udf / @tvf decorator', + ) + colspec = overrides_colspec # Return type is specified by a dataclass definition - if dataclasses.is_dataclass(spec): - schema = get_dataclass_schema(spec) + elif dataclasses.is_dataclass(spec): + colspec = overrides_colspec or get_dataclass_schema(spec) # Return type is specified by a TypedDict definition elif is_typeddict(spec): - schema = get_typeddict_schema(spec) + colspec = overrides_colspec or get_typeddict_schema(spec) # Return type is specified by a pydantic model elif is_pydantic(spec): - schema = get_pydantic_schema(spec) + colspec = overrides_colspec or get_pydantic_schema(spec) # Return type is specified by a named tuple elif is_namedtuple(spec): - schema = get_namedtuple_schema(spec) + colspec = overrides_colspec or get_namedtuple_schema(spec) # Unrecognized return type elif spec is not None: - if typing.get_origin(spec) is tuple: - output_fields = [ - string.ascii_letters[i] for i in range(len(typing.get_args(spec))) - ] - schema = [(x, y) for x, y in zip(output_fields, typing.get_args(spec))] + + # Return type is specified by a SQL string + if isinstance(spec, str): + data_format = 'scalar' + colspec = [(getattr(spec, 'name', ''), spec)] + + # Plain list vector + elif typing.get_origin(spec) is list: + data_format = 'list' + colspec = [('', typing.get_args(spec)[0])] + + # Multiple return values + elif typing.get_origin(spec) is tuple: + + data_formats, colspec = [], [] + + for i, (y, vec) in enumerate(vector_check(typing.get_args(spec))): + + # Apply override types as needed + if overrides: + colspec.append(overrides_colspec[i]) + + # Some vector types do not have annotated types, so they must + # be specified in the decorator + elif y is None: + raise TypeError( + f'type overrides must be specified for vector type: {vec}', + ) + + else: + colspec.append(('', y)) + + data_formats.append(vec) + + # Make sure that all the data formats are the same + if len(set(data_formats)) > 1: + raise TypeError( + 'data formats must be all be the same vector / scalar type: ' + f'{", ".join(data_formats)}', + ) + + data_format = data_formats[0] + + # Use overrides if specified + elif overrides: + data_format = get_data_format(spec) + colspec = overrides_colspec + + # Single value, no override else: - schema = [('', spec)] + data_format = 'scalar' + colspec = [('', spec)] - # Normalize schema data types + # Normalize colspec data types out = [] - for k, v in schema: + + for k, v in colspec: out.append(( - k, collapse_dtypes([normalize_dtype(x) for x in simplify_dtype(v)]), + k, + collapse_dtypes([normalize_dtype(x) for x in simplify_dtype(v)]), + v if isinstance(v, str) else None, )) - return out + + return out, data_format + + +def vector_check(obj: Any) -> Tuple[Any, str]: + """ + Check if the object is a vector type. + + Parameters + ---------- + obj : Any + The object to check + + Returns + ------- + Tuple[Any, str] + The scalar type and the data format ('scalar', 'numpy', 'pandas', 'polars') + + """ + if is_numpy(obj): + return typing.get_args(obj)[1], 'numpy' + if is_pandas_series(obj): + return None, 'pandas' + if is_polars_series(obj): + return None, 'polars' + if is_pyarrow_array(obj): + return None, 'pyarrow' + return obj, 'scalar' def get_signature( @@ -757,22 +936,12 @@ def get_signature( signature = inspect.signature(func) args: List[Dict[str, Any]] = [] returns: List[Dict[str, Any]] = [] + attrs = getattr(func, '_singlestoredb_attrs', {}) - name = attrs.get('name', func_name if func_name else func.__name__) function_type = attrs.get('function_type', 'udf') - out: Dict[str, Any] = dict(name=name, args=args, returns=returns) + name = attrs.get('name', func_name if func_name else func.__name__) - # Get parameter names, defaults, and annotations - arg_names = [x for x in signature.parameters] - args_overrides = attrs.get('args', None) - defaults = [ - x.default if x.default is not inspect.Parameter.empty else None - for x in signature.parameters.values() - ] - annotations = { - k: x.annotation for k, x in signature.parameters.items() - if x.annotation is not inspect.Parameter.empty - } + out: Dict[str, Any] = dict(name=name, args=args, returns=returns) # Do not allow variable positional or keyword arguments for p in signature.parameters.values(): @@ -781,84 +950,55 @@ def get_signature( elif p.kind == inspect.Parameter.VAR_KEYWORD: raise TypeError('variable keyword arguments are not supported') - spec_diff = set(arg_names).difference(set(annotations.keys())) - - # - # Make sure all arguments are annotated - # - - # If there are missing annotations and no overrides, raise an error - if spec_diff and args_overrides is None: - raise TypeError( - 'missing annotations for {} in {}' - .format(', '.join(spec_diff), name), - ) - - # If there are missing annotations and overrides are provided, make sure they match - elif isinstance(args_overrides, dict): - for s in spec_diff: - if s not in args_overrides: - raise TypeError( - 'missing annotations for {} in {}' - .format(', '.join(spec_diff), name), - ) - - # If there are missing annotations and overrides are provided, make sure they match - elif isinstance(args_overrides, list): - if len(arg_names) != len(args_overrides): - raise TypeError( - 'number of annotations does not match in {}: {}' - .format(name, ', '.join(spec_diff)), - ) - - # # Generate the parameter type and the corresponding SQL code for that parameter - # - - for i, arg in enumerate(arg_names): - - # If arg_overrides is a list, use corresponding item as SQL - if isinstance(args_overrides, list): - sql = args_overrides[i] - arg_type = sql_to_dtype(sql) - - # If arg_overrides is a dict, use the corresponding key as SQL - elif isinstance(args_overrides, dict) and arg in args_overrides: - sql = args_overrides[arg] - arg_type = sql_to_dtype(sql) + args_schema = [] + args_data_formats = [] + for param in signature.parameters.values(): + arg_schema, args_data_format = get_schema( + param.annotation, + overrides=attrs.get('args', None), + function_type=function_type, + mode='parameter', + ) + args_data_formats.append(args_data_format) - # If args_overrides is a string, use it as SQL (only one function parameter) - elif isinstance(args_overrides, str): - sql = args_overrides - arg_type = sql_to_dtype(sql) + # Insert parameter names as needed + if not arg_schema[0][0]: + args_schema.append((param.name, *arg_schema[0][1:])) - # Unrecognized type for args_overrides - elif args_overrides is not None \ - and not isinstance(args_overrides, (list, dict, str)): - raise TypeError(f'unrecognized type for arguments: {args_overrides}') + for i, (name, atype, sql) in enumerate(args_schema): + sql = sql or dtype_to_sql( + atype, + function_type=function_type, + default=param.default if param.default is not param.empty else None, + ) + args.append(dict(name=name, dtype=atype, sql=sql)) - # No args_overrides, use the Python type annotation - else: - arg_type = collapse_dtypes([ - normalize_dtype(x) for x in simplify_dtype(annotations[arg]) - ]) - sql = dtype_to_sql(arg_type, function_type=function_type) + # Check that all the data formats are all the same + if len(set(args_data_formats)) > 1: + raise TypeError( + 'input data formats must be all be the same: ' + f'{", ".join(args_data_formats)}', + ) - # Append parameter information to the args list - args.append(dict(name=arg, dtype=arg_type, sql=sql, default=defaults[i])) + out['args_data_format'] = args_data_formats[0] - # # Generate the return types and the corresponding SQL code for those values - # - - ret_schema = get_return_schema( - attrs.get('returns', signature.return_annotation), - output_fields=attrs.get('output_fields', None), + ret_schema, out['returns_data_format'] = get_schema( + signature.return_annotation, + overrides=attrs.get('returns', None), function_type=function_type, + mode='return', ) - for i, (name, rtype) in enumerate(ret_schema): - sql = dtype_to_sql(rtype, function_type=function_type) + # Generate names for fields as needed + if function_type == 'tvf' or len(ret_schema) > 1: + for i, (name, rtype, sql) in enumerate(ret_schema): + if not name: + ret_schema[i] = (string.ascii_letters[i], rtype, sql) + + for i, (name, rtype, sql) in enumerate(ret_schema): + sql = sql or dtype_to_sql(rtype, function_type=function_type) returns.append(dict(name=name, dtype=rtype, sql=sql)) # Copy keys from decorator to signature From 2148a1f253666093e340a3419da130413b6a27b6 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Thu, 10 Apr 2025 14:03:39 -0500 Subject: [PATCH 05/16] Add default values from pydantic / namedtuple / etc --- singlestoredb/functions/signature.py | 170 ++++++++++++++++++++++----- 1 file changed, 141 insertions(+), 29 deletions(-) diff --git a/singlestoredb/functions/signature.py b/singlestoredb/functions/signature.py index 514d3b38e..7bb9a370c 100644 --- a/singlestoredb/functions/signature.py +++ b/singlestoredb/functions/signature.py @@ -16,6 +16,7 @@ from typing import Optional from typing import Sequence from typing import Tuple +from typing import Type from typing import TypeVar from typing import Union @@ -27,6 +28,7 @@ try: import pydantic + import pydantic_core has_pydantic = True except ImportError: has_pydantic = False @@ -46,6 +48,9 @@ def is_union(x: Any) -> bool: return typing.get_origin(x) in _UNION_TYPES +NO_DEFAULT = object() + + array_types: Tuple[Any, ...] if has_numpy: @@ -608,7 +613,10 @@ def collapse_dtypes(dtypes: Union[str, List[str]]) -> str: return dtypes[0] + ('?' if is_nullable else '') -def get_dataclass_schema(obj: Any) -> List[Tuple[str, Any]]: +def get_dataclass_schema( + obj: Any, + include_default: bool = False, +) -> List[Union[Tuple[str, Any], Tuple[str, Any, Any]]]: """ Get the schema of a dataclass. @@ -619,14 +627,25 @@ def get_dataclass_schema(obj: Any) -> List[Tuple[str, Any]]: Returns ------- - List[Tuple[str, Any]] + List[Tuple[str, Any]] | List[Tuple[str, Any, Any]] A list of tuples containing the field names and field types """ - return list(get_annotations(obj).items()) + if include_default: + return [ + ( + f.name, f.type, + NO_DEFAULT if f.default is dataclasses.MISSING else f.default, + ) + for f in dataclasses.fields(obj) + ] + return [(f.name, f.type) for f in dataclasses.fields(obj)] -def get_typeddict_schema(obj: Any) -> List[Tuple[str, Any]]: +def get_typeddict_schema( + obj: Any, + include_default: bool = False, +) -> List[Union[Tuple[str, Any], Tuple[str, Any, Any]]]: """ Get the schema of a TypedDict. @@ -634,17 +653,27 @@ def get_typeddict_schema(obj: Any) -> List[Tuple[str, Any]]: ---------- obj : TypedDict The TypedDict to get the schema of + include_default : bool, optional + Whether to include the default value in the column specification Returns ------- - List[Tuple[str, Any]] + List[Tuple[str, Any]] | List[Tuple[str, Any, Any]] A list of tuples containing the field names and field types """ + if include_default: + return [ + (k, v, getattr(obj, 'k', NO_DEFAULT)) + for k, v in get_annotations(obj).items() + ] return list(get_annotations(obj).items()) -def get_pydantic_schema(obj: pydantic.BaseModel) -> List[Tuple[str, Any]]: +def get_pydantic_schema( + obj: pydantic.BaseModel, + include_default: bool = False, +) -> List[Union[Tuple[str, Any], Tuple[str, Any, Any]]]: """ Get the schema of a pydantic model. @@ -652,17 +681,30 @@ def get_pydantic_schema(obj: pydantic.BaseModel) -> List[Tuple[str, Any]]: ---------- obj : pydantic.BaseModel The pydantic model to get the schema of + include_default : bool, optional + Whether to include the default value in the column specification Returns ------- - List[Tuple[str, Any]] + List[Tuple[str, Any]] | List[Tuple[str, Any, Any]] A list of tuples containing the field names and field types """ + if include_default: + return [ + ( + k, v.annotation, + NO_DEFAULT if v.default is pydantic_core.PydanticUndefined else v.default, + ) + for k, v in obj.model_fields.items() + ] return [(k, v.annotation) for k, v in obj.model_fields.items()] -def get_namedtuple_schema(obj: Any) -> List[Tuple[Any, str]]: +def get_namedtuple_schema( + obj: Any, + include_default: bool = False, +) -> List[Union[Tuple[Any, str], Tuple[Any, str, Any]]]: """ Get the schema of a named tuple. @@ -670,19 +712,85 @@ def get_namedtuple_schema(obj: Any) -> List[Tuple[Any, str]]: ---------- obj : NamedTuple The named tuple to get the schema of + include_default : bool, optional + Whether to include the default value in the column specification Returns ------- - List[Tuple[Any, str]] + List[Tuple[Any, str]] | List[Tuple[Any, str, Any]] A list of tuples containing the field names and field types """ + if include_default: + return [ + ( + k, + v, + obj._field_defaults.get(k, NO_DEFAULT), + ) + for k, v in get_annotations(obj).items() + ] return list(get_annotations(obj).items()) +def get_colspec( + overrides: Any, + include_default: bool = False, +) -> List[Union[Tuple[str, Any], Tuple[str, Any, Any]]]: + """ + Get the column specification from the overrides. + + Parameters + ---------- + overrides : Any + The overrides to get the column specification from + include_default : bool, optional + Whether to include the default value in the column specification + + Returns + ------- + List[Tuple[str, Any]] | List[Tuple[str, Any, Any]] + A list of tuples containing the field names and field types + + """ + overrides_colspec = [] + if overrides: + if dataclasses.is_dataclass(overrides): + overrides_colspec = get_dataclass_schema( + overrides, include_default=include_default, + ) + elif is_typeddict(overrides): + overrides_colspec = get_typeddict_schema( + overrides, include_default=include_default, + ) + elif is_namedtuple(overrides): + overrides_colspec = get_namedtuple_schema( + overrides, include_default=include_default, + ) + elif is_pydantic(overrides): + overrides_colspec = get_pydantic_schema( + overrides, include_default=include_default, + ) + elif isinstance(overrides, list): + if include_default: + overrides_colspec = [ + (getattr(x, 'name', ''), x, NO_DEFAULT) for x in overrides + ] + else: + overrides_colspec = [(getattr(x, 'name', ''), x) for x in overrides] + else: + if include_default: + overrides_colspec = [ + (getattr(overrides, 'name', ''), overrides, NO_DEFAULT), + ] + else: + overrides_colspec = [(getattr(overrides, 'name', ''), overrides)] + return overrides_colspec + + def get_schema( spec: Any, - overrides: Optional[List[str]] = None, + overrides: Optional[Union[List[str], Type[Any]]] = None, function_type: str = 'udf', mode: str = 'parameter', ) -> Tuple[List[Tuple[str, Any, Optional[str]]], str]: @@ -748,18 +856,7 @@ def get_schema( # # Compute overrides colspec from various formats - overrides_colspec = [] - if overrides: - if dataclasses.is_dataclass(overrides): - overrides_colspec = get_dataclass_schema(overrides) - elif is_typeddict(overrides): - overrides_colspec = get_typeddict_schema(overrides) - elif is_namedtuple(overrides): - overrides_colspec = get_namedtuple_schema(overrides) - elif is_pydantic(overrides): - overrides_colspec = get_pydantic_schema(overrides) - else: - overrides_colspec = [(getattr(x, 'name', ''), x) for x in overrides] + overrides_colspec = get_colspec(overrides) # Numpy array types if is_numpy(spec): @@ -878,7 +975,7 @@ def get_schema( # Normalize colspec data types out = [] - for k, v in colspec: + for k, v, *_ in colspec: out.append(( k, collapse_dtypes([normalize_dtype(x) for x in simplify_dtype(v)]), @@ -953,10 +1050,13 @@ def get_signature( # Generate the parameter type and the corresponding SQL code for that parameter args_schema = [] args_data_formats = [] - for param in signature.parameters.values(): + args_colspec = [x for x in get_colspec(attrs.get('args', []), include_default=True)] + args_overrides = [x[1] for x in args_colspec] + args_defaults = [x[2] for x in args_colspec] # type: ignore + for i, param in enumerate(signature.parameters.values()): arg_schema, args_data_format = get_schema( param.annotation, - overrides=attrs.get('args', None), + overrides=args_overrides[i] if args_overrides else [], function_type=function_type, mode='parameter', ) @@ -967,12 +1067,24 @@ def get_signature( args_schema.append((param.name, *arg_schema[0][1:])) for i, (name, atype, sql) in enumerate(args_schema): + # Get default value + default_option = {} + if args_defaults: + if args_defaults[i] is not NO_DEFAULT: + default_option['default'] = args_defaults[i] + else: + if param.default is not param.empty: + default_option['default'] = param.default + + # Generate SQL code for the parameter sql = sql or dtype_to_sql( atype, function_type=function_type, - default=param.default if param.default is not param.empty else None, + **default_option, ) - args.append(dict(name=name, dtype=atype, sql=sql)) + + # Add parameter to args definitions + args.append(dict(name=name, dtype=atype, sql=sql, **default_option)) # Check that all the data formats are all the same if len(set(args_data_formats)) > 1: @@ -1059,7 +1171,7 @@ def sql_to_dtype(sql: str) -> str: def dtype_to_sql( dtype: str, - default: Any = None, + default: Any = NO_DEFAULT, field_names: Optional[List[str]] = None, function_type: str = 'udf', ) -> str: @@ -1092,7 +1204,7 @@ def dtype_to_sql( nullable = '' default_clause = '' - if default is not None: + if default is not NO_DEFAULT: if default is dt.NULL: default = None default_clause = f' DEFAULT {escape_item(default, "utf8")}' From 6a638856924c5998032119964fb15c8f71e8c263 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Thu, 10 Apr 2025 15:52:36 -0500 Subject: [PATCH 06/16] Fix TypedDict defaults --- singlestoredb/functions/signature.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/singlestoredb/functions/signature.py b/singlestoredb/functions/signature.py index 7bb9a370c..e86fc714e 100644 --- a/singlestoredb/functions/signature.py +++ b/singlestoredb/functions/signature.py @@ -48,7 +48,11 @@ def is_union(x: Any) -> bool: return typing.get_origin(x) in _UNION_TYPES -NO_DEFAULT = object() +class NoDefaultType: + pass + + +NO_DEFAULT = NoDefaultType() array_types: Tuple[Any, ...] @@ -664,7 +668,7 @@ def get_typeddict_schema( """ if include_default: return [ - (k, v, getattr(obj, 'k', NO_DEFAULT)) + (k, v, getattr(obj, k, NO_DEFAULT)) for k, v in get_annotations(obj).items() ] return list(get_annotations(obj).items()) @@ -724,8 +728,7 @@ def get_namedtuple_schema( if include_default: return [ ( - k, - v, + k, v, obj._field_defaults.get(k, NO_DEFAULT), ) for k, v in get_annotations(obj).items() @@ -754,23 +757,34 @@ def get_colspec( """ overrides_colspec = [] + if overrides: + + # Dataclass if dataclasses.is_dataclass(overrides): overrides_colspec = get_dataclass_schema( overrides, include_default=include_default, ) + + # TypedDict elif is_typeddict(overrides): overrides_colspec = get_typeddict_schema( overrides, include_default=include_default, ) + + # Named tuple elif is_namedtuple(overrides): overrides_colspec = get_namedtuple_schema( overrides, include_default=include_default, ) + + # Pydantic model elif is_pydantic(overrides): overrides_colspec = get_pydantic_schema( overrides, include_default=include_default, ) + + # List of types elif isinstance(overrides, list): if include_default: overrides_colspec = [ @@ -778,6 +792,8 @@ def get_colspec( ] else: overrides_colspec = [(getattr(x, 'name', ''), x) for x in overrides] + + # Other else: if include_default: overrides_colspec = [ @@ -785,6 +801,7 @@ def get_colspec( ] else: overrides_colspec = [(getattr(overrides, 'name', ''), overrides)] + return overrides_colspec From b1290252e48d78c162a79a8912e365fff0658a66 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Thu, 10 Apr 2025 16:19:57 -0500 Subject: [PATCH 07/16] Disable lifespan in uvicorn --- singlestoredb/functions/ext/asgi.py | 1 + 1 file changed, 1 insertion(+) diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index 165391a5d..e96c20e5d 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -1289,6 +1289,7 @@ def main(argv: Optional[List[str]] = None) -> None: host=args.host or None, port=args.port or None, log_level=args.log_level, + lifespan='off', ).items() if v is not None } From 7a21f465cf16e13d879f1aa2ff533b78529f9f77 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Fri, 11 Apr 2025 11:38:17 -0500 Subject: [PATCH 08/16] Fix multiple vector output in TVFs; fix ucs4 to utf8 transcoding in numpy fixed strings --- accel.c | 75 ++++++++++++++++++++++++++- singlestoredb/functions/ext/asgi.py | 37 +++++++++---- singlestoredb/functions/signature.py | 77 ++++++++++++++++++++-------- 3 files changed, 157 insertions(+), 32 deletions(-) diff --git a/accel.c b/accel.c index 3c16f2182..499a800b2 100644 --- a/accel.c +++ b/accel.c @@ -372,6 +372,70 @@ char *_PyUnicode_AsUTF8(PyObject *unicode) { return out; } +// Function to convert a UCS-4 string to a UTF-8 string +// Returns the length of the resulting UTF-8 string, or -1 on error +int ucs4_to_utf8(const uint32_t *ucs4_str, size_t ucs4_len, char **utf8_str) { + if (!ucs4_str || !utf8_str) { + return -1; // Invalid input + } + + // Allocate a buffer for the UTF-8 string (worst-case: 4 bytes per UCS-4 character) + size_t utf8_max_len = ucs4_len * 4 + 1; // +1 for null terminator + *utf8_str = malloc(utf8_max_len); + if (!*utf8_str) { + return -1; // Memory allocation failed + } + + char *utf8_ptr = *utf8_str; + size_t utf8_len = 0; + + for (size_t i = 0; i < ucs4_len; i++) { + uint32_t codepoint = ucs4_str[i]; + + if (codepoint <= 0x7F) { + // 1-byte UTF-8 + if (utf8_len + 1 > utf8_max_len) goto error; // Buffer overflow + *utf8_ptr++ = (char)codepoint; + utf8_len += 1; + } else if (codepoint <= 0x7FF) { + // 2-byte UTF-8 + if (utf8_len + 2 > utf8_max_len) goto error; // Buffer overflow + *utf8_ptr++ = (char)(0xC0 | (codepoint >> 6)); + *utf8_ptr++ = (char)(0x80 | (codepoint & 0x3F)); + utf8_len += 2; + } else if (codepoint <= 0xFFFF) { + // 3-byte UTF-8 + if (utf8_len + 3 > utf8_max_len) goto error; // Buffer overflow + *utf8_ptr++ = (char)(0xE0 | (codepoint >> 12)); + *utf8_ptr++ = (char)(0x80 | ((codepoint >> 6) & 0x3F)); + *utf8_ptr++ = (char)(0x80 | (codepoint & 0x3F)); + utf8_len += 3; + } else if (codepoint <= 0x10FFFF) { + // 4-byte UTF-8 + if (utf8_len + 4 > utf8_max_len) goto error; // Buffer overflow + *utf8_ptr++ = (char)(0xF0 | (codepoint >> 18)); + *utf8_ptr++ = (char)(0x80 | ((codepoint >> 12) & 0x3F)); + *utf8_ptr++ = (char)(0x80 | ((codepoint >> 6) & 0x3F)); + *utf8_ptr++ = (char)(0x80 | (codepoint & 0x3F)); + utf8_len += 4; + } else { + // Invalid codepoint + goto error; + } + } + + // Null-terminate the UTF-8 string + if (utf8_len + 1 > utf8_max_len) goto error; // Buffer overflow + *utf8_ptr = '\0'; + + return (int)utf8_len; + +error: + free(*utf8_str); + *utf8_str = NULL; + return -1; +} + // // Cached int values for date/time components // @@ -3873,13 +3937,20 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k memcpy(out+out_idx, &i64, 8); out_idx += 8; } else { - Py_ssize_t str_l = strnlen(bytes, col_types[i].length); + char *utf8_str = NULL; + Py_ssize_t str_l = ucs4_to_utf8(bytes, col_types[i].length, &utf8_str); + if (str_l < 0) { + PyErr_SetString(PyExc_ValueError, "invalid UCS4 string"); + if (utf8_str) free(utf8_str); + goto error; + } CHECKMEM(8+str_l); i64 = str_l; memcpy(out+out_idx, &i64, 8); out_idx += 8; - memcpy(out+out_idx, bytes, str_l); + memcpy(out+out_idx, utf8_str, str_l); out_idx += str_l; + free(utf8_str); } } else { diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index e96c20e5d..5016dc811 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -173,6 +173,12 @@ def get_dataframe_columns(df: Any) -> List[Any]: return [df[x] for x in df.columns] elif 'table' in rtype: return df.columns + elif 'series' in rtype: + return [df] + elif 'array' in rtype: + return [df] + elif 'tuple' in rtype: + return list(df) raise TypeError( 'Unsupported data type for dataframe columns: ' f'{rtype}', @@ -391,6 +397,13 @@ class Application(object): headers=[(b'content-type', b'text/plain')], ) + # Error response start + error_response_dict: Dict[str, Any] = dict( + type='http.response.start', + status=401, + headers=[(b'content-type', b'text/plain')], + ) + # JSON response start json_response_dict: Dict[str, Any] = dict( type='http.response.start', @@ -691,16 +704,22 @@ async def __call__( input_handler = self.handlers[(content_type, data_version, args_data_format)] output_handler = self.handlers[(accepts, data_version, returns_data_format)] - out = await func( - *input_handler['load']( # type: ignore - func_info['colspec'], b''.join(data), - ), - ) - body = output_handler['dump']( - [x[1] for x in func_info['returns']], *out, # type: ignore - ) + try: + out = await func( + *input_handler['load']( # type: ignore + func_info['colspec'], b''.join(data), + ), + ) + print(func_info, *out) + body = output_handler['dump']( + [x[1] for x in func_info['returns']], *out, # type: ignore + ) + await send(output_handler['response']) - await send(output_handler['response']) + except Exception as e: + logging.exception('Error in function call') + body = f'[{type(e).__name__}] {str(e).strip()}'.encode('utf-8') + await send(self.error_response_dict) # Handle api reflection elif method == 'GET' and path == self.show_create_function_path: diff --git a/singlestoredb/functions/signature.py b/singlestoredb/functions/signature.py index e86fc714e..84a849a68 100644 --- a/singlestoredb/functions/signature.py +++ b/singlestoredb/functions/signature.py @@ -242,6 +242,14 @@ def is_dataframe(obj: Any) -> bool: return False +def is_vector(obj: Any) -> bool: + """Check if an object is a vector type.""" + return is_pandas_series(obj) \ + or is_polars_series(obj) \ + or is_pyarrow_array(obj) \ + or is_numpy(obj) + + def get_data_format(obj: Any) -> str: """Return the data format of the DataFrame / Table / vector.""" # Cheating here a bit so we don't have to import pandas / polars / pyarrow @@ -254,6 +262,8 @@ def get_data_format(obj: Any) -> str: return 'pyarrow' if getattr(obj, '__module__', '').startswith('numpy.'): return 'numpy' + if isinstance(obj, list): + return 'list' return 'scalar' @@ -842,10 +852,18 @@ def get_schema( if typing.get_origin(spec) is list: spec = typing.get_args(spec)[0] + # If it's a tuple, it must be a tuple of vectors + elif typing.get_origin(spec) is tuple: + if not all([is_vector(x) for x in typing.get_args(spec)]): + raise TypeError( + 'return type for TVF must be a list, DataFrame / Table, ' + 'or tuple of vectors', + ) + # DataFrames require special handling. You can't get the schema # from the annotation, you need a separate structure to specify # the types. This should be specified in the overrides. - elif is_dataframe(spec): + elif is_dataframe(spec) or is_vector(spec): if not overrides: raise TypeError( 'type overrides must be specified for DataFrames / Tables', @@ -854,7 +872,8 @@ def get_schema( # Unsuported types else: raise TypeError( - 'return type for TVF must be a list or DataFrame', + 'return type for TVF must be a list, DataFrame / Table, ' + 'or tuple of vectors', ) # Error out for incorrect types @@ -950,34 +969,48 @@ def get_schema( # Multiple return values elif typing.get_origin(spec) is tuple: - data_formats, colspec = [], [] - - for i, (y, vec) in enumerate(vector_check(typing.get_args(spec))): - - # Apply override types as needed - if overrides: - colspec.append(overrides_colspec[i]) - - # Some vector types do not have annotated types, so they must - # be specified in the decorator - elif y is None: - raise TypeError( - f'type overrides must be specified for vector type: {vec}', + out_names, out_overrides = [], [] + if overrides: + out_colspec = [ + x for x in get_colspec( + overrides, include_default=True, ) + ] + out_names = [x[0] for x in out_colspec] + out_overrides = [x[1] for x in out_colspec] + + colspec = [] + out_data_formats = [] + for i, x in enumerate(typing.get_args(spec)): + out_item, out_data_format = get_schema( + x, + overrides=out_overrides[i] if out_overrides else [], + # Always use UDF for individual items + function_type='udf', + mode=mode, + ) - else: - colspec.append(('', y)) + # Use the name from the overrides if specified + if out_names and out_names[i] and not out_item[0][0]: + out_item = [(out_names[i], *out_item[0][1:])] + elif not out_item[0][0]: + out_item = [(f'{string.ascii_letters[i]}', *out_item[0][1:])] - data_formats.append(vec) + colspec += out_item + out_data_formats.append(out_data_format) # Make sure that all the data formats are the same - if len(set(data_formats)) > 1: + if len(set(out_data_formats)) > 1: raise TypeError( 'data formats must be all be the same vector / scalar type: ' - f'{", ".join(data_formats)}', + f'{", ".join(out_data_formats)}', ) - data_format = data_formats[0] + data_format = out_data_formats[0] + + # Since the colspec was computed by get_schema already, don't go + # through the process of normalizing the dtypes again + return colspec, data_format # type: ignore # Use overrides if specified elif overrides: @@ -1018,6 +1051,8 @@ def vector_check(obj: Any) -> Tuple[Any, str]: """ if is_numpy(obj): + if len(typing.get_args(obj)) < 2: + return None, 'numpy' return typing.get_args(obj)[1], 'numpy' if is_pandas_series(obj): return None, 'pandas' From b36f7bf6370fa844e263c09bc62e323be1da0e66 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Fri, 11 Apr 2025 11:40:03 -0500 Subject: [PATCH 09/16] Remove debugging code --- singlestoredb/functions/ext/asgi.py | 1 - 1 file changed, 1 deletion(-) diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index 5016dc811..1e110cb3b 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -710,7 +710,6 @@ async def __call__( func_info['colspec'], b''.join(data), ), ) - print(func_info, *out) body = output_handler['dump']( [x[1] for x in func_info['returns']], *out, # type: ignore ) From 1b02a3568761fcec8e8f61eef581068c7e111a2b Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Tue, 15 Apr 2025 17:10:21 -0500 Subject: [PATCH 10/16] Fix masked types; add asgi endpoint for function info --- singlestoredb/config.py | 12 + singlestoredb/functions/__init__.py | 4 + singlestoredb/functions/decorator.py | 229 +++++++-- singlestoredb/functions/ext/asgi.py | 105 +++- singlestoredb/functions/ext/json.py | 53 +- singlestoredb/functions/ext/mmap.py | 2 +- singlestoredb/functions/ext/rowdat_1.py | 111 ++--- singlestoredb/functions/signature.py | 294 +++++------ singlestoredb/functions/typing.py | 38 ++ singlestoredb/functions/utils.py | 152 ++++++ singlestoredb/tests/ext_funcs/__init__.py | 573 +++++++++++++--------- singlestoredb/tests/test_ext_func.py | 8 +- singlestoredb/tests/test_udf.py | 130 ++--- 13 files changed, 1073 insertions(+), 638 deletions(-) create mode 100644 singlestoredb/functions/typing.py create mode 100644 singlestoredb/functions/utils.py diff --git a/singlestoredb/config.py b/singlestoredb/config.py index 664635977..de386ba03 100644 --- a/singlestoredb/config.py +++ b/singlestoredb/config.py @@ -407,6 +407,18 @@ environ=['SINGLESTOREDB_EXT_FUNC_LOG_LEVEL'], ) +register_option( + 'external_function.name_prefix', 'string', check_str, '', + 'Prefix to add to external function names.', + environ=['SINGLESTOREDB_EXT_FUNC_NAME_PREFIX'], +) + +register_option( + 'external_function.name_suffix', 'string', check_str, '', + 'Suffix to add to external function names.', + environ=['SINGLESTOREDB_EXT_FUNC_NAME_SUFFIX'], +) + register_option( 'external_function.connection', 'string', check_str, os.environ.get('SINGLESTOREDB_URL') or None, diff --git a/singlestoredb/functions/__init__.py b/singlestoredb/functions/__init__.py index 38b95f5e4..01b059f76 100644 --- a/singlestoredb/functions/__init__.py +++ b/singlestoredb/functions/__init__.py @@ -1,2 +1,6 @@ from .decorator import tvf # noqa: F401 +from .decorator import tvf_with_null_masks # noqa: F401 from .decorator import udf # noqa: F401 +from .decorator import udf_with_null_masks # noqa: F401 +from .typing import Masked # noqa: F401 +from .typing import MaskedNDArray # noqa: F401 diff --git a/singlestoredb/functions/decorator.py b/singlestoredb/functions/decorator.py index 82ade5a46..e1d204b71 100644 --- a/singlestoredb/functions/decorator.py +++ b/singlestoredb/functions/decorator.py @@ -1,7 +1,6 @@ -from __future__ import annotations - import functools import inspect +import typing from typing import Any from typing import Callable from typing import List @@ -9,17 +8,92 @@ from typing import Type from typing import Union +from . import utils +from .dtypes import SQLString + ParameterType = Union[ str, - Callable[..., str], - List[Union[str, Callable[..., str]]], + Callable[..., SQLString], + List[Union[str, Callable[..., SQLString]]], Type[Any], ] ReturnType = ParameterType +def is_valid_type(obj: Any) -> bool: + """Check if the object is a valid type for a schema definition.""" + if not inspect.isclass(obj): + return False + + if utils.is_typeddict(obj): + return True + + if utils.is_namedtuple(obj): + return True + + if utils.is_dataclass(obj): + return True + + # We don't want to import pydantic here, so we check if + # the class is a subclass + if utils.is_pydantic(obj): + return True + + return False + + +def is_valid_callable(obj: Any) -> bool: + """Check if the object is a valid callable for a parameter type.""" + if not callable(obj): + return False + + returns = inspect.get_annotations(obj).get('return', None) + + if inspect.isclass(returns) and issubclass(returns, str): + return True + + raise TypeError( + f'callable {obj} must return a str, ' + f'but got {returns}', + ) + + +def verify_mask(obj: Any) -> bool: + """Verify that the object is a tuple of two vector types.""" + if typing.get_origin(obj) is not tuple or len(typing.get_args(obj)) != 2: + raise TypeError( + f'Expected a tuple of two vector types, but got {type(obj)}', + ) + + args = typing.get_args(obj) + + if not utils.is_vector(args[0]): + raise TypeError( + f'Expected a vector type for the first element, but got {args[0]}', + ) + + if not utils.is_vector(args[1]): + raise TypeError( + f'Expected a vector type for the second element, but got {args[1]}', + ) + + return True + + +def verify_masks(obj: Callable[..., Any]) -> bool: + """Verify that the function parameters and return value are all masks.""" + ann = utils.get_annotations(obj) + for name, value in ann.items(): + if not verify_mask(value): + raise TypeError( + f'Expected a vector type for the parameter {name} ' + f'in function {obj.__name__}, but got {value}', + ) + return True + + def expand_types(args: Any) -> Optional[Union[List[str], Type[Any]]]: """Expand the types for the function arguments / return values.""" if args is None: @@ -30,18 +104,11 @@ def expand_types(args: Any) -> Optional[Union[List[str], Type[Any]]]: return [args] # General way of accepting pydantic.BaseModel, NamedTuple, TypedDict - elif inspect.isclass(args): + elif is_valid_type(args): return args - # Callable that returns a SQL string - elif callable(args): - out = args() - if not isinstance(out, str): - raise TypeError(f'unrecognized type for parameter: {args}') - return [out] - # List of SQL strings or callables - else: + elif isinstance(args, list): new_args = [] for arg in args: if isinstance(arg, str): @@ -52,6 +119,15 @@ def expand_types(args: Any) -> Optional[Union[List[str], Type[Any]]]: raise TypeError(f'unrecognized type for parameter: {arg}') return new_args + # Callable that returns a SQL string + elif is_valid_callable(args): + out = args() + if not isinstance(out, str): + raise TypeError(f'unrecognized type for parameter: {args}') + return [out] + + raise TypeError(f'unrecognized type for parameter: {args}') + def _func( func: Optional[Callable[..., Any]] = None, @@ -59,7 +135,7 @@ def _func( name: Optional[str] = None, args: Optional[ParameterType] = None, returns: Optional[ReturnType] = None, - include_masks: bool = False, + with_null_masks: bool = False, function_type: str = 'udf', ) -> Callable[..., Any]: """Generic wrapper for UDF and TVF decorators.""" @@ -69,7 +145,7 @@ def _func( name=name, args=expand_types(args), returns=expand_types(returns), - include_masks=include_masks, + with_null_masks=with_null_masks, function_type=function_type, ).items() if v is not None } @@ -79,12 +155,21 @@ def _func( # in at that time. if func is None: def decorate(func: Callable[..., Any]) -> Callable[..., Any]: + if with_null_masks: + verify_masks(func) + def wrapper(*args: Any, **kwargs: Any) -> Callable[..., Any]: return func(*args, **kwargs) # type: ignore + wrapper._singlestoredb_attrs = _singlestoredb_attrs # type: ignore + return functools.wraps(func)(wrapper) + return decorate + if with_null_masks: + verify_masks(func) + def wrapper(*args: Any, **kwargs: Any) -> Callable[..., Any]: return func(*args, **kwargs) # type: ignore @@ -99,10 +184,9 @@ def udf( name: Optional[str] = None, args: Optional[ParameterType] = None, returns: Optional[ReturnType] = None, - include_masks: bool = False, ) -> Callable[..., Any]: """ - Apply attributes to a UDF. + Define a user-defined function (UDF). Parameters ---------- @@ -126,10 +210,6 @@ def udf( returns : str, optional Specifies the return data type of the function. If not specified, the type annotation from the function is used. - include_masks : bool, optional - Should boolean masks be included with each input parameter to indicate - which elements are NULL? This is only used when a input parameters are - configured to a vector type (numpy, pandas, polars, arrow). Returns ------- @@ -141,7 +221,55 @@ def udf( name=name, args=args, returns=returns, - include_masks=include_masks, + with_null_masks=False, + function_type='udf', + ) + + +def udf_with_null_masks( + func: Optional[Callable[..., Any]] = None, + *, + name: Optional[str] = None, + args: Optional[ParameterType] = None, + returns: Optional[ReturnType] = None, +) -> Callable[..., Any]: + """ + Define a user-defined function (UDF) with null masks. + + Parameters + ---------- + func : callable, optional + The UDF to apply parameters to + name : str, optional + The name to use for the UDF in the database + args : str | Callable | List[str | Callable], optional + Specifies the data types of the function arguments. Typically, + the function data types are derived from the function parameter + annotations. These annotations can be overridden. If the function + takes a single type for all parameters, `args` can be set to a + SQL string describing all parameters. If the function takes more + than one parameter and all of the parameters are being manually + defined, a list of SQL strings may be used (one for each parameter). + A dictionary of SQL strings may be used to specify a parameter type + for a subset of parameters; the keys are the names of the + function parameters. Callables may also be used for datatypes. This + is primarily for using the functions in the ``dtypes`` module that + are associated with SQL types with all default options (e.g., ``dt.FLOAT``). + returns : str, optional + Specifies the return data type of the function. If not specified, + the type annotation from the function is used. + + Returns + ------- + Callable + + """ + return _func( + func=func, + name=name, + args=args, + returns=returns, + with_null_masks=True, function_type='udf', ) @@ -152,10 +280,57 @@ def tvf( name: Optional[str] = None, args: Optional[ParameterType] = None, returns: Optional[ReturnType] = None, - include_masks: bool = False, ) -> Callable[..., Any]: """ - Apply attributes to a TVF. + Define a table-valued function (TVF). + + Parameters + ---------- + func : callable, optional + The TVF to apply parameters to + name : str, optional + The name to use for the TVF in the database + args : str | Callable | List[str | Callable], optional + Specifies the data types of the function arguments. Typically, + the function data types are derived from the function parameter + annotations. These annotations can be overridden. If the function + takes a single type for all parameters, `args` can be set to a + SQL string describing all parameters. If the function takes more + than one parameter and all of the parameters are being manually + defined, a list of SQL strings may be used (one for each parameter). + A dictionary of SQL strings may be used to specify a parameter type + for a subset of parameters; the keys are the names of the + function parameters. Callables may also be used for datatypes. This + is primarily for using the functions in the ``dtypes`` module that + are associated with SQL types with all default options (e.g., ``dt.FLOAT``). + returns : str, optional + Specifies the return data type of the function. If not specified, + the type annotation from the function is used. + + Returns + ------- + Callable + + """ + return _func( + func=func, + name=name, + args=args, + returns=returns, + with_null_masks=False, + function_type='tvf', + ) + + +def tvf_with_null_masks( + func: Optional[Callable[..., Any]] = None, + *, + name: Optional[str] = None, + args: Optional[ParameterType] = None, + returns: Optional[ReturnType] = None, +) -> Callable[..., Any]: + """ + Define a table-valued function (TVF) using null masks. Parameters ---------- @@ -179,10 +354,6 @@ def tvf( returns : str, optional Specifies the return data type of the function. If not specified, the type annotation from the function is used. - include_masks : bool, optional - Should boolean masks be included with each input parameter to indicate - which elements are NULL? This is only used when a input parameters are - configured to a vector type (numpy, pandas, polars, arrow). Returns ------- @@ -194,6 +365,6 @@ def tvf( name=name, args=args, returns=returns, - include_masks=include_masks, + with_null_masks=True, function_type='tvf', ) diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index 1e110cb3b..fa6c8cb32 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -225,7 +225,7 @@ def make_func( """ attrs = getattr(func, '_singlestoredb_attrs', {}) - include_masks = attrs.get('include_masks', False) + with_null_masks = attrs.get('with_null_masks', False) function_type = attrs.get('function_type', 'udf').lower() info: Dict[str, Any] = {} @@ -263,9 +263,9 @@ async def do_func( # type: ignore # each result row, so we just have to use the same # row ID for all rows in the result. - # If `include_masks` is set, the function is expected to return + # If `with_null_masks` is set, the function is expected to return # a tuple of (data, mask) for each column. - if include_masks: + if with_null_masks: out = func(*cols) assert isinstance(out, tuple) row_ids = array_cls([row_ids[0]] * len(out[0][0])) @@ -300,9 +300,9 @@ async def do_func( # type: ignore '''Call function on given cols of data.''' row_ids = array_cls(row_ids) - # If `include_masks` is set, the function is expected to return + # If `with_null_masks` is set, the function is expected to return # a tuple of (data, mask) for each column.` - if include_masks: + if with_null_masks: out = func(*cols) assert isinstance(out, tuple) return row_ids, [out] @@ -528,6 +528,7 @@ class Application(object): # Valid URL paths invoke_path = ('invoke',) show_create_function_path = ('show', 'create_function') + show_function_info_path = ('show', 'function_info') def __init__( self, @@ -548,6 +549,8 @@ def __init__( link_name: Optional[str] = get_option('external_function.link_name'), link_config: Optional[Dict[str, Any]] = None, link_credentials: Optional[Dict[str, Any]] = None, + name_prefix: str = get_option('external_function.name_prefix'), + name_suffix: str = get_option('external_function.name_suffix'), ) -> None: if link_name and (link_config or link_credentials): raise ValueError( @@ -604,6 +607,7 @@ def __init__( if not hasattr(x, '_singlestoredb_attrs'): continue name = x._singlestoredb_attrs.get('name', x.__name__) + name = f'{name_prefix}{name}{name_suffix}' external_functions[x.__name__] = x func, info = make_func(name, x) endpoints[name.encode('utf-8')] = func, info @@ -619,6 +623,7 @@ def __init__( # Add endpoint for each exported function for name, alias in get_func_names(func_names): item = getattr(pkg, name) + alias = f'{name_prefix}{name}{name_suffix}' external_functions[name] = item func, info = make_func(alias, item) endpoints[alias.encode('utf-8')] = func, info @@ -631,6 +636,7 @@ def __init__( if not hasattr(x, '_singlestoredb_attrs'): continue name = x._singlestoredb_attrs.get('name', x.__name__) + name = f'{name_prefix}{name}{name_suffix}' external_functions[x.__name__] = x func, info = make_func(name, x) endpoints[name.encode('utf-8')] = func, info @@ -638,6 +644,7 @@ def __init__( else: alias = funcs.__name__ external_functions[funcs.__name__] = funcs + alias = f'{name_prefix}{alias}{name_suffix}' func, info = make_func(alias, funcs) endpoints[alias.encode('utf-8')] = func, info @@ -740,6 +747,12 @@ async def __call__( await send(self.text_response_dict) + # Return function info + elif method == 'GET' and (path == self.show_function_info_path or not path): + functions = self.get_function_info() + body = json.dumps(dict(functions=functions)).encode('utf-8') + await send(self.text_response_dict) + # Path not found else: body = b'' @@ -784,14 +797,70 @@ def _locate_app_functions(self, cur: Any) -> Tuple[Set[str], Set[str]]: # See if function URL matches url cur.execute(f'SHOW CREATE FUNCTION `{name}`') for fname, _, code, *_ in list(cur): - m = re.search(r" (?:\w+) SERVICE '([^']+)'", code) + m = re.search(r" (?:\w+) (?:SERVICE|MANAGED) '([^']+)'", code) if m and m.group(1) == self.url: funcs.add(fname) if link and re.match(r'^py_ext_func_link_\S{14}$', link): links.add(link) return funcs, links - def show_create_functions( + def get_function_info( + self, + func_name: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Return the functions and function signature information. + Returns + ------- + Dict[str, Any] + """ + functions = {} + no_default = object() + + for key, (_, info) in self.endpoints.items(): + if not func_name or key == func_name: + sig = info['signature'] + args = [] + + # Function arguments + for a in sig.get('args', []): + dtype = a['dtype'] + nullable = '?' in dtype + args.append( + dict( + name=a['name'], + dtype=dtype.replace('?', ''), + nullable=nullable, + ), + ) + if a.get('default', no_default) is not no_default: + args[-1]['default'] = a['default'] + + # Return values + ret = sig.get('returns', []) + returns = [] + + for a in ret: + dtype = a['dtype'] + nullable = '?' in dtype + returns.append( + dict( + dtype=dtype.replace('?', ''), + nullable=nullable, + ), + ) + if a.get('name', None): + returns[-1]['name'] = a['name'] + if a.get('default', no_default) is not no_default: + returns[-1]['default'] = a['default'] + + functions[sig['name']] = dict( + args=args, returns=returns, function_type=info['function_type'], + ) + + return functions + + def get_create_functions( self, replace: bool = False, ) -> List[str]: @@ -860,7 +929,7 @@ def register_functions( cur.execute(f'DROP FUNCTION IF EXISTS `{fname}`') for link in links: cur.execute(f'DROP LINK {link}') - for func in self.show_create_functions(replace=replace): + for func in self.get_create_functions(replace=replace): cur.execute(func) def drop_functions( @@ -1188,6 +1257,22 @@ def main(argv: Optional[List[str]] = None) -> None: ), help='logging level', ) + parser.add_argument( + '--name-prefix', metavar='name_prefix', + default=defaults.get( + 'name_prefix', + get_option('external_function.name_prefix'), + ), + help='Prefix to add to function names', + ) + parser.add_argument( + '--name-suffix', metavar='name_suffix', + default=defaults.get( + 'name_suffix', + get_option('external_function.name_suffix'), + ), + help='Suffix to add to function names', + ) parser.add_argument( 'functions', metavar='module.or.func.path', nargs='*', help='functions or modules to export in UDF server', @@ -1285,9 +1370,11 @@ def main(argv: Optional[List[str]] = None) -> None: link_config=json.loads(args.link_config) or None, link_credentials=json.loads(args.link_credentials) or None, app_mode='remote', + name_prefix=args.name_prefix, + name_suffix=args.name_suffix, ) - funcs = app.show_create_functions(replace=args.replace_existing) + funcs = app.get_create_functions(replace=args.replace_existing) if not funcs: raise RuntimeError('no functions specified') diff --git a/singlestoredb/functions/ext/json.py b/singlestoredb/functions/ext/json.py index 3221b1d4f..6b2c4f3ae 100644 --- a/singlestoredb/functions/ext/json.py +++ b/singlestoredb/functions/ext/json.py @@ -3,6 +3,7 @@ from typing import Any from typing import List from typing import Tuple +from typing import TYPE_CHECKING from ..dtypes import DEFAULT_VALUES from ..dtypes import NUMPY_TYPE_MAP @@ -11,29 +12,23 @@ from ..dtypes import PYARROW_TYPE_MAP from ..dtypes import PYTHON_CONVERTERS -try: - import numpy as np - has_numpy = True -except ImportError: - has_numpy = False - -try: - import polars as pl - has_polars = True -except ImportError: - has_polars = False - -try: - import pandas as pd - has_pandas = True -except ImportError: - has_pandas = False - -try: - import pyarrow as pa - has_pyarrow = True -except ImportError: - has_pyarrow = False +if TYPE_CHECKING: + try: + import numpy as np + except ImportError: + pass + try: + import pandas as pd + except ImportError: + pass + try: + import polars as pl + except ImportError: + pass + try: + import pyarrow as pa + except ImportError: + pass class JSONEncoder(json.JSONEncoder): @@ -135,9 +130,6 @@ def load_pandas( Tuple[pd.Series[int], List[pd.Series[Any]] ''' - if not has_pandas or not has_numpy: - raise RuntimeError('This operation requires pandas and numpy to be installed') - row_ids, cols = _load_vectors(colspec, data) index = pd.Series(row_ids, dtype=np.longlong) return index, \ @@ -172,9 +164,6 @@ def load_polars( Tuple[polars.Series[int], List[polars.Series[Any]] ''' - if not has_polars or not has_numpy: - raise RuntimeError('This operation requires polars and numpy to be installed') - row_ids, cols = _load_vectors(colspec, data) return pl.Series(None, row_ids, dtype=pl.Int64), \ [ @@ -205,9 +194,6 @@ def load_numpy( Tuple[np.ndarray[int], List[np.ndarray[Any]] ''' - if not has_numpy: - raise RuntimeError('This operation requires numpy to be installed') - row_ids, cols = _load_vectors(colspec, data) return np.asarray(row_ids, dtype=np.longlong), \ [ @@ -238,9 +224,6 @@ def load_arrow( Tuple[pyarrow.Array[int], List[pyarrow.Array[Any]] ''' - if not has_pyarrow or not has_numpy: - raise RuntimeError('This operation requires pyarrow and numpy to be installed') - row_ids, cols = _load_vectors(colspec, data) return pa.array(row_ids, type=pa.int64()), \ [ diff --git a/singlestoredb/functions/ext/mmap.py b/singlestoredb/functions/ext/mmap.py index 3bdb6a6f5..df200fa14 100644 --- a/singlestoredb/functions/ext/mmap.py +++ b/singlestoredb/functions/ext/mmap.py @@ -338,7 +338,7 @@ def main(argv: Optional[List[str]] = None) -> None: app_mode='collocated', ) - funcs = app.show_create_functions(replace=args.replace_existing) + funcs = app.get_create_functions(replace=args.replace_existing) if not funcs: raise RuntimeError('no functions specified') diff --git a/singlestoredb/functions/ext/rowdat_1.py b/singlestoredb/functions/ext/rowdat_1.py index 22940ba63..bc597ce46 100644 --- a/singlestoredb/functions/ext/rowdat_1.py +++ b/singlestoredb/functions/ext/rowdat_1.py @@ -7,40 +7,37 @@ from typing import Optional from typing import Sequence from typing import Tuple +from typing import TYPE_CHECKING from ...config import get_option +from ...mysql.constants import FIELD_TYPE as ft from ..dtypes import DEFAULT_VALUES from ..dtypes import NUMPY_TYPE_MAP from ..dtypes import PANDAS_TYPE_MAP from ..dtypes import POLARS_TYPE_MAP from ..dtypes import PYARROW_TYPE_MAP -try: - import numpy as np - has_numpy = True -except ImportError: - has_numpy = False - -try: - import polars as pl - has_polars = True -except ImportError: - has_polars = False - -try: - import pandas as pd - has_pandas = True -except ImportError: - has_pandas = False - -try: - import pyarrow as pa - import pyarrow.compute as pc - has_pyarrow = True -except ImportError: - has_pyarrow = False - -from ...mysql.constants import FIELD_TYPE as ft +if TYPE_CHECKING: + try: + import numpy as np + except ImportError: + pass + try: + import polars as pl + except ImportError: + pass + try: + import pandas as pd + except ImportError: + pass + try: + import pyarrow as pa + except ImportError: + pass + try: + import pyarrow.compute as pc + except ImportError: + pass has_accel = False try: @@ -208,8 +205,7 @@ def _load_pandas( Tuple[pd.Series[int], List[Tuple[pd.Series[Any], pd.Series[bool]]]] ''' - if not has_pandas or not has_numpy: - raise RuntimeError('pandas must be installed for this operation') + import pandas as pd row_ids, cols = _load_vectors(colspec, data) index = pd.Series(row_ids) @@ -244,8 +240,7 @@ def _load_polars( Tuple[polars.Series[int], List[polars.Series[Any]]] ''' - if not has_polars: - raise RuntimeError('polars must be installed for this operation') + import polars as pl row_ids, cols = _load_vectors(colspec, data) return pl.Series(None, row_ids, dtype=pl.Int64), \ @@ -280,8 +275,7 @@ def _load_numpy( Tuple[np.ndarray[int], List[np.ndarray[Any]]] ''' - if not has_numpy: - raise RuntimeError('numpy must be installed for this operation') + import numpy as np row_ids, cols = _load_vectors(colspec, data) return np.asarray(row_ids, dtype=np.int64), \ @@ -298,8 +292,8 @@ def _load_arrow( colspec: List[Tuple[str, int]], data: bytes, ) -> Tuple[ - 'pa.Array[pa.int64()]', - List[Tuple['pa.Array[Any]', 'pa.Array[pa.bool_()]']], + 'pa.Array[pa.int64]', + List[Tuple['pa.Array[Any]', 'pa.Array[pa.bool_]']], ]: ''' Convert bytes in rowdat_1 format into rows of data. @@ -316,8 +310,7 @@ def _load_arrow( Tuple[pyarrow.Array[int], List[pyarrow.Array[Any]]] ''' - if not has_pyarrow: - raise RuntimeError('pyarrow must be installed for this operation') + import pyarrow as pa row_ids, cols = _load_vectors(colspec, data) return pa.array(row_ids, type=pa.int64()), \ @@ -488,9 +481,6 @@ def _dump_arrow( row_ids: 'pa.Array[int]', cols: List[Tuple['pa.Array[Any]', 'pa.Array[bool]']], ) -> bytes: - if not has_pyarrow: - raise RuntimeError('pyarrow must be installed for this operation') - return _dump_vectors( returns, row_ids.tolist(), @@ -503,9 +493,6 @@ def _dump_numpy( row_ids: 'np.typing.NDArray[np.int64]', cols: List[Tuple['np.typing.NDArray[Any]', 'np.typing.NDArray[np.bool_]']], ) -> bytes: - if not has_numpy: - raise RuntimeError('numpy must be installed for this operation') - return _dump_vectors( returns, row_ids.tolist(), @@ -518,9 +505,6 @@ def _dump_pandas( row_ids: 'pd.Series[np.int64]', cols: List[Tuple['pd.Series[Any]', 'pd.Series[np.bool_]']], ) -> bytes: - if not has_pandas or not has_numpy: - raise RuntimeError('pandas must be installed for this operation') - return _dump_vectors( returns, row_ids.to_list(), @@ -533,9 +517,6 @@ def _dump_polars( row_ids: 'pl.Series[pl.Int64]', cols: List[Tuple['pl.Series[Any]', 'pl.Series[pl.Boolean]']], ) -> bytes: - if not has_polars: - raise RuntimeError('polars must be installed for this operation') - return _dump_vectors( returns, row_ids.to_list(), @@ -550,8 +531,6 @@ def _load_numpy_accel( 'np.typing.NDArray[np.int64]', List[Tuple['np.typing.NDArray[Any]', 'np.typing.NDArray[np.bool_]']], ]: - if not has_numpy: - raise RuntimeError('numpy must be installed for this operation') if not has_accel: raise RuntimeError('could not load SingleStoreDB extension') @@ -563,8 +542,6 @@ def _dump_numpy_accel( row_ids: 'np.typing.NDArray[np.int64]', cols: List[Tuple['np.typing.NDArray[Any]', 'np.typing.NDArray[np.bool_]']], ) -> bytes: - if not has_numpy: - raise RuntimeError('numpy must be installed for this operation') if not has_accel: raise RuntimeError('could not load SingleStoreDB extension') @@ -578,11 +555,11 @@ def _load_pandas_accel( 'pd.Series[np.int64]', List[Tuple['pd.Series[Any]', 'pd.Series[np.bool_]']], ]: - if not has_pandas or not has_numpy: - raise RuntimeError('pandas must be installed for this operation') if not has_accel: raise RuntimeError('could not load SingleStoreDB extension') + import pandas as pd + numpy_ids, numpy_cols = _singlestoredb_accel.load_rowdat_1_numpy(colspec, data) cols = [ ( @@ -599,8 +576,6 @@ def _dump_pandas_accel( row_ids: 'pd.Series[np.int64]', cols: List[Tuple['pd.Series[Any]', 'pd.Series[np.bool_]']], ) -> bytes: - if not has_pandas or not has_numpy: - raise RuntimeError('pandas must be installed for this operation') if not has_accel: raise RuntimeError('could not load SingleStoreDB extension') @@ -622,11 +597,11 @@ def _load_polars_accel( 'pl.Series[pl.Int64]', List[Tuple['pl.Series[Any]', 'pl.Series[pl.Boolean]']], ]: - if not has_polars: - raise RuntimeError('polars must be installed for this operation') if not has_accel: raise RuntimeError('could not load SingleStoreDB extension') + import polars as pl + numpy_ids, numpy_cols = _singlestoredb_accel.load_rowdat_1_numpy(colspec, data) cols = [ ( @@ -647,8 +622,6 @@ def _dump_polars_accel( row_ids: 'pl.Series[pl.Int64]', cols: List[Tuple['pl.Series[Any]', 'pl.Series[pl.Boolean]']], ) -> bytes: - if not has_polars: - raise RuntimeError('polars must be installed for this operation') if not has_accel: raise RuntimeError('could not load SingleStoreDB extension') @@ -667,14 +640,14 @@ def _load_arrow_accel( colspec: List[Tuple[str, int]], data: bytes, ) -> Tuple[ - 'pa.Array[pa.int64()]', - List[Tuple['pa.Array[Any]', 'pa.Array[pa.bool_()]']], + 'pa.Array[pa.int64]', + List[Tuple['pa.Array[Any]', 'pa.Array[pa.bool_]']], ]: - if not has_pyarrow: - raise RuntimeError('pyarrow must be installed for this operation') if not has_accel: raise RuntimeError('could not load SingleStoreDB extension') + import pyarrow as pa + numpy_ids, numpy_cols = _singlestoredb_accel.load_rowdat_1_numpy(colspec, data) cols = [ ( @@ -688,8 +661,8 @@ def _load_arrow_accel( def _create_arrow_mask( data: 'pa.Array[Any]', - mask: 'pa.Array[pa.bool_()]', -) -> 'pa.Array[pa.bool_()]': + mask: 'pa.Array[pa.bool_]', +) -> 'pa.Array[pa.bool_]': if mask is None: return data.is_null().to_numpy(zero_copy_only=False) return pc.or_(data.is_null(), mask.is_null()).to_numpy(zero_copy_only=False) @@ -697,11 +670,9 @@ def _create_arrow_mask( def _dump_arrow_accel( returns: List[int], - row_ids: 'pa.Array[pa.int64()]', - cols: List[Tuple['pa.Array[Any]', 'pa.Array[pa.bool_()]']], + row_ids: 'pa.Array[pa.int64]', + cols: List[Tuple['pa.Array[Any]', 'pa.Array[pa.bool_]']], ) -> bytes: - if not has_pyarrow: - raise RuntimeError('pyarrow must be installed for this operation') if not has_accel: raise RuntimeError('could not load SingleStoreDB extension') diff --git a/singlestoredb/functions/signature.py b/singlestoredb/functions/signature.py index 84a849a68..c4b5adb2f 100644 --- a/singlestoredb/functions/signature.py +++ b/singlestoredb/functions/signature.py @@ -26,15 +26,9 @@ except ImportError: has_numpy = False -try: - import pydantic - import pydantic_core - has_pydantic = True -except ImportError: - has_pydantic = False - from . import dtypes as dt +from . import utils from ..mysql.converters import escape_item # type: ignore if sys.version_info >= (3, 10): @@ -207,49 +201,6 @@ class ArrayCollection(Collection): pass -def get_annotations(obj: Any) -> Dict[str, Any]: - """Get the annotations of an object.""" - if hasattr(inspect, 'get_annotations'): - return inspect.get_annotations(obj) - if isinstance(obj, type): - return obj.__dict__.get('__annotations__', {}) - return getattr(obj, '__annotations__', {}) - - -def is_numpy(obj: Any) -> bool: - """Check if an object is a numpy array.""" - if is_union(obj): - obj = typing.get_args(obj)[0] - if not has_numpy: - return False - if inspect.isclass(obj): - return obj is np.ndarray - if typing.get_origin(obj) is np.ndarray: - return True - return isinstance(obj, np.ndarray) - - -def is_dataframe(obj: Any) -> bool: - """Check if an object is a DataFrame.""" - # Cheating here a bit so we don't have to import pandas / polars / pyarrow: - # unless we absolutely need to - if getattr(obj, '__module__', '').startswith('pandas.'): - return getattr(obj, '__name__', '') == 'DataFrame' - if getattr(obj, '__module__', '').startswith('polars.'): - return getattr(obj, '__name__', '') == 'DataFrame' - if getattr(obj, '__module__', '').startswith('pyarrow.'): - return getattr(obj, '__name__', '') == 'Table' - return False - - -def is_vector(obj: Any) -> bool: - """Check if an object is a vector type.""" - return is_pandas_series(obj) \ - or is_polars_series(obj) \ - or is_pyarrow_array(obj) \ - or is_numpy(obj) - - def get_data_format(obj: Any) -> str: """Return the data format of the DataFrame / Table / vector.""" # Cheating here a bit so we don't have to import pandas / polars / pyarrow @@ -259,7 +210,7 @@ def get_data_format(obj: Any) -> str: if getattr(obj, '__module__', '').startswith('polars.'): return 'polars' if getattr(obj, '__module__', '').startswith('pyarrow.'): - return 'pyarrow' + return 'arrow' if getattr(obj, '__module__', '').startswith('numpy.'): return 'numpy' if isinstance(obj, list): @@ -267,69 +218,6 @@ def get_data_format(obj: Any) -> str: return 'scalar' -def is_pandas_series(obj: Any) -> bool: - """Check if an object is a pandas Series.""" - if is_union(obj): - obj = typing.get_args(obj)[0] - return ( - getattr(obj, '__module__', '').startswith('pandas.') and - getattr(obj, '__name__', '') == 'Series' - ) - - -def is_polars_series(obj: Any) -> bool: - """Check if an object is a polars Series.""" - if is_union(obj): - obj = typing.get_args(obj)[0] - return ( - getattr(obj, '__module__', '').startswith('polars.') and - getattr(obj, '__name__', '') == 'Series' - ) - - -def is_pyarrow_array(obj: Any) -> bool: - """Check if an object is a pyarrow Array.""" - if is_union(obj): - obj = typing.get_args(obj)[0] - return ( - getattr(obj, '__module__', '').startswith('pyarrow.') and - getattr(obj, '__name__', '') == 'Array' - ) - - -def is_typeddict(obj: Any) -> bool: - """Check if an object is a TypedDict.""" - if hasattr(typing, 'is_typeddict'): - return typing.is_typeddict(obj) # noqa: TYP006 - return False - - -def is_namedtuple(obj: Any) -> bool: - """Check if an object is a named tuple.""" - if inspect.isclass(obj): - return ( - issubclass(obj, tuple) and - hasattr(obj, '_asdict') and - hasattr(obj, '_fields') - ) - return ( - isinstance(obj, tuple) and - hasattr(obj, '_asdict') and - hasattr(obj, '_fields') - ) - - -def is_pydantic(obj: Any) -> bool: - """Check if an object is a pydantic model.""" - if not has_pydantic: - return False - - if inspect.isclass(obj): - return issubclass(obj, pydantic.BaseModel) - - return isinstance(obj, pydantic.BaseModel) - - def escape_name(name: str) -> str: """Escape a function parameter name.""" if '`' in name: @@ -432,21 +320,21 @@ def normalize_dtype(dtype: Any) -> str: if dtype is bool: return 'bool' - if dataclasses.is_dataclass(dtype): + if utils.is_dataclass(dtype): dc_fields = dataclasses.fields(dtype) item_dtypes = ','.join( f'{normalize_dtype(simplify_dtype(x.type))}' for x in dc_fields ) return f'tuple[{item_dtypes}]' - if is_typeddict(dtype): - td_fields = get_annotations(dtype).keys() + if utils.is_typeddict(dtype): + td_fields = utils.get_annotations(dtype).keys() item_dtypes = ','.join( f'{normalize_dtype(simplify_dtype(dtype[x]))}' for x in td_fields ) return f'tuple[{item_dtypes}]' - if is_pydantic(dtype): + if utils.is_pydantic(dtype): pyd_fields = dtype.model_fields.values() item_dtypes = ','.join( f'{normalize_dtype(simplify_dtype(x.annotation))}' # type: ignore @@ -454,8 +342,8 @@ def normalize_dtype(dtype: Any) -> str: ) return f'tuple[{item_dtypes}]' - if is_namedtuple(dtype): - nt_fields = get_annotations(dtype).values() + if utils.is_namedtuple(dtype): + nt_fields = utils.get_annotations(dtype).values() item_dtypes = ','.join( f'{normalize_dtype(simplify_dtype(dtype[x]))}' for x in nt_fields ) @@ -528,7 +416,7 @@ def normalize_dtype(dtype: Any) -> str: ) -def collapse_dtypes(dtypes: Union[str, List[str]]) -> str: +def collapse_dtypes(dtypes: Union[str, List[str]], include_null: bool = False) -> str: """ Collapse a dtype possibly containing multiple data types to one type. @@ -539,6 +427,8 @@ def collapse_dtypes(dtypes: Union[str, List[str]]) -> str: ---------- dtypes : str or list[str] The data types to collapse + include_null : bool, optional + Whether to force include null types in the result Returns ------- @@ -554,7 +444,7 @@ def collapse_dtypes(dtypes: Union[str, List[str]]) -> str: orig_dtypes = dtypes dtypes = list(set(dtypes)) - is_nullable = 'null' in dtypes + is_nullable = include_null or 'null' in dtypes dtypes = [x for x in dtypes if x != 'null'] @@ -679,13 +569,13 @@ def get_typeddict_schema( if include_default: return [ (k, v, getattr(obj, k, NO_DEFAULT)) - for k, v in get_annotations(obj).items() + for k, v in utils.get_annotations(obj).items() ] - return list(get_annotations(obj).items()) + return list(utils.get_annotations(obj).items()) def get_pydantic_schema( - obj: pydantic.BaseModel, + obj: Any, include_default: bool = False, ) -> List[Union[Tuple[str, Any], Tuple[str, Any, Any]]]: """ @@ -704,6 +594,7 @@ def get_pydantic_schema( A list of tuples containing the field names and field types """ + import pydantic_core if include_default: return [ ( @@ -741,9 +632,9 @@ def get_namedtuple_schema( k, v, obj._field_defaults.get(k, NO_DEFAULT), ) - for k, v in get_annotations(obj).items() + for k, v in utils.get_annotations(obj).items() ] - return list(get_annotations(obj).items()) + return list(utils.get_annotations(obj).items()) def get_colspec( @@ -771,25 +662,25 @@ def get_colspec( if overrides: # Dataclass - if dataclasses.is_dataclass(overrides): + if utils.is_dataclass(overrides): overrides_colspec = get_dataclass_schema( overrides, include_default=include_default, ) # TypedDict - elif is_typeddict(overrides): + elif utils.is_typeddict(overrides): overrides_colspec = get_typeddict_schema( overrides, include_default=include_default, ) # Named tuple - elif is_namedtuple(overrides): + elif utils.is_namedtuple(overrides): overrides_colspec = get_namedtuple_schema( overrides, include_default=include_default, ) # Pydantic model - elif is_pydantic(overrides): + elif utils.is_pydantic(overrides): overrides_colspec = get_pydantic_schema( overrides, include_default=include_default, ) @@ -815,11 +706,39 @@ def get_colspec( return overrides_colspec +def unpack_masked_type(obj: Any) -> Any: + """ + Unpack a masked type into a single type. + + Parameters + ---------- + obj : Any + The masked type to unpack + + Returns + ------- + Any + The unpacked type + + """ + if typing.get_origin(obj) is not tuple: + raise TypeError(f'masked type must be a tuple, got {obj}') + args = typing.get_args(obj) + if len(args) != 2: + raise TypeError(f'masked type must be a tuple of length 2, got {obj}') + if not utils.is_vector(args[0]): + raise TypeError(f'masked type must be a vector, got {args[0]}') + if not utils.is_vector(args[1]): + raise TypeError(f'masked type must be a vector, got {args[1]}') + return args[0] + + def get_schema( spec: Any, overrides: Optional[Union[List[str], Type[Any]]] = None, function_type: str = 'udf', mode: str = 'parameter', + with_null_masks: bool = False, ) -> Tuple[List[Tuple[str, Any, Optional[str]]], str]: """ Expand a return type annotation into a list of types and field names. @@ -834,6 +753,8 @@ def get_schema( The type of function, either 'udf' or 'tvf' mode : str The mode of the function, either 'parameter' or 'return' + with_null_masks : bool + Whether to use null masks for the parameters and return value Returns ------- @@ -843,6 +764,7 @@ def get_schema( definition of the type """ + colspec = [] data_format = 'scalar' # Make sure that the result of a TVF is a list or dataframe @@ -854,7 +776,7 @@ def get_schema( # If it's a tuple, it must be a tuple of vectors elif typing.get_origin(spec) is tuple: - if not all([is_vector(x) for x in typing.get_args(spec)]): + if not all([utils.is_vector(x) for x in typing.get_args(spec)]): raise TypeError( 'return type for TVF must be a list, DataFrame / Table, ' 'or tuple of vectors', @@ -863,7 +785,7 @@ def get_schema( # DataFrames require special handling. You can't get the schema # from the annotation, you need a separate structure to specify # the types. This should be specified in the overrides. - elif is_dataframe(spec) or is_vector(spec): + elif utils.is_dataframe(spec) or utils.is_vector(spec): if not overrides: raise TypeError( 'type overrides must be specified for DataFrames / Tables', @@ -877,15 +799,28 @@ def get_schema( ) # Error out for incorrect types - elif typing.get_origin(spec) in [tuple, dict] \ - or is_dataframe(spec) \ - or dataclasses.is_dataclass(spec) \ - or is_typeddict(spec) \ - or is_pydantic(spec) \ - or is_namedtuple(spec): - if mode == 'parameter': - raise TypeError('parameter types must be scalar or vector') - raise TypeError('return type for UDF must be a scalar type') + elif typing.get_origin(spec) in [tuple, dict] or \ + ( + # Check for optional tuples and dicts + is_union(spec) and + any([ + typing.get_origin(x) in [tuple, dict] + for x in typing.get_args(spec) + ]) + ) \ + or utils.is_dataframe(spec) \ + or utils.is_dataclass(spec) \ + or utils.is_typeddict(spec) \ + or utils.is_pydantic(spec) \ + or utils.is_namedtuple(spec): + if typing.get_origin(spec) is tuple: + raise TypeError( + f'{mode} types must be scalar or vector, got {spec}; ' + 'if you are trying to use null masks, you must use the ' + f'@{function_type}_with_null_masks decorator', + ) + else: + raise TypeError(f'{mode} types must be scalar or vector, got {spec}') # # Process each parameter / return type into a colspec @@ -895,7 +830,7 @@ def get_schema( overrides_colspec = get_colspec(overrides) # Numpy array types - if is_numpy(spec): + if utils.is_numpy(spec): data_format = 'numpy' if overrides: colspec = overrides_colspec @@ -908,7 +843,7 @@ def get_schema( colspec = [('', typing.get_args(spec)[1])] # Pandas Series - elif is_pandas_series(spec): + elif utils.is_pandas_series(spec): data_format = 'pandas' if not overrides: raise TypeError( @@ -918,7 +853,7 @@ def get_schema( colspec = overrides_colspec # Polars Series - elif is_polars_series(spec): + elif utils.is_polars_series(spec): data_format = 'polars' if not overrides: raise TypeError( @@ -928,8 +863,8 @@ def get_schema( colspec = overrides_colspec # PyArrow Array - elif is_pyarrow_array(spec): - data_format = 'pyarrow' + elif utils.is_pyarrow_array(spec): + data_format = 'arrow' if not overrides: raise TypeError( 'pyarrow Arrays must have a data type specified ' @@ -938,19 +873,19 @@ def get_schema( colspec = overrides_colspec # Return type is specified by a dataclass definition - elif dataclasses.is_dataclass(spec): + elif utils.is_dataclass(spec): colspec = overrides_colspec or get_dataclass_schema(spec) # Return type is specified by a TypedDict definition - elif is_typeddict(spec): + elif utils.is_typeddict(spec): colspec = overrides_colspec or get_typeddict_schema(spec) # Return type is specified by a pydantic model - elif is_pydantic(spec): + elif utils.is_pydantic(spec): colspec = overrides_colspec or get_pydantic_schema(spec) # Return type is specified by a named tuple - elif is_namedtuple(spec): + elif utils.is_namedtuple(spec): colspec = overrides_colspec or get_namedtuple_schema(spec) # Unrecognized return type @@ -979,15 +914,22 @@ def get_schema( out_names = [x[0] for x in out_colspec] out_overrides = [x[1] for x in out_colspec] + if out_overrides and len(typing.get_args(spec)) != len(out_overrides): + raise ValueError( + 'number of return types does not match the number of ' + 'overrides specified', + ) + colspec = [] out_data_formats = [] for i, x in enumerate(typing.get_args(spec)): out_item, out_data_format = get_schema( - x, + x if not with_null_masks else unpack_masked_type(x), overrides=out_overrides[i] if out_overrides else [], - # Always use UDF for individual items + # Always use UDF mode for individual items function_type='udf', mode=mode, + with_null_masks=with_null_masks, ) # Use the name from the overrides if specified @@ -1006,7 +948,8 @@ def get_schema( f'{", ".join(out_data_formats)}', ) - data_format = out_data_formats[0] + if out_data_formats: + data_format = out_data_formats[0] # Since the colspec was computed by get_schema already, don't go # through the process of normalizing the dtypes again @@ -1028,7 +971,10 @@ def get_schema( for k, v, *_ in colspec: out.append(( k, - collapse_dtypes([normalize_dtype(x) for x in simplify_dtype(v)]), + collapse_dtypes( + [normalize_dtype(x) for x in simplify_dtype(v)], + include_null=with_null_masks, + ), v if isinstance(v, str) else None, )) @@ -1050,16 +996,16 @@ def vector_check(obj: Any) -> Tuple[Any, str]: The scalar type and the data format ('scalar', 'numpy', 'pandas', 'polars') """ - if is_numpy(obj): + if utils.is_numpy(obj): if len(typing.get_args(obj)) < 2: return None, 'numpy' return typing.get_args(obj)[1], 'numpy' - if is_pandas_series(obj): + if utils.is_pandas_series(obj): return None, 'pandas' - if is_polars_series(obj): + if utils.is_polars_series(obj): return None, 'polars' - if is_pyarrow_array(obj): - return None, 'pyarrow' + if utils.is_pyarrow_array(obj): + return None, 'arrow' return obj, 'scalar' @@ -1088,6 +1034,7 @@ def get_signature( attrs = getattr(func, '_singlestoredb_attrs', {}) function_type = attrs.get('function_type', 'udf') + with_null_masks = attrs.get('with_null_masks', False) name = attrs.get('name', func_name if func_name else func.__name__) out: Dict[str, Any] = dict(name=name, args=args, returns=returns) @@ -1105,12 +1052,23 @@ def get_signature( args_colspec = [x for x in get_colspec(attrs.get('args', []), include_default=True)] args_overrides = [x[1] for x in args_colspec] args_defaults = [x[2] for x in args_colspec] # type: ignore - for i, param in enumerate(signature.parameters.values()): + + if args_overrides and len(args_overrides) != len(signature.parameters): + raise ValueError( + 'number of args in the decorator does not match ' + 'the number of parameters in the function signature', + ) + + params = list(signature.parameters.values()) + + for i, param in enumerate(params): arg_schema, args_data_format = get_schema( - param.annotation, + param.annotation + if not with_null_masks else unpack_masked_type(param.annotation), overrides=args_overrides[i] if args_overrides else [], function_type=function_type, mode='parameter', + with_null_masks=with_null_masks, ) args_data_formats.append(args_data_format) @@ -1125,8 +1083,8 @@ def get_signature( if args_defaults[i] is not NO_DEFAULT: default_option['default'] = args_defaults[i] else: - if param.default is not param.empty: - default_option['default'] = param.default + if params[i].default is not param.empty: + default_option['default'] = params[i].default # Generate SQL code for the parameter sql = sql or dtype_to_sql( @@ -1145,14 +1103,16 @@ def get_signature( f'{", ".join(args_data_formats)}', ) - out['args_data_format'] = args_data_formats[0] + out['args_data_format'] = args_data_formats[0] if args_data_formats else 'scalar' # Generate the return types and the corresponding SQL code for those values ret_schema, out['returns_data_format'] = get_schema( - signature.return_annotation, + signature.return_annotation + if not with_null_masks else unpack_masked_type(signature.return_annotation), overrides=attrs.get('returns', None), function_type=function_type, mode='return', + with_null_masks=with_null_masks, ) # Generate names for fields as needed @@ -1339,6 +1299,8 @@ def signature_to_sql( else: res = ret[0]['sql'] returns = f' RETURNS {res}' + else: + returns = ' RETURNS NULL' host = os.environ.get('SINGLESTOREDB_EXT_HOST', '127.0.0.1') port = os.environ.get('SINGLESTOREDB_EXT_PORT', '8000') diff --git a/singlestoredb/functions/typing.py b/singlestoredb/functions/typing.py new file mode 100644 index 000000000..848a6a503 --- /dev/null +++ b/singlestoredb/functions/typing.py @@ -0,0 +1,38 @@ +from typing import Tuple +from typing import TypeVar + +try: + import numpy as np + import numpy.typing as npt + has_numpy = True +except ImportError: + has_numpy = False + + +# +# Masked types are used for pairs of vectors where the first element is the +# vector and the second element is a boolean mask indicating which elements +# are NULL. The boolean mask is a vector of the same length as the first +# element, where True indicates that the corresponding element in the first +# element is NULL. +# +# This is needed for vector types that do not support NULL values, such as +# numpy arrays and pandas Series. +# +T = TypeVar('T') +Masked = Tuple[T, T] + + +# +# The MaskedNDArray type is used for pairs of numpy arrays where the first +# element is the numpy array and the second element is a boolean mask +# indicating which elements are NULL. The boolean mask is a numpy array of +# the same shape as the first element, where True indicates that the +# corresponding element in the first element is NULL. +# +# This is needed bebause numpy arrays do not support NULL values, so we need to +# use a boolean mask to indicate which elements are NULL. +# +if has_numpy: + TT = TypeVar('TT', bound=np.generic) # Generic type for NumPy data types + MaskedNDArray = Tuple[npt.NDArray[TT], npt.NDArray[np.bool_]] diff --git a/singlestoredb/functions/utils.py b/singlestoredb/functions/utils.py new file mode 100644 index 000000000..cba08c324 --- /dev/null +++ b/singlestoredb/functions/utils.py @@ -0,0 +1,152 @@ +import dataclasses +import inspect +import sys +import types +import typing +from typing import Any +from typing import Dict + +try: + import numpy as np + has_numpy = True +except ImportError: + has_numpy = False + + +if sys.version_info >= (3, 10): + _UNION_TYPES = {typing.Union, types.UnionType} +else: + _UNION_TYPES = {typing.Union} + + +is_dataclass = dataclasses.is_dataclass + + +def is_union(x: Any) -> bool: + """Check if the object is a Union.""" + return typing.get_origin(x) in _UNION_TYPES + + +def get_annotations(obj: Any) -> Dict[str, Any]: + """Get the annotations of an object.""" + if hasattr(inspect, 'get_annotations'): + return inspect.get_annotations(obj) + if isinstance(obj, type): + return obj.__dict__.get('__annotations__', {}) + return getattr(obj, '__annotations__', {}) + + +def is_numpy(obj: Any) -> bool: + """Check if an object is a numpy array.""" + if is_union(obj): + obj = typing.get_args(obj)[0] + if not has_numpy: + return False + if inspect.isclass(obj): + return obj is np.ndarray + if typing.get_origin(obj) is np.ndarray: + return True + return isinstance(obj, np.ndarray) + + +def is_dataframe(obj: Any) -> bool: + """Check if an object is a DataFrame.""" + # Cheating here a bit so we don't have to import pandas / polars / pyarrow: + # unless we absolutely need to + if getattr(obj, '__module__', '').startswith('pandas.'): + return getattr(obj, '__name__', '') == 'DataFrame' + if getattr(obj, '__module__', '').startswith('polars.'): + return getattr(obj, '__name__', '') == 'DataFrame' + if getattr(obj, '__module__', '').startswith('pyarrow.'): + return getattr(obj, '__name__', '') == 'Table' + return False + + +def is_vector(obj: Any) -> bool: + """Check if an object is a vector type.""" + return is_pandas_series(obj) \ + or is_polars_series(obj) \ + or is_pyarrow_array(obj) \ + or is_numpy(obj) + + +def get_data_format(obj: Any) -> str: + """Return the data format of the DataFrame / Table / vector.""" + # Cheating here a bit so we don't have to import pandas / polars / pyarrow + # unless we absolutely need to + if getattr(obj, '__module__', '').startswith('pandas.'): + return 'pandas' + if getattr(obj, '__module__', '').startswith('polars.'): + return 'polars' + if getattr(obj, '__module__', '').startswith('pyarrow.'): + return 'arrow' + if getattr(obj, '__module__', '').startswith('numpy.'): + return 'numpy' + if isinstance(obj, list): + return 'list' + return 'scalar' + + +def is_pandas_series(obj: Any) -> bool: + """Check if an object is a pandas Series.""" + if is_union(obj): + obj = typing.get_args(obj)[0] + return ( + getattr(obj, '__module__', '').startswith('pandas.') and + getattr(obj, '__name__', '') == 'Series' + ) + + +def is_polars_series(obj: Any) -> bool: + """Check if an object is a polars Series.""" + if is_union(obj): + obj = typing.get_args(obj)[0] + return ( + getattr(obj, '__module__', '').startswith('polars.') and + getattr(obj, '__name__', '') == 'Series' + ) + + +def is_pyarrow_array(obj: Any) -> bool: + """Check if an object is a pyarrow Array.""" + if is_union(obj): + obj = typing.get_args(obj)[0] + return ( + getattr(obj, '__module__', '').startswith('pyarrow.') and + getattr(obj, '__name__', '') == 'Array' + ) + + +def is_typeddict(obj: Any) -> bool: + """Check if an object is a TypedDict.""" + if hasattr(typing, 'is_typeddict'): + return typing.is_typeddict(obj) # noqa: TYP006 + return False + + +def is_namedtuple(obj: Any) -> bool: + """Check if an object is a named tuple.""" + if inspect.isclass(obj): + return ( + issubclass(obj, tuple) and + hasattr(obj, '_asdict') and + hasattr(obj, '_fields') + ) + return ( + isinstance(obj, tuple) and + hasattr(obj, '_asdict') and + hasattr(obj, '_fields') + ) + + +def is_pydantic(obj: Any) -> bool: + """Check if an object is a pydantic model.""" + if not inspect.isclass(obj): + return False + # We don't want to import pydantic here, so we check if + # the class is a subclass + return bool([ + x for x in inspect.getmro(obj) + if x.__module__.startswith('pydantic.') + and x.__name__ == 'BaseModel' + ]) diff --git a/singlestoredb/tests/ext_funcs/__init__.py b/singlestoredb/tests/ext_funcs/__init__.py index 0c1d78dff..1d4abf1ef 100644 --- a/singlestoredb/tests/ext_funcs/__init__.py +++ b/singlestoredb/tests/ext_funcs/__init__.py @@ -1,15 +1,28 @@ #!/usr/bin/env python3 -# type: ignore from typing import Optional -from typing import Tuple -from singlestoredb.functions.decorator import udf +import numpy as np +import numpy.typing as npt +import pandas as pd +import polars as pl +import pyarrow as pa + +from singlestoredb.functions import Masked +from singlestoredb.functions import MaskedNDArray +from singlestoredb.functions import udf +from singlestoredb.functions import udf_with_null_masks from singlestoredb.functions.dtypes import BIGINT +from singlestoredb.functions.dtypes import DOUBLE from singlestoredb.functions.dtypes import FLOAT from singlestoredb.functions.dtypes import MEDIUMINT from singlestoredb.functions.dtypes import SMALLINT +from singlestoredb.functions.dtypes import TEXT from singlestoredb.functions.dtypes import TINYINT -from singlestoredb.functions.dtypes import VARCHAR + + +@udf +def int_mult(x: int, y: int) -> int: + return x * y @udf @@ -17,24 +30,36 @@ def double_mult(x: float, y: float) -> float: return x * y -@udf.pandas -def pandas_double_mult(x: float, y: float) -> float: +@udf( + args=[DOUBLE(nullable=False), DOUBLE(nullable=False)], + returns=DOUBLE(nullable=False), +) +def pandas_double_mult(x: pd.Series, y: pd.Series) -> pd.Series: return x * y -@udf.numpy -def numpy_double_mult(x: float, y: float) -> float: +@udf +def numpy_double_mult( + x: npt.NDArray[np.float64], + y: npt.NDArray[np.float64], +) -> npt.NDArray[np.float64]: return x * y -@udf.arrow -def arrow_double_mult(x: float, y: float) -> float: +@udf( + args=[DOUBLE(nullable=False), DOUBLE(nullable=False)], + returns=DOUBLE(nullable=False), +) +def arrow_double_mult(x: pa.Array, y: pa.Array) -> pa.Array: import pyarrow.compute as pc return pc.multiply(x, y) -@udf.polars -def polars_double_mult(x: float, y: float) -> float: +@udf( + args=[DOUBLE(nullable=False), DOUBLE(nullable=False)], + returns=DOUBLE(nullable=False), +) +def polars_double_mult(x: pl.Series, y: pl.Series) -> pl.Series: return x * y @@ -57,279 +82,315 @@ def nullable_float_mult(x: Optional[float], y: Optional[float]) -> Optional[floa return x * y -def _int_mult(x: int, y: int) -> int: +# +# TINYINT +# + +tinyint_udf = udf( + args=[TINYINT(nullable=False), TINYINT(nullable=False)], + returns=TINYINT(nullable=False), +) + + +@tinyint_udf +def tinyint_mult(x: Optional[int], y: Optional[int]) -> Optional[int]: if x is None or y is None: return None return x * y -def _arrow_int_mult(x: int, y: int) -> int: - import pyarrow.compute as pc - return pc.multiply(x, y) +@tinyint_udf +def pandas_tinyint_mult(x: pd.Series, y: pd.Series) -> pd.Series: + return x * y -def _int_mult_with_masks(x: Tuple[int, bool], y: Tuple[int, bool]) -> Tuple[int, bool]: - x_data, x_nulls = x - y_data, y_nulls = y - return (x_data * y_data, x_nulls | y_nulls) +@tinyint_udf +def polars_tinyint_mult(x: pl.Series, y: pl.Series) -> pl.Series: + return x * y -def _arrow_int_mult_with_masks( - x: Tuple[int, bool], - y: Tuple[int, bool], -) -> Tuple[int, bool]: +@tinyint_udf +def numpy_tinyint_mult(x: np.ndarray, y: np.ndarray) -> np.ndarray: + return x * y + + +@tinyint_udf +def arrow_tinyint_mult(x: pa.Array, y: pa.Array) -> pa.Array: import pyarrow.compute as pc - x_data, x_nulls = x - y_data, y_nulls = y - return (pc.multiply(x_data, y_data), pc.or_(x_nulls, y_nulls)) + return pc.multiply(x, y) +# +# SMALLINT +# -int_mult = udf(_int_mult, name='int_mult') -tinyint_mult = udf( - _int_mult, - name='tinyint_mult', - args=[TINYINT(nullable=False), TINYINT(nullable=False)], - returns=TINYINT(nullable=False), +smallint_udf = udf( + args=[SMALLINT(nullable=False), SMALLINT(nullable=False)], + returns=SMALLINT(nullable=False), ) -pandas_tinyint_mult = udf.pandas( - _int_mult, - name='pandas_tinyint_mult', - args=[TINYINT(nullable=False), TINYINT(nullable=False)], - returns=TINYINT(nullable=False), -) -polars_tinyint_mult = udf.polars( - _int_mult, - name='polars_tinyint_mult', - args=[TINYINT(nullable=False), TINYINT(nullable=False)], - returns=TINYINT(nullable=False), -) +@smallint_udf +def smallint_mult(x: Optional[int], y: Optional[int]) -> Optional[int]: + if x is None or y is None: + return None + return x * y -numpy_tinyint_mult = udf.numpy( - _int_mult, - name='numpy_tinyint_mult', - args=[TINYINT(nullable=False), TINYINT(nullable=False)], - returns=TINYINT(nullable=False), -) -arrow_tinyint_mult = udf.arrow( - _arrow_int_mult, - name='arrow_tinyint_mult', - args=[TINYINT(nullable=False), TINYINT(nullable=False)], - returns=TINYINT(nullable=False), -) +@smallint_udf +def pandas_smallint_mult(x: pd.Series, y: pd.Series) -> pd.Series: + return x * y -smallint_mult = udf( - _int_mult, - name='smallint_mult', - args=[SMALLINT(nullable=False), SMALLINT(nullable=False)], - returns=SMALLINT(nullable=False), -) -pandas_smallint_mult = udf.pandas( - _int_mult, - name='pandas_smallint_mult', - args=[SMALLINT(nullable=False), SMALLINT(nullable=False)], - returns=SMALLINT(nullable=False), -) +@smallint_udf +def polars_smallint_mult(x: pl.Series, y: pl.Series) -> pl.Series: + return x * y -polars_smallint_mult = udf.polars( - _int_mult, - name='polars_smallint_mult', - args=[SMALLINT(nullable=False), SMALLINT(nullable=False)], - returns=SMALLINT(nullable=False), -) -numpy_smallint_mult = udf.numpy( - _int_mult, - name='numpy_smallint_mult', - args=[SMALLINT(nullable=False), SMALLINT(nullable=False)], - returns=SMALLINT(nullable=False), -) +@smallint_udf +def numpy_smallint_mult(x: np.ndarray, y: np.ndarray) -> np.ndarray: + return x * y -arrow_smallint_mult = udf.arrow( - _arrow_int_mult, - name='arrow_smallint_mult', - args=[SMALLINT(nullable=False), SMALLINT(nullable=False)], - returns=SMALLINT(nullable=False), -) -mediumint_mult = udf( - _int_mult, - name='mediumint_mult', - args=[MEDIUMINT(nullable=False), MEDIUMINT(nullable=False)], - returns=MEDIUMINT(nullable=False), -) +@smallint_udf +def arrow_smallint_mult(x: pa.Array, y: pa.Array) -> pa.Array: + import pyarrow.compute as pc + return pc.multiply(x, y) -pandas_mediumint_mult = udf.pandas( - _int_mult, - name='pandas_mediumint_mult', - args=[MEDIUMINT(nullable=False), MEDIUMINT(nullable=False)], - returns=MEDIUMINT(nullable=False), -) -polars_mediumint_mult = udf.polars( - _int_mult, - name='polars_mediumint_mult', - args=[MEDIUMINT(nullable=False), MEDIUMINT(nullable=False)], - returns=MEDIUMINT(nullable=False), -) +# +# MEDIUMINT +# -numpy_mediumint_mult = udf.numpy( - _int_mult, - name='numpy_mediumint_mult', - args=[MEDIUMINT(nullable=False), MEDIUMINT(nullable=False)], - returns=MEDIUMINT(nullable=False), -) -arrow_mediumint_mult = udf.arrow( - _arrow_int_mult, - name='arrow_mediumint_mult', +mediumint_udf = udf( args=[MEDIUMINT(nullable=False), MEDIUMINT(nullable=False)], returns=MEDIUMINT(nullable=False), ) -bigint_mult = udf( - _int_mult, - name='bigint_mult', - args=[BIGINT(nullable=False), BIGINT(nullable=False)], - returns=BIGINT(nullable=False), -) -pandas_bigint_mult = udf.pandas( - _int_mult, - name='pandas_bigint_mult', - args=[BIGINT(nullable=False), BIGINT(nullable=False)], - returns=BIGINT(nullable=False), -) +@mediumint_udf +def mediumint_mult(x: Optional[int], y: Optional[int]) -> Optional[int]: + if x is None or y is None: + return None + return x * y -polars_bigint_mult = udf.polars( - _int_mult, - name='polars_bigint_mult', - args=[BIGINT(nullable=False), BIGINT(nullable=False)], - returns=BIGINT(nullable=False), -) -numpy_bigint_mult = udf.numpy( - _int_mult, - name='numpy_bigint_mult', - args=[BIGINT(nullable=False), BIGINT(nullable=False)], - returns=BIGINT(nullable=False), -) +@mediumint_udf +def pandas_mediumint_mult(x: pd.Series, y: pd.Series) -> pd.Series: + return x * y + -arrow_bigint_mult = udf.arrow( - _arrow_int_mult, - name='arrow_bigint_mult', +@mediumint_udf +def polars_mediumint_mult(x: pl.Series, y: pl.Series) -> pl.Series: + return x * y + + +@mediumint_udf +def numpy_mediumint_mult(x: np.ndarray, y: np.ndarray) -> np.ndarray: + return x * y + + +@mediumint_udf +def arrow_mediumint_mult(x: pa.Array, y: pa.Array) -> pa.Array: + import pyarrow.compute as pc + return pc.multiply(x, y) + + +# +# BIGINT +# + + +bigint_udf = udf( args=[BIGINT(nullable=False), BIGINT(nullable=False)], returns=BIGINT(nullable=False), ) -nullable_tinyint_mult = udf( - _int_mult, - name='nullable_tinyint_mult', - args=[TINYINT, TINYINT], - returns=TINYINT, -) -pandas_nullable_tinyint_mult = udf.pandas( - _int_mult, - name='pandas_nullable_tinyint_mult', - args=[TINYINT, TINYINT], - returns=TINYINT, -) +@bigint_udf +def bigint_mult(x: Optional[int], y: Optional[int]) -> Optional[int]: + if x is None or y is None: + return None + return x * y -pandas_nullable_tinyint_mult_with_masks = udf.pandas( - _int_mult_with_masks, - name='pandas_nullable_tinyint_mult_with_masks', - args=[TINYINT, TINYINT], - returns=TINYINT, - include_masks=True, -) -polars_nullable_tinyint_mult = udf.polars( - _int_mult, - name='polars_nullable_tinyint_mult', - args=[TINYINT, TINYINT], - returns=TINYINT, -) +@bigint_udf +def pandas_bigint_mult(x: pd.Series, y: pd.Series) -> pd.Series: + return x * y -polars_nullable_tinyint_mult_with_masks = udf.polars( - _int_mult_with_masks, - name='polars_nullable_tinyint_mult_with_masks', - args=[TINYINT, TINYINT], - returns=TINYINT, - include_masks=True, -) -numpy_nullable_tinyint_mult = udf.numpy( - _int_mult, - name='numpy_nullable_tinyint_mult', - args=[TINYINT, TINYINT], - returns=TINYINT, -) +@bigint_udf +def polars_bigint_mult(x: pl.Series, y: pl.Series) -> pl.Series: + return x * y -numpy_nullable_tinyint_mult_with_masks = udf.numpy( - _int_mult_with_masks, - name='numpy_nullable_tinyint_mult_with_masks', - args=[TINYINT, TINYINT], - returns=TINYINT, - include_masks=True, -) -arrow_nullable_tinyint_mult = udf.arrow( - _arrow_int_mult, - name='arrow_nullable_tinyint_mult', - args=[TINYINT, TINYINT], - returns=TINYINT, -) +@bigint_udf +def numpy_bigint_mult(x: np.ndarray, y: np.ndarray) -> np.ndarray: + return x * y -arrow_nullable_tinyint_mult_with_masks = udf.arrow( - _arrow_int_mult_with_masks, - name='arrow_nullable_tinyint_mult_with_masks', - args=[TINYINT, TINYINT], - returns=TINYINT, - include_masks=True, -) -nullable_smallint_mult = udf( - _int_mult, - name='nullable_smallint_mult', - args=[SMALLINT, SMALLINT], - returns=SMALLINT, -) +@bigint_udf +def arrow_bigint_mult(x: pa.Array, y: pa.Array) -> pa.Array: + import pyarrow.compute as pc + return pc.multiply(x, y) + + +# +# NULLABLE TINYINT +# + -nullable_mediumint_mult = udf( - _int_mult, - name='nullable_mediumint_mult', - args=[MEDIUMINT, MEDIUMINT], - returns=MEDIUMINT, +nullable_tinyint_udf = udf( + args=[TINYINT(nullable=True), TINYINT(nullable=True)], + returns=TINYINT(nullable=True), ) -nullable_bigint_mult = udf( - _int_mult, - name='nullable_bigint_mult', - args=[BIGINT, BIGINT], - returns=BIGINT, + +@nullable_tinyint_udf +def nullable_tinyint_mult(x: Optional[int], y: Optional[int]) -> Optional[int]: + if x is None or y is None: + return None + return x * y + + +@nullable_tinyint_udf +def pandas_nullable_tinyint_mult(x: pd.Series, y: pd.Series) -> pd.Series: + return x * y + + +@nullable_tinyint_udf +def polars_nullable_tinyint_mult(x: pl.Series, y: pl.Series) -> pl.Series: + return x * y + + +@nullable_tinyint_udf +def numpy_nullable_tinyint_mult(x: np.ndarray, y: np.ndarray) -> np.ndarray: + return x * y + + +@nullable_tinyint_udf +def arrow_nullable_tinyint_mult(x: pa.Array, y: pa.Array) -> pa.Array: + import pyarrow.compute as pc + return pc.multiply(x, y) + +# +# NULLABLE SMALLINT +# + + +nullable_smallint_udf = udf( + args=[SMALLINT(nullable=True), SMALLINT(nullable=True)], + returns=SMALLINT(nullable=True), ) -numpy_nullable_bigint_mult = udf.numpy( - _int_mult, - name='numpy_nullable_bigint_mult', - args=[BIGINT, BIGINT], - returns=BIGINT, + +@nullable_smallint_udf +def nullable_smallint_mult(x: Optional[int], y: Optional[int]) -> Optional[int]: + if x is None or y is None: + return None + return x * y + + +@nullable_smallint_udf +def pandas_nullable_smallint_mult(x: pd.Series, y: pd.Series) -> pd.Series: + return x * y + + +@nullable_smallint_udf +def polars_nullable_smallint_mult(x: pl.Series, y: pl.Series) -> pl.Series: + return x * y + + +@nullable_smallint_udf +def numpy_nullable_smallint_mult(x: np.ndarray, y: np.ndarray) -> np.ndarray: + return x * y + + +@nullable_smallint_udf +def arrow_nullable_smallint_mult(x: pa.Array, y: pa.Array) -> pa.Array: + import pyarrow.compute as pc + return pc.multiply(x, y) + + +# +# NULLABLE MEDIUMINT +# + + +nullable_mediumint_udf = udf( + args=[MEDIUMINT(nullable=True), MEDIUMINT(nullable=True)], + returns=MEDIUMINT(nullable=True), ) -numpy_nullable_bigint_mult_with_masks = udf.numpy( - _int_mult_with_masks, - name='numpy_nullable_bigint_mult', - args=[BIGINT, BIGINT], - returns=BIGINT, - include_masks=True, + +@nullable_mediumint_udf +def nullable_mediumint_mult(x: Optional[int], y: Optional[int]) -> Optional[int]: + if x is None or y is None: + return None + return x * y + + +@nullable_mediumint_udf +def pandas_nullable_mediumint_mult(x: pd.Series, y: pd.Series) -> pd.Series: + return x * y + + +@nullable_mediumint_udf +def polars_nullable_mediumint_mult(x: pl.Series, y: pl.Series) -> pl.Series: + return x * y + + +@nullable_mediumint_udf +def numpy_nullable_mediumint_mult(x: np.ndarray, y: np.ndarray) -> np.ndarray: + return x * y + + +@nullable_mediumint_udf +def arrow_nullable_mediumint_mult(x: pa.Array, y: pa.Array) -> pa.Array: + import pyarrow.compute as pc + return pc.multiply(x, y) + + +# +# NULLABLE BIGINT +# + + +nullable_bigint_udf = udf( + args=[BIGINT(nullable=True), BIGINT(nullable=True)], + returns=BIGINT(nullable=True), ) +@nullable_bigint_udf +def nullable_bigint_mult(x: Optional[int], y: Optional[int]) -> Optional[int]: + if x is None or y is None: + return None + return x * y + + +@nullable_bigint_udf +def pandas_nullable_bigint_mult(x: pd.Series, y: pd.Series) -> pd.Series: + return x * y + + +@nullable_bigint_udf +def polars_nullable_bigint_mult(x: pl.Series, y: pl.Series) -> pl.Series: + return x * y + + +@nullable_bigint_udf +def numpy_nullable_bigint_mult(x: np.ndarray, y: np.ndarray) -> np.ndarray: + return x * y + + +@nullable_bigint_udf +def arrow_nullable_bigint_mult(x: pa.Array, y: pa.Array) -> pa.Array: + import pyarrow.compute as pc + return pc.multiply(x, y) + + @udf def nullable_int_mult(x: Optional[int], y: Optional[int]) -> Optional[int]: if x is None or y is None: @@ -342,13 +403,15 @@ def string_mult(x: str, times: int) -> str: return x * times -@udf.pandas -def pandas_string_mult(x: str, times: int) -> str: +@udf(args=[TEXT(nullable=False), BIGINT(nullable=False)], returns=TEXT(nullable=False)) +def pandas_string_mult(x: pd.Series, times: pd.Series) -> pd.Series: return x * times -@udf.numpy -def numpy_string_mult(x: str, times: int) -> str: +@udf +def numpy_string_mult( + x: npt.NDArray[np.str_], times: npt.NDArray[np.int_], +) -> npt.NDArray[np.str_]: return x * times @@ -373,13 +436,47 @@ def nullable_string_mult(x: Optional[str], times: Optional[int]) -> Optional[str return x * times -@udf(args=dict(x=VARCHAR(20, nullable=False))) -def varchar_mult(x: str, times: int) -> str: - return x * times +@udf_with_null_masks( + args=[TINYINT(nullable=True), TINYINT(nullable=True)], + returns=TINYINT(nullable=True), +) +def pandas_nullable_tinyint_mult_with_masks( + x: Masked[pd.Series], y: Masked[pd.Series], +) -> Masked[pd.Series]: + x_data, x_nulls = x + y_data, y_nulls = y + return (x_data * y_data, x_nulls | y_nulls) -@udf(args=dict(x=VARCHAR(20, nullable=True))) -def nullable_varchar_mult(x: Optional[str], times: Optional[int]) -> Optional[str]: - if x is None or times is None: - return None - return x * times +@udf_with_null_masks +def numpy_nullable_tinyint_mult_with_masks( + x: MaskedNDArray[np.int8], y: MaskedNDArray[np.int8], +) -> MaskedNDArray[np.int8]: + x_data, x_nulls = x + y_data, y_nulls = y + return (x_data * y_data, x_nulls | y_nulls) + + +@udf_with_null_masks( + args=[TINYINT(nullable=True), TINYINT(nullable=True)], + returns=TINYINT(nullable=True), +) +def polars_nullable_tinyint_mult_with_masks( + x: Masked[pl.Series], y: Masked[pl.Series], +) -> Masked[pl.Series]: + x_data, x_nulls = x + y_data, y_nulls = y + return (x_data * y_data, x_nulls | y_nulls) + + +@udf_with_null_masks( + args=[TINYINT(nullable=True), TINYINT(nullable=True)], + returns=TINYINT(nullable=True), +) +def arrow_nullable_tinyint_mult_with_masks( + x: Masked[pa.Array], y: Masked[pa.Array], +) -> Masked[pa.Array]: + import pyarrow.compute as pc + x_data, x_nulls = x + y_data, y_nulls = y + return (pc.multiply(x_data, y_data), pc.or_(x_nulls, y_nulls)) diff --git a/singlestoredb/tests/test_ext_func.py b/singlestoredb/tests/test_ext_func.py index 9d383c45a..c3ef236c1 100755 --- a/singlestoredb/tests/test_ext_func.py +++ b/singlestoredb/tests/test_ext_func.py @@ -929,8 +929,10 @@ def test_numpy_nullable_bigint_mult(self): 'from data_with_nulls order by id', ) + # assert [tuple(x) for x in self.cur] == \ + # [(200,), (200,), (500,), (None,), (0,)] assert [tuple(x) for x in self.cur] == \ - [(200,), (200,), (500,), (None,), (0,)] + [(200,), (200,), (500,), (0,), (0,)] desc = self.cur.description assert len(desc) == 1 @@ -1145,7 +1147,7 @@ def test_nullable_string_mult(self): assert desc[0].type_code == ft.BLOB assert desc[0].null_ok is True - def test_varchar_mult(self): + def _test_varchar_mult(self): self.cur.execute( 'select varchar_mult(name, value) as res ' 'from data order by id', @@ -1172,7 +1174,7 @@ def test_varchar_mult(self): 'from data order by id', ) - def test_nullable_varchar_mult(self): + def _test_nullable_varchar_mult(self): self.cur.execute( 'select nullable_varchar_mult(name, value) as res ' 'from data_with_nulls order by id', diff --git a/singlestoredb/tests/test_udf.py b/singlestoredb/tests/test_udf.py index aa21c4785..78259cb9a 100755 --- a/singlestoredb/tests/test_udf.py +++ b/singlestoredb/tests/test_udf.py @@ -28,7 +28,10 @@ def to_sql(x): - out = sig.signature_to_sql(sig.get_signature(x)) + out = sig.signature_to_sql( + sig.get_signature(x), + function_type=getattr(x, '_singlestoredb_attrs', {}).get('function_type', 'udf'), + ) out = re.sub(r'^CREATE EXTERNAL FUNCTION ', r'', out) out = re.sub(r' AS REMOTE SERVICE.+$', r'', out) return out.strip() @@ -101,28 +104,24 @@ def foo() -> Union[int, str]: ... to_sql(foo) # Tuple - def foo() -> Tuple[int, float, str]: ... - assert to_sql(foo) == '`foo`() RETURNS RECORD(`a` BIGINT NOT NULL, ' \ - '`b` DOUBLE NOT NULL, ' \ - '`c` TEXT NOT NULL) NOT NULL' + with self.assertRaises(TypeError): + def foo() -> Tuple[int, float, str]: ... + to_sql(foo) # Optional tuple - def foo() -> Optional[Tuple[int, float, str]]: ... - assert to_sql(foo) == '`foo`() RETURNS RECORD(`a` BIGINT NOT NULL, ' \ - '`b` DOUBLE NOT NULL, ' \ - '`c` TEXT NOT NULL) NULL' + with self.assertRaises(TypeError): + def foo() -> Optional[Tuple[int, float, str]]: ... + to_sql(foo) # Optional tuple with optional element - def foo() -> Optional[Tuple[int, float, Optional[str]]]: ... - assert to_sql(foo) == '`foo`() RETURNS RECORD(`a` BIGINT NOT NULL, ' \ - '`b` DOUBLE NOT NULL, ' \ - '`c` TEXT NULL) NULL' + with self.assertRaises(TypeError): + def foo() -> Optional[Tuple[int, float, Optional[str]]]: ... + to_sql(foo) # Optional tuple with optional union element - def foo() -> Optional[Tuple[int, Optional[Union[float, int]], str]]: ... - assert to_sql(foo) == '`foo`() RETURNS RECORD(`a` BIGINT NOT NULL, ' \ - '`b` DOUBLE NULL, ' \ - '`c` TEXT NOT NULL) NULL' + with self.assertRaises(TypeError): + def foo() -> Optional[Tuple[int, Optional[Union[float, int]], str]]: ... + to_sql(foo) # Unknown type def foo() -> set: ... @@ -184,22 +183,21 @@ def foo(x: Union[int, str]) -> None: ... to_sql(foo) # Tuple - def foo(x: Tuple[int, float, str]) -> None: ... - assert to_sql(foo) == '`foo`(`x` RECORD(`a` BIGINT NOT NULL, ' \ - '`b` DOUBLE NOT NULL, ' \ - '`c` TEXT NOT NULL) NOT NULL) RETURNS NULL' + with self.assertRaises(TypeError): + def foo(x: Tuple[int, float, str]) -> None: ... + to_sql(foo) # Optional tuple with optional element - def foo(x: Optional[Tuple[int, float, Optional[str]]]) -> None: ... - assert to_sql(foo) == '`foo`(`x` RECORD(`a` BIGINT NOT NULL, ' \ - '`b` DOUBLE NOT NULL, ' \ - '`c` TEXT NULL) NULL) RETURNS NULL' + with self.assertRaises(TypeError): + def foo(x: Optional[Tuple[int, float, Optional[str]]]) -> None: ... + to_sql(foo) # Optional tuple with optional union element - def foo(x: Optional[Tuple[int, Optional[Union[float, int]], str]]) -> None: ... - assert to_sql(foo) == '`foo`(`x` RECORD(`a` BIGINT NOT NULL, ' \ - '`b` DOUBLE NULL, ' \ - '`c` TEXT NOT NULL) NULL) RETURNS NULL' + with self.assertRaises(TypeError): + def foo( + x: Optional[Tuple[int, Optional[Union[float, int]], str]], + ) -> None: ... + to_sql(foo) # Unknown type def foo(x: set) -> None: ... @@ -336,9 +334,8 @@ def foo(x: int) -> int: ... # Override multiple params with one type @udf(args=dt.SMALLINT(nullable=False)) def foo(x: int, y: float, z: np.int8) -> int: ... - assert to_sql(foo) == '`foo`(`x` SMALLINT NOT NULL, ' \ - '`y` SMALLINT NOT NULL, ' \ - '`z` SMALLINT NOT NULL) RETURNS BIGINT NOT NULL' + with self.assertRaises(ValueError): + to_sql(foo) # Override with list @udf(args=[dt.SMALLINT, dt.FLOAT, dt.CHAR(30)]) @@ -350,13 +347,13 @@ def foo(x: int, y: float, z: str) -> int: ... # Override with too short of a list @udf(args=[dt.SMALLINT, dt.FLOAT]) def foo(x: int, y: float, z: str) -> int: ... - with self.assertRaises(TypeError): + with self.assertRaises(ValueError): to_sql(foo) # Override with too long of a list @udf(args=[dt.SMALLINT, dt.FLOAT, dt.CHAR(30), dt.TEXT]) def foo(x: int, y: float, z: str) -> int: ... - with self.assertRaises(TypeError): + with self.assertRaises(ValueError): to_sql(foo) # Override with list @@ -367,32 +364,10 @@ def foo(x: int, y: float, z: str) -> int: ... '`z` CHAR(30) NULL) RETURNS BIGINT NOT NULL' # Override with dict - @udf(args=dict(x=dt.SMALLINT, z=dt.CHAR(30))) - def foo(x: int, y: float, z: str) -> int: ... - assert to_sql(foo) == '`foo`(`x` SMALLINT NULL, ' \ - '`y` DOUBLE NOT NULL, ' \ - '`z` CHAR(30) NULL) RETURNS BIGINT NOT NULL' - - # Override with empty dict - @udf(args=dict()) - def foo(x: int, y: float, z: str) -> int: ... - assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL, ' \ - '`y` DOUBLE NOT NULL, ' \ - '`z` TEXT NOT NULL) RETURNS BIGINT NOT NULL' - - # Override with dict with extra keys - @udf(args=dict(bar=dt.INT)) - def foo(x: int, y: float, z: str) -> int: ... - assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL, ' \ - '`y` DOUBLE NOT NULL, ' \ - '`z` TEXT NOT NULL) RETURNS BIGINT NOT NULL' - - # Override parameters and return value - @udf(args=dict(x=dt.SMALLINT, z=dt.CHAR(30)), returns=dt.SMALLINT(nullable=False)) - def foo(x: int, y: float, z: str) -> int: ... - assert to_sql(foo) == '`foo`(`x` SMALLINT NULL, ' \ - '`y` DOUBLE NOT NULL, ' \ - '`z` CHAR(30) NULL) RETURNS SMALLINT NOT NULL' + with self.assertRaises(TypeError): + @udf(args=dict(x=dt.SMALLINT, z=dt.CHAR(30))) + def foo(x: int, y: float, z: str) -> int: ... + assert to_sql(foo) # Change function name @udf(name='hello_world') @@ -411,26 +386,19 @@ class MyData: two: str three: float - @udf - def foo(x: int) -> MyData: ... - assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL) ' \ - 'RETURNS RECORD(`one` BIGINT NULL, `two` TEXT NOT NULL, ' \ - '`three` DOUBLE NOT NULL) NOT NULL' - - @udf(returns=MyData) - def foo(x: int) -> Tuple[int, int, int]: ... - assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL) ' \ - 'RETURNS RECORD(`one` BIGINT NULL, `two` TEXT NOT NULL, ' \ - '`three` DOUBLE NOT NULL) NOT NULL' + with self.assertRaises(TypeError): + @udf + def foo(x: int) -> MyData: ... + to_sql(foo) @tvf - def foo(x: int) -> MyData: ... + def foo(x: int) -> List[MyData]: ... assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL) ' \ 'RETURNS TABLE(`one` BIGINT NULL, `two` TEXT NOT NULL, ' \ '`three` DOUBLE NOT NULL)' @tvf(returns=MyData) - def foo(x: int) -> Tuple[int, int, int]: ... + def foo(x: int) -> List[Tuple[int, int, int]]: ... assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL) ' \ 'RETURNS TABLE(`one` BIGINT NULL, `two` TEXT NOT NULL, ' \ '`three` DOUBLE NOT NULL)' @@ -440,26 +408,14 @@ class MyData(pydantic.BaseModel): two: str three: float - @udf - def foo(x: int) -> MyData: ... - assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL) ' \ - 'RETURNS RECORD(`one` BIGINT NULL, `two` TEXT NOT NULL, ' \ - '`three` DOUBLE NOT NULL) NOT NULL' - - @udf(returns=MyData) - def foo(x: int) -> Tuple[int, int, int]: ... - assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL) ' \ - 'RETURNS RECORD(`one` BIGINT NULL, `two` TEXT NOT NULL, ' \ - '`three` DOUBLE NOT NULL) NOT NULL' - @tvf - def foo(x: int) -> MyData: ... + def foo(x: int) -> List[MyData]: ... assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL) ' \ 'RETURNS TABLE(`one` BIGINT NULL, `two` TEXT NOT NULL, ' \ '`three` DOUBLE NOT NULL)' @tvf(returns=MyData) - def foo(x: int) -> Tuple[int, int, int]: ... + def foo(x: int) -> List[Tuple[int, int, int]]: ... assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL) ' \ 'RETURNS TABLE(`one` BIGINT NULL, `two` TEXT NOT NULL, ' \ '`three` DOUBLE NOT NULL)' From 858e6524a2c52749fa00f6732dd27610c8e2ecee Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Tue, 15 Apr 2025 17:45:41 -0500 Subject: [PATCH 11/16] Fix get_annotations call --- singlestoredb/functions/decorator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/singlestoredb/functions/decorator.py b/singlestoredb/functions/decorator.py index e1d204b71..a67497012 100644 --- a/singlestoredb/functions/decorator.py +++ b/singlestoredb/functions/decorator.py @@ -49,7 +49,7 @@ def is_valid_callable(obj: Any) -> bool: if not callable(obj): return False - returns = inspect.get_annotations(obj).get('return', None) + returns = utils.get_annotations(obj).get('return', None) if inspect.isclass(returns) and issubclass(returns, str): return True From 7ec2c24f07fba04b61253335deb320b868c0c549 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Wed, 16 Apr 2025 09:13:44 -0500 Subject: [PATCH 12/16] Fix imports; change binary from hex to base64 --- singlestoredb/functions/dtypes.py | 3 +- singlestoredb/functions/ext/json.py | 8 ++- singlestoredb/functions/ext/rowdat_1.py | 7 ++- singlestoredb/functions/utils.py | 80 +++++++++++++++---------- 4 files changed, 63 insertions(+), 35 deletions(-) diff --git a/singlestoredb/functions/dtypes.py b/singlestoredb/functions/dtypes.py index 905bec14d..b6aa02f2f 100644 --- a/singlestoredb/functions/dtypes.py +++ b/singlestoredb/functions/dtypes.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +import base64 import datetime import decimal import re @@ -106,7 +107,7 @@ def bytestr(x: Any) -> Optional[bytes]: return x if isinstance(x, bytes): return x - return bytes.fromhex(x) + return base64.b64decode(x) PYTHON_CONVERTERS = { diff --git a/singlestoredb/functions/ext/json.py b/singlestoredb/functions/ext/json.py index 6b2c4f3ae..05710247d 100644 --- a/singlestoredb/functions/ext/json.py +++ b/singlestoredb/functions/ext/json.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +import base64 import json from typing import Any from typing import List @@ -35,7 +36,7 @@ class JSONEncoder(json.JSONEncoder): def default(self, obj: Any) -> Any: if isinstance(obj, bytes): - return obj.hex() + return base64.b64encode(obj).decode('utf-8') return json.JSONEncoder.default(self, obj) @@ -130,6 +131,8 @@ def load_pandas( Tuple[pd.Series[int], List[pd.Series[Any]] ''' + import numpy as np + import pandas as pd row_ids, cols = _load_vectors(colspec, data) index = pd.Series(row_ids, dtype=np.longlong) return index, \ @@ -164,6 +167,7 @@ def load_polars( Tuple[polars.Series[int], List[polars.Series[Any]] ''' + import polars as pl row_ids, cols = _load_vectors(colspec, data) return pl.Series(None, row_ids, dtype=pl.Int64), \ [ @@ -194,6 +198,7 @@ def load_numpy( Tuple[np.ndarray[int], List[np.ndarray[Any]] ''' + import numpy as np row_ids, cols = _load_vectors(colspec, data) return np.asarray(row_ids, dtype=np.longlong), \ [ @@ -224,6 +229,7 @@ def load_arrow( Tuple[pyarrow.Array[int], List[pyarrow.Array[Any]] ''' + import pyarrow as pa row_ids, cols = _load_vectors(colspec, data) return pa.array(row_ids, type=pa.int64()), \ [ diff --git a/singlestoredb/functions/ext/rowdat_1.py b/singlestoredb/functions/ext/rowdat_1.py index bc597ce46..83052b671 100644 --- a/singlestoredb/functions/ext/rowdat_1.py +++ b/singlestoredb/functions/ext/rowdat_1.py @@ -35,7 +35,7 @@ except ImportError: pass try: - import pyarrow.compute as pc + import pyarrow.compute as pc # noqa: F401 except ImportError: pass @@ -205,6 +205,7 @@ def _load_pandas( Tuple[pd.Series[int], List[Tuple[pd.Series[Any], pd.Series[bool]]]] ''' + import numpy as np import pandas as pd row_ids, cols = _load_vectors(colspec, data) @@ -558,6 +559,7 @@ def _load_pandas_accel( if not has_accel: raise RuntimeError('could not load SingleStoreDB extension') + import numpy as np import pandas as pd numpy_ids, numpy_cols = _singlestoredb_accel.load_rowdat_1_numpy(colspec, data) @@ -663,8 +665,11 @@ def _create_arrow_mask( data: 'pa.Array[Any]', mask: 'pa.Array[pa.bool_]', ) -> 'pa.Array[pa.bool_]': + import pyarrow.compute as pc # noqa: F811 + if mask is None: return data.is_null().to_numpy(zero_copy_only=False) + return pc.or_(data.is_null(), mask.is_null()).to_numpy(zero_copy_only=False) diff --git a/singlestoredb/functions/utils.py b/singlestoredb/functions/utils.py index cba08c324..3b56707c7 100644 --- a/singlestoredb/functions/utils.py +++ b/singlestoredb/functions/utils.py @@ -6,12 +6,6 @@ from typing import Any from typing import Dict -try: - import numpy as np - has_numpy = True -except ImportError: - has_numpy = False - if sys.version_info >= (3, 10): _UNION_TYPES = {typing.Union, types.UnionType} @@ -36,29 +30,51 @@ def get_annotations(obj: Any) -> Dict[str, Any]: return getattr(obj, '__annotations__', {}) +def get_module(obj: Any) -> str: + """Get the module of an object.""" + module = getattr(obj, '__module__', '').split('.') + if module: + return module[0] + return '' + + +def get_type_name(obj: Any) -> str: + """Get the type name of an object.""" + if hasattr(obj, '__name__'): + return obj.__name__ + if hasattr(obj, '__class__'): + return obj.__class__.__name__ + return '' + + def is_numpy(obj: Any) -> bool: """Check if an object is a numpy array.""" - if is_union(obj): - obj = typing.get_args(obj)[0] - if not has_numpy: - return False if inspect.isclass(obj): - return obj is np.ndarray - if typing.get_origin(obj) is np.ndarray: - return True - return isinstance(obj, np.ndarray) + if get_module(obj) == 'numpy': + return get_type_name(obj) == 'ndarray' + + origin = typing.get_origin(obj) + if get_module(origin) == 'numpy': + if get_type_name(origin) == 'ndarray': + return True + + dtype = type(obj) + if get_module(dtype) == 'numpy': + return get_type_name(dtype) == 'ndarray' + + return False def is_dataframe(obj: Any) -> bool: """Check if an object is a DataFrame.""" # Cheating here a bit so we don't have to import pandas / polars / pyarrow: # unless we absolutely need to - if getattr(obj, '__module__', '').startswith('pandas.'): - return getattr(obj, '__name__', '') == 'DataFrame' - if getattr(obj, '__module__', '').startswith('polars.'): - return getattr(obj, '__name__', '') == 'DataFrame' - if getattr(obj, '__module__', '').startswith('pyarrow.'): - return getattr(obj, '__name__', '') == 'Table' + if get_module(obj) == 'pandas': + return get_type_name(obj) == 'DataFrame' + if get_module(obj) == 'polars': + return get_type_name(obj) == 'DataFrame' + if get_module(obj) == 'pyarrow': + return get_type_name(obj) == 'Table' return False @@ -74,13 +90,13 @@ def get_data_format(obj: Any) -> str: """Return the data format of the DataFrame / Table / vector.""" # Cheating here a bit so we don't have to import pandas / polars / pyarrow # unless we absolutely need to - if getattr(obj, '__module__', '').startswith('pandas.'): + if get_module(obj) == 'pandas': return 'pandas' - if getattr(obj, '__module__', '').startswith('polars.'): + if get_module(obj) == 'polars': return 'polars' - if getattr(obj, '__module__', '').startswith('pyarrow.'): + if get_module(obj) == 'pyarrow': return 'arrow' - if getattr(obj, '__module__', '').startswith('numpy.'): + if get_module(obj) == 'numpy': return 'numpy' if isinstance(obj, list): return 'list' @@ -92,8 +108,8 @@ def is_pandas_series(obj: Any) -> bool: if is_union(obj): obj = typing.get_args(obj)[0] return ( - getattr(obj, '__module__', '').startswith('pandas.') and - getattr(obj, '__name__', '') == 'Series' + get_module(obj) == 'pandas' and + get_type_name(obj) == 'Series' ) @@ -102,8 +118,8 @@ def is_polars_series(obj: Any) -> bool: if is_union(obj): obj = typing.get_args(obj)[0] return ( - getattr(obj, '__module__', '').startswith('polars.') and - getattr(obj, '__name__', '') == 'Series' + get_module(obj) == 'polars' and + get_type_name(obj) == 'Series' ) @@ -112,8 +128,8 @@ def is_pyarrow_array(obj: Any) -> bool: if is_union(obj): obj = typing.get_args(obj)[0] return ( - getattr(obj, '__module__', '').startswith('pyarrow.') and - getattr(obj, '__name__', '') == 'Array' + get_module(obj) == 'pyarrow' and + get_type_name(obj) == 'Array' ) @@ -147,6 +163,6 @@ def is_pydantic(obj: Any) -> bool: # the class is a subclass return bool([ x for x in inspect.getmro(obj) - if x.__module__.startswith('pydantic.') - and x.__name__ == 'BaseModel' + if get_module(x) == 'pydantic' + and get_type_name(x) == 'BaseModel' ]) From 2da33ff2f7fa5ec19e3d97d3313ded4cc079bccb Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Wed, 16 Apr 2025 10:33:31 -0500 Subject: [PATCH 13/16] Fix show functions call --- singlestoredb/functions/ext/asgi.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index fa6c8cb32..e6206d10c 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -790,7 +790,8 @@ def _locate_app_functions(self, cur: Any) -> Tuple[Set[str], Set[str]]: """Locate all current functions and links belonging to this app.""" funcs, links = set(), set() cur.execute('SHOW FUNCTIONS') - for name, ftype, _, _, _, link in list(cur): + for row in list(cur): + name, ftype, link = row[0], row[1], row[-1] # Only look at external functions if 'external' not in ftype.lower(): continue From 36f2b13f0ff394c27f2bdb864ac73e2c92a07621 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Wed, 16 Apr 2025 13:58:40 -0500 Subject: [PATCH 14/16] Fix fixed length strings / binary; add tests for fixed strings / binary; test no args / no return value --- accel.c | 11 ++-- singlestoredb/functions/ext/asgi.py | 11 +++- singlestoredb/functions/signature.py | 9 +++- singlestoredb/tests/ext_funcs/__init__.py | 33 ++++++++++++ singlestoredb/tests/test_ext_func.py | 41 ++++++++++++++ singlestoredb/tests/test_udf.py | 65 ++++++++++++----------- 6 files changed, 132 insertions(+), 38 deletions(-) diff --git a/accel.c b/accel.c index 499a800b2..02de6fb76 100644 --- a/accel.c +++ b/accel.c @@ -3929,7 +3929,8 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k out_idx += 8; } else if (col_types[i].type == NUMPY_FIXED_STRING) { - void *bytes = (void*)(cols[i] + j * 8); + // Jump to col_types[i].length * 4 for UCS4 fixed length string + void *bytes = (void*)(cols[i] + j * col_types[i].length * 4); if (bytes == NULL) { CHECKMEM(8); @@ -3944,6 +3945,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k if (utf8_str) free(utf8_str); goto error; } + str_l = strnlen(utf8_str, str_l); CHECKMEM(8+str_l); i64 = str_l; memcpy(out+out_idx, &i64, 8); @@ -4010,7 +4012,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k out_idx += 8; } else if (col_types[i].type == NUMPY_BYTES) { - void *bytes = (void*)(cols[i] + j * 8); + void *bytes = (void*)(cols[i] + j * col_types[i].length); if (bytes == NULL) { CHECKMEM(8); @@ -4434,7 +4436,10 @@ static PyObject *dump_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs) // Get return types n_cols = (unsigned long long)PyObject_Length(py_returns); - if (n_cols == 0) goto error; + if (n_cols == 0) { + PyErr_SetString(PyExc_ValueError, "no return values specified"); + goto error; + } returns = malloc(sizeof(int) * n_cols); if (!returns) goto error; diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index e6206d10c..7b5fe7e31 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -272,7 +272,10 @@ async def do_func( # type: ignore return row_ids, [out] # Call function on each column of data - res = get_dataframe_columns(func(*[x[0] for x in cols])) + if cols and cols[0]: + res = get_dataframe_columns(func(*[x[0] for x in cols])) + else: + res = get_dataframe_columns(func()) # Generate row IDs row_ids = array_cls([row_ids[0]] * len(res[0])) @@ -308,7 +311,10 @@ async def do_func( # type: ignore return row_ids, [out] # Call the function with `cols` as the function parameters - out = func(*[x[0] for x in cols]) + if cols and cols[0]: + out = func(*[x[0] for x in cols]) + else: + out = func() # Multiple return values if isinstance(out, tuple): @@ -717,6 +723,7 @@ async def __call__( func_info['colspec'], b''.join(data), ), ) + print(func_info['returns'], out) body = output_handler['dump']( [x[1] for x in func_info['returns']], *out, # type: ignore ) diff --git a/singlestoredb/functions/signature.py b/singlestoredb/functions/signature.py index c4b5adb2f..83edd98a8 100644 --- a/singlestoredb/functions/signature.py +++ b/singlestoredb/functions/signature.py @@ -1115,6 +1115,11 @@ def get_signature( with_null_masks=with_null_masks, ) + # All functions have to return a value, so if none was specified try to + # insert a reasonable default that includes NULLs. + if not ret_schema: + ret_schema = [('', 'int8?', 'TINYINT NULL')] + # Generate names for fields as needed if function_type == 'tvf' or len(ret_schema) > 1: for i, (name, rtype, sql) in enumerate(ret_schema): @@ -1300,7 +1305,9 @@ def signature_to_sql( res = ret[0]['sql'] returns = f' RETURNS {res}' else: - returns = ' RETURNS NULL' + raise ValueError( + 'function signature must have a return type specified', + ) host = os.environ.get('SINGLESTOREDB_EXT_HOST', '127.0.0.1') port = os.environ.get('SINGLESTOREDB_EXT_PORT', '8000') diff --git a/singlestoredb/tests/ext_funcs/__init__.py b/singlestoredb/tests/ext_funcs/__init__.py index 1d4abf1ef..74f6b25a8 100644 --- a/singlestoredb/tests/ext_funcs/__init__.py +++ b/singlestoredb/tests/ext_funcs/__init__.py @@ -9,9 +9,11 @@ from singlestoredb.functions import Masked from singlestoredb.functions import MaskedNDArray +from singlestoredb.functions import tvf from singlestoredb.functions import udf from singlestoredb.functions import udf_with_null_masks from singlestoredb.functions.dtypes import BIGINT +from singlestoredb.functions.dtypes import BLOB from singlestoredb.functions.dtypes import DOUBLE from singlestoredb.functions.dtypes import FLOAT from singlestoredb.functions.dtypes import MEDIUMINT @@ -480,3 +482,34 @@ def arrow_nullable_tinyint_mult_with_masks( x_data, x_nulls = x y_data, y_nulls = y return (pc.multiply(x_data, y_data), pc.or_(x_nulls, y_nulls)) + + +@tvf(returns=[TEXT(nullable=False, name='res')]) +def numpy_fixed_strings() -> npt.NDArray[np.str_]: + out = np.array( + [ + 'hello', + 'hi there 😜', + '😜 bye', + ], dtype=np.str_, + ) + assert str(out.dtype) == ' npt.NDArray[np.bytes_]: + out = np.array( + [ + 'hello'.encode('utf8'), + 'hi there 😜'.encode('utf8'), + '😜 bye'.encode('utf8'), + ], dtype=np.bytes_, + ) + assert str(out.dtype) == '|S13' + return out + + +@udf +def no_args_no_return_value() -> None: + pass diff --git a/singlestoredb/tests/test_ext_func.py b/singlestoredb/tests/test_ext_func.py index c3ef236c1..fa27cf66a 100755 --- a/singlestoredb/tests/test_ext_func.py +++ b/singlestoredb/tests/test_ext_func.py @@ -1193,3 +1193,44 @@ def _test_nullable_varchar_mult(self): assert desc[0].name == 'res' assert desc[0].type_code == ft.BLOB assert desc[0].null_ok is True + + def test_numpy_fixed_strings(self): + self.cur.execute('select * from numpy_fixed_strings()') + + assert [tuple(x) for x in self.cur] == [ + ('hello',), + ('hi there 😜',), + ('😜 bye',), + ] + + desc = self.cur.description + assert len(desc) == 1 + assert desc[0].name == 'res' + assert desc[0].type_code == ft.BLOB + assert desc[0].null_ok is False + + def test_numpy_fixed_binary(self): + self.cur.execute('select * from numpy_fixed_binary()') + + assert [tuple(x) for x in self.cur] == [ + ('hello'.encode('utf8') + b'\x00' * 8,), + ('hi there 😜'.encode('utf8'),), + ('😜 bye'.encode('utf8') + b'\x00' * 5,), + ] + + desc = self.cur.description + assert len(desc) == 1 + assert desc[0].name == 'res' + assert desc[0].type_code == ft.BLOB + assert desc[0].null_ok is False + + def test_no_args_no_return_value(self): + self.cur.execute('select no_args_no_return_value() as res') + + assert [tuple(x) for x in self.cur] == [(None,)] + + desc = self.cur.description + assert len(desc) == 1 + assert desc[0].name == 'res' + assert desc[0].type_code == ft.TINY + assert desc[0].null_ok is True diff --git a/singlestoredb/tests/test_udf.py b/singlestoredb/tests/test_udf.py index 78259cb9a..79fae0ed5 100755 --- a/singlestoredb/tests/test_udf.py +++ b/singlestoredb/tests/test_udf.py @@ -48,7 +48,7 @@ def foo(): ... # NULL return value def foo() -> None: ... - assert to_sql(foo) == '`foo`() RETURNS NULL' + assert to_sql(foo) == '`foo`() RETURNS TINYINT NULL' # Simple return value def foo() -> int: ... @@ -138,44 +138,44 @@ def foo(x) -> None: ... # Simple parameter def foo(x: int) -> None: ... - assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL) RETURNS NULL' + assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL) RETURNS TINYINT NULL' # Optional parameter def foo(x: Optional[int]) -> None: ... - assert to_sql(foo) == '`foo`(`x` BIGINT NULL) RETURNS NULL' + assert to_sql(foo) == '`foo`(`x` BIGINT NULL) RETURNS TINYINT NULL' # Optional parameter def foo(x: Union[int, None]) -> None: ... - assert to_sql(foo) == '`foo`(`x` BIGINT NULL) RETURNS NULL' + assert to_sql(foo) == '`foo`(`x` BIGINT NULL) RETURNS TINYINT NULL' # Optional multiple parameter types def foo(x: Union[int, float, None]) -> None: ... - assert to_sql(foo) == '`foo`(`x` DOUBLE NULL) RETURNS NULL' + assert to_sql(foo) == '`foo`(`x` DOUBLE NULL) RETURNS TINYINT NULL' # Optional parameter with custom type def foo(x: Optional[B]) -> None: ... - assert to_sql(foo) == '`foo`(`x` DOUBLE NULL) RETURNS NULL' + assert to_sql(foo) == '`foo`(`x` DOUBLE NULL) RETURNS TINYINT NULL' # Optional parameter with nested custom type def foo(x: Optional[C]) -> None: ... - assert to_sql(foo) == '`foo`(`x` DOUBLE NULL) RETURNS NULL' + assert to_sql(foo) == '`foo`(`x` DOUBLE NULL) RETURNS TINYINT NULL' # Optional parameter with collection type def foo(x: Optional[List[str]]) -> None: ... - assert to_sql(foo) == '`foo`(`x` ARRAY(TEXT NOT NULL) NULL) RETURNS NULL' + assert to_sql(foo) == '`foo`(`x` ARRAY(TEXT NOT NULL) NULL) RETURNS TINYINT NULL' # Optional parameter with nested collection type def foo(x: Optional[List[List[str]]]) -> None: ... assert to_sql(foo) == '`foo`(`x` ARRAY(ARRAY(TEXT NOT NULL) NOT NULL) NULL) ' \ - 'RETURNS NULL' + 'RETURNS TINYINT NULL' # Optional parameter with collection type with nulls def foo(x: Optional[List[Optional[str]]]) -> None: ... - assert to_sql(foo) == '`foo`(`x` ARRAY(TEXT NULL) NULL) RETURNS NULL' + assert to_sql(foo) == '`foo`(`x` ARRAY(TEXT NULL) NULL) RETURNS TINYINT NULL' # Custom type with bound def foo(x: D) -> None: ... - assert to_sql(foo) == '`foo`(`x` TEXT NOT NULL) RETURNS NULL' + assert to_sql(foo) == '`foo`(`x` TEXT NOT NULL) RETURNS TINYINT NULL' # Incompatible types def foo(x: Union[int, str]) -> None: ... @@ -209,15 +209,15 @@ def test_datetimes(self): # Datetime def foo(x: datetime.datetime) -> None: ... - assert to_sql(foo) == '`foo`(`x` DATETIME NOT NULL) RETURNS NULL' + assert to_sql(foo) == '`foo`(`x` DATETIME NOT NULL) RETURNS TINYINT NULL' # Date def foo(x: datetime.date) -> None: ... - assert to_sql(foo) == '`foo`(`x` DATE NOT NULL) RETURNS NULL' + assert to_sql(foo) == '`foo`(`x` DATE NOT NULL) RETURNS TINYINT NULL' # Time def foo(x: datetime.timedelta) -> None: ... - assert to_sql(foo) == '`foo`(`x` TIME NOT NULL) RETURNS NULL' + assert to_sql(foo) == '`foo`(`x` TIME NOT NULL) RETURNS TINYINT NULL' # Datetime + Date def foo(x: Union[datetime.datetime, datetime.date]) -> None: ... @@ -229,75 +229,76 @@ def test_numerics(self): # Ints # def foo(x: int) -> None: ... - assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL) RETURNS NULL' + assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL) RETURNS TINYINT NULL' def foo(x: np.int8) -> None: ... - assert to_sql(foo) == '`foo`(`x` TINYINT NOT NULL) RETURNS NULL' + assert to_sql(foo) == '`foo`(`x` TINYINT NOT NULL) RETURNS TINYINT NULL' def foo(x: np.int16) -> None: ... - assert to_sql(foo) == '`foo`(`x` SMALLINT NOT NULL) RETURNS NULL' + assert to_sql(foo) == '`foo`(`x` SMALLINT NOT NULL) RETURNS TINYINT NULL' def foo(x: np.int32) -> None: ... - assert to_sql(foo) == '`foo`(`x` INT NOT NULL) RETURNS NULL' + assert to_sql(foo) == '`foo`(`x` INT NOT NULL) RETURNS TINYINT NULL' def foo(x: np.int64) -> None: ... - assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL) RETURNS NULL' + assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL) RETURNS TINYINT NULL' # # Unsigned ints # def foo(x: np.uint8) -> None: ... - assert to_sql(foo) == '`foo`(`x` TINYINT UNSIGNED NOT NULL) RETURNS NULL' + assert to_sql(foo) == '`foo`(`x` TINYINT UNSIGNED NOT NULL) RETURNS TINYINT NULL' def foo(x: np.uint16) -> None: ... - assert to_sql(foo) == '`foo`(`x` SMALLINT UNSIGNED NOT NULL) RETURNS NULL' + assert to_sql(foo) == '`foo`(`x` SMALLINT UNSIGNED NOT NULL) RETURNS TINYINT NULL' def foo(x: np.uint32) -> None: ... - assert to_sql(foo) == '`foo`(`x` INT UNSIGNED NOT NULL) RETURNS NULL' + assert to_sql(foo) == '`foo`(`x` INT UNSIGNED NOT NULL) RETURNS TINYINT NULL' def foo(x: np.uint64) -> None: ... - assert to_sql(foo) == '`foo`(`x` BIGINT UNSIGNED NOT NULL) RETURNS NULL' + assert to_sql(foo) == '`foo`(`x` BIGINT UNSIGNED NOT NULL) RETURNS TINYINT NULL' # # Floats # def foo(x: float) -> None: ... - assert to_sql(foo) == '`foo`(`x` DOUBLE NOT NULL) RETURNS NULL' + assert to_sql(foo) == '`foo`(`x` DOUBLE NOT NULL) RETURNS TINYINT NULL' def foo(x: np.float32) -> None: ... - assert to_sql(foo) == '`foo`(`x` FLOAT NOT NULL) RETURNS NULL' + assert to_sql(foo) == '`foo`(`x` FLOAT NOT NULL) RETURNS TINYINT NULL' def foo(x: np.float64) -> None: ... - assert to_sql(foo) == '`foo`(`x` DOUBLE NOT NULL) RETURNS NULL' + assert to_sql(foo) == '`foo`(`x` DOUBLE NOT NULL) RETURNS TINYINT NULL' # # Type collapsing # def foo(x: Union[np.int8, np.int16]) -> None: ... - assert to_sql(foo) == '`foo`(`x` SMALLINT NOT NULL) RETURNS NULL' + assert to_sql(foo) == '`foo`(`x` SMALLINT NOT NULL) RETURNS TINYINT NULL' def foo(x: Union[np.int64, np.double]) -> None: ... - assert to_sql(foo) == '`foo`(`x` DOUBLE NOT NULL) RETURNS NULL' + assert to_sql(foo) == '`foo`(`x` DOUBLE NOT NULL) RETURNS TINYINT NULL' def foo(x: Union[int, float]) -> None: ... - assert to_sql(foo) == '`foo`(`x` DOUBLE NOT NULL) RETURNS NULL' + assert to_sql(foo) == '`foo`(`x` DOUBLE NOT NULL) RETURNS TINYINT NULL' def test_positional_and_keyword_parameters(self): # Keyword only def foo(x: int = 100) -> None: ... - assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL DEFAULT 100) RETURNS NULL' + assert to_sql(foo) == \ + '`foo`(`x` BIGINT NOT NULL DEFAULT 100) RETURNS TINYINT NULL' # Multiple keywords def foo(x: int = 100, y: float = 3.14) -> None: ... assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL DEFAULT 100, ' \ - '`y` DOUBLE NOT NULL DEFAULT 3.14e0) RETURNS NULL' + '`y` DOUBLE NOT NULL DEFAULT 3.14e0) RETURNS TINYINT NULL' # Keywords and positional def foo(a: str, b: str, x: int = 100, y: float = 3.14) -> None: ... assert to_sql(foo) == '`foo`(`a` TEXT NOT NULL, ' \ '`b` TEXT NOT NULL, ' \ '`x` BIGINT NOT NULL DEFAULT 100, ' \ - '`y` DOUBLE NOT NULL DEFAULT 3.14e0) RETURNS NULL' + '`y` DOUBLE NOT NULL DEFAULT 3.14e0) RETURNS TINYINT NULL' # Variable positional def foo(*args: int) -> None: ... From b32ac13a5d9f38ffa4c4080a4ff55b3c2dd85352 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Wed, 16 Apr 2025 15:27:08 -0500 Subject: [PATCH 15/16] Add vector type --- singlestoredb/functions/dtypes.py | 48 ++++++++++++++++++++++++++++ singlestoredb/functions/signature.py | 2 +- singlestoredb/tests/test_udf.py | 16 ++++++++++ 3 files changed, 65 insertions(+), 1 deletion(-) diff --git a/singlestoredb/functions/dtypes.py b/singlestoredb/functions/dtypes.py index b6aa02f2f..cadfbc1a7 100644 --- a/singlestoredb/functions/dtypes.py +++ b/singlestoredb/functions/dtypes.py @@ -1744,3 +1744,51 @@ def ARRAY( out = SQLString(f'ARRAY({dtype})' + _modifiers(nullable=nullable)) out.name = name return out + + +F32 = 'F32' +F64 = 'F64' +I8 = 'I8' +I16 = 'I16' +I32 = 'I32' +I64 = 'I64' + + +def VECTOR( + length: int, + element_type: str = F32, + *, + nullable: bool = True, + default: Optional[bytes] = None, + name: Optional[str] = None, +) -> SQLString: + """ + VECTOR type specification. + + Parameters + ---------- + n : int + Number of elements in vector + element_type : str, optional + Type of the elements in the vector: + F32, F64, I8, I16, I32, I64 + nullable : bool, optional + Can the value be NULL? + default : str, optional + Default value + name : str, optional + Name of the column / parameter + + Returns + ------- + SQLString + + """ + out = f'VECTOR({int(length)}, {element_type})' + out = SQLString( + out + _modifiers( + nullable=nullable, default=default, + ), + ) + out.name = name + return out diff --git a/singlestoredb/functions/signature.py b/singlestoredb/functions/signature.py index 83edd98a8..74c41fba0 100644 --- a/singlestoredb/functions/signature.py +++ b/singlestoredb/functions/signature.py @@ -788,7 +788,7 @@ def get_schema( elif utils.is_dataframe(spec) or utils.is_vector(spec): if not overrides: raise TypeError( - 'type overrides must be specified for DataFrames / Tables', + 'type overrides must be specified for vectors or DataFrames / Tables', ) # Unsuported types diff --git a/singlestoredb/tests/test_udf.py b/singlestoredb/tests/test_udf.py index 79fae0ed5..53303d77e 100755 --- a/singlestoredb/tests/test_udf.py +++ b/singlestoredb/tests/test_udf.py @@ -694,3 +694,19 @@ def test_dtypes(self): assert dt.ARRAY(dt.INT) == 'ARRAY(INT NULL) NULL' assert dt.ARRAY(dt.INT, nullable=False) == 'ARRAY(INT NULL) NOT NULL' + + assert dt.VECTOR(8) == 'VECTOR(8, F32) NULL' + assert dt.VECTOR(8, dt.F32) == 'VECTOR(8, F32) NULL' + assert dt.VECTOR(8, dt.F64) == 'VECTOR(8, F64) NULL' + assert dt.VECTOR(8, dt.I8) == 'VECTOR(8, I8) NULL' + assert dt.VECTOR(8, dt.I16) == 'VECTOR(8, I16) NULL' + assert dt.VECTOR(8, dt.I32) == 'VECTOR(8, I32) NULL' + assert dt.VECTOR(8, dt.I64) == 'VECTOR(8, I64) NULL' + + assert dt.VECTOR(8, nullable=False) == 'VECTOR(8, F32) NOT NULL' + assert dt.VECTOR(8, dt.F32, nullable=False) == 'VECTOR(8, F32) NOT NULL' + assert dt.VECTOR(8, dt.F64, nullable=False) == 'VECTOR(8, F64) NOT NULL' + assert dt.VECTOR(8, dt.I8, nullable=False) == 'VECTOR(8, I8) NOT NULL' + assert dt.VECTOR(8, dt.I16, nullable=False) == 'VECTOR(8, I16) NOT NULL' + assert dt.VECTOR(8, dt.I32, nullable=False) == 'VECTOR(8, I32) NOT NULL' + assert dt.VECTOR(8, dt.I64, nullable=False) == 'VECTOR(8, I64) NOT NULL' From 266cc4a4773d77bfc965553cec8eb83e7d769b71 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Wed, 16 Apr 2025 18:45:29 -0500 Subject: [PATCH 16/16] Fix fixed length binary values --- accel.c | 16 ++++++++++++++++ singlestoredb/functions/ext/asgi.py | 1 - singlestoredb/tests/test_ext_func.py | 4 ++-- 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/accel.c b/accel.c index 02de6fb76..9436a04d1 100644 --- a/accel.c +++ b/accel.c @@ -436,6 +436,19 @@ int ucs4_to_utf8(const uint32_t *ucs4_str, size_t ucs4_len, char **utf8_str) { return -1; } +size_t length_without_trailing_nulls(const char *str, size_t len) { + if (!str || len == 0) { + return 0; // Handle null or empty input + } + + // Start from the end of the string and move backward + while (len > 0 && str[len - 1] == '\0') { + len--; + } + + return len; +} + // // Cached int values for date/time components // @@ -4022,6 +4035,9 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k } else { Py_ssize_t str_l = col_types[i].length; CHECKMEM(8+str_l); + + str_l = length_without_trailing_nulls(bytes, str_l); + i64 = str_l; memcpy(out+out_idx, &i64, 8); out_idx += 8; diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index 7b5fe7e31..0a5780faa 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -723,7 +723,6 @@ async def __call__( func_info['colspec'], b''.join(data), ), ) - print(func_info['returns'], out) body = output_handler['dump']( [x[1] for x in func_info['returns']], *out, # type: ignore ) diff --git a/singlestoredb/tests/test_ext_func.py b/singlestoredb/tests/test_ext_func.py index fa27cf66a..651500bb6 100755 --- a/singlestoredb/tests/test_ext_func.py +++ b/singlestoredb/tests/test_ext_func.py @@ -1213,9 +1213,9 @@ def test_numpy_fixed_binary(self): self.cur.execute('select * from numpy_fixed_binary()') assert [tuple(x) for x in self.cur] == [ - ('hello'.encode('utf8') + b'\x00' * 8,), + ('hello'.encode('utf8'),), ('hi there 😜'.encode('utf8'),), - ('😜 bye'.encode('utf8') + b'\x00' * 5,), + ('😜 bye'.encode('utf8'),), ] desc = self.cur.description