diff --git a/libs/langchain/langchain/retrievers/time_weighted_retriever.py b/libs/langchain/langchain/retrievers/time_weighted_retriever.py index 33d553a3cae..7864cc3b097 100644 --- a/libs/langchain/langchain/retrievers/time_weighted_retriever.py +++ b/libs/langchain/langchain/retrievers/time_weighted_retriever.py @@ -2,7 +2,10 @@ import datetime from copy import deepcopy from typing import Any, Dict, List, Optional, Tuple -from langchain_core.callbacks import CallbackManagerForRetrieverRun +from langchain_core.callbacks import ( + AsyncCallbackManagerForRetrieverRun, + CallbackManagerForRetrieverRun, +) from langchain_core.documents import Document from langchain_core.pydantic_v1 import Field from langchain_core.retrievers import BaseRetriever @@ -89,17 +92,26 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever): results[buffer_idx] = (doc, relevance) return results - def _get_relevant_documents( - self, query: str, *, run_manager: CallbackManagerForRetrieverRun + async def aget_salient_docs(self, query: str) -> Dict[int, Tuple[Document, float]]: + """Return documents that are salient to the query.""" + docs_and_scores: List[Tuple[Document, float]] + docs_and_scores = ( + await self.vectorstore.asimilarity_search_with_relevance_scores( + query, **self.search_kwargs + ) + ) + results = {} + for fetched_doc, relevance in docs_and_scores: + if "buffer_idx" in fetched_doc.metadata: + buffer_idx = fetched_doc.metadata["buffer_idx"] + doc = self.memory_stream[buffer_idx] + results[buffer_idx] = (doc, relevance) + return results + + def _get_rescored_docs( + self, docs_and_scores: Dict[Any, Tuple[Document, Optional[float]]] ) -> List[Document]: - """Return documents that are relevant to the query.""" current_time = datetime.datetime.now() - docs_and_scores = { - doc.metadata["buffer_idx"]: (doc, self.default_salience) - for doc in self.memory_stream[-self.k :] - } - # If a doc is considered salient, update the salience score - docs_and_scores.update(self.get_salient_docs(query)) rescored_docs = [ (doc, self._get_combined_score(doc, relevance, current_time)) for doc, relevance in docs_and_scores.values() @@ -114,6 +126,28 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever): result.append(buffered_doc) return result + def _get_relevant_documents( + self, query: str, *, run_manager: CallbackManagerForRetrieverRun + ) -> List[Document]: + docs_and_scores = { + doc.metadata["buffer_idx"]: (doc, self.default_salience) + for doc in self.memory_stream[-self.k :] + } + # If a doc is considered salient, update the salience score + docs_and_scores.update(self.get_salient_docs(query)) + return self._get_rescored_docs(docs_and_scores) + + async def _aget_relevant_documents( + self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun + ) -> List[Document]: + docs_and_scores = { + doc.metadata["buffer_idx"]: (doc, self.default_salience) + for doc in self.memory_stream[-self.k :] + } + # If a doc is considered salient, update the salience score + docs_and_scores.update(await self.aget_salient_docs(query)) + return self._get_rescored_docs(docs_and_scores) + def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]: """Add documents to vectorstore.""" current_time = kwargs.get("current_time") diff --git a/libs/langchain/tests/unit_tests/retrievers/test_time_weighted_retriever.py b/libs/langchain/tests/unit_tests/retrievers/test_time_weighted_retriever.py index cfbf70de49d..9eeb86a8a05 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_time_weighted_retriever.py +++ b/libs/langchain/tests/unit_tests/retrievers/test_time_weighted_retriever.py @@ -36,45 +36,13 @@ class MockVectorStore(VectorStore): metadatas: Optional[List[dict]] = None, **kwargs: Any, ) -> List[str]: - """Run more texts through the embeddings and add to the vectorstore. - - Args: - texts: Iterable of strings to add to the vectorstore. - metadatas: Optional list of metadatas associated with the texts. - kwargs: vectorstore specific parameters - - Returns: - List of ids from adding the texts into the vectorstore. - """ return list(texts) - async def aadd_texts( - self, - texts: Iterable[str], - metadatas: Optional[List[dict]] = None, - **kwargs: Any, - ) -> List[str]: - """Run more texts through the embeddings and add to the vectorstore.""" - raise NotImplementedError - def similarity_search( self, query: str, k: int = 4, **kwargs: Any ) -> List[Document]: - """Return docs most similar to query.""" return [] - @classmethod - def from_documents( - cls: Type["MockVectorStore"], - documents: List[Document], - embedding: Embeddings, - **kwargs: Any, - ) -> "MockVectorStore": - """Return VectorStore initialized from documents and embeddings.""" - texts = [d.page_content for d in documents] - metadatas = [d.metadata for d in documents] - return cls.from_texts(texts, embedding, metadatas=metadatas, **kwargs) - @classmethod def from_texts( cls: Type["MockVectorStore"], @@ -83,7 +51,6 @@ class MockVectorStore(VectorStore): metadatas: Optional[List[dict]] = None, **kwargs: Any, ) -> "MockVectorStore": - """Return VectorStore initialized from texts and embeddings.""" return cls() def _similarity_search_with_relevance_scores( @@ -92,12 +59,16 @@ class MockVectorStore(VectorStore): k: int = 4, **kwargs: Any, ) -> List[Tuple[Document, float]]: - """Return docs and similarity scores, normalized on a scale from 0 to 1. - - 0 is dissimilar, 1 is most similar. - """ return [(doc, 0.5) for doc in _get_example_memories()] + async def _asimilarity_search_with_relevance_scores( + self, + query: str, + k: int = 4, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + return self._similarity_search_with_relevance_scores(query, k, **kwargs) + @pytest.fixture def time_weighted_retriever() -> TimeWeightedVectorStoreRetriever: @@ -146,6 +117,18 @@ def test_get_salient_docs( assert doc in want +async def test_aget_salient_docs( + time_weighted_retriever: TimeWeightedVectorStoreRetriever, +) -> None: + query = "Test query" + docs_and_scores = await time_weighted_retriever.aget_salient_docs(query) + want = [(doc, 0.5) for doc in _get_example_memories()] + assert isinstance(docs_and_scores, dict) + assert len(docs_and_scores) == len(want) + for k, doc in docs_and_scores.items(): + assert doc in want + + def test_get_relevant_documents( time_weighted_retriever: TimeWeightedVectorStoreRetriever, ) -> None: @@ -164,6 +147,24 @@ def test_get_relevant_documents( assert now - timedelta(hours=1) < d.metadata["last_accessed_at"] <= now +async def test_aget_relevant_documents( + time_weighted_retriever: TimeWeightedVectorStoreRetriever, +) -> None: + query = "Test query" + relevant_documents = await time_weighted_retriever.aget_relevant_documents(query) + want = [(doc, 0.5) for doc in _get_example_memories()] + assert isinstance(relevant_documents, list) + assert len(relevant_documents) == len(want) + now = datetime.now() + for doc in relevant_documents: + # assert that the last_accessed_at is close to now. + assert now - timedelta(hours=1) < doc.metadata["last_accessed_at"] <= now + + # assert that the last_accessed_at in the memory stream is updated. + for d in time_weighted_retriever.memory_stream: + assert now - timedelta(hours=1) < d.metadata["last_accessed_at"] <= now + + def test_add_documents( time_weighted_retriever: TimeWeightedVectorStoreRetriever, ) -> None: @@ -175,3 +176,16 @@ def test_add_documents( time_weighted_retriever.memory_stream[-1].page_content == documents[0].page_content ) + + +async def test_aadd_documents( + time_weighted_retriever: TimeWeightedVectorStoreRetriever, +) -> None: + documents = [Document(page_content="test_add_documents document")] + added_documents = await time_weighted_retriever.aadd_documents(documents) + assert isinstance(added_documents, list) + assert len(added_documents) == 1 + assert ( + time_weighted_retriever.memory_stream[-1].page_content + == documents[0].page_content + )