From 4b3dd345445cc24447e499e01fd6d4eecc1e6977 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Tue, 19 Mar 2024 12:32:33 -0400 Subject: [PATCH] core[patch]: Pass sync run manager for sync stream fallback in astream (#19280) This PR patches the fallback in chat models and language models to pass in the appropriate version of the run manager (sync vs. async) --- .../language_models/chat_models.py | 9 ++- .../language_models/fake_chat_models.py | 20 ------ .../langchain_core/language_models/llms.py | 8 ++- .../runnables/test_runnable_events.py | 62 +++++++++++++++++++ 4 files changed, 77 insertions(+), 22 deletions(-) diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index cfadfafa338..45a6ead8648 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -273,6 +273,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): if type(self)._astream is not BaseChatModel._astream: # Then astream is implemented _stream_implementation = self._astream + using_sync_stream = False elif type(self)._stream is not BaseChatModel._stream: # Then stream is implemented, so we can create an async iterator from it # The typing is hard to type correctly with mypy here, so we cast @@ -289,6 +290,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): ], _as_async_iterator(self._stream), ) + using_sync_stream = True else: # No async or sync stream is implemented, so fall back to ainvoke yield cast( BaseMessageChunk, @@ -318,10 +320,15 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): run_id=config.pop("run_id", None), batch_size=1, ) + + run_manager_ = run_manager.get_sync() if using_sync_stream else run_manager generation: Optional[ChatGenerationChunk] = None try: async for chunk in _stream_implementation( - messages, stop=stop, run_manager=run_manager, **kwargs + messages, + stop=stop, + run_manager=run_manager_, # type: ignore[arg-type] + **kwargs, ): chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk) yield chunk.message diff --git a/libs/core/langchain_core/language_models/fake_chat_models.py b/libs/core/langchain_core/language_models/fake_chat_models.py index 2f0fa6ffeab..8324787c507 100644 --- a/libs/core/langchain_core/language_models/fake_chat_models.py +++ b/libs/core/langchain_core/language_models/fake_chat_models.py @@ -11,7 +11,6 @@ from langchain_core.callbacks import ( from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.runnables import run_in_executor class FakeMessagesListChatModel(BaseChatModel): @@ -279,25 +278,6 @@ class GenericFakeChatModel(BaseChatModel): ) yield chunk - async def _astream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> AsyncIterator[ChatGenerationChunk]: - """Stream the output of the model.""" - result = await run_in_executor( - None, - self._stream, - messages, - stop=stop, - run_manager=run_manager.get_sync() if run_manager else None, - **kwargs, - ) - for chunk in result: - yield chunk - @property def _llm_type(self) -> str: return "generic-fake-chat-model" diff --git a/libs/core/langchain_core/language_models/llms.py b/libs/core/langchain_core/language_models/llms.py index 9f7ba5ce1ee..d8e1e009898 100644 --- a/libs/core/langchain_core/language_models/llms.py +++ b/libs/core/langchain_core/language_models/llms.py @@ -463,6 +463,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): if type(self)._astream is not BaseLLM._astream: # model doesn't implement streaming, so use default implementation _stream_implementation = self._astream + using_sync_stream = False elif type(self)._stream is not BaseLLM._stream: # Then stream is implemented, so we can create an async iterator from it # The typing is hard to type correctly with mypy here, so we cast @@ -479,6 +480,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): ], _as_async_iterator(self._stream), ) + using_sync_stream = True else: yield await self.ainvoke(input, config=config, stop=stop, **kwargs) return @@ -507,10 +509,14 @@ class BaseLLM(BaseLanguageModel[str], ABC): run_id=config.pop("run_id", None), batch_size=1, ) + run_manager_ = run_manager.get_sync() if using_sync_stream else run_manager generation: Optional[GenerationChunk] = None try: async for chunk in _stream_implementation( - prompt, stop=stop, run_manager=run_manager, **kwargs + prompt, + stop=stop, + run_manager=run_manager_, # type: ignore[arg-type] + **kwargs, ): yield chunk.text if generation is None: diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events.py b/libs/core/tests/unit_tests/runnables/test_runnable_events.py index 468806a12ae..d2d76d87799 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events.py @@ -312,6 +312,68 @@ async def test_event_stream_with_lambdas_from_lambda() -> None: ] +async def test_astream_events_from_model() -> None: + """Test the output of a model.""" + infinite_cycle = cycle( + [AIMessage(content="hello world!"), AIMessage(content="goodbye world!")] + ) + # When streaming GenericFakeChatModel breaks AIMessage into chunks based on spaces + model = ( + GenericFakeChatModel(messages=infinite_cycle) + .with_config( + { + "metadata": {"a": "b"}, + "tags": ["my_model"], + "run_name": "my_model", + } + ) + .bind(stop="") + ) + events = await _collect_events(model.astream_events("hello", version="v1")) + assert events == [ + { + "data": {"input": "hello"}, + "event": "on_chat_model_start", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessageChunk(content="hello")}, + "event": "on_chat_model_stream", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessageChunk(content=" ")}, + "event": "on_chat_model_stream", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessageChunk(content="world!")}, + "event": "on_chat_model_stream", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"output": AIMessageChunk(content="hello world!")}, + "event": "on_chat_model_end", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + ] + + async def test_event_stream_with_simple_chain() -> None: """Test as event stream.""" template = ChatPromptTemplate.from_messages(