mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 14:49:29 +00:00
core[patch]: base language model disable_streaming (#25070)
Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
parent
130e80b60f
commit
dff83cce66
@ -214,6 +214,18 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
rate_limiter: Optional[BaseRateLimiter] = Field(default=None, exclude=True)
|
rate_limiter: Optional[BaseRateLimiter] = Field(default=None, exclude=True)
|
||||||
"""An optional rate limiter to use for limiting the number of requests."""
|
"""An optional rate limiter to use for limiting the number of requests."""
|
||||||
|
|
||||||
|
disable_streaming: Union[bool, Literal["tool_calling"]] = False
|
||||||
|
"""Whether to disable streaming for this model.
|
||||||
|
|
||||||
|
If streaming is bypassed, then ``stream()/astream()`` will defer to
|
||||||
|
``invoke()/ainvoke()``.
|
||||||
|
|
||||||
|
- If True, will always bypass streaming case.
|
||||||
|
- If "tool_calling", will bypass streaming case only when the model is called
|
||||||
|
with a ``tools`` keyword argument.
|
||||||
|
- If False (default), will always use streaming case if available.
|
||||||
|
"""
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||||
"""Raise deprecation warning if callback_manager is used.
|
"""Raise deprecation warning if callback_manager is used.
|
||||||
@ -302,6 +314,41 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
)
|
)
|
||||||
return cast(ChatGeneration, llm_result.generations[0][0]).message
|
return cast(ChatGeneration, llm_result.generations[0][0]).message
|
||||||
|
|
||||||
|
def _should_stream(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
async_api: bool,
|
||||||
|
run_manager: Optional[
|
||||||
|
Union[CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun]
|
||||||
|
] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> bool:
|
||||||
|
"""Determine if a given model call should hit the streaming API."""
|
||||||
|
sync_not_implemented = type(self)._stream == BaseChatModel._stream
|
||||||
|
async_not_implemented = type(self)._astream == BaseChatModel._astream
|
||||||
|
|
||||||
|
# Check if streaming is implemented.
|
||||||
|
if (not async_api) and sync_not_implemented:
|
||||||
|
return False
|
||||||
|
# Note, since async falls back to sync we check both here.
|
||||||
|
if async_api and async_not_implemented and sync_not_implemented:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if streaming has been disabled on this instance.
|
||||||
|
if self.disable_streaming is True:
|
||||||
|
return False
|
||||||
|
# We assume tools are passed in via "tools" kwarg in all models.
|
||||||
|
if self.disable_streaming == "tool_calling" and kwargs.get("tools"):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if a runtime streaming flag has been passed in.
|
||||||
|
if "stream" in kwargs:
|
||||||
|
return kwargs["stream"]
|
||||||
|
|
||||||
|
# Check if any streaming callback handlers have been passed in.
|
||||||
|
handlers = run_manager.handlers if run_manager else []
|
||||||
|
return any(isinstance(h, _StreamingCallbackHandler) for h in handlers)
|
||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
self,
|
self,
|
||||||
input: LanguageModelInput,
|
input: LanguageModelInput,
|
||||||
@ -310,7 +357,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[BaseMessageChunk]:
|
) -> Iterator[BaseMessageChunk]:
|
||||||
if type(self)._stream == BaseChatModel._stream:
|
if not self._should_stream(async_api=False, **{**kwargs, **{"stream": True}}):
|
||||||
# model doesn't implement streaming, so use default implementation
|
# model doesn't implement streaming, so use default implementation
|
||||||
yield cast(
|
yield cast(
|
||||||
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
|
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
|
||||||
@ -380,10 +427,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[BaseMessageChunk]:
|
) -> AsyncIterator[BaseMessageChunk]:
|
||||||
if (
|
if not self._should_stream(async_api=True, **{**kwargs, **{"stream": True}}):
|
||||||
type(self)._astream is BaseChatModel._astream
|
|
||||||
and type(self)._stream is BaseChatModel._stream
|
|
||||||
):
|
|
||||||
# No async or sync stream is implemented, so fall back to ainvoke
|
# No async or sync stream is implemented, so fall back to ainvoke
|
||||||
yield cast(
|
yield cast(
|
||||||
BaseMessageChunk,
|
BaseMessageChunk,
|
||||||
@ -760,20 +804,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
|
|
||||||
# If stream is not explicitly set, check if implicitly requested by
|
# If stream is not explicitly set, check if implicitly requested by
|
||||||
# astream_events() or astream_log(). Bail out if _stream not implemented
|
# astream_events() or astream_log(). Bail out if _stream not implemented
|
||||||
if type(self)._stream != BaseChatModel._stream and kwargs.pop(
|
if self._should_stream(
|
||||||
"stream",
|
async_api=False,
|
||||||
(
|
run_manager=run_manager,
|
||||||
next(
|
**kwargs,
|
||||||
(
|
|
||||||
True
|
|
||||||
for h in run_manager.handlers
|
|
||||||
if isinstance(h, _StreamingCallbackHandler)
|
|
||||||
),
|
|
||||||
False,
|
|
||||||
)
|
|
||||||
if run_manager
|
|
||||||
else False
|
|
||||||
),
|
|
||||||
):
|
):
|
||||||
chunks: List[ChatGenerationChunk] = []
|
chunks: List[ChatGenerationChunk] = []
|
||||||
for chunk in self._stream(messages, stop=stop, **kwargs):
|
for chunk in self._stream(messages, stop=stop, **kwargs):
|
||||||
@ -847,23 +881,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
|
|
||||||
# If stream is not explicitly set, check if implicitly requested by
|
# If stream is not explicitly set, check if implicitly requested by
|
||||||
# astream_events() or astream_log(). Bail out if _astream not implemented
|
# astream_events() or astream_log(). Bail out if _astream not implemented
|
||||||
if (
|
if self._should_stream(
|
||||||
type(self)._astream != BaseChatModel._astream
|
async_api=True,
|
||||||
or type(self)._stream != BaseChatModel._stream
|
run_manager=run_manager,
|
||||||
) and kwargs.pop(
|
**kwargs,
|
||||||
"stream",
|
|
||||||
(
|
|
||||||
next(
|
|
||||||
(
|
|
||||||
True
|
|
||||||
for h in run_manager.handlers
|
|
||||||
if isinstance(h, _StreamingCallbackHandler)
|
|
||||||
),
|
|
||||||
False,
|
|
||||||
)
|
|
||||||
if run_manager
|
|
||||||
else False
|
|
||||||
),
|
|
||||||
):
|
):
|
||||||
chunks: List[ChatGenerationChunk] = []
|
chunks: List[ChatGenerationChunk] = []
|
||||||
async for chunk in self._astream(messages, stop=stop, **kwargs):
|
async for chunk in self._astream(messages, stop=stop, **kwargs):
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""Test base chat model."""
|
"""Test base chat model."""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, AsyncIterator, Iterator, List, Optional
|
from typing import Any, AsyncIterator, Iterator, List, Literal, Optional, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -18,6 +18,7 @@ from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResu
|
|||||||
from langchain_core.outputs.llm_result import LLMResult
|
from langchain_core.outputs.llm_result import LLMResult
|
||||||
from langchain_core.tracers.base import BaseTracer
|
from langchain_core.tracers.base import BaseTracer
|
||||||
from langchain_core.tracers.context import collect_runs
|
from langchain_core.tracers.context import collect_runs
|
||||||
|
from langchain_core.tracers.event_stream import _AstreamEventsCallbackHandler
|
||||||
from langchain_core.tracers.schemas import Run
|
from langchain_core.tracers.schemas import Run
|
||||||
from tests.unit_tests.fake.callbacks import (
|
from tests.unit_tests.fake.callbacks import (
|
||||||
BaseFakeCallbackHandler,
|
BaseFakeCallbackHandler,
|
||||||
@ -272,3 +273,96 @@ async def test_async_pass_run_id() -> None:
|
|||||||
uid3 = uuid.uuid4()
|
uid3 = uuid.uuid4()
|
||||||
await llm.abatch([["Dummy message"]], {"callbacks": [cb], "run_id": uid3})
|
await llm.abatch([["Dummy message"]], {"callbacks": [cb], "run_id": uid3})
|
||||||
assert cb.traced_run_ids == [uid1, uid2, uid3]
|
assert cb.traced_run_ids == [uid1, uid2, uid3]
|
||||||
|
|
||||||
|
|
||||||
|
class NoStreamingModel(BaseChatModel):
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
return ChatResult(generations=[ChatGeneration(message=AIMessage("invoke"))])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "model1"
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingModel(NoStreamingModel):
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
|
yield ChatGenerationChunk(message=AIMessageChunk(content="stream"))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"])
|
||||||
|
async def test_disable_streaming(
|
||||||
|
disable_streaming: Union[bool, Literal["tool_calling"]],
|
||||||
|
) -> None:
|
||||||
|
model = StreamingModel(disable_streaming=disable_streaming)
|
||||||
|
assert model.invoke([]).content == "invoke"
|
||||||
|
assert (await model.ainvoke([])).content == "invoke"
|
||||||
|
|
||||||
|
expected = "invoke" if disable_streaming is True else "stream"
|
||||||
|
assert next(model.stream([])).content == expected
|
||||||
|
async for c in model.astream([]):
|
||||||
|
assert c.content == expected
|
||||||
|
break
|
||||||
|
assert (
|
||||||
|
model.invoke(
|
||||||
|
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}
|
||||||
|
).content
|
||||||
|
== expected
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
await model.ainvoke([], config={"callbacks": [_AstreamEventsCallbackHandler()]})
|
||||||
|
).content == expected
|
||||||
|
|
||||||
|
expected = "invoke" if disable_streaming in ("tool_calling", True) else "stream"
|
||||||
|
assert next(model.stream([], tools=[{"type": "function"}])).content == expected
|
||||||
|
async for c in model.astream([], tools=[{}]):
|
||||||
|
assert c.content == expected
|
||||||
|
break
|
||||||
|
assert (
|
||||||
|
model.invoke(
|
||||||
|
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}, tools=[{}]
|
||||||
|
).content
|
||||||
|
== expected
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
await model.ainvoke(
|
||||||
|
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}, tools=[{}]
|
||||||
|
)
|
||||||
|
).content == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"])
|
||||||
|
async def test_disable_streaming_no_streaming_model(
|
||||||
|
disable_streaming: Union[bool, Literal["tool_calling"]],
|
||||||
|
) -> None:
|
||||||
|
model = NoStreamingModel(disable_streaming=disable_streaming)
|
||||||
|
assert model.invoke([]).content == "invoke"
|
||||||
|
assert (await model.ainvoke([])).content == "invoke"
|
||||||
|
assert next(model.stream([])).content == "invoke"
|
||||||
|
async for c in model.astream([]):
|
||||||
|
assert c.content == "invoke"
|
||||||
|
break
|
||||||
|
assert (
|
||||||
|
model.invoke(
|
||||||
|
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}
|
||||||
|
).content
|
||||||
|
== "invoke"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
await model.ainvoke([], config={"callbacks": [_AstreamEventsCallbackHandler()]})
|
||||||
|
).content == "invoke"
|
||||||
|
assert next(model.stream([], tools=[{}])).content == "invoke"
|
||||||
|
async for c in model.astream([], tools=[{}]):
|
||||||
|
assert c.content == "invoke"
|
||||||
|
break
|
||||||
|
Loading…
Reference in New Issue
Block a user