Compare commits

...

6 Commits

Author SHA1 Message Date
Eugene Yurtsev
522c6d1cc2 x 2024-06-04 17:23:45 -04:00
Eugene Yurtsev
1cace093ed x 2024-06-04 17:19:45 -04:00
Eugene Yurtsev
bb65e49630 x 2024-06-04 17:13:59 -04:00
Eugene Yurtsev
7106bc8125 Merge branch 'master' into eugene/async_history_2 2024-06-04 17:13:34 -04:00
Eugene Yurtsev
a1b87065ed x 2024-06-04 17:06:45 -04:00
Eugene Yurtsev
f63202e4b3 qxqx 2024-06-04 17:06:13 -04:00
2 changed files with 80 additions and 2 deletions

View File

@@ -306,8 +306,23 @@ class RunnableWithMessageHistory(RunnableBindingBase):
history_chain = RunnablePassthrough.assign(
**{messages_key: history_chain}
).with_config(run_name="insert_history")
bound = (
bound_sync = (
history_chain | runnable.with_listeners(on_end=self._exit_history)
).with_config(run_name="SyncRunnableWithMessageHistory")
bound_async = (
history_chain | runnable.with_alisteners(on_end=self._aexit_history)
).with_config(run_name="RunnableWithMessageHistory")
def bound(*args, **kwargs):
return bound_sync.invoke(*args, **kwargs)
async def abound(*args, **kwargs):
return await bound_async.ainvoke(*args, **kwargs)
bound = RunnableLambda(
bound,
abound,
).with_config(run_name="RunnableWithMessageHistory")
if history_factory_config:
@@ -483,6 +498,23 @@ class RunnableWithMessageHistory(RunnableBindingBase):
output_messages = self._get_output_messages(output_val)
hist.add_messages(input_messages + output_messages)
async def _aexit_history(self, run: Run, config: RunnableConfig) -> None:
hist: BaseChatMessageHistory = config["configurable"]["message_history"]
# Get the input messages
inputs = load(run.inputs)
input_messages = self._get_input_messages(inputs)
# If historic messages were prepended to the input messages, remove them to
# avoid adding duplicate messages to history.
if not self.history_messages_key:
historic_messages = config["configurable"]["message_history"].messages
input_messages = input_messages[len(historic_messages) :]
# Get the output messages
output_val = load(run.outputs)
output_messages = self._get_output_messages(output_val)
await hist.aadd_messages(input_messages + output_messages)
def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig:
config = super()._merge_configs(*configs)
expected_keys = [field_spec.id for field_spec in self.history_factory_config]

View File

@@ -3,10 +3,11 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables.base import RunnableLambda
from langchain_core.runnables.config import RunnableConfig
from langchain_core.runnables.history import RunnableWithMessageHistory
@@ -404,3 +405,48 @@ def test_using_custom_config_specs() -> None:
]
),
}
class AsyncChatMessageHistory(BaseChatMessageHistory, BaseModel):
"""In memory implementation of chat message history.
Stores messages in an in memory list.
"""
messages: List[BaseMessage] = Field(default_factory=list)
async def aadd_message(self, message: BaseMessage) -> None:
"""Add a self-created message to the store"""
self.messages.append(message)
def clear(self) -> None:
self.messages = []
def _get_get_async_session_history(
*,
store: Optional[Dict[str, Any]] = None,
) -> Callable[..., AsyncChatMessageHistory]:
chat_history_store = store if store is not None else {}
def get_session_history(session_id: str, **kwargs: Any) -> AsyncChatMessageHistory:
if session_id not in chat_history_store:
chat_history_store[session_id] = AsyncChatMessageHistory()
return chat_history_store[session_id]
return get_session_history
async def test_async_input_messages() -> None:
runnable = RunnableLambda(
lambda messages: "you said: "
+ "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage))
)
store: Dict = {}
get_session_history = _get_get_async_session_history(store=store)
with_history = RunnableWithMessageHistory(runnable, get_session_history)
config: RunnableConfig = {"configurable": {"session_id": "1"}}
output = await with_history.ainvoke([HumanMessage(content="hello")], config)
assert output == "you said: hello"
output = await with_history.ainvoke([HumanMessage(content="good bye")], config)
assert output == "you said: hello\ngood bye"