mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-07 22:11:51 +00:00
core[patch]: base language model disable_streaming (#25070)
Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
"""Test base chat model."""
|
||||
|
||||
import uuid
|
||||
from typing import Any, AsyncIterator, Iterator, List, Optional
|
||||
from typing import Any, AsyncIterator, Iterator, List, Literal, Optional, Union
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -18,6 +18,7 @@ from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResu
|
||||
from langchain_core.outputs.llm_result import LLMResult
|
||||
from langchain_core.tracers.base import BaseTracer
|
||||
from langchain_core.tracers.context import collect_runs
|
||||
from langchain_core.tracers.event_stream import _AstreamEventsCallbackHandler
|
||||
from langchain_core.tracers.schemas import Run
|
||||
from tests.unit_tests.fake.callbacks import (
|
||||
BaseFakeCallbackHandler,
|
||||
@@ -272,3 +273,96 @@ async def test_async_pass_run_id() -> None:
|
||||
uid3 = uuid.uuid4()
|
||||
await llm.abatch([["Dummy message"]], {"callbacks": [cb], "run_id": uid3})
|
||||
assert cb.traced_run_ids == [uid1, uid2, uid3]
|
||||
|
||||
|
||||
class NoStreamingModel(BaseChatModel):
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage("invoke"))])
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "model1"
|
||||
|
||||
|
||||
class StreamingModel(NoStreamingModel):
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content="stream"))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"])
|
||||
async def test_disable_streaming(
|
||||
disable_streaming: Union[bool, Literal["tool_calling"]],
|
||||
) -> None:
|
||||
model = StreamingModel(disable_streaming=disable_streaming)
|
||||
assert model.invoke([]).content == "invoke"
|
||||
assert (await model.ainvoke([])).content == "invoke"
|
||||
|
||||
expected = "invoke" if disable_streaming is True else "stream"
|
||||
assert next(model.stream([])).content == expected
|
||||
async for c in model.astream([]):
|
||||
assert c.content == expected
|
||||
break
|
||||
assert (
|
||||
model.invoke(
|
||||
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}
|
||||
).content
|
||||
== expected
|
||||
)
|
||||
assert (
|
||||
await model.ainvoke([], config={"callbacks": [_AstreamEventsCallbackHandler()]})
|
||||
).content == expected
|
||||
|
||||
expected = "invoke" if disable_streaming in ("tool_calling", True) else "stream"
|
||||
assert next(model.stream([], tools=[{"type": "function"}])).content == expected
|
||||
async for c in model.astream([], tools=[{}]):
|
||||
assert c.content == expected
|
||||
break
|
||||
assert (
|
||||
model.invoke(
|
||||
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}, tools=[{}]
|
||||
).content
|
||||
== expected
|
||||
)
|
||||
assert (
|
||||
await model.ainvoke(
|
||||
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}, tools=[{}]
|
||||
)
|
||||
).content == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"])
|
||||
async def test_disable_streaming_no_streaming_model(
|
||||
disable_streaming: Union[bool, Literal["tool_calling"]],
|
||||
) -> None:
|
||||
model = NoStreamingModel(disable_streaming=disable_streaming)
|
||||
assert model.invoke([]).content == "invoke"
|
||||
assert (await model.ainvoke([])).content == "invoke"
|
||||
assert next(model.stream([])).content == "invoke"
|
||||
async for c in model.astream([]):
|
||||
assert c.content == "invoke"
|
||||
break
|
||||
assert (
|
||||
model.invoke(
|
||||
[], config={"callbacks": [_AstreamEventsCallbackHandler()]}
|
||||
).content
|
||||
== "invoke"
|
||||
)
|
||||
assert (
|
||||
await model.ainvoke([], config={"callbacks": [_AstreamEventsCallbackHandler()]})
|
||||
).content == "invoke"
|
||||
assert next(model.stream([], tools=[{}])).content == "invoke"
|
||||
async for c in model.astream([], tools=[{}]):
|
||||
assert c.content == "invoke"
|
||||
break
|
||||
|
Reference in New Issue
Block a user