diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 0e95bbb99d2..35d63e10ac3 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -536,6 +536,14 @@ class BaseChatOpenAI(BaseChatModel): invocation. """ + use_responses_api: Optional[bool] = None + """Whether to use the Responses API instead of the Chat API. + + If not specified then will be inferred based on invocation params. + + .. versionadded:: 0.3.9 + """ + model_config = ConfigDict(populate_by_name=True) @model_validator(mode="before") @@ -871,13 +879,19 @@ class BaseChatOpenAI(BaseChatModel): raw_response = self.client.with_raw_response.create(**payload) response = raw_response.parse() generation_info = {"headers": dict(raw_response.headers)} - elif _use_responses_api(payload): + elif self._use_responses_api(payload): response = self.root_client.responses.create(**payload) return _construct_lc_result_from_responses_api(response) else: response = self.client.create(**payload) return self._create_chat_result(response, generation_info) + def _use_responses_api(self, payload: dict) -> bool: + if isinstance(self.use_responses_api, bool): + return self.use_responses_api + else: + return _use_responses_api(payload) + def _get_request_payload( self, input_: LanguageModelInput, @@ -890,7 +904,7 @@ class BaseChatOpenAI(BaseChatModel): kwargs["stop"] = stop payload = {**self._default_params, **kwargs} - if _use_responses_api(payload): + if self._use_responses_api(payload): payload = _construct_responses_api_payload(messages, payload) else: payload["messages"] = [_convert_message_to_dict(m) for m in messages] @@ -933,6 +947,8 @@ class BaseChatOpenAI(BaseChatModel): "model_name": response_dict.get("model", self.model_name), "system_fingerprint": response_dict.get("system_fingerprint", ""), } + if "id" in response_dict: + llm_output["id"] = response_dict["id"] if isinstance(response, openai.BaseModel) and getattr( response, "choices", None @@ -1045,7 +1061,7 @@ class BaseChatOpenAI(BaseChatModel): raw_response = await self.async_client.with_raw_response.create(**payload) response = raw_response.parse() generation_info = {"headers": dict(raw_response.headers)} - elif _use_responses_api(payload): + elif self._use_responses_api(payload): response = await self.root_async_client.responses.create(**payload) return _construct_lc_result_from_responses_api(response) else: @@ -2146,7 +2162,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override] self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any ) -> Iterator[ChatGenerationChunk]: """Set default stream_options.""" - if _use_responses_api(kwargs): + if self._use_responses_api(kwargs): return super()._stream_responses(*args, **kwargs) else: stream_usage = self._should_stream_usage(stream_usage, **kwargs) @@ -2164,7 +2180,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override] self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any ) -> AsyncIterator[ChatGenerationChunk]: """Set default stream_options.""" - if _use_responses_api(kwargs): + if self._use_responses_api(kwargs): async for chunk in super()._astream_responses(*args, **kwargs): yield chunk else: diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py b/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py index b7e7550f231..c320083e6ef 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py @@ -1,7 +1,7 @@ """Test Responses API usage.""" import os -from typing import Optional +from typing import Any, Optional, cast import pytest from langchain_core.messages import ( @@ -114,6 +114,42 @@ async def test_web_search_async() -> None: _check_response(full) +def test_function_calling() -> None: + def multiply(x: int, y: int) -> int: + """return x * y""" + return x * y + + llm = ChatOpenAI(model="gpt-4o-mini") + bound_llm = llm.bind_tools([multiply, {"type": "web_search_preview"}]) + ai_msg = cast(AIMessage, bound_llm.invoke("whats 5 * 4")) + assert len(ai_msg.tool_calls) == 1 + assert ai_msg.tool_calls[0]["name"] == "multiply" + assert set(ai_msg.tool_calls[0]["args"]) == {"x", "y"} + + full: Any = None + for chunk in bound_llm.stream("whats 5 * 4"): + assert isinstance(chunk, AIMessageChunk) + full = chunk if full is None else full + chunk + assert len(full.tool_calls) == 1 + assert full.tool_calls[0]["name"] == "multiply" + assert set(full.tool_calls[0]["args"]) == {"x", "y"} + + response = bound_llm.invoke("whats some good news from today") + _check_response(response) + + +def test_stateful_api() -> None: + llm = ChatOpenAI(model="gpt-4o-mini", use_responses_api=True) + response = llm.invoke("how are you, my name is Bobo") + assert "id" in response.response_metadata + + second_response = llm.invoke( + "what's my name", previous_response_id=response.response_metadata["id"] + ) + assert isinstance(second_response.content, list) + assert "bobo" in second_response.content[0]["text"].lower() # type: ignore + + def test_file_search() -> None: pytest.skip() # TODO: set up infra llm = ChatOpenAI(model="gpt-4o-mini")