Skip to content

Commit bf9a0a4

Browse files
committed
Added tests for halfvec and sparsevec types with SQLAlchemy and asyncpg [skip ci]
1 parent 57b7d3b commit bf9a0a4

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

tests/test_sqlalchemy.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,29 @@ async def test_asyncpg_vector(self):
550550

551551
await engine.dispose()
552552

553+
@pytest.mark.asyncio
554+
@pytest.mark.skipif(sqlalchemy_version == 1, reason='Requires SQLAlchemy 2+')
555+
async def test_asyncpg_halfvec(self):
556+
import asyncpg
557+
558+
engine = create_async_engine('postgresql+asyncpg://localhost/pgvector_python_test')
559+
async_session = async_sessionmaker(engine, expire_on_commit=False)
560+
561+
# TODO do not throw error when types are registered
562+
# @event.listens_for(engine.sync_engine, "connect")
563+
# def connect(dbapi_connection, connection_record):
564+
# from pgvector.asyncpg import register_vector
565+
# dbapi_connection.run_async(register_vector)
566+
567+
async with async_session() as session:
568+
async with session.begin():
569+
embedding = [1, 2, 3]
570+
session.add(Item(id=1, half_embedding=embedding))
571+
item = await session.get(Item, 1)
572+
assert item.half_embedding.to_list() == embedding
573+
574+
await engine.dispose()
575+
553576
@pytest.mark.asyncio
554577
@pytest.mark.skipif(sqlalchemy_version == 1, reason='Requires SQLAlchemy 2+')
555578
async def test_asyncpg_bit(self):
@@ -566,3 +589,26 @@ async def test_asyncpg_bit(self):
566589
assert item.binary_embedding == embedding
567590

568591
await engine.dispose()
592+
593+
@pytest.mark.asyncio
594+
@pytest.mark.skipif(sqlalchemy_version == 1, reason='Requires SQLAlchemy 2+')
595+
async def test_asyncpg_sparsevec(self):
596+
import asyncpg
597+
598+
engine = create_async_engine('postgresql+asyncpg://localhost/pgvector_python_test')
599+
async_session = async_sessionmaker(engine, expire_on_commit=False)
600+
601+
# TODO do not throw error when types are registered
602+
# @event.listens_for(engine.sync_engine, "connect")
603+
# def connect(dbapi_connection, connection_record):
604+
# from pgvector.asyncpg import register_vector
605+
# dbapi_connection.run_async(register_vector)
606+
607+
async with async_session() as session:
608+
async with session.begin():
609+
embedding = [1, 2, 3]
610+
session.add(Item(id=1, sparse_embedding=embedding))
611+
item = await session.get(Item, 1)
612+
assert item.sparse_embedding.to_list() == embedding
613+
614+
await engine.dispose()

0 commit comments

Comments
 (0)