Update decay rate

This commit is contained in:
vowelparrot
2023-04-16 15:47:43 -07:00
parent c28524d817
commit 116201d4f7
2 changed files with 103 additions and 118 deletions

View File

@@ -1,6 +1,7 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "a90b7557",
"metadata": {},
@@ -9,44 +10,46 @@
"\n",
"This retriever uses a combination of semantic similarity and recency.\n",
"\n",
"The algorithm for combining them is basically:\n",
"The algorithm for scoring them is:\n",
"\n",
"```\n",
"semantic_similarity + decay_factor ** hours_passed\n",
"semantic_similarity + (1.0 - decay_rate) ** hours_passed\n",
"```\n",
"\n",
"Notably, hours_passed refers to the hours passed since the object in the retriever was last accessed, not since it ws created."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "f22cc96b",
"metadata": {},
"outputs": [],
"source": [
"from langchain.retrievers import TimeWeightedVectorStoreRetriever\n",
"import faiss\n",
"from langchain.vectorstores import FAISS\n",
"from langchain.docstore import InMemoryDocstore\n",
"from langchain.embeddings import OpenAIEmbeddings\n",
"from langchain.schema import Document\n",
"import time"
]
},
{
"cell_type": "markdown",
"id": "6af7ea6b",
"metadata": {},
"source": [
"## High decay factor\n",
"\n",
"With a relatively high decay factor, this will mostly return documents that are semantically similar"
"Notably, hours_passed refers to the hours passed since the object in the retriever **was last accessed**, not since it was created. This means that frequently accessed objects remain \"fresh.\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f22cc96b",
"metadata": {},
"outputs": [],
"source": [
"import faiss\n",
"\n",
"from datetime import datetime, timedelta\n",
"from langchain.docstore import InMemoryDocstore\n",
"from langchain.embeddings import OpenAIEmbeddings\n",
"from langchain.retrievers import TimeWeightedVectorStoreRetriever\n",
"from langchain.schema import Document\n",
"from langchain.vectorstores import FAISS\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "6af7ea6b",
"metadata": {},
"source": [
"## Low Decay Rate\n",
"\n",
"A low decay rate (in this, to be extreme, we will set close to 0) means memories will be \"remembered\" for longer. A decay rate of 0 means memories never be forgotten, making this retriever equivalent to the vector lookup."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c10e7696",
"metadata": {},
"outputs": [],
@@ -57,52 +60,19 @@
"embedding_size = 1536\n",
"index = faiss.IndexFlatL2(embedding_size)\n",
"vectorstore = FAISS(embeddings_model.embed_query, index, InMemoryDocstore({}), {})\n",
"retriever = TimeWeightedVectorStoreRetriever(vectorstore=vectorstore, decay_factor=.99, k=1) "
"retriever = TimeWeightedVectorStoreRetriever(vectorstore=vectorstore, decay_rate=.0000000000000000000000001, k=1) "
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"id": "86dbadb9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['f6303531-d3a5-44af-b7c8-e3cf76916ce5']"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"retriever.add_documents([Document(page_content=\"hello world\")])\n",
"time.sleep(20)\n",
"retriever.add_documents([Document(page_content=\"hello foo\")])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a580be32",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9999994359334345\n",
"1.8408203353689756\n",
"0.9999428041917008\n",
"1.9999408025741263\n"
]
},
{
"data": {
"text/plain": [
"[Document(page_content='hello world', metadata={'last_accessed_at': datetime.datetime(2023, 4, 15, 21, 4, 41, 457055), 'created_at': datetime.datetime(2023, 4, 15, 21, 4, 20, 437090), 'buffer_idx': 0})]"
"['129ba56b-7e7f-480b-83b3-8138a7f5db4a']"
]
},
"execution_count": 4,
@@ -111,21 +81,47 @@
}
],
"source": [
"retriever.get_relevant_documents(\"hello world\")"
]
},
{
"cell_type": "markdown",
"id": "ca056896",
"metadata": {},
"source": [
"## Low decay factor\n",
"With a low decay factor (in this, to be extreme, we will set close to 0) this will return most recent docs"
"yesterday = datetime.now() - timedelta(days=1)\n",
"retriever.add_documents([Document(page_content=\"hello world\", metadata={\"last_accessed_at\": yesterday})])\n",
"retriever.add_documents([Document(page_content=\"hello foo\")])"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 9,
"id": "a580be32",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Document(page_content='hello foo', metadata={'last_accessed_at': datetime.datetime(2023, 4, 16, 15, 46, 43, 860748), 'created_at': datetime.datetime(2023, 4, 16, 15, 46, 14, 469670), 'buffer_idx': 1})]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# \"Hello World\" is returned first because it is most salient, and the decay rate is close to 0., meaning it's still recent enough\n",
"retriever.get_relevant_documents(\"hello world\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "ca056896",
"metadata": {},
"source": [
"## High Decay Rate\n",
"\n",
"With a high decay factor (e.g., several 9's), the recency score quickly goes to 0! If you set this all the way to 1, recency is 0 for all objects, once again making this equivalent to a vector lookup.\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "dc37669b",
"metadata": {},
"outputs": [],
@@ -136,52 +132,19 @@
"embedding_size = 1536\n",
"index = faiss.IndexFlatL2(embedding_size)\n",
"vectorstore = FAISS(embeddings_model.embed_query, index, InMemoryDocstore({}), {})\n",
"retriever = TimeWeightedVectorStoreRetriever(vectorstore=vectorstore, decay_factor=.0000000000000000000000001, k=1) "
"retriever = TimeWeightedVectorStoreRetriever(vectorstore=vectorstore, decay_rate=.999, k=1) "
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"id": "fa284384",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['f063e5b2-c2eb-42fc-8894-79c7a7a9038e']"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"retriever.add_documents([Document(page_content=\"hello world\")])\n",
"time.sleep(20)\n",
"retriever.add_documents([Document(page_content=\"hello foo\")])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "7558f94d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9978079943966991\n",
"1.8412067636745604\n",
"0.7235696005754303\n",
"1.7230094271411442\n"
]
},
{
"data": {
"text/plain": [
"[Document(page_content='hello foo', metadata={'last_accessed_at': datetime.datetime(2023, 4, 15, 21, 5, 2, 243331), 'created_at': datetime.datetime(2023, 4, 15, 21, 5, 1, 579028), 'buffer_idx': 1})]"
"['8fff7ef8-3a30-40f3-b42e-b8d5c7850863']"
]
},
"execution_count": 7,
@@ -190,6 +153,30 @@
}
],
"source": [
"yesterday = datetime.now() - timedelta(days=1)\n",
"retriever.add_documents([Document(page_content=\"hello world\", metadata={\"last_accessed_at\": yesterday})])\n",
"retriever.add_documents([Document(page_content=\"hello foo\")])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "7558f94d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Document(page_content='hello foo', metadata={'last_accessed_at': datetime.datetime(2023, 4, 16, 15, 46, 17, 646927), 'created_at': datetime.datetime(2023, 4, 16, 15, 46, 14, 469670), 'buffer_idx': 1})]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# \"Hello Foo\" is returned first because \"hello world\" is mostly forgotten\n",
"retriever.get_relevant_documents(\"hello world\")"
]
},
@@ -218,7 +205,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
"version": "3.11.2"
}
},
"nbformat": 4,