core[minor]: Add Graph Store component (#23092)

This PR introduces a GraphStore component. GraphStore extends
VectorStore with the concept of links between documents based on
document metadata. This allows linking documents based on a variety of
techniques, including common keywords, explicit links in the content,
and other patterns.

This works with existing Documents, so it’s easy to extend existing
VectorStores to be used as GraphStores. The interface can be implemented
for any Vector Store technology that supports metadata, not only graph
DBs.

When retrieving documents for a given query, the first level of search
is done using classical similarity search. Next, links may be followed
using various traversal strategies to get additional documents. This
allows documents to be retrieved that aren’t directly similar to the
query but contain relevant information.

2 retrieving methods are added to the VectorStore ones : 
* traversal_search which gets all linked documents up to a certain depth
* mmr_traversal_search which selects linked documents using an MMR
algorithm to have more diverse results.

If a depth of retrieval of 0 is used, GraphStore is effectively a
VectorStore. It enables an easy transition from a simple VectorStore to
GraphStore by adding links between documents as a second step.

An implementation for Apache Cassandra is also proposed.

See
https://github.com/datastax/ragstack-ai/blob/main/libs/knowledge-store/notebooks/astra_support.ipynb
for a notebook explaining how to use GraphStore and that shows that it
can answer correctly to questions that a simple VectorStore cannot.

**Twitter handle:** _cbornet
This commit is contained in:
Christophe Bornet
2024-07-05 18:24:10 +02:00
committed by GitHub
parent 77f5fc3d55
commit 42d049f618
8 changed files with 1281 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
from langchain_community.graph_vectorstores.cassandra import CassandraGraphVectorStore
__all__ = ["CassandraGraphVectorStore"]

View File

@@ -0,0 +1,172 @@
from __future__ import annotations
from typing import (
TYPE_CHECKING,
Any,
Iterable,
List,
Optional,
Type,
)
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.graph_vectorstores.base import (
GraphVectorStore,
Node,
nodes_to_documents,
)
from langchain_community.utilities.cassandra import SetupMode
if TYPE_CHECKING:
from cassandra.cluster import Session
class CassandraGraphVectorStore(GraphVectorStore):
def __init__(
self,
embedding: Embeddings,
*,
node_table: str = "graph_nodes",
targets_table: str = "graph_targets",
session: Optional[Session] = None,
keyspace: Optional[str] = None,
setup_mode: SetupMode = SetupMode.SYNC,
):
"""
Create the hybrid graph store.
Parameters configure the ways that edges should be added between
documents. Many take `Union[bool, Set[str]]`, with `False` disabling
inference, `True` enabling it globally between all documents, and a set
of metadata fields defining a scope in which to enable it. Specifically,
passing a set of metadata fields such as `source` only links documents
with the same `source` metadata value.
Args:
embedding: The embeddings to use for the document content.
setup_mode: Mode used to create the Cassandra table (SYNC,
ASYNC or OFF).
"""
try:
from ragstack_knowledge_store import EmbeddingModel, graph_store
except (ImportError, ModuleNotFoundError):
raise ImportError(
"Could not import ragstack-knowledge-store python package. "
"Please install it with `pip install ragstack-knowledge-store`."
)
self._embedding = embedding
_setup_mode = getattr(graph_store.SetupMode, setup_mode.name)
class _EmbeddingModelAdapter(EmbeddingModel):
def __init__(self, embeddings: Embeddings):
self.embeddings = embeddings
def embed_texts(self, texts: List[str]) -> List[List[float]]:
return self.embeddings.embed_documents(texts)
def embed_query(self, text: str) -> List[float]:
return self.embeddings.embed_query(text)
async def aembed_texts(self, texts: List[str]) -> List[List[float]]:
return await self.embeddings.aembed_documents(texts)
async def aembed_query(self, text: str) -> List[float]:
return await self.embeddings.aembed_query(text)
self.store = graph_store.GraphStore(
embedding=_EmbeddingModelAdapter(embedding),
node_table=node_table,
targets_table=targets_table,
session=session,
keyspace=keyspace,
setup_mode=_setup_mode,
)
@property
def embeddings(self) -> Optional[Embeddings]:
return self._embedding
def add_nodes(
self,
nodes: Iterable[Node],
**kwargs: Any,
) -> Iterable[str]:
return self.store.add_nodes(nodes)
@classmethod
def from_texts(
cls: Type["CassandraGraphVectorStore"],
texts: Iterable[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
ids: Optional[Iterable[str]] = None,
**kwargs: Any,
) -> "CassandraGraphVectorStore":
"""Return CassandraGraphVectorStore initialized from texts and embeddings."""
store = cls(embedding, **kwargs)
store.add_texts(texts, metadatas, ids=ids)
return store
@classmethod
def from_documents(
cls: Type["CassandraGraphVectorStore"],
documents: Iterable[Document],
embedding: Embeddings,
ids: Optional[Iterable[str]] = None,
**kwargs: Any,
) -> "CassandraGraphVectorStore":
"""Return CassandraGraphVectorStore initialized from documents and
embeddings."""
store = cls(embedding, **kwargs)
store.add_documents(documents, ids=ids)
return store
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
embedding_vector = self._embedding.embed_query(query)
return self.similarity_search_by_vector(
embedding_vector,
k=k,
)
def similarity_search_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any
) -> List[Document]:
nodes = self.store.similarity_search(embedding, k=k)
return list(nodes_to_documents(nodes))
def traversal_search(
self,
query: str,
*,
k: int = 4,
depth: int = 1,
**kwargs: Any,
) -> Iterable[Document]:
nodes = self.store.traversal_search(query, k=k, depth=depth)
return nodes_to_documents(nodes)
def mmr_traversal_search(
self,
query: str,
*,
k: int = 4,
depth: int = 2,
fetch_k: int = 100,
adjacent_k: int = 10,
lambda_mult: float = 0.5,
score_threshold: float = float("-inf"),
**kwargs: Any,
) -> Iterable[Document]:
nodes = self.store.mmr_traversal_search(
query,
k=k,
depth=depth,
fetch_k=fetch_k,
adjacent_k=adjacent_k,
lambda_mult=lambda_mult,
score_threshold=score_threshold,
)
return nodes_to_documents(nodes)