11import asyncpg
22import numpy as np
3- from pgvector import SparseVector
3+ from pgvector import HalfVector , SparseVector , Vector
44from pgvector .asyncpg import register_vector
55import pytest
66
@@ -15,13 +15,15 @@ async def test_vector(self):
1515
1616 await register_vector (conn )
1717
18- embedding = np .array ([1.5 , 2 , 3 ])
19- await conn .execute ("INSERT INTO asyncpg_items (embedding) VALUES ($1), (NULL)" , embedding )
18+ embedding = Vector ([1.5 , 2 , 3 ])
19+ embedding2 = np .array ([4.5 , 5 , 6 ])
20+ await conn .execute ("INSERT INTO asyncpg_items (embedding) VALUES ($1), ($2), (NULL)" , embedding , embedding2 )
2021
2122 res = await conn .fetch ("SELECT * FROM asyncpg_items ORDER BY id" )
22- assert np .array_equal (res [0 ]['embedding' ], embedding )
23+ assert np .array_equal (res [0 ]['embedding' ], embedding . to_numpy () )
2324 assert res [0 ]['embedding' ].dtype == np .float32
24- assert res [1 ]['embedding' ] is None
25+ assert np .array_equal (res [1 ]['embedding' ], embedding2 )
26+ assert res [2 ]['embedding' ] is None
2527
2628 # ensures binary format is correct
2729 text_res = await conn .fetch ("SELECT embedding::text FROM asyncpg_items ORDER BY id LIMIT 1" )
@@ -38,12 +40,14 @@ async def test_halfvec(self):
3840
3941 await register_vector (conn )
4042
41- embedding = [1.5 , 2 , 3 ]
42- await conn .execute ("INSERT INTO asyncpg_items (embedding) VALUES ($1), (NULL)" , embedding )
43+ embedding = HalfVector ([1.5 , 2 , 3 ])
44+ embedding2 = [4.5 , 5 , 6 ]
45+ await conn .execute ("INSERT INTO asyncpg_items (embedding) VALUES ($1), ($2), (NULL)" , embedding , embedding2 )
4346
4447 res = await conn .fetch ("SELECT * FROM asyncpg_items ORDER BY id" )
45- assert res [0 ]['embedding' ].to_list () == [1.5 , 2 , 3 ]
46- assert res [1 ]['embedding' ] is None
48+ assert res [0 ]['embedding' ] == embedding
49+ assert res [1 ]['embedding' ] == HalfVector (embedding2 )
50+ assert res [2 ]['embedding' ] is None
4751
4852 # ensures binary format is correct
4953 text_res = await conn .fetch ("SELECT embedding::text FROM asyncpg_items ORDER BY id LIMIT 1" )
@@ -87,7 +91,7 @@ async def test_sparsevec(self):
8791 await conn .execute ("INSERT INTO asyncpg_items (embedding) VALUES ($1), (NULL)" , embedding )
8892
8993 res = await conn .fetch ("SELECT * FROM asyncpg_items ORDER BY id" )
90- assert res [0 ]['embedding' ]. to_list () == [ 1.5 , 2 , 3 ]
94+ assert res [0 ]['embedding' ] == embedding
9195 assert res [1 ]['embedding' ] is None
9296
9397 # ensures binary format is correct
@@ -105,12 +109,15 @@ async def test_vector_array(self):
105109
106110 await register_vector (conn )
107111
108- embeddings = [np .array ([1.5 , 2 , 3 ]), np .array ([4.5 , 5 , 6 ])]
109- await conn .execute ("INSERT INTO asyncpg_items (embeddings) VALUES (ARRAY[$1, $2]::vector[])" , embeddings [0 ], embeddings [1 ])
112+ embeddings = [Vector ([1.5 , 2 , 3 ]), Vector ([4.5 , 5 , 6 ])]
113+ embeddings2 = [np .array ([1.5 , 2 , 3 ]), np .array ([4.5 , 5 , 6 ])]
114+ await conn .execute ("INSERT INTO asyncpg_items (embeddings) VALUES (ARRAY[$1, $2]::vector[]), (ARRAY[$3, $4]::vector[])" , embeddings [0 ], embeddings [1 ], embeddings2 [0 ], embeddings2 [1 ])
110115
111116 res = await conn .fetch ("SELECT * FROM asyncpg_items ORDER BY id" )
112- assert np .array_equal (res [0 ]['embeddings' ][0 ], embeddings [0 ])
113- assert np .array_equal (res [0 ]['embeddings' ][1 ], embeddings [1 ])
117+ assert np .array_equal (res [0 ]['embeddings' ][0 ], embeddings [0 ].to_numpy ())
118+ assert np .array_equal (res [0 ]['embeddings' ][1 ], embeddings [1 ].to_numpy ())
119+ assert np .array_equal (res [1 ]['embeddings' ][0 ], embeddings2 [0 ])
120+ assert np .array_equal (res [1 ]['embeddings' ][1 ], embeddings2 [1 ])
114121
115122 await conn .close ()
116123
@@ -126,10 +133,12 @@ async def init(conn):
126133 await conn .execute ('DROP TABLE IF EXISTS asyncpg_items' )
127134 await conn .execute ('CREATE TABLE asyncpg_items (id bigserial PRIMARY KEY, embedding vector(3))' )
128135
129- embedding = np .array ([1.5 , 2 , 3 ])
130- await conn .execute ("INSERT INTO asyncpg_items (embedding) VALUES ($1), (NULL)" , embedding )
136+ embedding = Vector ([1.5 , 2 , 3 ])
137+ embedding2 = np .array ([1.5 , 2 , 3 ])
138+ await conn .execute ("INSERT INTO asyncpg_items (embedding) VALUES ($1), ($2), (NULL)" , embedding , embedding2 )
131139
132140 res = await conn .fetch ("SELECT * FROM asyncpg_items ORDER BY id" )
133- assert np .array_equal (res [0 ]['embedding' ], embedding )
141+ assert np .array_equal (res [0 ]['embedding' ], embedding . to_numpy () )
134142 assert res [0 ]['embedding' ].dtype == np .float32
135- assert res [1 ]['embedding' ] is None
143+ assert np .array_equal (res [1 ]['embedding' ], embedding2 )
144+ assert res [2 ]['embedding' ] is None
0 commit comments