mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-04 04:28:58 +00:00
langchain[patch[: Add async methods to TimeWeightedVectorStoreRetriever (#19606)
This commit is contained in:
committed by
GitHub
parent
aeb7b6b11d
commit
b3d7b5a653
@@ -2,7 +2,10 @@ import datetime
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
from langchain_core.callbacks import (
|
||||||
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
|
CallbackManagerForRetrieverRun,
|
||||||
|
)
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.pydantic_v1 import Field
|
from langchain_core.pydantic_v1 import Field
|
||||||
from langchain_core.retrievers import BaseRetriever
|
from langchain_core.retrievers import BaseRetriever
|
||||||
@@ -89,17 +92,26 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever):
|
|||||||
results[buffer_idx] = (doc, relevance)
|
results[buffer_idx] = (doc, relevance)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def _get_relevant_documents(
|
async def aget_salient_docs(self, query: str) -> Dict[int, Tuple[Document, float]]:
|
||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
"""Return documents that are salient to the query."""
|
||||||
|
docs_and_scores: List[Tuple[Document, float]]
|
||||||
|
docs_and_scores = (
|
||||||
|
await self.vectorstore.asimilarity_search_with_relevance_scores(
|
||||||
|
query, **self.search_kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
results = {}
|
||||||
|
for fetched_doc, relevance in docs_and_scores:
|
||||||
|
if "buffer_idx" in fetched_doc.metadata:
|
||||||
|
buffer_idx = fetched_doc.metadata["buffer_idx"]
|
||||||
|
doc = self.memory_stream[buffer_idx]
|
||||||
|
results[buffer_idx] = (doc, relevance)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _get_rescored_docs(
|
||||||
|
self, docs_and_scores: Dict[Any, Tuple[Document, Optional[float]]]
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""Return documents that are relevant to the query."""
|
|
||||||
current_time = datetime.datetime.now()
|
current_time = datetime.datetime.now()
|
||||||
docs_and_scores = {
|
|
||||||
doc.metadata["buffer_idx"]: (doc, self.default_salience)
|
|
||||||
for doc in self.memory_stream[-self.k :]
|
|
||||||
}
|
|
||||||
# If a doc is considered salient, update the salience score
|
|
||||||
docs_and_scores.update(self.get_salient_docs(query))
|
|
||||||
rescored_docs = [
|
rescored_docs = [
|
||||||
(doc, self._get_combined_score(doc, relevance, current_time))
|
(doc, self._get_combined_score(doc, relevance, current_time))
|
||||||
for doc, relevance in docs_and_scores.values()
|
for doc, relevance in docs_and_scores.values()
|
||||||
@@ -114,6 +126,28 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever):
|
|||||||
result.append(buffered_doc)
|
result.append(buffered_doc)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def _get_relevant_documents(
|
||||||
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
|
) -> List[Document]:
|
||||||
|
docs_and_scores = {
|
||||||
|
doc.metadata["buffer_idx"]: (doc, self.default_salience)
|
||||||
|
for doc in self.memory_stream[-self.k :]
|
||||||
|
}
|
||||||
|
# If a doc is considered salient, update the salience score
|
||||||
|
docs_and_scores.update(self.get_salient_docs(query))
|
||||||
|
return self._get_rescored_docs(docs_and_scores)
|
||||||
|
|
||||||
|
async def _aget_relevant_documents(
|
||||||
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||||
|
) -> List[Document]:
|
||||||
|
docs_and_scores = {
|
||||||
|
doc.metadata["buffer_idx"]: (doc, self.default_salience)
|
||||||
|
for doc in self.memory_stream[-self.k :]
|
||||||
|
}
|
||||||
|
# If a doc is considered salient, update the salience score
|
||||||
|
docs_and_scores.update(await self.aget_salient_docs(query))
|
||||||
|
return self._get_rescored_docs(docs_and_scores)
|
||||||
|
|
||||||
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
||||||
"""Add documents to vectorstore."""
|
"""Add documents to vectorstore."""
|
||||||
current_time = kwargs.get("current_time")
|
current_time = kwargs.get("current_time")
|
||||||
|
@@ -36,45 +36,13 @@ class MockVectorStore(VectorStore):
|
|||||||
metadatas: Optional[List[dict]] = None,
|
metadatas: Optional[List[dict]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""Run more texts through the embeddings and add to the vectorstore.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts: Iterable of strings to add to the vectorstore.
|
|
||||||
metadatas: Optional list of metadatas associated with the texts.
|
|
||||||
kwargs: vectorstore specific parameters
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of ids from adding the texts into the vectorstore.
|
|
||||||
"""
|
|
||||||
return list(texts)
|
return list(texts)
|
||||||
|
|
||||||
async def aadd_texts(
|
|
||||||
self,
|
|
||||||
texts: Iterable[str],
|
|
||||||
metadatas: Optional[List[dict]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[str]:
|
|
||||||
"""Run more texts through the embeddings and add to the vectorstore."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def similarity_search(
|
def similarity_search(
|
||||||
self, query: str, k: int = 4, **kwargs: Any
|
self, query: str, k: int = 4, **kwargs: Any
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""Return docs most similar to query."""
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_documents(
|
|
||||||
cls: Type["MockVectorStore"],
|
|
||||||
documents: List[Document],
|
|
||||||
embedding: Embeddings,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> "MockVectorStore":
|
|
||||||
"""Return VectorStore initialized from documents and embeddings."""
|
|
||||||
texts = [d.page_content for d in documents]
|
|
||||||
metadatas = [d.metadata for d in documents]
|
|
||||||
return cls.from_texts(texts, embedding, metadatas=metadatas, **kwargs)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_texts(
|
def from_texts(
|
||||||
cls: Type["MockVectorStore"],
|
cls: Type["MockVectorStore"],
|
||||||
@@ -83,7 +51,6 @@ class MockVectorStore(VectorStore):
|
|||||||
metadatas: Optional[List[dict]] = None,
|
metadatas: Optional[List[dict]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> "MockVectorStore":
|
) -> "MockVectorStore":
|
||||||
"""Return VectorStore initialized from texts and embeddings."""
|
|
||||||
return cls()
|
return cls()
|
||||||
|
|
||||||
def _similarity_search_with_relevance_scores(
|
def _similarity_search_with_relevance_scores(
|
||||||
@@ -92,12 +59,16 @@ class MockVectorStore(VectorStore):
|
|||||||
k: int = 4,
|
k: int = 4,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[Tuple[Document, float]]:
|
) -> List[Tuple[Document, float]]:
|
||||||
"""Return docs and similarity scores, normalized on a scale from 0 to 1.
|
|
||||||
|
|
||||||
0 is dissimilar, 1 is most similar.
|
|
||||||
"""
|
|
||||||
return [(doc, 0.5) for doc in _get_example_memories()]
|
return [(doc, 0.5) for doc in _get_example_memories()]
|
||||||
|
|
||||||
|
async def _asimilarity_search_with_relevance_scores(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Tuple[Document, float]]:
|
||||||
|
return self._similarity_search_with_relevance_scores(query, k, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def time_weighted_retriever() -> TimeWeightedVectorStoreRetriever:
|
def time_weighted_retriever() -> TimeWeightedVectorStoreRetriever:
|
||||||
@@ -146,6 +117,18 @@ def test_get_salient_docs(
|
|||||||
assert doc in want
|
assert doc in want
|
||||||
|
|
||||||
|
|
||||||
|
async def test_aget_salient_docs(
|
||||||
|
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
|
||||||
|
) -> None:
|
||||||
|
query = "Test query"
|
||||||
|
docs_and_scores = await time_weighted_retriever.aget_salient_docs(query)
|
||||||
|
want = [(doc, 0.5) for doc in _get_example_memories()]
|
||||||
|
assert isinstance(docs_and_scores, dict)
|
||||||
|
assert len(docs_and_scores) == len(want)
|
||||||
|
for k, doc in docs_and_scores.items():
|
||||||
|
assert doc in want
|
||||||
|
|
||||||
|
|
||||||
def test_get_relevant_documents(
|
def test_get_relevant_documents(
|
||||||
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
|
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -164,6 +147,24 @@ def test_get_relevant_documents(
|
|||||||
assert now - timedelta(hours=1) < d.metadata["last_accessed_at"] <= now
|
assert now - timedelta(hours=1) < d.metadata["last_accessed_at"] <= now
|
||||||
|
|
||||||
|
|
||||||
|
async def test_aget_relevant_documents(
|
||||||
|
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
|
||||||
|
) -> None:
|
||||||
|
query = "Test query"
|
||||||
|
relevant_documents = await time_weighted_retriever.aget_relevant_documents(query)
|
||||||
|
want = [(doc, 0.5) for doc in _get_example_memories()]
|
||||||
|
assert isinstance(relevant_documents, list)
|
||||||
|
assert len(relevant_documents) == len(want)
|
||||||
|
now = datetime.now()
|
||||||
|
for doc in relevant_documents:
|
||||||
|
# assert that the last_accessed_at is close to now.
|
||||||
|
assert now - timedelta(hours=1) < doc.metadata["last_accessed_at"] <= now
|
||||||
|
|
||||||
|
# assert that the last_accessed_at in the memory stream is updated.
|
||||||
|
for d in time_weighted_retriever.memory_stream:
|
||||||
|
assert now - timedelta(hours=1) < d.metadata["last_accessed_at"] <= now
|
||||||
|
|
||||||
|
|
||||||
def test_add_documents(
|
def test_add_documents(
|
||||||
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
|
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -175,3 +176,16 @@ def test_add_documents(
|
|||||||
time_weighted_retriever.memory_stream[-1].page_content
|
time_weighted_retriever.memory_stream[-1].page_content
|
||||||
== documents[0].page_content
|
== documents[0].page_content
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_aadd_documents(
|
||||||
|
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
|
||||||
|
) -> None:
|
||||||
|
documents = [Document(page_content="test_add_documents document")]
|
||||||
|
added_documents = await time_weighted_retriever.aadd_documents(documents)
|
||||||
|
assert isinstance(added_documents, list)
|
||||||
|
assert len(added_documents) == 1
|
||||||
|
assert (
|
||||||
|
time_weighted_retriever.memory_stream[-1].page_content
|
||||||
|
== documents[0].page_content
|
||||||
|
)
|
||||||
|
Reference in New Issue
Block a user