Skip to content

Commit e0c1c4c

Browse files
committed
Add Links to graph store Node
1 parent 5775e7f commit e0c1c4c

File tree

6 files changed

+60
-50
lines changed

6 files changed

+60
-50
lines changed

.github/workflows/ci-unit-tests.yml

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,15 @@ jobs:
139139
rm -rf $dir/.tox
140140
}
141141
142-
run_itests libs/colbert
143-
run_itests libs/langchain
144-
run_itests libs/llamaindex
142+
if [[ "true" == "${{ needs.preconditions.outputs.libs_colbert }}" ]]; then
143+
run_itests libs/colbert
144+
fi
145+
if [[ "true" == "${{ needs.preconditions.outputs.libs_langchain }}" ]]; then
146+
run_itests libs/langchain
147+
fi
148+
if [[ "true" == "${{ needs.preconditions.outputs.libs_llamaindex }}" ]]; then
149+
run_itests libs/llamaindex
150+
fi
145151
146152
- name: Cleanup AstraDB
147153
uses: nicoloboschi/cleanup-astradb@v1

libs/knowledge-store/ragstack_knowledge_store/graph_store.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from .concurrency import ConcurrentQueries
2020
from .content import Kind
2121
from .embedding_model import EmbeddingModel
22-
from .link_tag import LinkTag
22+
from .links import Link
2323
from .math import cosine_similarity
2424

2525
CONTENT_ID = "content_id"
@@ -33,7 +33,7 @@ class Node:
3333
"""Unique ID for the node. Will be generated by the GraphStore if not set."""
3434
metadata: dict = field(default_factory=dict)
3535
"""Metadata for the node."""
36-
links: Set[LinkTag] = field(default_factory=set)
36+
links: Set[Link] = field(default_factory=set)
3737
"""Links for the node."""
3838

3939

@@ -330,13 +330,13 @@ def add_nodes(
330330
) -> Iterable[str]:
331331
texts = []
332332
metadatas = []
333-
links: List[Set[LinkTag]] = []
333+
nodes_links: List[Set[Link]] = []
334334
for node in nodes:
335335
if not isinstance(node, TextNode):
336336
raise ValueError("Only adding TextNode is supported at the moment")
337337
texts.append(node.text)
338338
metadatas.append(node.metadata)
339-
links.append(node.links)
339+
nodes_links.append(node.links)
340340

341341
text_embeddings = self._embedding.embed_texts(texts)
342342

@@ -347,8 +347,8 @@ def add_nodes(
347347

348348
# Step 1: Add the nodes, collecting the tags and new sources / targets.
349349
with self._concurrent_queries() as cq:
350-
tuples = zip(texts, text_embeddings, metadatas, links)
351-
for text, text_embedding, metadata, _links in tuples:
350+
tuples = zip(texts, text_embeddings, metadatas, nodes_links)
351+
for text, text_embedding, metadata, links in tuples:
352352
if CONTENT_ID not in metadata:
353353
metadata[CONTENT_ID] = secrets.token_hex(8)
354354
id = metadata[CONTENT_ID]
@@ -357,20 +357,22 @@ def add_nodes(
357357
link_to_tags = set() # link to these tags
358358
link_from_tags = set() # link from these tags
359359

360-
for tag in _links:
361-
tag_str = f"{tag.kind}:{tag.tag}"
362-
if tag.direction == "incoming" or tag.direction == "bidir":
363-
# An incom`ing link should be linked *from* nodes with the given tag.
364-
link_from_tags.add(tag_str)
365-
tag_to_new_targets.setdefault(tag_str, dict())[id] = (
366-
tag.kind,
367-
text_embedding,
368-
)
369-
if tag.direction == "outgoing" or tag.direction == "bidir":
370-
link_to_tags.add(tag_str)
371-
tag_to_new_sources.setdefault(tag_str, list()).append(
372-
(tag.kind, id)
373-
)
360+
for tag in links:
361+
if hasattr(tag, "tag"):
362+
tag_str = f"{tag.kind}:{tag.tag}"
363+
if tag.direction == "incoming" or tag.direction == "bidir":
364+
# An incoming link should be linked *from* nodes with the
365+
# given tag.
366+
link_from_tags.add(tag_str)
367+
tag_to_new_targets.setdefault(tag_str, dict())[id] = (
368+
tag.kind,
369+
text_embedding,
370+
)
371+
if tag.direction == "outgoing" or tag.direction == "bidir":
372+
link_to_tags.add(tag_str)
373+
tag_to_new_sources.setdefault(tag_str, list()).append(
374+
(tag.kind, id)
375+
)
374376

375377
cq.execute(
376378
self._insert_passage,

libs/knowledge-store/ragstack_knowledge_store/link_tag.py renamed to libs/knowledge-store/ragstack_knowledge_store/links.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
11
from dataclasses import dataclass
2-
from typing import Literal, Dict, Any, Set
2+
from typing import Literal
33

44

55
@dataclass(frozen=True)
6-
class _LinkTag:
6+
class Link:
77
kind: str
8-
tag: str
98
direction: Literal["incoming", "outgoing", "bidir"]
109

10+
def __post_init__(self):
11+
if self.__class__ in [Link, LinkTag]:
12+
raise TypeError(
13+
f"Abstract class {self.__class__.__name__} cannot be instantiated"
14+
)
15+
1116

1217
@dataclass(frozen=True)
13-
class LinkTag(_LinkTag):
14-
def __init__(self, kind: str, tag: str, direction: str) -> None:
15-
if self.__class__ == LinkTag:
16-
raise TypeError("Abstract class LinkTag cannot be instantiated")
17-
super().__init__(kind, tag, direction)
18+
class LinkTag(Link):
19+
tag: str
1820

1921

2022
@dataclass(frozen=True)

libs/langchain/ragstack_langchain/graph_store/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
2424
from langchain_core.pydantic_v1 import Field
2525

26-
from ragstack_langchain.graph_store.links import LinkTag, LINKS
26+
from ragstack_langchain.graph_store.links import METADATA_LINKS_KEY, Link
2727

2828

2929
def _has_next(iterator: Iterator) -> None:
@@ -40,7 +40,7 @@ class Node(Serializable):
4040
"""Unique ID for the node. Will be generated by the GraphStore if not set."""
4141
metadata: dict = Field(default_factory=dict)
4242
"""Metadata for the node."""
43-
links: Set[LinkTag] = Field(default_factory=set)
43+
links: Set[Link] = Field(default_factory=set)
4444
"""Links associated with the node."""
4545

4646

@@ -66,7 +66,7 @@ def _texts_to_nodes(
6666
except StopIteration:
6767
raise ValueError("texts iterable longer than ids")
6868

69-
links = _metadata.pop(LINKS, set())
69+
links = _metadata.pop(METADATA_LINKS_KEY, set())
7070
yield TextNode(
7171
id=_id,
7272
metadata=_metadata,
@@ -89,7 +89,7 @@ def _documents_to_nodes(
8989
except StopIteration:
9090
raise ValueError("documents iterable longer than ids")
9191
metadata = doc.metadata.copy()
92-
links = metadata.pop(LINKS, set())
92+
links = metadata.pop(METADATA_LINKS_KEY, set())
9393
yield TextNode(
9494
id=_id,
9595
metadata=metadata,

libs/langchain/ragstack_langchain/graph_store/links.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(self, kind: str, tag: str) -> None:
3939
super().__init__(kind=kind, tag=tag, direction="bidir")
4040

4141

42-
LINKS = "links"
42+
METADATA_LINKS_KEY = "links"
4343

4444

4545
def get_links(doc_or_md: Union[Document, Dict[str, Any]]) -> Set[Link]:
@@ -53,10 +53,10 @@ def get_links(doc_or_md: Union[Document, Dict[str, Any]]) -> Set[Link]:
5353
if isinstance(doc_or_md, Document):
5454
doc_or_md = doc_or_md.metadata
5555

56-
links = doc_or_md.setdefault(LINKS, set())
56+
links = doc_or_md.setdefault(METADATA_LINKS_KEY, set())
5757
if not isinstance(links, Set):
5858
links = set(links)
59-
doc_or_md[LINKS] = links
59+
doc_or_md[METADATA_LINKS_KEY] = links
6060
return links
6161

6262

libs/langchain/tests/integration_tests/test_graph_store.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
TextNode,
1515
)
1616
from ragstack_langchain.graph_store.links import (
17-
LINKS,
17+
METADATA_LINKS_KEY,
1818
BidirLinkTag,
1919
IncomingLinkTag,
2020
OutgoingLinkTag,
@@ -37,7 +37,7 @@ def __init__(self, session: Session, keyspace: str, embedding: Embeddings) -> No
3737

3838
def store(
3939
self,
40-
initial_documents: Iterable[Document] = [],
40+
initial_documents: Iterable[Document] = (),
4141
ids: Optional[Iterable[str]] = None,
4242
embedding: Optional[Embeddings] = None,
4343
) -> CassandraGraphStore:
@@ -125,7 +125,7 @@ def test_link_directed(cassandra: GraphStoreFactory) -> None:
125125
page_content="A",
126126
metadata={
127127
"content_id": "a",
128-
LINKS: {
128+
METADATA_LINKS_KEY: {
129129
IncomingLinkTag(kind="hyperlink", tag="http://a"),
130130
},
131131
},
@@ -134,7 +134,7 @@ def test_link_directed(cassandra: GraphStoreFactory) -> None:
134134
page_content="B",
135135
metadata={
136136
"content_id": "b",
137-
LINKS: {
137+
METADATA_LINKS_KEY: {
138138
IncomingLinkTag(kind="hyperlink", tag="http://b"),
139139
OutgoingLinkTag(kind="hyperlink", tag="http://a"),
140140
},
@@ -144,7 +144,7 @@ def test_link_directed(cassandra: GraphStoreFactory) -> None:
144144
page_content="C",
145145
metadata={
146146
"content_id": "c",
147-
LINKS: {
147+
METADATA_LINKS_KEY: {
148148
OutgoingLinkTag(kind="hyperlink", tag="http://a"),
149149
},
150150
},
@@ -153,7 +153,7 @@ def test_link_directed(cassandra: GraphStoreFactory) -> None:
153153
page_content="D",
154154
metadata={
155155
"content_id": "d",
156-
LINKS: {
156+
METADATA_LINKS_KEY: {
157157
OutgoingLinkTag(kind="hyperlink", tag="http://a"),
158158
OutgoingLinkTag(kind="hyperlink", tag="http://b"),
159159
},
@@ -198,7 +198,7 @@ def test_mmr_traversal(request, gs_factory: str):
198198
page_content="-0.124",
199199
metadata={
200200
"content_id": "v0",
201-
LINKS: {
201+
METADATA_LINKS_KEY: {
202202
OutgoingLinkTag(kind="explicit", tag="link"),
203203
},
204204
},
@@ -213,7 +213,7 @@ def test_mmr_traversal(request, gs_factory: str):
213213
page_content="+0.25",
214214
metadata={
215215
"content_id": "v2",
216-
LINKS: {
216+
METADATA_LINKS_KEY: {
217217
IncomingLinkTag(kind="explicit", tag="link"),
218218
},
219219
},
@@ -222,7 +222,7 @@ def test_mmr_traversal(request, gs_factory: str):
222222
page_content="+1.0",
223223
metadata={
224224
"content_id": "v3",
225-
LINKS: {
225+
METADATA_LINKS_KEY: {
226226
IncomingLinkTag(kind="explicit", tag="link"),
227227
},
228228
},
@@ -257,7 +257,7 @@ def test_write_retrieve_keywords(request, gs_factory: str):
257257
page_content="Typical Greetings",
258258
metadata={
259259
"content_id": "greetings",
260-
LINKS: {
260+
METADATA_LINKS_KEY: {
261261
IncomingLinkTag(kind="parent", tag="parent"),
262262
},
263263
},
@@ -266,7 +266,7 @@ def test_write_retrieve_keywords(request, gs_factory: str):
266266
page_content="Hello World",
267267
metadata={
268268
"content_id": "doc1",
269-
LINKS: {
269+
METADATA_LINKS_KEY: {
270270
OutgoingLinkTag(kind="parent", tag="parent"),
271271
BidirLinkTag(kind="kw", tag="greeting"),
272272
BidirLinkTag(kind="kw", tag="world"),
@@ -277,7 +277,7 @@ def test_write_retrieve_keywords(request, gs_factory: str):
277277
page_content="Hello Earth",
278278
metadata={
279279
"content_id": "doc2",
280-
LINKS: {
280+
METADATA_LINKS_KEY: {
281281
OutgoingLinkTag(kind="parent", tag="parent"),
282282
BidirLinkTag(kind="kw", tag="greeting"),
283283
BidirLinkTag(kind="kw", tag="earth"),

0 commit comments

Comments
 (0)