fix(openai): Always add raw response object to OpenAI client errors for invoke (#32655)

This commit is contained in:
Jacob Lee
2025-08-26 06:59:25 -07:00
committed by GitHub
parent f33480c2cf
commit 1459d4f4ce
7 changed files with 154 additions and 75 deletions

View File

@@ -1142,42 +1142,51 @@ class BaseChatOpenAI(BaseChatModel):
return generate_from_stream(stream_iter) return generate_from_stream(stream_iter)
payload = self._get_request_payload(messages, stop=stop, **kwargs) payload = self._get_request_payload(messages, stop=stop, **kwargs)
generation_info = None generation_info = None
if "response_format" in payload: raw_response = None
if self.include_response_headers: try:
warnings.warn( if "response_format" in payload:
"Cannot currently include response headers when response_format is " payload.pop("stream")
"specified." try:
) raw_response = (
payload.pop("stream") self.root_client.chat.completions.with_raw_response.parse(
try: **payload
response = self.root_client.beta.chat.completions.parse(**payload) )
except openai.BadRequestError as e:
_handle_openai_bad_request(e)
elif self._use_responses_api(payload):
original_schema_obj = kwargs.get("response_format")
if original_schema_obj and _is_pydantic_class(original_schema_obj):
response = self.root_client.responses.parse(**payload)
else:
if self.include_response_headers:
raw_response = self.root_client.with_raw_response.responses.create(
**payload
) )
response = raw_response.parse() response = raw_response.parse()
generation_info = {"headers": dict(raw_response.headers)} except openai.BadRequestError as e:
_handle_openai_bad_request(e)
elif self._use_responses_api(payload):
original_schema_obj = kwargs.get("response_format")
if original_schema_obj and _is_pydantic_class(original_schema_obj):
raw_response = self.root_client.responses.with_raw_response.parse(
**payload
)
else: else:
response = self.root_client.responses.create(**payload) raw_response = self.root_client.responses.with_raw_response.create(
return _construct_lc_result_from_responses_api( **payload
response, )
schema=original_schema_obj, response = raw_response.parse()
metadata=generation_info, if self.include_response_headers:
output_version=self.output_version, generation_info = {"headers": dict(raw_response.headers)}
) return _construct_lc_result_from_responses_api(
elif self.include_response_headers: response,
raw_response = self.client.with_raw_response.create(**payload) schema=original_schema_obj,
response = raw_response.parse() metadata=generation_info,
output_version=self.output_version,
)
else:
raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse()
except Exception as e:
if raw_response is not None and hasattr(raw_response, "http_response"):
e.response = raw_response.http_response # type: ignore[attr-defined]
raise e
if (
self.include_response_headers
and raw_response is not None
and hasattr(raw_response, "headers")
):
generation_info = {"headers": dict(raw_response.headers)} generation_info = {"headers": dict(raw_response.headers)}
else:
response = self.client.create(**payload)
return self._create_chat_result(response, generation_info) return self._create_chat_result(response, generation_info)
def _use_responses_api(self, payload: dict) -> bool: def _use_responses_api(self, payload: dict) -> bool:
@@ -1375,46 +1384,55 @@ class BaseChatOpenAI(BaseChatModel):
return await agenerate_from_stream(stream_iter) return await agenerate_from_stream(stream_iter)
payload = self._get_request_payload(messages, stop=stop, **kwargs) payload = self._get_request_payload(messages, stop=stop, **kwargs)
generation_info = None generation_info = None
if "response_format" in payload: raw_response = None
if self.include_response_headers: try:
warnings.warn( if "response_format" in payload:
"Cannot currently include response headers when response_format is " payload.pop("stream")
"specified." try:
) raw_response = await self.root_async_client.chat.completions.with_raw_response.parse( # noqa: E501
payload.pop("stream") **payload
try: )
response = await self.root_async_client.beta.chat.completions.parse( response = raw_response.parse()
**payload except openai.BadRequestError as e:
) _handle_openai_bad_request(e)
except openai.BadRequestError as e: elif self._use_responses_api(payload):
_handle_openai_bad_request(e) original_schema_obj = kwargs.get("response_format")
elif self._use_responses_api(payload): if original_schema_obj and _is_pydantic_class(original_schema_obj):
original_schema_obj = kwargs.get("response_format")
if original_schema_obj and _is_pydantic_class(original_schema_obj):
response = await self.root_async_client.responses.parse(**payload)
else:
if self.include_response_headers:
raw_response = ( raw_response = (
await self.root_async_client.with_raw_response.responses.create( await self.root_async_client.responses.with_raw_response.parse(
**payload **payload
) )
) )
response = raw_response.parse()
generation_info = {"headers": dict(raw_response.headers)}
else: else:
response = await self.root_async_client.responses.create(**payload) raw_response = (
return _construct_lc_result_from_responses_api( await self.root_async_client.responses.with_raw_response.create(
response, **payload
schema=original_schema_obj, )
metadata=generation_info, )
output_version=self.output_version, response = raw_response.parse()
) if self.include_response_headers:
elif self.include_response_headers: generation_info = {"headers": dict(raw_response.headers)}
raw_response = await self.async_client.with_raw_response.create(**payload) return _construct_lc_result_from_responses_api(
response = raw_response.parse() response,
schema=original_schema_obj,
metadata=generation_info,
output_version=self.output_version,
)
else:
raw_response = await self.async_client.with_raw_response.create(
**payload
)
response = raw_response.parse()
except Exception as e:
if raw_response is not None and hasattr(raw_response, "http_response"):
e.response = raw_response.http_response # type: ignore[attr-defined]
raise e
if (
self.include_response_headers
and raw_response is not None
and hasattr(raw_response, "headers")
):
generation_info = {"headers": dict(raw_response.headers)} generation_info = {"headers": dict(raw_response.headers)}
else:
response = await self.async_client.create(**payload)
return await run_in_executor( return await run_in_executor(
None, self._create_chat_result, response, generation_info None, self._create_chat_result, response, generation_info
) )

View File

@@ -27,7 +27,7 @@ from langchain_tests.integration_tests.chat_models import (
_validate_tool_call_message, _validate_tool_call_message,
magic_function, magic_function,
) )
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, field_validator
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from tests.unit_tests.fake.callbacks import FakeCallbackHandler from tests.unit_tests.fake.callbacks import FakeCallbackHandler
@@ -1155,3 +1155,62 @@ def test_prompt_cache_key_usage_methods_integration() -> None:
response_model_level = chat_model_level.invoke(messages) response_model_level = chat_model_level.invoke(messages)
assert isinstance(response_model_level, AIMessage) assert isinstance(response_model_level, AIMessage)
assert isinstance(response_model_level.content, str) assert isinstance(response_model_level.content, str)
class BadModel(BaseModel):
response: str
@field_validator("response")
@classmethod
def validate_response(cls, v: str) -> str:
if v != "bad":
raise ValueError('response must be exactly "bad"')
return v
# VCR can't handle parameterized tests
@pytest.mark.vcr()
def test_schema_parsing_failures() -> None:
llm = ChatOpenAI(model="gpt-5-nano", use_responses_api=False)
try:
llm.invoke("respond with good", response_format=BadModel)
except Exception as e:
assert e.response is not None # type: ignore[attr-defined]
else:
assert False
# VCR can't handle parameterized tests
@pytest.mark.vcr()
def test_schema_parsing_failures_responses_api() -> None:
llm = ChatOpenAI(model="gpt-5-nano", use_responses_api=True)
try:
llm.invoke("respond with good", response_format=BadModel)
except Exception as e:
assert e.response is not None # type: ignore[attr-defined]
else:
assert False
# VCR can't handle parameterized tests
@pytest.mark.vcr()
async def test_schema_parsing_failures_async() -> None:
llm = ChatOpenAI(model="gpt-5-nano", use_responses_api=False)
try:
await llm.ainvoke("respond with good", response_format=BadModel)
except Exception as e:
assert e.response is not None # type: ignore[attr-defined]
else:
assert False
# VCR can't handle parameterized tests
@pytest.mark.vcr()
async def test_schema_parsing_failures_responses_api_async() -> None:
llm = ChatOpenAI(model="gpt-5-nano", use_responses_api=True)
try:
await llm.ainvoke("respond with good", response_format=BadModel)
except Exception as e:
assert e.response is not None # type: ignore[attr-defined]
else:
assert False

View File

@@ -601,7 +601,7 @@ def test_openai_invoke(mock_client: MagicMock) -> None:
# headers are not in response_metadata if include_response_headers not set # headers are not in response_metadata if include_response_headers not set
assert "headers" not in res.response_metadata assert "headers" not in res.response_metadata
assert mock_client.create.called assert mock_client.with_raw_response.create.called
async def test_openai_ainvoke(mock_async_client: AsyncMock) -> None: async def test_openai_ainvoke(mock_async_client: AsyncMock) -> None:
@@ -613,7 +613,7 @@ async def test_openai_ainvoke(mock_async_client: AsyncMock) -> None:
# headers are not in response_metadata if include_response_headers not set # headers are not in response_metadata if include_response_headers not set
assert "headers" not in res.response_metadata assert "headers" not in res.response_metadata
assert mock_async_client.create.called assert mock_async_client.with_raw_response.create.called
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -638,7 +638,7 @@ def test_openai_invoke_name(mock_client: MagicMock) -> None:
with patch.object(llm, "client", mock_client): with patch.object(llm, "client", mock_client):
messages = [HumanMessage(content="Foo", name="Katie")] messages = [HumanMessage(content="Foo", name="Katie")]
res = llm.invoke(messages) res = llm.invoke(messages)
call_args, call_kwargs = mock_client.create.call_args call_args, call_kwargs = mock_client.with_raw_response.create.call_args
assert len(call_args) == 0 # no positional args assert len(call_args) == 0 # no positional args
call_messages = call_kwargs["messages"] call_messages = call_kwargs["messages"]
assert len(call_messages) == 1 assert len(call_messages) == 1
@@ -678,7 +678,7 @@ def test_function_calls_with_tool_calls(mock_client: MagicMock) -> None:
] ]
with patch.object(llm, "client", mock_client): with patch.object(llm, "client", mock_client):
_ = llm.invoke(messages) _ = llm.invoke(messages)
_, call_kwargs = mock_client.create.call_args _, call_kwargs = mock_client.with_raw_response.create.call_args
call_messages = call_kwargs["messages"] call_messages = call_kwargs["messages"]
tool_call_message_payload = call_messages[1] tool_call_message_payload = call_messages[1]
assert "tool_calls" in tool_call_message_payload assert "tool_calls" in tool_call_message_payload
@@ -688,7 +688,7 @@ def test_function_calls_with_tool_calls(mock_client: MagicMock) -> None:
cast(AIMessage, messages[1]).tool_calls = [] cast(AIMessage, messages[1]).tool_calls = []
with patch.object(llm, "client", mock_client): with patch.object(llm, "client", mock_client):
_ = llm.invoke(messages) _ = llm.invoke(messages)
_, call_kwargs = mock_client.create.call_args _, call_kwargs = mock_client.with_raw_response.create.call_args
call_messages = call_kwargs["messages"] call_messages = call_kwargs["messages"]
tool_call_message_payload = call_messages[1] tool_call_message_payload = call_messages[1]
assert "function_call" in tool_call_message_payload assert "function_call" in tool_call_message_payload
@@ -2326,8 +2326,9 @@ def test_mcp_tracing() -> None:
tracer = FakeTracer() tracer = FakeTracer()
mock_client = MagicMock() mock_client = MagicMock()
def mock_create(*args: Any, **kwargs: Any) -> Response: def mock_create(*args: Any, **kwargs: Any) -> MagicMock:
return Response( mock_raw_response = MagicMock()
mock_raw_response.parse.return_value = Response(
id="resp_123", id="resp_123",
created_at=1234567890, created_at=1234567890,
model="o4-mini", model="o4-mini",
@@ -2349,8 +2350,9 @@ def test_mcp_tracing() -> None:
) )
], ],
) )
return mock_raw_response
mock_client.responses.create = mock_create mock_client.responses.with_raw_response.create = mock_create
input_message = HumanMessage("Test query") input_message = HumanMessage("Test query")
tools = [ tools = [
{ {