mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-01 04:29:09 +00:00
Harrison/virtual time (#4658)
Co-authored-by: ifsheldon <39153080+ifsheldon@users.noreply.github.com> Co-authored-by: maple.liang <maple.liang@gempoll.com>
This commit is contained in:
parent
f2f2aced6d
commit
243886be93
@ -70,7 +70,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['5c9f7c06-c9eb-45f2-aea5-efce5fb9f2bd']"
|
||||
"['d7f85756-2371-4bdf-9140-052780a0f9b3']"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
@ -93,7 +93,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Document(page_content='hello world', metadata={'last_accessed_at': datetime.datetime(2023, 4, 16, 22, 9, 1, 966261), 'created_at': datetime.datetime(2023, 4, 16, 22, 9, 0, 374683), 'buffer_idx': 0})]"
|
||||
"[Document(page_content='hello world', metadata={'last_accessed_at': datetime.datetime(2023, 5, 13, 21, 0, 27, 678341), 'created_at': datetime.datetime(2023, 5, 13, 21, 0, 27, 279596), 'buffer_idx': 0})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
@ -177,10 +177,51 @@
|
||||
"retriever.get_relevant_documents(\"hello world\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "32e0131e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Virtual Time\n",
|
||||
"\n",
|
||||
"Using some utils in LangChain, you can mock out the time component"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "da080d40",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.utils import mock_now\n",
|
||||
"import datetime"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "7c7deff1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[Document(page_content='hello world', metadata={'last_accessed_at': MockDateTime(2011, 2, 3, 10, 11), 'created_at': datetime.datetime(2023, 5, 13, 21, 0, 27, 279596), 'buffer_idx': 0})]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Notice the last access time is that date time\n",
|
||||
"with mock_now(datetime.datetime(2011, 2, 3, 10, 11)):\n",
|
||||
" print(retriever.get_relevant_documents(\"hello world\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "bf6d8c90",
|
||||
"id": "c78d367d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
|
@ -88,7 +88,9 @@ Relevant context:
|
||||
q2 = f"{entity_name} is {entity_action}"
|
||||
return self.chain(prompt=prompt).run(q1=q1, queries=[q1, q2]).strip()
|
||||
|
||||
def _generate_reaction(self, observation: str, suffix: str) -> str:
|
||||
def _generate_reaction(
|
||||
self, observation: str, suffix: str, now: Optional[datetime] = None
|
||||
) -> str:
|
||||
"""React to a given observation or dialogue act."""
|
||||
prompt = PromptTemplate.from_template(
|
||||
"{agent_summary_description}"
|
||||
@ -101,9 +103,13 @@ Relevant context:
|
||||
+ "\n\n"
|
||||
+ suffix
|
||||
)
|
||||
agent_summary_description = self.get_summary()
|
||||
agent_summary_description = self.get_summary(now=now)
|
||||
relevant_memories_str = self.summarize_related_memories(observation)
|
||||
current_time_str = datetime.now().strftime("%B %d, %Y, %I:%M %p")
|
||||
current_time_str = (
|
||||
datetime.now().strftime("%B %d, %Y, %I:%M %p")
|
||||
if now is None
|
||||
else now.strftime("%B %d, %Y, %I:%M %p")
|
||||
)
|
||||
kwargs: Dict[str, Any] = dict(
|
||||
agent_summary_description=agent_summary_description,
|
||||
current_time=current_time_str,
|
||||
@ -121,7 +127,9 @@ Relevant context:
|
||||
def _clean_response(self, text: str) -> str:
|
||||
return re.sub(f"^{self.name} ", "", text.strip()).strip()
|
||||
|
||||
def generate_reaction(self, observation: str) -> Tuple[bool, str]:
|
||||
def generate_reaction(
|
||||
self, observation: str, now: Optional[datetime] = None
|
||||
) -> Tuple[bool, str]:
|
||||
"""React to a given observation."""
|
||||
call_to_action_template = (
|
||||
"Should {agent_name} react to the observation, and if so,"
|
||||
@ -130,14 +138,17 @@ Relevant context:
|
||||
+ "\notherwise, write:\nREACT: {agent_name}'s reaction (if anything)."
|
||||
+ "\nEither do nothing, react, or say something but not both.\n\n"
|
||||
)
|
||||
full_result = self._generate_reaction(observation, call_to_action_template)
|
||||
full_result = self._generate_reaction(
|
||||
observation, call_to_action_template, now=now
|
||||
)
|
||||
result = full_result.strip().split("\n")[0]
|
||||
# AAA
|
||||
self.memory.save_context(
|
||||
{},
|
||||
{
|
||||
self.memory.add_memory_key: f"{self.name} observed "
|
||||
f"{observation} and reacted by {result}"
|
||||
f"{observation} and reacted by {result}",
|
||||
self.memory.now_key: now,
|
||||
},
|
||||
)
|
||||
if "REACT:" in result:
|
||||
@ -149,14 +160,18 @@ Relevant context:
|
||||
else:
|
||||
return False, result
|
||||
|
||||
def generate_dialogue_response(self, observation: str) -> Tuple[bool, str]:
|
||||
def generate_dialogue_response(
|
||||
self, observation: str, now: Optional[datetime] = None
|
||||
) -> Tuple[bool, str]:
|
||||
"""React to a given observation."""
|
||||
call_to_action_template = (
|
||||
"What would {agent_name} say? To end the conversation, write:"
|
||||
' GOODBYE: "what to say". Otherwise to continue the conversation,'
|
||||
' write: SAY: "what to say next"\n\n'
|
||||
)
|
||||
full_result = self._generate_reaction(observation, call_to_action_template)
|
||||
full_result = self._generate_reaction(
|
||||
observation, call_to_action_template, now=now
|
||||
)
|
||||
result = full_result.strip().split("\n")[0]
|
||||
if "GOODBYE:" in result:
|
||||
farewell = self._clean_response(result.split("GOODBYE:")[-1])
|
||||
@ -164,7 +179,8 @@ Relevant context:
|
||||
{},
|
||||
{
|
||||
self.memory.add_memory_key: f"{self.name} observed "
|
||||
f"{observation} and said {farewell}"
|
||||
f"{observation} and said {farewell}",
|
||||
self.memory.now_key: now,
|
||||
},
|
||||
)
|
||||
return False, f"{self.name} said {farewell}"
|
||||
@ -174,7 +190,8 @@ Relevant context:
|
||||
{},
|
||||
{
|
||||
self.memory.add_memory_key: f"{self.name} observed "
|
||||
f"{observation} and said {response_text}"
|
||||
f"{observation} and said {response_text}",
|
||||
self.memory.now_key: now,
|
||||
},
|
||||
)
|
||||
return True, f"{self.name} said {response_text}"
|
||||
@ -203,9 +220,11 @@ Relevant context:
|
||||
.strip()
|
||||
)
|
||||
|
||||
def get_summary(self, force_refresh: bool = False) -> str:
|
||||
def get_summary(
|
||||
self, force_refresh: bool = False, now: Optional[datetime] = None
|
||||
) -> str:
|
||||
"""Return a descriptive summary of the agent."""
|
||||
current_time = datetime.now()
|
||||
current_time = datetime.now() if now is None else now
|
||||
since_refresh = (current_time - self.last_refreshed).seconds
|
||||
if (
|
||||
not self.summary
|
||||
@ -221,10 +240,13 @@ Relevant context:
|
||||
+ f"\n{self.summary}"
|
||||
)
|
||||
|
||||
def get_full_header(self, force_refresh: bool = False) -> str:
|
||||
def get_full_header(
|
||||
self, force_refresh: bool = False, now: Optional[datetime] = None
|
||||
) -> str:
|
||||
"""Return a full header of the agent's status, summary, and current time."""
|
||||
summary = self.get_summary(force_refresh=force_refresh)
|
||||
current_time_str = datetime.now().strftime("%B %d, %Y, %I:%M %p")
|
||||
now = datetime.now() if now is None else now
|
||||
summary = self.get_summary(force_refresh=force_refresh, now=now)
|
||||
current_time_str = now.strftime("%B %d, %Y, %I:%M %p")
|
||||
return (
|
||||
f"{summary}\nIt is {current_time_str}.\n{self.name}'s status: {self.status}"
|
||||
)
|
||||
|
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain import LLMChain
|
||||
@ -7,6 +8,7 @@ from langchain.base_language import BaseLanguageModel
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.retrievers import TimeWeightedVectorStoreRetriever
|
||||
from langchain.schema import BaseMemory, Document
|
||||
from langchain.utils import mock_now
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -44,6 +46,7 @@ class GenerativeAgentMemory(BaseMemory):
|
||||
relevant_memories_key: str = "relevant_memories"
|
||||
relevant_memories_simple_key: str = "relevant_memories_simple"
|
||||
most_recent_memories_key: str = "most_recent_memories"
|
||||
now_key: str = "now"
|
||||
reflecting: bool = False
|
||||
|
||||
def chain(self, prompt: PromptTemplate) -> LLMChain:
|
||||
@ -68,7 +71,9 @@ class GenerativeAgentMemory(BaseMemory):
|
||||
result = self.chain(prompt).run(observations=observation_str)
|
||||
return self._parse_list(result)
|
||||
|
||||
def _get_insights_on_topic(self, topic: str) -> List[str]:
|
||||
def _get_insights_on_topic(
|
||||
self, topic: str, now: Optional[datetime] = None
|
||||
) -> List[str]:
|
||||
"""Generate 'insights' on a topic of reflection, based on pertinent memories."""
|
||||
prompt = PromptTemplate.from_template(
|
||||
"Statements about {topic}\n"
|
||||
@ -76,7 +81,7 @@ class GenerativeAgentMemory(BaseMemory):
|
||||
+ "What 5 high-level insights can you infer from the above statements?"
|
||||
+ " (example format: insight (because of 1, 5, 3))"
|
||||
)
|
||||
related_memories = self.fetch_memories(topic)
|
||||
related_memories = self.fetch_memories(topic, now=now)
|
||||
related_statements = "\n".join(
|
||||
[
|
||||
f"{i+1}. {memory.page_content}"
|
||||
@ -89,16 +94,16 @@ class GenerativeAgentMemory(BaseMemory):
|
||||
# TODO: Parse the connections between memories and insights
|
||||
return self._parse_list(result)
|
||||
|
||||
def pause_to_reflect(self) -> List[str]:
|
||||
def pause_to_reflect(self, now: Optional[datetime] = None) -> List[str]:
|
||||
"""Reflect on recent observations and generate 'insights'."""
|
||||
if self.verbose:
|
||||
logger.info("Character is reflecting")
|
||||
new_insights = []
|
||||
topics = self._get_topics_of_reflection()
|
||||
for topic in topics:
|
||||
insights = self._get_insights_on_topic(topic)
|
||||
insights = self._get_insights_on_topic(topic, now=now)
|
||||
for insight in insights:
|
||||
self.add_memory(insight)
|
||||
self.add_memory(insight, now=now)
|
||||
new_insights.extend(insights)
|
||||
return new_insights
|
||||
|
||||
@ -122,14 +127,16 @@ class GenerativeAgentMemory(BaseMemory):
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
def add_memory(self, memory_content: str) -> List[str]:
|
||||
def add_memory(
|
||||
self, memory_content: str, now: Optional[datetime] = None
|
||||
) -> List[str]:
|
||||
"""Add an observation or memory to the agent's memory."""
|
||||
importance_score = self._score_memory_importance(memory_content)
|
||||
self.aggregate_importance += importance_score
|
||||
document = Document(
|
||||
page_content=memory_content, metadata={"importance": importance_score}
|
||||
)
|
||||
result = self.memory_retriever.add_documents([document])
|
||||
result = self.memory_retriever.add_documents([document], current_time=now)
|
||||
|
||||
# After an agent has processed a certain amount of memories (as measured by
|
||||
# aggregate importance), it is time to reflect on recent events to add
|
||||
@ -140,15 +147,21 @@ class GenerativeAgentMemory(BaseMemory):
|
||||
and not self.reflecting
|
||||
):
|
||||
self.reflecting = True
|
||||
self.pause_to_reflect()
|
||||
self.pause_to_reflect(now=now)
|
||||
# Hack to clear the importance from reflection
|
||||
self.aggregate_importance = 0.0
|
||||
self.reflecting = False
|
||||
return result
|
||||
|
||||
def fetch_memories(self, observation: str) -> List[Document]:
|
||||
def fetch_memories(
|
||||
self, observation: str, now: Optional[datetime] = None
|
||||
) -> List[Document]:
|
||||
"""Fetch related memories."""
|
||||
return self.memory_retriever.get_relevant_documents(observation)
|
||||
if now is not None:
|
||||
with mock_now(now):
|
||||
return self.memory_retriever.get_relevant_documents(observation)
|
||||
else:
|
||||
return self.memory_retriever.get_relevant_documents(observation)
|
||||
|
||||
def format_memories_detail(self, relevant_memories: List[Document]) -> str:
|
||||
content_strs = set()
|
||||
@ -183,9 +196,10 @@ class GenerativeAgentMemory(BaseMemory):
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Return key-value pairs given the text input to the chain."""
|
||||
queries = inputs.get(self.queries_key)
|
||||
now = inputs.get(self.now_key)
|
||||
if queries is not None:
|
||||
relevant_memories = [
|
||||
mem for query in queries for mem in self.fetch_memories(query)
|
||||
mem for query in queries for mem in self.fetch_memories(query, now=now)
|
||||
]
|
||||
return {
|
||||
self.relevant_memories_key: self.format_memories_detail(
|
||||
@ -205,12 +219,13 @@ class GenerativeAgentMemory(BaseMemory):
|
||||
}
|
||||
return {}
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
|
||||
"""Save the context of this model run to memory."""
|
||||
# TODO: fix the save memory key
|
||||
mem = outputs.get(self.add_memory_key)
|
||||
now = outputs.get(self.now_key)
|
||||
if mem:
|
||||
self.add_memory(mem)
|
||||
self.add_memory(mem, now=now)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Retriever that combines embedding similarity with recency in retrieving values."""
|
||||
import datetime
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@ -9,7 +9,7 @@ from langchain.schema import BaseRetriever, Document
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
def _get_hours_passed(time: datetime, ref_time: datetime) -> float:
|
||||
def _get_hours_passed(time: datetime.datetime, ref_time: datetime.datetime) -> float:
|
||||
"""Get the hours passed between two datetime objects."""
|
||||
return (time - ref_time).total_seconds() / 3600
|
||||
|
||||
@ -51,7 +51,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
|
||||
self,
|
||||
document: Document,
|
||||
vector_relevance: Optional[float],
|
||||
current_time: datetime,
|
||||
current_time: datetime.datetime,
|
||||
) -> float:
|
||||
"""Return the combined score for a document."""
|
||||
hours_passed = _get_hours_passed(
|
||||
@ -82,7 +82,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
"""Return documents that are relevant to the query."""
|
||||
current_time = datetime.now()
|
||||
current_time = datetime.datetime.now()
|
||||
docs_and_scores = {
|
||||
doc.metadata["buffer_idx"]: (doc, self.default_salience)
|
||||
for doc in self.memory_stream[-self.k :]
|
||||
@ -96,7 +96,6 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
|
||||
rescored_docs.sort(key=lambda x: x[1], reverse=True)
|
||||
result = []
|
||||
# Ensure frequently accessed memories aren't forgotten
|
||||
current_time = datetime.now()
|
||||
for doc, _ in rescored_docs[: self.k]:
|
||||
# TODO: Update vector store doc once `update` method is exposed.
|
||||
buffered_doc = self.memory_stream[doc.metadata["buffer_idx"]]
|
||||
@ -110,7 +109,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
|
||||
|
||||
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
||||
"""Add documents to vectorstore."""
|
||||
current_time = kwargs.get("current_time", datetime.now())
|
||||
current_time = kwargs.get("current_time", datetime.datetime.now())
|
||||
# Avoid mutating input documents
|
||||
dup_docs = [deepcopy(d) for d in documents]
|
||||
for i, doc in enumerate(dup_docs):
|
||||
@ -126,7 +125,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
|
||||
self, documents: List[Document], **kwargs: Any
|
||||
) -> List[str]:
|
||||
"""Add documents to vectorstore."""
|
||||
current_time = kwargs.get("current_time", datetime.now())
|
||||
current_time = kwargs.get("current_time", datetime.datetime.now())
|
||||
# Avoid mutating input documents
|
||||
dup_docs = [deepcopy(d) for d in documents]
|
||||
for i, doc in enumerate(dup_docs):
|
||||
|
@ -1,4 +1,6 @@
|
||||
"""Generic utility functions."""
|
||||
import contextlib
|
||||
import datetime
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
@ -78,3 +80,34 @@ def stringify_dict(data: dict) -> str:
|
||||
for key, value in data.items():
|
||||
text += key + ": " + stringify_value(value) + "\n"
|
||||
return text
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def mock_now(dt_value): # type: ignore
|
||||
"""Context manager for mocking out datetime.now() in unit tests.
|
||||
Example:
|
||||
with mock_now(datetime.datetime(2011, 2, 3, 10, 11)):
|
||||
assert datetime.datetime.now() == datetime.datetime(2011, 2, 3, 10, 11)
|
||||
"""
|
||||
|
||||
class MockDateTime(datetime.datetime):
|
||||
@classmethod
|
||||
def now(cls): # type: ignore
|
||||
# Create a copy of dt_value.
|
||||
return datetime.datetime(
|
||||
dt_value.year,
|
||||
dt_value.month,
|
||||
dt_value.day,
|
||||
dt_value.hour,
|
||||
dt_value.minute,
|
||||
dt_value.second,
|
||||
dt_value.microsecond,
|
||||
dt_value.tzinfo,
|
||||
)
|
||||
|
||||
real_datetime = datetime.datetime
|
||||
datetime.datetime = MockDateTime
|
||||
try:
|
||||
yield datetime.datetime
|
||||
finally:
|
||||
datetime.datetime = real_datetime
|
||||
|
Loading…
Reference in New Issue
Block a user