diff --git a/libs/langchain/langchain/chat_models/base.py b/libs/langchain/langchain/chat_models/base.py index 09199e30dc9..2d0db37c0ac 100644 --- a/libs/langchain/langchain/chat_models/base.py +++ b/libs/langchain/langchain/chat_models/base.py @@ -176,22 +176,22 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): dumpd(self), [messages], invocation_params=params, options=options ) try: - message: Optional[BaseMessageChunk] = None + generation: Optional[ChatGenerationChunk] = None for chunk in self._stream( messages, stop=stop, run_manager=run_manager, **kwargs ): yield chunk.message - if message is None: - message = chunk.message + if generation is None: + generation = chunk else: - message += chunk.message - assert message is not None + generation += chunk + assert generation is not None except (KeyboardInterrupt, Exception) as e: run_manager.on_llm_error(e) raise e else: run_manager.on_llm_end( - LLMResult(generations=[[ChatGeneration(message=message)]]), + LLMResult(generations=[[generation]]), ) async def astream( @@ -223,22 +223,22 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): dumpd(self), [messages], invocation_params=params, options=options ) try: - message: Optional[BaseMessageChunk] = None + generation: Optional[ChatGenerationChunk] = None async for chunk in self._astream( messages, stop=stop, run_manager=run_manager, **kwargs ): yield chunk.message - if message is None: - message = chunk.message + if generation is None: + generation = chunk else: - message += chunk.message - assert message is not None + generation += chunk + assert generation is not None except (KeyboardInterrupt, Exception) as e: await run_manager.on_llm_error(e) raise e else: await run_manager.on_llm_end( - LLMResult(generations=[[ChatGeneration(message=message)]]), + LLMResult(generations=[[generation]]), ) # --- Custom methods --- diff --git a/libs/langchain/tests/integration_tests/chat_models/test_openai.py b/libs/langchain/tests/integration_tests/chat_models/test_openai.py index 19adbf1cd41..7637014a4c1 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_openai.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_openai.py @@ -1,6 +1,8 @@ """Test ChatOpenAI wrapper.""" +from typing import Any + import pytest from langchain.callbacks.manager import CallbackManager @@ -89,6 +91,34 @@ def test_chat_openai_streaming() -> None: assert isinstance(response, BaseMessage) +@pytest.mark.scheduled +def test_chat_openai_streaming_generation_info() -> None: + """Test that generation info is preserved when streaming.""" + + class _FakeCallback(FakeCallbackHandler): + saved_things: dict = {} + + def on_llm_end( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + # Save the generation + self.saved_things["generation"] = args[0] + + callback = _FakeCallback() + callback_manager = CallbackManager([callback]) + chat = ChatOpenAI( + max_tokens=2, + temperature=0, + callback_manager=callback_manager, + ) + list(chat.stream("hi")) + generation = callback.saved_things["generation"] + # `Hello!` is two tokens, assert that that is what is returned + assert generation.generations[0][0].text == "Hello!" + + def test_chat_openai_llm_output_contains_model_name() -> None: """Test llm_output contains model_name.""" chat = ChatOpenAI(max_tokens=10)