core[patch]: base language model disable_streaming (#25070)

Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
Erick Friis 2024-08-07 09:26:21 -07:00 committed by GitHub
parent 130e80b60f
commit dff83cce66
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 152 additions and 37 deletions

View File

@ -214,6 +214,18 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
rate_limiter: Optional[BaseRateLimiter] = Field(default=None, exclude=True) rate_limiter: Optional[BaseRateLimiter] = Field(default=None, exclude=True)
"""An optional rate limiter to use for limiting the number of requests.""" """An optional rate limiter to use for limiting the number of requests."""
disable_streaming: Union[bool, Literal["tool_calling"]] = False
"""Whether to disable streaming for this model.
If streaming is bypassed, then ``stream()/astream()`` will defer to
``invoke()/ainvoke()``.
- If True, will always bypass streaming case.
- If "tool_calling", will bypass streaming case only when the model is called
with a ``tools`` keyword argument.
- If False (default), will always use streaming case if available.
"""
@root_validator(pre=True) @root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict: def raise_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used. """Raise deprecation warning if callback_manager is used.
@ -302,6 +314,41 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
) )
return cast(ChatGeneration, llm_result.generations[0][0]).message return cast(ChatGeneration, llm_result.generations[0][0]).message
def _should_stream(
self,
*,
async_api: bool,
run_manager: Optional[
Union[CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun]
] = None,
**kwargs: Any,
) -> bool:
"""Determine if a given model call should hit the streaming API."""
sync_not_implemented = type(self)._stream == BaseChatModel._stream
async_not_implemented = type(self)._astream == BaseChatModel._astream
# Check if streaming is implemented.
if (not async_api) and sync_not_implemented:
return False
# Note, since async falls back to sync we check both here.
if async_api and async_not_implemented and sync_not_implemented:
return False
# Check if streaming has been disabled on this instance.
if self.disable_streaming is True:
return False
# We assume tools are passed in via "tools" kwarg in all models.
if self.disable_streaming == "tool_calling" and kwargs.get("tools"):
return False
# Check if a runtime streaming flag has been passed in.
if "stream" in kwargs:
return kwargs["stream"]
# Check if any streaming callback handlers have been passed in.
handlers = run_manager.handlers if run_manager else []
return any(isinstance(h, _StreamingCallbackHandler) for h in handlers)
def stream( def stream(
self, self,
input: LanguageModelInput, input: LanguageModelInput,
@ -310,7 +357,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[BaseMessageChunk]: ) -> Iterator[BaseMessageChunk]:
if type(self)._stream == BaseChatModel._stream: if not self._should_stream(async_api=False, **{**kwargs, **{"stream": True}}):
# model doesn't implement streaming, so use default implementation # model doesn't implement streaming, so use default implementation
yield cast( yield cast(
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs) BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
@ -380,10 +427,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[BaseMessageChunk]: ) -> AsyncIterator[BaseMessageChunk]:
if ( if not self._should_stream(async_api=True, **{**kwargs, **{"stream": True}}):
type(self)._astream is BaseChatModel._astream
and type(self)._stream is BaseChatModel._stream
):
# No async or sync stream is implemented, so fall back to ainvoke # No async or sync stream is implemented, so fall back to ainvoke
yield cast( yield cast(
BaseMessageChunk, BaseMessageChunk,
@ -760,20 +804,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
# If stream is not explicitly set, check if implicitly requested by # If stream is not explicitly set, check if implicitly requested by
# astream_events() or astream_log(). Bail out if _stream not implemented # astream_events() or astream_log(). Bail out if _stream not implemented
if type(self)._stream != BaseChatModel._stream and kwargs.pop( if self._should_stream(
"stream", async_api=False,
( run_manager=run_manager,
next( **kwargs,
(
True
for h in run_manager.handlers
if isinstance(h, _StreamingCallbackHandler)
),
False,
)
if run_manager
else False
),
): ):
chunks: List[ChatGenerationChunk] = [] chunks: List[ChatGenerationChunk] = []
for chunk in self._stream(messages, stop=stop, **kwargs): for chunk in self._stream(messages, stop=stop, **kwargs):
@ -847,23 +881,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
# If stream is not explicitly set, check if implicitly requested by # If stream is not explicitly set, check if implicitly requested by
# astream_events() or astream_log(). Bail out if _astream not implemented # astream_events() or astream_log(). Bail out if _astream not implemented
if ( if self._should_stream(
type(self)._astream != BaseChatModel._astream async_api=True,
or type(self)._stream != BaseChatModel._stream run_manager=run_manager,
) and kwargs.pop( **kwargs,
"stream",
(
next(
(
True
for h in run_manager.handlers
if isinstance(h, _StreamingCallbackHandler)
),
False,
)
if run_manager
else False
),
): ):
chunks: List[ChatGenerationChunk] = [] chunks: List[ChatGenerationChunk] = []
async for chunk in self._astream(messages, stop=stop, **kwargs): async for chunk in self._astream(messages, stop=stop, **kwargs):

View File

@ -1,7 +1,7 @@
"""Test base chat model.""" """Test base chat model."""
import uuid import uuid
from typing import Any, AsyncIterator, Iterator, List, Optional from typing import Any, AsyncIterator, Iterator, List, Literal, Optional, Union
import pytest import pytest
@ -18,6 +18,7 @@ from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResu
from langchain_core.outputs.llm_result import LLMResult from langchain_core.outputs.llm_result import LLMResult
from langchain_core.tracers.base import BaseTracer from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.context import collect_runs from langchain_core.tracers.context import collect_runs
from langchain_core.tracers.event_stream import _AstreamEventsCallbackHandler
from langchain_core.tracers.schemas import Run from langchain_core.tracers.schemas import Run
from tests.unit_tests.fake.callbacks import ( from tests.unit_tests.fake.callbacks import (
BaseFakeCallbackHandler, BaseFakeCallbackHandler,
@ -272,3 +273,96 @@ async def test_async_pass_run_id() -> None:
uid3 = uuid.uuid4() uid3 = uuid.uuid4()
await llm.abatch([["Dummy message"]], {"callbacks": [cb], "run_id": uid3}) await llm.abatch([["Dummy message"]], {"callbacks": [cb], "run_id": uid3})
assert cb.traced_run_ids == [uid1, uid2, uid3] assert cb.traced_run_ids == [uid1, uid2, uid3]
class NoStreamingModel(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
return ChatResult(generations=[ChatGeneration(message=AIMessage("invoke"))])
@property
def _llm_type(self) -> str:
return "model1"
class StreamingModel(NoStreamingModel):
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
yield ChatGenerationChunk(message=AIMessageChunk(content="stream"))
@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"])
async def test_disable_streaming(
disable_streaming: Union[bool, Literal["tool_calling"]],
) -> None:
model = StreamingModel(disable_streaming=disable_streaming)
assert model.invoke([]).content == "invoke"
assert (await model.ainvoke([])).content == "invoke"
expected = "invoke" if disable_streaming is True else "stream"
assert next(model.stream([])).content == expected
async for c in model.astream([]):
assert c.content == expected
break
assert (
model.invoke(
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}
).content
== expected
)
assert (
await model.ainvoke([], config={"callbacks": [_AstreamEventsCallbackHandler()]})
).content == expected
expected = "invoke" if disable_streaming in ("tool_calling", True) else "stream"
assert next(model.stream([], tools=[{"type": "function"}])).content == expected
async for c in model.astream([], tools=[{}]):
assert c.content == expected
break
assert (
model.invoke(
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}, tools=[{}]
).content
== expected
)
assert (
await model.ainvoke(
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}, tools=[{}]
)
).content == expected
@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"])
async def test_disable_streaming_no_streaming_model(
disable_streaming: Union[bool, Literal["tool_calling"]],
) -> None:
model = NoStreamingModel(disable_streaming=disable_streaming)
assert model.invoke([]).content == "invoke"
assert (await model.ainvoke([])).content == "invoke"
assert next(model.stream([])).content == "invoke"
async for c in model.astream([]):
assert c.content == "invoke"
break
assert (
model.invoke(
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}
).content
== "invoke"
)
assert (
await model.ainvoke([], config={"callbacks": [_AstreamEventsCallbackHandler()]})
).content == "invoke"
assert next(model.stream([], tools=[{}])).content == "invoke"
async for c in model.astream([], tools=[{}]):
assert c.content == "invoke"
break