From 03035b8d732fbc50150823808a051275611f2b27 Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Tue, 11 Mar 2025 16:05:40 -0400 Subject: [PATCH] (a)stream --- .../langchain_openai/chat_models/base.py | 110 +++++++++++++++--- .../chat_models/test_base.py | 23 ++++ 2 files changed, 118 insertions(+), 15 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 68c7ddce595..12aa3abf64f 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -356,6 +356,27 @@ def _convert_delta_to_message_chunk( return default_class(content=content, id=id_) # type: ignore +def _convert_responses_chunk_to_generation_chunk( + chunk: Any, +) -> Optional[ChatGenerationChunk]: + content = "" + generation_info = None + usage_metadata = None + if chunk.type == "response.output_text.delta": + content = chunk.delta + elif chunk.type == "response.completed": + token_usage = chunk.response.usage.model_dump() if chunk.response.usage else {} + usage_metadata = _create_usage_metadata_responses(token_usage) + generation_info = {"model_name": chunk.response.model} + else: + return None + + return ChatGenerationChunk( + message=AIMessageChunk(content=content, usage_metadata=usage_metadata), + generation_info=generation_info, + ) + + def _update_token_usage( overall_token_usage: Union[int, dict], new_usage: Union[int, dict] ) -> Union[int, dict]: @@ -410,7 +431,7 @@ def _handle_openai_bad_request(e: openai.BadRequestError) -> None: def _is_builtin_tool(tool: dict) -> bool: - return set(tool.keys()) == {"type"} + return "type" in tool and tool["type"] != "function" def _transform_payload_for_responses(payload: dict) -> dict: @@ -741,6 +762,54 @@ class BaseChatOpenAI(BaseChatModel): ) return generation_chunk + def _stream_responses( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + kwargs["stream"] = True + payload = self._get_request_payload(messages, stop=stop, **kwargs) + responses_payload = _transform_payload_for_responses(payload) + context_manager = self.root_client.responses.create(**responses_payload) + + with context_manager as response: + for chunk in response: + if generation_chunk := _convert_responses_chunk_to_generation_chunk( + chunk + ): + if run_manager: + run_manager.on_llm_new_token( + generation_chunk.text, chunk=generation_chunk + ) + yield generation_chunk + + async def _astream_responses( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + kwargs["stream"] = True + payload = self._get_request_payload(messages, stop=stop, **kwargs) + responses_payload = _transform_payload_for_responses(payload) + context_manager = await self.root_async_client.responses.create( + **responses_payload + ) + + async with context_manager as response: + async for chunk in response: + if generation_chunk := _convert_responses_chunk_to_generation_chunk( + chunk + ): + if run_manager: + await run_manager.on_llm_new_token( + generation_chunk.text, chunk=generation_chunk + ) + yield generation_chunk + def _stream( self, messages: List[BaseMessage], @@ -2143,27 +2212,38 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override] self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any ) -> Iterator[ChatGenerationChunk]: """Set default stream_options.""" - stream_usage = self._should_stream_usage(stream_usage, **kwargs) - # Note: stream_options is not a valid parameter for Azure OpenAI. - # To support users proxying Azure through ChatOpenAI, here we only specify - # stream_options if include_usage is set to True. - # See https://learn.microsoft.com/en-us/azure/ai-services/openai/whats-new - # for release notes. - if stream_usage: - kwargs["stream_options"] = {"include_usage": stream_usage} + if "tools" in kwargs and any( + _is_builtin_tool(tool) for tool in kwargs["tools"] + ): + return super()._stream_responses(*args, **kwargs) + else: + stream_usage = self._should_stream_usage(stream_usage, **kwargs) + # Note: stream_options is not a valid parameter for Azure OpenAI. + # To support users proxying Azure through ChatOpenAI, here we only specify + # stream_options if include_usage is set to True. + # See https://learn.microsoft.com/en-us/azure/ai-services/openai/whats-new + # for release notes. + if stream_usage: + kwargs["stream_options"] = {"include_usage": stream_usage} - return super()._stream(*args, **kwargs) + return super()._stream(*args, **kwargs) async def _astream( self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any ) -> AsyncIterator[ChatGenerationChunk]: """Set default stream_options.""" - stream_usage = self._should_stream_usage(stream_usage, **kwargs) - if stream_usage: - kwargs["stream_options"] = {"include_usage": stream_usage} + if "tools" in kwargs and any( + _is_builtin_tool(tool) for tool in kwargs["tools"] + ): + async for chunk in super()._astream_responses(*args, **kwargs): + yield chunk + else: + stream_usage = self._should_stream_usage(stream_usage, **kwargs) + if stream_usage: + kwargs["stream_options"] = {"include_usage": stream_usage} - async for chunk in super()._astream(*args, **kwargs): - yield chunk + async for chunk in super()._astream(*args, **kwargs): + yield chunk def with_structured_output( self, diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index e4643c1acee..33c3bf4eb00 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -1245,6 +1245,18 @@ def test_web_search() -> None: assert response.response_metadata["model_name"] assert response.response_metadata["status"] + # Test streaming + full: Optional[BaseMessageChunk] = None + for chunk in llm.stream( + "What was a positive news story from today?", + tools=[{"type": "web_search_preview"}], + ): + assert isinstance(chunk, AIMessageChunk) + full = chunk if full is None else full + chunk + assert isinstance(full, AIMessageChunk) + assert full.content + assert full.usage_metadata + async def test_web_search_async() -> None: llm = ChatOpenAI(model="gpt-4o") @@ -1260,3 +1272,14 @@ async def test_web_search_async() -> None: assert response.usage_metadata["total_tokens"] > 0 assert response.response_metadata["model_name"] assert response.response_metadata["status"] + + full: Optional[BaseMessageChunk] = None + async for chunk in llm.astream( + "What was a positive news story from today?", + tools=[{"type": "web_search_preview"}], + ): + assert isinstance(chunk, AIMessageChunk) + full = chunk if full is None else full + chunk + assert isinstance(full, AIMessageChunk) + assert full.content + assert full.usage_metadata