Skip to content

Commit 8f66d13

Browse files
committed
Add pydantic support in query arguments and UDFs
1 parent 8ab721a commit 8f66d13

4 files changed

Lines changed: 89 additions & 3 deletions

File tree

singlestoredb/functions/ext/asgi.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@
6969
except ImportError:
7070
has_cloudpickle = False
7171

72+
try:
73+
from pydantic import BaseModel
74+
has_pydantic = True
75+
except ImportError:
76+
has_pydantic = False
77+
7278

7379
logger = utils.get_logger('singlestoredb.functions.ext.asgi')
7480

@@ -138,13 +144,24 @@ def get_func_names(funcs: str) -> List[Tuple[str, str]]:
138144

139145

140146
def as_tuple(x: Any) -> Any:
141-
if hasattr(x, 'model_fields'):
142-
return tuple(x.model_fields.values())
147+
"""Convert object to tuple."""
148+
if has_pydantic and isinstance(x, BaseModel):
149+
return tuple(x.model_dump().values())
143150
if dataclasses.is_dataclass(x):
144151
return dataclasses.astuple(x)
145152
return x
146153

147154

155+
def as_list_of_tuples(x: Any) -> Any:
156+
"""Convert object to a list of tuples."""
157+
if isinstance(x, (list, tuple)) and len(x) > 0:
158+
if has_pydantic and isinstance(x[0], BaseModel):
159+
return [tuple(y.model_dump().values()) for y in x]
160+
if dataclasses.is_dataclass(x[0]):
161+
return [dataclasses.astuple(y) for y in x]
162+
return x
163+
164+
148165
def make_func(
149166
name: str,
150167
func: Callable[..., Any],
@@ -183,7 +200,7 @@ async def do_func(
183200
out_ids: List[int] = []
184201
out = []
185202
for i, res in zip(row_ids, func_map(func, rows)):
186-
out.extend(as_tuple(res))
203+
out.extend(as_list_of_tuples(res))
187204
out_ids.extend([row_ids[i]] * (len(out)-len(out_ids)))
188205
return out_ids, out
189206

singlestoredb/http/connection.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@
4343
except ImportError:
4444
has_shapely = False
4545

46+
try:
47+
import pydantic
48+
has_pydantic = True
49+
except ImportError:
50+
has_pydantic = False
51+
4652
from .. import connection
4753
from .. import fusion
4854
from .. import types
@@ -533,6 +539,9 @@ def _execute(
533539
self._expect_results = True
534540
sql_type = 'query'
535541

542+
if has_pydantic and isinstance(params, pydantic.BaseModel):
543+
params = params.model_dump()
544+
536545
self._validate_param_subs(oper, params)
537546

538547
handler = fusion.get_handler(oper)

singlestoredb/mysql/cursors.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@
88
from ..utils.debug import log_query
99
from ..utils.results import get_schema
1010

11+
try:
12+
from pydantic import BaseModel
13+
has_pydantic = True
14+
except ImportError:
15+
has_pydantic = False
16+
1117

1218
#: Regular expression for :meth:`Cursor.executemany`.
1319
#: executemany only supports simple bulk insert.
@@ -149,6 +155,8 @@ def _escape_args(self, args, conn):
149155
return tuple(literal(arg) for arg in args)
150156
elif dtype is dict or isinstance(args, dict):
151157
return {key: literal(val) for (key, val) in args.items()}
158+
elif has_pydantic and isinstance(args, BaseModel):
159+
return {key: literal(val) for (key, val) in args.model_dump().items()}
152160
# If it's not a dictionary let's try escaping it anyways.
153161
# Worst case it will throw a Value error
154162
return conn.escape(args)

singlestoredb/tests/test_basics.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import math
77
import os
88
import unittest
9+
from typing import Optional
910

1011
from requests.exceptions import InvalidJSONError
1112

@@ -28,6 +29,12 @@
2829
except ImportError:
2930
has_pygeos = False
3031

32+
try:
33+
import pydantic
34+
has_pydantic = True
35+
except ImportError:
36+
has_pydantic = False
37+
3138
import singlestoredb as s2
3239
from . import utils
3340
# import traceback
@@ -1255,6 +1262,51 @@ def test_character_lengths(self):
12551262
except Exception:
12561263
pass
12571264

1265+
def test_pydantic(self):
1266+
if not has_pydantic:
1267+
self.skipTest('Test requires pydantic')
1268+
1269+
tblname = 'foo_' + str(id(self))
1270+
1271+
class FooData(pydantic.BaseModel):
1272+
x: Optional[int]
1273+
y: Optional[float]
1274+
z: Optional[str] = None
1275+
1276+
self.cur.execute(f'''
1277+
CREATE TABLE {tblname}(
1278+
x INT,
1279+
y DOUBLE,
1280+
z TEXT
1281+
)
1282+
''')
1283+
1284+
self.cur.execute(
1285+
f'INSERT INTO {tblname}(x, y) VALUES (%(x)s, %(y)s)',
1286+
FooData(x=2, y=3.23),
1287+
)
1288+
1289+
self.cur.execute('SELECT * FROM ' + tblname)
1290+
1291+
assert list(sorted(self.cur.fetchall())) == \
1292+
list(sorted([(2, 3.23, None)]))
1293+
1294+
self.cur.executemany(
1295+
f'INSERT INTO {tblname}(x) VALUES (%(x)s)',
1296+
[FooData(x=3, y=3.12), FooData(x=10, y=100.11)],
1297+
)
1298+
1299+
self.cur.execute('SELECT * FROM ' + tblname)
1300+
1301+
assert list(sorted(self.cur.fetchall())) == \
1302+
list(
1303+
sorted([
1304+
(2, 3.23, None),
1305+
(3, None, None),
1306+
(10, None, None),
1307+
]),
1308+
)
1309+
12581310

12591311
if __name__ == '__main__':
12601312
import nose2

0 commit comments

Comments
 (0)