Skip to content

Commit 5a473be

Browse files
Update sqlalchemy_items.py
1 parent 312293c commit 5a473be

File tree

1 file changed

+23
-8
lines changed

1 file changed

+23
-8
lines changed

examples/sqlalchemy_items.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,23 +39,30 @@ class Item(Base):
3939

4040
engine = create_engine(DATABASE_URI, echo=False)
4141

42-
# Create tables in database
42+
# Create pgvector extension
43+
with engine.begin() as conn:
44+
conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
45+
46+
# Drop all tables defined in this model from the database, if they already exist
4347
Base.metadata.drop_all(engine)
48+
# Create all tables defined in this model in the database
4449
Base.metadata.create_all(engine)
4550

4651
# Insert data and issue queries
4752
with Session(engine) as session:
48-
session.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
49-
53+
# Define HNSW index to support vector similarity search through the vector_l2_ops access method (Euclidean distance). The SQL operator for Euclidean distance is written as <->.
5054
index = Index(
51-
"my_index",
55+
"hnsw_index_for_euclidean_distance_similarity_search",
5256
Item.embedding,
5357
postgresql_using="hnsw",
5458
postgresql_with={"m": 16, "ef_construction": 64},
5559
postgresql_ops={"embedding": "vector_l2_ops"},
5660
)
61+
62+
# Create the HNSW index
5763
index.create(engine)
5864

65+
# Insert three vectors as three separate rows in the items table
5966
session.add_all(
6067
[
6168
Item(embedding=[1, 2, 3]),
@@ -64,20 +71,28 @@ class Item(Base):
6471
]
6572
)
6673

74+
# Find all vectors in table items
75+
all_items = session.scalars(select(Item))
76+
print("All vectors in table items:")
77+
for item in all_items:
78+
print(f"\t{item.embedding}")
79+
6780
# Find 2 closest vectors to [3, 1, 2]
6881
closest_items = session.scalars(select(Item).order_by(Item.embedding.l2_distance([3, 1, 2])).limit(2))
82+
print("Two closest vectors to [3, 1, 2] in table items:")
6983
for item in closest_items:
70-
print(item.embedding)
84+
print(f"\t{item.embedding}")
7185

7286
# Calculate distance between [3, 1, 2] and the first vector
7387
distance = session.scalars(select(Item.embedding.l2_distance([3, 1, 2]))).first()
74-
print(distance)
88+
print(f"Distance between [3, 1, 2] vector and the one closest to it: {distance}")
7589

7690
# Find vectors within distance 5 from [3, 1, 2]
7791
close_enough_items = session.scalars(select(Item).filter(Item.embedding.l2_distance([3, 1, 2]) < 5))
92+
print("Vectors within a distance of 5 from [3, 1, 2]:")
7893
for item in close_enough_items:
79-
print(item.embedding)
94+
print(f"\t{item.embedding}")
8095

8196
# Calculate average of all vectors
8297
avg_embedding = session.scalars(select(func.avg(Item.embedding))).first()
83-
print(avg_embedding)
98+
print(f"Average of all vectors: {avg_embedding}")

0 commit comments

Comments
 (0)