mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-03 20:16:52 +00:00
core[minor]: allow LLMs async streaming to fallback on sync streaming (#18960)
- **Description:** Handling fallbacks when calling async streaming for a LLM that doesn't support it. - **Issue:** #18920 - **Twitter handle:**@maximeperrin_ --------- Co-authored-by: Maxime Perrin <mperrin@doing.fr>
This commit is contained in:
@@ -1,6 +1,13 @@
|
||||
from typing import Any, AsyncIterator, Iterator, List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.outputs.llm_result import LLMResult
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.llms import BaseLLM
|
||||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
||||
from langchain_core.tracers.context import collect_runs
|
||||
from tests.unit_tests.fake.callbacks import (
|
||||
BaseFakeCallbackHandler,
|
||||
@@ -113,3 +120,100 @@ async def test_stream_error_callback() -> None:
|
||||
pass
|
||||
|
||||
eval_response(cb_sync, i)
|
||||
|
||||
|
||||
async def test_astream_fallback_to_ainvoke() -> None:
|
||||
"""Test astream uses appropriate implementation."""
|
||||
|
||||
class ModelWithGenerate(BaseLLM):
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
generations = [Generation(text="hello")]
|
||||
return LLMResult(generations=[generations])
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-chat-model"
|
||||
|
||||
model = ModelWithGenerate()
|
||||
chunks = [chunk for chunk in model.stream("anything")]
|
||||
assert chunks == ["hello"]
|
||||
|
||||
chunks = [chunk async for chunk in model.astream("anything")]
|
||||
assert chunks == ["hello"]
|
||||
|
||||
|
||||
async def test_astream_implementation_fallback_to_stream() -> None:
|
||||
"""Test astream uses appropriate implementation."""
|
||||
|
||||
class ModelWithSyncStream(BaseLLM):
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Top Level call"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
"""Stream the output of the model."""
|
||||
yield GenerationChunk(text="a")
|
||||
yield GenerationChunk(text="b")
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-chat-model"
|
||||
|
||||
model = ModelWithSyncStream()
|
||||
chunks = [chunk for chunk in model.stream("anything")]
|
||||
assert chunks == ["a", "b"]
|
||||
assert type(model)._astream == BaseLLM._astream
|
||||
astream_chunks = [chunk async for chunk in model.astream("anything")]
|
||||
assert astream_chunks == ["a", "b"]
|
||||
|
||||
|
||||
async def test_astream_implementation_uses_astream() -> None:
|
||||
"""Test astream uses appropriate implementation."""
|
||||
|
||||
class ModelWithAsyncStream(BaseLLM):
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Top Level call"""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
"""Stream the output of the model."""
|
||||
yield GenerationChunk(text="a")
|
||||
yield GenerationChunk(text="b")
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-chat-model"
|
||||
|
||||
model = ModelWithAsyncStream()
|
||||
chunks = [chunk async for chunk in model.astream("anything")]
|
||||
assert chunks == ["a", "b"]
|
||||
|
Reference in New Issue
Block a user