mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-02 03:26:17 +00:00
core: Two updates to chat model interface (#19684)
- .stream() and .astream() call on_llm_new_token, removing the need for subclasses to do so. Backwards compatible because now we don't pass run_manager into ._stream and ._astream - .generate() and .agenerate() now handle `stream: bool` kwarg for _generate and _agenerate. Subclasses handle this arg by delegating to ._stream(), now one less thing they need to do. Backwards compat because this is an optional arg that we now never pass to the subclasses - .generate() and .agenerate() now inspect callback handlers to decide on a default value for stream:bool if not passed in. This auto enables streaming when using astream_events and astream_log - as a result of these three changes any usage of .astream_events and .astream_log should now yield chat model stream events - In future PRs we can update all subclasses to reflect these two things now handled by base class, but in meantime all will continue to work
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
"""Module that contains tests for runnable.astream_events API."""
|
||||
import sys
|
||||
from itertools import cycle
|
||||
from typing import Any, AsyncIterator, Dict, List, Sequence, cast
|
||||
|
||||
@@ -22,6 +23,7 @@ from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.runnables import (
|
||||
ConfigurableField,
|
||||
Runnable,
|
||||
RunnableConfig,
|
||||
RunnableLambda,
|
||||
)
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
@@ -314,9 +316,7 @@ async def test_event_stream_with_lambdas_from_lambda() -> None:
|
||||
|
||||
async def test_astream_events_from_model() -> None:
|
||||
"""Test the output of a model."""
|
||||
infinite_cycle = cycle(
|
||||
[AIMessage(content="hello world!"), AIMessage(content="goodbye world!")]
|
||||
)
|
||||
infinite_cycle = cycle([AIMessage(content="hello world!")])
|
||||
# When streaming GenericFakeChatModel breaks AIMessage into chunks based on spaces
|
||||
model = (
|
||||
GenericFakeChatModel(messages=infinite_cycle)
|
||||
@@ -373,6 +373,188 @@ async def test_astream_events_from_model() -> None:
|
||||
},
|
||||
]
|
||||
|
||||
@RunnableLambda
|
||||
def i_dont_stream(input: Any, config: RunnableConfig) -> Any:
|
||||
if sys.version_info >= (3, 11):
|
||||
return model.invoke(input)
|
||||
else:
|
||||
return model.invoke(input, config)
|
||||
|
||||
events = await _collect_events(i_dont_stream.astream_events("hello", version="v1"))
|
||||
assert events == [
|
||||
{
|
||||
"data": {"input": "hello"},
|
||||
"event": "on_chain_start",
|
||||
"metadata": {},
|
||||
"name": "i_dont_stream",
|
||||
"run_id": "",
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"input": {"messages": [[HumanMessage(content="hello")]]}},
|
||||
"event": "on_chat_model_start",
|
||||
"metadata": {"a": "b"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="hello")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content=" ")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="world!")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"input": {"messages": [[HumanMessage(content="hello")]]},
|
||||
"output": {
|
||||
"generations": [
|
||||
[
|
||||
{
|
||||
"generation_info": None,
|
||||
"message": AIMessage(content="hello world!"),
|
||||
"text": "hello world!",
|
||||
"type": "ChatGeneration",
|
||||
}
|
||||
]
|
||||
],
|
||||
"llm_output": None,
|
||||
"run": None,
|
||||
},
|
||||
},
|
||||
"event": "on_chat_model_end",
|
||||
"metadata": {"a": "b"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessage(content="hello world!")},
|
||||
"event": "on_chain_stream",
|
||||
"metadata": {},
|
||||
"name": "i_dont_stream",
|
||||
"run_id": "",
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"output": AIMessage(content="hello world!")},
|
||||
"event": "on_chain_end",
|
||||
"metadata": {},
|
||||
"name": "i_dont_stream",
|
||||
"run_id": "",
|
||||
"tags": [],
|
||||
},
|
||||
]
|
||||
|
||||
@RunnableLambda
|
||||
async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any:
|
||||
if sys.version_info >= (3, 11):
|
||||
return await model.ainvoke(input)
|
||||
else:
|
||||
return await model.ainvoke(input, config)
|
||||
|
||||
events = await _collect_events(ai_dont_stream.astream_events("hello", version="v1"))
|
||||
assert events == [
|
||||
{
|
||||
"data": {"input": "hello"},
|
||||
"event": "on_chain_start",
|
||||
"metadata": {},
|
||||
"name": "ai_dont_stream",
|
||||
"run_id": "",
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"input": {"messages": [[HumanMessage(content="hello")]]}},
|
||||
"event": "on_chat_model_start",
|
||||
"metadata": {"a": "b"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="hello")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content=" ")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="world!")},
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"a": "b"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"input": {"messages": [[HumanMessage(content="hello")]]},
|
||||
"output": {
|
||||
"generations": [
|
||||
[
|
||||
{
|
||||
"generation_info": None,
|
||||
"message": AIMessage(content="hello world!"),
|
||||
"text": "hello world!",
|
||||
"type": "ChatGeneration",
|
||||
}
|
||||
]
|
||||
],
|
||||
"llm_output": None,
|
||||
"run": None,
|
||||
},
|
||||
},
|
||||
"event": "on_chat_model_end",
|
||||
"metadata": {"a": "b"},
|
||||
"name": "my_model",
|
||||
"run_id": "",
|
||||
"tags": ["my_model"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessage(content="hello world!")},
|
||||
"event": "on_chain_stream",
|
||||
"metadata": {},
|
||||
"name": "ai_dont_stream",
|
||||
"run_id": "",
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"output": AIMessage(content="hello world!")},
|
||||
"event": "on_chain_end",
|
||||
"metadata": {},
|
||||
"name": "ai_dont_stream",
|
||||
"run_id": "",
|
||||
"tags": [],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
async def test_event_stream_with_simple_chain() -> None:
|
||||
"""Test as event stream."""
|
||||
|
Reference in New Issue
Block a user