|
1 | 1 | import numpy as np
|
2 |
| -from pgvector.sqlalchemy import VECTOR, HALFVEC, BIT, SPARSEVEC, SparseVector |
| 2 | +from pgvector.sqlalchemy import VECTOR, HALFVEC, BIT, SPARSEVEC, SparseVector, avg, sum |
3 | 3 | import pytest
|
4 | 4 | from sqlalchemy import create_engine, insert, inspect, select, text, MetaData, Table, Column, Index, Integer
|
5 | 5 | from sqlalchemy.exc import StatementError
|
@@ -339,41 +339,39 @@ def test_select_orm(self):
|
339 | 339 |
|
340 | 340 | def test_avg(self):
|
341 | 341 | with Session(engine) as session:
|
342 |
| - avg = session.query(func.avg(Item.embedding)).first()[0] |
343 |
| - assert avg is None |
| 342 | + res = session.query(avg(Item.embedding)).first()[0] |
| 343 | + assert res is None |
344 | 344 | session.add(Item(embedding=[1, 2, 3]))
|
345 | 345 | session.add(Item(embedding=[4, 5, 6]))
|
346 |
| - avg = session.query(func.avg(Item.embedding)).first()[0] |
347 |
| - # does not type cast |
348 |
| - assert avg == '[2.5,3.5,4.5]' |
| 346 | + res = session.query(avg(Item.embedding)).first()[0] |
| 347 | + assert np.array_equal(res, np.array([2.5, 3.5, 4.5])) |
349 | 348 |
|
350 | 349 | def test_avg_orm(self):
|
351 | 350 | with Session(engine) as session:
|
352 |
| - avg = session.scalars(select(func.avg(Item.embedding))).first() |
353 |
| - assert avg is None |
| 351 | + res = session.scalars(select(avg(Item.embedding))).first() |
| 352 | + assert res is None |
354 | 353 | session.add(Item(embedding=[1, 2, 3]))
|
355 | 354 | session.add(Item(embedding=[4, 5, 6]))
|
356 |
| - avg = session.scalars(select(func.avg(Item.embedding))).first() |
357 |
| - # does not type cast |
358 |
| - assert avg == '[2.5,3.5,4.5]' |
| 355 | + res = session.scalars(select(avg(Item.embedding))).first() |
| 356 | + assert np.array_equal(res, np.array([2.5, 3.5, 4.5])) |
359 | 357 |
|
360 | 358 | def test_sum(self):
|
361 | 359 | with Session(engine) as session:
|
362 |
| - sum = session.query(func.sum(Item.embedding)).first()[0] |
363 |
| - assert sum is None |
| 360 | + res = session.query(sum(Item.embedding)).first()[0] |
| 361 | + assert res is None |
364 | 362 | session.add(Item(embedding=[1, 2, 3]))
|
365 | 363 | session.add(Item(embedding=[4, 5, 6]))
|
366 |
| - sum = session.query(func.sum(Item.embedding)).first()[0] |
367 |
| - assert np.array_equal(sum, np.array([5, 7, 9])) |
| 364 | + res = session.query(sum(Item.embedding)).first()[0] |
| 365 | + assert np.array_equal(res, np.array([5, 7, 9])) |
368 | 366 |
|
369 | 367 | def test_sum_orm(self):
|
370 | 368 | with Session(engine) as session:
|
371 |
| - sum = session.scalars(select(func.sum(Item.embedding))).first() |
372 |
| - assert sum is None |
| 369 | + res = session.scalars(select(sum(Item.embedding))).first() |
| 370 | + assert res is None |
373 | 371 | session.add(Item(embedding=[1, 2, 3]))
|
374 | 372 | session.add(Item(embedding=[4, 5, 6]))
|
375 |
| - sum = session.scalars(select(func.sum(Item.embedding))).first() |
376 |
| - assert np.array_equal(sum, np.array([5, 7, 9])) |
| 373 | + res = session.scalars(select(sum(Item.embedding))).first() |
| 374 | + assert np.array_equal(res, np.array([5, 7, 9])) |
377 | 375 |
|
378 | 376 | def test_bad_dimensions(self):
|
379 | 377 | item = Item(embedding=[1, 2])
|
|
0 commit comments