From 4c2e887276a7f9816a496c8e1e0a4f273232cb0c Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Wed, 20 Mar 2024 14:05:51 +0100 Subject: [PATCH] core: Simplify astream logic in BaseChatModel and BaseLLM (#19332) Covered by tests in `libs/core/tests/unit_tests/language_models/chat_models/test_base.py`, `libs/core/tests/unit_tests/language_models/llms/test_base.py` and `libs/core/tests/unit_tests/runnables/test_runnable_events.py` --- .../language_models/chat_models.py | 76 +++++++------------ .../langchain_core/language_models/llms.py | 74 +++++++----------- 2 files changed, 53 insertions(+), 97 deletions(-) diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 45a6ead8648..60885a84a86 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -8,9 +8,7 @@ from abc import ABC, abstractmethod from typing import ( TYPE_CHECKING, Any, - AsyncGenerator, AsyncIterator, - Callable, Dict, Iterator, List, @@ -99,26 +97,6 @@ async def agenerate_from_stream( ) -def _as_async_iterator(sync_iterator: Callable) -> Callable: - """Convert a sync iterator into an async iterator.""" - - async def _as_sync_iterator(*args: Any, **kwargs: Any) -> AsyncGenerator: - iterator = await run_in_executor(None, sync_iterator, *args, **kwargs) - done = object() - while True: - item = await run_in_executor( - None, - next, - iterator, - done, # type: ignore[call-arg, arg-type] - ) - if item is done: - break - yield item # type: ignore[misc] - - return _as_sync_iterator - - class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): """Base class for Chat models.""" @@ -270,28 +248,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): stop: Optional[List[str]] = None, **kwargs: Any, ) -> AsyncIterator[BaseMessageChunk]: - if type(self)._astream is not BaseChatModel._astream: - # Then astream is implemented - _stream_implementation = self._astream - using_sync_stream = False - elif type(self)._stream is not BaseChatModel._stream: - # Then stream is implemented, so we can create an async iterator from it - # The typing is hard to type correctly with mypy here, so we cast - # and do a type ignore, this code is unit tested and should be fine. - _stream_implementation = cast( # type: ignore - Callable[ - [ - List[BaseMessage], - Optional[List[str]], - CallbackManagerForLLMRun, - Any, - ], - AsyncIterator[ChatGenerationChunk], - ], - _as_async_iterator(self._stream), - ) - using_sync_stream = True - else: # No async or sync stream is implemented, so fall back to ainvoke + if ( + type(self)._astream is BaseChatModel._astream + and type(self)._stream is BaseChatModel._stream + ): + # No async or sync stream is implemented, so fall back to ainvoke yield cast( BaseMessageChunk, await self.ainvoke(input, config=config, stop=stop, **kwargs), @@ -321,13 +282,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): batch_size=1, ) - run_manager_ = run_manager.get_sync() if using_sync_stream else run_manager generation: Optional[ChatGenerationChunk] = None try: - async for chunk in _stream_implementation( + async for chunk in self._astream( messages, stop=stop, - run_manager=run_manager_, # type: ignore[arg-type] + run_manager=run_manager, **kwargs, ): chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk) @@ -731,14 +691,32 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): ) -> Iterator[ChatGenerationChunk]: raise NotImplementedError() - def _astream( + async def _astream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: - raise NotImplementedError() + iterator = await run_in_executor( + None, + self._stream, + messages, + stop, + run_manager.get_sync() if run_manager else None, + **kwargs, + ) + done = object() + while True: + item = await run_in_executor( + None, + next, + iterator, + done, # type: ignore[call-arg, arg-type] + ) + if item is done: + break + yield item # type: ignore[misc] @deprecated("0.1.7", alternative="invoke", removal="0.2.0") def __call__( diff --git a/libs/core/langchain_core/language_models/llms.py b/libs/core/langchain_core/language_models/llms.py index d8e1e009898..7fe90a0ffda 100644 --- a/libs/core/langchain_core/language_models/llms.py +++ b/libs/core/langchain_core/language_models/llms.py @@ -13,7 +13,6 @@ from abc import ABC, abstractmethod from pathlib import Path from typing import ( Any, - AsyncGenerator, AsyncIterator, Callable, Dict, @@ -116,26 +115,6 @@ def create_base_retry_decorator( ) -def _as_async_iterator(sync_iterator: Callable) -> Callable: - """Convert a sync iterator into an async iterator.""" - - async def _as_sync_iterator(*args: Any, **kwargs: Any) -> AsyncGenerator: - iterator = await run_in_executor(None, sync_iterator, *args, **kwargs) - done = object() - while True: - item = await run_in_executor( - None, - next, - iterator, - done, # type: ignore[call-arg, arg-type] - ) - if item is done: - break - yield item # type: ignore[misc] - - return _as_sync_iterator - - def get_prompts( params: Dict[str, Any], prompts: List[str] ) -> Tuple[Dict[int, List], str, List[int], List[str]]: @@ -460,28 +439,10 @@ class BaseLLM(BaseLanguageModel[str], ABC): stop: Optional[List[str]] = None, **kwargs: Any, ) -> AsyncIterator[str]: - if type(self)._astream is not BaseLLM._astream: - # model doesn't implement streaming, so use default implementation - _stream_implementation = self._astream - using_sync_stream = False - elif type(self)._stream is not BaseLLM._stream: - # Then stream is implemented, so we can create an async iterator from it - # The typing is hard to type correctly with mypy here, so we cast - # and do a type ignore, this code is unit tested and should be fine. - _stream_implementation = cast( # type: ignore - Callable[ - [ - str, - Optional[List[str]], - CallbackManagerForLLMRun, - Any, - ], - AsyncIterator[GenerationChunk], - ], - _as_async_iterator(self._stream), - ) - using_sync_stream = True - else: + if ( + type(self)._astream is BaseLLM._astream + and type(self)._stream is BaseLLM._stream + ): yield await self.ainvoke(input, config=config, stop=stop, **kwargs) return @@ -509,13 +470,12 @@ class BaseLLM(BaseLanguageModel[str], ABC): run_id=config.pop("run_id", None), batch_size=1, ) - run_manager_ = run_manager.get_sync() if using_sync_stream else run_manager generation: Optional[GenerationChunk] = None try: - async for chunk in _stream_implementation( + async for chunk in self._astream( prompt, stop=stop, - run_manager=run_manager_, # type: ignore[arg-type] + run_manager=run_manager, **kwargs, ): yield chunk.text @@ -571,14 +531,32 @@ class BaseLLM(BaseLanguageModel[str], ABC): ) -> Iterator[GenerationChunk]: raise NotImplementedError() - def _astream( + async def _astream( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> AsyncIterator[GenerationChunk]: - raise NotImplementedError() + iterator = await run_in_executor( + None, + self._stream, + prompt, + stop, + run_manager.get_sync() if run_manager else None, + **kwargs, + ) + done = object() + while True: + item = await run_in_executor( + None, + next, + iterator, + done, # type: ignore[call-arg, arg-type] + ) + if item is done: + break + yield item # type: ignore[misc] def generate_prompt( self,