From 834d17778945eebad0618e681a49fd3a03da513d Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Tue, 22 Apr 2025 13:02:20 -0500 Subject: [PATCH 01/16] Refactor to use Table / Masked types --- .pre-commit-config.yaml | 2 +- singlestoredb/functions/__init__.py | 4 +- singlestoredb/functions/decorator.py | 118 +----- singlestoredb/functions/dtypes.py | 119 +++--- singlestoredb/functions/ext/asgi.py | 8 +- singlestoredb/functions/signature.py | 251 +++++++----- singlestoredb/functions/typing.py | 39 +- singlestoredb/functions/utils.py | 6 +- singlestoredb/http/connection.py | 4 +- singlestoredb/management/utils.py | 2 +- singlestoredb/tests/ext_funcs/__init__.py | 40 +- singlestoredb/tests/test_udf.py | 39 +- singlestoredb/tests/test_udf_returns.py | 459 ++++++++++++++++++++++ 13 files changed, 749 insertions(+), 342 deletions(-) create mode 100644 singlestoredb/tests/test_udf_returns.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8cb8b15e1..e5a4b4cf1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,7 +40,7 @@ repos: hooks: - id: setup-cfg-fmt - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.6.1 + rev: v1.15.0 hooks: - id: mypy additional_dependencies: [types-requests] diff --git a/singlestoredb/functions/__init__.py b/singlestoredb/functions/__init__.py index 01b059f76..bca08470f 100644 --- a/singlestoredb/functions/__init__.py +++ b/singlestoredb/functions/__init__.py @@ -1,6 +1,4 @@ -from .decorator import tvf # noqa: F401 -from .decorator import tvf_with_null_masks # noqa: F401 from .decorator import udf # noqa: F401 from .decorator import udf_with_null_masks # noqa: F401 from .typing import Masked # noqa: F401 -from .typing import MaskedNDArray # noqa: F401 +from .typing import Table # noqa: F401 diff --git a/singlestoredb/functions/decorator.py b/singlestoredb/functions/decorator.py index a67497012..4917944ad 100644 --- a/singlestoredb/functions/decorator.py +++ b/singlestoredb/functions/decorator.py @@ -10,6 +10,7 @@ from . import utils from .dtypes import SQLString +from .typing import Masked ParameterType = Union[ @@ -62,23 +63,10 @@ def is_valid_callable(obj: Any) -> bool: def verify_mask(obj: Any) -> bool: """Verify that the object is a tuple of two vector types.""" - if typing.get_origin(obj) is not tuple or len(typing.get_args(obj)) != 2: + if not typing.get_origin(obj) is Masked: raise TypeError( - f'Expected a tuple of two vector types, but got {type(obj)}', + f'expected a Masked type, but got {type(obj)}', ) - - args = typing.get_args(obj) - - if not utils.is_vector(args[0]): - raise TypeError( - f'Expected a vector type for the first element, but got {args[0]}', - ) - - if not utils.is_vector(args[1]): - raise TypeError( - f'Expected a vector type for the second element, but got {args[1]}', - ) - return True @@ -136,7 +124,6 @@ def _func( args: Optional[ParameterType] = None, returns: Optional[ReturnType] = None, with_null_masks: bool = False, - function_type: str = 'udf', ) -> Callable[..., Any]: """Generic wrapper for UDF and TVF decorators.""" @@ -146,7 +133,6 @@ def _func( args=expand_types(args), returns=expand_types(returns), with_null_masks=with_null_masks, - function_type=function_type, ).items() if v is not None } @@ -222,7 +208,6 @@ def udf( args=args, returns=returns, with_null_masks=False, - function_type='udf', ) @@ -270,101 +255,4 @@ def udf_with_null_masks( args=args, returns=returns, with_null_masks=True, - function_type='udf', - ) - - -def tvf( - func: Optional[Callable[..., Any]] = None, - *, - name: Optional[str] = None, - args: Optional[ParameterType] = None, - returns: Optional[ReturnType] = None, -) -> Callable[..., Any]: - """ - Define a table-valued function (TVF). - - Parameters - ---------- - func : callable, optional - The TVF to apply parameters to - name : str, optional - The name to use for the TVF in the database - args : str | Callable | List[str | Callable], optional - Specifies the data types of the function arguments. Typically, - the function data types are derived from the function parameter - annotations. These annotations can be overridden. If the function - takes a single type for all parameters, `args` can be set to a - SQL string describing all parameters. If the function takes more - than one parameter and all of the parameters are being manually - defined, a list of SQL strings may be used (one for each parameter). - A dictionary of SQL strings may be used to specify a parameter type - for a subset of parameters; the keys are the names of the - function parameters. Callables may also be used for datatypes. This - is primarily for using the functions in the ``dtypes`` module that - are associated with SQL types with all default options (e.g., ``dt.FLOAT``). - returns : str, optional - Specifies the return data type of the function. If not specified, - the type annotation from the function is used. - - Returns - ------- - Callable - - """ - return _func( - func=func, - name=name, - args=args, - returns=returns, - with_null_masks=False, - function_type='tvf', - ) - - -def tvf_with_null_masks( - func: Optional[Callable[..., Any]] = None, - *, - name: Optional[str] = None, - args: Optional[ParameterType] = None, - returns: Optional[ReturnType] = None, -) -> Callable[..., Any]: - """ - Define a table-valued function (TVF) using null masks. - - Parameters - ---------- - func : callable, optional - The TVF to apply parameters to - name : str, optional - The name to use for the TVF in the database - args : str | Callable | List[str | Callable], optional - Specifies the data types of the function arguments. Typically, - the function data types are derived from the function parameter - annotations. These annotations can be overridden. If the function - takes a single type for all parameters, `args` can be set to a - SQL string describing all parameters. If the function takes more - than one parameter and all of the parameters are being manually - defined, a list of SQL strings may be used (one for each parameter). - A dictionary of SQL strings may be used to specify a parameter type - for a subset of parameters; the keys are the names of the - function parameters. Callables may also be used for datatypes. This - is primarily for using the functions in the ``dtypes`` module that - are associated with SQL types with all default options (e.g., ``dt.FLOAT``). - returns : str, optional - Specifies the return data type of the function. If not specified, - the type annotation from the function is used. - - Returns - ------- - Callable - - """ - return _func( - func=func, - name=name, - args=args, - returns=returns, - with_null_masks=True, - function_type='tvf', ) diff --git a/singlestoredb/functions/dtypes.py b/singlestoredb/functions/dtypes.py index 4522abd22..0fe26a452 100644 --- a/singlestoredb/functions/dtypes.py +++ b/singlestoredb/functions/dtypes.py @@ -6,7 +6,6 @@ from typing import Any from typing import Callable from typing import Optional -from typing import Tuple from typing import Union from ..converters import converters @@ -1683,67 +1682,67 @@ def GEOGRAPHY( return out -def RECORD( - *args: Tuple[str, DataType], - nullable: bool = True, - name: Optional[str] = None, -) -> SQLString: - """ - RECORD type specification. - - Parameters - ---------- - *args : Tuple[str, DataType] - Field specifications - nullable : bool, optional - Can the value be NULL? - name : str, optional - Name of the column / parameter - - Returns - ------- - SQLString - - """ - assert len(args) > 0 - fields = [] - for name, value in args: - if callable(value): - fields.append(f'{escape_name(name)} {value()}') - else: - fields.append(f'{escape_name(name)} {value}') - out = SQLString(f'RECORD({", ".join(fields)})' + _modifiers(nullable=nullable)) - out.name = name - return out - - -def ARRAY( - dtype: DataType, - nullable: bool = True, - name: Optional[str] = None, -) -> SQLString: - """ - ARRAY type specification. - - Parameters - ---------- - dtype : DataType - The data type of the array elements - nullable : bool, optional - Can the value be NULL? - name : str, optional - Name of the column / parameter +# def RECORD( +# *args: Tuple[str, DataType], +# nullable: bool = True, +# name: Optional[str] = None, +# ) -> SQLString: +# """ +# RECORD type specification. +# +# Parameters +# ---------- +# *args : Tuple[str, DataType] +# Field specifications +# nullable : bool, optional +# Can the value be NULL? +# name : str, optional +# Name of the column / parameter +# +# Returns +# ------- +# SQLString +# +# """ +# assert len(args) > 0 +# fields = [] +# for name, value in args: +# if callable(value): +# fields.append(f'{escape_name(name)} {value()}') +# else: +# fields.append(f'{escape_name(name)} {value}') +# out = SQLString(f'RECORD({", ".join(fields)})' + _modifiers(nullable=nullable)) +# out.name = name +# return out - Returns - ------- - SQLString - """ - if callable(dtype): - dtype = dtype() - out = SQLString(f'ARRAY({dtype})' + _modifiers(nullable=nullable)) - out.name = name - return out +# def ARRAY( +# dtype: DataType, +# nullable: bool = True, +# name: Optional[str] = None, +# ) -> SQLString: +# """ +# ARRAY type specification. +# +# Parameters +# ---------- +# dtype : DataType +# The data type of the array elements +# nullable : bool, optional +# Can the value be NULL? +# name : str, optional +# Name of the column / parameter +# +# Returns +# ------- +# SQLString +# +# """ +# if callable(dtype): +# dtype = dtype() +# out = SQLString(f'ARRAY({dtype})' + _modifiers(nullable=nullable)) +# out.name = name +# return out # F32 = 'F32' diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index 0a5780faa..413b7cd3f 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -148,7 +148,7 @@ def as_tuple(x: Any) -> Any: if has_pydantic and isinstance(x, BaseModel): return tuple(x.model_dump().values()) if dataclasses.is_dataclass(x): - return dataclasses.astuple(x) + return dataclasses.astuple(x) # type: ignore if isinstance(x, dict): return tuple(x.values()) return tuple(x) @@ -168,6 +168,8 @@ def as_list_of_tuples(x: Any) -> Any: def get_dataframe_columns(df: Any) -> List[Any]: """Return columns of data from a dataframe/table.""" + if isinstance(df, tuple): + return list(df) rtype = str(type(df)).lower() if 'dataframe' in rtype: return [df[x] for x in df.columns] @@ -226,11 +228,11 @@ def make_func( """ attrs = getattr(func, '_singlestoredb_attrs', {}) with_null_masks = attrs.get('with_null_masks', False) - function_type = attrs.get('function_type', 'udf').lower() info: Dict[str, Any] = {} sig = get_signature(func, func_name=name) + function_type = sig.get('function_type', 'udf') args_data_format = sig.get('args_data_format', 'scalar') returns_data_format = sig.get('returns_data_format', 'scalar') @@ -746,7 +748,6 @@ async def __call__( endpoint_info['signature'], url=self.url or reflected_url, data_format=self.data_format, - function_type=endpoint_info['function_type'], ), ) body = '\n'.join(syntax).encode('utf-8') @@ -903,7 +904,6 @@ def get_create_functions( app_mode=self.app_mode, replace=replace, link=link or None, - function_type=endpoint_info['function_type'], ), ) diff --git a/singlestoredb/functions/signature.py b/singlestoredb/functions/signature.py index b36e404eb..6a0f97c1a 100644 --- a/singlestoredb/functions/signature.py +++ b/singlestoredb/functions/signature.py @@ -29,6 +29,7 @@ from . import dtypes as dt from . import utils +from .typing import Table from ..mysql.converters import escape_item # type: ignore if sys.version_info >= (3, 10): @@ -637,6 +638,34 @@ def get_namedtuple_schema( return list(utils.get_annotations(obj).items()) +def get_table_schema( + obj: Table[Any], + include_default: bool = False, +) -> List[Union[Tuple[Any, str], Tuple[Any, str, Any]]]: + """ + Get the schema of a Table. + + Parameters + ---------- + obj : Table + The Table to get the schema of + include_default : bool, optional + Whether to include the default value in the column specification + + Returns + ------- + List[Tuple[Any, str]] | List[Tuple[Any, str, Any]] + A list of tuples containing the field names and field types + + """ + if include_default: + return [ + (k, v, getattr(obj, k, NO_DEFAULT)) + for k, v in utils.get_annotations(obj).items() + ] + return list(utils.get_annotations(obj).items()) + + def get_colspec( overrides: Any, include_default: bool = False, @@ -721,25 +750,23 @@ def unpack_masked_type(obj: Any) -> Any: The unpacked type """ - if typing.get_origin(obj) is not tuple: - raise TypeError(f'masked type must be a tuple, got {obj}') + # TODO: Fix checks + # if typing.get_origin(obj) not in MASK_TYPES: + # raise TypeError(f'masked type must be a tuple, got {obj}') args = typing.get_args(obj) - if len(args) != 2: - raise TypeError(f'masked type must be a tuple of length 2, got {obj}') - if not utils.is_vector(args[0]): - raise TypeError(f'masked type must be a vector, got {args[0]}') - if not utils.is_vector(args[1]): - raise TypeError(f'masked type must be a vector, got {args[1]}') + if len(args) != 1: + raise TypeError(f'masked type must be a tuple of length 1, got {obj}') + # if not utils.is_vector(args[0]): + # raise TypeError(f'masked type must be a vector, got {args[0]}') return args[0] def get_schema( spec: Any, overrides: Optional[Union[List[str], Type[Any]]] = None, - function_type: str = 'udf', mode: str = 'parameter', with_null_masks: bool = False, -) -> Tuple[List[Tuple[str, Any, Optional[str]]], str]: +) -> Tuple[List[Tuple[str, Any, Optional[str]]], str, str]: """ Expand a return type annotation into a list of types and field names. @@ -749,8 +776,6 @@ def get_schema( The return type specification overrides : List[str], optional List of SQL type specifications for the return type - function_type : str - The type of function, either 'udf' or 'tvf' mode : str The mode of the function, either 'parameter' or 'return' with_null_masks : bool @@ -758,69 +783,88 @@ def get_schema( Returns ------- - Tuple[List[Tuple[str, Any]], str] + Tuple[List[Tuple[str, Any, Optional[str]]], str, str] A list of tuples containing the field names and field types, - the normalized data format, and optionally the SQL - definition of the type + the normalized data format, optionally the SQL + definition of the type, and the data format of the type """ colspec = [] - data_format = 'scalar' + data_format = '' + function_type = 'udf' + + origin = typing.get_origin(spec) + args = typing.get_args(spec) + args_origins = [typing.get_origin(x) if x is not None else None for x in args] # Make sure that the result of a TVF is a list or dataframe - if function_type == 'tvf' and mode == 'return': + if mode == 'return': + + # See if it's a Table subclass with annotations + if inspect.isclass(origin) and origin is Table: + function_type = 'tvf' + if utils.is_dataframe(args[0]): + if not overrides: + raise TypeError( + 'column types must be specified by the ' + '`returns=` parameter of the @udf decorator', + ) - # Use the item type from the list if it's a list - if typing.get_origin(spec) is list: - spec = typing.get_args(spec)[0] + if utils.get_module(args[0]) in ['pandas', 'polars', 'pyarrow']: + data_format = utils.get_module(args[0]) + spec = args[0] + else: + raise TypeError( + 'only pandas.DataFrames, polars.DataFrames, ' + 'and pyarrow.Tables are supported as tables.', + ) - # If it's a tuple, it must be a tuple of vectors - elif typing.get_origin(spec) is tuple: - if not all([utils.is_vector(x) for x in typing.get_args(spec)]): + elif typing.get_origin(args[0]) is list: + if len(args) != 1: + raise TypeError( + 'only one list is supported within a table; to ' + 'return multiple columns, use a NamedTuple, dataclass, ' + 'TypedDict, or pydantic model', + ) + spec = typing.get_args(args[0])[0] + data_format = 'list' + + elif not all([utils.is_vector(x) for x in args]): + # TODO: Don't fail if types are specified in np.ndarrays raise TypeError( 'return type for TVF must be a list, DataFrame / Table, ' 'or tuple of vectors', ) - # DataFrames require special handling. You can't get the schema - # from the annotation, you need a separate structure to specify - # the types. This should be specified in the overrides. - elif utils.is_dataframe(spec) or utils.is_vector(spec): - if not overrides: - raise TypeError( - 'type overrides must be specified for vectors or DataFrames / Tables', - ) - - # Unsuported types - else: + # Try to catch some common mistakes + elif origin in [tuple, dict] or tuple in args_origins or \ + ( + inspect.isclass(spec) and + ( + utils.is_dataframe(spec) + or utils.is_dataclass(spec) + or utils.is_typeddict(spec) + or utils.is_pydantic(spec) + or utils.is_namedtuple(spec) + ) + ): raise TypeError( - 'return type for TVF must be a list, DataFrame / Table, ' - 'or tuple of vectors', + 'return type for table-valued functions must be annotated with a Table,', ) - # Error out for incorrect types - elif typing.get_origin(spec) in [tuple, dict] or \ + # Error out for incorrect parameter types + elif origin in [tuple, dict] or tuple in args_origins or \ ( - # Check for optional tuples and dicts - is_union(spec) and - any([ - typing.get_origin(x) in [tuple, dict] - for x in typing.get_args(spec) - ]) - ) \ - or utils.is_dataframe(spec) \ - or utils.is_dataclass(spec) \ - or utils.is_typeddict(spec) \ - or utils.is_pydantic(spec) \ - or utils.is_namedtuple(spec): - if typing.get_origin(spec) is tuple: - raise TypeError( - f'{mode} types must be scalar or vector, got {spec}; ' - 'if you are trying to use null masks, you must use the ' - f'@{function_type}_with_null_masks decorator', - ) - else: - raise TypeError(f'{mode} types must be scalar or vector, got {spec}') + inspect.isclass(spec) and + ( + utils.is_dataframe(spec) + or utils.is_dataclass(spec) + or utils.is_typeddict(spec) + or utils.is_pydantic(spec) + or utils.is_namedtuple(spec) + ) + ): + raise TypeError(f'parameter types must be scalar or vector, got {spec}') # # Process each parameter / return type into a colspec @@ -829,15 +873,19 @@ def get_schema( # Compute overrides colspec from various formats overrides_colspec = get_colspec(overrides) + # Dataframe type + if utils.is_dataframe(spec): + colspec = overrides_colspec + # Numpy array types - if utils.is_numpy(spec): + elif utils.is_numpy(spec): data_format = 'numpy' if overrides: colspec = overrides_colspec elif len(typing.get_args(spec)) < 2: raise TypeError( 'numpy array must have a data type specified ' - 'in the @udf / @tvf decorator or with an NDArray type annotation', + 'in the @udf decorator or with an NDArray type annotation', ) else: colspec = [('', typing.get_args(spec)[1])] @@ -848,7 +896,7 @@ def get_schema( if not overrides: raise TypeError( 'pandas Series must have a data type specified ' - 'in the @udf / @tvf decorator', + 'in the @udf decorator', ) colspec = overrides_colspec @@ -858,7 +906,7 @@ def get_schema( if not overrides: raise TypeError( 'polars Series must have a data type specified ' - 'in the @udf / @tvf decorator', + 'in the @udf decorator', ) colspec = overrides_colspec @@ -868,7 +916,7 @@ def get_schema( if not overrides: raise TypeError( 'pyarrow Arrays must have a data type specified ' - 'in the @udf / @tvf decorator', + 'in the @udf decorator', ) colspec = overrides_colspec @@ -902,32 +950,36 @@ def get_schema( colspec = [('', typing.get_args(spec)[0])] # Multiple return values - elif typing.get_origin(spec) is tuple: + elif inspect.isclass(typing.get_origin(spec)) \ + and issubclass(typing.get_origin(spec), tuple): out_names, out_overrides = [], [] + + # Get the colspec for the overrides if overrides: out_colspec = [ - x for x in get_colspec( - overrides, include_default=True, - ) + x for x in get_colspec(overrides, include_default=True) ] out_names = [x[0] for x in out_colspec] out_overrides = [x[1] for x in out_colspec] + # Make sure that the number of overrides matches the number of + # return types or parameter types if out_overrides and len(typing.get_args(spec)) != len(out_overrides): raise ValueError( - 'number of return types does not match the number of ' + 'number of {mode} types does not match the number of ' 'overrides specified', ) colspec = [] out_data_formats = [] + + # Get the colspec for each item in the tuple for i, x in enumerate(typing.get_args(spec)): - out_item, out_data_format = get_schema( + out_item, out_data_format, _ = get_schema( x if not with_null_masks else unpack_masked_type(x), overrides=out_overrides[i] if out_overrides else [], - # Always use UDF mode for individual items - function_type='udf', + # Always pass UDF mode for individual items mode=mode, with_null_masks=with_null_masks, ) @@ -953,21 +1005,23 @@ def get_schema( # Since the colspec was computed by get_schema already, don't go # through the process of normalizing the dtypes again - return colspec, data_format # type: ignore + return colspec, data_format, function_type # type: ignore # Use overrides if specified elif overrides: - data_format = get_data_format(spec) + if not data_format: + data_format = get_data_format(spec) colspec = overrides_colspec # Single value, no override else: - data_format = 'scalar' + if not data_format: + data_format = 'scalar' colspec = [('', spec)] - # Normalize colspec data types out = [] + # Normalize colspec data types for k, v, *_ in colspec: out.append(( k, @@ -978,7 +1032,7 @@ def get_schema( v if isinstance(v, str) else None, )) - return out, data_format + return out, data_format, function_type def vector_check(obj: Any) -> Tuple[Any, str]: @@ -993,7 +1047,8 @@ def vector_check(obj: Any) -> Tuple[Any, str]: Returns ------- Tuple[Any, str] - The scalar type and the data format ('scalar', 'numpy', 'pandas', 'polars') + The scalar type and the data format: + 'scalar', 'list', 'numpy', 'pandas', or 'polars' """ if utils.is_numpy(obj): @@ -1001,11 +1056,17 @@ def vector_check(obj: Any) -> Tuple[Any, str]: return None, 'numpy' return typing.get_args(obj)[1], 'numpy' if utils.is_pandas_series(obj): - return None, 'pandas' + if len(typing.get_args(obj)) < 2: + return None, 'pandas' + return typing.get_args(obj)[1], 'pandas' if utils.is_polars_series(obj): return None, 'polars' if utils.is_pyarrow_array(obj): return None, 'arrow' + if obj is list or typing.get_origin(obj) is list: + if len(typing.get_args(obj)) < 1: + return None, 'list' + return typing.get_args(obj)[0], 'list' return obj, 'scalar' @@ -1028,12 +1089,11 @@ def get_signature( Dict[str, Any] ''' - signature = inspect.signature(func) + signature = inspect.signature(func, eval_str=True) args: List[Dict[str, Any]] = [] returns: List[Dict[str, Any]] = [] attrs = getattr(func, '_singlestoredb_attrs', {}) - function_type = attrs.get('function_type', 'udf') with_null_masks = attrs.get('with_null_masks', False) name = attrs.get('name', func_name if func_name else func.__name__) @@ -1061,12 +1121,12 @@ def get_signature( params = list(signature.parameters.values()) + # Get the colspec for each parameter for i, param in enumerate(params): - arg_schema, args_data_format = get_schema( + arg_schema, args_data_format, _ = get_schema( param.annotation if not with_null_masks else unpack_masked_type(param.annotation), overrides=args_overrides[i] if args_overrides else [], - function_type=function_type, mode='parameter', with_null_masks=with_null_masks, ) @@ -1076,9 +1136,10 @@ def get_signature( if not arg_schema[0][0]: args_schema.append((param.name, *arg_schema[0][1:])) + # Insert default values as needed for i, (name, atype, sql) in enumerate(args_schema): - # Get default value default_option = {} + if args_defaults: if args_defaults[i] is not NO_DEFAULT: default_option['default'] = args_defaults[i] @@ -1087,11 +1148,7 @@ def get_signature( default_option['default'] = params[i].default # Generate SQL code for the parameter - sql = sql or dtype_to_sql( - atype, - function_type=function_type, - **default_option, - ) + sql = sql or dtype_to_sql(atype, **default_option) # Add parameter to args definitions args.append(dict(name=name, dtype=atype, sql=sql, **default_option)) @@ -1106,36 +1163,33 @@ def get_signature( out['args_data_format'] = args_data_formats[0] if args_data_formats else 'scalar' # Generate the return types and the corresponding SQL code for those values - ret_schema, out['returns_data_format'] = get_schema( + ret_schema, out['returns_data_format'], function_type = get_schema( signature.return_annotation if not with_null_masks else unpack_masked_type(signature.return_annotation), overrides=attrs.get('returns', None), - function_type=function_type, mode='return', with_null_masks=with_null_masks, ) + out['returns_data_format'] = out['returns_data_format'] or 'scalar' + out['function_type'] = function_type + # All functions have to return a value, so if none was specified try to # insert a reasonable default that includes NULLs. if not ret_schema: ret_schema = [('', 'int8?', 'TINYINT NULL')] - # Generate names for fields as needed + # Generate field names for the return values if function_type == 'tvf' or len(ret_schema) > 1: for i, (name, rtype, sql) in enumerate(ret_schema): if not name: ret_schema[i] = (string.ascii_letters[i], rtype, sql) + # Generate SQL code for the return values for i, (name, rtype, sql) in enumerate(ret_schema): sql = sql or dtype_to_sql(rtype, function_type=function_type) returns.append(dict(name=name, dtype=rtype, sql=sql)) - # Copy keys from decorator to signature - copied_keys = ['database', 'environment', 'packages', 'resources', 'replace'] - for key in copied_keys: - if attrs.get(key): - out[key] = attrs[key] - # Set the function endpoint out['endpoint'] = '/invoke' @@ -1264,7 +1318,6 @@ def signature_to_sql( app_mode: str = 'remote', link: Optional[str] = None, replace: bool = False, - function_type: str = 'udf', ) -> str: ''' Convert a dictionary function signature into SQL. @@ -1280,6 +1333,8 @@ def signature_to_sql( str : SQL formatted function signature ''' + function_type = signature.get('function_type') or 'udf' + args = [] for arg in signature['args']: # Use default value from Python function if SQL doesn't set one diff --git a/singlestoredb/functions/typing.py b/singlestoredb/functions/typing.py index 848a6a503..e002e0944 100644 --- a/singlestoredb/functions/typing.py +++ b/singlestoredb/functions/typing.py @@ -1,14 +1,17 @@ +from typing import Any +from typing import Iterable from typing import Tuple from typing import TypeVar try: - import numpy as np - import numpy.typing as npt - has_numpy = True + from typing import TypeVarTuple except ImportError: - has_numpy = False + # Python 3.8 and earlier do not have TypeVarTuple + from typing_extensions import TypeVarTuple # type: ignore +T = TypeVar('T', bound=Iterable[Any]) # Generic type for iterable types + # # Masked types are used for pairs of vectors where the first element is the # vector and the second element is a boolean mask indicating which elements @@ -19,20 +22,18 @@ # This is needed for vector types that do not support NULL values, such as # numpy arrays and pandas Series. # -T = TypeVar('T') -Masked = Tuple[T, T] -# -# The MaskedNDArray type is used for pairs of numpy arrays where the first -# element is the numpy array and the second element is a boolean mask -# indicating which elements are NULL. The boolean mask is a numpy array of -# the same shape as the first element, where True indicates that the -# corresponding element in the first element is NULL. -# -# This is needed bebause numpy arrays do not support NULL values, so we need to -# use a boolean mask to indicate which elements are NULL. -# -if has_numpy: - TT = TypeVar('TT', bound=np.generic) # Generic type for NumPy data types - MaskedNDArray = Tuple[npt.NDArray[TT], npt.NDArray[np.bool_]] +class Masked(Tuple[T, T]): + def __new__(cls, *args: T) -> 'Masked[T]': + return tuple.__new__(cls, (args[0], args[1])) # type: ignore + + +Ts = TypeVarTuple('Ts') + + +class Table(Tuple[*Ts]): + """Return type for a table valued function.""" + + def __new__(cls, *args: *Ts) -> 'Table[*Ts]': + return tuple.__new__(cls, args) # type: ignore diff --git a/singlestoredb/functions/utils.py b/singlestoredb/functions/utils.py index 3b56707c7..9895085ef 100644 --- a/singlestoredb/functions/utils.py +++ b/singlestoredb/functions/utils.py @@ -23,11 +23,7 @@ def is_union(x: Any) -> bool: def get_annotations(obj: Any) -> Dict[str, Any]: """Get the annotations of an object.""" - if hasattr(inspect, 'get_annotations'): - return inspect.get_annotations(obj) - if isinstance(obj, type): - return obj.__dict__.get('__annotations__', {}) - return getattr(obj, '__annotations__', {}) + return typing.get_type_hints(obj) def get_module(obj: Any) -> str: diff --git a/singlestoredb/http/connection.py b/singlestoredb/http/connection.py index a4bb60187..874f8173d 100644 --- a/singlestoredb/http/connection.py +++ b/singlestoredb/http/connection.py @@ -648,7 +648,9 @@ def json_to_str(x: Any) -> Optional[str]: if 'UNSIGNED' in data_type: flags = 32 if data_type.endswith('BLOB') or data_type.endswith('BINARY'): - converter = functools.partial(b64decode_converter, converter) + converter = functools.partial( + b64decode_converter, converter, # type: ignore + ) charset = 63 # BINARY if type_code == 0: # DECIMAL type_code = types.ColumnType.get_code('NEWDECIMAL') diff --git a/singlestoredb/management/utils.py b/singlestoredb/management/utils.py index c7640ffc2..a1aee2236 100644 --- a/singlestoredb/management/utils.py +++ b/singlestoredb/management/utils.py @@ -73,7 +73,7 @@ def ttl_property(ttl: datetime.timedelta) -> Callable[[Any], Any]: """Property with a time-to-live.""" def wrapper(func: Callable[[Any], Any]) -> Any: out = TTLProperty(func, ttl=ttl) - return functools.wraps(func)(out) + return functools.wraps(func)(out) # type: ignore return wrapper diff --git a/singlestoredb/tests/ext_funcs/__init__.py b/singlestoredb/tests/ext_funcs/__init__.py index 74f6b25a8..e8bb013e4 100644 --- a/singlestoredb/tests/ext_funcs/__init__.py +++ b/singlestoredb/tests/ext_funcs/__init__.py @@ -8,8 +8,6 @@ import pyarrow as pa from singlestoredb.functions import Masked -from singlestoredb.functions import MaskedNDArray -from singlestoredb.functions import tvf from singlestoredb.functions import udf from singlestoredb.functions import udf_with_null_masks from singlestoredb.functions.dtypes import BIGINT @@ -20,6 +18,7 @@ from singlestoredb.functions.dtypes import SMALLINT from singlestoredb.functions.dtypes import TEXT from singlestoredb.functions.dtypes import TINYINT +from singlestoredb.functions.typing import Table @udf @@ -447,16 +446,16 @@ def pandas_nullable_tinyint_mult_with_masks( ) -> Masked[pd.Series]: x_data, x_nulls = x y_data, y_nulls = y - return (x_data * y_data, x_nulls | y_nulls) + return Masked(x_data * y_data, x_nulls | y_nulls) @udf_with_null_masks def numpy_nullable_tinyint_mult_with_masks( - x: MaskedNDArray[np.int8], y: MaskedNDArray[np.int8], -) -> MaskedNDArray[np.int8]: + x: Masked[npt.NDArray[np.int8]], y: Masked[npt.NDArray[np.int8]], +) -> Masked[npt.NDArray[np.int8]]: x_data, x_nulls = x y_data, y_nulls = y - return (x_data * y_data, x_nulls | y_nulls) + return Masked(x_data * y_data, x_nulls | y_nulls) @udf_with_null_masks( @@ -468,7 +467,7 @@ def polars_nullable_tinyint_mult_with_masks( ) -> Masked[pl.Series]: x_data, x_nulls = x y_data, y_nulls = y - return (x_data * y_data, x_nulls | y_nulls) + return Masked(x_data * y_data, x_nulls | y_nulls) @udf_with_null_masks( @@ -481,11 +480,11 @@ def arrow_nullable_tinyint_mult_with_masks( import pyarrow.compute as pc x_data, x_nulls = x y_data, y_nulls = y - return (pc.multiply(x_data, y_data), pc.or_(x_nulls, y_nulls)) + return Masked(pc.multiply(x_data, y_data), pc.or_(x_nulls, y_nulls)) -@tvf(returns=[TEXT(nullable=False, name='res')]) -def numpy_fixed_strings() -> npt.NDArray[np.str_]: +@udf(returns=[TEXT(nullable=False, name='res')]) +def numpy_fixed_strings() -> Table[npt.NDArray[np.str_]]: out = np.array( [ 'hello', @@ -494,11 +493,24 @@ def numpy_fixed_strings() -> npt.NDArray[np.str_]: ], dtype=np.str_, ) assert str(out.dtype) == ' npt.NDArray[np.bytes_]: +@udf(returns=[TEXT(nullable=False, name='res'), TINYINT(nullable=False, name='res2')]) +def numpy_fixed_strings_2() -> Table[npt.NDArray[np.str_], npt.NDArray[np.int8]]: + out = np.array( + [ + 'hello', + 'hi there 😜', + '😜 bye', + ], dtype=np.str_, + ) + assert str(out.dtype) == ' Table[npt.NDArray[np.bytes_]]: out = np.array( [ 'hello'.encode('utf8'), @@ -507,7 +519,7 @@ def numpy_fixed_binary() -> npt.NDArray[np.bytes_]: ], dtype=np.bytes_, ) assert str(out.dtype) == '|S13' - return out + return Table(out) @udf diff --git a/singlestoredb/tests/test_udf.py b/singlestoredb/tests/test_udf.py index 227105be5..16eb325d6 100755 --- a/singlestoredb/tests/test_udf.py +++ b/singlestoredb/tests/test_udf.py @@ -16,7 +16,7 @@ from ..functions import dtypes as dt from ..functions import signature as sig -from ..functions import tvf +from ..functions import Table from ..functions import udf @@ -28,10 +28,7 @@ def to_sql(x): - out = sig.signature_to_sql( - sig.get_signature(x), - function_type=getattr(x, '_singlestoredb_attrs', {}).get('function_type', 'udf'), - ) + out = sig.signature_to_sql(sig.get_signature(x)) out = re.sub(r'^CREATE EXTERNAL FUNCTION ', r'', out) out = re.sub(r' AS REMOTE SERVICE.+$', r'', out) return out.strip() @@ -392,14 +389,14 @@ class MyData: def foo(x: int) -> MyData: ... to_sql(foo) - @tvf - def foo(x: int) -> List[MyData]: ... + @udf + def foo(x: int) -> Table[List[MyData]]: ... assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL) ' \ 'RETURNS TABLE(`one` BIGINT NULL, `two` TEXT NOT NULL, ' \ '`three` DOUBLE NOT NULL)' - @tvf(returns=MyData) - def foo(x: int) -> List[Tuple[int, int, int]]: ... + @udf(returns=MyData) + def foo(x: int) -> Table[List[Tuple[int, int, int]]]: ... assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL) ' \ 'RETURNS TABLE(`one` BIGINT NULL, `two` TEXT NOT NULL, ' \ '`three` DOUBLE NOT NULL)' @@ -409,14 +406,14 @@ class MyData(pydantic.BaseModel): two: str three: float - @tvf - def foo(x: int) -> List[MyData]: ... + @udf + def foo(x: int) -> Table[List[MyData]]: ... assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL) ' \ 'RETURNS TABLE(`one` BIGINT NULL, `two` TEXT NOT NULL, ' \ '`three` DOUBLE NOT NULL)' - @tvf(returns=MyData) - def foo(x: int) -> List[Tuple[int, int, int]]: ... + @udf(returns=MyData) + def foo(x: int) -> Table[List[Tuple[int, int, int]]]: ... assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL) ' \ 'RETURNS TABLE(`one` BIGINT NULL, `two` TEXT NOT NULL, ' \ '`three` DOUBLE NOT NULL)' @@ -685,15 +682,15 @@ def test_dtypes(self): assert dt.GEOGRAPHY(nullable=False) == 'GEOGRAPHY NOT NULL' assert dt.GEOGRAPHY(default='hi') == "GEOGRAPHY NULL DEFAULT 'hi'" - with self.assertRaises(AssertionError): - dt.RECORD() - assert dt.RECORD(('a', dt.INT), ('b', dt.FLOAT)) == \ - 'RECORD(`a` INT NULL, `b` FLOAT NULL) NULL' - assert dt.RECORD(('a', dt.INT), ('b', dt.FLOAT), nullable=False) == \ - 'RECORD(`a` INT NULL, `b` FLOAT NULL) NOT NULL' + # with self.assertRaises(AssertionError): + # dt.RECORD() + # assert dt.RECORD(('a', dt.INT), ('b', dt.FLOAT)) == \ + # 'RECORD(`a` INT NULL, `b` FLOAT NULL) NULL' + # assert dt.RECORD(('a', dt.INT), ('b', dt.FLOAT), nullable=False) == \ + # 'RECORD(`a` INT NULL, `b` FLOAT NULL) NOT NULL' - assert dt.ARRAY(dt.INT) == 'ARRAY(INT NULL) NULL' - assert dt.ARRAY(dt.INT, nullable=False) == 'ARRAY(INT NULL) NOT NULL' + # assert dt.ARRAY(dt.INT) == 'ARRAY(INT NULL) NULL' + # assert dt.ARRAY(dt.INT, nullable=False) == 'ARRAY(INT NULL) NOT NULL' # assert dt.VECTOR(8) == 'VECTOR(8, F32) NULL' # assert dt.VECTOR(8, dt.F32) == 'VECTOR(8, F32) NULL' diff --git a/singlestoredb/tests/test_udf_returns.py b/singlestoredb/tests/test_udf_returns.py new file mode 100644 index 000000000..8fe105ba4 --- /dev/null +++ b/singlestoredb/tests/test_udf_returns.py @@ -0,0 +1,459 @@ +from __future__ import annotations + +import unittest +from typing import Any +from typing import Callable +from typing import List +from typing import NamedTuple +from typing import Optional +from typing import TypedDict + +import numpy as np +import numpy.typing as npt +import pandas as pd +import polars as pl +import pyarrow as pa +from pydantic import BaseModel + +from singlestoredb.functions import Table +from singlestoredb.functions import udf +from singlestoredb.functions.signature import get_signature +from singlestoredb.functions.signature import signature_to_sql + + +def to_sql(func: Callable[..., Any]) -> str: + """Convert a function signature to SQL.""" + out = signature_to_sql(get_signature(func)) + return out.split('EXTERNAL FUNCTION ')[1].split('AS REMOTE')[0].strip() + + +class Parameters(NamedTuple): + x: Optional[str] = '' + + +class UDFTuple(NamedTuple): + value: str + + +class TVFTuple(NamedTuple): + idx: int + value: Optional[str] + + +class TVFDict(TypedDict): + idx: int + value: Optional[str] + + +class TVFBaseModel(BaseModel): + idx: int + value: Optional[str] + + +class UDFResultsTest(unittest.TestCase): + + def test_udf_returns(self) -> None: + # Plain UDF + @udf + def foo_a(x: str) -> str: + return f'0: {x}' + + foo_a_out = foo_a('cat') + + assert type(foo_a_out) is str + assert foo_a_out == '0: cat' + assert to_sql(foo_a) == '`foo_a`(`x` TEXT NOT NULL) RETURNS TEXT NOT NULL' + + # Vectorized UDF using lists + @udf + def foo_b(x: List[str]) -> List[str]: + return [f'{i}: {y}' for i, y in enumerate(x)] + + foo_b_out = foo_b(['cat', 'dog', 'monkey']) + + assert type(foo_b_out) is list + assert foo_b_out == ['0: cat', '1: dog', '2: monkey'] + assert to_sql(foo_b) == '`foo_b`(`x` TEXT NOT NULL) RETURNS TEXT NOT NULL' + + # Illegal return type for UDF + @udf + def foo_c(x: List[str]) -> List[UDFTuple]: + return [UDFTuple(value='invalid')] + + # Vectorized UDF using pandas Series + @udf(args=Parameters, returns=UDFTuple) + def foo_d(x: pd.Series) -> pd.Series: + return pd.Series([f'{i}: {y}' for i, y in enumerate(x)]) + + foo_d_out = foo_d(pd.Series(['cat', 'dog', 'monkey'])) + + assert type(foo_d_out) is pd.Series + assert list(foo_d_out) == ['0: cat', '1: dog', '2: monkey'] + assert to_sql(foo_d) == "`foo_d`(`x` TEXT NULL DEFAULT '') RETURNS TEXT NOT NULL" + + # Vectorized UDF using polars Series + @udf(args=Parameters, returns=UDFTuple) + def foo_e(x: pl.Series) -> pl.Series: + return pl.Series([f'{i}: {y}' for i, y in enumerate(x)]) + + foo_e_out = foo_e(pl.Series(['cat', 'dog', 'monkey'])) + + assert type(foo_e_out) is pl.Series + assert list(foo_e_out) == ['0: cat', '1: dog', '2: monkey'] + assert to_sql(foo_e) == "`foo_e`(`x` TEXT NULL DEFAULT '') RETURNS TEXT NOT NULL" + + # Vectorized UDF using numpy arrays + @udf(args=Parameters, returns=UDFTuple) + def foo_f(x: np.ndarray) -> np.ndarray: + return np.array([f'{i}: {y}' for i, y in enumerate(x)]) + + foo_f_out = foo_f(np.array(['cat', 'dog', 'monkey'])) + + assert type(foo_f_out) is np.ndarray + assert list(foo_f_out) == ['0: cat', '1: dog', '2: monkey'] + assert to_sql(foo_f) == "`foo_f`(`x` TEXT NULL DEFAULT '') RETURNS TEXT NOT NULL" + + # Vectorized UDF using typed numpy arrays + @udf + def foo_g(x: npt.NDArray[np.str_]) -> npt.NDArray[np.str_]: + return np.array([f'{i}: {y}' for i, y in enumerate(x)]) + + foo_g_out = foo_g(np.array(['cat', 'dog', 'monkey'])) + + assert type(foo_g_out) is np.ndarray + assert list(foo_g_out) == ['0: cat', '1: dog', '2: monkey'] + assert to_sql(foo_g) == '`foo_g`(`x` TEXT NOT NULL) RETURNS TEXT NOT NULL' + + # Plain TVF using one list + @udf + def foo_h_(x: str) -> Table[List[str]]: + return Table([x] * 3) + + foo_h__out = foo_h_('cat') + + assert type(foo_h__out) is Table + assert foo_h__out == Table(['cat', 'cat', 'cat']) + + assert to_sql(foo_h_) == \ + '`foo_h_`(`x` TEXT NOT NULL) RETURNS TABLE(`a` TEXT NOT NULL)' + + # Plain TVF using multiple lists -- Illegal! + @udf + def foo_h(x: str) -> Table[List[int], List[str]]: + return Table(list(range(3)), [x] * 3) + + foo_h_out = foo_h('cat') + + assert type(foo_h_out) is Table + assert foo_h_out == Table([0, 1, 2], ['cat', 'cat', 'cat']) + + with self.assertRaises(TypeError): + to_sql(foo_h) + + # Plain TVF using lists of NamedTuples + @udf + def foo_i(x: str) -> Table[List[TVFTuple]]: + return Table([ + TVFTuple(idx=0, value=x), + TVFTuple(idx=1, value=x), + TVFTuple(idx=2, value=x), + ]) + + foo_i_out = foo_i('cat') + + assert type(foo_i_out) is Table + assert foo_i_out == Table([ + TVFTuple(idx=0, value='cat'), + TVFTuple(idx=1, value='cat'), + TVFTuple(idx=2, value='cat'), + ]) + assert to_sql(foo_i) == ( + '`foo_i`(`x` TEXT NOT NULL) ' + 'RETURNS TABLE(`idx` BIGINT NOT NULL, `value` TEXT NULL)' + ) + + # Plain TVF using lists of TypedDicts + @udf + def foo_j(x: str) -> Table[List[TVFDict]]: + return Table([ + dict(idx=0, value=x), + dict(idx=1, value=x), + dict(idx=2, value=x), + ]) + + foo_j_out = foo_j('cat') + + assert type(foo_j_out) is Table + assert foo_j_out == Table([ + dict(idx=0, value='cat'), + dict(idx=1, value='cat'), + dict(idx=2, value='cat'), + ]) + assert to_sql(foo_j) == ( + '`foo_j`(`x` TEXT NOT NULL) ' + 'RETURNS TABLE(`idx` BIGINT NOT NULL, `value` TEXT NULL)' + ) + + # Plain TVF using lists of pydantic BaseModels + @udf + def foo_k(x: str) -> Table[List[TVFBaseModel]]: + return Table([ + TVFBaseModel(idx=0, value=x), + TVFBaseModel(idx=1, value=x), + TVFBaseModel(idx=2, value=x), + ]) + + foo_k_out = foo_k('cat') + + assert type(foo_k_out) is Table + assert foo_k_out == Table([ + TVFBaseModel(idx=0, value='cat'), + TVFBaseModel(idx=1, value='cat'), + TVFBaseModel(idx=2, value='cat'), + ]) + assert to_sql(foo_k) == ( + '`foo_k`(`x` TEXT NOT NULL) ' + 'RETURNS TABLE(`idx` BIGINT NOT NULL, `value` TEXT NULL)' + ) + + # Plain TVF using pandas Series + @udf(returns=TVFTuple) + def foo_l(x: str) -> Table[pd.Series, pd.Series]: + return Table(pd.Series(range(3)), pd.Series([x] * 3)) + + foo_l_out = foo_l('cat') + + assert type(foo_l_out) is Table + assert len(foo_l_out) == 2 + assert type(foo_l_out[0]) is pd.Series + assert list(foo_l_out[0]) == [0, 1, 2] + assert type(foo_l_out[1]) is pd.Series + assert list(foo_l_out[1]) == ['cat', 'cat', 'cat'] + assert to_sql(foo_l) == ( + '`foo_l`(`x` TEXT NOT NULL) ' + 'RETURNS TABLE(`idx` BIGINT NOT NULL, `value` TEXT NULL)' + ) + + # Plain TVF using polars Series + @udf(returns=TVFTuple) + def foo_m(x: str) -> Table[pl.Series, pl.Series]: + return Table(pl.Series(range(3)), pl.Series([x] * 3)) + + foo_m_out = foo_m('cat') + + assert type(foo_m_out) is Table + assert len(foo_m_out) == 2 + assert type(foo_m_out[0]) is pl.Series + assert list(foo_m_out[0]) == [0, 1, 2] + assert type(foo_m_out[1]) is pl.Series + assert list(foo_m_out[1]) == ['cat', 'cat', 'cat'] + assert to_sql(foo_m) == ( + '`foo_m`(`x` TEXT NOT NULL) ' + 'RETURNS TABLE(`idx` BIGINT NOT NULL, `value` TEXT NULL)' + ) + + # Plain TVF using pyarrow Array + @udf(returns=TVFTuple) + def foo_n(x: str) -> Table[pa.Array, pa.Array]: + return Table(pa.array(range(3)), pa.array([x] * 3)) + + foo_n_out = foo_n('cat') + + assert type(foo_n_out) is Table + assert foo_n_out == Table(pa.array([0, 1, 2]), pa.array(['cat', 'cat', 'cat'])) + assert to_sql(foo_n) == ( + '`foo_n`(`x` TEXT NOT NULL) ' + 'RETURNS TABLE(`idx` BIGINT NOT NULL, `value` TEXT NULL)' + ) + + # Plain TVF using numpy arrays + @udf(returns=TVFTuple) + def foo_o(x: str) -> Table[np.ndarray, np.ndarray]: + return Table(np.array(range(3)), np.array([x] * 3)) + + foo_o_out = foo_o('cat') + + assert type(foo_o_out) is Table + assert len(foo_o_out) == 2 + assert type(foo_o_out[0]) is np.ndarray + assert list(foo_o_out[0]) == [0, 1, 2] + assert type(foo_o_out[1]) is np.ndarray + assert list(foo_o_out[1]) == ['cat', 'cat', 'cat'] + assert to_sql(foo_o) == ( + '`foo_o`(`x` TEXT NOT NULL) ' + 'RETURNS TABLE(`idx` BIGINT NOT NULL, `value` TEXT NULL)' + ) + + # Plain TVF using typed numpy arrays + @udf + def foo_p(x: str) -> Table[npt.NDArray[np.int_], npt.NDArray[np.str_]]: + return Table(np.array(range(3)), np.array([x] * 3)) + + foo_p_out = foo_p('cat') + + assert type(foo_p_out) is Table + assert len(foo_p_out) == 2 + assert type(foo_p_out[0]) is np.ndarray + assert list(foo_p_out[0]) == [0, 1, 2] + assert type(foo_p_out[1]) is np.ndarray + assert list(foo_p_out[1]) == ['cat', 'cat', 'cat'] + assert to_sql(foo_p) == ( + '`foo_p`(`x` TEXT NOT NULL) ' + 'RETURNS TABLE(`a` BIGINT NOT NULL, `b` TEXT NOT NULL)' + ) + + # Plain TVF using pandas DataFrame + @udf(returns=TVFTuple) + def foo_q(x: str) -> Table[pd.DataFrame]: + return Table(pd.DataFrame([[0, x], [1, x], [2, x]])) # columns??? + + foo_q_out = foo_q('cat') + + assert type(foo_q_out) is Table + assert len(foo_q_out) == 1 + assert list(foo_q_out[0].iloc[:, 0]) == [0, 1, 2] + assert list(foo_q_out[0].iloc[:, 1]) == ['cat', 'cat', 'cat'] + assert to_sql(foo_q) == ( + '`foo_q`(`x` TEXT NOT NULL) ' + 'RETURNS TABLE(`idx` BIGINT NOT NULL, `value` TEXT NULL)' + ) + + # Plain TVF using polars DataFrame + @udf(returns=TVFTuple) + def foo_r(x: str) -> Table[pl.DataFrame]: + return Table(pl.DataFrame([[0, 1, 2], [x] * 3])) # columns??? + + foo_r_out = foo_r('cat') + + assert type(foo_r_out) is Table + assert len(foo_r_out) == 1 + assert list(foo_r_out[0][:, 0]) == [0, 1, 2] + assert list(foo_r_out[0][:, 1]) == ['cat', 'cat', 'cat'] + assert to_sql(foo_r) == ( + '`foo_r`(`x` TEXT NOT NULL) ' + 'RETURNS TABLE(`idx` BIGINT NOT NULL, `value` TEXT NULL)' + ) + + # Plain TVF using pyarrow Table + @udf(returns=TVFTuple) + def foo_s(x: str) -> Table[pa.Table]: + return Table( + pa.Table.from_pylist([ + dict(idx=0, value='cat'), + dict(idx=1, value='cat'), + dict(idx=2, value='cat'), + ]), + ) # columns??? + + foo_s_out = foo_s('cat') + + assert type(foo_s_out) is Table + assert foo_s_out == Table( + pa.Table.from_pylist([ + dict(idx=0, value='cat'), + dict(idx=1, value='cat'), + dict(idx=2, value='cat'), + ]), + ) + assert to_sql(foo_s) == ( + '`foo_s`(`x` TEXT NOT NULL) ' + 'RETURNS TABLE(`idx` BIGINT NOT NULL, `value` TEXT NULL)' + ) + + # Vectorized TVF using lists -- Illegal! + @udf + def foo_t(x: List[str]) -> Table[List[int], List[str]]: + return Table(list(range(len(x))), x) + + foo_t_out = foo_t(['cat', 'dog', 'monkey']) + + assert type(foo_t_out) is Table + assert foo_t_out == Table([0, 1, 2], ['cat', 'dog', 'monkey']) + with self.assertRaises(TypeError): + to_sql(foo_t) + + # Vectorized TVF using pandas Series + @udf(args=Parameters, returns=TVFTuple) + def foo_u(x: pd.Series) -> Table[pd.Series, pd.Series]: + return Table(pd.Series(range(len(x))), pd.Series(x)) + + foo_u_out = foo_u(pd.Series(['cat', 'dog', 'monkey'])) + + assert type(foo_u_out) is Table + assert len(foo_u_out) == 2 + assert list(foo_u_out[0]) == [0, 1, 2] + assert list(foo_u_out[1]) == ['cat', 'dog', 'monkey'] + assert to_sql(foo_u) == ( + "`foo_u`(`x` TEXT NULL DEFAULT '') " + 'RETURNS TABLE(`idx` BIGINT NOT NULL, `value` TEXT NULL)' + ) + + # Vectorized TVF using polars Series + @udf(args=Parameters, returns=TVFTuple) + def foo_v(x: pl.Series) -> Table[pl.Series, pl.Series]: + return Table(pl.Series(range(len(x))), pl.Series(x)) + + foo_v_out = foo_v(pl.Series(['cat', 'dog', 'monkey'])) + + assert type(foo_v_out) is Table + assert len(foo_v_out) == 2 + assert list(foo_v_out[0]) == [0, 1, 2] + assert list(foo_v_out[1]) == ['cat', 'dog', 'monkey'] + assert to_sql(foo_v) == ( + "`foo_v`(`x` TEXT NULL DEFAULT '') " + 'RETURNS TABLE(`idx` BIGINT NOT NULL, `value` TEXT NULL)' + ) + + # Vectorized TVF using pyarrow Array + @udf(args=Parameters, returns=TVFTuple) + def foo_w(x: pa.Array) -> Table[pa.Array, pa.Array]: + return Table(pa.array(range(len(x))), pa.array(x)) + + foo_w_out = foo_w(pa.array(['cat', 'dog', 'monkey'])) + + assert type(foo_w_out) is Table + assert foo_w_out == Table( + pa.array([0, 1, 2]), pa.array(['cat', 'dog', 'monkey']), + ) + assert to_sql(foo_w) == ( + "`foo_w`(`x` TEXT NULL DEFAULT '') " + 'RETURNS TABLE(`idx` BIGINT NOT NULL, `value` TEXT NULL)' + ) + + # Vectorized TVF using numpy arrays + @udf(args=Parameters, returns=TVFTuple) + def foo_x(x: np.ndarray) -> Table[np.ndarray, np.ndarray]: + return Table(np.array(range(len(x))), np.array(x)) + + foo_x_out = foo_x(np.array(['cat', 'dog', 'monkey'])) + + assert type(foo_x_out) is Table + assert len(foo_x_out) == 2 + assert list(foo_x_out[0]) == [0, 1, 2] + assert list(foo_x_out[1]) == ['cat', 'dog', 'monkey'] + assert to_sql(foo_x) == ( + "`foo_x`(`x` TEXT NULL DEFAULT '') " + 'RETURNS TABLE(`idx` BIGINT NOT NULL, `value` TEXT NULL)' + ) + + # Vectorized TVF using typed numpy arrays + @udf + def foo_y( + x: npt.NDArray[np.str_], + ) -> Table[npt.NDArray[np.int_], npt.NDArray[np.str_]]: + return Table(np.array(range(len(x))), np.array(x)) + + foo_y_out = foo_y(np.array(['cat', 'dog', 'monkey'])) + + assert type(foo_y_out) is Table + assert len(foo_y_out) == 2 + assert list(foo_y_out[0]) == [0, 1, 2] + assert list(foo_y_out[1]) == ['cat', 'dog', 'monkey'] + assert to_sql(foo_y) == ( + '`foo_y`(`x` TEXT NOT NULL) ' + 'RETURNS TABLE(`a` BIGINT NOT NULL, `b` TEXT NOT NULL)' + ) + + +if __name__ == '__main__': + unittest.main() From 25535926b7c01954102b2e5fbd7b979afa542f3f Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Tue, 22 Apr 2025 13:04:56 -0500 Subject: [PATCH 02/16] Downgrade mypy for Python 3.8 --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e5a4b4cf1..ce627decd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,7 +40,7 @@ repos: hooks: - id: setup-cfg-fmt - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.15.0 + rev: v1.14.1 hooks: - id: mypy additional_dependencies: [types-requests] From 32c84207b98872abf17948d287e2a365d2379704 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Tue, 22 Apr 2025 13:15:25 -0500 Subject: [PATCH 03/16] Fix annotations for older versions of Python --- requirements.txt | 1 + setup.cfg | 1 + singlestoredb/functions/typing.py | 6 ++++-- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 39a11197a..bdc386103 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ requests setuptools sqlparams tomli>=1.1.0; python_version < '3.11' +typing_extensions<=4.13.2 wheel diff --git a/setup.cfg b/setup.cfg index 59a4571c8..3eb4d7aad 100644 --- a/setup.cfg +++ b/setup.cfg @@ -25,6 +25,7 @@ install_requires = requests setuptools sqlparams + typing-extensions<=4.13.2 wheel tomli>=1.1.0;python_version < '3.11' python_requires = >=3.8 diff --git a/singlestoredb/functions/typing.py b/singlestoredb/functions/typing.py index e002e0944..87fcec86b 100644 --- a/singlestoredb/functions/typing.py +++ b/singlestoredb/functions/typing.py @@ -5,9 +5,11 @@ try: from typing import TypeVarTuple + from typing import Unpack except ImportError: # Python 3.8 and earlier do not have TypeVarTuple from typing_extensions import TypeVarTuple # type: ignore + from typing_extensions import Unpack # type: ignore T = TypeVar('T', bound=Iterable[Any]) # Generic type for iterable types @@ -32,8 +34,8 @@ def __new__(cls, *args: T) -> 'Masked[T]': Ts = TypeVarTuple('Ts') -class Table(Tuple[*Ts]): +class Table(Tuple[Unpack[Ts]]): """Return type for a table valued function.""" - def __new__(cls, *args: *Ts) -> 'Table[*Ts]': + def __new__(cls, *args: Unpack[Ts]) -> 'Table[Unpack[Ts]]': return tuple.__new__(cls, args) # type: ignore From 0579b8b4cbcf77b0211d9bfbb63115c0bc56cff8 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Tue, 22 Apr 2025 13:22:02 -0500 Subject: [PATCH 04/16] Fix annotations for older versions of Python --- singlestoredb/functions/signature.py | 4 +++- singlestoredb/functions/typing.py | 6 +++--- singlestoredb/tests/test_udf_returns.py | 3 +-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/singlestoredb/functions/signature.py b/singlestoredb/functions/signature.py index 6a0f97c1a..ff2ddd1af 100644 --- a/singlestoredb/functions/signature.py +++ b/singlestoredb/functions/signature.py @@ -1089,7 +1089,7 @@ def get_signature( Dict[str, Any] ''' - signature = inspect.signature(func, eval_str=True) + signature = inspect.signature(func) args: List[Dict[str, Any]] = [] returns: List[Dict[str, Any]] = [] @@ -1106,6 +1106,8 @@ def get_signature( elif p.kind == inspect.Parameter.VAR_KEYWORD: raise TypeError('variable keyword arguments are not supported') + # TODO: Use typing.get_type_hints() for parameters / return values? + # Generate the parameter type and the corresponding SQL code for that parameter args_schema = [] args_data_formats = [] diff --git a/singlestoredb/functions/typing.py b/singlestoredb/functions/typing.py index 87fcec86b..1ef0500af 100644 --- a/singlestoredb/functions/typing.py +++ b/singlestoredb/functions/typing.py @@ -4,8 +4,8 @@ from typing import TypeVar try: - from typing import TypeVarTuple - from typing import Unpack + from typing import TypeVarTuple # type: ignore + from typing import Unpack # type: ignore except ImportError: # Python 3.8 and earlier do not have TypeVarTuple from typing_extensions import TypeVarTuple # type: ignore @@ -37,5 +37,5 @@ def __new__(cls, *args: T) -> 'Masked[T]': class Table(Tuple[Unpack[Ts]]): """Return type for a table valued function.""" - def __new__(cls, *args: Unpack[Ts]) -> 'Table[Unpack[Ts]]': + def __new__(cls, *args: Unpack[Ts]) -> 'Table[Tuple[Unpack[Ts]]]': return tuple.__new__(cls, args) # type: ignore diff --git a/singlestoredb/tests/test_udf_returns.py b/singlestoredb/tests/test_udf_returns.py index 8fe105ba4..dd317b865 100644 --- a/singlestoredb/tests/test_udf_returns.py +++ b/singlestoredb/tests/test_udf_returns.py @@ -1,5 +1,4 @@ -from __future__ import annotations - +# from __future__ import annotations import unittest from typing import Any from typing import Callable From 42f578e1b62eed5d570653bec695832170193f9e Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Tue, 22 Apr 2025 13:52:21 -0500 Subject: [PATCH 05/16] Fix annotations for older versions of Python --- singlestoredb/functions/signature.py | 2 +- singlestoredb/functions/typing.py | 2 +- singlestoredb/management/utils.py | 2 +- singlestoredb/tests/ext_funcs/__init__.py | 1 + singlestoredb/tests/test_udf_returns.py | 1 + 5 files changed, 5 insertions(+), 3 deletions(-) diff --git a/singlestoredb/functions/signature.py b/singlestoredb/functions/signature.py index ff2ddd1af..f0a638c4e 100644 --- a/singlestoredb/functions/signature.py +++ b/singlestoredb/functions/signature.py @@ -639,7 +639,7 @@ def get_namedtuple_schema( def get_table_schema( - obj: Table[Any], + obj: Any, include_default: bool = False, ) -> List[Union[Tuple[Any, str], Tuple[Any, str, Any]]]: """ diff --git a/singlestoredb/functions/typing.py b/singlestoredb/functions/typing.py index 1ef0500af..cd966b16c 100644 --- a/singlestoredb/functions/typing.py +++ b/singlestoredb/functions/typing.py @@ -37,5 +37,5 @@ def __new__(cls, *args: T) -> 'Masked[T]': class Table(Tuple[Unpack[Ts]]): """Return type for a table valued function.""" - def __new__(cls, *args: Unpack[Ts]) -> 'Table[Tuple[Unpack[Ts]]]': + def __new__(cls, *args: Unpack[Ts]) -> 'Table[Tuple[Unpack[Ts]]]': # type: ignore return tuple.__new__(cls, args) # type: ignore diff --git a/singlestoredb/management/utils.py b/singlestoredb/management/utils.py index a1aee2236..c398f006d 100644 --- a/singlestoredb/management/utils.py +++ b/singlestoredb/management/utils.py @@ -30,7 +30,7 @@ T = TypeVar('T') if sys.version_info < (3, 10): - PathLike = Union[str, os.PathLike] + PathLike = Union[str, os.PathLike] # type: ignore PathLikeABC = os.PathLike else: PathLike = Union[str, os.PathLike[str]] diff --git a/singlestoredb/tests/ext_funcs/__init__.py b/singlestoredb/tests/ext_funcs/__init__.py index e8bb013e4..0a8ee44c7 100644 --- a/singlestoredb/tests/ext_funcs/__init__.py +++ b/singlestoredb/tests/ext_funcs/__init__.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# mypy: disable-error-code="type-arg" from typing import Optional import numpy as np diff --git a/singlestoredb/tests/test_udf_returns.py b/singlestoredb/tests/test_udf_returns.py index dd317b865..eec61c100 100644 --- a/singlestoredb/tests/test_udf_returns.py +++ b/singlestoredb/tests/test_udf_returns.py @@ -1,3 +1,4 @@ +# type ignore[type-arg] # from __future__ import annotations import unittest from typing import Any From a5f1dd790ab9d164e19c08f03cdc09dc467195f0 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Tue, 22 Apr 2025 13:54:08 -0500 Subject: [PATCH 06/16] Fix annotations for older versions of Python --- singlestoredb/tests/test_udf_returns.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/singlestoredb/tests/test_udf_returns.py b/singlestoredb/tests/test_udf_returns.py index eec61c100..f1daa58de 100644 --- a/singlestoredb/tests/test_udf_returns.py +++ b/singlestoredb/tests/test_udf_returns.py @@ -1,4 +1,4 @@ -# type ignore[type-arg] +# mypy: disable-error-code="type-arg" # from __future__ import annotations import unittest from typing import Any From f337bc69b15f24456fda5c9ba1c5f976b8a3cef2 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Tue, 22 Apr 2025 13:57:40 -0500 Subject: [PATCH 07/16] Fix annotations for older versions of Python --- singlestoredb/functions/signature.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/singlestoredb/functions/signature.py b/singlestoredb/functions/signature.py index f0a638c4e..3415ec623 100644 --- a/singlestoredb/functions/signature.py +++ b/singlestoredb/functions/signature.py @@ -951,7 +951,7 @@ def get_schema( # Multiple return values elif inspect.isclass(typing.get_origin(spec)) \ - and issubclass(typing.get_origin(spec), tuple): + and issubclass(typing.get_origin(spec), tuple): # type: ignore[arg-type] out_names, out_overrides = [], [] From 2d6ff9d1cc9c0128519a20039ed6d087ae4fb1bf Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Tue, 22 Apr 2025 16:01:46 -0500 Subject: [PATCH 08/16] Add null masks --- setup.cfg | 2 +- singlestoredb/functions/__init__.py | 1 - singlestoredb/functions/decorator.py | 87 ++-------------------- singlestoredb/functions/ext/asgi.py | 57 +++++++++----- singlestoredb/functions/signature.py | 91 ++++++++++++++++------- singlestoredb/functions/typing.py | 2 +- singlestoredb/tests/ext_funcs/__init__.py | 9 +-- 7 files changed, 114 insertions(+), 135 deletions(-) diff --git a/setup.cfg b/setup.cfg index 3eb4d7aad..ae1034402 100644 --- a/setup.cfg +++ b/setup.cfg @@ -25,9 +25,9 @@ install_requires = requests setuptools sqlparams - typing-extensions<=4.13.2 wheel tomli>=1.1.0;python_version < '3.11' + typing-extensions<=4.13.2;python_version < '3.11' python_requires = >=3.8 include_package_data = True tests_require = diff --git a/singlestoredb/functions/__init__.py b/singlestoredb/functions/__init__.py index bca08470f..a156a80c9 100644 --- a/singlestoredb/functions/__init__.py +++ b/singlestoredb/functions/__init__.py @@ -1,4 +1,3 @@ from .decorator import udf # noqa: F401 -from .decorator import udf_with_null_masks # noqa: F401 from .typing import Masked # noqa: F401 from .typing import Table # noqa: F401 diff --git a/singlestoredb/functions/decorator.py b/singlestoredb/functions/decorator.py index 4917944ad..2280ed401 100644 --- a/singlestoredb/functions/decorator.py +++ b/singlestoredb/functions/decorator.py @@ -1,6 +1,5 @@ import functools import inspect -import typing from typing import Any from typing import Callable from typing import List @@ -10,7 +9,6 @@ from . import utils from .dtypes import SQLString -from .typing import Masked ParameterType = Union[ @@ -61,27 +59,6 @@ def is_valid_callable(obj: Any) -> bool: ) -def verify_mask(obj: Any) -> bool: - """Verify that the object is a tuple of two vector types.""" - if not typing.get_origin(obj) is Masked: - raise TypeError( - f'expected a Masked type, but got {type(obj)}', - ) - return True - - -def verify_masks(obj: Callable[..., Any]) -> bool: - """Verify that the function parameters and return value are all masks.""" - ann = utils.get_annotations(obj) - for name, value in ann.items(): - if not verify_mask(value): - raise TypeError( - f'Expected a vector type for the parameter {name} ' - f'in function {obj.__name__}, but got {value}', - ) - return True - - def expand_types(args: Any) -> Optional[Union[List[str], Type[Any]]]: """Expand the types for the function arguments / return values.""" if args is None: @@ -123,7 +100,6 @@ def _func( name: Optional[str] = None, args: Optional[ParameterType] = None, returns: Optional[ReturnType] = None, - with_null_masks: bool = False, ) -> Callable[..., Any]: """Generic wrapper for UDF and TVF decorators.""" @@ -132,7 +108,6 @@ def _func( name=name, args=expand_types(args), returns=expand_types(returns), - with_null_masks=with_null_masks, ).items() if v is not None } @@ -141,8 +116,6 @@ def _func( # in at that time. if func is None: def decorate(func: Callable[..., Any]) -> Callable[..., Any]: - if with_null_masks: - verify_masks(func) def wrapper(*args: Any, **kwargs: Any) -> Callable[..., Any]: return func(*args, **kwargs) # type: ignore @@ -153,9 +126,6 @@ def wrapper(*args: Any, **kwargs: Any) -> Callable[..., Any]: return decorate - if with_null_masks: - verify_masks(func) - def wrapper(*args: Any, **kwargs: Any) -> Callable[..., Any]: return func(*args, **kwargs) # type: ignore @@ -180,54 +150,7 @@ def udf( The UDF to apply parameters to name : str, optional The name to use for the UDF in the database - args : str | Callable | List[str | Callable], optional - Specifies the data types of the function arguments. Typically, - the function data types are derived from the function parameter - annotations. These annotations can be overridden. If the function - takes a single type for all parameters, `args` can be set to a - SQL string describing all parameters. If the function takes more - than one parameter and all of the parameters are being manually - defined, a list of SQL strings may be used (one for each parameter). - A dictionary of SQL strings may be used to specify a parameter type - for a subset of parameters; the keys are the names of the - function parameters. Callables may also be used for datatypes. This - is primarily for using the functions in the ``dtypes`` module that - are associated with SQL types with all default options (e.g., ``dt.FLOAT``). - returns : str, optional - Specifies the return data type of the function. If not specified, - the type annotation from the function is used. - - Returns - ------- - Callable - - """ - return _func( - func=func, - name=name, - args=args, - returns=returns, - with_null_masks=False, - ) - - -def udf_with_null_masks( - func: Optional[Callable[..., Any]] = None, - *, - name: Optional[str] = None, - args: Optional[ParameterType] = None, - returns: Optional[ReturnType] = None, -) -> Callable[..., Any]: - """ - Define a user-defined function (UDF) with null masks. - - Parameters - ---------- - func : callable, optional - The UDF to apply parameters to - name : str, optional - The name to use for the UDF in the database - args : str | Callable | List[str | Callable], optional + args : str | Type | Callable | List[str | Callable], optional Specifies the data types of the function arguments. Typically, the function data types are derived from the function parameter annotations. These annotations can be overridden. If the function @@ -240,9 +163,10 @@ def udf_with_null_masks( function parameters. Callables may also be used for datatypes. This is primarily for using the functions in the ``dtypes`` module that are associated with SQL types with all default options (e.g., ``dt.FLOAT``). - returns : str, optional - Specifies the return data type of the function. If not specified, - the type annotation from the function is used. + returns : str | Type | Callable | List[str | Callable] | Table, optional + Specifies the return data type of the function. This parameter + works the same way as `args`. If the function is a table-valued + function, the return type should be a `Table` object. Returns ------- @@ -254,5 +178,4 @@ def udf_with_null_masks( name=name, args=args, returns=returns, - with_null_masks=True, ) diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index 413b7cd3f..2c0a58d21 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -26,6 +26,7 @@ import asyncio import dataclasses import importlib.util +import inspect import io import itertools import json @@ -36,6 +37,7 @@ import sys import tempfile import textwrap +import typing import urllib import zipfile import zipimport @@ -62,6 +64,7 @@ from ...mysql.constants import FIELD_TYPE as ft from ..signature import get_signature from ..signature import signature_to_sql +from ..typing import Masked try: import cloudpickle @@ -207,6 +210,25 @@ def get_array_class(data_format: str) -> Callable[..., Any]: return array_cls +def get_masked_params(func: Callable[..., Any]) -> List[bool]: + """ + Get the list of masked parameters for the function. + + Parameters + ---------- + func : Callable + The function to call as the endpoint + + Returns + ------- + List[bool] + Boolean list of masked parameters + + """ + params = inspect.signature(func).parameters + return [typing.get_origin(x.annotation) is Masked for x in params.values()] + + def make_func( name: str, func: Callable[..., Any], @@ -226,8 +248,6 @@ def make_func( (Callable, Dict[str, Any]) """ - attrs = getattr(func, '_singlestoredb_attrs', {}) - with_null_masks = attrs.get('with_null_masks', False) info: Dict[str, Any] = {} sig = get_signature(func, func_name=name) @@ -236,6 +256,8 @@ def make_func( args_data_format = sig.get('args_data_format', 'scalar') returns_data_format = sig.get('returns_data_format', 'scalar') + masks = get_masked_params(func) + if function_type == 'tvf': # Scalar (Python) types if returns_data_format == 'scalar': @@ -265,24 +287,21 @@ async def do_func( # type: ignore # each result row, so we just have to use the same # row ID for all rows in the result. - # If `with_null_masks` is set, the function is expected to return - # a tuple of (data, mask) for each column. - if with_null_masks: - out = func(*cols) - assert isinstance(out, tuple) - row_ids = array_cls([row_ids[0]] * len(out[0][0])) - return row_ids, [out] + def build_tuple(x: Any) -> Any: + return tuple(x) if isinstance(x, Masked) else (x, None) # Call function on each column of data if cols and cols[0]: - res = get_dataframe_columns(func(*[x[0] for x in cols])) + res = get_dataframe_columns( + func(*[x if m else x[0] for x, m in zip(cols, masks)]), + ) else: res = get_dataframe_columns(func()) # Generate row IDs row_ids = array_cls([row_ids[0]] * len(res[0])) - return row_ids, [(x, None) for x in res] + return row_ids, [build_tuple(x) for x in res] else: # Scalar (Python) types @@ -305,22 +324,22 @@ async def do_func( # type: ignore '''Call function on given cols of data.''' row_ids = array_cls(row_ids) - # If `with_null_masks` is set, the function is expected to return - # a tuple of (data, mask) for each column.` - if with_null_masks: - out = func(*cols) - assert isinstance(out, tuple) - return row_ids, [out] + def build_tuple(x: Any) -> Any: + return tuple(x) if isinstance(x, Masked) else (x, None) # Call the function with `cols` as the function parameters if cols and cols[0]: - out = func(*[x[0] for x in cols]) + out = func(*[x if m else x[0] for x, m in zip(cols, masks)]) else: out = func() + # Single masked value + if isinstance(out, Masked): + return row_ids, [tuple(out)] + # Multiple return values if isinstance(out, tuple): - return row_ids, [(x, None) for x in out] + return row_ids, [build_tuple(x) for x in out] # Single return value return row_ids, [(out, None)] diff --git a/singlestoredb/functions/signature.py b/singlestoredb/functions/signature.py index 3415ec623..feb8c7b15 100644 --- a/singlestoredb/functions/signature.py +++ b/singlestoredb/functions/signature.py @@ -30,6 +30,7 @@ from . import dtypes as dt from . import utils from .typing import Table +from .typing import Masked from ..mysql.converters import escape_item # type: ignore if sys.version_info >= (3, 10): @@ -750,22 +751,15 @@ def unpack_masked_type(obj: Any) -> Any: The unpacked type """ - # TODO: Fix checks - # if typing.get_origin(obj) not in MASK_TYPES: - # raise TypeError(f'masked type must be a tuple, got {obj}') - args = typing.get_args(obj) - if len(args) != 1: - raise TypeError(f'masked type must be a tuple of length 1, got {obj}') - # if not utils.is_vector(args[0]): - # raise TypeError(f'masked type must be a vector, got {args[0]}') - return args[0] + if typing.get_origin(obj) is Masked: + return typing.get_args(obj)[0] + return obj def get_schema( spec: Any, overrides: Optional[Union[List[str], Type[Any]]] = None, mode: str = 'parameter', - with_null_masks: bool = False, ) -> Tuple[List[Tuple[str, Any, Optional[str]]], str, str]: """ Expand a return type annotation into a list of types and field names. @@ -778,8 +772,6 @@ def get_schema( List of SQL type specifications for the return type mode : str The mode of the function, either 'parameter' or 'return' - with_null_masks : bool - Whether to use null masks for the parameters and return value Returns ------- @@ -977,11 +969,10 @@ def get_schema( # Get the colspec for each item in the tuple for i, x in enumerate(typing.get_args(spec)): out_item, out_data_format, _ = get_schema( - x if not with_null_masks else unpack_masked_type(x), + unpack_masked_type(x), overrides=out_overrides[i] if out_overrides else [], # Always pass UDF mode for individual items mode=mode, - with_null_masks=with_null_masks, ) # Use the name from the overrides if specified @@ -1027,7 +1018,6 @@ def get_schema( k, collapse_dtypes( [normalize_dtype(x) for x in simplify_dtype(v)], - include_null=with_null_masks, ), v if isinstance(v, str) else None, )) @@ -1070,6 +1060,44 @@ def vector_check(obj: Any) -> Tuple[Any, str]: return obj, 'scalar' +def get_masks(func: Callable[..., Any]) -> Tuple[List[bool], List[bool]]: + """ + Get the list of masked parameters and return values for the function. + + Parameters + ---------- + func : Callable + The function to call as the endpoint + + Returns + ------- + Tuple[List[bool], List[bool]] + A Tuple containing the parameter / return value masks + as lists of booleans + + + """ + params = inspect.signature(func).parameters + returns = inspect.signature(func).return_annotation + + ret_masks = [] + if typing.get_origin(returns) is Masked: + ret_masks = [True] + elif typing.get_origin(returns) is Table: + for x in typing.get_args(returns): + if typing.get_origin(x) is Masked: + ret_masks.append(True) + else: + ret_masks.append(False) + if not any(ret_masks): + ret_masks = [] + + return ( + [typing.get_origin(x.annotation) is Masked for x in params.values()], + ret_masks, + ) + + def get_signature( func: Callable[..., Any], func_name: Optional[str] = None, @@ -1094,7 +1122,6 @@ def get_signature( returns: List[Dict[str, Any]] = [] attrs = getattr(func, '_singlestoredb_attrs', {}) - with_null_masks = attrs.get('with_null_masks', False) name = attrs.get('name', func_name if func_name else func.__name__) out: Dict[str, Any] = dict(name=name, args=args, returns=returns) @@ -1114,6 +1141,9 @@ def get_signature( args_colspec = [x for x in get_colspec(attrs.get('args', []), include_default=True)] args_overrides = [x[1] for x in args_colspec] args_defaults = [x[2] for x in args_colspec] # type: ignore + args_masks, ret_masks = get_masks(func) + + print(func, args_masks, ret_masks) if args_overrides and len(args_overrides) != len(signature.parameters): raise ValueError( @@ -1126,11 +1156,9 @@ def get_signature( # Get the colspec for each parameter for i, param in enumerate(params): arg_schema, args_data_format, _ = get_schema( - param.annotation - if not with_null_masks else unpack_masked_type(param.annotation), + unpack_masked_type(param.annotation), overrides=args_overrides[i] if args_overrides else [], mode='parameter', - with_null_masks=with_null_masks, ) args_data_formats.append(args_data_format) @@ -1138,10 +1166,10 @@ def get_signature( if not arg_schema[0][0]: args_schema.append((param.name, *arg_schema[0][1:])) - # Insert default values as needed for i, (name, atype, sql) in enumerate(args_schema): default_option = {} + # Insert default values as needed if args_defaults: if args_defaults[i] is not NO_DEFAULT: default_option['default'] = args_defaults[i] @@ -1150,7 +1178,9 @@ def get_signature( default_option['default'] = params[i].default # Generate SQL code for the parameter - sql = sql or dtype_to_sql(atype, **default_option) + sql = sql or dtype_to_sql( + atype, force_nullable=args_masks[i], **default_option, + ) # Add parameter to args definitions args.append(dict(name=name, dtype=atype, sql=sql, **default_option)) @@ -1166,11 +1196,9 @@ def get_signature( # Generate the return types and the corresponding SQL code for those values ret_schema, out['returns_data_format'], function_type = get_schema( - signature.return_annotation - if not with_null_masks else unpack_masked_type(signature.return_annotation), + unpack_masked_type(signature.return_annotation), overrides=attrs.get('returns', None), mode='return', - with_null_masks=with_null_masks, ) out['returns_data_format'] = out['returns_data_format'] or 'scalar' @@ -1189,7 +1217,11 @@ def get_signature( # Generate SQL code for the return values for i, (name, rtype, sql) in enumerate(ret_schema): - sql = sql or dtype_to_sql(rtype, function_type=function_type) + sql = sql or dtype_to_sql( + rtype, + force_nullable=ret_masks[i] if ret_masks else False, + function_type=function_type, + ) returns.append(dict(name=name, dtype=rtype, sql=sql)) # Set the function endpoint @@ -1247,6 +1279,7 @@ def dtype_to_sql( default: Any = NO_DEFAULT, field_names: Optional[List[str]] = None, function_type: str = 'udf', + force_nullable: bool = False, ) -> str: """ Convert a collapsed dtype string to a SQL type. @@ -1259,6 +1292,10 @@ def dtype_to_sql( Default value field_names : List[str], optional Field names for tuple types + function_type : str, optional + Function type, either 'udf' or 'tvf' + force_nullable : bool, optional + Whether to force the type to be nullable Returns ------- @@ -1266,7 +1303,9 @@ def dtype_to_sql( """ nullable = ' NOT NULL' - if dtype.endswith('?'): + if force_nullable: + nullable = ' NULL' + elif dtype.endswith('?'): nullable = ' NULL' dtype = dtype[:-1] elif '|null' in dtype: diff --git a/singlestoredb/functions/typing.py b/singlestoredb/functions/typing.py index cd966b16c..12770a78d 100644 --- a/singlestoredb/functions/typing.py +++ b/singlestoredb/functions/typing.py @@ -27,7 +27,7 @@ class Masked(Tuple[T, T]): - def __new__(cls, *args: T) -> 'Masked[T]': + def __new__(cls, *args: T) -> 'Masked[Tuple[T, T]]': # type: ignore return tuple.__new__(cls, (args[0], args[1])) # type: ignore diff --git a/singlestoredb/tests/ext_funcs/__init__.py b/singlestoredb/tests/ext_funcs/__init__.py index 0a8ee44c7..afa65e9ef 100644 --- a/singlestoredb/tests/ext_funcs/__init__.py +++ b/singlestoredb/tests/ext_funcs/__init__.py @@ -10,7 +10,6 @@ from singlestoredb.functions import Masked from singlestoredb.functions import udf -from singlestoredb.functions import udf_with_null_masks from singlestoredb.functions.dtypes import BIGINT from singlestoredb.functions.dtypes import BLOB from singlestoredb.functions.dtypes import DOUBLE @@ -438,7 +437,7 @@ def nullable_string_mult(x: Optional[str], times: Optional[int]) -> Optional[str return x * times -@udf_with_null_masks( +@udf( args=[TINYINT(nullable=True), TINYINT(nullable=True)], returns=TINYINT(nullable=True), ) @@ -450,7 +449,7 @@ def pandas_nullable_tinyint_mult_with_masks( return Masked(x_data * y_data, x_nulls | y_nulls) -@udf_with_null_masks +@udf def numpy_nullable_tinyint_mult_with_masks( x: Masked[npt.NDArray[np.int8]], y: Masked[npt.NDArray[np.int8]], ) -> Masked[npt.NDArray[np.int8]]: @@ -459,7 +458,7 @@ def numpy_nullable_tinyint_mult_with_masks( return Masked(x_data * y_data, x_nulls | y_nulls) -@udf_with_null_masks( +@udf( args=[TINYINT(nullable=True), TINYINT(nullable=True)], returns=TINYINT(nullable=True), ) @@ -471,7 +470,7 @@ def polars_nullable_tinyint_mult_with_masks( return Masked(x_data * y_data, x_nulls | y_nulls) -@udf_with_null_masks( +@udf( args=[TINYINT(nullable=True), TINYINT(nullable=True)], returns=TINYINT(nullable=True), ) From 5423ed7d648638a8d6e9a5c4895e1d236dc1e1b4 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Wed, 23 Apr 2025 09:54:33 -0500 Subject: [PATCH 09/16] Fix difference in numpy detection --- singlestoredb/functions/signature.py | 2 -- singlestoredb/functions/utils.py | 3 +++ 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/singlestoredb/functions/signature.py b/singlestoredb/functions/signature.py index feb8c7b15..57c2585c1 100644 --- a/singlestoredb/functions/signature.py +++ b/singlestoredb/functions/signature.py @@ -1143,8 +1143,6 @@ def get_signature( args_defaults = [x[2] for x in args_colspec] # type: ignore args_masks, ret_masks = get_masks(func) - print(func, args_masks, ret_masks) - if args_overrides and len(args_overrides) != len(signature.parameters): raise ValueError( 'number of args in the decorator does not match ' diff --git a/singlestoredb/functions/utils.py b/singlestoredb/functions/utils.py index 9895085ef..11dc92dda 100644 --- a/singlestoredb/functions/utils.py +++ b/singlestoredb/functions/utils.py @@ -45,6 +45,9 @@ def get_type_name(obj: Any) -> str: def is_numpy(obj: Any) -> bool: """Check if an object is a numpy array.""" + if str(obj).startswith('numpy.ndarray['): + return True + if inspect.isclass(obj): if get_module(obj) == 'numpy': return get_type_name(obj) == 'ndarray' From 69c1deb7054aac1d108733dcdcc3a7cb217af5b8 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Wed, 23 Apr 2025 11:22:55 -0500 Subject: [PATCH 10/16] Short circuit common valid types --- singlestoredb/functions/signature.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/singlestoredb/functions/signature.py b/singlestoredb/functions/signature.py index 57c2585c1..ffb7cd83c 100644 --- a/singlestoredb/functions/signature.py +++ b/singlestoredb/functions/signature.py @@ -794,7 +794,9 @@ def get_schema( # See if it's a Table subclass with annotations if inspect.isclass(origin) and origin is Table: + function_type = 'tvf' + if utils.is_dataframe(args[0]): if not overrides: raise TypeError( @@ -828,6 +830,10 @@ def get_schema( 'or tuple of vectors', ) + # Short circuit check for common valid types + elif utils.is_vector(spec) or spec in [str, float, int, bytes]: + pass + # Try to catch some common mistakes elif origin in [tuple, dict] or tuple in args_origins or \ ( @@ -841,9 +847,14 @@ def get_schema( ) ): raise TypeError( - 'return type for table-valued functions must be annotated with a Table,', + 'invalid return type for a UDF; ' + f'expecting a scalar or vector, but got {spec}', ) + # Short circuit check for common valid types + elif utils.is_vector(spec) or spec in [str, float, int, bytes]: + pass + # Error out for incorrect parameter types elif origin in [tuple, dict] or tuple in args_origins or \ ( From 9b046a6e480c1fb78d1d7dd06c41cde51a0e4486 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Wed, 23 Apr 2025 11:24:50 -0500 Subject: [PATCH 11/16] Add 3.13 checks --- .github/workflows/pre-commit.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 31d043ab1..7d535e37f 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -14,6 +14,7 @@ jobs: - "3.10" - "3.11" - "3.12" + - "3.13" steps: - uses: actions/checkout@v3 From 879d14fbe8524aba5d5143cda9e2da0989642cab Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Wed, 23 Apr 2025 11:28:19 -0500 Subject: [PATCH 12/16] Update autopep --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ce627decd..524ce3581 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,7 +21,7 @@ repos: exclude: singlestoredb/clients/pymysqlsv/ additional_dependencies: [flake8-typing-imports==1.12.0] - repo: https://github.com/hhatto/autopep8 - rev: v2.0.4 + rev: v2.3.1 hooks: - id: autopep8 args: [--diff] From 67b4641b0ff1bcb7a8ef697a50a7b203530f994e Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Wed, 23 Apr 2025 11:31:23 -0500 Subject: [PATCH 13/16] Add 3.13 to smoke tests --- .github/workflows/smoke-test.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/smoke-test.yml b/.github/workflows/smoke-test.yml index b1d7562ed..5fd113eaa 100644 --- a/.github/workflows/smoke-test.yml +++ b/.github/workflows/smoke-test.yml @@ -54,6 +54,7 @@ jobs: - "3.10" - "3.11" - "3.12" + - "3.13" driver: - mysql - https From be0e64ab182ef97e9a350377a5e57694a8906867 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Wed, 23 Apr 2025 16:43:25 -0500 Subject: [PATCH 14/16] Fix Table wrappers --- singlestoredb/functions/ext/asgi.py | 31 ++++++++---- singlestoredb/functions/signature.py | 2 +- singlestoredb/tests/ext_funcs/__init__.py | 57 ++++++++++++++++++++++- singlestoredb/tests/test_ext_func.py | 55 ++++++++++++++++++++++ 4 files changed, 133 insertions(+), 12 deletions(-) diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index 2c0a58d21..e518d1672 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -65,6 +65,7 @@ from ..signature import get_signature from ..signature import signature_to_sql from ..typing import Masked +from ..typing import Table try: import cloudpickle @@ -159,18 +160,28 @@ def as_tuple(x: Any) -> Any: def as_list_of_tuples(x: Any) -> Any: """Convert object to a list of tuples.""" + if isinstance(x, Table): + x = x[0] if isinstance(x, (list, tuple)) and len(x) > 0: + if isinstance(x[0], (list, tuple)): + return x if has_pydantic and isinstance(x[0], BaseModel): return [tuple(y.model_dump().values()) for y in x] if dataclasses.is_dataclass(x[0]): return [dataclasses.astuple(y) for y in x] if isinstance(x[0], dict): return [tuple(y.values()) for y in x] + return [(y,) for y in x] return x def get_dataframe_columns(df: Any) -> List[Any]: """Return columns of data from a dataframe/table.""" + if isinstance(df, Table): + if len(df) == 1: + df = df[0] + else: + return list(df) if isinstance(df, tuple): return list(df) rtype = str(type(df)).lower() @@ -259,8 +270,8 @@ def make_func( masks = get_masked_params(func) if function_type == 'tvf': - # Scalar (Python) types - if returns_data_format == 'scalar': + # Scalar / list types (row-based) + if returns_data_format in ['scalar', 'list']: async def do_func( row_ids: Sequence[int], rows: Sequence[Sequence[Any]], @@ -274,7 +285,7 @@ async def do_func( out_ids.extend([row_ids[i]] * (len(out)-len(out_ids))) return out_ids, out - # Vector formats + # Vector formats (column-based) else: array_cls = get_array_class(returns_data_format) @@ -304,8 +315,8 @@ def build_tuple(x: Any) -> Any: return row_ids, [build_tuple(x) for x in res] else: - # Scalar (Python) types - if returns_data_format == 'scalar': + # Scalar / list types (row-based) + if returns_data_format in ['scalar', 'list']: async def do_func( row_ids: Sequence[int], rows: Sequence[Sequence[Any]], @@ -313,7 +324,7 @@ async def do_func( '''Call function on given rows of data.''' return row_ids, [as_tuple(x) for x in zip(func_map(func, rows))] - # Vector formats + # Vector formats (column-based) else: array_cls = get_array_class(returns_data_format) @@ -471,8 +482,8 @@ class Application(object): response=rowdat_1_response_dict, ), (b'application/octet-stream', b'1.0', 'list'): dict( - load=rowdat_1.load_list, - dump=rowdat_1.dump_list, + load=rowdat_1.load, + dump=rowdat_1.dump, response=rowdat_1_response_dict, ), (b'application/octet-stream', b'1.0', 'pandas'): dict( @@ -501,8 +512,8 @@ class Application(object): response=json_response_dict, ), (b'application/json', b'1.0', 'list'): dict( - load=jdata.load_list, - dump=jdata.dump_list, + load=jdata.load, + dump=jdata.dump, response=json_response_dict, ), (b'application/json', b'1.0', 'pandas'): dict( diff --git a/singlestoredb/functions/signature.py b/singlestoredb/functions/signature.py index ffb7cd83c..1ead1168d 100644 --- a/singlestoredb/functions/signature.py +++ b/singlestoredb/functions/signature.py @@ -1002,7 +1002,7 @@ def get_schema( f'{", ".join(out_data_formats)}', ) - if out_data_formats: + if data_format != 'list' and out_data_formats: data_format = out_data_formats[0] # Since the colspec was computed by get_schema already, don't go diff --git a/singlestoredb/tests/ext_funcs/__init__.py b/singlestoredb/tests/ext_funcs/__init__.py index afa65e9ef..d96357f1f 100644 --- a/singlestoredb/tests/ext_funcs/__init__.py +++ b/singlestoredb/tests/ext_funcs/__init__.py @@ -1,6 +1,10 @@ #!/usr/bin/env python3 # mypy: disable-error-code="type-arg" +import typing +from typing import List +from typing import NamedTuple from typing import Optional +from typing import Tuple import numpy as np import numpy.typing as npt @@ -8,7 +12,9 @@ import polars as pl import pyarrow as pa +import singlestoredb.functions.dtypes as dt from singlestoredb.functions import Masked +from singlestoredb.functions import Table from singlestoredb.functions import udf from singlestoredb.functions.dtypes import BIGINT from singlestoredb.functions.dtypes import BLOB @@ -18,7 +24,6 @@ from singlestoredb.functions.dtypes import SMALLINT from singlestoredb.functions.dtypes import TEXT from singlestoredb.functions.dtypes import TINYINT -from singlestoredb.functions.typing import Table @udf @@ -525,3 +530,53 @@ def numpy_fixed_binary() -> Table[npt.NDArray[np.bytes_]]: @udf def no_args_no_return_value() -> None: pass + + +@udf +def table_function(n: int) -> Table[List[int]]: + return Table([10] * n) + + +@udf( + returns=[ + dt.INT(name='c_int', nullable=False), + dt.DOUBLE(name='c_float', nullable=False), + dt.TEXT(name='c_str', nullable=False), + ], +) +def table_function_tuple(n: int) -> Table[List[Tuple[int, float, str]]]: + return Table([(10, 10.0, 'ten')] * n) + + +class MyTable(NamedTuple): + c_int: int + c_float: float + c_str: str + + +@udf +def table_function_struct(n: int) -> Table[List[MyTable]]: + return Table([MyTable(10, 10.0, 'ten')] * n) + + +@udf +def vec_function( + x: npt.NDArray[np.float64], y: npt.NDArray[np.float64], +) -> npt.NDArray[np.float64]: + return x * y + + +class VecInputs(typing.NamedTuple): + x: np.int8 + y: np.int8 + + +class VecOutputs(typing.NamedTuple): + res: np.int16 + + +@udf(args=VecInputs, returns=VecOutputs) +def vec_function_ints( + x: npt.NDArray[np.int_], y: npt.NDArray[np.int_], +) -> npt.NDArray[np.int_]: + return x * y diff --git a/singlestoredb/tests/test_ext_func.py b/singlestoredb/tests/test_ext_func.py index 651500bb6..ed1eab3ba 100755 --- a/singlestoredb/tests/test_ext_func.py +++ b/singlestoredb/tests/test_ext_func.py @@ -1234,3 +1234,58 @@ def test_no_args_no_return_value(self): assert desc[0].name == 'res' assert desc[0].type_code == ft.TINY assert desc[0].null_ok is True + + def test_table_function(self): + self.cur.execute('select * from table_function(5)') + + assert [x[0] for x in self.cur] == [10, 10, 10, 10, 10] + + desc = self.cur.description + assert len(desc) == 1 + assert desc[0].name == 'a' + assert desc[0].type_code == ft.LONGLONG + assert desc[0].null_ok is False + + def test_table_function_tuple(self): + self.cur.execute('select * from table_function_tuple(3)') + + out = list(self.cur) + + assert out == [ + (10, 10.0, 'ten'), + (10, 10.0, 'ten'), + (10, 10.0, 'ten'), + ] + + desc = self.cur.description + assert len(desc) == 3 + assert desc[0].name == 'c_int' + assert desc[1].name == 'c_float' + assert desc[2].name == 'c_str' + + def test_table_function_struct(self): + self.cur.execute('select * from table_function_struct(3)') + + out = list(self.cur) + + assert out == [ + (10, 10.0, 'ten'), + (10, 10.0, 'ten'), + (10, 10.0, 'ten'), + ] + + desc = self.cur.description + assert len(desc) == 3 + assert desc[0].name == 'c_int' + assert desc[1].name == 'c_float' + assert desc[2].name == 'c_str' + + def test_vec_function(self): + self.cur.execute('select vec_function(5, 10) as res') + + assert [tuple(x) for x in self.cur] == [(50.0,)] + + def test_vec_function_ints(self): + self.cur.execute('select vec_function(5, 10) as res') + + assert [tuple(x) for x in self.cur] == [(50,)] From 5039ec23b8dcc7c009717ab5da41352c139d5631 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Thu, 24 Apr 2025 10:17:42 -0500 Subject: [PATCH 15/16] Fix masks in table results --- singlestoredb/functions/ext/asgi.py | 21 ++++- singlestoredb/functions/signature.py | 18 +++-- singlestoredb/functions/utils.py | 15 +++- singlestoredb/tests/ext_funcs/__init__.py | 42 ++++++++++ singlestoredb/tests/test_ext_func.py | 93 ++++++++++++++++++++++- 5 files changed, 175 insertions(+), 14 deletions(-) diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index e518d1672..2756461bc 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -182,8 +182,13 @@ def get_dataframe_columns(df: Any) -> List[Any]: df = df[0] else: return list(df) + + if isinstance(df, Masked): + return [df] + if isinstance(df, tuple): return list(df) + rtype = str(type(df)).lower() if 'dataframe' in rtype: return [df[x] for x in df.columns] @@ -195,6 +200,7 @@ def get_dataframe_columns(df: Any) -> List[Any]: return [df] elif 'tuple' in rtype: return list(df) + raise TypeError( 'Unsupported data type for dataframe columns: ' f'{rtype}', @@ -292,7 +298,10 @@ async def do_func( async def do_func( # type: ignore row_ids: Sequence[int], cols: Sequence[Tuple[Sequence[Any], Optional[Sequence[bool]]]], - ) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]: + ) -> Tuple[ + Sequence[int], + List[Tuple[Sequence[Any], Optional[Sequence[bool]]]], + ]: '''Call function on given cols of data.''' # NOTE: There is no way to determine which row ID belongs to # each result row, so we just have to use the same @@ -310,7 +319,10 @@ def build_tuple(x: Any) -> Any: res = get_dataframe_columns(func()) # Generate row IDs - row_ids = array_cls([row_ids[0]] * len(res[0])) + if isinstance(res[0], Masked): + row_ids = array_cls([row_ids[0]] * len(res[0][0])) + else: + row_ids = array_cls([row_ids[0]] * len(res[0])) return row_ids, [build_tuple(x) for x in res] @@ -331,7 +343,10 @@ async def do_func( async def do_func( # type: ignore row_ids: Sequence[int], cols: Sequence[Tuple[Sequence[Any], Optional[Sequence[bool]]]], - ) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]: + ) -> Tuple[ + Sequence[int], + List[Tuple[Sequence[Any], Optional[Sequence[bool]]]], + ]: '''Call function on given cols of data.''' row_ids = array_cls(row_ids) diff --git a/singlestoredb/functions/signature.py b/singlestoredb/functions/signature.py index 1ead1168d..35504401d 100644 --- a/singlestoredb/functions/signature.py +++ b/singlestoredb/functions/signature.py @@ -817,14 +817,16 @@ def get_schema( if len(args) != 1: raise TypeError( 'only one list is supported within a table; to ' - 'return multiple columns, use a NamedTuple, dataclass, ' - 'TypedDict, or pydantic model', + 'return multiple columns, use a tuple, NamedTuple, ' + 'dataclass, TypedDict, or pydantic model', ) spec = typing.get_args(args[0])[0] data_format = 'list' - elif not all([utils.is_vector(x) for x in args]): - # TODO: Don't fail if types are specified in np.ndarrays + elif all([utils.is_vector(x, include_masks=True) for x in args]): + pass + + else: raise TypeError( 'return type for TVF must be a list, DataFrame / Table, ' 'or tuple of vectors', @@ -970,7 +972,7 @@ def get_schema( # return types or parameter types if out_overrides and len(typing.get_args(spec)) != len(out_overrides): raise ValueError( - 'number of {mode} types does not match the number of ' + f'number of {mode} types does not match the number of ' 'overrides specified', ) @@ -1312,14 +1314,14 @@ def dtype_to_sql( """ nullable = ' NOT NULL' - if force_nullable: - nullable = ' NULL' - elif dtype.endswith('?'): + if dtype.endswith('?'): nullable = ' NULL' dtype = dtype[:-1] elif '|null' in dtype: nullable = ' NULL' dtype = dtype.replace('|null', '') + elif force_nullable: + nullable = ' NULL' if dtype == 'null': nullable = '' diff --git a/singlestoredb/functions/utils.py b/singlestoredb/functions/utils.py index 11dc92dda..f639fa44b 100644 --- a/singlestoredb/functions/utils.py +++ b/singlestoredb/functions/utils.py @@ -6,6 +6,7 @@ from typing import Any from typing import Dict +from .typing import Masked if sys.version_info >= (3, 10): _UNION_TYPES = {typing.Union, types.UnionType} @@ -16,6 +17,15 @@ is_dataclass = dataclasses.is_dataclass +def is_masked(obj: Any) -> bool: + """Check if an object is a Masked type.""" + origin = typing.get_origin(obj) + if origin is not None: + return origin is Masked or \ + (inspect.isclass(origin) and issubclass(origin, Masked)) + return False + + def is_union(x: Any) -> bool: """Check if the object is a Union.""" return typing.get_origin(x) in _UNION_TYPES @@ -77,12 +87,13 @@ def is_dataframe(obj: Any) -> bool: return False -def is_vector(obj: Any) -> bool: +def is_vector(obj: Any, include_masks: bool = False) -> bool: """Check if an object is a vector type.""" return is_pandas_series(obj) \ or is_polars_series(obj) \ or is_pyarrow_array(obj) \ - or is_numpy(obj) + or is_numpy(obj) \ + or is_masked(obj) def get_data_format(obj: Any) -> str: diff --git a/singlestoredb/tests/ext_funcs/__init__.py b/singlestoredb/tests/ext_funcs/__init__.py index d96357f1f..d481af9e5 100644 --- a/singlestoredb/tests/ext_funcs/__init__.py +++ b/singlestoredb/tests/ext_funcs/__init__.py @@ -580,3 +580,45 @@ def vec_function_ints( x: npt.NDArray[np.int_], y: npt.NDArray[np.int_], ) -> npt.NDArray[np.int_]: return x * y + + +class DFOutputs(typing.NamedTuple): + res: np.int16 + res2: np.float64 + + +@udf(args=VecInputs, returns=DFOutputs) +def vec_function_df( + x: npt.NDArray[np.int_], y: npt.NDArray[np.int_], +) -> Table[pd.DataFrame]: + return pd.DataFrame(dict(res=[1, 2, 3], res2=[1.1, 2.2, 3.3])) + + +class MaskOutputs(typing.NamedTuple): + res: Optional[np.int16] + + +@udf(args=VecInputs, returns=MaskOutputs) +def vec_function_ints_masked( + x: Masked[npt.NDArray[np.int_]], y: Masked[npt.NDArray[np.int_]], +) -> Table[Masked[npt.NDArray[np.int_]]]: + x_data, x_nulls = x + y_data, y_nulls = y + return Table(Masked(x_data * y_data, x_nulls | y_nulls)) + + +class MaskOutputs2(typing.NamedTuple): + res: Optional[np.int16] + res2: Optional[np.int16] + + +@udf(args=VecInputs, returns=MaskOutputs2) +def vec_function_ints_masked2( + x: Masked[npt.NDArray[np.int_]], y: Masked[npt.NDArray[np.int_]], +) -> Table[Masked[npt.NDArray[np.int_]], Masked[npt.NDArray[np.int_]]]: + x_data, x_nulls = x + y_data, y_nulls = y + return Table( + Masked(x_data * y_data, x_nulls | y_nulls), + Masked(x_data * y_data, x_nulls | y_nulls), + ) diff --git a/singlestoredb/tests/test_ext_func.py b/singlestoredb/tests/test_ext_func.py index ed1eab3ba..60e1ecf2a 100755 --- a/singlestoredb/tests/test_ext_func.py +++ b/singlestoredb/tests/test_ext_func.py @@ -1286,6 +1286,97 @@ def test_vec_function(self): assert [tuple(x) for x in self.cur] == [(50.0,)] def test_vec_function_ints(self): - self.cur.execute('select vec_function(5, 10) as res') + self.cur.execute('select vec_function_ints(5, 10) as res') + + assert [tuple(x) for x in self.cur] == [(50,)] + + def test_vec_function_df(self): + self.cur.execute('select * from vec_function_df(5, 10)') + + out = list(self.cur) + + assert out == [ + (1, 1.1), + (2, 2.2), + (3, 3.3), + ] + + desc = self.cur.description + assert len(desc) == 2 + assert desc[0].name == 'res' + assert desc[0].type_code == ft.SHORT + assert desc[0].null_ok is False + assert desc[1].name == 'res2' + assert desc[1].type_code == ft.DOUBLE + assert desc[1].null_ok is False + + def test_vec_function_ints_masked(self): + self.cur.execute('select * from vec_function_ints_masked(5, 10)') assert [tuple(x) for x in self.cur] == [(50,)] + + desc = self.cur.description + assert len(desc) == 1 + assert desc[0].name == 'res' + assert desc[0].type_code == ft.SHORT + assert desc[0].null_ok is True + + self.cur.execute('select * from vec_function_ints_masked(NULL, 10)') + + assert [tuple(x) for x in self.cur] == [(None,)] + + desc = self.cur.description + assert len(desc) == 1 + assert desc[0].name == 'res' + assert desc[0].type_code == ft.SHORT + assert desc[0].null_ok is True + + self.cur.execute('select * from vec_function_ints_masked(5, NULL)') + + assert [tuple(x) for x in self.cur] == [(None,)] + + desc = self.cur.description + assert len(desc) == 1 + assert desc[0].name == 'res' + assert desc[0].type_code == ft.SHORT + assert desc[0].null_ok is True + + def test_vec_function_ints_masked2(self): + self.cur.execute('select * from vec_function_ints_masked2(5, 10)') + + assert [tuple(x) for x in self.cur] == [(50, 50)] + + desc = self.cur.description + assert len(desc) == 2 + assert desc[0].name == 'res' + assert desc[0].type_code == ft.SHORT + assert desc[0].null_ok is True + assert desc[1].name == 'res2' + assert desc[1].type_code == ft.SHORT + assert desc[1].null_ok is True + + self.cur.execute('select * from vec_function_ints_masked2(NULL, 10)') + + assert [tuple(x) for x in self.cur] == [(None, None)] + + desc = self.cur.description + assert len(desc) == 2 + assert desc[0].name == 'res' + assert desc[0].type_code == ft.SHORT + assert desc[0].null_ok is True + assert desc[1].name == 'res2' + assert desc[1].type_code == ft.SHORT + assert desc[1].null_ok is True + + self.cur.execute('select * from vec_function_ints_masked2(5, NULL)') + + assert [tuple(x) for x in self.cur] == [(None, None)] + + desc = self.cur.description + assert len(desc) == 2 + assert desc[0].name == 'res' + assert desc[0].type_code == ft.SHORT + assert desc[0].null_ok is True + assert desc[1].name == 'res2' + assert desc[1].type_code == ft.SHORT + assert desc[1].null_ok is True From 829775e4dd2c27c61bc2adcf71ad547b7d6a1599 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Thu, 24 Apr 2025 15:39:10 -0500 Subject: [PATCH 16/16] Vector utility functions --- singlestoredb/functions/__init__.py | 9 ++ singlestoredb/functions/utils.py | 164 ++++++++++++++++++++++++++++ 2 files changed, 173 insertions(+) diff --git a/singlestoredb/functions/__init__.py b/singlestoredb/functions/__init__.py index a156a80c9..ccb9b597b 100644 --- a/singlestoredb/functions/__init__.py +++ b/singlestoredb/functions/__init__.py @@ -1,3 +1,12 @@ from .decorator import udf # noqa: F401 from .typing import Masked # noqa: F401 from .typing import Table # noqa: F401 +from .utils import VectorTypes + + +F32 = VectorTypes.F32 +F64 = VectorTypes.F64 +I8 = VectorTypes.I8 +I16 = VectorTypes.I16 +I32 = VectorTypes.I32 +I64 = VectorTypes.I64 diff --git a/singlestoredb/functions/utils.py b/singlestoredb/functions/utils.py index f639fa44b..0b2d53d69 100644 --- a/singlestoredb/functions/utils.py +++ b/singlestoredb/functions/utils.py @@ -1,10 +1,13 @@ import dataclasses import inspect +import struct import sys import types import typing +from enum import Enum from typing import Any from typing import Dict +from typing import Iterable from .typing import Masked @@ -176,3 +179,164 @@ def is_pydantic(obj: Any) -> bool: if get_module(x) == 'pydantic' and get_type_name(x) == 'BaseModel' ]) + + +class VectorTypes(str, Enum): + """Enum for vector types.""" + F16 = 'f16' + F32 = 'f32' + F64 = 'f64' + I8 = 'i8' + I16 = 'i16' + I32 = 'i32' + I64 = 'i64' + + +def unpack_vector( + obj: Any, + element_type: VectorTypes = VectorTypes.F32, +) -> Iterable[Any]: + """ + Unpack a vector from bytes. + + Parameters + ---------- + obj : Any + The object to unpack. + element_type : VectorTypes + The type of the elements in the vector. + Can be one of 'f32', 'f64', 'i8', 'i16', 'i32', or 'i64'. + Default is 'f32'. + + Returns + ------- + Iterable[Any] + The unpacked vector. + + """ + if isinstance(obj, (bytes, bytearray, list, tuple)): + if element_type == 'f32': + n = len(obj) // 4 + fmt = 'f' + elif element_type == 'f64': + n = len(obj) // 8 + fmt = 'd' + elif element_type == 'i8': + n = len(obj) + fmt = 'b' + elif element_type == 'i16': + n = len(obj) // 2 + fmt = 'h' + elif element_type == 'i32': + n = len(obj) // 4 + fmt = 'i' + elif element_type == 'i64': + n = len(obj) // 8 + fmt = 'q' + else: + raise ValueError(f'unsupported element type: {element_type}') + + if isinstance(obj, (bytes, bytearray)): + return struct.unpack(f'<{n}{fmt}', obj) + return tuple([struct.unpack(f'<{n}{fmt}', x) for x in obj]) + + if element_type == 'f32': + np_type = 'f4' + elif element_type == 'f64': + np_type = 'f8' + elif element_type == 'i8': + np_type = 'i1' + elif element_type == 'i16': + np_type = 'i2' + elif element_type == 'i32': + np_type = 'i4' + elif element_type == 'i64': + np_type = 'i8' + else: + raise ValueError(f'unsupported element type: {element_type}') + + if is_numpy(obj): + import numpy as np + return np.array([np.frombuffer(x, dtype=np_type) for x in obj]) + + if is_pandas_series(obj): + import numpy as np + import pandas as pd + return pd.Series([np.frombuffer(x, dtype=np_type) for x in obj]) + + if is_polars_series(obj): + import numpy as np + import polars as pl + return pl.Series([np.frombuffer(x, dtype=np_type) for x in obj]) + + if is_pyarrow_array(obj): + import numpy as np + import pyarrow as pa + return pa.array([np.frombuffer(x, dtype=np_type) for x in obj]) + + raise ValueError( + f'unsupported object type: {type(obj)}', + ) + + +def pack_vector( + obj: Any, + element_type: VectorTypes = VectorTypes.F32, +) -> bytes: + """ + Pack a vector into bytes. + + Parameters + ---------- + obj : Any + The object to pack. + element_type : VectorTypes + The type of the elements in the vector. + Can be one of 'f32', 'f64', 'i8', 'i16', 'i32', or 'i64'. + Default is 'f32'. + + Returns + ------- + bytes + The packed vector. + + """ + if element_type == 'f32': + fmt = 'f' + elif element_type == 'f64': + fmt = 'd' + elif element_type == 'i8': + fmt = 'b' + elif element_type == 'i16': + fmt = 'h' + elif element_type == 'i32': + fmt = 'i' + elif element_type == 'i64': + fmt = 'q' + else: + raise ValueError(f'unsupported element type: {element_type}') + + if isinstance(obj, (list, tuple)): + return struct.pack(f'<{len(obj)}{fmt}', *obj) + + elif is_numpy(obj): + return obj.tobytes() + + elif is_pandas_series(obj): + # TODO: Nested vectors + import pandas as pd + return pd.Series(obj).to_numpy().tobytes() + + elif is_polars_series(obj): + # TODO: Nested vectors + import polars as pl + return pl.Series(obj).to_numpy().tobytes() + + elif is_pyarrow_array(obj): + # TODO: Nested vectors + import pyarrow as pa + return pa.array(obj).to_numpy().tobytes() + + raise ValueError( + f'unsupported object type: {type(obj)}', + )