|
1 | 1 | from math import sqrt |
2 | 2 | import numpy as np |
3 | 3 | from peewee import Model, PostgresqlDatabase, fn |
4 | | -from pgvector import SparseVector |
| 4 | +from pgvector import HalfVector, SparseVector |
5 | 5 | from pgvector.peewee import VectorField, HalfVectorField, FixedBitField, SparseVectorField |
6 | 6 |
|
7 | 7 | db = PostgresqlDatabase('pgvector_python_test') |
@@ -77,7 +77,7 @@ def test_vector_l1_distance(self): |
77 | 77 | def test_halfvec(self): |
78 | 78 | Item.create(id=1, half_embedding=[1, 2, 3]) |
79 | 79 | item = Item.get_by_id(1) |
80 | | - assert item.half_embedding.to_list() == [1, 2, 3] |
| 80 | + assert item.half_embedding == HalfVector([1, 2, 3]) |
81 | 81 |
|
82 | 82 | def test_halfvec_l2_distance(self): |
83 | 83 | create_items() |
@@ -129,7 +129,7 @@ def test_bit_jaccard_distance(self): |
129 | 129 | def test_sparsevec(self): |
130 | 130 | Item.create(id=1, sparse_embedding=[1, 2, 3]) |
131 | 131 | item = Item.get_by_id(1) |
132 | | - assert item.sparse_embedding.to_list() == [1, 2, 3] |
| 132 | + assert item.sparse_embedding == SparseVector([1, 2, 3]) |
133 | 133 |
|
134 | 134 | def test_sparsevec_l2_distance(self): |
135 | 135 | create_items() |
@@ -186,15 +186,15 @@ def test_halfvec_avg(self): |
186 | 186 | Item.create(half_embedding=[1, 2, 3]) |
187 | 187 | Item.create(half_embedding=[4, 5, 6]) |
188 | 188 | avg = Item.select(fn.avg(Item.half_embedding).coerce(True)).scalar() |
189 | | - assert avg.to_list() == [2.5, 3.5, 4.5] |
| 189 | + assert avg == HalfVector([2.5, 3.5, 4.5]) |
190 | 190 |
|
191 | 191 | def test_halfvec_sum(self): |
192 | 192 | sum = Item.select(fn.sum(Item.half_embedding).coerce(True)).scalar() |
193 | 193 | assert sum is None |
194 | 194 | Item.create(half_embedding=[1, 2, 3]) |
195 | 195 | Item.create(half_embedding=[4, 5, 6]) |
196 | 196 | sum = Item.select(fn.sum(Item.half_embedding).coerce(True)).scalar() |
197 | | - assert sum.to_list() == [5, 7, 9] |
| 197 | + assert sum == HalfVector([5, 7, 9]) |
198 | 198 |
|
199 | 199 | def test_get_or_create(self): |
200 | 200 | Item.get_or_create(id=1, defaults={'embedding': [1, 2, 3]}) |
|
0 commit comments