From fdfb51ad8daffa1e6e5c6889fd71627697de178e Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 27 Mar 2024 18:45:01 -0700 Subject: [PATCH] core: Two updates to chat model interface (#19684) - .stream() and .astream() call on_llm_new_token, removing the need for subclasses to do so. Backwards compatible because now we don't pass run_manager into ._stream and ._astream - .generate() and .agenerate() now handle `stream: bool` kwarg for _generate and _agenerate. Subclasses handle this arg by delegating to ._stream(), now one less thing they need to do. Backwards compat because this is an optional arg that we now never pass to the subclasses - .generate() and .agenerate() now inspect callback handlers to decide on a default value for stream:bool if not passed in. This auto enables streaming when using astream_events and astream_log - as a result of these three changes any usage of .astream_events and .astream_log should now yield chat model stream events - In future PRs we can update all subclasses to reflect these two things now handled by base class, but in meantime all will continue to work --- .../language_models/chat_models.py | 81 ++++++-- .../runnables/test_runnable_events.py | 188 +++++++++++++++++- 2 files changed, 254 insertions(+), 15 deletions(-) diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 60885a84a86..235b5d4b7b2 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -50,6 +50,7 @@ from langchain_core.outputs import ( from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.runnables.config import ensure_config, run_in_executor +from langchain_core.tracers.log_stream import LogStreamCallbackHandler if TYPE_CHECKING: from langchain_core.runnables import RunnableConfig @@ -219,9 +220,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): ) generation: Optional[ChatGenerationChunk] = None try: - for chunk in self._stream( - messages, stop=stop, run_manager=run_manager, **kwargs - ): + for chunk in self._stream(messages, stop=stop, **kwargs): + run_manager.on_llm_new_token( + cast(str, chunk.message.content), chunk=chunk + ) chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk) yield chunk.message if generation is None: @@ -287,9 +289,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): async for chunk in self._astream( messages, stop=stop, - run_manager=run_manager, **kwargs, ): + await run_manager.on_llm_new_token( + cast(str, chunk.message.content), chunk=chunk + ) chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk) yield chunk.message if generation is None: @@ -585,12 +589,37 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): raise ValueError( "Asked to cache, but no cache found at `langchain.cache`." ) - if inspect.signature(self._generate).parameters.get("run_manager"): - result = self._generate( - messages, stop=stop, run_manager=run_manager, **kwargs + # If stream is not explicitly set, check if implicitly requested by + # astream_events() or astream_log(). Bail out if _stream not implemented + if type(self)._stream != BaseChatModel._stream and kwargs.pop( + "stream", + next( + ( + True + for h in run_manager.handlers + if isinstance(h, LogStreamCallbackHandler) + ), + False, ) + if run_manager + else False, + ): + chunks: List[ChatGenerationChunk] = [] + for chunk in self._stream(messages, stop=stop, **kwargs): + if run_manager: + run_manager.on_llm_new_token( + cast(str, chunk.message.content), chunk=chunk + ) + chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk) + chunks.append(chunk) + result = generate_from_stream(iter(chunks)) else: - result = self._generate(messages, stop=stop, **kwargs) + if inspect.signature(self._generate).parameters.get("run_manager"): + result = self._generate( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + else: + result = self._generate(messages, stop=stop, **kwargs) # Add response metadata to each generation for generation in result.generations: @@ -634,12 +663,40 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): raise ValueError( "Asked to cache, but no cache found at `langchain.cache`." ) - if inspect.signature(self._agenerate).parameters.get("run_manager"): - result = await self._agenerate( - messages, stop=stop, run_manager=run_manager, **kwargs + # If stream is not explicitly set, check if implicitly requested by + # astream_events() or astream_log(). Bail out if _astream not implemented + if ( + type(self)._astream != BaseChatModel._astream + or type(self)._stream != BaseChatModel._stream + ) and kwargs.pop( + "stream", + next( + ( + True + for h in run_manager.handlers + if isinstance(h, LogStreamCallbackHandler) + ), + False, ) + if run_manager + else False, + ): + chunks: List[ChatGenerationChunk] = [] + async for chunk in self._astream(messages, stop=stop, **kwargs): + if run_manager: + await run_manager.on_llm_new_token( + cast(str, chunk.message.content), chunk=chunk + ) + chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk) + chunks.append(chunk) + result = generate_from_stream(iter(chunks)) else: - result = await self._agenerate(messages, stop=stop, **kwargs) + if inspect.signature(self._agenerate).parameters.get("run_manager"): + result = await self._agenerate( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + else: + result = await self._agenerate(messages, stop=stop, **kwargs) # Add response metadata to each generation for generation in result.generations: diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events.py b/libs/core/tests/unit_tests/runnables/test_runnable_events.py index 3b822bce6fe..bc5d6102ecc 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events.py @@ -1,4 +1,5 @@ """Module that contains tests for runnable.astream_events API.""" +import sys from itertools import cycle from typing import Any, AsyncIterator, Dict, List, Sequence, cast @@ -22,6 +23,7 @@ from langchain_core.retrievers import BaseRetriever from langchain_core.runnables import ( ConfigurableField, Runnable, + RunnableConfig, RunnableLambda, ) from langchain_core.runnables.history import RunnableWithMessageHistory @@ -314,9 +316,7 @@ async def test_event_stream_with_lambdas_from_lambda() -> None: async def test_astream_events_from_model() -> None: """Test the output of a model.""" - infinite_cycle = cycle( - [AIMessage(content="hello world!"), AIMessage(content="goodbye world!")] - ) + infinite_cycle = cycle([AIMessage(content="hello world!")]) # When streaming GenericFakeChatModel breaks AIMessage into chunks based on spaces model = ( GenericFakeChatModel(messages=infinite_cycle) @@ -373,6 +373,188 @@ async def test_astream_events_from_model() -> None: }, ] + @RunnableLambda + def i_dont_stream(input: Any, config: RunnableConfig) -> Any: + if sys.version_info >= (3, 11): + return model.invoke(input) + else: + return model.invoke(input, config) + + events = await _collect_events(i_dont_stream.astream_events("hello", version="v1")) + assert events == [ + { + "data": {"input": "hello"}, + "event": "on_chain_start", + "metadata": {}, + "name": "i_dont_stream", + "run_id": "", + "tags": [], + }, + { + "data": {"input": {"messages": [[HumanMessage(content="hello")]]}}, + "event": "on_chat_model_start", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessageChunk(content="hello")}, + "event": "on_chat_model_stream", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessageChunk(content=" ")}, + "event": "on_chat_model_stream", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessageChunk(content="world!")}, + "event": "on_chat_model_stream", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": { + "input": {"messages": [[HumanMessage(content="hello")]]}, + "output": { + "generations": [ + [ + { + "generation_info": None, + "message": AIMessage(content="hello world!"), + "text": "hello world!", + "type": "ChatGeneration", + } + ] + ], + "llm_output": None, + "run": None, + }, + }, + "event": "on_chat_model_end", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessage(content="hello world!")}, + "event": "on_chain_stream", + "metadata": {}, + "name": "i_dont_stream", + "run_id": "", + "tags": [], + }, + { + "data": {"output": AIMessage(content="hello world!")}, + "event": "on_chain_end", + "metadata": {}, + "name": "i_dont_stream", + "run_id": "", + "tags": [], + }, + ] + + @RunnableLambda + async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: + if sys.version_info >= (3, 11): + return await model.ainvoke(input) + else: + return await model.ainvoke(input, config) + + events = await _collect_events(ai_dont_stream.astream_events("hello", version="v1")) + assert events == [ + { + "data": {"input": "hello"}, + "event": "on_chain_start", + "metadata": {}, + "name": "ai_dont_stream", + "run_id": "", + "tags": [], + }, + { + "data": {"input": {"messages": [[HumanMessage(content="hello")]]}}, + "event": "on_chat_model_start", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessageChunk(content="hello")}, + "event": "on_chat_model_stream", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessageChunk(content=" ")}, + "event": "on_chat_model_stream", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessageChunk(content="world!")}, + "event": "on_chat_model_stream", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": { + "input": {"messages": [[HumanMessage(content="hello")]]}, + "output": { + "generations": [ + [ + { + "generation_info": None, + "message": AIMessage(content="hello world!"), + "text": "hello world!", + "type": "ChatGeneration", + } + ] + ], + "llm_output": None, + "run": None, + }, + }, + "event": "on_chat_model_end", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessage(content="hello world!")}, + "event": "on_chain_stream", + "metadata": {}, + "name": "ai_dont_stream", + "run_id": "", + "tags": [], + }, + { + "data": {"output": AIMessage(content="hello world!")}, + "event": "on_chain_end", + "metadata": {}, + "name": "ai_dont_stream", + "run_id": "", + "tags": [], + }, + ] + async def test_event_stream_with_simple_chain() -> None: """Test as event stream."""