mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-01 19:12:42 +00:00
support messages in messages out (#20862)
This commit is contained in:
@@ -1,6 +1,11 @@
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables.base import RunnableLambda
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
@@ -127,6 +132,41 @@ def test_output_message() -> None:
|
||||
assert output == AIMessage(content="you said: hello\ngood bye")
|
||||
|
||||
|
||||
def test_input_messages_output_message() -> None:
|
||||
class LengthChatModel(BaseChatModel):
|
||||
"""A fake chat model that returns the length of the messages passed in."""
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
return ChatResult(
|
||||
generations=[
|
||||
ChatGeneration(message=AIMessage(content=str(len(messages))))
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "length-fake-chat-model"
|
||||
|
||||
runnable = LengthChatModel()
|
||||
get_session_history = _get_get_session_history()
|
||||
with_history = RunnableWithMessageHistory(
|
||||
runnable,
|
||||
get_session_history,
|
||||
)
|
||||
config: RunnableConfig = {"configurable": {"session_id": "4"}}
|
||||
output = with_history.invoke([HumanMessage(content="hi")], config)
|
||||
assert output.content == "1"
|
||||
output = with_history.invoke([HumanMessage(content="hi")], config)
|
||||
assert output.content == "3"
|
||||
|
||||
|
||||
def test_output_messages() -> None:
|
||||
runnable = RunnableLambda(
|
||||
lambda input: [
|
||||
|
Reference in New Issue
Block a user