mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-31 08:32:32 +00:00
core[minor]: Add Graph Store component (#23092)
This PR introduces a GraphStore component. GraphStore extends VectorStore with the concept of links between documents based on document metadata. This allows linking documents based on a variety of techniques, including common keywords, explicit links in the content, and other patterns. This works with existing Documents, so it’s easy to extend existing VectorStores to be used as GraphStores. The interface can be implemented for any Vector Store technology that supports metadata, not only graph DBs. When retrieving documents for a given query, the first level of search is done using classical similarity search. Next, links may be followed using various traversal strategies to get additional documents. This allows documents to be retrieved that aren’t directly similar to the query but contain relevant information. 2 retrieving methods are added to the VectorStore ones : * traversal_search which gets all linked documents up to a certain depth * mmr_traversal_search which selects linked documents using an MMR algorithm to have more diverse results. If a depth of retrieval of 0 is used, GraphStore is effectively a VectorStore. It enables an easy transition from a simple VectorStore to GraphStore by adding links between documents as a second step. An implementation for Apache Cassandra is also proposed. See https://github.com/datastax/ragstack-ai/blob/main/libs/knowledge-store/notebooks/astra_support.ipynb for a notebook explaining how to use GraphStore and that shows that it can answer correctly to questions that a simple VectorStore cannot. **Twitter handle:** _cbornet
This commit is contained in:
parent
77f5fc3d55
commit
42d049f618
@ -0,0 +1,3 @@
|
||||
from langchain_community.graph_vectorstores.cassandra import CassandraGraphVectorStore
|
||||
|
||||
__all__ = ["CassandraGraphVectorStore"]
|
@ -0,0 +1,172 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
)
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.graph_vectorstores.base import (
|
||||
GraphVectorStore,
|
||||
Node,
|
||||
nodes_to_documents,
|
||||
)
|
||||
|
||||
from langchain_community.utilities.cassandra import SetupMode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from cassandra.cluster import Session
|
||||
|
||||
|
||||
class CassandraGraphVectorStore(GraphVectorStore):
|
||||
def __init__(
|
||||
self,
|
||||
embedding: Embeddings,
|
||||
*,
|
||||
node_table: str = "graph_nodes",
|
||||
targets_table: str = "graph_targets",
|
||||
session: Optional[Session] = None,
|
||||
keyspace: Optional[str] = None,
|
||||
setup_mode: SetupMode = SetupMode.SYNC,
|
||||
):
|
||||
"""
|
||||
Create the hybrid graph store.
|
||||
Parameters configure the ways that edges should be added between
|
||||
documents. Many take `Union[bool, Set[str]]`, with `False` disabling
|
||||
inference, `True` enabling it globally between all documents, and a set
|
||||
of metadata fields defining a scope in which to enable it. Specifically,
|
||||
passing a set of metadata fields such as `source` only links documents
|
||||
with the same `source` metadata value.
|
||||
Args:
|
||||
embedding: The embeddings to use for the document content.
|
||||
setup_mode: Mode used to create the Cassandra table (SYNC,
|
||||
ASYNC or OFF).
|
||||
"""
|
||||
try:
|
||||
from ragstack_knowledge_store import EmbeddingModel, graph_store
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
raise ImportError(
|
||||
"Could not import ragstack-knowledge-store python package. "
|
||||
"Please install it with `pip install ragstack-knowledge-store`."
|
||||
)
|
||||
|
||||
self._embedding = embedding
|
||||
_setup_mode = getattr(graph_store.SetupMode, setup_mode.name)
|
||||
|
||||
class _EmbeddingModelAdapter(EmbeddingModel):
|
||||
def __init__(self, embeddings: Embeddings):
|
||||
self.embeddings = embeddings
|
||||
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
return self.embeddings.embed_documents(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
return self.embeddings.embed_query(text)
|
||||
|
||||
async def aembed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
return await self.embeddings.aembed_documents(texts)
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
return await self.embeddings.aembed_query(text)
|
||||
|
||||
self.store = graph_store.GraphStore(
|
||||
embedding=_EmbeddingModelAdapter(embedding),
|
||||
node_table=node_table,
|
||||
targets_table=targets_table,
|
||||
session=session,
|
||||
keyspace=keyspace,
|
||||
setup_mode=_setup_mode,
|
||||
)
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[Embeddings]:
|
||||
return self._embedding
|
||||
|
||||
def add_nodes(
|
||||
self,
|
||||
nodes: Iterable[Node],
|
||||
**kwargs: Any,
|
||||
) -> Iterable[str]:
|
||||
return self.store.add_nodes(nodes)
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls: Type["CassandraGraphVectorStore"],
|
||||
texts: Iterable[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[Iterable[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> "CassandraGraphVectorStore":
|
||||
"""Return CassandraGraphVectorStore initialized from texts and embeddings."""
|
||||
store = cls(embedding, **kwargs)
|
||||
store.add_texts(texts, metadatas, ids=ids)
|
||||
return store
|
||||
|
||||
@classmethod
|
||||
def from_documents(
|
||||
cls: Type["CassandraGraphVectorStore"],
|
||||
documents: Iterable[Document],
|
||||
embedding: Embeddings,
|
||||
ids: Optional[Iterable[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> "CassandraGraphVectorStore":
|
||||
"""Return CassandraGraphVectorStore initialized from documents and
|
||||
embeddings."""
|
||||
store = cls(embedding, **kwargs)
|
||||
store.add_documents(documents, ids=ids)
|
||||
return store
|
||||
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
embedding_vector = self._embedding.embed_query(query)
|
||||
return self.similarity_search_by_vector(
|
||||
embedding_vector,
|
||||
k=k,
|
||||
)
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self, embedding: List[float], k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
nodes = self.store.similarity_search(embedding, k=k)
|
||||
return list(nodes_to_documents(nodes))
|
||||
|
||||
def traversal_search(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
k: int = 4,
|
||||
depth: int = 1,
|
||||
**kwargs: Any,
|
||||
) -> Iterable[Document]:
|
||||
nodes = self.store.traversal_search(query, k=k, depth=depth)
|
||||
return nodes_to_documents(nodes)
|
||||
|
||||
def mmr_traversal_search(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
k: int = 4,
|
||||
depth: int = 2,
|
||||
fetch_k: int = 100,
|
||||
adjacent_k: int = 10,
|
||||
lambda_mult: float = 0.5,
|
||||
score_threshold: float = float("-inf"),
|
||||
**kwargs: Any,
|
||||
) -> Iterable[Document]:
|
||||
nodes = self.store.mmr_traversal_search(
|
||||
query,
|
||||
k=k,
|
||||
depth=depth,
|
||||
fetch_k=fetch_k,
|
||||
adjacent_k=adjacent_k,
|
||||
lambda_mult=lambda_mult,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
return nodes_to_documents(nodes)
|
@ -0,0 +1,272 @@
|
||||
import math
|
||||
import os
|
||||
from typing import Iterable, List, Optional, Type
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.graph_vectorstores.links import METADATA_LINKS_KEY, Link
|
||||
|
||||
from langchain_community.graph_vectorstores import CassandraGraphVectorStore
|
||||
|
||||
CASSANDRA_DEFAULT_KEYSPACE = "graph_test_keyspace"
|
||||
|
||||
|
||||
def _get_graph_store(
|
||||
embedding_class: Type[Embeddings], documents: Iterable[Document] = ()
|
||||
) -> CassandraGraphVectorStore:
|
||||
import cassio
|
||||
from cassandra.cluster import Cluster
|
||||
from cassio.config import check_resolve_session, resolve_keyspace
|
||||
|
||||
node_table = "graph_test_node_table"
|
||||
edge_table = "graph_test_edge_table"
|
||||
|
||||
if any(
|
||||
env_var in os.environ
|
||||
for env_var in [
|
||||
"CASSANDRA_CONTACT_POINTS",
|
||||
"ASTRA_DB_APPLICATION_TOKEN",
|
||||
"ASTRA_DB_INIT_STRING",
|
||||
]
|
||||
):
|
||||
cassio.init(auto=True)
|
||||
session = check_resolve_session()
|
||||
else:
|
||||
cluster = Cluster()
|
||||
session = cluster.connect()
|
||||
keyspace = resolve_keyspace() or CASSANDRA_DEFAULT_KEYSPACE
|
||||
cassio.init(session=session, keyspace=keyspace)
|
||||
# ensure keyspace exists
|
||||
session.execute(
|
||||
(
|
||||
f"CREATE KEYSPACE IF NOT EXISTS {keyspace} "
|
||||
f"WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}"
|
||||
)
|
||||
)
|
||||
session.execute(f"DROP TABLE IF EXISTS {keyspace}.{node_table}")
|
||||
session.execute(f"DROP TABLE IF EXISTS {keyspace}.{edge_table}")
|
||||
store = CassandraGraphVectorStore.from_documents(
|
||||
documents,
|
||||
embedding=embedding_class(),
|
||||
session=session,
|
||||
keyspace=keyspace,
|
||||
node_table=node_table,
|
||||
targets_table=edge_table,
|
||||
)
|
||||
return store
|
||||
|
||||
|
||||
class FakeEmbeddings(Embeddings):
|
||||
"""Fake embeddings functionality for testing."""
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Return simple embeddings.
|
||||
Embeddings encode each text as its index."""
|
||||
return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))]
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return self.embed_documents(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Return constant query embeddings.
|
||||
Embeddings are identical to embed_documents(texts)[0].
|
||||
Distance to each text will be that text's index,
|
||||
as it was passed to embed_documents."""
|
||||
return [float(1.0)] * 9 + [float(0.0)]
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
return self.embed_query(text)
|
||||
|
||||
|
||||
class AngularTwoDimensionalEmbeddings(Embeddings):
|
||||
"""
|
||||
From angles (as strings in units of pi) to unit embedding vectors on a circle.
|
||||
"""
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""
|
||||
Make a list of texts into a list of embedding vectors.
|
||||
"""
|
||||
return [self.embed_query(text) for text in texts]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""
|
||||
Convert input text to a 'vector' (list of floats).
|
||||
If the text is a number, use it as the angle for the
|
||||
unit vector in units of pi.
|
||||
Any other input text becomes the singular result [0, 0] !
|
||||
"""
|
||||
try:
|
||||
angle = float(text)
|
||||
return [math.cos(angle * math.pi), math.sin(angle * math.pi)]
|
||||
except ValueError:
|
||||
# Assume: just test string, no attention is paid to values.
|
||||
return [0.0, 0.0]
|
||||
|
||||
|
||||
def _result_ids(docs: Iterable[Document]) -> List[Optional[str]]:
|
||||
return [doc.id for doc in docs]
|
||||
|
||||
|
||||
def test_mmr_traversal() -> None:
|
||||
"""
|
||||
Test end to end construction and MMR search.
|
||||
The embedding function used here ensures `texts` become
|
||||
the following vectors on a circle (numbered v0 through v3):
|
||||
|
||||
______ v2
|
||||
/ \
|
||||
/ | v1
|
||||
v3 | . | query
|
||||
| / v0
|
||||
|______/ (N.B. very crude drawing)
|
||||
|
||||
With fetch_k==2 and k==2, when query is at (1, ),
|
||||
one expects that v2 and v0 are returned (in some order)
|
||||
because v1 is "too close" to v0 (and v0 is closer than v1)).
|
||||
|
||||
Both v2 and v3 are reachable via edges from v0, so once it is
|
||||
selected, those are both considered.
|
||||
"""
|
||||
store = _get_graph_store(AngularTwoDimensionalEmbeddings)
|
||||
|
||||
v0 = Document(
|
||||
id="v0",
|
||||
page_content="-0.124",
|
||||
metadata={
|
||||
METADATA_LINKS_KEY: [
|
||||
Link.outgoing(kind="explicit", tag="link"),
|
||||
],
|
||||
},
|
||||
)
|
||||
v1 = Document(
|
||||
id="v1",
|
||||
page_content="+0.127",
|
||||
)
|
||||
v2 = Document(
|
||||
id="v2",
|
||||
page_content="+0.25",
|
||||
metadata={
|
||||
METADATA_LINKS_KEY: [
|
||||
Link.incoming(kind="explicit", tag="link"),
|
||||
],
|
||||
},
|
||||
)
|
||||
v3 = Document(
|
||||
id="v3",
|
||||
page_content="+1.0",
|
||||
metadata={
|
||||
METADATA_LINKS_KEY: [
|
||||
Link.incoming(kind="explicit", tag="link"),
|
||||
],
|
||||
},
|
||||
)
|
||||
store.add_documents([v0, v1, v2, v3])
|
||||
|
||||
results = store.mmr_traversal_search("0.0", k=2, fetch_k=2)
|
||||
assert _result_ids(results) == ["v0", "v2"]
|
||||
|
||||
# With max depth 0, no edges are traversed, so this doesn't reach v2 or v3.
|
||||
# So it ends up picking "v1" even though it's similar to "v0".
|
||||
results = store.mmr_traversal_search("0.0", k=2, fetch_k=2, depth=0)
|
||||
assert _result_ids(results) == ["v0", "v1"]
|
||||
|
||||
# With max depth 0 but higher `fetch_k`, we encounter v2
|
||||
results = store.mmr_traversal_search("0.0", k=2, fetch_k=3, depth=0)
|
||||
assert _result_ids(results) == ["v0", "v2"]
|
||||
|
||||
# v0 score is .46, v2 score is 0.16 so it won't be chosen.
|
||||
results = store.mmr_traversal_search("0.0", k=2, score_threshold=0.2)
|
||||
assert _result_ids(results) == ["v0"]
|
||||
|
||||
# with k=4 we should get all of the documents.
|
||||
results = store.mmr_traversal_search("0.0", k=4)
|
||||
assert _result_ids(results) == ["v0", "v2", "v1", "v3"]
|
||||
|
||||
|
||||
def test_write_retrieve_keywords() -> None:
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
greetings = Document(
|
||||
id="greetings",
|
||||
page_content="Typical Greetings",
|
||||
metadata={
|
||||
METADATA_LINKS_KEY: [
|
||||
Link.incoming(kind="parent", tag="parent"),
|
||||
],
|
||||
},
|
||||
)
|
||||
doc1 = Document(
|
||||
id="doc1",
|
||||
page_content="Hello World",
|
||||
metadata={
|
||||
METADATA_LINKS_KEY: [
|
||||
Link.outgoing(kind="parent", tag="parent"),
|
||||
Link.bidir(kind="kw", tag="greeting"),
|
||||
Link.bidir(kind="kw", tag="world"),
|
||||
],
|
||||
},
|
||||
)
|
||||
doc2 = Document(
|
||||
id="doc2",
|
||||
page_content="Hello Earth",
|
||||
metadata={
|
||||
METADATA_LINKS_KEY: [
|
||||
Link.outgoing(kind="parent", tag="parent"),
|
||||
Link.bidir(kind="kw", tag="greeting"),
|
||||
Link.bidir(kind="kw", tag="earth"),
|
||||
],
|
||||
},
|
||||
)
|
||||
store = _get_graph_store(OpenAIEmbeddings, [greetings, doc1, doc2])
|
||||
|
||||
# Doc2 is more similar, but World and Earth are similar enough that doc1 also
|
||||
# shows up.
|
||||
results: Iterable[Document] = store.similarity_search("Earth", k=2)
|
||||
assert _result_ids(results) == ["doc2", "doc1"]
|
||||
|
||||
results = store.similarity_search("Earth", k=1)
|
||||
assert _result_ids(results) == ["doc2"]
|
||||
|
||||
results = store.traversal_search("Earth", k=2, depth=0)
|
||||
assert _result_ids(results) == ["doc2", "doc1"]
|
||||
|
||||
results = store.traversal_search("Earth", k=2, depth=1)
|
||||
assert _result_ids(results) == ["doc2", "doc1", "greetings"]
|
||||
|
||||
# K=1 only pulls in doc2 (Hello Earth)
|
||||
results = store.traversal_search("Earth", k=1, depth=0)
|
||||
assert _result_ids(results) == ["doc2"]
|
||||
|
||||
# K=1 only pulls in doc2 (Hello Earth). Depth=1 traverses to parent and via
|
||||
# keyword edge.
|
||||
results = store.traversal_search("Earth", k=1, depth=1)
|
||||
assert set(_result_ids(results)) == {"doc2", "doc1", "greetings"}
|
||||
|
||||
|
||||
def test_metadata() -> None:
|
||||
store = _get_graph_store(FakeEmbeddings)
|
||||
store.add_documents(
|
||||
[
|
||||
Document(
|
||||
id="a",
|
||||
page_content="A",
|
||||
metadata={
|
||||
METADATA_LINKS_KEY: [
|
||||
Link.incoming(kind="hyperlink", tag="http://a"),
|
||||
Link.bidir(kind="other", tag="foo"),
|
||||
],
|
||||
"other": "some other field",
|
||||
},
|
||||
)
|
||||
]
|
||||
)
|
||||
results = store.similarity_search("A")
|
||||
assert len(results) == 1
|
||||
assert results[0].id == "a"
|
||||
metadata = results[0].metadata
|
||||
assert metadata["other"] == "some other field"
|
||||
assert set(metadata[METADATA_LINKS_KEY]) == {
|
||||
Link.incoming(kind="hyperlink", tag="http://a"),
|
||||
Link.bidir(kind="other", tag="foo"),
|
||||
}
|
15
libs/core/langchain_core/graph_vectorstores/__init__.py
Normal file
15
libs/core/langchain_core/graph_vectorstores/__init__.py
Normal file
@ -0,0 +1,15 @@
|
||||
from langchain_core.graph_vectorstores.base import (
|
||||
GraphVectorStore,
|
||||
GraphVectorStoreRetriever,
|
||||
Node,
|
||||
)
|
||||
from langchain_core.graph_vectorstores.links import (
|
||||
Link,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"GraphVectorStore",
|
||||
"GraphVectorStoreRetriever",
|
||||
"Node",
|
||||
"Link",
|
||||
]
|
692
libs/core/langchain_core/graph_vectorstores/base.py
Normal file
692
libs/core/langchain_core/graph_vectorstores/base.py
Normal file
@ -0,0 +1,692 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterable,
|
||||
ClassVar,
|
||||
Collection,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
)
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.graph_vectorstores.links import METADATA_LINKS_KEY, Link
|
||||
from langchain_core.load import Serializable
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.runnables import run_in_executor
|
||||
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
|
||||
|
||||
|
||||
def _has_next(iterator: Iterator) -> bool:
|
||||
"""Checks if the iterator has more elements.
|
||||
Warning: consumes an element from the iterator"""
|
||||
sentinel = object()
|
||||
return next(iterator, sentinel) is not sentinel
|
||||
|
||||
|
||||
class Node(Serializable):
|
||||
"""Node in the GraphVectorStore.
|
||||
|
||||
Edges exist from nodes with an outgoing link to nodes with a matching incoming link.
|
||||
|
||||
For instance two nodes `a` and `b` connected over a hyperlink `https://some-url`
|
||||
would look like:
|
||||
|
||||
.. code-block:: python
|
||||
[
|
||||
Node(
|
||||
id="a",
|
||||
text="some text a",
|
||||
links= [
|
||||
Link(kind="hyperlink", tag="https://some-url", direction="incoming")
|
||||
],
|
||||
),
|
||||
Node(
|
||||
id="b",
|
||||
text="some text b",
|
||||
links= [
|
||||
Link(kind="hyperlink", tag="https://some-url", direction="outgoing")
|
||||
],
|
||||
)
|
||||
]
|
||||
"""
|
||||
|
||||
id: Optional[str] = None
|
||||
"""Unique ID for the node. Will be generated by the GraphVectorStore if not set."""
|
||||
text: str
|
||||
"""Text contained by the node."""
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
"""Metadata for the node."""
|
||||
links: List[Link] = Field(default_factory=list)
|
||||
"""Links associated with the node."""
|
||||
|
||||
|
||||
def _texts_to_nodes(
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[Iterable[dict]],
|
||||
ids: Optional[Iterable[str]],
|
||||
) -> Iterator[Node]:
|
||||
metadatas_it = iter(metadatas) if metadatas else None
|
||||
ids_it = iter(ids) if ids else None
|
||||
for text in texts:
|
||||
try:
|
||||
_metadata = next(metadatas_it).copy() if metadatas_it else {}
|
||||
except StopIteration:
|
||||
raise ValueError("texts iterable longer than metadatas")
|
||||
try:
|
||||
_id = next(ids_it) if ids_it else None
|
||||
except StopIteration:
|
||||
raise ValueError("texts iterable longer than ids")
|
||||
|
||||
links = _metadata.pop(METADATA_LINKS_KEY, [])
|
||||
if not isinstance(links, list):
|
||||
links = list(links)
|
||||
yield Node(
|
||||
id=_id,
|
||||
metadata=_metadata,
|
||||
text=text,
|
||||
links=links,
|
||||
)
|
||||
if ids_it and _has_next(ids_it):
|
||||
raise ValueError("ids iterable longer than texts")
|
||||
if metadatas_it and _has_next(metadatas_it):
|
||||
raise ValueError("metadatas iterable longer than texts")
|
||||
|
||||
|
||||
def _documents_to_nodes(documents: Iterable[Document]) -> Iterator[Node]:
|
||||
for doc in documents:
|
||||
metadata = doc.metadata.copy()
|
||||
links = metadata.pop(METADATA_LINKS_KEY, [])
|
||||
if not isinstance(links, list):
|
||||
links = list(links)
|
||||
yield Node(
|
||||
id=doc.id,
|
||||
metadata=metadata,
|
||||
text=doc.page_content,
|
||||
links=links,
|
||||
)
|
||||
|
||||
|
||||
def nodes_to_documents(nodes: Iterable[Node]) -> Iterator[Document]:
|
||||
for node in nodes:
|
||||
metadata = node.metadata.copy()
|
||||
metadata[METADATA_LINKS_KEY] = [
|
||||
# Convert the core `Link` (from the node) back to the local `Link`.
|
||||
Link(kind=link.kind, direction=link.direction, tag=link.tag)
|
||||
for link in node.links
|
||||
]
|
||||
|
||||
yield Document(
|
||||
id=node.id,
|
||||
page_content=node.text,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
class GraphVectorStore(VectorStore):
|
||||
"""A hybrid vector-and-graph graph store.
|
||||
|
||||
Document chunks support vector-similarity search as well as edges linking
|
||||
chunks based on structural and semantic properties.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_nodes(
|
||||
self,
|
||||
nodes: Iterable[Node],
|
||||
**kwargs: Any,
|
||||
) -> Iterable[str]:
|
||||
"""Add nodes to the graph store.
|
||||
|
||||
Args:
|
||||
nodes: the nodes to add.
|
||||
"""
|
||||
|
||||
async def aadd_nodes(
|
||||
self,
|
||||
nodes: Iterable[Node],
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[str]:
|
||||
"""Add nodes to the graph store.
|
||||
|
||||
Args:
|
||||
nodes: the nodes to add.
|
||||
"""
|
||||
iterator = iter(await run_in_executor(None, self.add_nodes, nodes, **kwargs))
|
||||
done = object()
|
||||
while True:
|
||||
doc = await run_in_executor(None, next, iterator, done)
|
||||
if doc is done:
|
||||
break
|
||||
yield doc # type: ignore[misc]
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[Iterable[dict]] = None,
|
||||
*,
|
||||
ids: Optional[Iterable[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
|
||||
The Links present in the metadata field `links` will be extracted to create
|
||||
the `Node` links.
|
||||
|
||||
Eg if nodes `a` and `b` are connected over a hyperlink `https://some-url`, the
|
||||
function call would look like:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
store.add_texts(
|
||||
ids=["a", "b"],
|
||||
texts=["some text a", "some text b"],
|
||||
metadatas=[
|
||||
{
|
||||
"links": [
|
||||
Link.incoming(kind="hyperlink", tag="https://some-url")
|
||||
]
|
||||
},
|
||||
{
|
||||
"links": [
|
||||
Link.outgoing(kind="hyperlink", tag="https://some-url")
|
||||
]
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
Args:
|
||||
texts: Iterable of strings to add to the vectorstore.
|
||||
metadatas: Optional list of metadatas associated with the texts.
|
||||
The metadata key `links` shall be an iterable of
|
||||
:py:class:`~langchain_core.graph_vectorstores.links.Link`.
|
||||
**kwargs: vectorstore specific parameters.
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
"""
|
||||
nodes = _texts_to_nodes(texts, metadatas, ids)
|
||||
return list(self.add_nodes(nodes, **kwargs))
|
||||
|
||||
async def aadd_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[Iterable[dict]] = None,
|
||||
*,
|
||||
ids: Optional[Iterable[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
|
||||
The Links present in the metadata field `links` will be extracted to create
|
||||
the `Node` links.
|
||||
|
||||
Eg if nodes `a` and `b` are connected over a hyperlink `https://some-url`, the
|
||||
function call would look like:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
await store.aadd_texts(
|
||||
ids=["a", "b"],
|
||||
texts=["some text a", "some text b"],
|
||||
metadatas=[
|
||||
{
|
||||
"links": [
|
||||
Link.incoming(kind="hyperlink", tag="https://some-url")
|
||||
]
|
||||
},
|
||||
{
|
||||
"links": [
|
||||
Link.outgoing(kind="hyperlink", tag="https://some-url")
|
||||
]
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
Args:
|
||||
texts: Iterable of strings to add to the vectorstore.
|
||||
metadatas: Optional list of metadatas associated with the texts.
|
||||
The metadata key `links` shall be an iterable of
|
||||
:py:class:`~langchain_core.graph_vectorstores.links.Link`.
|
||||
**kwargs: vectorstore specific parameters.
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
"""
|
||||
nodes = _texts_to_nodes(texts, metadatas, ids)
|
||||
return [_id async for _id in self.aadd_nodes(nodes, **kwargs)]
|
||||
|
||||
def add_documents(
|
||||
self,
|
||||
documents: Iterable[Document],
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more documents through the embeddings and add to the vectorstore.
|
||||
|
||||
The Links present in the document metadata field `links` will be extracted to
|
||||
create the `Node` links.
|
||||
|
||||
Eg if nodes `a` and `b` are connected over a hyperlink `https://some-url`, the
|
||||
function call would look like:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
store.add_documents(
|
||||
[
|
||||
Document(
|
||||
id="a",
|
||||
page_content="some text a",
|
||||
metadata={
|
||||
"links": [
|
||||
Link.incoming(kind="hyperlink", tag="http://some-url")
|
||||
]
|
||||
}
|
||||
),
|
||||
Document(
|
||||
id="b",
|
||||
page_content="some text b",
|
||||
metadata={
|
||||
"links": [
|
||||
Link.outgoing(kind="hyperlink", tag="http://some-url")
|
||||
]
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
)
|
||||
|
||||
Args:
|
||||
documents: Documents to add to the vectorstore.
|
||||
The document's metadata key `links` shall be an iterable of
|
||||
:py:class:`~langchain_core.graph_vectorstores.links.Link`.
|
||||
|
||||
Returns:
|
||||
List of IDs of the added texts.
|
||||
"""
|
||||
nodes = _documents_to_nodes(documents)
|
||||
return list(self.add_nodes(nodes, **kwargs))
|
||||
|
||||
async def aadd_documents(
|
||||
self,
|
||||
documents: Iterable[Document],
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more documents through the embeddings and add to the vectorstore.
|
||||
|
||||
The Links present in the document metadata field `links` will be extracted to
|
||||
create the `Node` links.
|
||||
|
||||
Eg if nodes `a` and `b` are connected over a hyperlink `https://some-url`, the
|
||||
function call would look like:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
store.add_documents(
|
||||
[
|
||||
Document(
|
||||
id="a",
|
||||
page_content="some text a",
|
||||
metadata={
|
||||
"links": [
|
||||
Link.incoming(kind="hyperlink", tag="http://some-url")
|
||||
]
|
||||
}
|
||||
),
|
||||
Document(
|
||||
id="b",
|
||||
page_content="some text b",
|
||||
metadata={
|
||||
"links": [
|
||||
Link.outgoing(kind="hyperlink", tag="http://some-url")
|
||||
]
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
)
|
||||
|
||||
Args:
|
||||
documents: Documents to add to the vectorstore.
|
||||
The document's metadata key `links` shall be an iterable of
|
||||
:py:class:`~langchain_core.graph_vectorstores.links.Link`.
|
||||
|
||||
Returns:
|
||||
List of IDs of the added texts.
|
||||
"""
|
||||
nodes = _documents_to_nodes(documents)
|
||||
return [_id async for _id in self.aadd_nodes(nodes, **kwargs)]
|
||||
|
||||
@abstractmethod
|
||||
def traversal_search(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
k: int = 4,
|
||||
depth: int = 1,
|
||||
**kwargs: Any,
|
||||
) -> Iterable[Document]:
|
||||
"""Retrieve documents from traversing this graph store.
|
||||
|
||||
First, `k` nodes are retrieved using a search for each `query` string.
|
||||
Then, additional nodes are discovered up to the given `depth` from those
|
||||
starting nodes.
|
||||
|
||||
Args:
|
||||
query: The query string.
|
||||
k: The number of Documents to return from the initial search.
|
||||
Defaults to 4. Applies to each of the query strings.
|
||||
depth: The maximum depth of edges to traverse. Defaults to 1.
|
||||
Returns:
|
||||
Retrieved documents.
|
||||
"""
|
||||
|
||||
async def atraversal_search(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
k: int = 4,
|
||||
depth: int = 1,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[Document]:
|
||||
"""Retrieve documents from traversing this graph store.
|
||||
|
||||
First, `k` nodes are retrieved using a search for each `query` string.
|
||||
Then, additional nodes are discovered up to the given `depth` from those
|
||||
starting nodes.
|
||||
|
||||
Args:
|
||||
query: The query string.
|
||||
k: The number of Documents to return from the initial search.
|
||||
Defaults to 4. Applies to each of the query strings.
|
||||
depth: The maximum depth of edges to traverse. Defaults to 1.
|
||||
Returns:
|
||||
Retrieved documents.
|
||||
"""
|
||||
iterator = iter(
|
||||
await run_in_executor(
|
||||
None, self.traversal_search, query, k=k, depth=depth, **kwargs
|
||||
)
|
||||
)
|
||||
done = object()
|
||||
while True:
|
||||
doc = await run_in_executor(None, next, iterator, done)
|
||||
if doc is done:
|
||||
break
|
||||
yield doc # type: ignore[misc]
|
||||
|
||||
@abstractmethod
|
||||
def mmr_traversal_search(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
k: int = 4,
|
||||
depth: int = 2,
|
||||
fetch_k: int = 100,
|
||||
adjacent_k: int = 10,
|
||||
lambda_mult: float = 0.5,
|
||||
score_threshold: float = float("-inf"),
|
||||
**kwargs: Any,
|
||||
) -> Iterable[Document]:
|
||||
"""Retrieve documents from this graph store using MMR-traversal.
|
||||
|
||||
This strategy first retrieves the top `fetch_k` results by similarity to
|
||||
the question. It then selects the top `k` results based on
|
||||
maximum-marginal relevance using the given `lambda_mult`.
|
||||
|
||||
At each step, it considers the (remaining) documents from `fetch_k` as
|
||||
well as any documents connected by edges to a selected document
|
||||
retrieved based on similarity (a "root").
|
||||
|
||||
Args:
|
||||
query: The query string to search for.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch via similarity.
|
||||
Defaults to 100.
|
||||
adjacent_k: Number of adjacent Documents to fetch.
|
||||
Defaults to 10.
|
||||
depth: Maximum depth of a node (number of edges) from a node
|
||||
retrieved via similarity. Defaults to 2.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding to maximum
|
||||
diversity and 1 to minimum diversity. Defaults to 0.5.
|
||||
score_threshold: Only documents with a score greater than or equal
|
||||
this threshold will be chosen. Defaults to negative infinity.
|
||||
"""
|
||||
|
||||
async def ammr_traversal_search(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
k: int = 4,
|
||||
depth: int = 2,
|
||||
fetch_k: int = 100,
|
||||
adjacent_k: int = 10,
|
||||
lambda_mult: float = 0.5,
|
||||
score_threshold: float = float("-inf"),
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[Document]:
|
||||
"""Retrieve documents from this graph store using MMR-traversal.
|
||||
|
||||
This strategy first retrieves the top `fetch_k` results by similarity to
|
||||
the question. It then selects the top `k` results based on
|
||||
maximum-marginal relevance using the given `lambda_mult`.
|
||||
|
||||
At each step, it considers the (remaining) documents from `fetch_k` as
|
||||
well as any documents connected by edges to a selected document
|
||||
retrieved based on similarity (a "root").
|
||||
|
||||
Args:
|
||||
query: The query string to search for.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch via similarity.
|
||||
Defaults to 100.
|
||||
adjacent_k: Number of adjacent Documents to fetch.
|
||||
Defaults to 10.
|
||||
depth: Maximum depth of a node (number of edges) from a node
|
||||
retrieved via similarity. Defaults to 2.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding to maximum
|
||||
diversity and 1 to minimum diversity. Defaults to 0.5.
|
||||
score_threshold: Only documents with a score greater than or equal
|
||||
this threshold will be chosen. Defaults to negative infinity.
|
||||
"""
|
||||
iterator = iter(
|
||||
await run_in_executor(
|
||||
None,
|
||||
self.mmr_traversal_search,
|
||||
query,
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
adjacent_k=adjacent_k,
|
||||
depth=depth,
|
||||
lambda_mult=lambda_mult,
|
||||
score_threshold=score_threshold,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
done = object()
|
||||
while True:
|
||||
doc = await run_in_executor(None, next, iterator, done)
|
||||
if doc is done:
|
||||
break
|
||||
yield doc # type: ignore[misc]
|
||||
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
return list(self.traversal_search(query, k=k, depth=0))
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
return list(
|
||||
self.mmr_traversal_search(
|
||||
query, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, depth=0
|
||||
)
|
||||
)
|
||||
|
||||
async def asimilarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
return [doc async for doc in self.atraversal_search(query, k=k, depth=0)]
|
||||
|
||||
def search(self, query: str, search_type: str, **kwargs: Any) -> List[Document]:
|
||||
if search_type == "similarity":
|
||||
return self.similarity_search(query, **kwargs)
|
||||
elif search_type == "similarity_score_threshold":
|
||||
docs_and_similarities = self.similarity_search_with_relevance_scores(
|
||||
query, **kwargs
|
||||
)
|
||||
return [doc for doc, _ in docs_and_similarities]
|
||||
elif search_type == "mmr":
|
||||
return self.max_marginal_relevance_search(query, **kwargs)
|
||||
elif search_type == "traversal":
|
||||
return list(self.traversal_search(query, **kwargs))
|
||||
elif search_type == "mmr_traversal":
|
||||
return list(self.mmr_traversal_search(query, **kwargs))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"search_type of {search_type} not allowed. Expected "
|
||||
"search_type to be 'similarity', 'similarity_score_threshold', "
|
||||
"'mmr' or 'traversal'."
|
||||
)
|
||||
|
||||
async def asearch(
|
||||
self, query: str, search_type: str, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
if search_type == "similarity":
|
||||
return await self.asimilarity_search(query, **kwargs)
|
||||
elif search_type == "similarity_score_threshold":
|
||||
docs_and_similarities = await self.asimilarity_search_with_relevance_scores(
|
||||
query, **kwargs
|
||||
)
|
||||
return [doc for doc, _ in docs_and_similarities]
|
||||
elif search_type == "mmr":
|
||||
return await self.amax_marginal_relevance_search(query, **kwargs)
|
||||
elif search_type == "traversal":
|
||||
return [doc async for doc in self.atraversal_search(query, **kwargs)]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"search_type of {search_type} not allowed. Expected "
|
||||
"search_type to be 'similarity', 'similarity_score_threshold', "
|
||||
"'mmr' or 'traversal'."
|
||||
)
|
||||
|
||||
def as_retriever(self, **kwargs: Any) -> "GraphVectorStoreRetriever":
|
||||
"""Return GraphVectorStoreRetriever initialized from this GraphVectorStore.
|
||||
|
||||
Args:
|
||||
search_type (Optional[str]): Defines the type of search that
|
||||
the Retriever should perform.
|
||||
Can be "traversal" (default), "similarity", "mmr", or
|
||||
"similarity_score_threshold".
|
||||
search_kwargs (Optional[Dict]): Keyword arguments to pass to the
|
||||
search function. Can include things like:
|
||||
k: Amount of documents to return (Default: 4)
|
||||
depth: The maximum depth of edges to traverse (Default: 1)
|
||||
score_threshold: Minimum relevance threshold
|
||||
for similarity_score_threshold
|
||||
fetch_k: Amount of documents to pass to MMR algorithm (Default: 20)
|
||||
lambda_mult: Diversity of results returned by MMR;
|
||||
1 for minimum diversity and 0 for maximum. (Default: 0.5)
|
||||
Returns:
|
||||
Retriever for this GraphVectorStore.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Retrieve documents traversing edges
|
||||
docsearch.as_retriever(
|
||||
search_type="traversal",
|
||||
search_kwargs={'k': 6, 'depth': 3}
|
||||
)
|
||||
|
||||
# Retrieve more documents with higher diversity
|
||||
# Useful if your dataset has many similar documents
|
||||
docsearch.as_retriever(
|
||||
search_type="mmr",
|
||||
search_kwargs={'k': 6, 'lambda_mult': 0.25}
|
||||
)
|
||||
|
||||
# Fetch more documents for the MMR algorithm to consider
|
||||
# But only return the top 5
|
||||
docsearch.as_retriever(
|
||||
search_type="mmr",
|
||||
search_kwargs={'k': 5, 'fetch_k': 50}
|
||||
)
|
||||
|
||||
# Only retrieve documents that have a relevance score
|
||||
# Above a certain threshold
|
||||
docsearch.as_retriever(
|
||||
search_type="similarity_score_threshold",
|
||||
search_kwargs={'score_threshold': 0.8}
|
||||
)
|
||||
|
||||
# Only get the single most similar document from the dataset
|
||||
docsearch.as_retriever(search_kwargs={'k': 1})
|
||||
|
||||
"""
|
||||
return GraphVectorStoreRetriever(vectorstore=self, **kwargs)
|
||||
|
||||
|
||||
class GraphVectorStoreRetriever(VectorStoreRetriever):
|
||||
"""Retriever class for GraphVectorStore."""
|
||||
|
||||
vectorstore: GraphVectorStore
|
||||
"""GraphVectorStore to use for retrieval."""
|
||||
search_type: str = "traversal"
|
||||
"""Type of search to perform. Defaults to "traversal"."""
|
||||
allowed_search_types: ClassVar[Collection[str]] = (
|
||||
"similarity",
|
||||
"similarity_score_threshold",
|
||||
"mmr",
|
||||
"traversal",
|
||||
"mmr_traversal",
|
||||
)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
if self.search_type == "traversal":
|
||||
return list(self.vectorstore.traversal_search(query, **self.search_kwargs))
|
||||
elif self.search_type == "mmr_traversal":
|
||||
return list(
|
||||
self.vectorstore.mmr_traversal_search(query, **self.search_kwargs)
|
||||
)
|
||||
else:
|
||||
return super()._get_relevant_documents(query, run_manager=run_manager)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
if self.search_type == "traversal":
|
||||
return [
|
||||
doc
|
||||
async for doc in self.vectorstore.atraversal_search(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
]
|
||||
elif self.search_type == "mmr_traversal":
|
||||
return [
|
||||
doc
|
||||
async for doc in self.vectorstore.ammr_traversal_search(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
]
|
||||
else:
|
||||
return await super()._aget_relevant_documents(
|
||||
query, run_manager=run_manager
|
||||
)
|
68
libs/core/langchain_core/graph_vectorstores/links.py
Normal file
68
libs/core/langchain_core/graph_vectorstores/links.py
Normal file
@ -0,0 +1,68 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, List, Literal, Union
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Link:
|
||||
"""A link to/from a tag of a given tag.
|
||||
|
||||
Edges exist from nodes with an outgoing link to nodes with a matching incoming link.
|
||||
"""
|
||||
|
||||
kind: str
|
||||
"""The kind of link. Allows different extractors to use the same tag name without
|
||||
creating collisions between extractors. For example “keyword” vs “url”."""
|
||||
direction: Literal["in", "out", "bidir"]
|
||||
"""The direction of the link."""
|
||||
tag: str
|
||||
"""The tag of the link."""
|
||||
|
||||
@staticmethod
|
||||
def incoming(kind: str, tag: str) -> "Link":
|
||||
"""Create an incoming link."""
|
||||
return Link(kind=kind, direction="in", tag=tag)
|
||||
|
||||
@staticmethod
|
||||
def outgoing(kind: str, tag: str) -> "Link":
|
||||
"""Create an outgoing link."""
|
||||
return Link(kind=kind, direction="out", tag=tag)
|
||||
|
||||
@staticmethod
|
||||
def bidir(kind: str, tag: str) -> "Link":
|
||||
"""Create a bidirectional link."""
|
||||
return Link(kind=kind, direction="bidir", tag=tag)
|
||||
|
||||
|
||||
METADATA_LINKS_KEY = "links"
|
||||
|
||||
|
||||
def get_links(doc: Document) -> List[Link]:
|
||||
"""Get the links from a document.
|
||||
Args:
|
||||
doc: The document to get the link tags from.
|
||||
Returns:
|
||||
The set of link tags from the document.
|
||||
"""
|
||||
|
||||
links = doc.metadata.setdefault(METADATA_LINKS_KEY, [])
|
||||
if not isinstance(links, list):
|
||||
# Convert to a list and remember that.
|
||||
links = list(links)
|
||||
doc.metadata[METADATA_LINKS_KEY] = links
|
||||
return links
|
||||
|
||||
|
||||
def add_links(doc: Document, *links: Union[Link, Iterable[Link]]) -> None:
|
||||
"""Add links to the given metadata.
|
||||
Args:
|
||||
doc: The document to add the links to.
|
||||
*links: The links to add to the document.
|
||||
"""
|
||||
links_in_metadata = get_links(doc)
|
||||
for link in links:
|
||||
if isinstance(link, Iterable):
|
||||
links_in_metadata.extend(link)
|
||||
else:
|
||||
links_in_metadata.append(link)
|
59
libs/core/tests/unit_tests/test_graph_vectorstores.py
Normal file
59
libs/core/tests/unit_tests/test_graph_vectorstores.py
Normal file
@ -0,0 +1,59 @@
|
||||
import pytest
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.graph_vectorstores.base import (
|
||||
Node,
|
||||
_documents_to_nodes,
|
||||
_texts_to_nodes,
|
||||
)
|
||||
from langchain_core.graph_vectorstores.links import Link
|
||||
|
||||
|
||||
def test_texts_to_nodes() -> None:
|
||||
assert list(_texts_to_nodes(["a", "b"], [{"a": "b"}, {"c": "d"}], ["a", "b"])) == [
|
||||
Node(id="a", metadata={"a": "b"}, text="a"),
|
||||
Node(id="b", metadata={"c": "d"}, text="b"),
|
||||
]
|
||||
assert list(_texts_to_nodes(["a", "b"], None, ["a", "b"])) == [
|
||||
Node(id="a", metadata={}, text="a"),
|
||||
Node(id="b", metadata={}, text="b"),
|
||||
]
|
||||
assert list(_texts_to_nodes(["a", "b"], [{"a": "b"}, {"c": "d"}], None)) == [
|
||||
Node(metadata={"a": "b"}, text="a"),
|
||||
Node(metadata={"c": "d"}, text="b"),
|
||||
]
|
||||
assert list(
|
||||
_texts_to_nodes(
|
||||
["a"],
|
||||
[{"links": {Link.incoming(kind="hyperlink", tag="http://b")}}],
|
||||
None,
|
||||
)
|
||||
) == [Node(links=[Link.incoming(kind="hyperlink", tag="http://b")], text="a")]
|
||||
with pytest.raises(ValueError):
|
||||
list(_texts_to_nodes(["a", "b"], None, ["a"]))
|
||||
with pytest.raises(ValueError):
|
||||
list(_texts_to_nodes(["a", "b"], [{"a": "b"}], None))
|
||||
with pytest.raises(ValueError):
|
||||
list(_texts_to_nodes(["a"], [{"a": "b"}, {"c": "d"}], None))
|
||||
with pytest.raises(ValueError):
|
||||
list(_texts_to_nodes(["a"], None, ["a", "b"]))
|
||||
|
||||
|
||||
def test_documents_to_nodes() -> None:
|
||||
documents = [
|
||||
Document(
|
||||
id="a",
|
||||
page_content="some text a",
|
||||
metadata={"links": [Link.incoming(kind="hyperlink", tag="http://b")]},
|
||||
),
|
||||
Document(id="b", page_content="some text b", metadata={"c": "d"}),
|
||||
]
|
||||
assert list(_documents_to_nodes(documents)) == [
|
||||
Node(
|
||||
id="a",
|
||||
metadata={},
|
||||
links=[Link.incoming(kind="hyperlink", tag="http://b")],
|
||||
text="some text a",
|
||||
),
|
||||
Node(id="b", metadata={"c": "d"}, text="some text b"),
|
||||
]
|
Loading…
Reference in New Issue
Block a user