mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-05 07:08:03 +00:00
core: Two updates to chat model interface (#19684)
- .stream() and .astream() call on_llm_new_token, removing the need for subclasses to do so. Backwards compatible because now we don't pass run_manager into ._stream and ._astream - .generate() and .agenerate() now handle `stream: bool` kwarg for _generate and _agenerate. Subclasses handle this arg by delegating to ._stream(), now one less thing they need to do. Backwards compat because this is an optional arg that we now never pass to the subclasses - .generate() and .agenerate() now inspect callback handlers to decide on a default value for stream:bool if not passed in. This auto enables streaming when using astream_events and astream_log - as a result of these three changes any usage of .astream_events and .astream_log should now yield chat model stream events - In future PRs we can update all subclasses to reflect these two things now handled by base class, but in meantime all will continue to work
This commit is contained in:
parent
3685f8ceac
commit
fdfb51ad8d
@ -50,6 +50,7 @@ from langchain_core.outputs import (
|
|||||||
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
|
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
|
||||||
from langchain_core.pydantic_v1 import Field, root_validator
|
from langchain_core.pydantic_v1 import Field, root_validator
|
||||||
from langchain_core.runnables.config import ensure_config, run_in_executor
|
from langchain_core.runnables.config import ensure_config, run_in_executor
|
||||||
|
from langchain_core.tracers.log_stream import LogStreamCallbackHandler
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
@ -219,9 +220,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
)
|
)
|
||||||
generation: Optional[ChatGenerationChunk] = None
|
generation: Optional[ChatGenerationChunk] = None
|
||||||
try:
|
try:
|
||||||
for chunk in self._stream(
|
for chunk in self._stream(messages, stop=stop, **kwargs):
|
||||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
run_manager.on_llm_new_token(
|
||||||
):
|
cast(str, chunk.message.content), chunk=chunk
|
||||||
|
)
|
||||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||||
yield chunk.message
|
yield chunk.message
|
||||||
if generation is None:
|
if generation is None:
|
||||||
@ -287,9 +289,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
async for chunk in self._astream(
|
async for chunk in self._astream(
|
||||||
messages,
|
messages,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
run_manager=run_manager,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
await run_manager.on_llm_new_token(
|
||||||
|
cast(str, chunk.message.content), chunk=chunk
|
||||||
|
)
|
||||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||||
yield chunk.message
|
yield chunk.message
|
||||||
if generation is None:
|
if generation is None:
|
||||||
@ -585,6 +589,31 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Asked to cache, but no cache found at `langchain.cache`."
|
"Asked to cache, but no cache found at `langchain.cache`."
|
||||||
)
|
)
|
||||||
|
# If stream is not explicitly set, check if implicitly requested by
|
||||||
|
# astream_events() or astream_log(). Bail out if _stream not implemented
|
||||||
|
if type(self)._stream != BaseChatModel._stream and kwargs.pop(
|
||||||
|
"stream",
|
||||||
|
next(
|
||||||
|
(
|
||||||
|
True
|
||||||
|
for h in run_manager.handlers
|
||||||
|
if isinstance(h, LogStreamCallbackHandler)
|
||||||
|
),
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
if run_manager
|
||||||
|
else False,
|
||||||
|
):
|
||||||
|
chunks: List[ChatGenerationChunk] = []
|
||||||
|
for chunk in self._stream(messages, stop=stop, **kwargs):
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(
|
||||||
|
cast(str, chunk.message.content), chunk=chunk
|
||||||
|
)
|
||||||
|
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||||
|
chunks.append(chunk)
|
||||||
|
result = generate_from_stream(iter(chunks))
|
||||||
|
else:
|
||||||
if inspect.signature(self._generate).parameters.get("run_manager"):
|
if inspect.signature(self._generate).parameters.get("run_manager"):
|
||||||
result = self._generate(
|
result = self._generate(
|
||||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
@ -634,6 +663,34 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Asked to cache, but no cache found at `langchain.cache`."
|
"Asked to cache, but no cache found at `langchain.cache`."
|
||||||
)
|
)
|
||||||
|
# If stream is not explicitly set, check if implicitly requested by
|
||||||
|
# astream_events() or astream_log(). Bail out if _astream not implemented
|
||||||
|
if (
|
||||||
|
type(self)._astream != BaseChatModel._astream
|
||||||
|
or type(self)._stream != BaseChatModel._stream
|
||||||
|
) and kwargs.pop(
|
||||||
|
"stream",
|
||||||
|
next(
|
||||||
|
(
|
||||||
|
True
|
||||||
|
for h in run_manager.handlers
|
||||||
|
if isinstance(h, LogStreamCallbackHandler)
|
||||||
|
),
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
if run_manager
|
||||||
|
else False,
|
||||||
|
):
|
||||||
|
chunks: List[ChatGenerationChunk] = []
|
||||||
|
async for chunk in self._astream(messages, stop=stop, **kwargs):
|
||||||
|
if run_manager:
|
||||||
|
await run_manager.on_llm_new_token(
|
||||||
|
cast(str, chunk.message.content), chunk=chunk
|
||||||
|
)
|
||||||
|
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||||
|
chunks.append(chunk)
|
||||||
|
result = generate_from_stream(iter(chunks))
|
||||||
|
else:
|
||||||
if inspect.signature(self._agenerate).parameters.get("run_manager"):
|
if inspect.signature(self._agenerate).parameters.get("run_manager"):
|
||||||
result = await self._agenerate(
|
result = await self._agenerate(
|
||||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
"""Module that contains tests for runnable.astream_events API."""
|
"""Module that contains tests for runnable.astream_events API."""
|
||||||
|
import sys
|
||||||
from itertools import cycle
|
from itertools import cycle
|
||||||
from typing import Any, AsyncIterator, Dict, List, Sequence, cast
|
from typing import Any, AsyncIterator, Dict, List, Sequence, cast
|
||||||
|
|
||||||
@ -22,6 +23,7 @@ from langchain_core.retrievers import BaseRetriever
|
|||||||
from langchain_core.runnables import (
|
from langchain_core.runnables import (
|
||||||
ConfigurableField,
|
ConfigurableField,
|
||||||
Runnable,
|
Runnable,
|
||||||
|
RunnableConfig,
|
||||||
RunnableLambda,
|
RunnableLambda,
|
||||||
)
|
)
|
||||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||||
@ -314,9 +316,7 @@ async def test_event_stream_with_lambdas_from_lambda() -> None:
|
|||||||
|
|
||||||
async def test_astream_events_from_model() -> None:
|
async def test_astream_events_from_model() -> None:
|
||||||
"""Test the output of a model."""
|
"""Test the output of a model."""
|
||||||
infinite_cycle = cycle(
|
infinite_cycle = cycle([AIMessage(content="hello world!")])
|
||||||
[AIMessage(content="hello world!"), AIMessage(content="goodbye world!")]
|
|
||||||
)
|
|
||||||
# When streaming GenericFakeChatModel breaks AIMessage into chunks based on spaces
|
# When streaming GenericFakeChatModel breaks AIMessage into chunks based on spaces
|
||||||
model = (
|
model = (
|
||||||
GenericFakeChatModel(messages=infinite_cycle)
|
GenericFakeChatModel(messages=infinite_cycle)
|
||||||
@ -373,6 +373,188 @@ async def test_astream_events_from_model() -> None:
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@RunnableLambda
|
||||||
|
def i_dont_stream(input: Any, config: RunnableConfig) -> Any:
|
||||||
|
if sys.version_info >= (3, 11):
|
||||||
|
return model.invoke(input)
|
||||||
|
else:
|
||||||
|
return model.invoke(input, config)
|
||||||
|
|
||||||
|
events = await _collect_events(i_dont_stream.astream_events("hello", version="v1"))
|
||||||
|
assert events == [
|
||||||
|
{
|
||||||
|
"data": {"input": "hello"},
|
||||||
|
"event": "on_chain_start",
|
||||||
|
"metadata": {},
|
||||||
|
"name": "i_dont_stream",
|
||||||
|
"run_id": "",
|
||||||
|
"tags": [],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {"input": {"messages": [[HumanMessage(content="hello")]]}},
|
||||||
|
"event": "on_chat_model_start",
|
||||||
|
"metadata": {"a": "b"},
|
||||||
|
"name": "my_model",
|
||||||
|
"run_id": "",
|
||||||
|
"tags": ["my_model"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {"chunk": AIMessageChunk(content="hello")},
|
||||||
|
"event": "on_chat_model_stream",
|
||||||
|
"metadata": {"a": "b"},
|
||||||
|
"name": "my_model",
|
||||||
|
"run_id": "",
|
||||||
|
"tags": ["my_model"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {"chunk": AIMessageChunk(content=" ")},
|
||||||
|
"event": "on_chat_model_stream",
|
||||||
|
"metadata": {"a": "b"},
|
||||||
|
"name": "my_model",
|
||||||
|
"run_id": "",
|
||||||
|
"tags": ["my_model"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {"chunk": AIMessageChunk(content="world!")},
|
||||||
|
"event": "on_chat_model_stream",
|
||||||
|
"metadata": {"a": "b"},
|
||||||
|
"name": "my_model",
|
||||||
|
"run_id": "",
|
||||||
|
"tags": ["my_model"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"input": {"messages": [[HumanMessage(content="hello")]]},
|
||||||
|
"output": {
|
||||||
|
"generations": [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"generation_info": None,
|
||||||
|
"message": AIMessage(content="hello world!"),
|
||||||
|
"text": "hello world!",
|
||||||
|
"type": "ChatGeneration",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
],
|
||||||
|
"llm_output": None,
|
||||||
|
"run": None,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"event": "on_chat_model_end",
|
||||||
|
"metadata": {"a": "b"},
|
||||||
|
"name": "my_model",
|
||||||
|
"run_id": "",
|
||||||
|
"tags": ["my_model"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {"chunk": AIMessage(content="hello world!")},
|
||||||
|
"event": "on_chain_stream",
|
||||||
|
"metadata": {},
|
||||||
|
"name": "i_dont_stream",
|
||||||
|
"run_id": "",
|
||||||
|
"tags": [],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {"output": AIMessage(content="hello world!")},
|
||||||
|
"event": "on_chain_end",
|
||||||
|
"metadata": {},
|
||||||
|
"name": "i_dont_stream",
|
||||||
|
"run_id": "",
|
||||||
|
"tags": [],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
@RunnableLambda
|
||||||
|
async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any:
|
||||||
|
if sys.version_info >= (3, 11):
|
||||||
|
return await model.ainvoke(input)
|
||||||
|
else:
|
||||||
|
return await model.ainvoke(input, config)
|
||||||
|
|
||||||
|
events = await _collect_events(ai_dont_stream.astream_events("hello", version="v1"))
|
||||||
|
assert events == [
|
||||||
|
{
|
||||||
|
"data": {"input": "hello"},
|
||||||
|
"event": "on_chain_start",
|
||||||
|
"metadata": {},
|
||||||
|
"name": "ai_dont_stream",
|
||||||
|
"run_id": "",
|
||||||
|
"tags": [],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {"input": {"messages": [[HumanMessage(content="hello")]]}},
|
||||||
|
"event": "on_chat_model_start",
|
||||||
|
"metadata": {"a": "b"},
|
||||||
|
"name": "my_model",
|
||||||
|
"run_id": "",
|
||||||
|
"tags": ["my_model"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {"chunk": AIMessageChunk(content="hello")},
|
||||||
|
"event": "on_chat_model_stream",
|
||||||
|
"metadata": {"a": "b"},
|
||||||
|
"name": "my_model",
|
||||||
|
"run_id": "",
|
||||||
|
"tags": ["my_model"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {"chunk": AIMessageChunk(content=" ")},
|
||||||
|
"event": "on_chat_model_stream",
|
||||||
|
"metadata": {"a": "b"},
|
||||||
|
"name": "my_model",
|
||||||
|
"run_id": "",
|
||||||
|
"tags": ["my_model"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {"chunk": AIMessageChunk(content="world!")},
|
||||||
|
"event": "on_chat_model_stream",
|
||||||
|
"metadata": {"a": "b"},
|
||||||
|
"name": "my_model",
|
||||||
|
"run_id": "",
|
||||||
|
"tags": ["my_model"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"input": {"messages": [[HumanMessage(content="hello")]]},
|
||||||
|
"output": {
|
||||||
|
"generations": [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"generation_info": None,
|
||||||
|
"message": AIMessage(content="hello world!"),
|
||||||
|
"text": "hello world!",
|
||||||
|
"type": "ChatGeneration",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
],
|
||||||
|
"llm_output": None,
|
||||||
|
"run": None,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"event": "on_chat_model_end",
|
||||||
|
"metadata": {"a": "b"},
|
||||||
|
"name": "my_model",
|
||||||
|
"run_id": "",
|
||||||
|
"tags": ["my_model"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {"chunk": AIMessage(content="hello world!")},
|
||||||
|
"event": "on_chain_stream",
|
||||||
|
"metadata": {},
|
||||||
|
"name": "ai_dont_stream",
|
||||||
|
"run_id": "",
|
||||||
|
"tags": [],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {"output": AIMessage(content="hello world!")},
|
||||||
|
"event": "on_chain_end",
|
||||||
|
"metadata": {},
|
||||||
|
"name": "ai_dont_stream",
|
||||||
|
"run_id": "",
|
||||||
|
"tags": [],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
async def test_event_stream_with_simple_chain() -> None:
|
async def test_event_stream_with_simple_chain() -> None:
|
||||||
"""Test as event stream."""
|
"""Test as event stream."""
|
||||||
|
Loading…
Reference in New Issue
Block a user