This commit is contained in:
Chester Curme
2026-02-09 11:28:07 -05:00
parent dfb42a5b84
commit abb0b6eb29
2 changed files with 312 additions and 136 deletions

View File

@@ -40,6 +40,7 @@ from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.exceptions import ContextOverflowError
from langchain_core.language_models import (
LanguageModelInput,
ModelProfileRegistry,
@@ -450,6 +451,11 @@ def _update_token_usage(
def _handle_openai_bad_request(e: openai.BadRequestError) -> None:
if (
"context_length_exceeded" in str(e)
or "Input tokens exceed the configured limit" in e.message
):
raise ContextOverflowError(e.message) from e
if (
"'response_format' of type 'json_schema' is not supported with this model"
) in e.message:
@@ -474,6 +480,13 @@ def _handle_openai_bad_request(e: openai.BadRequestError) -> None:
raise
def _handle_openai_api_error(e: openai.APIError) -> None:
error_message = str(e)
if "exceeds the context window" in error_message:
raise ContextOverflowError(error_message) from e
raise
def _model_prefers_responses_api(model_name: str | None) -> bool:
if not model_name:
return False
@@ -1146,49 +1159,54 @@ class BaseChatOpenAI(BaseChatModel):
self._ensure_sync_client_available()
kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs)
if self.include_response_headers:
raw_context_manager = self.root_client.with_raw_response.responses.create(
**payload
)
context_manager = raw_context_manager.parse()
headers = {"headers": dict(raw_context_manager.headers)}
else:
context_manager = self.root_client.responses.create(**payload)
headers = {}
original_schema_obj = kwargs.get("response_format")
with context_manager as response:
is_first_chunk = True
current_index = -1
current_output_index = -1
current_sub_index = -1
has_reasoning = False
for chunk in response:
metadata = headers if is_first_chunk else {}
(
current_index,
current_output_index,
current_sub_index,
generation_chunk,
) = _convert_responses_chunk_to_generation_chunk(
chunk,
current_index,
current_output_index,
current_sub_index,
schema=original_schema_obj,
metadata=metadata,
has_reasoning=has_reasoning,
output_version=self.output_version,
try:
if self.include_response_headers:
raw_context_manager = (
self.root_client.with_raw_response.responses.create(**payload)
)
if generation_chunk:
if run_manager:
run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk
)
is_first_chunk = False
if "reasoning" in generation_chunk.message.additional_kwargs:
has_reasoning = True
yield generation_chunk
context_manager = raw_context_manager.parse()
headers = {"headers": dict(raw_context_manager.headers)}
else:
context_manager = self.root_client.responses.create(**payload)
headers = {}
original_schema_obj = kwargs.get("response_format")
with context_manager as response:
is_first_chunk = True
current_index = -1
current_output_index = -1
current_sub_index = -1
has_reasoning = False
for chunk in response:
metadata = headers if is_first_chunk else {}
(
current_index,
current_output_index,
current_sub_index,
generation_chunk,
) = _convert_responses_chunk_to_generation_chunk(
chunk,
current_index,
current_output_index,
current_sub_index,
schema=original_schema_obj,
metadata=metadata,
has_reasoning=has_reasoning,
output_version=self.output_version,
)
if generation_chunk:
if run_manager:
run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk
)
is_first_chunk = False
if "reasoning" in generation_chunk.message.additional_kwargs:
has_reasoning = True
yield generation_chunk
except openai.BadRequestError as e:
_handle_openai_bad_request(e)
except openai.APIError as e:
_handle_openai_api_error(e)
async def _astream_responses(
self,
@@ -1199,51 +1217,58 @@ class BaseChatOpenAI(BaseChatModel):
) -> AsyncIterator[ChatGenerationChunk]:
kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs)
if self.include_response_headers:
raw_context_manager = (
await self.root_async_client.with_raw_response.responses.create(
try:
if self.include_response_headers:
raw_context_manager = (
await self.root_async_client.with_raw_response.responses.create(
**payload
)
)
context_manager = raw_context_manager.parse()
headers = {"headers": dict(raw_context_manager.headers)}
else:
context_manager = await self.root_async_client.responses.create(
**payload
)
)
context_manager = raw_context_manager.parse()
headers = {"headers": dict(raw_context_manager.headers)}
else:
context_manager = await self.root_async_client.responses.create(**payload)
headers = {}
original_schema_obj = kwargs.get("response_format")
headers = {}
original_schema_obj = kwargs.get("response_format")
async with context_manager as response:
is_first_chunk = True
current_index = -1
current_output_index = -1
current_sub_index = -1
has_reasoning = False
async for chunk in response:
metadata = headers if is_first_chunk else {}
(
current_index,
current_output_index,
current_sub_index,
generation_chunk,
) = _convert_responses_chunk_to_generation_chunk(
chunk,
current_index,
current_output_index,
current_sub_index,
schema=original_schema_obj,
metadata=metadata,
has_reasoning=has_reasoning,
output_version=self.output_version,
)
if generation_chunk:
if run_manager:
await run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk
)
is_first_chunk = False
if "reasoning" in generation_chunk.message.additional_kwargs:
has_reasoning = True
yield generation_chunk
async with context_manager as response:
is_first_chunk = True
current_index = -1
current_output_index = -1
current_sub_index = -1
has_reasoning = False
async for chunk in response:
metadata = headers if is_first_chunk else {}
(
current_index,
current_output_index,
current_sub_index,
generation_chunk,
) = _convert_responses_chunk_to_generation_chunk(
chunk,
current_index,
current_output_index,
current_sub_index,
schema=original_schema_obj,
metadata=metadata,
has_reasoning=has_reasoning,
output_version=self.output_version,
)
if generation_chunk:
if run_manager:
await run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk
)
is_first_chunk = False
if "reasoning" in generation_chunk.message.additional_kwargs:
has_reasoning = True
yield generation_chunk
except openai.BadRequestError as e:
_handle_openai_bad_request(e)
except openai.APIError as e:
_handle_openai_api_error(e)
def _should_stream_usage(
self, stream_usage: bool | None = None, **kwargs: Any
@@ -1282,24 +1307,26 @@ class BaseChatOpenAI(BaseChatModel):
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
base_generation_info = {}
if "response_format" in payload:
if self.include_response_headers:
warnings.warn(
"Cannot currently include response headers when response_format is "
"specified."
)
payload.pop("stream")
response_stream = self.root_client.beta.chat.completions.stream(**payload)
context_manager = response_stream
else:
if self.include_response_headers:
raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)}
else:
response = self.client.create(**payload)
context_manager = response
try:
if "response_format" in payload:
if self.include_response_headers:
warnings.warn(
"Cannot currently include response headers when "
"response_format is specified."
)
payload.pop("stream")
response_stream = self.root_client.beta.chat.completions.stream(
**payload
)
context_manager = response_stream
else:
if self.include_response_headers:
raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)}
else:
response = self.client.create(**payload)
context_manager = response
with context_manager as response:
is_first_chunk = True
for chunk in response:
@@ -1324,6 +1351,8 @@ class BaseChatOpenAI(BaseChatModel):
yield generation_chunk
except openai.BadRequestError as e:
_handle_openai_bad_request(e)
except openai.APIError as e:
_handle_openai_api_error(e)
if hasattr(response, "get_final_completion") and "response_format" in payload:
final_completion = response.get_final_completion()
generation_chunk = self._get_generation_chunk_from_completion(
@@ -1349,15 +1378,10 @@ class BaseChatOpenAI(BaseChatModel):
try:
if "response_format" in payload:
payload.pop("stream")
try:
raw_response = (
self.root_client.chat.completions.with_raw_response.parse(
**payload
)
)
response = raw_response.parse()
except openai.BadRequestError as e:
_handle_openai_bad_request(e)
raw_response = (
self.root_client.chat.completions.with_raw_response.parse(**payload)
)
response = raw_response.parse()
elif self._use_responses_api(payload):
original_schema_obj = kwargs.get("response_format")
if original_schema_obj and _is_pydantic_class(original_schema_obj):
@@ -1380,6 +1404,10 @@ class BaseChatOpenAI(BaseChatModel):
else:
raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse()
except openai.BadRequestError as e:
_handle_openai_bad_request(e)
except openai.APIError as e:
_handle_openai_api_error(e)
except Exception as e:
if raw_response is not None and hasattr(raw_response, "http_response"):
e.response = raw_response.http_response # type: ignore[attr-defined]
@@ -1523,28 +1551,28 @@ class BaseChatOpenAI(BaseChatModel):
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
base_generation_info = {}
if "response_format" in payload:
if self.include_response_headers:
warnings.warn(
"Cannot currently include response headers when response_format is "
"specified."
)
payload.pop("stream")
response_stream = self.root_async_client.beta.chat.completions.stream(
**payload
)
context_manager = response_stream
else:
if self.include_response_headers:
raw_response = await self.async_client.with_raw_response.create(
try:
if "response_format" in payload:
if self.include_response_headers:
warnings.warn(
"Cannot currently include response headers when "
"response_format is specified."
)
payload.pop("stream")
response_stream = self.root_async_client.beta.chat.completions.stream(
**payload
)
response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)}
context_manager = response_stream
else:
response = await self.async_client.create(**payload)
context_manager = response
try:
if self.include_response_headers:
raw_response = await self.async_client.with_raw_response.create(
**payload
)
response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)}
else:
response = await self.async_client.create(**payload)
context_manager = response
async with context_manager as response:
is_first_chunk = True
async for chunk in response:
@@ -1569,6 +1597,8 @@ class BaseChatOpenAI(BaseChatModel):
yield generation_chunk
except openai.BadRequestError as e:
_handle_openai_bad_request(e)
except openai.APIError as e:
_handle_openai_api_error(e)
if hasattr(response, "get_final_completion") and "response_format" in payload:
final_completion = await response.get_final_completion()
generation_chunk = self._get_generation_chunk_from_completion(
@@ -1593,13 +1623,10 @@ class BaseChatOpenAI(BaseChatModel):
try:
if "response_format" in payload:
payload.pop("stream")
try:
raw_response = await self.root_async_client.chat.completions.with_raw_response.parse( # noqa: E501
**payload
)
response = raw_response.parse()
except openai.BadRequestError as e:
_handle_openai_bad_request(e)
raw_response = await self.root_async_client.chat.completions.with_raw_response.parse( # noqa: E501
**payload
)
response = raw_response.parse()
elif self._use_responses_api(payload):
original_schema_obj = kwargs.get("response_format")
if original_schema_obj and _is_pydantic_class(original_schema_obj):
@@ -1628,6 +1655,10 @@ class BaseChatOpenAI(BaseChatModel):
**payload
)
response = raw_response.parse()
except openai.BadRequestError as e:
_handle_openai_bad_request(e)
except openai.APIError as e:
_handle_openai_api_error(e)
except Exception as e:
if raw_response is not None and hasattr(raw_response, "http_response"):
e.response = raw_response.http_response # type: ignore[attr-defined]

View File

@@ -10,7 +10,9 @@ from typing import Any, Literal, cast
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import openai
import pytest
from langchain_core.exceptions import ContextOverflowError
from langchain_core.load import dumps, loads
from langchain_core.messages import (
AIMessage,
@@ -3231,3 +3233,146 @@ def test_openai_structured_output_refusal_handling_responses_api() -> None:
pass
except ValueError as e:
pytest.fail(f"This is a wrong behavior. Error details: {e}")
# Test fixtures for context overflow error tests
_CONTEXT_OVERFLOW_ERROR_BODY = {
"error": {
"message": (
"Input tokens exceed the configured limit of 272000 tokens. Your messages "
"resulted in 300007 tokens. Please reduce the length of the messages."
),
"type": "invalid_request_error",
"param": "messages",
"code": "context_length_exceeded",
}
}
_CONTEXT_OVERFLOW_BAD_REQUEST_ERROR = openai.BadRequestError(
message=_CONTEXT_OVERFLOW_ERROR_BODY["error"]["message"],
response=MagicMock(status_code=400),
body=_CONTEXT_OVERFLOW_ERROR_BODY,
)
_CONTEXT_OVERFLOW_API_ERROR = openai.APIError(
message=(
"Your input exceeds the context window of this model. Please adjust your input "
"and try again."
),
request=MagicMock(),
body=None,
)
def test_context_overflow_error_invoke_sync() -> None:
"""Test context overflow error on invoke (sync, chat completions API)."""
llm = ChatOpenAI()
with ( # noqa: PT012
patch.object(llm.client, "with_raw_response") as mock_client,
pytest.raises(ContextOverflowError) as exc_info,
):
mock_client.create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR
llm.invoke([HumanMessage(content="test")])
assert "Input tokens exceed the configured limit" in str(exc_info.value)
def test_context_overflow_error_invoke_sync_responses_api() -> None:
"""Test context overflow error on invoke (sync, responses API)."""
llm = ChatOpenAI(use_responses_api=True)
with ( # noqa: PT012
patch.object(llm.root_client.responses, "with_raw_response") as mock_client,
pytest.raises(ContextOverflowError) as exc_info,
):
mock_client.create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR
llm.invoke([HumanMessage(content="test")])
assert "Input tokens exceed the configured limit" in str(exc_info.value)
async def test_context_overflow_error_invoke_async() -> None:
"""Test context overflow error on invoke (async, chat completions API)."""
llm = ChatOpenAI()
with ( # noqa: PT012
patch.object(llm.async_client, "with_raw_response") as mock_client,
pytest.raises(ContextOverflowError) as exc_info,
):
mock_client.create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR
await llm.ainvoke([HumanMessage(content="test")])
assert "Input tokens exceed the configured limit" in str(exc_info.value)
async def test_context_overflow_error_invoke_async_responses_api() -> None:
"""Test context overflow error on invoke (async, responses API)."""
llm = ChatOpenAI(use_responses_api=True)
with ( # noqa: PT012
patch.object(
llm.root_async_client.responses, "with_raw_response"
) as mock_client,
pytest.raises(ContextOverflowError) as exc_info,
):
mock_client.create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR
await llm.ainvoke([HumanMessage(content="test")])
assert "Input tokens exceed the configured limit" in str(exc_info.value)
def test_context_overflow_error_stream_sync() -> None:
"""Test context overflow error on stream (sync, chat completions API)."""
llm = ChatOpenAI()
with ( # noqa: PT012
patch.object(llm.client, "create") as mock_create,
pytest.raises(ContextOverflowError) as exc_info,
):
mock_create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR
list(llm.stream([HumanMessage(content="test")]))
assert "Input tokens exceed the configured limit" in str(exc_info.value)
def test_context_overflow_error_stream_sync_responses_api() -> None:
"""Test context overflow error on stream (sync, responses API)."""
llm = ChatOpenAI(use_responses_api=True)
with ( # noqa: PT012
patch.object(llm.root_client.responses, "create") as mock_create,
pytest.raises(ContextOverflowError) as exc_info,
):
mock_create.side_effect = _CONTEXT_OVERFLOW_API_ERROR
list(llm.stream([HumanMessage(content="test")]))
assert "exceeds the context window" in str(exc_info.value)
async def test_context_overflow_error_stream_async() -> None:
"""Test context overflow error on stream (async, chat completions API)."""
llm = ChatOpenAI()
with ( # noqa: PT012
patch.object(llm.async_client, "create") as mock_create,
pytest.raises(ContextOverflowError) as exc_info,
):
mock_create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR
async for _ in llm.astream([HumanMessage(content="test")]):
pass
assert "Input tokens exceed the configured limit" in str(exc_info.value)
async def test_context_overflow_error_stream_async_responses_api() -> None:
"""Test context overflow error on stream (async, responses API)."""
llm = ChatOpenAI(use_responses_api=True)
with ( # noqa: PT012
patch.object(llm.root_async_client.responses, "create") as mock_create,
pytest.raises(ContextOverflowError) as exc_info,
):
mock_create.side_effect = _CONTEXT_OVERFLOW_API_ERROR
async for _ in llm.astream([HumanMessage(content="test")]):
pass
assert "exceeds the context window" in str(exc_info.value)