core: Simplify astream logic in BaseChatModel and BaseLLM (#19332)

Covered by tests in
`libs/core/tests/unit_tests/language_models/chat_models/test_base.py`,
`libs/core/tests/unit_tests/language_models/llms/test_base.py` and
`libs/core/tests/unit_tests/runnables/test_runnable_events.py`
This commit is contained in:
Christophe Bornet 2024-03-20 14:05:51 +01:00 committed by GitHub
parent 40f846e65d
commit 4c2e887276
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 53 additions and 97 deletions

View File

@ -8,9 +8,7 @@ from abc import ABC, abstractmethod
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
AsyncGenerator,
AsyncIterator, AsyncIterator,
Callable,
Dict, Dict,
Iterator, Iterator,
List, List,
@ -99,26 +97,6 @@ async def agenerate_from_stream(
) )
def _as_async_iterator(sync_iterator: Callable) -> Callable:
"""Convert a sync iterator into an async iterator."""
async def _as_sync_iterator(*args: Any, **kwargs: Any) -> AsyncGenerator:
iterator = await run_in_executor(None, sync_iterator, *args, **kwargs)
done = object()
while True:
item = await run_in_executor(
None,
next,
iterator,
done, # type: ignore[call-arg, arg-type]
)
if item is done:
break
yield item # type: ignore[misc]
return _as_sync_iterator
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
"""Base class for Chat models.""" """Base class for Chat models."""
@ -270,28 +248,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[BaseMessageChunk]: ) -> AsyncIterator[BaseMessageChunk]:
if type(self)._astream is not BaseChatModel._astream: if (
# Then astream is implemented type(self)._astream is BaseChatModel._astream
_stream_implementation = self._astream and type(self)._stream is BaseChatModel._stream
using_sync_stream = False ):
elif type(self)._stream is not BaseChatModel._stream: # No async or sync stream is implemented, so fall back to ainvoke
# Then stream is implemented, so we can create an async iterator from it
# The typing is hard to type correctly with mypy here, so we cast
# and do a type ignore, this code is unit tested and should be fine.
_stream_implementation = cast( # type: ignore
Callable[
[
List[BaseMessage],
Optional[List[str]],
CallbackManagerForLLMRun,
Any,
],
AsyncIterator[ChatGenerationChunk],
],
_as_async_iterator(self._stream),
)
using_sync_stream = True
else: # No async or sync stream is implemented, so fall back to ainvoke
yield cast( yield cast(
BaseMessageChunk, BaseMessageChunk,
await self.ainvoke(input, config=config, stop=stop, **kwargs), await self.ainvoke(input, config=config, stop=stop, **kwargs),
@ -321,13 +282,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
batch_size=1, batch_size=1,
) )
run_manager_ = run_manager.get_sync() if using_sync_stream else run_manager
generation: Optional[ChatGenerationChunk] = None generation: Optional[ChatGenerationChunk] = None
try: try:
async for chunk in _stream_implementation( async for chunk in self._astream(
messages, messages,
stop=stop, stop=stop,
run_manager=run_manager_, # type: ignore[arg-type] run_manager=run_manager,
**kwargs, **kwargs,
): ):
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk) chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
@ -731,14 +691,32 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
) -> Iterator[ChatGenerationChunk]: ) -> Iterator[ChatGenerationChunk]:
raise NotImplementedError() raise NotImplementedError()
def _astream( async def _astream(
self, self,
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]: ) -> AsyncIterator[ChatGenerationChunk]:
raise NotImplementedError() iterator = await run_in_executor(
None,
self._stream,
messages,
stop,
run_manager.get_sync() if run_manager else None,
**kwargs,
)
done = object()
while True:
item = await run_in_executor(
None,
next,
iterator,
done, # type: ignore[call-arg, arg-type]
)
if item is done:
break
yield item # type: ignore[misc]
@deprecated("0.1.7", alternative="invoke", removal="0.2.0") @deprecated("0.1.7", alternative="invoke", removal="0.2.0")
def __call__( def __call__(

View File

@ -13,7 +13,6 @@ from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
Any, Any,
AsyncGenerator,
AsyncIterator, AsyncIterator,
Callable, Callable,
Dict, Dict,
@ -116,26 +115,6 @@ def create_base_retry_decorator(
) )
def _as_async_iterator(sync_iterator: Callable) -> Callable:
"""Convert a sync iterator into an async iterator."""
async def _as_sync_iterator(*args: Any, **kwargs: Any) -> AsyncGenerator:
iterator = await run_in_executor(None, sync_iterator, *args, **kwargs)
done = object()
while True:
item = await run_in_executor(
None,
next,
iterator,
done, # type: ignore[call-arg, arg-type]
)
if item is done:
break
yield item # type: ignore[misc]
return _as_sync_iterator
def get_prompts( def get_prompts(
params: Dict[str, Any], prompts: List[str] params: Dict[str, Any], prompts: List[str]
) -> Tuple[Dict[int, List], str, List[int], List[str]]: ) -> Tuple[Dict[int, List], str, List[int], List[str]]:
@ -460,28 +439,10 @@ class BaseLLM(BaseLanguageModel[str], ABC):
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[str]: ) -> AsyncIterator[str]:
if type(self)._astream is not BaseLLM._astream: if (
# model doesn't implement streaming, so use default implementation type(self)._astream is BaseLLM._astream
_stream_implementation = self._astream and type(self)._stream is BaseLLM._stream
using_sync_stream = False ):
elif type(self)._stream is not BaseLLM._stream:
# Then stream is implemented, so we can create an async iterator from it
# The typing is hard to type correctly with mypy here, so we cast
# and do a type ignore, this code is unit tested and should be fine.
_stream_implementation = cast( # type: ignore
Callable[
[
str,
Optional[List[str]],
CallbackManagerForLLMRun,
Any,
],
AsyncIterator[GenerationChunk],
],
_as_async_iterator(self._stream),
)
using_sync_stream = True
else:
yield await self.ainvoke(input, config=config, stop=stop, **kwargs) yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
return return
@ -509,13 +470,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
run_id=config.pop("run_id", None), run_id=config.pop("run_id", None),
batch_size=1, batch_size=1,
) )
run_manager_ = run_manager.get_sync() if using_sync_stream else run_manager
generation: Optional[GenerationChunk] = None generation: Optional[GenerationChunk] = None
try: try:
async for chunk in _stream_implementation( async for chunk in self._astream(
prompt, prompt,
stop=stop, stop=stop,
run_manager=run_manager_, # type: ignore[arg-type] run_manager=run_manager,
**kwargs, **kwargs,
): ):
yield chunk.text yield chunk.text
@ -571,14 +531,32 @@ class BaseLLM(BaseLanguageModel[str], ABC):
) -> Iterator[GenerationChunk]: ) -> Iterator[GenerationChunk]:
raise NotImplementedError() raise NotImplementedError()
def _astream( async def _astream(
self, self,
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[GenerationChunk]: ) -> AsyncIterator[GenerationChunk]:
raise NotImplementedError() iterator = await run_in_executor(
None,
self._stream,
prompt,
stop,
run_manager.get_sync() if run_manager else None,
**kwargs,
)
done = object()
while True:
item = await run_in_executor(
None,
next,
iterator,
done, # type: ignore[call-arg, arg-type]
)
if item is done:
break
yield item # type: ignore[misc]
def generate_prompt( def generate_prompt(
self, self,