core: Simplify astream logic in BaseChatModel and BaseLLM (#19332)

Covered by tests in
`libs/core/tests/unit_tests/language_models/chat_models/test_base.py`,
`libs/core/tests/unit_tests/language_models/llms/test_base.py` and
`libs/core/tests/unit_tests/runnables/test_runnable_events.py`
This commit is contained in:
Christophe Bornet 2024-03-20 14:05:51 +01:00 committed by GitHub
parent 40f846e65d
commit 4c2e887276
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 53 additions and 97 deletions

View File

@ -8,9 +8,7 @@ from abc import ABC, abstractmethod
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
@ -99,26 +97,6 @@ async def agenerate_from_stream(
)
def _as_async_iterator(sync_iterator: Callable) -> Callable:
"""Convert a sync iterator into an async iterator."""
async def _as_sync_iterator(*args: Any, **kwargs: Any) -> AsyncGenerator:
iterator = await run_in_executor(None, sync_iterator, *args, **kwargs)
done = object()
while True:
item = await run_in_executor(
None,
next,
iterator,
done, # type: ignore[call-arg, arg-type]
)
if item is done:
break
yield item # type: ignore[misc]
return _as_sync_iterator
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
"""Base class for Chat models."""
@ -270,28 +248,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[BaseMessageChunk]:
if type(self)._astream is not BaseChatModel._astream:
# Then astream is implemented
_stream_implementation = self._astream
using_sync_stream = False
elif type(self)._stream is not BaseChatModel._stream:
# Then stream is implemented, so we can create an async iterator from it
# The typing is hard to type correctly with mypy here, so we cast
# and do a type ignore, this code is unit tested and should be fine.
_stream_implementation = cast( # type: ignore
Callable[
[
List[BaseMessage],
Optional[List[str]],
CallbackManagerForLLMRun,
Any,
],
AsyncIterator[ChatGenerationChunk],
],
_as_async_iterator(self._stream),
)
using_sync_stream = True
else: # No async or sync stream is implemented, so fall back to ainvoke
if (
type(self)._astream is BaseChatModel._astream
and type(self)._stream is BaseChatModel._stream
):
# No async or sync stream is implemented, so fall back to ainvoke
yield cast(
BaseMessageChunk,
await self.ainvoke(input, config=config, stop=stop, **kwargs),
@ -321,13 +282,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
batch_size=1,
)
run_manager_ = run_manager.get_sync() if using_sync_stream else run_manager
generation: Optional[ChatGenerationChunk] = None
try:
async for chunk in _stream_implementation(
async for chunk in self._astream(
messages,
stop=stop,
run_manager=run_manager_, # type: ignore[arg-type]
run_manager=run_manager,
**kwargs,
):
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
@ -731,14 +691,32 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
) -> Iterator[ChatGenerationChunk]:
raise NotImplementedError()
def _astream(
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
raise NotImplementedError()
iterator = await run_in_executor(
None,
self._stream,
messages,
stop,
run_manager.get_sync() if run_manager else None,
**kwargs,
)
done = object()
while True:
item = await run_in_executor(
None,
next,
iterator,
done, # type: ignore[call-arg, arg-type]
)
if item is done:
break
yield item # type: ignore[misc]
@deprecated("0.1.7", alternative="invoke", removal="0.2.0")
def __call__(

View File

@ -13,7 +13,6 @@ from abc import ABC, abstractmethod
from pathlib import Path
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Callable,
Dict,
@ -116,26 +115,6 @@ def create_base_retry_decorator(
)
def _as_async_iterator(sync_iterator: Callable) -> Callable:
"""Convert a sync iterator into an async iterator."""
async def _as_sync_iterator(*args: Any, **kwargs: Any) -> AsyncGenerator:
iterator = await run_in_executor(None, sync_iterator, *args, **kwargs)
done = object()
while True:
item = await run_in_executor(
None,
next,
iterator,
done, # type: ignore[call-arg, arg-type]
)
if item is done:
break
yield item # type: ignore[misc]
return _as_sync_iterator
def get_prompts(
params: Dict[str, Any], prompts: List[str]
) -> Tuple[Dict[int, List], str, List[int], List[str]]:
@ -460,28 +439,10 @@ class BaseLLM(BaseLanguageModel[str], ABC):
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[str]:
if type(self)._astream is not BaseLLM._astream:
# model doesn't implement streaming, so use default implementation
_stream_implementation = self._astream
using_sync_stream = False
elif type(self)._stream is not BaseLLM._stream:
# Then stream is implemented, so we can create an async iterator from it
# The typing is hard to type correctly with mypy here, so we cast
# and do a type ignore, this code is unit tested and should be fine.
_stream_implementation = cast( # type: ignore
Callable[
[
str,
Optional[List[str]],
CallbackManagerForLLMRun,
Any,
],
AsyncIterator[GenerationChunk],
],
_as_async_iterator(self._stream),
)
using_sync_stream = True
else:
if (
type(self)._astream is BaseLLM._astream
and type(self)._stream is BaseLLM._stream
):
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
return
@ -509,13 +470,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
run_id=config.pop("run_id", None),
batch_size=1,
)
run_manager_ = run_manager.get_sync() if using_sync_stream else run_manager
generation: Optional[GenerationChunk] = None
try:
async for chunk in _stream_implementation(
async for chunk in self._astream(
prompt,
stop=stop,
run_manager=run_manager_, # type: ignore[arg-type]
run_manager=run_manager,
**kwargs,
):
yield chunk.text
@ -571,14 +531,32 @@ class BaseLLM(BaseLanguageModel[str], ABC):
) -> Iterator[GenerationChunk]:
raise NotImplementedError()
def _astream(
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
raise NotImplementedError()
iterator = await run_in_executor(
None,
self._stream,
prompt,
stop,
run_manager.get_sync() if run_manager else None,
**kwargs,
)
done = object()
while True:
item = await run_in_executor(
None,
next,
iterator,
done, # type: ignore[call-arg, arg-type]
)
if item is done:
break
yield item # type: ignore[misc]
def generate_prompt(
self,