diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 701461ab595..6190519e3e5 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -7,7 +7,9 @@ from abc import ABC, abstractmethod from typing import ( TYPE_CHECKING, Any, + AsyncGenerator, AsyncIterator, + Callable, Dict, Iterator, List, @@ -100,6 +102,26 @@ async def agenerate_from_stream( ) +def _as_async_iterator(sync_iterator: Callable) -> Callable: + """Convert a sync iterator into an async iterator.""" + + async def _as_sync_iterator(*args: Any, **kwargs: Any) -> AsyncGenerator: + iterator = await run_in_executor(None, sync_iterator, *args, **kwargs) + done = object() + while True: + item = await run_in_executor( + None, + next, + iterator, + done, # type: ignore[call-arg, arg-type] + ) + if item is done: + break + yield item # type: ignore[misc] + + return _as_sync_iterator + + class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): """Base class for Chat models.""" @@ -259,57 +281,74 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): stop: Optional[List[str]] = None, **kwargs: Any, ) -> AsyncIterator[BaseMessageChunk]: - if type(self)._astream == BaseChatModel._astream: - # model doesn't implement streaming, so use default implementation + if type(self)._astream is not BaseChatModel._astream: + # Then astream is implemented + _stream_implementation = self._astream + elif type(self)._stream is not BaseChatModel._stream: + # Then stream is implemented, so we can create an async iterator from it + # The typing is hard to type correctly with mypy here, so we cast + # and do a type ignore, this code is unit tested and should be fine. + _stream_implementation = cast( # type: ignore + Callable[ + [ + List[BaseMessage], + Optional[List[str]], + CallbackManagerForLLMRun, + Any, + ], + AsyncIterator[ChatGenerationChunk], + ], + _as_async_iterator(self._stream), + ) + else: # No async or sync stream is implemented, so fall back to ainvoke yield cast( BaseMessageChunk, await self.ainvoke(input, config=config, stop=stop, **kwargs), ) + return + + config = ensure_config(config) + messages = self._convert_input(input).to_messages() + params = self._get_invocation_params(stop=stop, **kwargs) + options = {"stop": stop, **kwargs} + callback_manager = AsyncCallbackManager.configure( + config.get("callbacks"), + self.callbacks, + self.verbose, + config.get("tags"), + self.tags, + config.get("metadata"), + self.metadata, + ) + (run_manager,) = await callback_manager.on_chat_model_start( + dumpd(self), + [messages], + invocation_params=params, + options=options, + name=config.get("run_name"), + batch_size=1, + ) + generation: Optional[ChatGenerationChunk] = None + try: + async for chunk in _stream_implementation( + messages, stop=stop, run_manager=run_manager, **kwargs + ): + yield chunk.message + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None + except BaseException as e: + await run_manager.on_llm_error( + e, + response=LLMResult(generations=[[generation]] if generation else []), + ) + raise e else: - config = ensure_config(config) - messages = self._convert_input(input).to_messages() - params = self._get_invocation_params(stop=stop, **kwargs) - options = {"stop": stop, **kwargs} - callback_manager = AsyncCallbackManager.configure( - config.get("callbacks"), - self.callbacks, - self.verbose, - config.get("tags"), - self.tags, - config.get("metadata"), - self.metadata, + await run_manager.on_llm_end( + LLMResult(generations=[[generation]]), ) - (run_manager,) = await callback_manager.on_chat_model_start( - dumpd(self), - [messages], - invocation_params=params, - options=options, - name=config.get("run_name"), - batch_size=1, - ) - generation: Optional[ChatGenerationChunk] = None - try: - async for chunk in self._astream( - messages, stop=stop, run_manager=run_manager, **kwargs - ): - yield chunk.message - if generation is None: - generation = chunk - else: - generation += chunk - assert generation is not None - except BaseException as e: - await run_manager.on_llm_error( - e, - response=LLMResult( - generations=[[generation]] if generation else [] - ), - ) - raise e - else: - await run_manager.on_llm_end( - LLMResult(generations=[[generation]]), - ) # --- Custom methods --- diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py index 24c49f79a3f..9e412137c0a 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py @@ -1,8 +1,19 @@ """Test base chat model.""" +from typing import Any, AsyncIterator, Iterator, List, Optional + import pytest -from langchain_core.messages import HumanMessage, SystemMessage +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + HumanMessage, + SystemMessage, +) +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs.llm_result import LLMResult from langchain_core.tracers.context import collect_runs from tests.unit_tests.fake.callbacks import ( @@ -106,3 +117,111 @@ async def test_stream_error_callback() -> None: pass eval_response(cb_sync, i) + + +async def test_astream_fallback_to_ainvoke() -> None: + """Test astream uses appropriate implementation.""" + + class ModelWithGenerate(BaseChatModel): + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """Top Level call""" + message = AIMessage(content="hello") + generation = ChatGeneration(message=message) + return ChatResult(generations=[generation]) + + @property + def _llm_type(self) -> str: + return "fake-chat-model" + + model = ModelWithGenerate() + chunks = [chunk for chunk in model.stream("anything")] + assert chunks == [AIMessage(content="hello")] + + chunks = [chunk async for chunk in model.astream("anything")] + assert chunks == [AIMessage(content="hello")] + + +async def test_astream_implementation_fallback_to_stream() -> None: + """Test astream uses appropriate implementation.""" + + class ModelWithSyncStream(BaseChatModel): + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """Top Level call""" + raise NotImplementedError() + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + """Stream the output of the model.""" + yield ChatGenerationChunk(message=AIMessageChunk(content="a")) + yield ChatGenerationChunk(message=AIMessageChunk(content="b")) + + @property + def _llm_type(self) -> str: + return "fake-chat-model" + + model = ModelWithSyncStream() + chunks = [chunk for chunk in model.stream("anything")] + assert chunks == [ + AIMessageChunk(content="a"), + AIMessageChunk(content="b"), + ] + assert type(model)._astream == BaseChatModel._astream + astream_chunks = [chunk async for chunk in model.astream("anything")] + assert astream_chunks == [ + AIMessageChunk(content="a"), + AIMessageChunk(content="b"), + ] + + +async def test_astream_implementation_uses_astream() -> None: + """Test astream uses appropriate implementation.""" + + class ModelWithAsyncStream(BaseChatModel): + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """Top Level call""" + raise NotImplementedError() + + async def _astream( # type: ignore + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + """Stream the output of the model.""" + yield ChatGenerationChunk(message=AIMessageChunk(content="a")) + yield ChatGenerationChunk(message=AIMessageChunk(content="b")) + + @property + def _llm_type(self) -> str: + return "fake-chat-model" + + model = ModelWithAsyncStream() + chunks = [chunk async for chunk in model.astream("anything")] + assert chunks == [ + AIMessageChunk(content="a"), + AIMessageChunk(content="b"), + ]