mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-18 21:09:00 +00:00
core[patch]: Add astream events config test (#17055)
Verify that astream events propagates config correctly --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
609ea019b2
commit
fbab8baac5
@ -1,23 +1,29 @@
|
||||
"""Module that contains tests for runnable.astream_events API."""
|
||||
from itertools import cycle
|
||||
from typing import AsyncIterator, List, Sequence, cast
|
||||
from typing import Any, AsyncIterator, Dict, List, Sequence, cast
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.prompt_values import ChatPromptValue
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.runnables import (
|
||||
ConfigurableField,
|
||||
Runnable,
|
||||
RunnableLambda,
|
||||
)
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
from langchain_core.runnables.schema import StreamEvent
|
||||
from langchain_core.tools import tool
|
||||
from tests.unit_tests.fake.chat_model import GenericFakeChatModel
|
||||
@ -1079,3 +1085,130 @@ async def test_runnable_each() -> None:
|
||||
with pytest.raises(NotImplementedError):
|
||||
async for _ in add_one_map.astream_events([1, 2, 3], version="v1"):
|
||||
pass
|
||||
|
||||
|
||||
async def test_events_astream_config() -> None:
|
||||
"""Test that astream events support accepting config"""
|
||||
infinite_cycle = cycle([AIMessage(content="hello world!")])
|
||||
good_world_on_repeat = cycle([AIMessage(content="Goodbye world")])
|
||||
model = GenericFakeChatModel(messages=infinite_cycle).configurable_fields(
|
||||
messages=ConfigurableField(
|
||||
id="messages",
|
||||
name="Messages",
|
||||
description="Messages return by the LLM",
|
||||
)
|
||||
)
|
||||
|
||||
model_02 = model.with_config({"configurable": {"messages": good_world_on_repeat}})
|
||||
assert model_02.invoke("hello") == AIMessage(content="Goodbye world")
|
||||
|
||||
events = await _collect_events(model_02.astream_events("hello", version="v1"))
|
||||
assert events == [
|
||||
{
|
||||
"data": {"input": "hello"},
|
||||
"event": "on_chat_model_start",
|
||||
"metadata": {},
|
||||
"name": "RunnableConfigurableFields",
|
||||
"run_id": "",
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="Goodbye")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {},
|
||||
"name": "RunnableConfigurableFields",
|
||||
"run_id": "",
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content=" ")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {},
|
||||
"name": "RunnableConfigurableFields",
|
||||
"run_id": "",
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="world")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {},
|
||||
"name": "RunnableConfigurableFields",
|
||||
"run_id": "",
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"output": AIMessageChunk(content="Goodbye world")},
|
||||
"event": "on_chat_model_end",
|
||||
"metadata": {},
|
||||
"name": "RunnableConfigurableFields",
|
||||
"run_id": "",
|
||||
"tags": [],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
async def test_runnable_with_message_history() -> None:
|
||||
class InMemoryHistory(BaseChatMessageHistory, BaseModel):
|
||||
"""In memory implementation of chat message history."""
|
||||
|
||||
# Attention: for the tests use an Any type to work-around a pydantic issue
|
||||
# where it re-instantiates a list, so mutating the list doesn't end up mutating
|
||||
# the content in the store!
|
||||
|
||||
# Using Any type here rather than List[BaseMessage] due to pydantic issue!
|
||||
messages: Any
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Add a self-created message to the store."""
|
||||
self.messages.append(message)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.messages = []
|
||||
|
||||
# Here we use a global variable to store the chat message history.
|
||||
# This will make it easier to inspect it to see the underlying results.
|
||||
store: Dict = {}
|
||||
|
||||
def get_by_session_id(session_id: str) -> BaseChatMessageHistory:
|
||||
"""Get a chat message history"""
|
||||
if session_id not in store:
|
||||
store[session_id] = []
|
||||
return InMemoryHistory(messages=store[session_id])
|
||||
|
||||
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="world")])
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", "You are a cat"),
|
||||
MessagesPlaceholder(variable_name="history"),
|
||||
("human", "{question}"),
|
||||
]
|
||||
)
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
|
||||
chain: Runnable = prompt | model
|
||||
with_message_history = RunnableWithMessageHistory(
|
||||
chain,
|
||||
get_session_history=get_by_session_id,
|
||||
input_messages_key="question",
|
||||
history_messages_key="history",
|
||||
)
|
||||
with_message_history.with_config(
|
||||
{"configurable": {"session_id": "session-123"}}
|
||||
).invoke({"question": "hello"})
|
||||
|
||||
assert store == {
|
||||
"session-123": [HumanMessage(content="hello"), AIMessage(content="hello")]
|
||||
}
|
||||
|
||||
with_message_history.with_config(
|
||||
{"configurable": {"session_id": "session-123"}}
|
||||
).invoke({"question": "meow"})
|
||||
assert store == {
|
||||
"session-123": [
|
||||
HumanMessage(content="hello"),
|
||||
AIMessage(content="hello"),
|
||||
HumanMessage(content="meow"),
|
||||
AIMessage(content="world"),
|
||||
]
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user