Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions libs/knowledge-store/ragstack_knowledge_store/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .base import KnowledgeStore
from .cassandra import CassandraKnowledgeStore
from .embedding_model import EmbeddingModel
from .knowledge_store import KnowledgeStore, Node, SetupMode, TextNode

__all__ = ["CassandraKnowledgeStore", "KnowledgeStore"]
__all__ = ["EmbeddingModel", "KnowledgeStore", "Node", "SetupMode", "TextNode"]
12 changes: 1 addition & 11 deletions libs/knowledge-store/ragstack_knowledge_store/edge_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,7 @@

import abc
from abc import ABC, abstractmethod
from typing import (
Any,
Dict,
Generic,
Iterable,
Iterator,
Literal,
Set,
TypeVar,
Union,
)
from typing import Any, Dict, Generic, Iterable, Iterator, Literal, Set, TypeVar, Union

from langchain_core.documents import Document
from pydantic import BaseModel
Expand Down
22 changes: 22 additions & 0 deletions libs/knowledge-store/ragstack_knowledge_store/embedding_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from abc import ABC, abstractmethod
from typing import List


class EmbeddingModel(ABC):
"""Embedding model."""

@abstractmethod
def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Embed texts."""

@abstractmethod
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""

@abstractmethod
async def aembed_texts(self, texts: List[str]) -> List[List[float]]:
"""Embed texts."""

@abstractmethod
async def aembed_query(self, text: str) -> List[float]:
"""Embed query text."""
Original file line number Diff line number Diff line change
@@ -1,48 +1,66 @@
import secrets
from dataclasses import dataclass
from dataclasses import dataclass, field
from enum import Enum
from typing import (
Any,
Dict,
Iterable,
List,
NamedTuple,
Optional,
Sequence,
Tuple,
Type,
)

import numpy as np
from cassandra.cluster import ConsistencyLevel, ResponseFuture, Session
from cassio.config import check_resolve_keyspace, check_resolve_session
from langchain_community.utilities.cassandra import SetupMode
from langchain_community.utils.math import cosine_similarity
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings

from ragstack_knowledge_store.edge_extractor import get_link_tags

from .base import KnowledgeStore, Node, TextNode
from .concurrency import ConcurrentQueries
from .content import Kind
from .edge_extractor import get_link_tags
from .embedding_model import EmbeddingModel
from .math import cosine_similarity

CONTENT_ID = "content_id"


def _row_to_document(row) -> Document:
return Document(
page_content=row.text_content,
@dataclass
class Node:
"""Node in the KnowledgeStore graph"""

id: Optional[str] = None
"""Unique ID for the node. Will be generated by the KnowledgeStore if not set."""
metadata: dict = field(default_factory=dict)
"""Metadata for the node. May contain information used to link this node
with other nodes."""


@dataclass
class TextNode(Node):
text: str = None
"""Text contained by the node."""


class SetupMode(Enum):
SYNC = 1
ASYNC = 2
OFF = 3


def _row_to_node(row) -> Node:
return TextNode(
text=row.text_content,
metadata={
CONTENT_ID: row.content_id,
"kind": row.kind,
},
)


def _results_to_documents(results: Optional[ResponseFuture]) -> Iterable[Document]:
def _results_to_nodes(results: Optional[ResponseFuture]) -> Iterable[TextNode]:
if results:
for row in results:
yield _row_to_document(row)
yield _row_to_node(row)


def _results_to_ids(results: Optional[ResponseFuture]) -> Iterable[str]:
Expand Down Expand Up @@ -94,10 +112,10 @@ def update_for_selection(
self.score = self.similarity_to_query - selected_r_sim


class CassandraKnowledgeStore(KnowledgeStore):
class KnowledgeStore:
def __init__(
self,
embedding: Embeddings,
embedding: EmbeddingModel,
*,
node_table: str = "knowledge_nodes",
edge_table: str = "knowledge_edges",
Expand All @@ -114,8 +132,8 @@ def __init__(
Args:
embedding: The embeddings to use for the document content.
concurrency: Maximum number of queries to have concurrently executing.
apply_schema: If true, the schema will be created if necessary. If false,
the schema must have already been applied.
setup_mode: Mode used to create the Cassandra table (SYNC,
ASYNC or OFF).
"""
session = check_resolve_session(session)
keyspace = check_resolve_keyspace(keyspace)
Expand Down Expand Up @@ -300,18 +318,13 @@ def _apply_schema(self):
"""
)

@property
def embeddings(self) -> Optional[Embeddings]:
return self._embedding

def _concurrent_queries(self) -> ConcurrentQueries:
return ConcurrentQueries(self._session, concurrency=self._concurrency)

# TODO: Async (aadd_nodes)
def add_nodes(
self,
nodes: Iterable[Node] = None,
**kwargs: Any,
):
texts = []
metadatas = []
Expand All @@ -321,7 +334,7 @@ def add_nodes(
texts.append(node.text)
metadatas.append(node.metadata)

text_embeddings = self._embedding.embed_documents(texts)
text_embeddings = self._embedding.embed_texts(texts)

ids = []

Expand Down Expand Up @@ -467,63 +480,21 @@ def add_edges_for_targets(

return ids

@classmethod
def from_texts(
cls: Type["CassandraKnowledgeStore"],
texts: Iterable[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
ids: Optional[Iterable[str]] = None,
**kwargs: Any,
) -> "CassandraKnowledgeStore":
"""Return CassandraKnowledgeStore initialized from texts and embeddings."""
store = cls(embedding, **kwargs)
store.add_texts(texts, metadatas, ids=ids)
return store

@classmethod
def from_documents(
cls: Type["CassandraKnowledgeStore"],
documents: Iterable[Document],
embedding: Embeddings,
ids: Optional[Iterable[str]] = None,
**kwargs: Any,
) -> "CassandraKnowledgeStore":
"""Return CassandraKnowledgeStore initialized from documents and embeddings."""
store = cls(embedding, **kwargs)
store.add_documents(documents, ids=ids)
return store

def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
embedding_vector = self._embedding.embed_query(query)
return self.similarity_search_by_vector(
embedding_vector,
k=k,
)

def similarity_search_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any
) -> List[Document]:
results = self._session.execute(self._query_by_embedding, (embedding, k))
return list(_results_to_documents(results))

def _query_by_ids(
self,
ids: Iterable[str],
) -> Iterable[Document]:
) -> Iterable[TextNode]:
results = []
with self._concurrent_queries() as cq:

def add_documents(rows, index):
results.extend([(index, _row_to_document(row)) for row in rows])
results.extend([(index, _row_to_node(row)) for row in rows])

for index, id in enumerate(ids):
for idx, id in enumerate(ids):
cq.execute(
self._query_by_id,
parameters=(id,),
callback=lambda rows, index=index: add_documents(rows, index),
callback=lambda rows, index=idx: add_documents(rows, index),
)

results.sort(key=lambda tuple: tuple[0])
Expand All @@ -545,7 +516,7 @@ def mmr_traversal_search(
fetch_k: int = 100,
lambda_mult: float = 0.5,
score_threshold: float = float("-inf"),
) -> Iterable[Document]:
) -> Iterable[TextNode]:
"""Retrieve documents from this knowledge store using MMR-traversal.

This strategy first retrieves the top `fetch_k` results by similarity to
Expand Down Expand Up @@ -646,7 +617,7 @@ def mmr_traversal_search(

def traversal_search(
self, query: str, *, k: int = 4, depth: int = 1
) -> Iterable[Document]:
) -> Iterable[TextNode]:
"""Retrieve documents from this knowledge store.

First, `k` nodes are retrieved using a vector search for the `query` string.
Expand Down Expand Up @@ -686,3 +657,11 @@ def visit(d: int, nodes: Sequence[NamedTuple]):
)

return self._query_by_ids(visited.keys())

def similarity_search(
self,
embedding: List[float],
k: int = 4,
) -> Iterable[TextNode]:
for row in self._session.execute(self._query_by_embedding, (embedding, k)):
yield _row_to_node(row)
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .base import KnowledgeStore, Node, TextNode
from .cassandra import CassandraKnowledgeStore

__all__ = ["CassandraKnowledgeStore", "KnowledgeStore", "Node", "TextNode"]
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,18 @@ def _has_next(iterator: Iterator) -> None:


class Node(Serializable):
"""Node in the KnowledgeStore graph"""
"""Node in the KnowledgeStore graph."""

id: Optional[str]
"""Unique ID for the node. Shall be generated by the KnowledgeStore if not set"""
"""Unique ID for the node. Will be generated by the KnowledgeStore if not set."""
metadata: dict = Field(default_factory=dict)
"""Metadata for the node. May contain information used to link this node
with other nodes."""


class TextNode(Node):
text: str
"""Text contained by the node"""
"""Text contained by the node."""


def _texts_to_nodes(
Expand Down
Loading