support messages in messages out (#20862)

This commit is contained in:
Harrison Chase
2024-04-24 14:58:58 -07:00
committed by GitHub
parent a1614b88ac
commit 43c041cda5
2 changed files with 71 additions and 11 deletions

View File

@@ -1,6 +1,11 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import RunnableLambda
from langchain_core.runnables.config import RunnableConfig
@@ -127,6 +132,41 @@ def test_output_message() -> None:
assert output == AIMessage(content="you said: hello\ngood bye")
def test_input_messages_output_message() -> None:
class LengthChatModel(BaseChatModel):
"""A fake chat model that returns the length of the messages passed in."""
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
return ChatResult(
generations=[
ChatGeneration(message=AIMessage(content=str(len(messages))))
]
)
@property
def _llm_type(self) -> str:
return "length-fake-chat-model"
runnable = LengthChatModel()
get_session_history = _get_get_session_history()
with_history = RunnableWithMessageHistory(
runnable,
get_session_history,
)
config: RunnableConfig = {"configurable": {"session_id": "4"}}
output = with_history.invoke([HumanMessage(content="hi")], config)
assert output.content == "1"
output = with_history.invoke([HumanMessage(content="hi")], config)
assert output.content == "3"
def test_output_messages() -> None:
runnable = RunnableLambda(
lambda input: [