Skip to content

Commit c7cd058

Browse files
committed
Improved tests [skip ci]
1 parent bb3b32c commit c7cd058

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

tests/test_django.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def test_vector_l1_distance(self):
199199
def test_halfvec(self):
200200
Item(id=1, half_embedding=[1, 2, 3]).save()
201201
item = Item.objects.get(pk=1)
202-
assert item.half_embedding.to_list() == [1, 2, 3]
202+
assert item.half_embedding == HalfVector([1, 2, 3])
203203

204204
def test_halfvec_l2_distance(self):
205205
create_items()
@@ -251,7 +251,7 @@ def test_bit_jaccard_distance(self):
251251
def test_sparsevec(self):
252252
Item(id=1, sparse_embedding=SparseVector([1, 2, 3])).save()
253253
item = Item.objects.get(pk=1)
254-
assert item.sparse_embedding.to_list() == [1, 2, 3]
254+
assert item.sparse_embedding == SparseVector([1, 2, 3])
255255

256256
def test_sparsevec_l2_distance(self):
257257
create_items()
@@ -309,15 +309,15 @@ def test_halfvec_avg(self):
309309
Item(half_embedding=[1, 2, 3]).save()
310310
Item(half_embedding=[4, 5, 6]).save()
311311
avg = Item.objects.aggregate(Avg('half_embedding'))['half_embedding__avg']
312-
assert avg.to_list() == [2.5, 3.5, 4.5]
312+
assert avg == HalfVector([2.5, 3.5, 4.5])
313313

314314
def test_halfvec_sum(self):
315315
sum = Item.objects.aggregate(Sum('half_embedding'))['half_embedding__sum']
316316
assert sum is None
317317
Item(half_embedding=[1, 2, 3]).save()
318318
Item(half_embedding=[4, 5, 6]).save()
319319
sum = Item.objects.aggregate(Sum('half_embedding'))['half_embedding__sum']
320-
assert sum.to_list() == [5, 7, 9]
320+
assert sum == HalfVector([5, 7, 9])
321321

322322
def test_serialization(self):
323323
create_items()
@@ -375,7 +375,7 @@ def test_halfvec_form_save(self):
375375
assert form.has_changed()
376376
assert form.is_valid()
377377
assert form.save()
378-
assert [4, 5, 6] == Item.objects.get(pk=1).half_embedding.to_list()
378+
assert Item.objects.get(pk=1).half_embedding == HalfVector([4, 5, 6])
379379

380380
def test_halfvec_form_save_missing(self):
381381
Item(id=1).save()
@@ -432,7 +432,7 @@ def test_sparsevec_form_save(self):
432432
assert form.has_changed()
433433
assert form.is_valid()
434434
assert form.save()
435-
assert [4, 5, 6] == Item.objects.get(pk=1).sparse_embedding.to_list()
435+
assert Item.objects.get(pk=1).sparse_embedding == SparseVector([4, 5, 6])
436436

437437
def test_sparesevec_form_save_missing(self):
438438
Item(id=1).save()

tests/test_psycopg.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,17 +69,19 @@ def test_halfvec(self):
6969
conn.execute('INSERT INTO psycopg_items (half_embedding) VALUES (%s)', (embedding,))
7070

7171
res = conn.execute('SELECT half_embedding FROM psycopg_items ORDER BY id').fetchone()[0]
72-
assert res.to_list() == [1.5, 2, 3]
72+
assert res == HalfVector([1.5, 2, 3])
7373

7474
def test_halfvec_binary_format(self):
7575
embedding = HalfVector([1.5, 2, 3])
7676
res = conn.execute('SELECT %b::halfvec', (embedding,), binary=True).fetchone()[0]
77+
assert res == HalfVector([1.5, 2, 3])
7778
assert res.to_list() == [1.5, 2, 3]
7879
assert np.array_equal(res.to_numpy(), np.array([1.5, 2, 3]))
7980

8081
def test_halfvec_text_format(self):
8182
embedding = HalfVector([1.5, 2, 3])
8283
res = conn.execute('SELECT %t::halfvec', (embedding,)).fetchone()[0]
84+
assert res == HalfVector([1.5, 2, 3])
8385
assert res.to_list() == [1.5, 2, 3]
8486
assert np.array_equal(res.to_numpy(), np.array([1.5, 2, 3]))
8587

@@ -106,11 +108,12 @@ def test_sparsevec(self):
106108
conn.execute('INSERT INTO psycopg_items (sparse_embedding) VALUES (%s)', (embedding,))
107109

108110
res = conn.execute('SELECT sparse_embedding FROM psycopg_items ORDER BY id').fetchone()[0]
109-
assert res.to_list() == [1.5, 2, 3]
111+
assert res == SparseVector([1.5, 2, 3])
110112

111113
def test_sparsevec_binary_format(self):
112114
embedding = SparseVector([1.5, 0, 2, 0, 3, 0])
113115
res = conn.execute('SELECT %b::sparsevec', (embedding,), binary=True).fetchone()[0]
116+
assert res == embedding
114117
assert res.dimensions() == 6
115118
assert res.indices() == [0, 2, 4]
116119
assert res.values() == [1.5, 2, 3]
@@ -120,6 +123,7 @@ def test_sparsevec_binary_format(self):
120123
def test_sparsevec_text_format(self):
121124
embedding = SparseVector([1.5, 0, 2, 0, 3, 0])
122125
res = conn.execute('SELECT %t::sparsevec', (embedding,)).fetchone()[0]
126+
assert res == embedding
123127
assert res.dimensions() == 6
124128
assert res.indices() == [0, 2, 4]
125129
assert res.values() == [1.5, 2, 3]

0 commit comments

Comments
 (0)