Skip to content

Commit 0705e3a

Browse files
committed
Added more tests for aggregates with SQLAlchemy - pgvector#44
1 parent bab9e4b commit 0705e3a

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

tests/test_sqlalchemy.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,15 @@ def test_avg(self):
166166
avg = session.query(func.avg(Item.embedding)).first()[0]
167167
assert np.array_equal(avg, np.array([2.5, 3.5, 4.5]))
168168

169+
def test_avg_orm(self):
170+
with Session(engine) as session:
171+
avg = session.scalars(select(func.avg(Item.embedding))).first()
172+
assert avg is None
173+
session.add(Item(embedding=[1, 2, 3]))
174+
session.add(Item(embedding=[4, 5, 6]))
175+
avg = session.scalars(select(func.avg(Item.embedding))).first()
176+
assert np.array_equal(avg, np.array([2.5, 3.5, 4.5]))
177+
169178
def test_sum(self):
170179
with Session(engine) as session:
171180
sum = session.query(func.sum(Item.embedding)).first()[0]
@@ -175,6 +184,15 @@ def test_sum(self):
175184
sum = session.query(func.sum(Item.embedding)).first()[0]
176185
assert np.array_equal(sum, np.array([5, 7, 9]))
177186

187+
def test_sum_orm(self):
188+
with Session(engine) as session:
189+
sum = session.scalars(select(func.sum(Item.embedding))).first()
190+
assert sum is None
191+
session.add(Item(embedding=[1, 2, 3]))
192+
session.add(Item(embedding=[4, 5, 6]))
193+
sum = session.scalars(select(func.sum(Item.embedding))).first()
194+
assert np.array_equal(sum, np.array([5, 7, 9]))
195+
178196
def test_bad_dimensions(self):
179197
item = Item(embedding=[1, 2])
180198
session = Session(engine)

0 commit comments

Comments
 (0)