core[patch]: base language model disable_streaming (#25070)

Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
Erick Friis 2024-08-07 09:26:21 -07:00 committed by GitHub
parent 130e80b60f
commit dff83cce66
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 152 additions and 37 deletions

View File

@ -214,6 +214,18 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
rate_limiter: Optional[BaseRateLimiter] = Field(default=None, exclude=True)
"""An optional rate limiter to use for limiting the number of requests."""
disable_streaming: Union[bool, Literal["tool_calling"]] = False
"""Whether to disable streaming for this model.
If streaming is bypassed, then ``stream()/astream()`` will defer to
``invoke()/ainvoke()``.
- If True, will always bypass streaming case.
- If "tool_calling", will bypass streaming case only when the model is called
with a ``tools`` keyword argument.
- If False (default), will always use streaming case if available.
"""
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used.
@ -302,6 +314,41 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
)
return cast(ChatGeneration, llm_result.generations[0][0]).message
def _should_stream(
self,
*,
async_api: bool,
run_manager: Optional[
Union[CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun]
] = None,
**kwargs: Any,
) -> bool:
"""Determine if a given model call should hit the streaming API."""
sync_not_implemented = type(self)._stream == BaseChatModel._stream
async_not_implemented = type(self)._astream == BaseChatModel._astream
# Check if streaming is implemented.
if (not async_api) and sync_not_implemented:
return False
# Note, since async falls back to sync we check both here.
if async_api and async_not_implemented and sync_not_implemented:
return False
# Check if streaming has been disabled on this instance.
if self.disable_streaming is True:
return False
# We assume tools are passed in via "tools" kwarg in all models.
if self.disable_streaming == "tool_calling" and kwargs.get("tools"):
return False
# Check if a runtime streaming flag has been passed in.
if "stream" in kwargs:
return kwargs["stream"]
# Check if any streaming callback handlers have been passed in.
handlers = run_manager.handlers if run_manager else []
return any(isinstance(h, _StreamingCallbackHandler) for h in handlers)
def stream(
self,
input: LanguageModelInput,
@ -310,7 +357,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[BaseMessageChunk]:
if type(self)._stream == BaseChatModel._stream:
if not self._should_stream(async_api=False, **{**kwargs, **{"stream": True}}):
# model doesn't implement streaming, so use default implementation
yield cast(
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
@ -380,10 +427,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[BaseMessageChunk]:
if (
type(self)._astream is BaseChatModel._astream
and type(self)._stream is BaseChatModel._stream
):
if not self._should_stream(async_api=True, **{**kwargs, **{"stream": True}}):
# No async or sync stream is implemented, so fall back to ainvoke
yield cast(
BaseMessageChunk,
@ -760,20 +804,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
# If stream is not explicitly set, check if implicitly requested by
# astream_events() or astream_log(). Bail out if _stream not implemented
if type(self)._stream != BaseChatModel._stream and kwargs.pop(
"stream",
(
next(
(
True
for h in run_manager.handlers
if isinstance(h, _StreamingCallbackHandler)
),
False,
)
if run_manager
else False
),
if self._should_stream(
async_api=False,
run_manager=run_manager,
**kwargs,
):
chunks: List[ChatGenerationChunk] = []
for chunk in self._stream(messages, stop=stop, **kwargs):
@ -847,23 +881,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
# If stream is not explicitly set, check if implicitly requested by
# astream_events() or astream_log(). Bail out if _astream not implemented
if (
type(self)._astream != BaseChatModel._astream
or type(self)._stream != BaseChatModel._stream
) and kwargs.pop(
"stream",
(
next(
(
True
for h in run_manager.handlers
if isinstance(h, _StreamingCallbackHandler)
),
False,
)
if run_manager
else False
),
if self._should_stream(
async_api=True,
run_manager=run_manager,
**kwargs,
):
chunks: List[ChatGenerationChunk] = []
async for chunk in self._astream(messages, stop=stop, **kwargs):

View File

@ -1,7 +1,7 @@
"""Test base chat model."""
import uuid
from typing import Any, AsyncIterator, Iterator, List, Optional
from typing import Any, AsyncIterator, Iterator, List, Literal, Optional, Union
import pytest
@ -18,6 +18,7 @@ from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResu
from langchain_core.outputs.llm_result import LLMResult
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.context import collect_runs
from langchain_core.tracers.event_stream import _AstreamEventsCallbackHandler
from langchain_core.tracers.schemas import Run
from tests.unit_tests.fake.callbacks import (
BaseFakeCallbackHandler,
@ -272,3 +273,96 @@ async def test_async_pass_run_id() -> None:
uid3 = uuid.uuid4()
await llm.abatch([["Dummy message"]], {"callbacks": [cb], "run_id": uid3})
assert cb.traced_run_ids == [uid1, uid2, uid3]
class NoStreamingModel(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
return ChatResult(generations=[ChatGeneration(message=AIMessage("invoke"))])
@property
def _llm_type(self) -> str:
return "model1"
class StreamingModel(NoStreamingModel):
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
yield ChatGenerationChunk(message=AIMessageChunk(content="stream"))
@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"])
async def test_disable_streaming(
disable_streaming: Union[bool, Literal["tool_calling"]],
) -> None:
model = StreamingModel(disable_streaming=disable_streaming)
assert model.invoke([]).content == "invoke"
assert (await model.ainvoke([])).content == "invoke"
expected = "invoke" if disable_streaming is True else "stream"
assert next(model.stream([])).content == expected
async for c in model.astream([]):
assert c.content == expected
break
assert (
model.invoke(
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}
).content
== expected
)
assert (
await model.ainvoke([], config={"callbacks": [_AstreamEventsCallbackHandler()]})
).content == expected
expected = "invoke" if disable_streaming in ("tool_calling", True) else "stream"
assert next(model.stream([], tools=[{"type": "function"}])).content == expected
async for c in model.astream([], tools=[{}]):
assert c.content == expected
break
assert (
model.invoke(
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}, tools=[{}]
).content
== expected
)
assert (
await model.ainvoke(
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}, tools=[{}]
)
).content == expected
@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"])
async def test_disable_streaming_no_streaming_model(
disable_streaming: Union[bool, Literal["tool_calling"]],
) -> None:
model = NoStreamingModel(disable_streaming=disable_streaming)
assert model.invoke([]).content == "invoke"
assert (await model.ainvoke([])).content == "invoke"
assert next(model.stream([])).content == "invoke"
async for c in model.astream([]):
assert c.content == "invoke"
break
assert (
model.invoke(
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}
).content
== "invoke"
)
assert (
await model.ainvoke([], config={"callbacks": [_AstreamEventsCallbackHandler()]})
).content == "invoke"
assert next(model.stream([], tools=[{}])).content == "invoke"
async for c in model.astream([], tools=[{}]):
assert c.content == "invoke"
break