diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index b4d367572a1..e584565a9a9 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -214,6 +214,18 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): rate_limiter: Optional[BaseRateLimiter] = Field(default=None, exclude=True) """An optional rate limiter to use for limiting the number of requests.""" + disable_streaming: Union[bool, Literal["tool_calling"]] = False + """Whether to disable streaming for this model. + + If streaming is bypassed, then ``stream()/astream()`` will defer to + ``invoke()/ainvoke()``. + + - If True, will always bypass streaming case. + - If "tool_calling", will bypass streaming case only when the model is called + with a ``tools`` keyword argument. + - If False (default), will always use streaming case if available. + """ + @root_validator(pre=True) def raise_deprecation(cls, values: Dict) -> Dict: """Raise deprecation warning if callback_manager is used. @@ -302,6 +314,41 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): ) return cast(ChatGeneration, llm_result.generations[0][0]).message + def _should_stream( + self, + *, + async_api: bool, + run_manager: Optional[ + Union[CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun] + ] = None, + **kwargs: Any, + ) -> bool: + """Determine if a given model call should hit the streaming API.""" + sync_not_implemented = type(self)._stream == BaseChatModel._stream + async_not_implemented = type(self)._astream == BaseChatModel._astream + + # Check if streaming is implemented. + if (not async_api) and sync_not_implemented: + return False + # Note, since async falls back to sync we check both here. + if async_api and async_not_implemented and sync_not_implemented: + return False + + # Check if streaming has been disabled on this instance. + if self.disable_streaming is True: + return False + # We assume tools are passed in via "tools" kwarg in all models. + if self.disable_streaming == "tool_calling" and kwargs.get("tools"): + return False + + # Check if a runtime streaming flag has been passed in. + if "stream" in kwargs: + return kwargs["stream"] + + # Check if any streaming callback handlers have been passed in. + handlers = run_manager.handlers if run_manager else [] + return any(isinstance(h, _StreamingCallbackHandler) for h in handlers) + def stream( self, input: LanguageModelInput, @@ -310,7 +357,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): stop: Optional[List[str]] = None, **kwargs: Any, ) -> Iterator[BaseMessageChunk]: - if type(self)._stream == BaseChatModel._stream: + if not self._should_stream(async_api=False, **{**kwargs, **{"stream": True}}): # model doesn't implement streaming, so use default implementation yield cast( BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs) @@ -380,10 +427,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): stop: Optional[List[str]] = None, **kwargs: Any, ) -> AsyncIterator[BaseMessageChunk]: - if ( - type(self)._astream is BaseChatModel._astream - and type(self)._stream is BaseChatModel._stream - ): + if not self._should_stream(async_api=True, **{**kwargs, **{"stream": True}}): # No async or sync stream is implemented, so fall back to ainvoke yield cast( BaseMessageChunk, @@ -760,20 +804,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): # If stream is not explicitly set, check if implicitly requested by # astream_events() or astream_log(). Bail out if _stream not implemented - if type(self)._stream != BaseChatModel._stream and kwargs.pop( - "stream", - ( - next( - ( - True - for h in run_manager.handlers - if isinstance(h, _StreamingCallbackHandler) - ), - False, - ) - if run_manager - else False - ), + if self._should_stream( + async_api=False, + run_manager=run_manager, + **kwargs, ): chunks: List[ChatGenerationChunk] = [] for chunk in self._stream(messages, stop=stop, **kwargs): @@ -847,23 +881,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): # If stream is not explicitly set, check if implicitly requested by # astream_events() or astream_log(). Bail out if _astream not implemented - if ( - type(self)._astream != BaseChatModel._astream - or type(self)._stream != BaseChatModel._stream - ) and kwargs.pop( - "stream", - ( - next( - ( - True - for h in run_manager.handlers - if isinstance(h, _StreamingCallbackHandler) - ), - False, - ) - if run_manager - else False - ), + if self._should_stream( + async_api=True, + run_manager=run_manager, + **kwargs, ): chunks: List[ChatGenerationChunk] = [] async for chunk in self._astream(messages, stop=stop, **kwargs): diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py index 4c08a5441d1..e137d96460e 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py @@ -1,7 +1,7 @@ """Test base chat model.""" import uuid -from typing import Any, AsyncIterator, Iterator, List, Optional +from typing import Any, AsyncIterator, Iterator, List, Literal, Optional, Union import pytest @@ -18,6 +18,7 @@ from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResu from langchain_core.outputs.llm_result import LLMResult from langchain_core.tracers.base import BaseTracer from langchain_core.tracers.context import collect_runs +from langchain_core.tracers.event_stream import _AstreamEventsCallbackHandler from langchain_core.tracers.schemas import Run from tests.unit_tests.fake.callbacks import ( BaseFakeCallbackHandler, @@ -272,3 +273,96 @@ async def test_async_pass_run_id() -> None: uid3 = uuid.uuid4() await llm.abatch([["Dummy message"]], {"callbacks": [cb], "run_id": uid3}) assert cb.traced_run_ids == [uid1, uid2, uid3] + + +class NoStreamingModel(BaseChatModel): + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + return ChatResult(generations=[ChatGeneration(message=AIMessage("invoke"))]) + + @property + def _llm_type(self) -> str: + return "model1" + + +class StreamingModel(NoStreamingModel): + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + yield ChatGenerationChunk(message=AIMessageChunk(content="stream")) + + +@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"]) +async def test_disable_streaming( + disable_streaming: Union[bool, Literal["tool_calling"]], +) -> None: + model = StreamingModel(disable_streaming=disable_streaming) + assert model.invoke([]).content == "invoke" + assert (await model.ainvoke([])).content == "invoke" + + expected = "invoke" if disable_streaming is True else "stream" + assert next(model.stream([])).content == expected + async for c in model.astream([]): + assert c.content == expected + break + assert ( + model.invoke( + [], config={"callbacks": [_AstreamEventsCallbackHandler()]} + ).content + == expected + ) + assert ( + await model.ainvoke([], config={"callbacks": [_AstreamEventsCallbackHandler()]}) + ).content == expected + + expected = "invoke" if disable_streaming in ("tool_calling", True) else "stream" + assert next(model.stream([], tools=[{"type": "function"}])).content == expected + async for c in model.astream([], tools=[{}]): + assert c.content == expected + break + assert ( + model.invoke( + [], config={"callbacks": [_AstreamEventsCallbackHandler()]}, tools=[{}] + ).content + == expected + ) + assert ( + await model.ainvoke( + [], config={"callbacks": [_AstreamEventsCallbackHandler()]}, tools=[{}] + ) + ).content == expected + + +@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"]) +async def test_disable_streaming_no_streaming_model( + disable_streaming: Union[bool, Literal["tool_calling"]], +) -> None: + model = NoStreamingModel(disable_streaming=disable_streaming) + assert model.invoke([]).content == "invoke" + assert (await model.ainvoke([])).content == "invoke" + assert next(model.stream([])).content == "invoke" + async for c in model.astream([]): + assert c.content == "invoke" + break + assert ( + model.invoke( + [], config={"callbacks": [_AstreamEventsCallbackHandler()]} + ).content + == "invoke" + ) + assert ( + await model.ainvoke([], config={"callbacks": [_AstreamEventsCallbackHandler()]}) + ).content == "invoke" + assert next(model.stream([], tools=[{}])).content == "invoke" + async for c in model.astream([], tools=[{}]): + assert c.content == "invoke" + break