Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/changes-filter.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ libs_colbert:
libs_langchain:
- "libs/langchain/**"
- "libs/colbert/**"
- "libs/knowledge-store/**"
libs_llamaindex:
- "libs/llamaindex/**"
- "libs/colbert/**"
Expand Down
8 changes: 3 additions & 5 deletions .github/workflows/ci-unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,9 @@ jobs:
if: ${{ needs.preconditions.outputs.libs_llamaindex == 'true' }}
run: tox -e unit-tests -c libs/llamaindex && rm -rf libs/llamaindex/.tox

# - name: "Unit tests (knowledge-store)"
# if: ${{ needs.preconditions.outputs.libs_knowledge_store == 'true' }}
# env:
# OPENAI_API_KEY: "${{ secrets.E2E_TESTS_OPEN_AI_KEY }}"
# run: tox -e unit-tests -c libs/knowledge-store && rm -rf libs/knowledge-store/.tox
- name: "Unit tests (knowledge-store)"
if: ${{ needs.preconditions.outputs.libs_knowledge_store == 'true' }}
run: tox -e unit-tests -c libs/knowledge-store && rm -rf libs/knowledge-store/.tox

- name: "Unit tests (knowledge-graph)"
# yamllint disable-line rule:line-length
Expand Down
62 changes: 52 additions & 10 deletions libs/knowledge-store/ragstack_knowledge_store/graph_store.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import json
import secrets
from dataclasses import dataclass, field
from enum import Enum
from typing import (
Any,
Dict,
Iterable,
List,
NamedTuple,
Optional,
Sequence,
Set,
cast,
)

import numpy as np
Expand Down Expand Up @@ -46,14 +50,48 @@ class SetupMode(Enum):
ASYNC = 2
OFF = 3

def _serialize_metadata(md: Dict[str, Any]) -> str:
if isinstance(md.get("links"), Set):
md = md.copy()
md["links"] = list(md["links"])
s = json.dumps(md)
return s

def _serialize_links(links: Set[Link]) -> str:
import dataclasses
class SetAndLinkEncoder(json.JSONEncoder):
def default(self, obj):
if dataclasses.is_dataclass(obj):
return dataclasses.asdict(obj)

try:
iterable = iter(obj)
except TypeError:
pass
else:
return list(iterable)
# Let the base class default method raise the TypeError
return super().default(obj)
return json.dumps(list(links), cls=SetAndLinkEncoder)

def _deserialize_metadata(json_blob: Optional[str]) -> Dict[str, Any]:
# We don't need to convert the links list back to a set -- it will be
# converted when accessed, if needed.
return cast(Dict[str, Any], json.loads(json_blob or ""))

def _deserialize_links(json_blob: Optional[str]) -> Set[Link]:
return {
Link(kind=link["kind"], direction=link["direction"], tag=link["tag"])
for link in cast(List[Dict], json.loads(json_blob))
}

def _row_to_node(row) -> Node:
metadata = _deserialize_metadata(row.metadata_blob)
links = _deserialize_links(row.links_blob)
return TextNode(
text=row.text_content,
metadata={
CONTENT_ID: row.content_id,
"kind": row.kind,
},
metadata=metadata,
links=links,
)


Expand Down Expand Up @@ -140,13 +178,13 @@ def __init__(
"Only SYNC and OFF are supported at the moment"
)

# TODO: Metadata
# TODO: Parent ID / source ID / etc.
self._insert_passage = session.prepare(
f"""
INSERT INTO {keyspace}.{node_table} (
content_id, kind, text_content, text_embedding, link_to_tags
) VALUES (?, '{Kind.passage}', ?, ?, ?)
content_id, kind, text_content, text_embedding, link_to_tags,
metadata_blob, links_blob
) VALUES (?, '{Kind.passage}', ?, ?, ?, ?, ?)
"""
)

Expand All @@ -160,15 +198,15 @@ def __init__(

self._query_by_id = session.prepare(
f"""
SELECT content_id, kind, text_content
SELECT content_id, kind, text_content, metadata_blob, links_blob
FROM {keyspace}.{node_table}
WHERE content_id = ?
"""
)

self._query_by_embedding = session.prepare(
f"""
SELECT content_id, kind, text_content
SELECT content_id, kind, text_content, metadata_blob, links_blob
FROM {keyspace}.{node_table}
ORDER BY text_embedding ANN OF ?
LIMIT ?
Expand Down Expand Up @@ -239,6 +277,8 @@ def _apply_schema(self):
text_embedding VECTOR<FLOAT, {embedding_dim}>,

link_to_tags SET<TUPLE<TEXT, TEXT>>,
metadata_blob TEXT,
links_blob TEXT,

PRIMARY KEY (content_id)
)
Expand Down Expand Up @@ -307,9 +347,11 @@ def add_nodes(
if tag.direction == "out" or tag.direction == "bidir":
link_to_tags.add((tag.kind, tag.tag))

metadata_blob = _serialize_metadata(metadata)
links_blob = _serialize_links(links)
cq.execute(
self._insert_passage,
parameters=(id, text, text_embedding, link_to_tags),
parameters=(id, text, text_embedding, link_to_tags, metadata_blob, links_blob),
)

for kind, value in link_from_tags:
Expand Down
31 changes: 31 additions & 0 deletions libs/knowledge-store/tests/unit_tests/test_graph_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Any, Dict, Set
from ragstack_knowledge_store.graph_store import _serialize_metadata, _deserialize_metadata, _serialize_links, _deserialize_links
from ragstack_knowledge_store.links import Link

def test_metadata_serialization():
def assert_roundtrip(metadata: Dict[str, Any]):
serialized = _serialize_metadata(metadata)
deserialized = _deserialize_metadata(serialized)
assert metadata == deserialized

assert_roundtrip({})
assert_roundtrip({
"a": "hello",
"b": ["c", "d"],
"c": []
})

def test_links_serialization():
def assert_roundtrip(links: Set[Link]):
serialized = _serialize_links(links)
deserialized = _deserialize_links(serialized)
assert links == deserialized

assert_roundtrip(set())
assert_roundtrip({
Link.incoming("a", "b"),
Link.outgoing("a", "b"),
})
assert_roundtrip({
Link.bidir("a", "b")
})
19 changes: 19 additions & 0 deletions libs/langchain/ragstack_langchain/graph_store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def _texts_to_nodes(
raise ValueError("texts iterable longer than ids")

links = _metadata.pop(METADATA_LINKS_KEY, set())
if not isinstance(links, Set):
links = set(links)
yield TextNode(
id=_id,
metadata=_metadata,
Expand All @@ -90,6 +92,8 @@ def _documents_to_nodes(
raise ValueError("documents iterable longer than ids")
metadata = doc.metadata.copy()
links = metadata.pop(METADATA_LINKS_KEY, set())
if not isinstance(links, Set):
links = set(links)
yield TextNode(
id=_id,
metadata=metadata,
Expand All @@ -99,6 +103,21 @@ def _documents_to_nodes(
if ids and _has_next(ids_it):
raise ValueError("ids iterable longer than documents")

def nodes_to_documents(
nodes: Iterable[Node]
) -> Iterator[Document]:
for node in nodes:
metadata = node.metadata.copy()
metadata[METADATA_LINKS_KEY] = {
# Convert the core `Link` (from the node) back to the local `Link`.
Link(kind=link.kind, direction=link.direction, tag=link.tag)
for link in node.links
}

yield Document(
page_content=node.text,
metadata=metadata,
)

class GraphStore(VectorStore):
"""A hybrid vector-and-graph graph store.
Expand Down
22 changes: 8 additions & 14 deletions libs/langchain/ragstack_langchain/graph_store/cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings

from .base import GraphStore, Node, TextNode
from .base import GraphStore, Node, TextNode, nodes_to_documents
from ragstack_knowledge_store import EmbeddingModel, graph_store


Expand Down Expand Up @@ -125,8 +125,8 @@ def similarity_search(self, query: str, k: int = 4, **kwargs: Any) -> List[Docum
def similarity_search_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any
) -> List[Document]:
return [Document(page_content=node.text, metadata=node.metadata)
for node in self.store.similarity_search(embedding, k=k)]
nodes = self.store.similarity_search(embedding, k=k)
return list(nodes_to_documents(nodes))

def traversal_search(
self,
Expand All @@ -136,11 +136,8 @@ def traversal_search(
depth: int = 1,
**kwargs: Any,
) -> Iterable[Document]:
for node in self.store.traversal_search(query, k=k, depth=depth):
yield Document(
page_content=node.text,
metadata=node.metadata,
)
nodes = self.store.traversal_search(query, k=k, depth=depth)
return nodes_to_documents(nodes)

def mmr_traversal_search(
self,
Expand All @@ -153,15 +150,12 @@ def mmr_traversal_search(
score_threshold: float = float("-inf"),
**kwargs: Any,
) -> Iterable[Document]:
for node in self.store.mmr_traversal_search(
nodes = self.store.mmr_traversal_search(
query,
k=k,
depth=depth,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
score_threshold=score_threshold,
):
yield Document(
page_content=node.text,
metadata=node.metadata,
)
)
return nodes_to_documents(nodes)
23 changes: 23 additions & 0 deletions libs/langchain/tests/integration_tests/test_graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,3 +362,26 @@ def test_documents_to_nodes():
list(_documents_to_nodes(documents, ["a"]))
with pytest.raises(ValueError):
list(_documents_to_nodes(documents[1:], ["a", "b"]))

def test_metadata(cassandra: GraphStoreFactory) -> None:
store = cassandra.store()
store.add_documents([
Document(
page_content="A",
metadata={
"content_id": "a",
METADATA_LINKS_KEY: {
Link.incoming(kind="hyperlink", tag="http://a"),
Link.bidir(kind="other", tag="foo"),
},
"other": "some other field"
},
)
])
results = store.similarity_search("A")
assert len(results) == 1
assert results[0].metadata.get("other") == "some other field"
assert set(results[0].metadata.get(METADATA_LINKS_KEY)) == {
Link.incoming(kind="hyperlink", tag="http://a"),
Link.bidir(kind="other", tag="foo"),
}