Skip to content

Commit f65c361

Browse files
committed
Fixed type casting for Peewee aggregations [skip ci]
1 parent 89ec21d commit f65c361

File tree

2 files changed

+13
-17
lines changed

2 files changed

+13
-17
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,7 @@ Average vectors
496496
```python
497497
from peewee import fn
498498

499-
Item.select(fn.avg(Item.embedding)).scalar()
499+
Item.select(fn.avg(Item.embedding).coerce(True)).scalar()
500500
```
501501

502502
Also supports `sum`

tests/test_peewee.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -164,40 +164,36 @@ def test_where(self):
164164
assert [v.id for v in items] == [1]
165165

166166
def test_vector_avg(self):
167-
avg = Item.select(fn.avg(Item.embedding)).scalar()
167+
avg = Item.select(fn.avg(Item.embedding).coerce(True)).scalar()
168168
assert avg is None
169169
Item.create(embedding=[1, 2, 3])
170170
Item.create(embedding=[4, 5, 6])
171-
avg = Item.select(fn.avg(Item.embedding)).scalar()
172-
# does not type cast
173-
assert avg == '[2.5,3.5,4.5]'
171+
avg = Item.select(fn.avg(Item.embedding).coerce(True)).scalar()
172+
assert np.array_equal(avg, np.array([2.5, 3.5, 4.5]))
174173

175174
def test_vector_sum(self):
176-
sum = Item.select(fn.sum(Item.embedding)).scalar()
175+
sum = Item.select(fn.sum(Item.embedding).coerce(True)).scalar()
177176
assert sum is None
178177
Item.create(embedding=[1, 2, 3])
179178
Item.create(embedding=[4, 5, 6])
180-
sum = Item.select(fn.sum(Item.embedding)).scalar()
181-
# does not type cast
182-
assert sum == '[5,7,9]'
179+
sum = Item.select(fn.sum(Item.embedding).coerce(True)).scalar()
180+
assert np.array_equal(sum, np.array([5, 7, 9]))
183181

184182
def test_halfvec_avg(self):
185-
avg = Item.select(fn.avg(Item.half_embedding)).scalar()
183+
avg = Item.select(fn.avg(Item.half_embedding).coerce(True)).scalar()
186184
assert avg is None
187185
Item.create(half_embedding=[1, 2, 3])
188186
Item.create(half_embedding=[4, 5, 6])
189-
avg = Item.select(fn.avg(Item.half_embedding)).scalar()
190-
# does not type cast
191-
assert avg == '[2.5,3.5,4.5]'
187+
avg = Item.select(fn.avg(Item.half_embedding).coerce(True)).scalar()
188+
assert avg.to_list() == [2.5, 3.5, 4.5]
192189

193190
def test_halfvec_sum(self):
194-
sum = Item.select(fn.sum(Item.half_embedding)).scalar()
191+
sum = Item.select(fn.sum(Item.half_embedding).coerce(True)).scalar()
195192
assert sum is None
196193
Item.create(half_embedding=[1, 2, 3])
197194
Item.create(half_embedding=[4, 5, 6])
198-
sum = Item.select(fn.sum(Item.half_embedding)).scalar()
199-
# does not type cast
200-
assert sum == '[5,7,9]'
195+
sum = Item.select(fn.sum(Item.half_embedding).coerce(True)).scalar()
196+
assert sum.to_list() == [5, 7, 9]
201197

202198
def test_get_or_create(self):
203199
Item.get_or_create(id=1, defaults={'embedding': [1, 2, 3]})

0 commit comments

Comments
 (0)