diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index d2ff16912dd..e23f9ac5d28 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -3764,7 +3764,7 @@ def _construct_responses_api_payload( if payload.get("stream") and "partial_images" not in tool: # OpenAI requires this parameter be set; we ignore it during # streaming. - tool["partial_images"] = 1 + tool = {**tool, "partial_images": 1} else: pass diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_responses_stream.py b/libs/partners/openai/tests/unit_tests/chat_models/test_responses_stream.py index 75d31cf5e54..d710e1b4c46 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_responses_stream.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_responses_stream.py @@ -756,3 +756,38 @@ def test_responses_stream(output_version: str, expected_content: list[dict]) -> dumped = _strip_none(item.model_dump()) _ = dumped.pop("status", None) assert dumped == payload["input"][idx] + + +def test_responses_stream_with_image_generation_multiple_calls() -> None: + """Test that streaming with image_generation tool works across multiple calls. + + Regression test: image_generation tool should not be mutated between calls, + which would cause NotImplementedError on subsequent invocations. + """ + tools: list[dict[str, Any]] = [ + {"type": "image_generation"}, + {"type": "function", "name": "my_tool", "parameters": {}}, + ] + llm = ChatOpenAI( + model="gpt-4o", + use_responses_api=True, + streaming=True, + ) + llm_with_tools = llm.bind_tools(tools) + + mock_client = MagicMock() + + def mock_create(*args: Any, **kwargs: Any) -> MockSyncContextManager: + return MockSyncContextManager(responses_stream) + + mock_client.responses.create = mock_create + + # First call should work + with patch.object(llm, "root_client", mock_client): + chunks = list(llm_with_tools.stream("test")) + assert len(chunks) > 0 + + # Second call should also work (would fail before fix due to tool mutation) + with patch.object(llm, "root_client", mock_client): + chunks = list(llm_with_tools.stream("test again")) + assert len(chunks) > 0