Skip to content

Commit e90260a

Browse files
committed
Added avg function with type casting to SQLAlchemy - #44
Co-authored-by: lucasgadams
1 parent f65c361 commit e90260a

File tree

5 files changed

+47
-39
lines changed

5 files changed

+47
-39
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
## 0.3.5 (unreleased)
22

3+
- Added `avg` function with type casting to SQLAlchemy
34
- Added `globally` option for Psycopg 2
45

56
## 0.3.4 (2024-09-26)

pgvector/sqlalchemy/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .bit import BIT
2+
from .functions import avg, sum
23
from .halfvec import HALFVEC
34
from .sparsevec import SPARSEVEC
45
from .vector import VECTOR
@@ -12,5 +13,7 @@
1213
'BIT',
1314
'SPARSEVEC',
1415
'HalfVector',
15-
'SparseVector'
16+
'SparseVector',
17+
'avg',
18+
'sum'
1619
]

pgvector/sqlalchemy/functions.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# https://docs.sqlalchemy.org/en/20/core/functions.html
2+
# include sum for a consistent API
3+
from sqlalchemy.sql.functions import ReturnTypeFromArgs, sum
4+
5+
6+
class avg(ReturnTypeFromArgs):
7+
inherit_cache = True
8+
package = 'pgvector'

tests/test_sqlalchemy.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from pgvector.sqlalchemy import VECTOR, HALFVEC, BIT, SPARSEVEC, SparseVector
2+
from pgvector.sqlalchemy import VECTOR, HALFVEC, BIT, SPARSEVEC, SparseVector, avg, sum
33
import pytest
44
from sqlalchemy import create_engine, insert, inspect, select, text, MetaData, Table, Column, Index, Integer
55
from sqlalchemy.exc import StatementError
@@ -339,41 +339,39 @@ def test_select_orm(self):
339339

340340
def test_avg(self):
341341
with Session(engine) as session:
342-
avg = session.query(func.avg(Item.embedding)).first()[0]
343-
assert avg is None
342+
res = session.query(avg(Item.embedding)).first()[0]
343+
assert res is None
344344
session.add(Item(embedding=[1, 2, 3]))
345345
session.add(Item(embedding=[4, 5, 6]))
346-
avg = session.query(func.avg(Item.embedding)).first()[0]
347-
# does not type cast
348-
assert avg == '[2.5,3.5,4.5]'
346+
res = session.query(avg(Item.embedding)).first()[0]
347+
assert np.array_equal(res, np.array([2.5, 3.5, 4.5]))
349348

350349
def test_avg_orm(self):
351350
with Session(engine) as session:
352-
avg = session.scalars(select(func.avg(Item.embedding))).first()
353-
assert avg is None
351+
res = session.scalars(select(avg(Item.embedding))).first()
352+
assert res is None
354353
session.add(Item(embedding=[1, 2, 3]))
355354
session.add(Item(embedding=[4, 5, 6]))
356-
avg = session.scalars(select(func.avg(Item.embedding))).first()
357-
# does not type cast
358-
assert avg == '[2.5,3.5,4.5]'
355+
res = session.scalars(select(avg(Item.embedding))).first()
356+
assert np.array_equal(res, np.array([2.5, 3.5, 4.5]))
359357

360358
def test_sum(self):
361359
with Session(engine) as session:
362-
sum = session.query(func.sum(Item.embedding)).first()[0]
363-
assert sum is None
360+
res = session.query(sum(Item.embedding)).first()[0]
361+
assert res is None
364362
session.add(Item(embedding=[1, 2, 3]))
365363
session.add(Item(embedding=[4, 5, 6]))
366-
sum = session.query(func.sum(Item.embedding)).first()[0]
367-
assert np.array_equal(sum, np.array([5, 7, 9]))
364+
res = session.query(sum(Item.embedding)).first()[0]
365+
assert np.array_equal(res, np.array([5, 7, 9]))
368366

369367
def test_sum_orm(self):
370368
with Session(engine) as session:
371-
sum = session.scalars(select(func.sum(Item.embedding))).first()
372-
assert sum is None
369+
res = session.scalars(select(sum(Item.embedding))).first()
370+
assert res is None
373371
session.add(Item(embedding=[1, 2, 3]))
374372
session.add(Item(embedding=[4, 5, 6]))
375-
sum = session.scalars(select(func.sum(Item.embedding))).first()
376-
assert np.array_equal(sum, np.array([5, 7, 9]))
373+
res = session.scalars(select(sum(Item.embedding))).first()
374+
assert np.array_equal(res, np.array([5, 7, 9]))
377375

378376
def test_bad_dimensions(self):
379377
item = Item(embedding=[1, 2])

tests/test_sqlmodel.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from pgvector.sqlalchemy import VECTOR, HALFVEC, BIT, SPARSEVEC, SparseVector
2+
from pgvector.sqlalchemy import VECTOR, HALFVEC, BIT, SPARSEVEC, SparseVector, avg, sum
33
import pytest
44
from sqlalchemy import Column, Index
55
from sqlalchemy.exc import StatementError
@@ -198,41 +198,39 @@ def test_select(self):
198198

199199
def test_vector_avg(self):
200200
with Session(engine) as session:
201-
avg = session.exec(select(func.avg(Item.embedding))).first()
202-
assert avg is None
201+
res = session.exec(select(avg(Item.embedding))).first()
202+
assert res is None
203203
session.add(Item(embedding=[1, 2, 3]))
204204
session.add(Item(embedding=[4, 5, 6]))
205-
avg = session.exec(select(func.avg(Item.embedding))).first()
206-
# does not type cast
207-
assert avg == '[2.5,3.5,4.5]'
205+
res = session.exec(select(avg(Item.embedding))).first()
206+
assert np.array_equal(res, np.array([2.5, 3.5, 4.5]))
208207

209208
def test_vector_sum(self):
210209
with Session(engine) as session:
211-
sum = session.exec(select(func.sum(Item.embedding))).first()
212-
assert sum is None
210+
res = session.exec(select(sum(Item.embedding))).first()
211+
assert res is None
213212
session.add(Item(embedding=[1, 2, 3]))
214213
session.add(Item(embedding=[4, 5, 6]))
215-
sum = session.exec(select(func.sum(Item.embedding))).first()
216-
assert np.array_equal(sum, np.array([5, 7, 9]))
214+
res = session.exec(select(sum(Item.embedding))).first()
215+
assert np.array_equal(res, np.array([5, 7, 9]))
217216

218217
def test_halfvec_avg(self):
219218
with Session(engine) as session:
220-
avg = session.exec(select(func.avg(Item.half_embedding))).first()
221-
assert avg is None
219+
res = session.exec(select(avg(Item.half_embedding))).first()
220+
assert res is None
222221
session.add(Item(half_embedding=[1, 2, 3]))
223222
session.add(Item(half_embedding=[4, 5, 6]))
224-
avg = session.exec(select(func.avg(Item.half_embedding))).first()
225-
# does not type cast
226-
assert avg == '[2.5,3.5,4.5]'
223+
res = session.exec(select(avg(Item.half_embedding))).first()
224+
assert res.to_list() == [2.5, 3.5, 4.5]
227225

228226
def test_halfvec_sum(self):
229227
with Session(engine) as session:
230-
sum = session.exec(select(func.sum(Item.half_embedding))).first()
231-
assert sum is None
228+
res = session.exec(select(sum(Item.half_embedding))).first()
229+
assert res is None
232230
session.add(Item(half_embedding=[1, 2, 3]))
233231
session.add(Item(half_embedding=[4, 5, 6]))
234-
sum = session.exec(select(func.sum(Item.half_embedding))).first()
235-
assert sum.to_list() == [5, 7, 9]
232+
res = session.exec(select(sum(Item.half_embedding))).first()
233+
assert res.to_list() == [5, 7, 9]
236234

237235
def test_bad_dimensions(self):
238236
item = Item(embedding=[1, 2])

0 commit comments

Comments
 (0)