core[minor]: allow LLMs async streaming to fallback on sync streaming (#18960)

- **Description:** Handling fallbacks when calling async streaming for a
LLM that doesn't support it.
- **Issue:** #18920 
- **Twitter handle:**@maximeperrin_

---------

Co-authored-by: Maxime Perrin <mperrin@doing.fr>
This commit is contained in:
Maxime Perrin
2024-03-15 21:06:50 +01:00
committed by GitHub
parent caf47ab666
commit aa785fa6ec
2 changed files with 188 additions and 46 deletions

View File

@@ -1,6 +1,13 @@
from typing import Any, AsyncIterator, Iterator, List, Optional
import pytest
from langchain_core.outputs.llm_result import LLMResult
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.tracers.context import collect_runs
from tests.unit_tests.fake.callbacks import (
BaseFakeCallbackHandler,
@@ -113,3 +120,100 @@ async def test_stream_error_callback() -> None:
pass
eval_response(cb_sync, i)
async def test_astream_fallback_to_ainvoke() -> None:
"""Test astream uses appropriate implementation."""
class ModelWithGenerate(BaseLLM):
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
generations = [Generation(text="hello")]
return LLMResult(generations=[generations])
@property
def _llm_type(self) -> str:
return "fake-chat-model"
model = ModelWithGenerate()
chunks = [chunk for chunk in model.stream("anything")]
assert chunks == ["hello"]
chunks = [chunk async for chunk in model.astream("anything")]
assert chunks == ["hello"]
async def test_astream_implementation_fallback_to_stream() -> None:
"""Test astream uses appropriate implementation."""
class ModelWithSyncStream(BaseLLM):
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
raise NotImplementedError()
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
"""Stream the output of the model."""
yield GenerationChunk(text="a")
yield GenerationChunk(text="b")
@property
def _llm_type(self) -> str:
return "fake-chat-model"
model = ModelWithSyncStream()
chunks = [chunk for chunk in model.stream("anything")]
assert chunks == ["a", "b"]
assert type(model)._astream == BaseLLM._astream
astream_chunks = [chunk async for chunk in model.astream("anything")]
assert astream_chunks == ["a", "b"]
async def test_astream_implementation_uses_astream() -> None:
"""Test astream uses appropriate implementation."""
class ModelWithAsyncStream(BaseLLM):
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
raise NotImplementedError()
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
"""Stream the output of the model."""
yield GenerationChunk(text="a")
yield GenerationChunk(text="b")
@property
def _llm_type(self) -> str:
return "fake-chat-model"
model = ModelWithAsyncStream()
chunks = [chunk async for chunk in model.astream("anything")]
assert chunks == ["a", "b"]