mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 06:33:41 +00:00
x
This commit is contained in:
@@ -372,6 +372,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
) -> List[BaseMessage]:
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
# If dictionary, try to pluck the single key representing messages
|
||||
if isinstance(input_val, dict):
|
||||
if self.input_messages_key:
|
||||
key = self.input_messages_key
|
||||
@@ -381,13 +382,25 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
key = "input"
|
||||
input_val = input_val[key]
|
||||
|
||||
# If value is a string, convert to a human message
|
||||
if isinstance(input_val, str):
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
return [HumanMessage(content=input_val)]
|
||||
# If value is a single message, convert to a list
|
||||
elif isinstance(input_val, BaseMessage):
|
||||
return [input_val]
|
||||
# If value is a list or tuple...
|
||||
elif isinstance(input_val, (list, tuple)):
|
||||
# Handle empty case
|
||||
if len(input_val) == 0:
|
||||
return list(input_val)
|
||||
# If is a list of list, then return the first value
|
||||
# This occurs for chat models - since we batch inputs
|
||||
if isinstance(input_val[0], list):
|
||||
if len(input_val) != 1:
|
||||
raise ValueError()
|
||||
return input_val[0]
|
||||
return list(input_val)
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -400,6 +413,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
) -> List[BaseMessage]:
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
# If dictionary, try to pluck the single key representing messages
|
||||
if isinstance(output_val, dict):
|
||||
if self.output_messages_key:
|
||||
key = self.output_messages_key
|
||||
@@ -418,6 +432,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
return [AIMessage(content=output_val)]
|
||||
# If value is a single message, convert to a list
|
||||
elif isinstance(output_val, BaseMessage):
|
||||
return [output_val]
|
||||
elif isinstance(output_val, (list, tuple)):
|
||||
@@ -431,7 +446,10 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
|
||||
if not self.history_messages_key:
|
||||
# return all messages
|
||||
messages += self._get_input_messages(input)
|
||||
input_val = (
|
||||
input if not self.input_messages_key else input[self.input_messages_key]
|
||||
)
|
||||
messages += self._get_input_messages(input_val)
|
||||
return messages
|
||||
|
||||
async def _aenter_history(
|
||||
@@ -454,7 +472,6 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
# Get the input messages
|
||||
inputs = load(run.inputs)
|
||||
input_messages = self._get_input_messages(inputs)
|
||||
|
||||
# If historic messages were prepended to the input messages, remove them to
|
||||
# avoid adding duplicate messages to history.
|
||||
if not self.history_messages_key:
|
||||
@@ -466,23 +483,6 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
output_messages = self._get_output_messages(output_val)
|
||||
hist.add_messages(input_messages + output_messages)
|
||||
|
||||
async def _aexit_history(self, run: Run, config: RunnableConfig) -> None:
|
||||
hist: BaseChatMessageHistory = config["configurable"]["message_history"]
|
||||
|
||||
# Get the input messages
|
||||
inputs = load(run.inputs)
|
||||
input_messages = self._get_input_messages(inputs)
|
||||
|
||||
# If historic messages were prepended to the input messages, remove them to
|
||||
# avoid adding duplicate messages to history.
|
||||
if not self.history_messages_key:
|
||||
historic_messages = config["configurable"]["message_history"].messages
|
||||
input_messages = input_messages[len(historic_messages) :]
|
||||
|
||||
# Get the output messages
|
||||
output_val = load(run.outputs)
|
||||
output_messages = self._get_output_messages(output_val)
|
||||
await hist.aadd_messages(input_messages + output_messages)
|
||||
def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig:
|
||||
config = super()._merge_configs(*configs)
|
||||
expected_keys = [field_spec.id for field_spec in self.history_factory_config]
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
@@ -9,7 +7,7 @@ from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.runnables.base import RunnableLambda
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
|
||||
Reference in New Issue
Block a user