From c74e7b997daaae36166e71ce685850f8dc9db28e Mon Sep 17 00:00:00 2001 From: ccurme Date: Fri, 14 Mar 2025 15:14:23 -0400 Subject: [PATCH] openai[patch]: support structured output via Responses API (#30265) Also runs all standard tests using Responses API. --- .../langchain_openai/chat_models/base.py | 120 +++++++++++++++--- .../chat_models/test_responses_api.py | 117 ++++++++++++++++- .../chat_models/test_responses_standard.py | 23 ++++ .../test_responses_standard.ambr | 31 +++++ .../tests/unit_tests/chat_models/test_base.py | 17 --- .../chat_models/test_responses_standard.py | 36 ++++++ .../integration_tests/chat_models.py | 14 +- 7 files changed, 308 insertions(+), 50 deletions(-) create mode 100644 libs/partners/openai/tests/integration_tests/chat_models/test_responses_standard.py create mode 100644 libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_responses_standard.ambr create mode 100644 libs/partners/openai/tests/unit_tests/chat_models/test_responses_standard.py diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index a2720764a86..4f0e9b3f805 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -751,11 +751,12 @@ class BaseChatOpenAI(BaseChatModel): kwargs["stream"] = True payload = self._get_request_payload(messages, stop=stop, **kwargs) context_manager = self.root_client.responses.create(**payload) + original_schema_obj = kwargs.get("response_format") with context_manager as response: for chunk in response: if generation_chunk := _convert_responses_chunk_to_generation_chunk( - chunk + chunk, schema=original_schema_obj ): if run_manager: run_manager.on_llm_new_token( @@ -773,11 +774,12 @@ class BaseChatOpenAI(BaseChatModel): kwargs["stream"] = True payload = self._get_request_payload(messages, stop=stop, **kwargs) context_manager = await self.root_async_client.responses.create(**payload) + original_schema_obj = kwargs.get("response_format") async with context_manager as response: async for chunk in response: if generation_chunk := _convert_responses_chunk_to_generation_chunk( - chunk + chunk, schema=original_schema_obj ): if run_manager: await run_manager.on_llm_new_token( @@ -880,8 +882,14 @@ class BaseChatOpenAI(BaseChatModel): response = raw_response.parse() generation_info = {"headers": dict(raw_response.headers)} elif self._use_responses_api(payload): - response = self.root_client.responses.create(**payload) - return _construct_lc_result_from_responses_api(response) + original_schema_obj = kwargs.get("response_format") + if original_schema_obj and _is_pydantic_class(original_schema_obj): + response = self.root_client.responses.parse(**payload) + else: + response = self.root_client.responses.create(**payload) + return _construct_lc_result_from_responses_api( + response, schema=original_schema_obj + ) else: response = self.client.create(**payload) return self._create_chat_result(response, generation_info) @@ -1062,8 +1070,15 @@ class BaseChatOpenAI(BaseChatModel): response = raw_response.parse() generation_info = {"headers": dict(raw_response.headers)} elif self._use_responses_api(payload): - response = await self.root_async_client.responses.create(**payload) - return _construct_lc_result_from_responses_api(response) + original_schema_obj = kwargs.get("response_format") + if original_schema_obj and _is_pydantic_class(original_schema_obj): + response = await self.root_async_client.responses.parse(**payload) + else: + response = await self.root_async_client.responses.create(**payload) + return _construct_lc_result_from_responses_api( + response, schema=original_schema_obj + ) + else: response = await self.async_client.create(**payload) return await run_in_executor( @@ -2833,23 +2848,45 @@ def _construct_responses_api_payload( if tool_choice := payload.pop("tool_choice", None): # chat api: {"type": "function", "function": {"name": "..."}} # responses api: {"type": "function", "name": "..."} - if tool_choice["type"] == "function" and "function" in tool_choice: + if ( + isinstance(tool_choice, dict) + and tool_choice["type"] == "function" + and "function" in tool_choice + ): payload["tool_choice"] = {"type": "function", **tool_choice["function"]} else: payload["tool_choice"] = tool_choice - if response_format := payload.pop("response_format", None): + + # Structured output + if schema := payload.pop("response_format", None): if payload.get("text"): text = payload["text"] raise ValueError( "Can specify at most one of 'response_format' or 'text', received both:" - f"\n{response_format=}\n{text=}" + f"\n{schema=}\n{text=}" ) - # chat api: {"type": "json_schema, "json_schema": {"schema": {...}, "name": "...", "description": "...", "strict": ...}} # noqa: E501 - # responses api: {"type": "json_schema, "schema": {...}, "name": "...", "description": "...", "strict": ...} # noqa: E501 - if response_format["type"] == "json_schema": - payload["text"] = {"type": "json_schema", **response_format["json_schema"]} + + # For pydantic + non-streaming case, we use responses.parse. + # Otherwise, we use responses.create. + if not payload.get("stream") and _is_pydantic_class(schema): + payload["text_format"] = schema else: - payload["text"] = response_format + if _is_pydantic_class(schema): + schema_dict = schema.model_json_schema() + else: + schema_dict = schema + if schema_dict == {"type": "json_object"}: # JSON mode + payload["text"] = {"format": {"type": "json_object"}} + elif ( + (response_format := _convert_to_openai_response_format(schema_dict)) + and (isinstance(response_format, dict)) + and (response_format["type"] == "json_schema") + ): + payload["text"] = { + "format": {"type": "json_schema", **response_format["json_schema"]} + } + else: + pass return payload @@ -2857,6 +2894,9 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list: input_ = [] for lc_msg in messages: msg = _convert_message_to_dict(lc_msg) + # "name" parameter unsupported + if "name" in msg: + msg.pop("name") if msg["role"] == "tool": tool_output = msg["content"] if not isinstance(tool_output, str): @@ -2872,17 +2912,20 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list: if tool_calls := msg.pop("tool_calls", None): # TODO: should you be able to preserve the function call object id on # the langchain tool calls themselves? - if not lc_msg.additional_kwargs.get(_FUNCTION_CALL_IDS_MAP_KEY): - raise ValueError("") - function_call_ids = lc_msg.additional_kwargs[_FUNCTION_CALL_IDS_MAP_KEY] + function_call_ids = lc_msg.additional_kwargs.get( + _FUNCTION_CALL_IDS_MAP_KEY + ) for tool_call in tool_calls: function_call = { "type": "function_call", "name": tool_call["function"]["name"], "arguments": tool_call["function"]["arguments"], "call_id": tool_call["id"], - "id": function_call_ids[tool_call["id"]], } + if function_call_ids is not None and ( + _id := function_call_ids.get(tool_call["id"]) + ): + function_call["id"] = _id function_calls.append(function_call) msg["content"] = msg.get("content") or [] @@ -2949,7 +2992,9 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list: return input_ -def _construct_lc_result_from_responses_api(response: Response) -> ChatResult: +def _construct_lc_result_from_responses_api( + response: Response, schema: Optional[Type[_BM]] = None +) -> ChatResult: """Construct ChatResponse from OpenAI Response API response.""" if response.error: raise ValueError(response.error) @@ -2994,6 +3039,8 @@ def _construct_lc_result_from_responses_api(response: Response) -> ChatResult: ], } content_blocks.append(block) + if hasattr(content, "parsed"): + additional_kwargs["parsed"] = content.parsed if content.type == "refusal": additional_kwargs["refusal"] = content.refusal msg_id = output.id @@ -3034,6 +3081,35 @@ def _construct_lc_result_from_responses_api(response: Response) -> ChatResult: additional_kwargs["tool_outputs"].append(tool_output) else: additional_kwargs["tool_outputs"] = [tool_output] + # Workaround for parsing structured output in the streaming case. + # from openai import OpenAI + # from pydantic import BaseModel + + # class Foo(BaseModel): + # response: str + + # client = OpenAI() + + # client.responses.parse( + # model="gpt-4o-mini", + # input=[{"content": "how are ya", "role": "user"}], + # text_format=Foo, + # stream=True, # <-- errors + # ) + if ( + schema is not None + and "parsed" not in additional_kwargs + and response.text + and (text_config := response.text.model_dump()) + and (format_ := text_config.get("format", {})) + and (format_.get("type") == "json_schema") + ): + parsed_dict = json.loads(response.output_text) + if schema and _is_pydantic_class(schema): + parsed = schema(**parsed_dict) + else: + parsed = parsed_dict + additional_kwargs["parsed"] = parsed message = AIMessage( content=content_blocks, id=msg_id, @@ -3047,7 +3123,7 @@ def _construct_lc_result_from_responses_api(response: Response) -> ChatResult: def _convert_responses_chunk_to_generation_chunk( - chunk: Any, + chunk: Any, schema: Optional[Type[_BM]] = None ) -> Optional[ChatGenerationChunk]: content = [] tool_call_chunks: list = [] @@ -3074,11 +3150,13 @@ def _convert_responses_chunk_to_generation_chunk( msg = cast( AIMessage, ( - _construct_lc_result_from_responses_api(chunk.response) + _construct_lc_result_from_responses_api(chunk.response, schema=schema) .generations[0] .message ), ) + if parsed := msg.additional_kwargs.get("parsed"): + additional_kwargs["parsed"] = parsed usage_metadata = msg.usage_metadata response_metadata = { k: v for k, v in msg.response_metadata.items() if k != "id" diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py b/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py index c320083e6ef..a9e4c3ca20b 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py @@ -1,5 +1,6 @@ """Test Responses API usage.""" +import json import os from typing import Any, Optional, cast @@ -10,9 +11,13 @@ from langchain_core.messages import ( BaseMessage, BaseMessageChunk, ) +from pydantic import BaseModel +from typing_extensions import TypedDict from langchain_openai import ChatOpenAI +MODEL_NAME = "gpt-4o-mini" + def _check_response(response: Optional[BaseMessage]) -> None: assert isinstance(response, AIMessage) @@ -48,7 +53,7 @@ def _check_response(response: Optional[BaseMessage]) -> None: def test_web_search() -> None: - llm = ChatOpenAI(model="gpt-4o-mini") + llm = ChatOpenAI(model=MODEL_NAME) first_response = llm.invoke( "What was a positive news story from today?", tools=[{"type": "web_search_preview"}], @@ -94,7 +99,7 @@ def test_web_search() -> None: async def test_web_search_async() -> None: - llm = ChatOpenAI(model="gpt-4o-mini") + llm = ChatOpenAI(model=MODEL_NAME) response = await llm.ainvoke( "What was a positive news story from today?", tools=[{"type": "web_search_preview"}], @@ -119,7 +124,7 @@ def test_function_calling() -> None: """return x * y""" return x * y - llm = ChatOpenAI(model="gpt-4o-mini") + llm = ChatOpenAI(model=MODEL_NAME) bound_llm = llm.bind_tools([multiply, {"type": "web_search_preview"}]) ai_msg = cast(AIMessage, bound_llm.invoke("whats 5 * 4")) assert len(ai_msg.tool_calls) == 1 @@ -138,8 +143,110 @@ def test_function_calling() -> None: _check_response(response) +class Foo(BaseModel): + response: str + + +class FooDict(TypedDict): + response: str + + +def test_parsed_pydantic_schema() -> None: + llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True) + response = llm.invoke("how are ya", response_format=Foo) + parsed = Foo(**json.loads(response.text())) + assert parsed == response.additional_kwargs["parsed"] + assert parsed.response + + # Test stream + full: Optional[BaseMessageChunk] = None + for chunk in llm.stream("how are ya", response_format=Foo): + assert isinstance(chunk, AIMessageChunk) + full = chunk if full is None else full + chunk + assert isinstance(full, AIMessageChunk) + parsed = Foo(**json.loads(full.text())) + assert parsed == full.additional_kwargs["parsed"] + assert parsed.response + + +async def test_parsed_pydantic_schema_async() -> None: + llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True) + response = await llm.ainvoke("how are ya", response_format=Foo) + parsed = Foo(**json.loads(response.text())) + assert parsed == response.additional_kwargs["parsed"] + assert parsed.response + + # Test stream + full: Optional[BaseMessageChunk] = None + async for chunk in llm.astream("how are ya", response_format=Foo): + assert isinstance(chunk, AIMessageChunk) + full = chunk if full is None else full + chunk + assert isinstance(full, AIMessageChunk) + parsed = Foo(**json.loads(full.text())) + assert parsed == full.additional_kwargs["parsed"] + assert parsed.response + + +@pytest.mark.parametrize("schema", [Foo.model_json_schema(), FooDict]) +def test_parsed_dict_schema(schema: Any) -> None: + llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True) + response = llm.invoke("how are ya", response_format=schema) + parsed = json.loads(response.text()) + assert parsed == response.additional_kwargs["parsed"] + assert parsed["response"] and isinstance(parsed["response"], str) + + # Test stream + full: Optional[BaseMessageChunk] = None + for chunk in llm.stream("how are ya", response_format=schema): + assert isinstance(chunk, AIMessageChunk) + full = chunk if full is None else full + chunk + assert isinstance(full, AIMessageChunk) + parsed = json.loads(full.text()) + assert parsed == full.additional_kwargs["parsed"] + assert parsed["response"] and isinstance(parsed["response"], str) + + +@pytest.mark.parametrize("schema", [Foo.model_json_schema(), FooDict]) +async def test_parsed_dict_schema_async(schema: Any) -> None: + llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True) + response = await llm.ainvoke("how are ya", response_format=schema) + parsed = json.loads(response.text()) + assert parsed == response.additional_kwargs["parsed"] + assert parsed["response"] and isinstance(parsed["response"], str) + + # Test stream + full: Optional[BaseMessageChunk] = None + async for chunk in llm.astream("how are ya", response_format=schema): + assert isinstance(chunk, AIMessageChunk) + full = chunk if full is None else full + chunk + assert isinstance(full, AIMessageChunk) + parsed = json.loads(full.text()) + assert parsed == full.additional_kwargs["parsed"] + assert parsed["response"] and isinstance(parsed["response"], str) + + +def test_function_calling_and_structured_output() -> None: + def multiply(x: int, y: int) -> int: + """return x * y""" + return x * y + + llm = ChatOpenAI(model=MODEL_NAME) + bound_llm = llm.bind_tools([multiply], response_format=Foo, strict=True) + # Test structured output + response = llm.invoke("how are ya", response_format=Foo) + parsed = Foo(**json.loads(response.text())) + assert parsed == response.additional_kwargs["parsed"] + assert parsed.response + + # Test function calling + ai_msg = cast(AIMessage, bound_llm.invoke("whats 5 * 4")) + assert len(ai_msg.tool_calls) == 1 + assert ai_msg.tool_calls[0]["name"] == "multiply" + assert set(ai_msg.tool_calls[0]["args"]) == {"x", "y"} + + def test_stateful_api() -> None: - llm = ChatOpenAI(model="gpt-4o-mini", use_responses_api=True) + llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True) response = llm.invoke("how are you, my name is Bobo") assert "id" in response.response_metadata @@ -152,7 +259,7 @@ def test_stateful_api() -> None: def test_file_search() -> None: pytest.skip() # TODO: set up infra - llm = ChatOpenAI(model="gpt-4o-mini") + llm = ChatOpenAI(model=MODEL_NAME) tool = { "type": "file_search", "vector_store_ids": [os.environ["OPENAI_VECTOR_STORE_ID"]], diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_responses_standard.py b/libs/partners/openai/tests/integration_tests/chat_models/test_responses_standard.py new file mode 100644 index 00000000000..d2cd27f3dbf --- /dev/null +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_responses_standard.py @@ -0,0 +1,23 @@ +"""Standard LangChain interface tests for Responses API""" + +from typing import Type + +import pytest +from langchain_core.language_models import BaseChatModel + +from langchain_openai import ChatOpenAI +from tests.integration_tests.chat_models.test_base_standard import TestOpenAIStandard + + +class TestOpenAIResponses(TestOpenAIStandard): + @property + def chat_model_class(self) -> Type[BaseChatModel]: + return ChatOpenAI + + @property + def chat_model_params(self) -> dict: + return {"model": "gpt-4o-mini", "use_responses_api": True} + + @pytest.mark.xfail(reason="Unsupported.") + def test_stop_sequence(self, model: BaseChatModel) -> None: + super().test_stop_sequence(model) diff --git a/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_responses_standard.ambr b/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_responses_standard.ambr new file mode 100644 index 00000000000..88a49a27502 --- /dev/null +++ b/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_responses_standard.ambr @@ -0,0 +1,31 @@ +# serializer version: 1 +# name: TestOpenAIResponses.test_serdes[serialized] + dict({ + 'id': list([ + 'langchain', + 'chat_models', + 'openai', + 'ChatOpenAI', + ]), + 'kwargs': dict({ + 'max_retries': 2, + 'max_tokens': 100, + 'model_name': 'gpt-3.5-turbo', + 'openai_api_key': dict({ + 'id': list([ + 'OPENAI_API_KEY', + ]), + 'lc': 1, + 'type': 'secret', + }), + 'request_timeout': 60.0, + 'stop': list([ + ]), + 'temperature': 0.0, + 'use_responses_api': True, + }), + 'lc': 1, + 'name': 'ChatOpenAI', + 'type': 'constructor', + }) +# --- diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index e5e89990b78..8bacb3f5f57 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -1569,23 +1569,6 @@ def test__construct_responses_api_input_ai_message_with_tool_calls_and_content() assert result[1]["id"] == "func_456" -def test__construct_responses_api_input_missing_function_call_ids() -> None: - """Test AI messages with tool calls but missing function call IDs raise an error.""" - tool_calls = [ - { - "id": "call_123", - "name": "get_weather", - "args": {"location": "San Francisco"}, - "type": "tool_call", - } - ] - - ai_message = AIMessage(content="", tool_calls=tool_calls) - - with pytest.raises(ValueError): - _construct_responses_api_input([ai_message]) - - def test__construct_responses_api_input_tool_message_conversion() -> None: """Test that tool messages are properly converted to function_call_output.""" messages = [ diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_responses_standard.py b/libs/partners/openai/tests/unit_tests/chat_models/test_responses_standard.py new file mode 100644 index 00000000000..2d835e57295 --- /dev/null +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_responses_standard.py @@ -0,0 +1,36 @@ +"""Standard LangChain interface tests""" + +from typing import Tuple, Type + +from langchain_core.language_models import BaseChatModel +from langchain_tests.unit_tests import ChatModelUnitTests + +from langchain_openai import ChatOpenAI + + +class TestOpenAIResponses(ChatModelUnitTests): + @property + def chat_model_class(self) -> Type[BaseChatModel]: + return ChatOpenAI + + @property + def chat_model_params(self) -> dict: + return {"use_responses_api": True} + + @property + def init_from_env_params(self) -> Tuple[dict, dict, dict]: + return ( + { + "OPENAI_API_KEY": "api_key", + "OPENAI_ORG_ID": "org_id", + "OPENAI_API_BASE": "api_base", + "OPENAI_PROXY": "https://proxy.com", + }, + {}, + { + "openai_api_key": "api_key", + "openai_organization": "org_id", + "openai_api_base": "api_base", + "openai_proxy": "https://proxy.com", + }, + ) diff --git a/libs/standard-tests/langchain_tests/integration_tests/chat_models.py b/libs/standard-tests/langchain_tests/integration_tests/chat_models.py index ea2dcf52faa..81542017b64 100644 --- a/libs/standard-tests/langchain_tests/integration_tests/chat_models.py +++ b/libs/standard-tests/langchain_tests/integration_tests/chat_models.py @@ -416,7 +416,7 @@ class ChatModelIntegrationTests(ChatModelTests): result = model.invoke("Hello") assert result is not None assert isinstance(result, AIMessage) - assert isinstance(result.content, str) + assert isinstance(result.text(), str) assert len(result.content) > 0 async def test_ainvoke(self, model: BaseChatModel) -> None: @@ -448,7 +448,7 @@ class ChatModelIntegrationTests(ChatModelTests): result = await model.ainvoke("Hello") assert result is not None assert isinstance(result, AIMessage) - assert isinstance(result.content, str) + assert isinstance(result.text(), str) assert len(result.content) > 0 def test_stream(self, model: BaseChatModel) -> None: @@ -542,7 +542,7 @@ class ChatModelIntegrationTests(ChatModelTests): for result in batch_results: assert result is not None assert isinstance(result, AIMessage) - assert isinstance(result.content, str) + assert isinstance(result.text(), str) assert len(result.content) > 0 async def test_abatch(self, model: BaseChatModel) -> None: @@ -571,7 +571,7 @@ class ChatModelIntegrationTests(ChatModelTests): for result in batch_results: assert result is not None assert isinstance(result, AIMessage) - assert isinstance(result.content, str) + assert isinstance(result.text(), str) assert len(result.content) > 0 def test_conversation(self, model: BaseChatModel) -> None: @@ -600,7 +600,7 @@ class ChatModelIntegrationTests(ChatModelTests): result = model.invoke(messages) assert result is not None assert isinstance(result, AIMessage) - assert isinstance(result.content, str) + assert isinstance(result.text(), str) assert len(result.content) > 0 def test_double_messages_conversation(self, model: BaseChatModel) -> None: @@ -638,7 +638,7 @@ class ChatModelIntegrationTests(ChatModelTests): result = model.invoke(messages) assert result is not None assert isinstance(result, AIMessage) - assert isinstance(result.content, str) + assert isinstance(result.text(), str) assert len(result.content) > 0 def test_usage_metadata(self, model: BaseChatModel) -> None: @@ -2136,7 +2136,7 @@ class ChatModelIntegrationTests(ChatModelTests): result = model.invoke([HumanMessage("hello", name="example_user")]) assert result is not None assert isinstance(result, AIMessage) - assert isinstance(result.content, str) + assert isinstance(result.text(), str) assert len(result.content) > 0 def test_agent_loop(self, model: BaseChatModel) -> None: