core[minor]: Chat Models to fallback astream to fallback on sync stream if available (#18748)

Allows all chat models that implement _stream, but not _astream to still have async streaming to work.

Amongst other things this should resolve issues with streaming community model implementations through langserve since langserve is exclusively async.
This commit is contained in:
Eugene Yurtsev 2024-03-08 13:27:29 -05:00 committed by GitHub
parent 3624f56ccb
commit cdfb5b4ca1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 204 additions and 46 deletions

View File

@ -7,7 +7,9 @@ from abc import ABC, abstractmethod
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
@ -100,6 +102,26 @@ async def agenerate_from_stream(
)
def _as_async_iterator(sync_iterator: Callable) -> Callable:
"""Convert a sync iterator into an async iterator."""
async def _as_sync_iterator(*args: Any, **kwargs: Any) -> AsyncGenerator:
iterator = await run_in_executor(None, sync_iterator, *args, **kwargs)
done = object()
while True:
item = await run_in_executor(
None,
next,
iterator,
done, # type: ignore[call-arg, arg-type]
)
if item is done:
break
yield item # type: ignore[misc]
return _as_sync_iterator
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
"""Base class for Chat models."""
@ -259,57 +281,74 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[BaseMessageChunk]:
if type(self)._astream == BaseChatModel._astream:
# model doesn't implement streaming, so use default implementation
if type(self)._astream is not BaseChatModel._astream:
# Then astream is implemented
_stream_implementation = self._astream
elif type(self)._stream is not BaseChatModel._stream:
# Then stream is implemented, so we can create an async iterator from it
# The typing is hard to type correctly with mypy here, so we cast
# and do a type ignore, this code is unit tested and should be fine.
_stream_implementation = cast( # type: ignore
Callable[
[
List[BaseMessage],
Optional[List[str]],
CallbackManagerForLLMRun,
Any,
],
AsyncIterator[ChatGenerationChunk],
],
_as_async_iterator(self._stream),
)
else: # No async or sync stream is implemented, so fall back to ainvoke
yield cast(
BaseMessageChunk,
await self.ainvoke(input, config=config, stop=stop, **kwargs),
)
return
config = ensure_config(config)
messages = self._convert_input(input).to_messages()
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop, **kwargs}
callback_manager = AsyncCallbackManager.configure(
config.get("callbacks"),
self.callbacks,
self.verbose,
config.get("tags"),
self.tags,
config.get("metadata"),
self.metadata,
)
(run_manager,) = await callback_manager.on_chat_model_start(
dumpd(self),
[messages],
invocation_params=params,
options=options,
name=config.get("run_name"),
batch_size=1,
)
generation: Optional[ChatGenerationChunk] = None
try:
async for chunk in _stream_implementation(
messages, stop=stop, run_manager=run_manager, **kwargs
):
yield chunk.message
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
except BaseException as e:
await run_manager.on_llm_error(
e,
response=LLMResult(generations=[[generation]] if generation else []),
)
raise e
else:
config = ensure_config(config)
messages = self._convert_input(input).to_messages()
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop, **kwargs}
callback_manager = AsyncCallbackManager.configure(
config.get("callbacks"),
self.callbacks,
self.verbose,
config.get("tags"),
self.tags,
config.get("metadata"),
self.metadata,
await run_manager.on_llm_end(
LLMResult(generations=[[generation]]),
)
(run_manager,) = await callback_manager.on_chat_model_start(
dumpd(self),
[messages],
invocation_params=params,
options=options,
name=config.get("run_name"),
batch_size=1,
)
generation: Optional[ChatGenerationChunk] = None
try:
async for chunk in self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
):
yield chunk.message
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
except BaseException as e:
await run_manager.on_llm_error(
e,
response=LLMResult(
generations=[[generation]] if generation else []
),
)
raise e
else:
await run_manager.on_llm_end(
LLMResult(generations=[[generation]]),
)
# --- Custom methods ---

View File

@ -1,8 +1,19 @@
"""Test base chat model."""
from typing import Any, AsyncIterator, Iterator, List, Optional
import pytest
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.outputs.llm_result import LLMResult
from langchain_core.tracers.context import collect_runs
from tests.unit_tests.fake.callbacks import (
@ -106,3 +117,111 @@ async def test_stream_error_callback() -> None:
pass
eval_response(cb_sync, i)
async def test_astream_fallback_to_ainvoke() -> None:
"""Test astream uses appropriate implementation."""
class ModelWithGenerate(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
message = AIMessage(content="hello")
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
@property
def _llm_type(self) -> str:
return "fake-chat-model"
model = ModelWithGenerate()
chunks = [chunk for chunk in model.stream("anything")]
assert chunks == [AIMessage(content="hello")]
chunks = [chunk async for chunk in model.astream("anything")]
assert chunks == [AIMessage(content="hello")]
async def test_astream_implementation_fallback_to_stream() -> None:
"""Test astream uses appropriate implementation."""
class ModelWithSyncStream(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
raise NotImplementedError()
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Stream the output of the model."""
yield ChatGenerationChunk(message=AIMessageChunk(content="a"))
yield ChatGenerationChunk(message=AIMessageChunk(content="b"))
@property
def _llm_type(self) -> str:
return "fake-chat-model"
model = ModelWithSyncStream()
chunks = [chunk for chunk in model.stream("anything")]
assert chunks == [
AIMessageChunk(content="a"),
AIMessageChunk(content="b"),
]
assert type(model)._astream == BaseChatModel._astream
astream_chunks = [chunk async for chunk in model.astream("anything")]
assert astream_chunks == [
AIMessageChunk(content="a"),
AIMessageChunk(content="b"),
]
async def test_astream_implementation_uses_astream() -> None:
"""Test astream uses appropriate implementation."""
class ModelWithAsyncStream(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
raise NotImplementedError()
async def _astream( # type: ignore
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
"""Stream the output of the model."""
yield ChatGenerationChunk(message=AIMessageChunk(content="a"))
yield ChatGenerationChunk(message=AIMessageChunk(content="b"))
@property
def _llm_type(self) -> str:
return "fake-chat-model"
model = ModelWithAsyncStream()
chunks = [chunk async for chunk in model.astream("anything")]
assert chunks == [
AIMessageChunk(content="a"),
AIMessageChunk(content="b"),
]