mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 13:23:35 +00:00
core[patch]: Add astream events config test (#17055)
Verify that astream events propagates config correctly --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
609ea019b2
commit
fbab8baac5
@ -1,23 +1,29 @@
|
|||||||
"""Module that contains tests for runnable.astream_events API."""
|
"""Module that contains tests for runnable.astream_events API."""
|
||||||
from itertools import cycle
|
from itertools import cycle
|
||||||
from typing import AsyncIterator, List, Sequence, cast
|
from typing import Any, AsyncIterator, Dict, List, Sequence, cast
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks
|
from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks
|
||||||
|
from langchain_core.chat_history import BaseChatMessageHistory
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
AIMessageChunk,
|
AIMessageChunk,
|
||||||
|
BaseMessage,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
)
|
)
|
||||||
from langchain_core.prompt_values import ChatPromptValue
|
from langchain_core.prompt_values import ChatPromptValue
|
||||||
from langchain_core.prompts import ChatPromptTemplate
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
from langchain_core.retrievers import BaseRetriever
|
from langchain_core.retrievers import BaseRetriever
|
||||||
from langchain_core.runnables import (
|
from langchain_core.runnables import (
|
||||||
|
ConfigurableField,
|
||||||
|
Runnable,
|
||||||
RunnableLambda,
|
RunnableLambda,
|
||||||
)
|
)
|
||||||
|
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||||
from langchain_core.runnables.schema import StreamEvent
|
from langchain_core.runnables.schema import StreamEvent
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from tests.unit_tests.fake.chat_model import GenericFakeChatModel
|
from tests.unit_tests.fake.chat_model import GenericFakeChatModel
|
||||||
@ -1079,3 +1085,130 @@ async def test_runnable_each() -> None:
|
|||||||
with pytest.raises(NotImplementedError):
|
with pytest.raises(NotImplementedError):
|
||||||
async for _ in add_one_map.astream_events([1, 2, 3], version="v1"):
|
async for _ in add_one_map.astream_events([1, 2, 3], version="v1"):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def test_events_astream_config() -> None:
|
||||||
|
"""Test that astream events support accepting config"""
|
||||||
|
infinite_cycle = cycle([AIMessage(content="hello world!")])
|
||||||
|
good_world_on_repeat = cycle([AIMessage(content="Goodbye world")])
|
||||||
|
model = GenericFakeChatModel(messages=infinite_cycle).configurable_fields(
|
||||||
|
messages=ConfigurableField(
|
||||||
|
id="messages",
|
||||||
|
name="Messages",
|
||||||
|
description="Messages return by the LLM",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model_02 = model.with_config({"configurable": {"messages": good_world_on_repeat}})
|
||||||
|
assert model_02.invoke("hello") == AIMessage(content="Goodbye world")
|
||||||
|
|
||||||
|
events = await _collect_events(model_02.astream_events("hello", version="v1"))
|
||||||
|
assert events == [
|
||||||
|
{
|
||||||
|
"data": {"input": "hello"},
|
||||||
|
"event": "on_chat_model_start",
|
||||||
|
"metadata": {},
|
||||||
|
"name": "RunnableConfigurableFields",
|
||||||
|
"run_id": "",
|
||||||
|
"tags": [],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {"chunk": AIMessageChunk(content="Goodbye")},
|
||||||
|
"event": "on_chat_model_stream",
|
||||||
|
"metadata": {},
|
||||||
|
"name": "RunnableConfigurableFields",
|
||||||
|
"run_id": "",
|
||||||
|
"tags": [],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {"chunk": AIMessageChunk(content=" ")},
|
||||||
|
"event": "on_chat_model_stream",
|
||||||
|
"metadata": {},
|
||||||
|
"name": "RunnableConfigurableFields",
|
||||||
|
"run_id": "",
|
||||||
|
"tags": [],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {"chunk": AIMessageChunk(content="world")},
|
||||||
|
"event": "on_chat_model_stream",
|
||||||
|
"metadata": {},
|
||||||
|
"name": "RunnableConfigurableFields",
|
||||||
|
"run_id": "",
|
||||||
|
"tags": [],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {"output": AIMessageChunk(content="Goodbye world")},
|
||||||
|
"event": "on_chat_model_end",
|
||||||
|
"metadata": {},
|
||||||
|
"name": "RunnableConfigurableFields",
|
||||||
|
"run_id": "",
|
||||||
|
"tags": [],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_runnable_with_message_history() -> None:
|
||||||
|
class InMemoryHistory(BaseChatMessageHistory, BaseModel):
|
||||||
|
"""In memory implementation of chat message history."""
|
||||||
|
|
||||||
|
# Attention: for the tests use an Any type to work-around a pydantic issue
|
||||||
|
# where it re-instantiates a list, so mutating the list doesn't end up mutating
|
||||||
|
# the content in the store!
|
||||||
|
|
||||||
|
# Using Any type here rather than List[BaseMessage] due to pydantic issue!
|
||||||
|
messages: Any
|
||||||
|
|
||||||
|
def add_message(self, message: BaseMessage) -> None:
|
||||||
|
"""Add a self-created message to the store."""
|
||||||
|
self.messages.append(message)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
self.messages = []
|
||||||
|
|
||||||
|
# Here we use a global variable to store the chat message history.
|
||||||
|
# This will make it easier to inspect it to see the underlying results.
|
||||||
|
store: Dict = {}
|
||||||
|
|
||||||
|
def get_by_session_id(session_id: str) -> BaseChatMessageHistory:
|
||||||
|
"""Get a chat message history"""
|
||||||
|
if session_id not in store:
|
||||||
|
store[session_id] = []
|
||||||
|
return InMemoryHistory(messages=store[session_id])
|
||||||
|
|
||||||
|
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="world")])
|
||||||
|
|
||||||
|
prompt = ChatPromptTemplate.from_messages(
|
||||||
|
[
|
||||||
|
("system", "You are a cat"),
|
||||||
|
MessagesPlaceholder(variable_name="history"),
|
||||||
|
("human", "{question}"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||||
|
|
||||||
|
chain: Runnable = prompt | model
|
||||||
|
with_message_history = RunnableWithMessageHistory(
|
||||||
|
chain,
|
||||||
|
get_session_history=get_by_session_id,
|
||||||
|
input_messages_key="question",
|
||||||
|
history_messages_key="history",
|
||||||
|
)
|
||||||
|
with_message_history.with_config(
|
||||||
|
{"configurable": {"session_id": "session-123"}}
|
||||||
|
).invoke({"question": "hello"})
|
||||||
|
|
||||||
|
assert store == {
|
||||||
|
"session-123": [HumanMessage(content="hello"), AIMessage(content="hello")]
|
||||||
|
}
|
||||||
|
|
||||||
|
with_message_history.with_config(
|
||||||
|
{"configurable": {"session_id": "session-123"}}
|
||||||
|
).invoke({"question": "meow"})
|
||||||
|
assert store == {
|
||||||
|
"session-123": [
|
||||||
|
HumanMessage(content="hello"),
|
||||||
|
AIMessage(content="hello"),
|
||||||
|
HumanMessage(content="meow"),
|
||||||
|
AIMessage(content="world"),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user