mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 21:33:51 +00:00
core[minor]: allow LLMs async streaming to fallback on sync streaming (#18960)
- **Description:** Handling fallbacks when calling async streaming for a LLM that doesn't support it. - **Issue:** #18920 - **Twitter handle:**@maximeperrin_ --------- Co-authored-by: Maxime Perrin <mperrin@doing.fr>
This commit is contained in:
parent
caf47ab666
commit
aa785fa6ec
@ -12,6 +12,7 @@ from abc import ABC, abstractmethod
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
AsyncGenerator,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
@ -113,6 +114,26 @@ def create_base_retry_decorator(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _as_async_iterator(sync_iterator: Callable) -> Callable:
|
||||||
|
"""Convert a sync iterator into an async iterator."""
|
||||||
|
|
||||||
|
async def _as_sync_iterator(*args: Any, **kwargs: Any) -> AsyncGenerator:
|
||||||
|
iterator = await run_in_executor(None, sync_iterator, *args, **kwargs)
|
||||||
|
done = object()
|
||||||
|
while True:
|
||||||
|
item = await run_in_executor(
|
||||||
|
None,
|
||||||
|
next,
|
||||||
|
iterator,
|
||||||
|
done, # type: ignore[call-arg, arg-type]
|
||||||
|
)
|
||||||
|
if item is done:
|
||||||
|
break
|
||||||
|
yield item # type: ignore[misc]
|
||||||
|
|
||||||
|
return _as_sync_iterator
|
||||||
|
|
||||||
|
|
||||||
def get_prompts(
|
def get_prompts(
|
||||||
params: Dict[str, Any], prompts: List[str]
|
params: Dict[str, Any], prompts: List[str]
|
||||||
) -> Tuple[Dict[int, List], str, List[int], List[str]]:
|
) -> Tuple[Dict[int, List], str, List[int], List[str]]:
|
||||||
@ -434,10 +455,29 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[str]:
|
) -> AsyncIterator[str]:
|
||||||
if type(self)._astream == BaseLLM._astream:
|
if type(self)._astream is not BaseLLM._astream:
|
||||||
# model doesn't implement streaming, so use default implementation
|
# model doesn't implement streaming, so use default implementation
|
||||||
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
|
_stream_implementation = self._astream
|
||||||
|
elif type(self)._stream is not BaseLLM._stream:
|
||||||
|
# Then stream is implemented, so we can create an async iterator from it
|
||||||
|
# The typing is hard to type correctly with mypy here, so we cast
|
||||||
|
# and do a type ignore, this code is unit tested and should be fine.
|
||||||
|
_stream_implementation = cast( # type: ignore
|
||||||
|
Callable[
|
||||||
|
[
|
||||||
|
str,
|
||||||
|
Optional[List[str]],
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
Any,
|
||||||
|
],
|
||||||
|
AsyncIterator[GenerationChunk],
|
||||||
|
],
|
||||||
|
_as_async_iterator(self._stream),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
|
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
|
||||||
|
return
|
||||||
|
|
||||||
prompt = self._convert_input(input).to_string()
|
prompt = self._convert_input(input).to_string()
|
||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
params = self.dict()
|
params = self.dict()
|
||||||
@ -463,7 +503,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
)
|
)
|
||||||
generation: Optional[GenerationChunk] = None
|
generation: Optional[GenerationChunk] = None
|
||||||
try:
|
try:
|
||||||
async for chunk in self._astream(
|
async for chunk in _stream_implementation(
|
||||||
prompt, stop=stop, run_manager=run_manager, **kwargs
|
prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||||
):
|
):
|
||||||
yield chunk.text
|
yield chunk.text
|
||||||
@ -475,9 +515,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
await run_manager.on_llm_error(
|
await run_manager.on_llm_error(
|
||||||
e,
|
e,
|
||||||
response=LLMResult(
|
response=LLMResult(generations=[[generation]] if generation else []),
|
||||||
generations=[[generation]] if generation else []
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
|
@ -1,6 +1,13 @@
|
|||||||
|
from typing import Any, AsyncIterator, Iterator, List, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain_core.outputs.llm_result import LLMResult
|
from langchain_core.callbacks import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
|
from langchain_core.language_models.llms import BaseLLM
|
||||||
|
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
||||||
from langchain_core.tracers.context import collect_runs
|
from langchain_core.tracers.context import collect_runs
|
||||||
from tests.unit_tests.fake.callbacks import (
|
from tests.unit_tests.fake.callbacks import (
|
||||||
BaseFakeCallbackHandler,
|
BaseFakeCallbackHandler,
|
||||||
@ -113,3 +120,100 @@ async def test_stream_error_callback() -> None:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
eval_response(cb_sync, i)
|
eval_response(cb_sync, i)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_astream_fallback_to_ainvoke() -> None:
|
||||||
|
"""Test astream uses appropriate implementation."""
|
||||||
|
|
||||||
|
class ModelWithGenerate(BaseLLM):
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> LLMResult:
|
||||||
|
generations = [Generation(text="hello")]
|
||||||
|
return LLMResult(generations=[generations])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "fake-chat-model"
|
||||||
|
|
||||||
|
model = ModelWithGenerate()
|
||||||
|
chunks = [chunk for chunk in model.stream("anything")]
|
||||||
|
assert chunks == ["hello"]
|
||||||
|
|
||||||
|
chunks = [chunk async for chunk in model.astream("anything")]
|
||||||
|
assert chunks == ["hello"]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_astream_implementation_fallback_to_stream() -> None:
|
||||||
|
"""Test astream uses appropriate implementation."""
|
||||||
|
|
||||||
|
class ModelWithSyncStream(BaseLLM):
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> LLMResult:
|
||||||
|
"""Top Level call"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[GenerationChunk]:
|
||||||
|
"""Stream the output of the model."""
|
||||||
|
yield GenerationChunk(text="a")
|
||||||
|
yield GenerationChunk(text="b")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "fake-chat-model"
|
||||||
|
|
||||||
|
model = ModelWithSyncStream()
|
||||||
|
chunks = [chunk for chunk in model.stream("anything")]
|
||||||
|
assert chunks == ["a", "b"]
|
||||||
|
assert type(model)._astream == BaseLLM._astream
|
||||||
|
astream_chunks = [chunk async for chunk in model.astream("anything")]
|
||||||
|
assert astream_chunks == ["a", "b"]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_astream_implementation_uses_astream() -> None:
|
||||||
|
"""Test astream uses appropriate implementation."""
|
||||||
|
|
||||||
|
class ModelWithAsyncStream(BaseLLM):
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> LLMResult:
|
||||||
|
"""Top Level call"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[GenerationChunk]:
|
||||||
|
"""Stream the output of the model."""
|
||||||
|
yield GenerationChunk(text="a")
|
||||||
|
yield GenerationChunk(text="b")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "fake-chat-model"
|
||||||
|
|
||||||
|
model = ModelWithAsyncStream()
|
||||||
|
chunks = [chunk async for chunk in model.astream("anything")]
|
||||||
|
assert chunks == ["a", "b"]
|
||||||
|
Loading…
Reference in New Issue
Block a user