Compare commits

...

30 Commits

Author SHA1 Message Date
vowelparrot
db4f5a7318 Update Notebook 2023-04-17 06:54:46 -07:00
vowelparrot
445b5725ab Test fixing 2023-04-16 20:54:25 -07:00
vowelparrot
9e99fbb4b2 Merge branch 'master' into vwp/characters 2023-04-16 20:48:11 -07:00
vowelparrot
5fd6fa80c4 Update Test and Type 2023-04-16 15:50:16 -07:00
vowelparrot
116201d4f7 Update decay rate 2023-04-16 15:47:43 -07:00
vowelparrot
c28524d817 Re-run to test 2023-04-16 13:31:45 -07:00
Harrison Chase
62254e1438 cr 2023-04-15 21:07:36 -07:00
Harrison Chase
5f5670eb00 cr 2023-04-15 21:06:39 -07:00
Harrison Chase
478b5038cc Merge branch 'vwp/characters' of github.com:hwchase17/langchain into vwp/characters 2023-04-15 20:57:27 -07:00
vowelparrot
4681f1f59f cr 2023-04-15 18:56:38 -07:00
vowelparrot
06a66a0e84 Merge branch 'vwp/similarity_search_with_distances' into vwp/characters 2023-04-15 18:50:12 -07:00
vowelparrot
f0ce4c8450 cr 2023-04-15 18:49:51 -07:00
Harrison Chase
1d09d15fdd stash 2023-04-15 18:25:17 -07:00
vowelparrot
585b6654e2 Add Default 2023-04-15 17:15:28 -07:00
vowelparrot
e9cc8db66a Merge branch 'vwp/similarity_search_with_distances' into vwp/characters 2023-04-15 08:14:22 -07:00
vowelparrot
d31d92acdd Update __from 2023-04-15 08:13:49 -07:00
vowelparrot
a742c2ca3d Update the twr implementation 2023-04-15 08:10:21 -07:00
vowelparrot
265fb5f829 Add tests 2023-04-14 16:42:04 -07:00
vowelparrot
285bd56d11 Update nits 2023-04-14 15:26:08 -07:00
vowelparrot
42ea0a0b33 Rm method 2023-04-14 15:25:14 -07:00
vowelparrot
8bd2434862 Update 2023-04-14 15:24:21 -07:00
vowelparrot
fb677984f1 Merge branch 'vwp/similarity_search_with_distances' into vwp/characters 2023-04-14 14:48:46 -07:00
vowelparrot
8b74c6e54a Add similarity_search_with_normalized_similarities
Add a method that exposes a sim ilarity search with corresponding
normalized similarity scores.
2023-04-14 14:47:28 -07:00
Harrison Chase
2c4138f80a cr 2023-04-13 23:38:14 -07:00
Harrison Chase
ee377f4029 stash 2023-04-13 23:25:54 -07:00
Harrison Chase
7226bbcb8b stash 2023-04-13 23:24:26 -07:00
vowelparrot
97ff1bdfcf Update notebook` 2023-04-13 23:09:14 -07:00
Harrison Chase
59f16f7abc cr 2023-04-13 21:20:06 -07:00
vowelparrot
faea852c24 Add a test 2023-04-13 20:13:36 -07:00
vowelparrot
c64e7ac357 Draft Characters 2023-04-13 17:17:05 -07:00
7 changed files with 1771 additions and 3 deletions

View File

@@ -0,0 +1,213 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "a90b7557",
"metadata": {},
"source": [
"# Time Weighted VectorStore Retriever\n",
"\n",
"This retriever uses a combination of semantic similarity and recency.\n",
"\n",
"The algorithm for scoring them is:\n",
"\n",
"```\n",
"semantic_similarity + (1.0 - decay_rate) ** hours_passed\n",
"```\n",
"\n",
"Notably, hours_passed refers to the hours passed since the object in the retriever **was last accessed**, not since it was created. This means that frequently accessed objects remain \"fresh.\""
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "f22cc96b",
"metadata": {},
"outputs": [],
"source": [
"import faiss\n",
"\n",
"from datetime import datetime, timedelta\n",
"from langchain.docstore import InMemoryDocstore\n",
"from langchain.embeddings import OpenAIEmbeddings\n",
"from langchain.retrievers import TimeWeightedVectorStoreRetriever\n",
"from langchain.schema import Document\n",
"from langchain.vectorstores import FAISS\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "6af7ea6b",
"metadata": {},
"source": [
"## Low Decay Rate\n",
"\n",
"A low decay rate (in this, to be extreme, we will set close to 0) means memories will be \"remembered\" for longer. A decay rate of 0 means memories never be forgotten, making this retriever equivalent to the vector lookup."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "c10e7696",
"metadata": {},
"outputs": [],
"source": [
"# Define your embedding model\n",
"embeddings_model = OpenAIEmbeddings()\n",
"# Initialize the vectorstore as empty\n",
"embedding_size = 1536\n",
"index = faiss.IndexFlatL2(embedding_size)\n",
"vectorstore = FAISS(embeddings_model.embed_query, index, InMemoryDocstore({}), {})\n",
"retriever = TimeWeightedVectorStoreRetriever(vectorstore=vectorstore, decay_rate=.0000000000000000000000001, k=1) "
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "86dbadb9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['30fac87b-0357-450c-9a92-63e0147d005c']"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"yesterday = datetime.now() - timedelta(days=1)\n",
"retriever.add_documents([Document(page_content=\"hello world\", metadata={\"last_accessed_at\": yesterday})])\n",
"retriever.add_documents([Document(page_content=\"hello foo\")])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a580be32",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Document(page_content='hello world', metadata={'last_accessed_at': datetime.datetime(2023, 4, 17, 6, 52, 55, 670204), 'created_at': datetime.datetime(2023, 4, 17, 6, 52, 54, 819346), 'buffer_idx': 0})]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# \"Hello World\" is returned first because it is most salient, and the decay rate is close to 0., meaning it's still recent enough\n",
"retriever.get_relevant_documents(\"hello world\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "ca056896",
"metadata": {},
"source": [
"## High Decay Rate\n",
"\n",
"With a high decay factor (e.g., several 9's), the recency score quickly goes to 0! If you set this all the way to 1, recency is 0 for all objects, once again making this equivalent to a vector lookup.\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "dc37669b",
"metadata": {},
"outputs": [],
"source": [
"# Define your embedding model\n",
"embeddings_model = OpenAIEmbeddings()\n",
"# Initialize the vectorstore as empty\n",
"embedding_size = 1536\n",
"index = faiss.IndexFlatL2(embedding_size)\n",
"vectorstore = FAISS(embeddings_model.embed_query, index, InMemoryDocstore({}), {})\n",
"retriever = TimeWeightedVectorStoreRetriever(vectorstore=vectorstore, decay_rate=.999, k=1) "
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "fa284384",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['a8d823e6-baee-46c8-aaa2-519cfb9b59ba']"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"yesterday = datetime.now() - timedelta(days=1)\n",
"retriever.add_documents([Document(page_content=\"hello world\", metadata={\"last_accessed_at\": yesterday})])\n",
"retriever.add_documents([Document(page_content=\"hello foo\")])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "7558f94d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Document(page_content='hello foo', metadata={'last_accessed_at': datetime.datetime(2023, 4, 17, 6, 52, 57, 11210), 'created_at': datetime.datetime(2023, 4, 17, 6, 52, 56, 17459), 'buffer_idx': 1})]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# \"Hello Foo\" is returned first because \"hello world\" is mostly forgotten\n",
"retriever.get_relevant_documents(\"hello world\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bf6d8c90",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

File diff suppressed because it is too large Load Diff

View File

@@ -6,6 +6,9 @@ from langchain.retrievers.pinecone_hybrid_search import PineconeHybridSearchRetr
from langchain.retrievers.remote_retriever import RemoteLangChainRetriever
from langchain.retrievers.svm import SVMRetriever
from langchain.retrievers.tfidf import TFIDFRetriever
from langchain.retrievers.time_weighted_retriever import (
TimeWeightedVectorStoreRetriever,
)
from langchain.retrievers.weaviate_hybrid_search import WeaviateHybridSearchRetriever
__all__ = [
@@ -17,5 +20,6 @@ __all__ = [
"TFIDFRetriever",
"WeaviateHybridSearchRetriever",
"DataberryRetriever",
"TimeWeightedVectorStoreRetriever",
"SVMRetriever",
]

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import List
from typing import List, Optional
import aiohttp
import requests
@@ -13,8 +13,8 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
url: str
bearer_token: str
top_k: int = 3
filter: dict | None = None
aiosession: aiohttp.ClientSession | None = None
filter: Optional[None] = None
aiosession: Optional[aiohttp.ClientSession] = None
class Config:
"""Configuration for this pydantic object."""

View File

@@ -0,0 +1,138 @@
"""Retriever that combines embedding similarity with recency in retrieving values."""
from copy import deepcopy
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from pydantic import BaseModel, Field
from langchain.schema import BaseRetriever, Document
from langchain.vectorstores.base import VectorStore
def _get_hours_passed(time: datetime, ref_time: datetime) -> float:
"""Get the hours passed between two datetime objects."""
return (time - ref_time).total_seconds() / 3600
class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
"""Retriever combining embededing similarity with recency."""
vectorstore: VectorStore
"""The vectorstore to store documents and determine salience."""
search_kwargs: dict = Field(default_factory=lambda: dict(k=100))
"""Keyword arguments to pass to the vectorstore similarity search."""
# TODO: abstract as a queue
memory_stream: List[Document] = Field(default_factory=list)
"""The memory_stream of documents to search through."""
decay_rate: float = Field(default=0.01)
"""The exponential decay factor used as (1.0-decay_rate)**(hrs_passed)."""
k: int = 4
"""The maximum number of documents to retrieve in a given call."""
other_score_keys: List[str] = []
"""Other keys in the metadata to factor into the score, e.g. 'importance'."""
default_salience: Optional[float] = None
"""The salience to assign memories not retrieved from the vector store.
None assigns no salience to documents not fetched from the vector store.
"""
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def _get_combined_score(
self,
document: Document,
vector_relevance: Optional[float],
current_time: datetime,
) -> float:
"""Return the combined score for a document."""
hours_passed = _get_hours_passed(
current_time,
document.metadata["last_accessed_at"],
)
score = (1.0 - self.decay_rate) ** hours_passed
for key in self.other_score_keys:
if key in document.metadata:
score += document.metadata[key]
if vector_relevance is not None:
score += vector_relevance
return score
def get_salient_docs(self, query: str) -> Dict[int, Tuple[Document, float]]:
"""Return documents that are salient to the query."""
docs_and_scores: List[Tuple[Document, float]]
docs_and_scores = self.vectorstore.similarity_search_with_relevance_scores(
query, **self.search_kwargs
)
results = {}
for fetched_doc, relevance in docs_and_scores:
buffer_idx = fetched_doc.metadata["buffer_idx"]
doc = self.memory_stream[buffer_idx]
results[buffer_idx] = (doc, relevance)
return results
def get_relevant_documents(self, query: str) -> List[Document]:
"""Return documents that are relevant to the query."""
current_time = datetime.now()
docs_and_scores = {
doc.metadata["buffer_idx"]: (doc, self.default_salience)
for doc in self.memory_stream[-self.k :]
}
# If a doc is considered salient, update the salience score
docs_and_scores.update(self.get_salient_docs(query))
rescored_docs = [
(doc, self._get_combined_score(doc, relevance, current_time))
for doc, relevance in docs_and_scores.values()
]
rescored_docs.sort(key=lambda x: x[1], reverse=True)
result = []
# Ensure frequently accessed memories aren't forgotten
current_time = datetime.now()
for doc, _ in rescored_docs[: self.k]:
# TODO: Update vector store doc once `update` method is exposed.
buffered_doc = self.memory_stream[doc.metadata["buffer_idx"]]
buffered_doc.metadata["last_accessed_at"] = current_time
result.append(buffered_doc)
return result
async def aget_relevant_documents(self, query: str) -> List[Document]:
"""Return documents that are relevant to the query."""
raise NotImplementedError
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
"""Add documents to vectorstore."""
current_time = kwargs.get("current_time", datetime.now())
# Avoid mutating input documents
dup_docs = [deepcopy(d) for d in documents]
for i, doc in enumerate(dup_docs):
if "last_accessed_at" not in doc.metadata:
doc.metadata["last_accessed_at"] = current_time
if "created_at" not in doc.metadata:
doc.metadata["created_at"] = current_time
doc.metadata["buffer_idx"] = len(self.memory_stream) + i
self.memory_stream.extend(dup_docs)
return self.vectorstore.add_documents(dup_docs, **kwargs)
async def aadd_documents(
self, documents: List[Document], **kwargs: Any
) -> List[str]:
"""Add documents to vectorstore."""
current_time = kwargs.get("current_time", datetime.now())
# Avoid mutating input documents
dup_docs = [deepcopy(d) for d in documents]
for i, doc in enumerate(dup_docs):
if "last_accessed_at" not in doc.metadata:
doc.metadata["last_accessed_at"] = current_time
if "created_at" not in doc.metadata:
doc.metadata["created_at"] = current_time
doc.metadata["buffer_idx"] = len(self.memory_stream) + i
self.memory_stream.extend(dup_docs)
return await self.vectorstore.aadd_documents(dup_docs, **kwargs)

View File

View File

@@ -0,0 +1,163 @@
"""Tests for the time-weighted retriever class."""
from datetime import datetime
from typing import Any, Iterable, List, Optional, Tuple, Type
import pytest
from langchain.embeddings.base import Embeddings
from langchain.retrievers.time_weighted_retriever import (
TimeWeightedVectorStoreRetriever,
_get_hours_passed,
)
from langchain.schema import Document
from langchain.vectorstores.base import VectorStore
def _get_example_memories(k: int = 4) -> List[Document]:
return [
Document(
page_content="foo",
metadata={
"buffer_idx": i,
"last_accessed_at": datetime(2023, 4, 14, 12, 0),
},
)
for i in range(k)
]
class MockVectorStore(VectorStore):
"""Mock invalid vector store."""
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore.
Args:
texts: Iterable of strings to add to the vectorstore.
metadatas: Optional list of metadatas associated with the texts.
kwargs: vectorstore specific parameters
Returns:
List of ids from adding the texts into the vectorstore.
"""
return list(texts)
async def aadd_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore."""
raise NotImplementedError
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
"""Return docs most similar to query."""
return []
@classmethod
def from_documents(
cls: Type["MockVectorStore"],
documents: List[Document],
embedding: Embeddings,
**kwargs: Any,
) -> "MockVectorStore":
"""Return VectorStore initialized from documents and embeddings."""
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
return cls.from_texts(texts, embedding, metadatas=metadatas, **kwargs)
@classmethod
def from_texts(
cls: Type["MockVectorStore"],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> "MockVectorStore":
"""Return VectorStore initialized from texts and embeddings."""
return cls()
def _similarity_search_with_relevance_scores(
self,
query: str,
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs and similarity scores, normalized on a scale from 0 to 1.
0 is dissimilar, 1 is most similar.
"""
return [(doc, 0.5) for doc in _get_example_memories()]
@pytest.fixture
def time_weighted_retriever() -> TimeWeightedVectorStoreRetriever:
vectorstore = MockVectorStore()
return TimeWeightedVectorStoreRetriever(
vectorstore=vectorstore, memory_stream=_get_example_memories()
)
def test__get_hours_passed() -> None:
time1 = datetime(2023, 4, 14, 14, 30)
time2 = datetime(2023, 4, 14, 12, 0)
expected_hours_passed = 2.5
hours_passed = _get_hours_passed(time1, time2)
assert hours_passed == expected_hours_passed
def test_get_combined_score(
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
) -> None:
document = Document(
page_content="Test document",
metadata={"last_accessed_at": datetime(2023, 4, 14, 12, 0)},
)
vector_salience = 0.7
expected_hours_passed = 2.5
current_time = datetime(2023, 4, 14, 14, 30)
combined_score = time_weighted_retriever._get_combined_score(
document, vector_salience, current_time
)
expected_score = (
1.0 - time_weighted_retriever.decay_rate
) ** expected_hours_passed + vector_salience
assert combined_score == pytest.approx(expected_score)
def test_get_salient_docs(
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
) -> None:
query = "Test query"
docs_and_scores = time_weighted_retriever.get_salient_docs(query)
assert isinstance(docs_and_scores, dict)
def test_get_relevant_documents(
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
) -> None:
query = "Test query"
relevant_documents = time_weighted_retriever.get_relevant_documents(query)
assert isinstance(relevant_documents, list)
def test_add_documents(
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
) -> None:
documents = [Document(page_content="test_add_documents document")]
added_documents = time_weighted_retriever.add_documents(documents)
assert isinstance(added_documents, list)
assert len(added_documents) == 1
assert (
time_weighted_retriever.memory_stream[-1].page_content
== documents[0].page_content
)