mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-22 11:00:37 +00:00
support messages in messages out (#20862)
This commit is contained in:
@@ -25,6 +25,7 @@ from langchain_core.runnables.utils import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.language_models import LanguageModelLike
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langchain_core.tracers.schemas import Run
|
||||
@@ -228,9 +229,12 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runnable: Runnable[
|
||||
MessagesOrDictWithMessages,
|
||||
Union[str, BaseMessage, MessagesOrDictWithMessages],
|
||||
runnable: Union[
|
||||
Runnable[
|
||||
Union[MessagesOrDictWithMessages],
|
||||
Union[str, BaseMessage, MessagesOrDictWithMessages],
|
||||
],
|
||||
LanguageModelLike,
|
||||
],
|
||||
get_session_history: GetSessionHistoryCallable,
|
||||
*,
|
||||
@@ -364,10 +368,19 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
return super_schema
|
||||
|
||||
def _get_input_messages(
|
||||
self, input_val: Union[str, BaseMessage, Sequence[BaseMessage]]
|
||||
self, input_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
|
||||
) -> List[BaseMessage]:
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
if isinstance(input_val, dict):
|
||||
if self.input_messages_key:
|
||||
key = self.input_messages_key
|
||||
elif len(input_val) == 1:
|
||||
key = list(input_val.keys())[0]
|
||||
else:
|
||||
key = "input"
|
||||
input_val = input_val[key]
|
||||
|
||||
if isinstance(input_val, str):
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
@@ -388,7 +401,18 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
if isinstance(output_val, dict):
|
||||
output_val = output_val[self.output_messages_key or "output"]
|
||||
if self.output_messages_key:
|
||||
key = self.output_messages_key
|
||||
elif len(output_val) == 1:
|
||||
key = list(output_val.keys())[0]
|
||||
else:
|
||||
key = "output"
|
||||
# If you are wrapping a chat model directly
|
||||
# The output is actually this weird generations object
|
||||
if key not in output_val and "generations" in output_val:
|
||||
output_val = output_val["generations"][0][0]["message"]
|
||||
else:
|
||||
output_val = output_val[key]
|
||||
|
||||
if isinstance(output_val, str):
|
||||
from langchain_core.messages import AIMessage
|
||||
@@ -407,10 +431,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
|
||||
if not self.history_messages_key:
|
||||
# return all messages
|
||||
input_val = (
|
||||
input if not self.input_messages_key else input[self.input_messages_key]
|
||||
)
|
||||
messages += self._get_input_messages(input_val)
|
||||
messages += self._get_input_messages(input)
|
||||
return messages
|
||||
|
||||
async def _aenter_history(
|
||||
@@ -432,8 +453,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
|
||||
# Get the input messages
|
||||
inputs = load(run.inputs)
|
||||
input_val = inputs[self.input_messages_key or "input"]
|
||||
input_messages = self._get_input_messages(input_val)
|
||||
input_messages = self._get_input_messages(inputs)
|
||||
|
||||
# If historic messages were prepended to the input messages, remove them to
|
||||
# avoid adding duplicate messages to history.
|
||||
|
Reference in New Issue
Block a user