mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 23:29:21 +00:00
support messages in messages out (#20862)
This commit is contained in:
parent
a1614b88ac
commit
43c041cda5
@ -25,6 +25,7 @@ from langchain_core.runnables.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from langchain_core.language_models import LanguageModelLike
|
||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
from langchain_core.runnables.config import RunnableConfig
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
from langchain_core.tracers.schemas import Run
|
from langchain_core.tracers.schemas import Run
|
||||||
@ -228,9 +229,12 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
runnable: Runnable[
|
runnable: Union[
|
||||||
MessagesOrDictWithMessages,
|
Runnable[
|
||||||
Union[str, BaseMessage, MessagesOrDictWithMessages],
|
Union[MessagesOrDictWithMessages],
|
||||||
|
Union[str, BaseMessage, MessagesOrDictWithMessages],
|
||||||
|
],
|
||||||
|
LanguageModelLike,
|
||||||
],
|
],
|
||||||
get_session_history: GetSessionHistoryCallable,
|
get_session_history: GetSessionHistoryCallable,
|
||||||
*,
|
*,
|
||||||
@ -364,10 +368,19 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
return super_schema
|
return super_schema
|
||||||
|
|
||||||
def _get_input_messages(
|
def _get_input_messages(
|
||||||
self, input_val: Union[str, BaseMessage, Sequence[BaseMessage]]
|
self, input_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
|
||||||
) -> List[BaseMessage]:
|
) -> List[BaseMessage]:
|
||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
|
|
||||||
|
if isinstance(input_val, dict):
|
||||||
|
if self.input_messages_key:
|
||||||
|
key = self.input_messages_key
|
||||||
|
elif len(input_val) == 1:
|
||||||
|
key = list(input_val.keys())[0]
|
||||||
|
else:
|
||||||
|
key = "input"
|
||||||
|
input_val = input_val[key]
|
||||||
|
|
||||||
if isinstance(input_val, str):
|
if isinstance(input_val, str):
|
||||||
from langchain_core.messages import HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
@ -388,7 +401,18 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
|
|
||||||
if isinstance(output_val, dict):
|
if isinstance(output_val, dict):
|
||||||
output_val = output_val[self.output_messages_key or "output"]
|
if self.output_messages_key:
|
||||||
|
key = self.output_messages_key
|
||||||
|
elif len(output_val) == 1:
|
||||||
|
key = list(output_val.keys())[0]
|
||||||
|
else:
|
||||||
|
key = "output"
|
||||||
|
# If you are wrapping a chat model directly
|
||||||
|
# The output is actually this weird generations object
|
||||||
|
if key not in output_val and "generations" in output_val:
|
||||||
|
output_val = output_val["generations"][0][0]["message"]
|
||||||
|
else:
|
||||||
|
output_val = output_val[key]
|
||||||
|
|
||||||
if isinstance(output_val, str):
|
if isinstance(output_val, str):
|
||||||
from langchain_core.messages import AIMessage
|
from langchain_core.messages import AIMessage
|
||||||
@ -407,10 +431,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
|
|
||||||
if not self.history_messages_key:
|
if not self.history_messages_key:
|
||||||
# return all messages
|
# return all messages
|
||||||
input_val = (
|
messages += self._get_input_messages(input)
|
||||||
input if not self.input_messages_key else input[self.input_messages_key]
|
|
||||||
)
|
|
||||||
messages += self._get_input_messages(input_val)
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
async def _aenter_history(
|
async def _aenter_history(
|
||||||
@ -432,8 +453,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
|
|
||||||
# Get the input messages
|
# Get the input messages
|
||||||
inputs = load(run.inputs)
|
inputs = load(run.inputs)
|
||||||
input_val = inputs[self.input_messages_key or "input"]
|
input_messages = self._get_input_messages(inputs)
|
||||||
input_messages = self._get_input_messages(input_val)
|
|
||||||
|
|
||||||
# If historic messages were prepended to the input messages, remove them to
|
# If historic messages were prepended to the input messages, remove them to
|
||||||
# avoid adding duplicate messages to history.
|
# avoid adding duplicate messages to history.
|
||||||
|
@ -1,6 +1,11 @@
|
|||||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
|
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
|
||||||
|
|
||||||
|
from langchain_core.callbacks import (
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||||
|
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
from langchain_core.runnables.base import RunnableLambda
|
from langchain_core.runnables.base import RunnableLambda
|
||||||
from langchain_core.runnables.config import RunnableConfig
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
@ -127,6 +132,41 @@ def test_output_message() -> None:
|
|||||||
assert output == AIMessage(content="you said: hello\ngood bye")
|
assert output == AIMessage(content="you said: hello\ngood bye")
|
||||||
|
|
||||||
|
|
||||||
|
def test_input_messages_output_message() -> None:
|
||||||
|
class LengthChatModel(BaseChatModel):
|
||||||
|
"""A fake chat model that returns the length of the messages passed in."""
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
"""Top Level call"""
|
||||||
|
return ChatResult(
|
||||||
|
generations=[
|
||||||
|
ChatGeneration(message=AIMessage(content=str(len(messages))))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "length-fake-chat-model"
|
||||||
|
|
||||||
|
runnable = LengthChatModel()
|
||||||
|
get_session_history = _get_get_session_history()
|
||||||
|
with_history = RunnableWithMessageHistory(
|
||||||
|
runnable,
|
||||||
|
get_session_history,
|
||||||
|
)
|
||||||
|
config: RunnableConfig = {"configurable": {"session_id": "4"}}
|
||||||
|
output = with_history.invoke([HumanMessage(content="hi")], config)
|
||||||
|
assert output.content == "1"
|
||||||
|
output = with_history.invoke([HumanMessage(content="hi")], config)
|
||||||
|
assert output.content == "3"
|
||||||
|
|
||||||
|
|
||||||
def test_output_messages() -> None:
|
def test_output_messages() -> None:
|
||||||
runnable = RunnableLambda(
|
runnable = RunnableLambda(
|
||||||
lambda input: [
|
lambda input: [
|
||||||
|
Loading…
Reference in New Issue
Block a user