Skip to content

Commit 8b92716

Browse files
committed
Improved asyncpg tests [skip ci]
1 parent 1b25460 commit 8b92716

File tree

1 file changed

+27
-18
lines changed

1 file changed

+27
-18
lines changed

tests/test_asyncpg.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncpg
22
import numpy as np
3-
from pgvector import SparseVector
3+
from pgvector import HalfVector, SparseVector, Vector
44
from pgvector.asyncpg import register_vector
55
import pytest
66

@@ -15,13 +15,15 @@ async def test_vector(self):
1515

1616
await register_vector(conn)
1717

18-
embedding = np.array([1.5, 2, 3])
19-
await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), (NULL)", embedding)
18+
embedding = Vector([1.5, 2, 3])
19+
embedding2 = np.array([4.5, 5, 6])
20+
await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), ($2), (NULL)", embedding, embedding2)
2021

2122
res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id")
22-
assert np.array_equal(res[0]['embedding'], embedding)
23+
assert np.array_equal(res[0]['embedding'], embedding.to_numpy())
2324
assert res[0]['embedding'].dtype == np.float32
24-
assert res[1]['embedding'] is None
25+
assert np.array_equal(res[1]['embedding'], embedding2)
26+
assert res[2]['embedding'] is None
2527

2628
# ensures binary format is correct
2729
text_res = await conn.fetch("SELECT embedding::text FROM asyncpg_items ORDER BY id LIMIT 1")
@@ -38,12 +40,14 @@ async def test_halfvec(self):
3840

3941
await register_vector(conn)
4042

41-
embedding = [1.5, 2, 3]
42-
await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), (NULL)", embedding)
43+
embedding = HalfVector([1.5, 2, 3])
44+
embedding2 = [4.5, 5, 6]
45+
await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), ($2), (NULL)", embedding, embedding2)
4346

4447
res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id")
45-
assert res[0]['embedding'].to_list() == [1.5, 2, 3]
46-
assert res[1]['embedding'] is None
48+
assert res[0]['embedding'] == embedding
49+
assert res[1]['embedding'] == HalfVector(embedding2)
50+
assert res[2]['embedding'] is None
4751

4852
# ensures binary format is correct
4953
text_res = await conn.fetch("SELECT embedding::text FROM asyncpg_items ORDER BY id LIMIT 1")
@@ -87,7 +91,7 @@ async def test_sparsevec(self):
8791
await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), (NULL)", embedding)
8892

8993
res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id")
90-
assert res[0]['embedding'].to_list() == [1.5, 2, 3]
94+
assert res[0]['embedding'] == embedding
9195
assert res[1]['embedding'] is None
9296

9397
# ensures binary format is correct
@@ -105,12 +109,15 @@ async def test_vector_array(self):
105109

106110
await register_vector(conn)
107111

108-
embeddings = [np.array([1.5, 2, 3]), np.array([4.5, 5, 6])]
109-
await conn.execute("INSERT INTO asyncpg_items (embeddings) VALUES (ARRAY[$1, $2]::vector[])", embeddings[0], embeddings[1])
112+
embeddings = [Vector([1.5, 2, 3]), Vector([4.5, 5, 6])]
113+
embeddings2 = [np.array([1.5, 2, 3]), np.array([4.5, 5, 6])]
114+
await conn.execute("INSERT INTO asyncpg_items (embeddings) VALUES (ARRAY[$1, $2]::vector[]), (ARRAY[$3, $4]::vector[])", embeddings[0], embeddings[1], embeddings2[0], embeddings2[1])
110115

111116
res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id")
112-
assert np.array_equal(res[0]['embeddings'][0], embeddings[0])
113-
assert np.array_equal(res[0]['embeddings'][1], embeddings[1])
117+
assert np.array_equal(res[0]['embeddings'][0], embeddings[0].to_numpy())
118+
assert np.array_equal(res[0]['embeddings'][1], embeddings[1].to_numpy())
119+
assert np.array_equal(res[1]['embeddings'][0], embeddings2[0])
120+
assert np.array_equal(res[1]['embeddings'][1], embeddings2[1])
114121

115122
await conn.close()
116123

@@ -126,10 +133,12 @@ async def init(conn):
126133
await conn.execute('DROP TABLE IF EXISTS asyncpg_items')
127134
await conn.execute('CREATE TABLE asyncpg_items (id bigserial PRIMARY KEY, embedding vector(3))')
128135

129-
embedding = np.array([1.5, 2, 3])
130-
await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), (NULL)", embedding)
136+
embedding = Vector([1.5, 2, 3])
137+
embedding2 = np.array([1.5, 2, 3])
138+
await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), ($2), (NULL)", embedding, embedding2)
131139

132140
res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id")
133-
assert np.array_equal(res[0]['embedding'], embedding)
141+
assert np.array_equal(res[0]['embedding'], embedding.to_numpy())
134142
assert res[0]['embedding'].dtype == np.float32
135-
assert res[1]['embedding'] is None
143+
assert np.array_equal(res[1]['embedding'], embedding2)
144+
assert res[2]['embedding'] is None

0 commit comments

Comments
 (0)