use_responses_api

This commit is contained in:
Bagatur 2025-03-12 04:21:50 -07:00
parent 5493628a51
commit deed6f680c
2 changed files with 58 additions and 6 deletions

View File

@ -536,6 +536,14 @@ class BaseChatOpenAI(BaseChatModel):
invocation.
"""
use_responses_api: Optional[bool] = None
"""Whether to use the Responses API instead of the Chat API.
If not specified then will be inferred based on invocation params.
.. versionadded:: 0.3.9
"""
model_config = ConfigDict(populate_by_name=True)
@model_validator(mode="before")
@ -871,13 +879,19 @@ class BaseChatOpenAI(BaseChatModel):
raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse()
generation_info = {"headers": dict(raw_response.headers)}
elif _use_responses_api(payload):
elif self._use_responses_api(payload):
response = self.root_client.responses.create(**payload)
return _construct_lc_result_from_responses_api(response)
else:
response = self.client.create(**payload)
return self._create_chat_result(response, generation_info)
def _use_responses_api(self, payload: dict) -> bool:
if isinstance(self.use_responses_api, bool):
return self.use_responses_api
else:
return _use_responses_api(payload)
def _get_request_payload(
self,
input_: LanguageModelInput,
@ -890,7 +904,7 @@ class BaseChatOpenAI(BaseChatModel):
kwargs["stop"] = stop
payload = {**self._default_params, **kwargs}
if _use_responses_api(payload):
if self._use_responses_api(payload):
payload = _construct_responses_api_payload(messages, payload)
else:
payload["messages"] = [_convert_message_to_dict(m) for m in messages]
@ -933,6 +947,8 @@ class BaseChatOpenAI(BaseChatModel):
"model_name": response_dict.get("model", self.model_name),
"system_fingerprint": response_dict.get("system_fingerprint", ""),
}
if "id" in response_dict:
llm_output["id"] = response_dict["id"]
if isinstance(response, openai.BaseModel) and getattr(
response, "choices", None
@ -1045,7 +1061,7 @@ class BaseChatOpenAI(BaseChatModel):
raw_response = await self.async_client.with_raw_response.create(**payload)
response = raw_response.parse()
generation_info = {"headers": dict(raw_response.headers)}
elif _use_responses_api(payload):
elif self._use_responses_api(payload):
response = await self.root_async_client.responses.create(**payload)
return _construct_lc_result_from_responses_api(response)
else:
@ -2146,7 +2162,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
) -> Iterator[ChatGenerationChunk]:
"""Set default stream_options."""
if _use_responses_api(kwargs):
if self._use_responses_api(kwargs):
return super()._stream_responses(*args, **kwargs)
else:
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
@ -2164,7 +2180,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
) -> AsyncIterator[ChatGenerationChunk]:
"""Set default stream_options."""
if _use_responses_api(kwargs):
if self._use_responses_api(kwargs):
async for chunk in super()._astream_responses(*args, **kwargs):
yield chunk
else:

View File

@ -1,7 +1,7 @@
"""Test Responses API usage."""
import os
from typing import Optional
from typing import Any, Optional, cast
import pytest
from langchain_core.messages import (
@ -114,6 +114,42 @@ async def test_web_search_async() -> None:
_check_response(full)
def test_function_calling() -> None:
def multiply(x: int, y: int) -> int:
"""return x * y"""
return x * y
llm = ChatOpenAI(model="gpt-4o-mini")
bound_llm = llm.bind_tools([multiply, {"type": "web_search_preview"}])
ai_msg = cast(AIMessage, bound_llm.invoke("whats 5 * 4"))
assert len(ai_msg.tool_calls) == 1
assert ai_msg.tool_calls[0]["name"] == "multiply"
assert set(ai_msg.tool_calls[0]["args"]) == {"x", "y"}
full: Any = None
for chunk in bound_llm.stream("whats 5 * 4"):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
assert len(full.tool_calls) == 1
assert full.tool_calls[0]["name"] == "multiply"
assert set(full.tool_calls[0]["args"]) == {"x", "y"}
response = bound_llm.invoke("whats some good news from today")
_check_response(response)
def test_stateful_api() -> None:
llm = ChatOpenAI(model="gpt-4o-mini", use_responses_api=True)
response = llm.invoke("how are you, my name is Bobo")
assert "id" in response.response_metadata
second_response = llm.invoke(
"what's my name", previous_response_id=response.response_metadata["id"]
)
assert isinstance(second_response.content, list)
assert "bobo" in second_response.content[0]["text"].lower() # type: ignore
def test_file_search() -> None:
pytest.skip() # TODO: set up infra
llm = ChatOpenAI(model="gpt-4o-mini")