core, community: move graph vectorstores to community (#26678)

remove beta namespace from core, add to community
This commit is contained in:
Erick Friis 2024-09-19 11:38:14 -07:00 committed by GitHub
parent c77c28e631
commit 311f861547
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 838 additions and 868 deletions

View File

@ -1,7 +1,708 @@
from langchain_core.graph_vectorstores.base import (
GraphVectorStore,
GraphVectorStoreRetriever,
Node,
from __future__ import annotations
from abc import abstractmethod
from collections.abc import AsyncIterable, Collection, Iterable, Iterator
from typing import (
Any,
ClassVar,
Optional,
)
__all__ = ["GraphVectorStore", "GraphVectorStoreRetriever", "Node"]
from langchain_core._api import beta
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.load import Serializable
from langchain_core.runnables import run_in_executor
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
from pydantic import Field
from langchain_community.graph_vectorstores.links import METADATA_LINKS_KEY, Link
def _has_next(iterator: Iterator) -> bool:
"""Checks if the iterator has more elements.
Warning: consumes an element from the iterator"""
sentinel = object()
return next(iterator, sentinel) is not sentinel
@beta()
class Node(Serializable):
"""Node in the GraphVectorStore.
Edges exist from nodes with an outgoing link to nodes with a matching incoming link.
For instance two nodes `a` and `b` connected over a hyperlink ``https://some-url``
would look like:
.. code-block:: python
[
Node(
id="a",
text="some text a",
links= [
Link(kind="hyperlink", tag="https://some-url", direction="incoming")
],
),
Node(
id="b",
text="some text b",
links= [
Link(kind="hyperlink", tag="https://some-url", direction="outgoing")
],
)
]
"""
id: Optional[str] = None
"""Unique ID for the node. Will be generated by the GraphVectorStore if not set."""
text: str
"""Text contained by the node."""
metadata: dict = Field(default_factory=dict)
"""Metadata for the node."""
links: list[Link] = Field(default_factory=list)
"""Links associated with the node."""
def _texts_to_nodes(
texts: Iterable[str],
metadatas: Optional[Iterable[dict]],
ids: Optional[Iterable[str]],
) -> Iterator[Node]:
metadatas_it = iter(metadatas) if metadatas else None
ids_it = iter(ids) if ids else None
for text in texts:
try:
_metadata = next(metadatas_it).copy() if metadatas_it else {}
except StopIteration as e:
raise ValueError("texts iterable longer than metadatas") from e
try:
_id = next(ids_it) if ids_it else None
except StopIteration as e:
raise ValueError("texts iterable longer than ids") from e
links = _metadata.pop(METADATA_LINKS_KEY, [])
if not isinstance(links, list):
links = list(links)
yield Node(
id=_id,
metadata=_metadata,
text=text,
links=links,
)
if ids_it and _has_next(ids_it):
raise ValueError("ids iterable longer than texts")
if metadatas_it and _has_next(metadatas_it):
raise ValueError("metadatas iterable longer than texts")
def _documents_to_nodes(documents: Iterable[Document]) -> Iterator[Node]:
for doc in documents:
metadata = doc.metadata.copy()
links = metadata.pop(METADATA_LINKS_KEY, [])
if not isinstance(links, list):
links = list(links)
yield Node(
id=doc.id,
metadata=metadata,
text=doc.page_content,
links=links,
)
@beta()
def nodes_to_documents(nodes: Iterable[Node]) -> Iterator[Document]:
"""Convert nodes to documents.
Args:
nodes: The nodes to convert to documents.
Returns:
The documents generated from the nodes.
"""
for node in nodes:
metadata = node.metadata.copy()
metadata[METADATA_LINKS_KEY] = [
# Convert the core `Link` (from the node) back to the local `Link`.
Link(kind=link.kind, direction=link.direction, tag=link.tag)
for link in node.links
]
yield Document(
id=node.id,
page_content=node.text,
metadata=metadata,
)
@beta(message="Added in version 0.2.14 of langchain_core. API subject to change.")
class GraphVectorStore(VectorStore):
"""A hybrid vector-and-graph graph store.
Document chunks support vector-similarity search as well as edges linking
chunks based on structural and semantic properties.
.. versionadded:: 0.2.14
"""
@abstractmethod
def add_nodes(
self,
nodes: Iterable[Node],
**kwargs: Any,
) -> Iterable[str]:
"""Add nodes to the graph store.
Args:
nodes: the nodes to add.
"""
async def aadd_nodes(
self,
nodes: Iterable[Node],
**kwargs: Any,
) -> AsyncIterable[str]:
"""Add nodes to the graph store.
Args:
nodes: the nodes to add.
"""
iterator = iter(await run_in_executor(None, self.add_nodes, nodes, **kwargs))
done = object()
while True:
doc = await run_in_executor(None, next, iterator, done)
if doc is done:
break
yield doc # type: ignore[misc]
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[Iterable[dict]] = None,
*,
ids: Optional[Iterable[str]] = None,
**kwargs: Any,
) -> list[str]:
"""Run more texts through the embeddings and add to the vectorstore.
The Links present in the metadata field `links` will be extracted to create
the `Node` links.
Eg if nodes `a` and `b` are connected over a hyperlink `https://some-url`, the
function call would look like:
.. code-block:: python
store.add_texts(
ids=["a", "b"],
texts=["some text a", "some text b"],
metadatas=[
{
"links": [
Link.incoming(kind="hyperlink", tag="https://some-url")
]
},
{
"links": [
Link.outgoing(kind="hyperlink", tag="https://some-url")
]
},
],
)
Args:
texts: Iterable of strings to add to the vectorstore.
metadatas: Optional list of metadatas associated with the texts.
The metadata key `links` shall be an iterable of
:py:class:`~langchain_community.graph_vectorstores.links.Link`.
**kwargs: vectorstore specific parameters.
Returns:
List of ids from adding the texts into the vectorstore.
"""
nodes = _texts_to_nodes(texts, metadatas, ids)
return list(self.add_nodes(nodes, **kwargs))
async def aadd_texts(
self,
texts: Iterable[str],
metadatas: Optional[Iterable[dict]] = None,
*,
ids: Optional[Iterable[str]] = None,
**kwargs: Any,
) -> list[str]:
"""Run more texts through the embeddings and add to the vectorstore.
The Links present in the metadata field `links` will be extracted to create
the `Node` links.
Eg if nodes `a` and `b` are connected over a hyperlink `https://some-url`, the
function call would look like:
.. code-block:: python
await store.aadd_texts(
ids=["a", "b"],
texts=["some text a", "some text b"],
metadatas=[
{
"links": [
Link.incoming(kind="hyperlink", tag="https://some-url")
]
},
{
"links": [
Link.outgoing(kind="hyperlink", tag="https://some-url")
]
},
],
)
Args:
texts: Iterable of strings to add to the vectorstore.
metadatas: Optional list of metadatas associated with the texts.
The metadata key `links` shall be an iterable of
:py:class:`~langchain_community.graph_vectorstores.links.Link`.
**kwargs: vectorstore specific parameters.
Returns:
List of ids from adding the texts into the vectorstore.
"""
nodes = _texts_to_nodes(texts, metadatas, ids)
return [_id async for _id in self.aadd_nodes(nodes, **kwargs)]
def add_documents(
self,
documents: Iterable[Document],
**kwargs: Any,
) -> list[str]:
"""Run more documents through the embeddings and add to the vectorstore.
The Links present in the document metadata field `links` will be extracted to
create the `Node` links.
Eg if nodes `a` and `b` are connected over a hyperlink `https://some-url`, the
function call would look like:
.. code-block:: python
store.add_documents(
[
Document(
id="a",
page_content="some text a",
metadata={
"links": [
Link.incoming(kind="hyperlink", tag="http://some-url")
]
}
),
Document(
id="b",
page_content="some text b",
metadata={
"links": [
Link.outgoing(kind="hyperlink", tag="http://some-url")
]
}
),
]
)
Args:
documents: Documents to add to the vectorstore.
The document's metadata key `links` shall be an iterable of
:py:class:`~langchain_community.graph_vectorstores.links.Link`.
Returns:
List of IDs of the added texts.
"""
nodes = _documents_to_nodes(documents)
return list(self.add_nodes(nodes, **kwargs))
async def aadd_documents(
self,
documents: Iterable[Document],
**kwargs: Any,
) -> list[str]:
"""Run more documents through the embeddings and add to the vectorstore.
The Links present in the document metadata field `links` will be extracted to
create the `Node` links.
Eg if nodes `a` and `b` are connected over a hyperlink `https://some-url`, the
function call would look like:
.. code-block:: python
store.add_documents(
[
Document(
id="a",
page_content="some text a",
metadata={
"links": [
Link.incoming(kind="hyperlink", tag="http://some-url")
]
}
),
Document(
id="b",
page_content="some text b",
metadata={
"links": [
Link.outgoing(kind="hyperlink", tag="http://some-url")
]
}
),
]
)
Args:
documents: Documents to add to the vectorstore.
The document's metadata key `links` shall be an iterable of
:py:class:`~langchain_community.graph_vectorstores.links.Link`.
Returns:
List of IDs of the added texts.
"""
nodes = _documents_to_nodes(documents)
return [_id async for _id in self.aadd_nodes(nodes, **kwargs)]
@abstractmethod
def traversal_search(
self,
query: str,
*,
k: int = 4,
depth: int = 1,
**kwargs: Any,
) -> Iterable[Document]:
"""Retrieve documents from traversing this graph store.
First, `k` nodes are retrieved using a search for each `query` string.
Then, additional nodes are discovered up to the given `depth` from those
starting nodes.
Args:
query: The query string.
k: The number of Documents to return from the initial search.
Defaults to 4. Applies to each of the query strings.
depth: The maximum depth of edges to traverse. Defaults to 1.
Returns:
Retrieved documents.
"""
async def atraversal_search(
self,
query: str,
*,
k: int = 4,
depth: int = 1,
**kwargs: Any,
) -> AsyncIterable[Document]:
"""Retrieve documents from traversing this graph store.
First, `k` nodes are retrieved using a search for each `query` string.
Then, additional nodes are discovered up to the given `depth` from those
starting nodes.
Args:
query: The query string.
k: The number of Documents to return from the initial search.
Defaults to 4. Applies to each of the query strings.
depth: The maximum depth of edges to traverse. Defaults to 1.
Returns:
Retrieved documents.
"""
iterator = iter(
await run_in_executor(
None, self.traversal_search, query, k=k, depth=depth, **kwargs
)
)
done = object()
while True:
doc = await run_in_executor(None, next, iterator, done)
if doc is done:
break
yield doc # type: ignore[misc]
@abstractmethod
def mmr_traversal_search(
self,
query: str,
*,
k: int = 4,
depth: int = 2,
fetch_k: int = 100,
adjacent_k: int = 10,
lambda_mult: float = 0.5,
score_threshold: float = float("-inf"),
**kwargs: Any,
) -> Iterable[Document]:
"""Retrieve documents from this graph store using MMR-traversal.
This strategy first retrieves the top `fetch_k` results by similarity to
the question. It then selects the top `k` results based on
maximum-marginal relevance using the given `lambda_mult`.
At each step, it considers the (remaining) documents from `fetch_k` as
well as any documents connected by edges to a selected document
retrieved based on similarity (a "root").
Args:
query: The query string to search for.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch via similarity.
Defaults to 100.
adjacent_k: Number of adjacent Documents to fetch.
Defaults to 10.
depth: Maximum depth of a node (number of edges) from a node
retrieved via similarity. Defaults to 2.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding to maximum
diversity and 1 to minimum diversity. Defaults to 0.5.
score_threshold: Only documents with a score greater than or equal
this threshold will be chosen. Defaults to negative infinity.
"""
async def ammr_traversal_search(
self,
query: str,
*,
k: int = 4,
depth: int = 2,
fetch_k: int = 100,
adjacent_k: int = 10,
lambda_mult: float = 0.5,
score_threshold: float = float("-inf"),
**kwargs: Any,
) -> AsyncIterable[Document]:
"""Retrieve documents from this graph store using MMR-traversal.
This strategy first retrieves the top `fetch_k` results by similarity to
the question. It then selects the top `k` results based on
maximum-marginal relevance using the given `lambda_mult`.
At each step, it considers the (remaining) documents from `fetch_k` as
well as any documents connected by edges to a selected document
retrieved based on similarity (a "root").
Args:
query: The query string to search for.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch via similarity.
Defaults to 100.
adjacent_k: Number of adjacent Documents to fetch.
Defaults to 10.
depth: Maximum depth of a node (number of edges) from a node
retrieved via similarity. Defaults to 2.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding to maximum
diversity and 1 to minimum diversity. Defaults to 0.5.
score_threshold: Only documents with a score greater than or equal
this threshold will be chosen. Defaults to negative infinity.
"""
iterator = iter(
await run_in_executor(
None,
self.mmr_traversal_search,
query,
k=k,
fetch_k=fetch_k,
adjacent_k=adjacent_k,
depth=depth,
lambda_mult=lambda_mult,
score_threshold=score_threshold,
**kwargs,
)
)
done = object()
while True:
doc = await run_in_executor(None, next, iterator, done)
if doc is done:
break
yield doc # type: ignore[misc]
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> list[Document]:
return list(self.traversal_search(query, k=k, depth=0))
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> list[Document]:
return list(
self.mmr_traversal_search(
query, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, depth=0
)
)
async def asimilarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> list[Document]:
return [doc async for doc in self.atraversal_search(query, k=k, depth=0)]
def search(self, query: str, search_type: str, **kwargs: Any) -> list[Document]:
if search_type == "similarity":
return self.similarity_search(query, **kwargs)
elif search_type == "similarity_score_threshold":
docs_and_similarities = self.similarity_search_with_relevance_scores(
query, **kwargs
)
return [doc for doc, _ in docs_and_similarities]
elif search_type == "mmr":
return self.max_marginal_relevance_search(query, **kwargs)
elif search_type == "traversal":
return list(self.traversal_search(query, **kwargs))
elif search_type == "mmr_traversal":
return list(self.mmr_traversal_search(query, **kwargs))
else:
raise ValueError(
f"search_type of {search_type} not allowed. Expected "
"search_type to be 'similarity', 'similarity_score_threshold', "
"'mmr' or 'traversal'."
)
async def asearch(
self, query: str, search_type: str, **kwargs: Any
) -> list[Document]:
if search_type == "similarity":
return await self.asimilarity_search(query, **kwargs)
elif search_type == "similarity_score_threshold":
docs_and_similarities = await self.asimilarity_search_with_relevance_scores(
query, **kwargs
)
return [doc for doc, _ in docs_and_similarities]
elif search_type == "mmr":
return await self.amax_marginal_relevance_search(query, **kwargs)
elif search_type == "traversal":
return [doc async for doc in self.atraversal_search(query, **kwargs)]
else:
raise ValueError(
f"search_type of {search_type} not allowed. Expected "
"search_type to be 'similarity', 'similarity_score_threshold', "
"'mmr' or 'traversal'."
)
def as_retriever(self, **kwargs: Any) -> GraphVectorStoreRetriever:
"""Return GraphVectorStoreRetriever initialized from this GraphVectorStore.
Args:
**kwargs: Keyword arguments to pass to the search function.
Can include:
- search_type (Optional[str]): Defines the type of search that
the Retriever should perform.
Can be ``traversal`` (default), ``similarity``, ``mmr``, or
``similarity_score_threshold``.
- search_kwargs (Optional[Dict]): Keyword arguments to pass to the
search function. Can include things like:
- k(int): Amount of documents to return (Default: 4).
- depth(int): The maximum depth of edges to traverse (Default: 1).
- score_threshold(float): Minimum relevance threshold
for similarity_score_threshold.
- fetch_k(int): Amount of documents to pass to MMR algorithm
(Default: 20).
- lambda_mult(float): Diversity of results returned by MMR;
1 for minimum diversity and 0 for maximum. (Default: 0.5).
Returns:
Retriever for this GraphVectorStore.
Examples:
.. code-block:: python
# Retrieve documents traversing edges
docsearch.as_retriever(
search_type="traversal",
search_kwargs={'k': 6, 'depth': 3}
)
# Retrieve more documents with higher diversity
# Useful if your dataset has many similar documents
docsearch.as_retriever(
search_type="mmr",
search_kwargs={'k': 6, 'lambda_mult': 0.25}
)
# Fetch more documents for the MMR algorithm to consider
# But only return the top 5
docsearch.as_retriever(
search_type="mmr",
search_kwargs={'k': 5, 'fetch_k': 50}
)
# Only retrieve documents that have a relevance score
# Above a certain threshold
docsearch.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={'score_threshold': 0.8}
)
# Only get the single most similar document from the dataset
docsearch.as_retriever(search_kwargs={'k': 1})
"""
return GraphVectorStoreRetriever(vectorstore=self, **kwargs)
class GraphVectorStoreRetriever(VectorStoreRetriever):
"""Retriever class for GraphVectorStore."""
vectorstore: GraphVectorStore
"""GraphVectorStore to use for retrieval."""
search_type: str = "traversal"
"""Type of search to perform. Defaults to "traversal"."""
allowed_search_types: ClassVar[Collection[str]] = (
"similarity",
"similarity_score_threshold",
"mmr",
"traversal",
"mmr_traversal",
)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> list[Document]:
if self.search_type == "traversal":
return list(self.vectorstore.traversal_search(query, **self.search_kwargs))
elif self.search_type == "mmr_traversal":
return list(
self.vectorstore.mmr_traversal_search(query, **self.search_kwargs)
)
else:
return super()._get_relevant_documents(query, run_manager=run_manager)
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> list[Document]:
if self.search_type == "traversal":
return [
doc
async for doc in self.vectorstore.atraversal_search(
query, **self.search_kwargs
)
]
elif self.search_type == "mmr_traversal":
return [
doc
async for doc in self.vectorstore.ammr_traversal_search(
query, **self.search_kwargs
)
]
else:
return await super()._aget_relevant_documents(
query, run_manager=run_manager
)

View File

@ -12,12 +12,12 @@ from typing import (
from langchain_core._api import beta
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.graph_vectorstores.base import (
from langchain_community.graph_vectorstores.base import (
GraphVectorStore,
Node,
nodes_to_documents,
)
from langchain_community.utilities.cassandra import SetupMode
if TYPE_CHECKING:

View File

@ -2,11 +2,11 @@ from typing import Any, Dict, Iterable, List, Optional, Set, Union
from langchain_core._api import beta
from langchain_core.documents import Document
from langchain_core.graph_vectorstores.links import Link
from langchain_community.graph_vectorstores.extractors.link_extractor import (
LinkExtractor,
)
from langchain_community.graph_vectorstores.links import Link
# TypeAlias is not available in Python 3.9, we can't use that or the newer `type`.
GLiNERInput = Union[str, Document]
@ -34,7 +34,7 @@ class GLiNERLinkExtractor(LinkExtractor[GLiNERInput]):
.. seealso::
- :mod:`How to use a graph vector store <langchain_community.graph_vectorstores>`
- :class:`How to create links between documents <langchain_core.graph_vectorstores.links.Link>`
- :class:`How to create links between documents <langchain_community.graph_vectorstores.links.Link>`
How to link Documents on common named entities
==============================================
@ -59,12 +59,12 @@ class GLiNERLinkExtractor(LinkExtractor[GLiNERInput]):
We can use :meth:`extract_one` on a document to get the links and add the links
to the document metadata with
:meth:`~langchain_core.graph_vectorstores.links.add_links`::
:meth:`~langchain_community.graph_vectorstores.links.add_links`::
from langchain_community.document_loaders import TextLoader
from langchain_community.graph_vectorstores import CassandraGraphVectorStore
from langchain_community.graph_vectorstores.extractors import GLiNERLinkExtractor
from langchain_core.graph_vectorstores.links import add_links
from langchain_community.graph_vectorstores.links import add_links
from langchain_text_splitters import CharacterTextSplitter
loader = TextLoader("state_of_the_union.txt")
@ -113,7 +113,7 @@ class GLiNERLinkExtractor(LinkExtractor[GLiNERInput]):
{'source': 'state_of_the_union.txt', 'links': [Link(kind='entity:Person', direction='bidir', tag='President Zelenskyy'), Link(kind='entity:Person', direction='bidir', tag='Vladimir Putin')]}
The documents with named entity links can then be added to a :class:`~langchain_core.graph_vectorstores.base.GraphVectorStore`::
The documents with named entity links can then be added to a :class:`~langchain_community.graph_vectorstores.base.GraphVectorStore`::
from langchain_community.graph_vectorstores import CassandraGraphVectorStore

View File

@ -2,7 +2,6 @@ from typing import Callable, List, Set
from langchain_core._api import beta
from langchain_core.documents import Document
from langchain_core.graph_vectorstores.links import Link
from langchain_community.graph_vectorstores.extractors.link_extractor import (
LinkExtractor,
@ -10,6 +9,7 @@ from langchain_community.graph_vectorstores.extractors.link_extractor import (
from langchain_community.graph_vectorstores.extractors.link_extractor_adapter import (
LinkExtractorAdapter,
)
from langchain_community.graph_vectorstores.links import Link
# TypeAlias is not available in Python 3.9, we can't use that or the newer `type`.
HierarchyInput = List[str]

View File

@ -6,8 +6,8 @@ from urllib.parse import urldefrag, urljoin, urlparse
from langchain_core._api import beta
from langchain_core.documents import Document
from langchain_core.graph_vectorstores import Link
from langchain_community.graph_vectorstores import Link
from langchain_community.graph_vectorstores.extractors.link_extractor import (
LinkExtractor,
)
@ -77,7 +77,7 @@ class HtmlLinkExtractor(LinkExtractor[HtmlInput]):
.. seealso::
- :mod:`How to use a graph vector store <langchain_community.graph_vectorstores>`
- :class:`How to create links between documents <langchain_core.graph_vectorstores.links.Link>`
- :class:`How to create links between documents <langchain_community.graph_vectorstores.links.Link>`
How to link Documents on hyperlinks in HTML
===========================================
@ -103,7 +103,7 @@ class HtmlLinkExtractor(LinkExtractor[HtmlInput]):
We can use :meth:`extract_one` on a document to get the links and add the links
to the document metadata with
:meth:`~langchain_core.graph_vectorstores.links.add_links`::
:meth:`~langchain_community.graph_vectorstores.links.add_links`::
from langchain_community.document_loaders import AsyncHtmlLoader
from langchain_community.graph_vectorstores.extractors import (
@ -148,7 +148,7 @@ class HtmlLinkExtractor(LinkExtractor[HtmlInput]):
from langchain_community.document_loaders import AsyncHtmlLoader
from langchain_community.graph_vectorstores.extractors import HtmlLinkExtractor
from langchain_core.graph_vectorstores.links import add_links
from langchain_community.graph_vectorstores.links import add_links
loader = AsyncHtmlLoader(
[
@ -227,7 +227,7 @@ class HtmlLinkExtractor(LinkExtractor[HtmlInput]):
Found link from https://python.langchain.com/v0.2/docs/integrations/providers/astradb/ to https://docs.datastax.com/en/astra/home/astra.html.
The documents with URL links can then be added to a :class:`~langchain_core.graph_vectorstores.base.GraphVectorStore`::
The documents with URL links can then be added to a :class:`~langchain_community.graph_vectorstores.base.GraphVectorStore`::
from langchain_community.graph_vectorstores import CassandraGraphVectorStore

View File

@ -2,11 +2,11 @@ from typing import Any, Dict, Iterable, Optional, Set, Union
from langchain_core._api import beta
from langchain_core.documents import Document
from langchain_core.graph_vectorstores.links import Link
from langchain_community.graph_vectorstores.extractors.link_extractor import (
LinkExtractor,
)
from langchain_community.graph_vectorstores.links import Link
KeybertInput = Union[str, Document]
@ -37,7 +37,7 @@ class KeybertLinkExtractor(LinkExtractor[KeybertInput]):
.. seealso::
- :mod:`How to use a graph vector store <langchain_community.graph_vectorstores>`
- :class:`How to create links between documents <langchain_core.graph_vectorstores.links.Link>`
- :class:`How to create links between documents <langchain_community.graph_vectorstores.links.Link>`
How to link Documents on common keywords using Keybert
======================================================
@ -62,12 +62,12 @@ class KeybertLinkExtractor(LinkExtractor[KeybertInput]):
We can use :meth:`extract_one` on a document to get the links and add the links
to the document metadata with
:meth:`~langchain_core.graph_vectorstores.links.add_links`::
:meth:`~langchain_community.graph_vectorstores.links.add_links`::
from langchain_community.document_loaders import TextLoader
from langchain_community.graph_vectorstores import CassandraGraphVectorStore
from langchain_community.graph_vectorstores.extractors import KeybertLinkExtractor
from langchain_core.graph_vectorstores.links import add_links
from langchain_community.graph_vectorstores.links import add_links
from langchain_text_splitters import CharacterTextSplitter
loader = TextLoader("state_of_the_union.txt")
@ -116,7 +116,7 @@ class KeybertLinkExtractor(LinkExtractor[KeybertInput]):
{'source': 'state_of_the_union.txt', 'links': [Link(kind='kw', direction='bidir', tag='ukraine'), Link(kind='kw', direction='bidir', tag='ukrainian'), Link(kind='kw', direction='bidir', tag='putin'), Link(kind='kw', direction='bidir', tag='vladimir'), Link(kind='kw', direction='bidir', tag='russia')]}
The documents with keyword links can then be added to a :class:`~langchain_core.graph_vectorstores.base.GraphVectorStore`::
The documents with keyword links can then be added to a :class:`~langchain_community.graph_vectorstores.base.GraphVectorStore`::
from langchain_community.graph_vectorstores import CassandraGraphVectorStore

View File

@ -4,7 +4,8 @@ from abc import ABC, abstractmethod
from typing import Generic, Iterable, Set, TypeVar
from langchain_core._api import beta
from langchain_core.graph_vectorstores import Link
from langchain_community.graph_vectorstores import Link
InputT = TypeVar("InputT")

View File

@ -1,8 +1,8 @@
from typing import Callable, Iterable, Set, TypeVar
from langchain_core._api import beta
from langchain_core.graph_vectorstores import Link
from langchain_community.graph_vectorstores import Link
from langchain_community.graph_vectorstores.extractors.link_extractor import (
LinkExtractor,
)

View File

@ -3,11 +3,11 @@ from typing import Any, Sequence
from langchain_core._api import beta
from langchain_core.documents import Document
from langchain_core.documents.transformers import BaseDocumentTransformer
from langchain_core.graph_vectorstores.links import copy_with_links
from langchain_community.graph_vectorstores.extractors.link_extractor import (
LinkExtractor,
)
from langchain_community.graph_vectorstores.links import copy_with_links
@beta()

View File

@ -1,8 +1,102 @@
from langchain_core.graph_vectorstores.links import (
Link,
add_links,
copy_with_links,
get_links,
)
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Literal, Union
__all__ = ["Link", "add_links", "get_links", "copy_with_links"]
from langchain_core._api import beta
from langchain_core.documents import Document
@beta()
@dataclass(frozen=True)
class Link:
"""A link to/from a tag of a given tag.
Edges exist from nodes with an outgoing link to nodes with a matching incoming link.
"""
kind: str
"""The kind of link. Allows different extractors to use the same tag name without
creating collisions between extractors. For example keyword vs url."""
direction: Literal["in", "out", "bidir"]
"""The direction of the link."""
tag: str
"""The tag of the link."""
@staticmethod
def incoming(kind: str, tag: str) -> "Link":
"""Create an incoming link."""
return Link(kind=kind, direction="in", tag=tag)
@staticmethod
def outgoing(kind: str, tag: str) -> "Link":
"""Create an outgoing link."""
return Link(kind=kind, direction="out", tag=tag)
@staticmethod
def bidir(kind: str, tag: str) -> "Link":
"""Create a bidirectional link."""
return Link(kind=kind, direction="bidir", tag=tag)
METADATA_LINKS_KEY = "links"
@beta()
def get_links(doc: Document) -> list[Link]:
"""Get the links from a document.
Args:
doc: The document to get the link tags from.
Returns:
The set of link tags from the document.
"""
links = doc.metadata.setdefault(METADATA_LINKS_KEY, [])
if not isinstance(links, list):
# Convert to a list and remember that.
links = list(links)
doc.metadata[METADATA_LINKS_KEY] = links
return links
@beta()
def add_links(doc: Document, *links: Union[Link, Iterable[Link]]) -> None:
"""Add links to the given metadata.
Args:
doc: The document to add the links to.
*links: The links to add to the document.
"""
links_in_metadata = get_links(doc)
for link in links:
if isinstance(link, Iterable):
links_in_metadata.extend(link)
else:
links_in_metadata.append(link)
@beta()
def copy_with_links(doc: Document, *links: Union[Link, Iterable[Link]]) -> Document:
"""Return a document with the given links added.
Args:
doc: The document to add the links to.
*links: The links to add to the document.
Returns:
A document with a shallow-copy of the metadata with the links added.
"""
new_links = set(get_links(doc))
for link in links:
if isinstance(link, Iterable):
new_links.update(link)
else:
new_links.add(link)
return Document(
page_content=doc.page_content,
metadata={
**doc.metadata,
METADATA_LINKS_KEY: list(new_links),
},
)

View File

@ -1,7 +1,7 @@
import pytest
from langchain_core.graph_vectorstores.links import Link
from langchain_community.graph_vectorstores.extractors import GLiNERLinkExtractor
from langchain_community.graph_vectorstores.links import Link
PAGE_1 = """
Cristiano Ronaldo dos Santos Aveiro (Portuguese pronunciation: [kɾiʃ'tjɐnu

View File

@ -1,7 +1,7 @@
import pytest
from langchain_core.graph_vectorstores.links import Link
from langchain_community.graph_vectorstores.extractors import KeybertLinkExtractor
from langchain_community.graph_vectorstores.links import Link
PAGE_1 = """
Supervised learning is the machine learning task of learning a function that

View File

@ -4,9 +4,9 @@ from typing import Iterable, List, Optional, Type
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.graph_vectorstores.links import METADATA_LINKS_KEY, Link
from langchain_community.graph_vectorstores import CassandraGraphVectorStore
from langchain_community.graph_vectorstores.links import METADATA_LINKS_KEY, Link
CASSANDRA_DEFAULT_KEYSPACE = "graph_test_keyspace"

View File

@ -1,6 +1,5 @@
from langchain_core.graph_vectorstores.links import Link
from langchain_community.graph_vectorstores.extractors import HierarchyLinkExtractor
from langchain_community.graph_vectorstores.links import Link
PATH_1 = ["Root", "H1", "h2"]

View File

@ -1,6 +1,6 @@
import pytest
from langchain_core.graph_vectorstores import Link
from langchain_community.graph_vectorstores import Link
from langchain_community.graph_vectorstores.extractors import (
HtmlInput,
HtmlLinkExtractor,

View File

@ -1,12 +1,12 @@
from typing import Set
from langchain_core.documents import Document
from langchain_core.graph_vectorstores.links import Link, get_links
from langchain_community.graph_vectorstores.extractors import (
LinkExtractor,
LinkExtractorTransformer,
)
from langchain_community.graph_vectorstores.links import Link, get_links
TEXT1 = "Text1"
TEXT2 = "Text2"

View File

@ -1,12 +1,12 @@
import pytest
from langchain_core.documents import Document
from langchain_core.graph_vectorstores.base import (
from langchain_community.graph_vectorstores.base import (
Node,
_documents_to_nodes,
_texts_to_nodes,
)
from langchain_core.graph_vectorstores.links import Link
from langchain_community.graph_vectorstores.links import Link
def test_texts_to_nodes() -> None:

View File

@ -1,15 +0,0 @@
from langchain_core.graph_vectorstores.base import (
GraphVectorStore,
GraphVectorStoreRetriever,
Node,
)
from langchain_core.graph_vectorstores.links import (
Link,
)
__all__ = [
"GraphVectorStore",
"GraphVectorStoreRetriever",
"Node",
"Link",
]

View File

@ -1,708 +0,0 @@
from __future__ import annotations
from abc import abstractmethod
from collections.abc import AsyncIterable, Collection, Iterable, Iterator
from typing import (
Any,
ClassVar,
Optional,
)
from pydantic import Field
from langchain_core._api import beta
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.graph_vectorstores.links import METADATA_LINKS_KEY, Link
from langchain_core.load import Serializable
from langchain_core.runnables import run_in_executor
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
def _has_next(iterator: Iterator) -> bool:
"""Checks if the iterator has more elements.
Warning: consumes an element from the iterator"""
sentinel = object()
return next(iterator, sentinel) is not sentinel
@beta()
class Node(Serializable):
"""Node in the GraphVectorStore.
Edges exist from nodes with an outgoing link to nodes with a matching incoming link.
For instance two nodes `a` and `b` connected over a hyperlink ``https://some-url``
would look like:
.. code-block:: python
[
Node(
id="a",
text="some text a",
links= [
Link(kind="hyperlink", tag="https://some-url", direction="incoming")
],
),
Node(
id="b",
text="some text b",
links= [
Link(kind="hyperlink", tag="https://some-url", direction="outgoing")
],
)
]
"""
id: Optional[str] = None
"""Unique ID for the node. Will be generated by the GraphVectorStore if not set."""
text: str
"""Text contained by the node."""
metadata: dict = Field(default_factory=dict)
"""Metadata for the node."""
links: list[Link] = Field(default_factory=list)
"""Links associated with the node."""
def _texts_to_nodes(
texts: Iterable[str],
metadatas: Optional[Iterable[dict]],
ids: Optional[Iterable[str]],
) -> Iterator[Node]:
metadatas_it = iter(metadatas) if metadatas else None
ids_it = iter(ids) if ids else None
for text in texts:
try:
_metadata = next(metadatas_it).copy() if metadatas_it else {}
except StopIteration as e:
raise ValueError("texts iterable longer than metadatas") from e
try:
_id = next(ids_it) if ids_it else None
except StopIteration as e:
raise ValueError("texts iterable longer than ids") from e
links = _metadata.pop(METADATA_LINKS_KEY, [])
if not isinstance(links, list):
links = list(links)
yield Node(
id=_id,
metadata=_metadata,
text=text,
links=links,
)
if ids_it and _has_next(ids_it):
raise ValueError("ids iterable longer than texts")
if metadatas_it and _has_next(metadatas_it):
raise ValueError("metadatas iterable longer than texts")
def _documents_to_nodes(documents: Iterable[Document]) -> Iterator[Node]:
for doc in documents:
metadata = doc.metadata.copy()
links = metadata.pop(METADATA_LINKS_KEY, [])
if not isinstance(links, list):
links = list(links)
yield Node(
id=doc.id,
metadata=metadata,
text=doc.page_content,
links=links,
)
@beta()
def nodes_to_documents(nodes: Iterable[Node]) -> Iterator[Document]:
"""Convert nodes to documents.
Args:
nodes: The nodes to convert to documents.
Returns:
The documents generated from the nodes.
"""
for node in nodes:
metadata = node.metadata.copy()
metadata[METADATA_LINKS_KEY] = [
# Convert the core `Link` (from the node) back to the local `Link`.
Link(kind=link.kind, direction=link.direction, tag=link.tag)
for link in node.links
]
yield Document(
id=node.id,
page_content=node.text,
metadata=metadata,
)
@beta(message="Added in version 0.2.14 of langchain_core. API subject to change.")
class GraphVectorStore(VectorStore):
"""A hybrid vector-and-graph graph store.
Document chunks support vector-similarity search as well as edges linking
chunks based on structural and semantic properties.
.. versionadded:: 0.2.14
"""
@abstractmethod
def add_nodes(
self,
nodes: Iterable[Node],
**kwargs: Any,
) -> Iterable[str]:
"""Add nodes to the graph store.
Args:
nodes: the nodes to add.
"""
async def aadd_nodes(
self,
nodes: Iterable[Node],
**kwargs: Any,
) -> AsyncIterable[str]:
"""Add nodes to the graph store.
Args:
nodes: the nodes to add.
"""
iterator = iter(await run_in_executor(None, self.add_nodes, nodes, **kwargs))
done = object()
while True:
doc = await run_in_executor(None, next, iterator, done)
if doc is done:
break
yield doc # type: ignore[misc]
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[Iterable[dict]] = None,
*,
ids: Optional[Iterable[str]] = None,
**kwargs: Any,
) -> list[str]:
"""Run more texts through the embeddings and add to the vectorstore.
The Links present in the metadata field `links` will be extracted to create
the `Node` links.
Eg if nodes `a` and `b` are connected over a hyperlink `https://some-url`, the
function call would look like:
.. code-block:: python
store.add_texts(
ids=["a", "b"],
texts=["some text a", "some text b"],
metadatas=[
{
"links": [
Link.incoming(kind="hyperlink", tag="https://some-url")
]
},
{
"links": [
Link.outgoing(kind="hyperlink", tag="https://some-url")
]
},
],
)
Args:
texts: Iterable of strings to add to the vectorstore.
metadatas: Optional list of metadatas associated with the texts.
The metadata key `links` shall be an iterable of
:py:class:`~langchain_core.graph_vectorstores.links.Link`.
**kwargs: vectorstore specific parameters.
Returns:
List of ids from adding the texts into the vectorstore.
"""
nodes = _texts_to_nodes(texts, metadatas, ids)
return list(self.add_nodes(nodes, **kwargs))
async def aadd_texts(
self,
texts: Iterable[str],
metadatas: Optional[Iterable[dict]] = None,
*,
ids: Optional[Iterable[str]] = None,
**kwargs: Any,
) -> list[str]:
"""Run more texts through the embeddings and add to the vectorstore.
The Links present in the metadata field `links` will be extracted to create
the `Node` links.
Eg if nodes `a` and `b` are connected over a hyperlink `https://some-url`, the
function call would look like:
.. code-block:: python
await store.aadd_texts(
ids=["a", "b"],
texts=["some text a", "some text b"],
metadatas=[
{
"links": [
Link.incoming(kind="hyperlink", tag="https://some-url")
]
},
{
"links": [
Link.outgoing(kind="hyperlink", tag="https://some-url")
]
},
],
)
Args:
texts: Iterable of strings to add to the vectorstore.
metadatas: Optional list of metadatas associated with the texts.
The metadata key `links` shall be an iterable of
:py:class:`~langchain_core.graph_vectorstores.links.Link`.
**kwargs: vectorstore specific parameters.
Returns:
List of ids from adding the texts into the vectorstore.
"""
nodes = _texts_to_nodes(texts, metadatas, ids)
return [_id async for _id in self.aadd_nodes(nodes, **kwargs)]
def add_documents(
self,
documents: Iterable[Document],
**kwargs: Any,
) -> list[str]:
"""Run more documents through the embeddings and add to the vectorstore.
The Links present in the document metadata field `links` will be extracted to
create the `Node` links.
Eg if nodes `a` and `b` are connected over a hyperlink `https://some-url`, the
function call would look like:
.. code-block:: python
store.add_documents(
[
Document(
id="a",
page_content="some text a",
metadata={
"links": [
Link.incoming(kind="hyperlink", tag="http://some-url")
]
}
),
Document(
id="b",
page_content="some text b",
metadata={
"links": [
Link.outgoing(kind="hyperlink", tag="http://some-url")
]
}
),
]
)
Args:
documents: Documents to add to the vectorstore.
The document's metadata key `links` shall be an iterable of
:py:class:`~langchain_core.graph_vectorstores.links.Link`.
Returns:
List of IDs of the added texts.
"""
nodes = _documents_to_nodes(documents)
return list(self.add_nodes(nodes, **kwargs))
async def aadd_documents(
self,
documents: Iterable[Document],
**kwargs: Any,
) -> list[str]:
"""Run more documents through the embeddings and add to the vectorstore.
The Links present in the document metadata field `links` will be extracted to
create the `Node` links.
Eg if nodes `a` and `b` are connected over a hyperlink `https://some-url`, the
function call would look like:
.. code-block:: python
store.add_documents(
[
Document(
id="a",
page_content="some text a",
metadata={
"links": [
Link.incoming(kind="hyperlink", tag="http://some-url")
]
}
),
Document(
id="b",
page_content="some text b",
metadata={
"links": [
Link.outgoing(kind="hyperlink", tag="http://some-url")
]
}
),
]
)
Args:
documents: Documents to add to the vectorstore.
The document's metadata key `links` shall be an iterable of
:py:class:`~langchain_core.graph_vectorstores.links.Link`.
Returns:
List of IDs of the added texts.
"""
nodes = _documents_to_nodes(documents)
return [_id async for _id in self.aadd_nodes(nodes, **kwargs)]
@abstractmethod
def traversal_search(
self,
query: str,
*,
k: int = 4,
depth: int = 1,
**kwargs: Any,
) -> Iterable[Document]:
"""Retrieve documents from traversing this graph store.
First, `k` nodes are retrieved using a search for each `query` string.
Then, additional nodes are discovered up to the given `depth` from those
starting nodes.
Args:
query: The query string.
k: The number of Documents to return from the initial search.
Defaults to 4. Applies to each of the query strings.
depth: The maximum depth of edges to traverse. Defaults to 1.
Returns:
Retrieved documents.
"""
async def atraversal_search(
self,
query: str,
*,
k: int = 4,
depth: int = 1,
**kwargs: Any,
) -> AsyncIterable[Document]:
"""Retrieve documents from traversing this graph store.
First, `k` nodes are retrieved using a search for each `query` string.
Then, additional nodes are discovered up to the given `depth` from those
starting nodes.
Args:
query: The query string.
k: The number of Documents to return from the initial search.
Defaults to 4. Applies to each of the query strings.
depth: The maximum depth of edges to traverse. Defaults to 1.
Returns:
Retrieved documents.
"""
iterator = iter(
await run_in_executor(
None, self.traversal_search, query, k=k, depth=depth, **kwargs
)
)
done = object()
while True:
doc = await run_in_executor(None, next, iterator, done)
if doc is done:
break
yield doc # type: ignore[misc]
@abstractmethod
def mmr_traversal_search(
self,
query: str,
*,
k: int = 4,
depth: int = 2,
fetch_k: int = 100,
adjacent_k: int = 10,
lambda_mult: float = 0.5,
score_threshold: float = float("-inf"),
**kwargs: Any,
) -> Iterable[Document]:
"""Retrieve documents from this graph store using MMR-traversal.
This strategy first retrieves the top `fetch_k` results by similarity to
the question. It then selects the top `k` results based on
maximum-marginal relevance using the given `lambda_mult`.
At each step, it considers the (remaining) documents from `fetch_k` as
well as any documents connected by edges to a selected document
retrieved based on similarity (a "root").
Args:
query: The query string to search for.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch via similarity.
Defaults to 100.
adjacent_k: Number of adjacent Documents to fetch.
Defaults to 10.
depth: Maximum depth of a node (number of edges) from a node
retrieved via similarity. Defaults to 2.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding to maximum
diversity and 1 to minimum diversity. Defaults to 0.5.
score_threshold: Only documents with a score greater than or equal
this threshold will be chosen. Defaults to negative infinity.
"""
async def ammr_traversal_search(
self,
query: str,
*,
k: int = 4,
depth: int = 2,
fetch_k: int = 100,
adjacent_k: int = 10,
lambda_mult: float = 0.5,
score_threshold: float = float("-inf"),
**kwargs: Any,
) -> AsyncIterable[Document]:
"""Retrieve documents from this graph store using MMR-traversal.
This strategy first retrieves the top `fetch_k` results by similarity to
the question. It then selects the top `k` results based on
maximum-marginal relevance using the given `lambda_mult`.
At each step, it considers the (remaining) documents from `fetch_k` as
well as any documents connected by edges to a selected document
retrieved based on similarity (a "root").
Args:
query: The query string to search for.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch via similarity.
Defaults to 100.
adjacent_k: Number of adjacent Documents to fetch.
Defaults to 10.
depth: Maximum depth of a node (number of edges) from a node
retrieved via similarity. Defaults to 2.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding to maximum
diversity and 1 to minimum diversity. Defaults to 0.5.
score_threshold: Only documents with a score greater than or equal
this threshold will be chosen. Defaults to negative infinity.
"""
iterator = iter(
await run_in_executor(
None,
self.mmr_traversal_search,
query,
k=k,
fetch_k=fetch_k,
adjacent_k=adjacent_k,
depth=depth,
lambda_mult=lambda_mult,
score_threshold=score_threshold,
**kwargs,
)
)
done = object()
while True:
doc = await run_in_executor(None, next, iterator, done)
if doc is done:
break
yield doc # type: ignore[misc]
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> list[Document]:
return list(self.traversal_search(query, k=k, depth=0))
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> list[Document]:
return list(
self.mmr_traversal_search(
query, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, depth=0
)
)
async def asimilarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> list[Document]:
return [doc async for doc in self.atraversal_search(query, k=k, depth=0)]
def search(self, query: str, search_type: str, **kwargs: Any) -> list[Document]:
if search_type == "similarity":
return self.similarity_search(query, **kwargs)
elif search_type == "similarity_score_threshold":
docs_and_similarities = self.similarity_search_with_relevance_scores(
query, **kwargs
)
return [doc for doc, _ in docs_and_similarities]
elif search_type == "mmr":
return self.max_marginal_relevance_search(query, **kwargs)
elif search_type == "traversal":
return list(self.traversal_search(query, **kwargs))
elif search_type == "mmr_traversal":
return list(self.mmr_traversal_search(query, **kwargs))
else:
raise ValueError(
f"search_type of {search_type} not allowed. Expected "
"search_type to be 'similarity', 'similarity_score_threshold', "
"'mmr' or 'traversal'."
)
async def asearch(
self, query: str, search_type: str, **kwargs: Any
) -> list[Document]:
if search_type == "similarity":
return await self.asimilarity_search(query, **kwargs)
elif search_type == "similarity_score_threshold":
docs_and_similarities = await self.asimilarity_search_with_relevance_scores(
query, **kwargs
)
return [doc for doc, _ in docs_and_similarities]
elif search_type == "mmr":
return await self.amax_marginal_relevance_search(query, **kwargs)
elif search_type == "traversal":
return [doc async for doc in self.atraversal_search(query, **kwargs)]
else:
raise ValueError(
f"search_type of {search_type} not allowed. Expected "
"search_type to be 'similarity', 'similarity_score_threshold', "
"'mmr' or 'traversal'."
)
def as_retriever(self, **kwargs: Any) -> GraphVectorStoreRetriever:
"""Return GraphVectorStoreRetriever initialized from this GraphVectorStore.
Args:
**kwargs: Keyword arguments to pass to the search function.
Can include:
- search_type (Optional[str]): Defines the type of search that
the Retriever should perform.
Can be ``traversal`` (default), ``similarity``, ``mmr``, or
``similarity_score_threshold``.
- search_kwargs (Optional[Dict]): Keyword arguments to pass to the
search function. Can include things like:
- k(int): Amount of documents to return (Default: 4).
- depth(int): The maximum depth of edges to traverse (Default: 1).
- score_threshold(float): Minimum relevance threshold
for similarity_score_threshold.
- fetch_k(int): Amount of documents to pass to MMR algorithm
(Default: 20).
- lambda_mult(float): Diversity of results returned by MMR;
1 for minimum diversity and 0 for maximum. (Default: 0.5).
Returns:
Retriever for this GraphVectorStore.
Examples:
.. code-block:: python
# Retrieve documents traversing edges
docsearch.as_retriever(
search_type="traversal",
search_kwargs={'k': 6, 'depth': 3}
)
# Retrieve more documents with higher diversity
# Useful if your dataset has many similar documents
docsearch.as_retriever(
search_type="mmr",
search_kwargs={'k': 6, 'lambda_mult': 0.25}
)
# Fetch more documents for the MMR algorithm to consider
# But only return the top 5
docsearch.as_retriever(
search_type="mmr",
search_kwargs={'k': 5, 'fetch_k': 50}
)
# Only retrieve documents that have a relevance score
# Above a certain threshold
docsearch.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={'score_threshold': 0.8}
)
# Only get the single most similar document from the dataset
docsearch.as_retriever(search_kwargs={'k': 1})
"""
return GraphVectorStoreRetriever(vectorstore=self, **kwargs)
class GraphVectorStoreRetriever(VectorStoreRetriever):
"""Retriever class for GraphVectorStore."""
vectorstore: GraphVectorStore
"""GraphVectorStore to use for retrieval."""
search_type: str = "traversal"
"""Type of search to perform. Defaults to "traversal"."""
allowed_search_types: ClassVar[Collection[str]] = (
"similarity",
"similarity_score_threshold",
"mmr",
"traversal",
"mmr_traversal",
)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> list[Document]:
if self.search_type == "traversal":
return list(self.vectorstore.traversal_search(query, **self.search_kwargs))
elif self.search_type == "mmr_traversal":
return list(
self.vectorstore.mmr_traversal_search(query, **self.search_kwargs)
)
else:
return super()._get_relevant_documents(query, run_manager=run_manager)
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> list[Document]:
if self.search_type == "traversal":
return [
doc
async for doc in self.vectorstore.atraversal_search(
query, **self.search_kwargs
)
]
elif self.search_type == "mmr_traversal":
return [
doc
async for doc in self.vectorstore.ammr_traversal_search(
query, **self.search_kwargs
)
]
else:
return await super()._aget_relevant_documents(
query, run_manager=run_manager
)

View File

@ -1,102 +0,0 @@
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Literal, Union
from langchain_core._api import beta
from langchain_core.documents import Document
@beta()
@dataclass(frozen=True)
class Link:
"""A link to/from a tag of a given tag.
Edges exist from nodes with an outgoing link to nodes with a matching incoming link.
"""
kind: str
"""The kind of link. Allows different extractors to use the same tag name without
creating collisions between extractors. For example keyword vs url."""
direction: Literal["in", "out", "bidir"]
"""The direction of the link."""
tag: str
"""The tag of the link."""
@staticmethod
def incoming(kind: str, tag: str) -> "Link":
"""Create an incoming link."""
return Link(kind=kind, direction="in", tag=tag)
@staticmethod
def outgoing(kind: str, tag: str) -> "Link":
"""Create an outgoing link."""
return Link(kind=kind, direction="out", tag=tag)
@staticmethod
def bidir(kind: str, tag: str) -> "Link":
"""Create a bidirectional link."""
return Link(kind=kind, direction="bidir", tag=tag)
METADATA_LINKS_KEY = "links"
@beta()
def get_links(doc: Document) -> list[Link]:
"""Get the links from a document.
Args:
doc: The document to get the link tags from.
Returns:
The set of link tags from the document.
"""
links = doc.metadata.setdefault(METADATA_LINKS_KEY, [])
if not isinstance(links, list):
# Convert to a list and remember that.
links = list(links)
doc.metadata[METADATA_LINKS_KEY] = links
return links
@beta()
def add_links(doc: Document, *links: Union[Link, Iterable[Link]]) -> None:
"""Add links to the given metadata.
Args:
doc: The document to add the links to.
*links: The links to add to the document.
"""
links_in_metadata = get_links(doc)
for link in links:
if isinstance(link, Iterable):
links_in_metadata.extend(link)
else:
links_in_metadata.append(link)
@beta()
def copy_with_links(doc: Document, *links: Union[Link, Iterable[Link]]) -> Document:
"""Return a document with the given links added.
Args:
doc: The document to add the links to.
*links: The links to add to the document.
Returns:
A document with a shallow-copy of the metadata with the links added.
"""
new_links = set(get_links(doc))
for link in links:
if isinstance(link, Iterable):
new_links.update(link)
else:
new_links.add(link)
return Document(
page_content=doc.page_content,
metadata={
**doc.metadata,
METADATA_LINKS_KEY: list(new_links),
},
)