mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 14:18:52 +00:00
core[patch]: base language model disable_streaming (#25070)
Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
parent
130e80b60f
commit
dff83cce66
@ -214,6 +214,18 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
rate_limiter: Optional[BaseRateLimiter] = Field(default=None, exclude=True)
|
||||
"""An optional rate limiter to use for limiting the number of requests."""
|
||||
|
||||
disable_streaming: Union[bool, Literal["tool_calling"]] = False
|
||||
"""Whether to disable streaming for this model.
|
||||
|
||||
If streaming is bypassed, then ``stream()/astream()`` will defer to
|
||||
``invoke()/ainvoke()``.
|
||||
|
||||
- If True, will always bypass streaming case.
|
||||
- If "tool_calling", will bypass streaming case only when the model is called
|
||||
with a ``tools`` keyword argument.
|
||||
- If False (default), will always use streaming case if available.
|
||||
"""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
"""Raise deprecation warning if callback_manager is used.
|
||||
@ -302,6 +314,41 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
)
|
||||
return cast(ChatGeneration, llm_result.generations[0][0]).message
|
||||
|
||||
def _should_stream(
|
||||
self,
|
||||
*,
|
||||
async_api: bool,
|
||||
run_manager: Optional[
|
||||
Union[CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun]
|
||||
] = None,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
"""Determine if a given model call should hit the streaming API."""
|
||||
sync_not_implemented = type(self)._stream == BaseChatModel._stream
|
||||
async_not_implemented = type(self)._astream == BaseChatModel._astream
|
||||
|
||||
# Check if streaming is implemented.
|
||||
if (not async_api) and sync_not_implemented:
|
||||
return False
|
||||
# Note, since async falls back to sync we check both here.
|
||||
if async_api and async_not_implemented and sync_not_implemented:
|
||||
return False
|
||||
|
||||
# Check if streaming has been disabled on this instance.
|
||||
if self.disable_streaming is True:
|
||||
return False
|
||||
# We assume tools are passed in via "tools" kwarg in all models.
|
||||
if self.disable_streaming == "tool_calling" and kwargs.get("tools"):
|
||||
return False
|
||||
|
||||
# Check if a runtime streaming flag has been passed in.
|
||||
if "stream" in kwargs:
|
||||
return kwargs["stream"]
|
||||
|
||||
# Check if any streaming callback handlers have been passed in.
|
||||
handlers = run_manager.handlers if run_manager else []
|
||||
return any(isinstance(h, _StreamingCallbackHandler) for h in handlers)
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
@ -310,7 +357,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[BaseMessageChunk]:
|
||||
if type(self)._stream == BaseChatModel._stream:
|
||||
if not self._should_stream(async_api=False, **{**kwargs, **{"stream": True}}):
|
||||
# model doesn't implement streaming, so use default implementation
|
||||
yield cast(
|
||||
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
|
||||
@ -380,10 +427,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[BaseMessageChunk]:
|
||||
if (
|
||||
type(self)._astream is BaseChatModel._astream
|
||||
and type(self)._stream is BaseChatModel._stream
|
||||
):
|
||||
if not self._should_stream(async_api=True, **{**kwargs, **{"stream": True}}):
|
||||
# No async or sync stream is implemented, so fall back to ainvoke
|
||||
yield cast(
|
||||
BaseMessageChunk,
|
||||
@ -760,20 +804,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
# If stream is not explicitly set, check if implicitly requested by
|
||||
# astream_events() or astream_log(). Bail out if _stream not implemented
|
||||
if type(self)._stream != BaseChatModel._stream and kwargs.pop(
|
||||
"stream",
|
||||
(
|
||||
next(
|
||||
(
|
||||
True
|
||||
for h in run_manager.handlers
|
||||
if isinstance(h, _StreamingCallbackHandler)
|
||||
),
|
||||
False,
|
||||
)
|
||||
if run_manager
|
||||
else False
|
||||
),
|
||||
if self._should_stream(
|
||||
async_api=False,
|
||||
run_manager=run_manager,
|
||||
**kwargs,
|
||||
):
|
||||
chunks: List[ChatGenerationChunk] = []
|
||||
for chunk in self._stream(messages, stop=stop, **kwargs):
|
||||
@ -847,23 +881,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
# If stream is not explicitly set, check if implicitly requested by
|
||||
# astream_events() or astream_log(). Bail out if _astream not implemented
|
||||
if (
|
||||
type(self)._astream != BaseChatModel._astream
|
||||
or type(self)._stream != BaseChatModel._stream
|
||||
) and kwargs.pop(
|
||||
"stream",
|
||||
(
|
||||
next(
|
||||
(
|
||||
True
|
||||
for h in run_manager.handlers
|
||||
if isinstance(h, _StreamingCallbackHandler)
|
||||
),
|
||||
False,
|
||||
)
|
||||
if run_manager
|
||||
else False
|
||||
),
|
||||
if self._should_stream(
|
||||
async_api=True,
|
||||
run_manager=run_manager,
|
||||
**kwargs,
|
||||
):
|
||||
chunks: List[ChatGenerationChunk] = []
|
||||
async for chunk in self._astream(messages, stop=stop, **kwargs):
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Test base chat model."""
|
||||
|
||||
import uuid
|
||||
from typing import Any, AsyncIterator, Iterator, List, Optional
|
||||
from typing import Any, AsyncIterator, Iterator, List, Literal, Optional, Union
|
||||
|
||||
import pytest
|
||||
|
||||
@ -18,6 +18,7 @@ from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResu
|
||||
from langchain_core.outputs.llm_result import LLMResult
|
||||
from langchain_core.tracers.base import BaseTracer
|
||||
from langchain_core.tracers.context import collect_runs
|
||||
from langchain_core.tracers.event_stream import _AstreamEventsCallbackHandler
|
||||
from langchain_core.tracers.schemas import Run
|
||||
from tests.unit_tests.fake.callbacks import (
|
||||
BaseFakeCallbackHandler,
|
||||
@ -272,3 +273,96 @@ async def test_async_pass_run_id() -> None:
|
||||
uid3 = uuid.uuid4()
|
||||
await llm.abatch([["Dummy message"]], {"callbacks": [cb], "run_id": uid3})
|
||||
assert cb.traced_run_ids == [uid1, uid2, uid3]
|
||||
|
||||
|
||||
class NoStreamingModel(BaseChatModel):
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage("invoke"))])
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "model1"
|
||||
|
||||
|
||||
class StreamingModel(NoStreamingModel):
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content="stream"))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"])
|
||||
async def test_disable_streaming(
|
||||
disable_streaming: Union[bool, Literal["tool_calling"]],
|
||||
) -> None:
|
||||
model = StreamingModel(disable_streaming=disable_streaming)
|
||||
assert model.invoke([]).content == "invoke"
|
||||
assert (await model.ainvoke([])).content == "invoke"
|
||||
|
||||
expected = "invoke" if disable_streaming is True else "stream"
|
||||
assert next(model.stream([])).content == expected
|
||||
async for c in model.astream([]):
|
||||
assert c.content == expected
|
||||
break
|
||||
assert (
|
||||
model.invoke(
|
||||
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}
|
||||
).content
|
||||
== expected
|
||||
)
|
||||
assert (
|
||||
await model.ainvoke([], config={"callbacks": [_AstreamEventsCallbackHandler()]})
|
||||
).content == expected
|
||||
|
||||
expected = "invoke" if disable_streaming in ("tool_calling", True) else "stream"
|
||||
assert next(model.stream([], tools=[{"type": "function"}])).content == expected
|
||||
async for c in model.astream([], tools=[{}]):
|
||||
assert c.content == expected
|
||||
break
|
||||
assert (
|
||||
model.invoke(
|
||||
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}, tools=[{}]
|
||||
).content
|
||||
== expected
|
||||
)
|
||||
assert (
|
||||
await model.ainvoke(
|
||||
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}, tools=[{}]
|
||||
)
|
||||
).content == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"])
|
||||
async def test_disable_streaming_no_streaming_model(
|
||||
disable_streaming: Union[bool, Literal["tool_calling"]],
|
||||
) -> None:
|
||||
model = NoStreamingModel(disable_streaming=disable_streaming)
|
||||
assert model.invoke([]).content == "invoke"
|
||||
assert (await model.ainvoke([])).content == "invoke"
|
||||
assert next(model.stream([])).content == "invoke"
|
||||
async for c in model.astream([]):
|
||||
assert c.content == "invoke"
|
||||
break
|
||||
assert (
|
||||
model.invoke(
|
||||
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}
|
||||
).content
|
||||
== "invoke"
|
||||
)
|
||||
assert (
|
||||
await model.ainvoke([], config={"callbacks": [_AstreamEventsCallbackHandler()]})
|
||||
).content == "invoke"
|
||||
assert next(model.stream([], tools=[{}])).content == "invoke"
|
||||
async for c in model.astream([], tools=[{}]):
|
||||
assert c.content == "invoke"
|
||||
break
|
||||
|
Loading…
Reference in New Issue
Block a user