mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-16 06:53:16 +00:00
Generative Characters (#2859)
Add a time-weighted memory retriever and a notebook that approximates a Generative Agent from https://arxiv.org/pdf/2304.03442.pdf The "daily plan" components are removed for now since they are less useful without a virtual world, but the memory is an interesting component to build off. --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
0
tests/unit_tests/retrievers/__init__.py
Normal file
0
tests/unit_tests/retrievers/__init__.py
Normal file
163
tests/unit_tests/retrievers/test_time_weighted_retriever.py
Normal file
163
tests/unit_tests/retrievers/test_time_weighted_retriever.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""Tests for the time-weighted retriever class."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Iterable, List, Optional, Tuple, Type
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.retrievers.time_weighted_retriever import (
|
||||
TimeWeightedVectorStoreRetriever,
|
||||
_get_hours_passed,
|
||||
)
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
def _get_example_memories(k: int = 4) -> List[Document]:
|
||||
return [
|
||||
Document(
|
||||
page_content="foo",
|
||||
metadata={
|
||||
"buffer_idx": i,
|
||||
"last_accessed_at": datetime(2023, 4, 14, 12, 0),
|
||||
},
|
||||
)
|
||||
for i in range(k)
|
||||
]
|
||||
|
||||
|
||||
class MockVectorStore(VectorStore):
|
||||
"""Mock invalid vector store."""
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
|
||||
Args:
|
||||
texts: Iterable of strings to add to the vectorstore.
|
||||
metadatas: Optional list of metadatas associated with the texts.
|
||||
kwargs: vectorstore specific parameters
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
"""
|
||||
return list(texts)
|
||||
|
||||
async def aadd_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore."""
|
||||
raise NotImplementedError
|
||||
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to query."""
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_documents(
|
||||
cls: Type["MockVectorStore"],
|
||||
documents: List[Document],
|
||||
embedding: Embeddings,
|
||||
**kwargs: Any,
|
||||
) -> "MockVectorStore":
|
||||
"""Return VectorStore initialized from documents and embeddings."""
|
||||
texts = [d.page_content for d in documents]
|
||||
metadatas = [d.metadata for d in documents]
|
||||
return cls.from_texts(texts, embedding, metadatas=metadatas, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls: Type["MockVectorStore"],
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> "MockVectorStore":
|
||||
"""Return VectorStore initialized from texts and embeddings."""
|
||||
return cls()
|
||||
|
||||
def _similarity_search_with_relevance_scores(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs and similarity scores, normalized on a scale from 0 to 1.
|
||||
|
||||
0 is dissimilar, 1 is most similar.
|
||||
"""
|
||||
return [(doc, 0.5) for doc in _get_example_memories()]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def time_weighted_retriever() -> TimeWeightedVectorStoreRetriever:
|
||||
vectorstore = MockVectorStore()
|
||||
return TimeWeightedVectorStoreRetriever(
|
||||
vectorstore=vectorstore, memory_stream=_get_example_memories()
|
||||
)
|
||||
|
||||
|
||||
def test__get_hours_passed() -> None:
|
||||
time1 = datetime(2023, 4, 14, 14, 30)
|
||||
time2 = datetime(2023, 4, 14, 12, 0)
|
||||
expected_hours_passed = 2.5
|
||||
hours_passed = _get_hours_passed(time1, time2)
|
||||
assert hours_passed == expected_hours_passed
|
||||
|
||||
|
||||
def test_get_combined_score(
|
||||
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
|
||||
) -> None:
|
||||
document = Document(
|
||||
page_content="Test document",
|
||||
metadata={"last_accessed_at": datetime(2023, 4, 14, 12, 0)},
|
||||
)
|
||||
vector_salience = 0.7
|
||||
expected_hours_passed = 2.5
|
||||
current_time = datetime(2023, 4, 14, 14, 30)
|
||||
combined_score = time_weighted_retriever._get_combined_score(
|
||||
document, vector_salience, current_time
|
||||
)
|
||||
expected_score = (
|
||||
1.0 - time_weighted_retriever.decay_rate
|
||||
) ** expected_hours_passed + vector_salience
|
||||
assert combined_score == pytest.approx(expected_score)
|
||||
|
||||
|
||||
def test_get_salient_docs(
|
||||
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
|
||||
) -> None:
|
||||
query = "Test query"
|
||||
docs_and_scores = time_weighted_retriever.get_salient_docs(query)
|
||||
assert isinstance(docs_and_scores, dict)
|
||||
|
||||
|
||||
def test_get_relevant_documents(
|
||||
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
|
||||
) -> None:
|
||||
query = "Test query"
|
||||
relevant_documents = time_weighted_retriever.get_relevant_documents(query)
|
||||
assert isinstance(relevant_documents, list)
|
||||
|
||||
|
||||
def test_add_documents(
|
||||
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
|
||||
) -> None:
|
||||
documents = [Document(page_content="test_add_documents document")]
|
||||
added_documents = time_weighted_retriever.add_documents(documents)
|
||||
assert isinstance(added_documents, list)
|
||||
assert len(added_documents) == 1
|
||||
assert (
|
||||
time_weighted_retriever.memory_stream[-1].page_content
|
||||
== documents[0].page_content
|
||||
)
|
Reference in New Issue
Block a user