diff --git a/libs/community/langchain_community/graph_vectorstores/__init__.py b/libs/community/langchain_community/graph_vectorstores/__init__.py new file mode 100644 index 00000000000..f5281743f71 --- /dev/null +++ b/libs/community/langchain_community/graph_vectorstores/__init__.py @@ -0,0 +1,3 @@ +from langchain_community.graph_vectorstores.cassandra import CassandraGraphVectorStore + +__all__ = ["CassandraGraphVectorStore"] diff --git a/libs/community/langchain_community/graph_vectorstores/cassandra.py b/libs/community/langchain_community/graph_vectorstores/cassandra.py new file mode 100644 index 00000000000..19f9453c42b --- /dev/null +++ b/libs/community/langchain_community/graph_vectorstores/cassandra.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + Any, + Iterable, + List, + Optional, + Type, +) + +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from langchain_core.graph_vectorstores.base import ( + GraphVectorStore, + Node, + nodes_to_documents, +) + +from langchain_community.utilities.cassandra import SetupMode + +if TYPE_CHECKING: + from cassandra.cluster import Session + + +class CassandraGraphVectorStore(GraphVectorStore): + def __init__( + self, + embedding: Embeddings, + *, + node_table: str = "graph_nodes", + targets_table: str = "graph_targets", + session: Optional[Session] = None, + keyspace: Optional[str] = None, + setup_mode: SetupMode = SetupMode.SYNC, + ): + """ + Create the hybrid graph store. + Parameters configure the ways that edges should be added between + documents. Many take `Union[bool, Set[str]]`, with `False` disabling + inference, `True` enabling it globally between all documents, and a set + of metadata fields defining a scope in which to enable it. Specifically, + passing a set of metadata fields such as `source` only links documents + with the same `source` metadata value. + Args: + embedding: The embeddings to use for the document content. + setup_mode: Mode used to create the Cassandra table (SYNC, + ASYNC or OFF). + """ + try: + from ragstack_knowledge_store import EmbeddingModel, graph_store + except (ImportError, ModuleNotFoundError): + raise ImportError( + "Could not import ragstack-knowledge-store python package. " + "Please install it with `pip install ragstack-knowledge-store`." + ) + + self._embedding = embedding + _setup_mode = getattr(graph_store.SetupMode, setup_mode.name) + + class _EmbeddingModelAdapter(EmbeddingModel): + def __init__(self, embeddings: Embeddings): + self.embeddings = embeddings + + def embed_texts(self, texts: List[str]) -> List[List[float]]: + return self.embeddings.embed_documents(texts) + + def embed_query(self, text: str) -> List[float]: + return self.embeddings.embed_query(text) + + async def aembed_texts(self, texts: List[str]) -> List[List[float]]: + return await self.embeddings.aembed_documents(texts) + + async def aembed_query(self, text: str) -> List[float]: + return await self.embeddings.aembed_query(text) + + self.store = graph_store.GraphStore( + embedding=_EmbeddingModelAdapter(embedding), + node_table=node_table, + targets_table=targets_table, + session=session, + keyspace=keyspace, + setup_mode=_setup_mode, + ) + + @property + def embeddings(self) -> Optional[Embeddings]: + return self._embedding + + def add_nodes( + self, + nodes: Iterable[Node], + **kwargs: Any, + ) -> Iterable[str]: + return self.store.add_nodes(nodes) + + @classmethod + def from_texts( + cls: Type["CassandraGraphVectorStore"], + texts: Iterable[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + ids: Optional[Iterable[str]] = None, + **kwargs: Any, + ) -> "CassandraGraphVectorStore": + """Return CassandraGraphVectorStore initialized from texts and embeddings.""" + store = cls(embedding, **kwargs) + store.add_texts(texts, metadatas, ids=ids) + return store + + @classmethod + def from_documents( + cls: Type["CassandraGraphVectorStore"], + documents: Iterable[Document], + embedding: Embeddings, + ids: Optional[Iterable[str]] = None, + **kwargs: Any, + ) -> "CassandraGraphVectorStore": + """Return CassandraGraphVectorStore initialized from documents and + embeddings.""" + store = cls(embedding, **kwargs) + store.add_documents(documents, ids=ids) + return store + + def similarity_search( + self, query: str, k: int = 4, **kwargs: Any + ) -> List[Document]: + embedding_vector = self._embedding.embed_query(query) + return self.similarity_search_by_vector( + embedding_vector, + k=k, + ) + + def similarity_search_by_vector( + self, embedding: List[float], k: int = 4, **kwargs: Any + ) -> List[Document]: + nodes = self.store.similarity_search(embedding, k=k) + return list(nodes_to_documents(nodes)) + + def traversal_search( + self, + query: str, + *, + k: int = 4, + depth: int = 1, + **kwargs: Any, + ) -> Iterable[Document]: + nodes = self.store.traversal_search(query, k=k, depth=depth) + return nodes_to_documents(nodes) + + def mmr_traversal_search( + self, + query: str, + *, + k: int = 4, + depth: int = 2, + fetch_k: int = 100, + adjacent_k: int = 10, + lambda_mult: float = 0.5, + score_threshold: float = float("-inf"), + **kwargs: Any, + ) -> Iterable[Document]: + nodes = self.store.mmr_traversal_search( + query, + k=k, + depth=depth, + fetch_k=fetch_k, + adjacent_k=adjacent_k, + lambda_mult=lambda_mult, + score_threshold=score_threshold, + ) + return nodes_to_documents(nodes) diff --git a/libs/community/tests/integration_tests/graph_vectorstores/__init__.py b/libs/community/tests/integration_tests/graph_vectorstores/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/libs/community/tests/integration_tests/graph_vectorstores/test_cassandra.py b/libs/community/tests/integration_tests/graph_vectorstores/test_cassandra.py new file mode 100644 index 00000000000..bfa3946fa4a --- /dev/null +++ b/libs/community/tests/integration_tests/graph_vectorstores/test_cassandra.py @@ -0,0 +1,272 @@ +import math +import os +from typing import Iterable, List, Optional, Type + +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from langchain_core.graph_vectorstores.links import METADATA_LINKS_KEY, Link + +from langchain_community.graph_vectorstores import CassandraGraphVectorStore + +CASSANDRA_DEFAULT_KEYSPACE = "graph_test_keyspace" + + +def _get_graph_store( + embedding_class: Type[Embeddings], documents: Iterable[Document] = () +) -> CassandraGraphVectorStore: + import cassio + from cassandra.cluster import Cluster + from cassio.config import check_resolve_session, resolve_keyspace + + node_table = "graph_test_node_table" + edge_table = "graph_test_edge_table" + + if any( + env_var in os.environ + for env_var in [ + "CASSANDRA_CONTACT_POINTS", + "ASTRA_DB_APPLICATION_TOKEN", + "ASTRA_DB_INIT_STRING", + ] + ): + cassio.init(auto=True) + session = check_resolve_session() + else: + cluster = Cluster() + session = cluster.connect() + keyspace = resolve_keyspace() or CASSANDRA_DEFAULT_KEYSPACE + cassio.init(session=session, keyspace=keyspace) + # ensure keyspace exists + session.execute( + ( + f"CREATE KEYSPACE IF NOT EXISTS {keyspace} " + f"WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}" + ) + ) + session.execute(f"DROP TABLE IF EXISTS {keyspace}.{node_table}") + session.execute(f"DROP TABLE IF EXISTS {keyspace}.{edge_table}") + store = CassandraGraphVectorStore.from_documents( + documents, + embedding=embedding_class(), + session=session, + keyspace=keyspace, + node_table=node_table, + targets_table=edge_table, + ) + return store + + +class FakeEmbeddings(Embeddings): + """Fake embeddings functionality for testing.""" + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Return simple embeddings. + Embeddings encode each text as its index.""" + return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))] + + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + return self.embed_documents(texts) + + def embed_query(self, text: str) -> List[float]: + """Return constant query embeddings. + Embeddings are identical to embed_documents(texts)[0]. + Distance to each text will be that text's index, + as it was passed to embed_documents.""" + return [float(1.0)] * 9 + [float(0.0)] + + async def aembed_query(self, text: str) -> List[float]: + return self.embed_query(text) + + +class AngularTwoDimensionalEmbeddings(Embeddings): + """ + From angles (as strings in units of pi) to unit embedding vectors on a circle. + """ + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """ + Make a list of texts into a list of embedding vectors. + """ + return [self.embed_query(text) for text in texts] + + def embed_query(self, text: str) -> List[float]: + """ + Convert input text to a 'vector' (list of floats). + If the text is a number, use it as the angle for the + unit vector in units of pi. + Any other input text becomes the singular result [0, 0] ! + """ + try: + angle = float(text) + return [math.cos(angle * math.pi), math.sin(angle * math.pi)] + except ValueError: + # Assume: just test string, no attention is paid to values. + return [0.0, 0.0] + + +def _result_ids(docs: Iterable[Document]) -> List[Optional[str]]: + return [doc.id for doc in docs] + + +def test_mmr_traversal() -> None: + """ + Test end to end construction and MMR search. + The embedding function used here ensures `texts` become + the following vectors on a circle (numbered v0 through v3): + + ______ v2 + / \ + / | v1 + v3 | . | query + | / v0 + |______/ (N.B. very crude drawing) + + With fetch_k==2 and k==2, when query is at (1, ), + one expects that v2 and v0 are returned (in some order) + because v1 is "too close" to v0 (and v0 is closer than v1)). + + Both v2 and v3 are reachable via edges from v0, so once it is + selected, those are both considered. + """ + store = _get_graph_store(AngularTwoDimensionalEmbeddings) + + v0 = Document( + id="v0", + page_content="-0.124", + metadata={ + METADATA_LINKS_KEY: [ + Link.outgoing(kind="explicit", tag="link"), + ], + }, + ) + v1 = Document( + id="v1", + page_content="+0.127", + ) + v2 = Document( + id="v2", + page_content="+0.25", + metadata={ + METADATA_LINKS_KEY: [ + Link.incoming(kind="explicit", tag="link"), + ], + }, + ) + v3 = Document( + id="v3", + page_content="+1.0", + metadata={ + METADATA_LINKS_KEY: [ + Link.incoming(kind="explicit", tag="link"), + ], + }, + ) + store.add_documents([v0, v1, v2, v3]) + + results = store.mmr_traversal_search("0.0", k=2, fetch_k=2) + assert _result_ids(results) == ["v0", "v2"] + + # With max depth 0, no edges are traversed, so this doesn't reach v2 or v3. + # So it ends up picking "v1" even though it's similar to "v0". + results = store.mmr_traversal_search("0.0", k=2, fetch_k=2, depth=0) + assert _result_ids(results) == ["v0", "v1"] + + # With max depth 0 but higher `fetch_k`, we encounter v2 + results = store.mmr_traversal_search("0.0", k=2, fetch_k=3, depth=0) + assert _result_ids(results) == ["v0", "v2"] + + # v0 score is .46, v2 score is 0.16 so it won't be chosen. + results = store.mmr_traversal_search("0.0", k=2, score_threshold=0.2) + assert _result_ids(results) == ["v0"] + + # with k=4 we should get all of the documents. + results = store.mmr_traversal_search("0.0", k=4) + assert _result_ids(results) == ["v0", "v2", "v1", "v3"] + + +def test_write_retrieve_keywords() -> None: + from langchain_openai import OpenAIEmbeddings + + greetings = Document( + id="greetings", + page_content="Typical Greetings", + metadata={ + METADATA_LINKS_KEY: [ + Link.incoming(kind="parent", tag="parent"), + ], + }, + ) + doc1 = Document( + id="doc1", + page_content="Hello World", + metadata={ + METADATA_LINKS_KEY: [ + Link.outgoing(kind="parent", tag="parent"), + Link.bidir(kind="kw", tag="greeting"), + Link.bidir(kind="kw", tag="world"), + ], + }, + ) + doc2 = Document( + id="doc2", + page_content="Hello Earth", + metadata={ + METADATA_LINKS_KEY: [ + Link.outgoing(kind="parent", tag="parent"), + Link.bidir(kind="kw", tag="greeting"), + Link.bidir(kind="kw", tag="earth"), + ], + }, + ) + store = _get_graph_store(OpenAIEmbeddings, [greetings, doc1, doc2]) + + # Doc2 is more similar, but World and Earth are similar enough that doc1 also + # shows up. + results: Iterable[Document] = store.similarity_search("Earth", k=2) + assert _result_ids(results) == ["doc2", "doc1"] + + results = store.similarity_search("Earth", k=1) + assert _result_ids(results) == ["doc2"] + + results = store.traversal_search("Earth", k=2, depth=0) + assert _result_ids(results) == ["doc2", "doc1"] + + results = store.traversal_search("Earth", k=2, depth=1) + assert _result_ids(results) == ["doc2", "doc1", "greetings"] + + # K=1 only pulls in doc2 (Hello Earth) + results = store.traversal_search("Earth", k=1, depth=0) + assert _result_ids(results) == ["doc2"] + + # K=1 only pulls in doc2 (Hello Earth). Depth=1 traverses to parent and via + # keyword edge. + results = store.traversal_search("Earth", k=1, depth=1) + assert set(_result_ids(results)) == {"doc2", "doc1", "greetings"} + + +def test_metadata() -> None: + store = _get_graph_store(FakeEmbeddings) + store.add_documents( + [ + Document( + id="a", + page_content="A", + metadata={ + METADATA_LINKS_KEY: [ + Link.incoming(kind="hyperlink", tag="http://a"), + Link.bidir(kind="other", tag="foo"), + ], + "other": "some other field", + }, + ) + ] + ) + results = store.similarity_search("A") + assert len(results) == 1 + assert results[0].id == "a" + metadata = results[0].metadata + assert metadata["other"] == "some other field" + assert set(metadata[METADATA_LINKS_KEY]) == { + Link.incoming(kind="hyperlink", tag="http://a"), + Link.bidir(kind="other", tag="foo"), + } diff --git a/libs/core/langchain_core/graph_vectorstores/__init__.py b/libs/core/langchain_core/graph_vectorstores/__init__.py new file mode 100644 index 00000000000..973f0ef9545 --- /dev/null +++ b/libs/core/langchain_core/graph_vectorstores/__init__.py @@ -0,0 +1,15 @@ +from langchain_core.graph_vectorstores.base import ( + GraphVectorStore, + GraphVectorStoreRetriever, + Node, +) +from langchain_core.graph_vectorstores.links import ( + Link, +) + +__all__ = [ + "GraphVectorStore", + "GraphVectorStoreRetriever", + "Node", + "Link", +] diff --git a/libs/core/langchain_core/graph_vectorstores/base.py b/libs/core/langchain_core/graph_vectorstores/base.py new file mode 100644 index 00000000000..37d1b438579 --- /dev/null +++ b/libs/core/langchain_core/graph_vectorstores/base.py @@ -0,0 +1,692 @@ +from __future__ import annotations + +from abc import abstractmethod +from typing import ( + Any, + AsyncIterable, + ClassVar, + Collection, + Iterable, + Iterator, + List, + Optional, +) + +from langchain_core.callbacks import ( + AsyncCallbackManagerForRetrieverRun, + CallbackManagerForRetrieverRun, +) +from langchain_core.documents import Document +from langchain_core.graph_vectorstores.links import METADATA_LINKS_KEY, Link +from langchain_core.load import Serializable +from langchain_core.pydantic_v1 import Field +from langchain_core.runnables import run_in_executor +from langchain_core.vectorstores import VectorStore, VectorStoreRetriever + + +def _has_next(iterator: Iterator) -> bool: + """Checks if the iterator has more elements. + Warning: consumes an element from the iterator""" + sentinel = object() + return next(iterator, sentinel) is not sentinel + + +class Node(Serializable): + """Node in the GraphVectorStore. + + Edges exist from nodes with an outgoing link to nodes with a matching incoming link. + + For instance two nodes `a` and `b` connected over a hyperlink `https://some-url` + would look like: + + .. code-block:: python + [ + Node( + id="a", + text="some text a", + links= [ + Link(kind="hyperlink", tag="https://some-url", direction="incoming") + ], + ), + Node( + id="b", + text="some text b", + links= [ + Link(kind="hyperlink", tag="https://some-url", direction="outgoing") + ], + ) + ] + """ + + id: Optional[str] = None + """Unique ID for the node. Will be generated by the GraphVectorStore if not set.""" + text: str + """Text contained by the node.""" + metadata: dict = Field(default_factory=dict) + """Metadata for the node.""" + links: List[Link] = Field(default_factory=list) + """Links associated with the node.""" + + +def _texts_to_nodes( + texts: Iterable[str], + metadatas: Optional[Iterable[dict]], + ids: Optional[Iterable[str]], +) -> Iterator[Node]: + metadatas_it = iter(metadatas) if metadatas else None + ids_it = iter(ids) if ids else None + for text in texts: + try: + _metadata = next(metadatas_it).copy() if metadatas_it else {} + except StopIteration: + raise ValueError("texts iterable longer than metadatas") + try: + _id = next(ids_it) if ids_it else None + except StopIteration: + raise ValueError("texts iterable longer than ids") + + links = _metadata.pop(METADATA_LINKS_KEY, []) + if not isinstance(links, list): + links = list(links) + yield Node( + id=_id, + metadata=_metadata, + text=text, + links=links, + ) + if ids_it and _has_next(ids_it): + raise ValueError("ids iterable longer than texts") + if metadatas_it and _has_next(metadatas_it): + raise ValueError("metadatas iterable longer than texts") + + +def _documents_to_nodes(documents: Iterable[Document]) -> Iterator[Node]: + for doc in documents: + metadata = doc.metadata.copy() + links = metadata.pop(METADATA_LINKS_KEY, []) + if not isinstance(links, list): + links = list(links) + yield Node( + id=doc.id, + metadata=metadata, + text=doc.page_content, + links=links, + ) + + +def nodes_to_documents(nodes: Iterable[Node]) -> Iterator[Document]: + for node in nodes: + metadata = node.metadata.copy() + metadata[METADATA_LINKS_KEY] = [ + # Convert the core `Link` (from the node) back to the local `Link`. + Link(kind=link.kind, direction=link.direction, tag=link.tag) + for link in node.links + ] + + yield Document( + id=node.id, + page_content=node.text, + metadata=metadata, + ) + + +class GraphVectorStore(VectorStore): + """A hybrid vector-and-graph graph store. + + Document chunks support vector-similarity search as well as edges linking + chunks based on structural and semantic properties. + """ + + @abstractmethod + def add_nodes( + self, + nodes: Iterable[Node], + **kwargs: Any, + ) -> Iterable[str]: + """Add nodes to the graph store. + + Args: + nodes: the nodes to add. + """ + + async def aadd_nodes( + self, + nodes: Iterable[Node], + **kwargs: Any, + ) -> AsyncIterable[str]: + """Add nodes to the graph store. + + Args: + nodes: the nodes to add. + """ + iterator = iter(await run_in_executor(None, self.add_nodes, nodes, **kwargs)) + done = object() + while True: + doc = await run_in_executor(None, next, iterator, done) + if doc is done: + break + yield doc # type: ignore[misc] + + def add_texts( + self, + texts: Iterable[str], + metadatas: Optional[Iterable[dict]] = None, + *, + ids: Optional[Iterable[str]] = None, + **kwargs: Any, + ) -> List[str]: + """Run more texts through the embeddings and add to the vectorstore. + + The Links present in the metadata field `links` will be extracted to create + the `Node` links. + + Eg if nodes `a` and `b` are connected over a hyperlink `https://some-url`, the + function call would look like: + + .. code-block:: python + + store.add_texts( + ids=["a", "b"], + texts=["some text a", "some text b"], + metadatas=[ + { + "links": [ + Link.incoming(kind="hyperlink", tag="https://some-url") + ] + }, + { + "links": [ + Link.outgoing(kind="hyperlink", tag="https://some-url") + ] + }, + ], + ) + + Args: + texts: Iterable of strings to add to the vectorstore. + metadatas: Optional list of metadatas associated with the texts. + The metadata key `links` shall be an iterable of + :py:class:`~langchain_core.graph_vectorstores.links.Link`. + **kwargs: vectorstore specific parameters. + + Returns: + List of ids from adding the texts into the vectorstore. + """ + nodes = _texts_to_nodes(texts, metadatas, ids) + return list(self.add_nodes(nodes, **kwargs)) + + async def aadd_texts( + self, + texts: Iterable[str], + metadatas: Optional[Iterable[dict]] = None, + *, + ids: Optional[Iterable[str]] = None, + **kwargs: Any, + ) -> List[str]: + """Run more texts through the embeddings and add to the vectorstore. + + The Links present in the metadata field `links` will be extracted to create + the `Node` links. + + Eg if nodes `a` and `b` are connected over a hyperlink `https://some-url`, the + function call would look like: + + .. code-block:: python + + await store.aadd_texts( + ids=["a", "b"], + texts=["some text a", "some text b"], + metadatas=[ + { + "links": [ + Link.incoming(kind="hyperlink", tag="https://some-url") + ] + }, + { + "links": [ + Link.outgoing(kind="hyperlink", tag="https://some-url") + ] + }, + ], + ) + + Args: + texts: Iterable of strings to add to the vectorstore. + metadatas: Optional list of metadatas associated with the texts. + The metadata key `links` shall be an iterable of + :py:class:`~langchain_core.graph_vectorstores.links.Link`. + **kwargs: vectorstore specific parameters. + + Returns: + List of ids from adding the texts into the vectorstore. + """ + nodes = _texts_to_nodes(texts, metadatas, ids) + return [_id async for _id in self.aadd_nodes(nodes, **kwargs)] + + def add_documents( + self, + documents: Iterable[Document], + **kwargs: Any, + ) -> List[str]: + """Run more documents through the embeddings and add to the vectorstore. + + The Links present in the document metadata field `links` will be extracted to + create the `Node` links. + + Eg if nodes `a` and `b` are connected over a hyperlink `https://some-url`, the + function call would look like: + + .. code-block:: python + + store.add_documents( + [ + Document( + id="a", + page_content="some text a", + metadata={ + "links": [ + Link.incoming(kind="hyperlink", tag="http://some-url") + ] + } + ), + Document( + id="b", + page_content="some text b", + metadata={ + "links": [ + Link.outgoing(kind="hyperlink", tag="http://some-url") + ] + } + ), + ] + + ) + + Args: + documents: Documents to add to the vectorstore. + The document's metadata key `links` shall be an iterable of + :py:class:`~langchain_core.graph_vectorstores.links.Link`. + + Returns: + List of IDs of the added texts. + """ + nodes = _documents_to_nodes(documents) + return list(self.add_nodes(nodes, **kwargs)) + + async def aadd_documents( + self, + documents: Iterable[Document], + **kwargs: Any, + ) -> List[str]: + """Run more documents through the embeddings and add to the vectorstore. + + The Links present in the document metadata field `links` will be extracted to + create the `Node` links. + + Eg if nodes `a` and `b` are connected over a hyperlink `https://some-url`, the + function call would look like: + + .. code-block:: python + + store.add_documents( + [ + Document( + id="a", + page_content="some text a", + metadata={ + "links": [ + Link.incoming(kind="hyperlink", tag="http://some-url") + ] + } + ), + Document( + id="b", + page_content="some text b", + metadata={ + "links": [ + Link.outgoing(kind="hyperlink", tag="http://some-url") + ] + } + ), + ] + + ) + + Args: + documents: Documents to add to the vectorstore. + The document's metadata key `links` shall be an iterable of + :py:class:`~langchain_core.graph_vectorstores.links.Link`. + + Returns: + List of IDs of the added texts. + """ + nodes = _documents_to_nodes(documents) + return [_id async for _id in self.aadd_nodes(nodes, **kwargs)] + + @abstractmethod + def traversal_search( + self, + query: str, + *, + k: int = 4, + depth: int = 1, + **kwargs: Any, + ) -> Iterable[Document]: + """Retrieve documents from traversing this graph store. + + First, `k` nodes are retrieved using a search for each `query` string. + Then, additional nodes are discovered up to the given `depth` from those + starting nodes. + + Args: + query: The query string. + k: The number of Documents to return from the initial search. + Defaults to 4. Applies to each of the query strings. + depth: The maximum depth of edges to traverse. Defaults to 1. + Returns: + Retrieved documents. + """ + + async def atraversal_search( + self, + query: str, + *, + k: int = 4, + depth: int = 1, + **kwargs: Any, + ) -> AsyncIterable[Document]: + """Retrieve documents from traversing this graph store. + + First, `k` nodes are retrieved using a search for each `query` string. + Then, additional nodes are discovered up to the given `depth` from those + starting nodes. + + Args: + query: The query string. + k: The number of Documents to return from the initial search. + Defaults to 4. Applies to each of the query strings. + depth: The maximum depth of edges to traverse. Defaults to 1. + Returns: + Retrieved documents. + """ + iterator = iter( + await run_in_executor( + None, self.traversal_search, query, k=k, depth=depth, **kwargs + ) + ) + done = object() + while True: + doc = await run_in_executor(None, next, iterator, done) + if doc is done: + break + yield doc # type: ignore[misc] + + @abstractmethod + def mmr_traversal_search( + self, + query: str, + *, + k: int = 4, + depth: int = 2, + fetch_k: int = 100, + adjacent_k: int = 10, + lambda_mult: float = 0.5, + score_threshold: float = float("-inf"), + **kwargs: Any, + ) -> Iterable[Document]: + """Retrieve documents from this graph store using MMR-traversal. + + This strategy first retrieves the top `fetch_k` results by similarity to + the question. It then selects the top `k` results based on + maximum-marginal relevance using the given `lambda_mult`. + + At each step, it considers the (remaining) documents from `fetch_k` as + well as any documents connected by edges to a selected document + retrieved based on similarity (a "root"). + + Args: + query: The query string to search for. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of Documents to fetch via similarity. + Defaults to 100. + adjacent_k: Number of adjacent Documents to fetch. + Defaults to 10. + depth: Maximum depth of a node (number of edges) from a node + retrieved via similarity. Defaults to 2. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding to maximum + diversity and 1 to minimum diversity. Defaults to 0.5. + score_threshold: Only documents with a score greater than or equal + this threshold will be chosen. Defaults to negative infinity. + """ + + async def ammr_traversal_search( + self, + query: str, + *, + k: int = 4, + depth: int = 2, + fetch_k: int = 100, + adjacent_k: int = 10, + lambda_mult: float = 0.5, + score_threshold: float = float("-inf"), + **kwargs: Any, + ) -> AsyncIterable[Document]: + """Retrieve documents from this graph store using MMR-traversal. + + This strategy first retrieves the top `fetch_k` results by similarity to + the question. It then selects the top `k` results based on + maximum-marginal relevance using the given `lambda_mult`. + + At each step, it considers the (remaining) documents from `fetch_k` as + well as any documents connected by edges to a selected document + retrieved based on similarity (a "root"). + + Args: + query: The query string to search for. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of Documents to fetch via similarity. + Defaults to 100. + adjacent_k: Number of adjacent Documents to fetch. + Defaults to 10. + depth: Maximum depth of a node (number of edges) from a node + retrieved via similarity. Defaults to 2. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding to maximum + diversity and 1 to minimum diversity. Defaults to 0.5. + score_threshold: Only documents with a score greater than or equal + this threshold will be chosen. Defaults to negative infinity. + """ + iterator = iter( + await run_in_executor( + None, + self.mmr_traversal_search, + query, + k=k, + fetch_k=fetch_k, + adjacent_k=adjacent_k, + depth=depth, + lambda_mult=lambda_mult, + score_threshold=score_threshold, + **kwargs, + ) + ) + done = object() + while True: + doc = await run_in_executor(None, next, iterator, done) + if doc is done: + break + yield doc # type: ignore[misc] + + def similarity_search( + self, query: str, k: int = 4, **kwargs: Any + ) -> List[Document]: + return list(self.traversal_search(query, k=k, depth=0)) + + def max_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + **kwargs: Any, + ) -> List[Document]: + return list( + self.mmr_traversal_search( + query, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, depth=0 + ) + ) + + async def asimilarity_search( + self, query: str, k: int = 4, **kwargs: Any + ) -> List[Document]: + return [doc async for doc in self.atraversal_search(query, k=k, depth=0)] + + def search(self, query: str, search_type: str, **kwargs: Any) -> List[Document]: + if search_type == "similarity": + return self.similarity_search(query, **kwargs) + elif search_type == "similarity_score_threshold": + docs_and_similarities = self.similarity_search_with_relevance_scores( + query, **kwargs + ) + return [doc for doc, _ in docs_and_similarities] + elif search_type == "mmr": + return self.max_marginal_relevance_search(query, **kwargs) + elif search_type == "traversal": + return list(self.traversal_search(query, **kwargs)) + elif search_type == "mmr_traversal": + return list(self.mmr_traversal_search(query, **kwargs)) + else: + raise ValueError( + f"search_type of {search_type} not allowed. Expected " + "search_type to be 'similarity', 'similarity_score_threshold', " + "'mmr' or 'traversal'." + ) + + async def asearch( + self, query: str, search_type: str, **kwargs: Any + ) -> List[Document]: + if search_type == "similarity": + return await self.asimilarity_search(query, **kwargs) + elif search_type == "similarity_score_threshold": + docs_and_similarities = await self.asimilarity_search_with_relevance_scores( + query, **kwargs + ) + return [doc for doc, _ in docs_and_similarities] + elif search_type == "mmr": + return await self.amax_marginal_relevance_search(query, **kwargs) + elif search_type == "traversal": + return [doc async for doc in self.atraversal_search(query, **kwargs)] + else: + raise ValueError( + f"search_type of {search_type} not allowed. Expected " + "search_type to be 'similarity', 'similarity_score_threshold', " + "'mmr' or 'traversal'." + ) + + def as_retriever(self, **kwargs: Any) -> "GraphVectorStoreRetriever": + """Return GraphVectorStoreRetriever initialized from this GraphVectorStore. + + Args: + search_type (Optional[str]): Defines the type of search that + the Retriever should perform. + Can be "traversal" (default), "similarity", "mmr", or + "similarity_score_threshold". + search_kwargs (Optional[Dict]): Keyword arguments to pass to the + search function. Can include things like: + k: Amount of documents to return (Default: 4) + depth: The maximum depth of edges to traverse (Default: 1) + score_threshold: Minimum relevance threshold + for similarity_score_threshold + fetch_k: Amount of documents to pass to MMR algorithm (Default: 20) + lambda_mult: Diversity of results returned by MMR; + 1 for minimum diversity and 0 for maximum. (Default: 0.5) + Returns: + Retriever for this GraphVectorStore. + + Examples: + + .. code-block:: python + + # Retrieve documents traversing edges + docsearch.as_retriever( + search_type="traversal", + search_kwargs={'k': 6, 'depth': 3} + ) + + # Retrieve more documents with higher diversity + # Useful if your dataset has many similar documents + docsearch.as_retriever( + search_type="mmr", + search_kwargs={'k': 6, 'lambda_mult': 0.25} + ) + + # Fetch more documents for the MMR algorithm to consider + # But only return the top 5 + docsearch.as_retriever( + search_type="mmr", + search_kwargs={'k': 5, 'fetch_k': 50} + ) + + # Only retrieve documents that have a relevance score + # Above a certain threshold + docsearch.as_retriever( + search_type="similarity_score_threshold", + search_kwargs={'score_threshold': 0.8} + ) + + # Only get the single most similar document from the dataset + docsearch.as_retriever(search_kwargs={'k': 1}) + + """ + return GraphVectorStoreRetriever(vectorstore=self, **kwargs) + + +class GraphVectorStoreRetriever(VectorStoreRetriever): + """Retriever class for GraphVectorStore.""" + + vectorstore: GraphVectorStore + """GraphVectorStore to use for retrieval.""" + search_type: str = "traversal" + """Type of search to perform. Defaults to "traversal".""" + allowed_search_types: ClassVar[Collection[str]] = ( + "similarity", + "similarity_score_threshold", + "mmr", + "traversal", + "mmr_traversal", + ) + + def _get_relevant_documents( + self, query: str, *, run_manager: CallbackManagerForRetrieverRun + ) -> List[Document]: + if self.search_type == "traversal": + return list(self.vectorstore.traversal_search(query, **self.search_kwargs)) + elif self.search_type == "mmr_traversal": + return list( + self.vectorstore.mmr_traversal_search(query, **self.search_kwargs) + ) + else: + return super()._get_relevant_documents(query, run_manager=run_manager) + + async def _aget_relevant_documents( + self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun + ) -> List[Document]: + if self.search_type == "traversal": + return [ + doc + async for doc in self.vectorstore.atraversal_search( + query, **self.search_kwargs + ) + ] + elif self.search_type == "mmr_traversal": + return [ + doc + async for doc in self.vectorstore.ammr_traversal_search( + query, **self.search_kwargs + ) + ] + else: + return await super()._aget_relevant_documents( + query, run_manager=run_manager + ) diff --git a/libs/core/langchain_core/graph_vectorstores/links.py b/libs/core/langchain_core/graph_vectorstores/links.py new file mode 100644 index 00000000000..9da58a39276 --- /dev/null +++ b/libs/core/langchain_core/graph_vectorstores/links.py @@ -0,0 +1,68 @@ +from dataclasses import dataclass +from typing import Iterable, List, Literal, Union + +from langchain_core.documents import Document + + +@dataclass(frozen=True) +class Link: + """A link to/from a tag of a given tag. + + Edges exist from nodes with an outgoing link to nodes with a matching incoming link. + """ + + kind: str + """The kind of link. Allows different extractors to use the same tag name without + creating collisions between extractors. For example “keyword” vs “url”.""" + direction: Literal["in", "out", "bidir"] + """The direction of the link.""" + tag: str + """The tag of the link.""" + + @staticmethod + def incoming(kind: str, tag: str) -> "Link": + """Create an incoming link.""" + return Link(kind=kind, direction="in", tag=tag) + + @staticmethod + def outgoing(kind: str, tag: str) -> "Link": + """Create an outgoing link.""" + return Link(kind=kind, direction="out", tag=tag) + + @staticmethod + def bidir(kind: str, tag: str) -> "Link": + """Create a bidirectional link.""" + return Link(kind=kind, direction="bidir", tag=tag) + + +METADATA_LINKS_KEY = "links" + + +def get_links(doc: Document) -> List[Link]: + """Get the links from a document. + Args: + doc: The document to get the link tags from. + Returns: + The set of link tags from the document. + """ + + links = doc.metadata.setdefault(METADATA_LINKS_KEY, []) + if not isinstance(links, list): + # Convert to a list and remember that. + links = list(links) + doc.metadata[METADATA_LINKS_KEY] = links + return links + + +def add_links(doc: Document, *links: Union[Link, Iterable[Link]]) -> None: + """Add links to the given metadata. + Args: + doc: The document to add the links to. + *links: The links to add to the document. + """ + links_in_metadata = get_links(doc) + for link in links: + if isinstance(link, Iterable): + links_in_metadata.extend(link) + else: + links_in_metadata.append(link) diff --git a/libs/core/tests/unit_tests/test_graph_vectorstores.py b/libs/core/tests/unit_tests/test_graph_vectorstores.py new file mode 100644 index 00000000000..2e3c8c5bdaf --- /dev/null +++ b/libs/core/tests/unit_tests/test_graph_vectorstores.py @@ -0,0 +1,59 @@ +import pytest + +from langchain_core.documents import Document +from langchain_core.graph_vectorstores.base import ( + Node, + _documents_to_nodes, + _texts_to_nodes, +) +from langchain_core.graph_vectorstores.links import Link + + +def test_texts_to_nodes() -> None: + assert list(_texts_to_nodes(["a", "b"], [{"a": "b"}, {"c": "d"}], ["a", "b"])) == [ + Node(id="a", metadata={"a": "b"}, text="a"), + Node(id="b", metadata={"c": "d"}, text="b"), + ] + assert list(_texts_to_nodes(["a", "b"], None, ["a", "b"])) == [ + Node(id="a", metadata={}, text="a"), + Node(id="b", metadata={}, text="b"), + ] + assert list(_texts_to_nodes(["a", "b"], [{"a": "b"}, {"c": "d"}], None)) == [ + Node(metadata={"a": "b"}, text="a"), + Node(metadata={"c": "d"}, text="b"), + ] + assert list( + _texts_to_nodes( + ["a"], + [{"links": {Link.incoming(kind="hyperlink", tag="http://b")}}], + None, + ) + ) == [Node(links=[Link.incoming(kind="hyperlink", tag="http://b")], text="a")] + with pytest.raises(ValueError): + list(_texts_to_nodes(["a", "b"], None, ["a"])) + with pytest.raises(ValueError): + list(_texts_to_nodes(["a", "b"], [{"a": "b"}], None)) + with pytest.raises(ValueError): + list(_texts_to_nodes(["a"], [{"a": "b"}, {"c": "d"}], None)) + with pytest.raises(ValueError): + list(_texts_to_nodes(["a"], None, ["a", "b"])) + + +def test_documents_to_nodes() -> None: + documents = [ + Document( + id="a", + page_content="some text a", + metadata={"links": [Link.incoming(kind="hyperlink", tag="http://b")]}, + ), + Document(id="b", page_content="some text b", metadata={"c": "d"}), + ] + assert list(_documents_to_nodes(documents)) == [ + Node( + id="a", + metadata={}, + links=[Link.incoming(kind="hyperlink", tag="http://b")], + text="some text a", + ), + Node(id="b", metadata={"c": "d"}, text="some text b"), + ]