feat(core): add ContextOverflowError, raise in anthropic and openai (#35099)

This commit is contained in:
ccurme
2026-02-09 15:15:34 -05:00
committed by GitHub
parent 4ca586b322
commit 7c41298355
5 changed files with 445 additions and 137 deletions

View File

@@ -65,6 +65,14 @@ class OutputParserException(ValueError, LangChainException): # noqa: N818
self.send_to_llm = send_to_llm
class ContextOverflowError(LangChainException):
"""Exception raised when input exceeds the model's context limit.
This exception is raised by chat models when the input tokens exceed
the maximum context window supported by the model.
"""
class ErrorCode(Enum):
"""Error codes."""

View File

@@ -17,7 +17,7 @@ from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.exceptions import OutputParserException
from langchain_core.exceptions import ContextOverflowError, OutputParserException
from langchain_core.language_models import (
LanguageModelInput,
ModelProfile,
@@ -721,8 +721,16 @@ def _is_code_execution_related_block(
return False
class AnthropicContextOverflowError(anthropic.BadRequestError, ContextOverflowError):
"""BadRequestError raised when input exceeds Anthropic's context limit."""
def _handle_anthropic_bad_request(e: anthropic.BadRequestError) -> None:
"""Handle Anthropic BadRequestError."""
if "prompt is too long" in e.message:
raise AnthropicContextOverflowError(
message=e.message, response=e.response, body=e.body
) from e
if ("messages: at least one message is required") in e.message:
message = "Received only system message(s). "
warnings.warn(message, stacklevel=2)

View File

@@ -11,6 +11,7 @@ import anthropic
import pytest
from anthropic.types import Message, TextBlock, Usage
from blockbuster import blockbuster_ctx
from langchain_core.exceptions import ContextOverflowError
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.runnables import RunnableBinding
from langchain_core.tools import BaseTool
@@ -2339,3 +2340,90 @@ def test__format_messages_trailing_whitespace() -> None:
ai_intermediate = AIMessage("thought ") # type: ignore[misc]
_, anthropic_messages = _format_messages([human, ai_intermediate, human])
assert anthropic_messages[1]["content"] == "thought "
# Test fixtures for context overflow error tests
_CONTEXT_OVERFLOW_BAD_REQUEST_ERROR = anthropic.BadRequestError(
message="prompt is too long: 209752 tokens > 200000 maximum",
response=MagicMock(status_code=400),
body={
"type": "error",
"error": {
"type": "invalid_request_error",
"message": "prompt is too long: 209752 tokens > 200000 maximum",
},
},
)
def test_context_overflow_error_invoke_sync() -> None:
"""Test context overflow error on invoke (sync)."""
llm = ChatAnthropic(model=MODEL_NAME)
with ( # noqa: PT012
patch.object(llm._client.messages, "create") as mock_create,
pytest.raises(ContextOverflowError) as exc_info,
):
mock_create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR
llm.invoke([HumanMessage(content="test")])
assert "prompt is too long" in str(exc_info.value)
async def test_context_overflow_error_invoke_async() -> None:
"""Test context overflow error on invoke (async)."""
llm = ChatAnthropic(model=MODEL_NAME)
with ( # noqa: PT012
patch.object(llm._async_client.messages, "create") as mock_create,
pytest.raises(ContextOverflowError) as exc_info,
):
mock_create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR
await llm.ainvoke([HumanMessage(content="test")])
assert "prompt is too long" in str(exc_info.value)
def test_context_overflow_error_stream_sync() -> None:
"""Test context overflow error on stream (sync)."""
llm = ChatAnthropic(model=MODEL_NAME)
with ( # noqa: PT012
patch.object(llm._client.messages, "create") as mock_create,
pytest.raises(ContextOverflowError) as exc_info,
):
mock_create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR
list(llm.stream([HumanMessage(content="test")]))
assert "prompt is too long" in str(exc_info.value)
async def test_context_overflow_error_stream_async() -> None:
"""Test context overflow error on stream (async)."""
llm = ChatAnthropic(model=MODEL_NAME)
with ( # noqa: PT012
patch.object(llm._async_client.messages, "create") as mock_create,
pytest.raises(ContextOverflowError) as exc_info,
):
mock_create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR
async for _ in llm.astream([HumanMessage(content="test")]):
pass
assert "prompt is too long" in str(exc_info.value)
def test_context_overflow_error_backwards_compatibility() -> None:
"""Test that ContextOverflowError can be caught as BadRequestError."""
llm = ChatAnthropic(model=MODEL_NAME)
with ( # noqa: PT012
patch.object(llm._client.messages, "create") as mock_create,
pytest.raises(anthropic.BadRequestError) as exc_info,
):
mock_create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR
llm.invoke([HumanMessage(content="test")])
# Verify it's both types (multiple inheritance)
assert isinstance(exc_info.value, anthropic.BadRequestError)
assert isinstance(exc_info.value, ContextOverflowError)

View File

@@ -40,6 +40,7 @@ from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.exceptions import ContextOverflowError
from langchain_core.language_models import (
LanguageModelInput,
ModelProfileRegistry,
@@ -449,7 +450,22 @@ def _update_token_usage(
return new_usage
class OpenAIContextOverflowError(openai.BadRequestError, ContextOverflowError):
"""BadRequestError raised when input exceeds OpenAI's context limit."""
class OpenAIAPIContextOverflowError(openai.APIError, ContextOverflowError):
"""APIError raised when input exceeds OpenAI's context limit."""
def _handle_openai_bad_request(e: openai.BadRequestError) -> None:
if (
"context_length_exceeded" in str(e)
or "Input tokens exceed the configured limit" in e.message
):
raise OpenAIContextOverflowError(
message=e.message, response=e.response, body=e.body
) from e
if (
"'response_format' of type 'json_schema' is not supported with this model"
) in e.message:
@@ -474,6 +490,15 @@ def _handle_openai_bad_request(e: openai.BadRequestError) -> None:
raise
def _handle_openai_api_error(e: openai.APIError) -> None:
error_message = str(e)
if "exceeds the context window" in error_message:
raise OpenAIAPIContextOverflowError(
message=e.message, request=e.request, body=e.body
) from e
raise
def _model_prefers_responses_api(model_name: str | None) -> bool:
if not model_name:
return False
@@ -1146,49 +1171,54 @@ class BaseChatOpenAI(BaseChatModel):
self._ensure_sync_client_available()
kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs)
if self.include_response_headers:
raw_context_manager = self.root_client.with_raw_response.responses.create(
**payload
)
context_manager = raw_context_manager.parse()
headers = {"headers": dict(raw_context_manager.headers)}
else:
context_manager = self.root_client.responses.create(**payload)
headers = {}
original_schema_obj = kwargs.get("response_format")
with context_manager as response:
is_first_chunk = True
current_index = -1
current_output_index = -1
current_sub_index = -1
has_reasoning = False
for chunk in response:
metadata = headers if is_first_chunk else {}
(
current_index,
current_output_index,
current_sub_index,
generation_chunk,
) = _convert_responses_chunk_to_generation_chunk(
chunk,
current_index,
current_output_index,
current_sub_index,
schema=original_schema_obj,
metadata=metadata,
has_reasoning=has_reasoning,
output_version=self.output_version,
try:
if self.include_response_headers:
raw_context_manager = (
self.root_client.with_raw_response.responses.create(**payload)
)
if generation_chunk:
if run_manager:
run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk
)
is_first_chunk = False
if "reasoning" in generation_chunk.message.additional_kwargs:
has_reasoning = True
yield generation_chunk
context_manager = raw_context_manager.parse()
headers = {"headers": dict(raw_context_manager.headers)}
else:
context_manager = self.root_client.responses.create(**payload)
headers = {}
original_schema_obj = kwargs.get("response_format")
with context_manager as response:
is_first_chunk = True
current_index = -1
current_output_index = -1
current_sub_index = -1
has_reasoning = False
for chunk in response:
metadata = headers if is_first_chunk else {}
(
current_index,
current_output_index,
current_sub_index,
generation_chunk,
) = _convert_responses_chunk_to_generation_chunk(
chunk,
current_index,
current_output_index,
current_sub_index,
schema=original_schema_obj,
metadata=metadata,
has_reasoning=has_reasoning,
output_version=self.output_version,
)
if generation_chunk:
if run_manager:
run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk
)
is_first_chunk = False
if "reasoning" in generation_chunk.message.additional_kwargs:
has_reasoning = True
yield generation_chunk
except openai.BadRequestError as e:
_handle_openai_bad_request(e)
except openai.APIError as e:
_handle_openai_api_error(e)
async def _astream_responses(
self,
@@ -1199,51 +1229,58 @@ class BaseChatOpenAI(BaseChatModel):
) -> AsyncIterator[ChatGenerationChunk]:
kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs)
if self.include_response_headers:
raw_context_manager = (
await self.root_async_client.with_raw_response.responses.create(
try:
if self.include_response_headers:
raw_context_manager = (
await self.root_async_client.with_raw_response.responses.create(
**payload
)
)
context_manager = raw_context_manager.parse()
headers = {"headers": dict(raw_context_manager.headers)}
else:
context_manager = await self.root_async_client.responses.create(
**payload
)
)
context_manager = raw_context_manager.parse()
headers = {"headers": dict(raw_context_manager.headers)}
else:
context_manager = await self.root_async_client.responses.create(**payload)
headers = {}
original_schema_obj = kwargs.get("response_format")
headers = {}
original_schema_obj = kwargs.get("response_format")
async with context_manager as response:
is_first_chunk = True
current_index = -1
current_output_index = -1
current_sub_index = -1
has_reasoning = False
async for chunk in response:
metadata = headers if is_first_chunk else {}
(
current_index,
current_output_index,
current_sub_index,
generation_chunk,
) = _convert_responses_chunk_to_generation_chunk(
chunk,
current_index,
current_output_index,
current_sub_index,
schema=original_schema_obj,
metadata=metadata,
has_reasoning=has_reasoning,
output_version=self.output_version,
)
if generation_chunk:
if run_manager:
await run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk
)
is_first_chunk = False
if "reasoning" in generation_chunk.message.additional_kwargs:
has_reasoning = True
yield generation_chunk
async with context_manager as response:
is_first_chunk = True
current_index = -1
current_output_index = -1
current_sub_index = -1
has_reasoning = False
async for chunk in response:
metadata = headers if is_first_chunk else {}
(
current_index,
current_output_index,
current_sub_index,
generation_chunk,
) = _convert_responses_chunk_to_generation_chunk(
chunk,
current_index,
current_output_index,
current_sub_index,
schema=original_schema_obj,
metadata=metadata,
has_reasoning=has_reasoning,
output_version=self.output_version,
)
if generation_chunk:
if run_manager:
await run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk
)
is_first_chunk = False
if "reasoning" in generation_chunk.message.additional_kwargs:
has_reasoning = True
yield generation_chunk
except openai.BadRequestError as e:
_handle_openai_bad_request(e)
except openai.APIError as e:
_handle_openai_api_error(e)
def _should_stream_usage(
self, stream_usage: bool | None = None, **kwargs: Any
@@ -1282,24 +1319,26 @@ class BaseChatOpenAI(BaseChatModel):
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
base_generation_info = {}
if "response_format" in payload:
if self.include_response_headers:
warnings.warn(
"Cannot currently include response headers when response_format is "
"specified."
)
payload.pop("stream")
response_stream = self.root_client.beta.chat.completions.stream(**payload)
context_manager = response_stream
else:
if self.include_response_headers:
raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)}
else:
response = self.client.create(**payload)
context_manager = response
try:
if "response_format" in payload:
if self.include_response_headers:
warnings.warn(
"Cannot currently include response headers when "
"response_format is specified."
)
payload.pop("stream")
response_stream = self.root_client.beta.chat.completions.stream(
**payload
)
context_manager = response_stream
else:
if self.include_response_headers:
raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)}
else:
response = self.client.create(**payload)
context_manager = response
with context_manager as response:
is_first_chunk = True
for chunk in response:
@@ -1324,6 +1363,8 @@ class BaseChatOpenAI(BaseChatModel):
yield generation_chunk
except openai.BadRequestError as e:
_handle_openai_bad_request(e)
except openai.APIError as e:
_handle_openai_api_error(e)
if hasattr(response, "get_final_completion") and "response_format" in payload:
final_completion = response.get_final_completion()
generation_chunk = self._get_generation_chunk_from_completion(
@@ -1349,15 +1390,10 @@ class BaseChatOpenAI(BaseChatModel):
try:
if "response_format" in payload:
payload.pop("stream")
try:
raw_response = (
self.root_client.chat.completions.with_raw_response.parse(
**payload
)
)
response = raw_response.parse()
except openai.BadRequestError as e:
_handle_openai_bad_request(e)
raw_response = (
self.root_client.chat.completions.with_raw_response.parse(**payload)
)
response = raw_response.parse()
elif self._use_responses_api(payload):
original_schema_obj = kwargs.get("response_format")
if original_schema_obj and _is_pydantic_class(original_schema_obj):
@@ -1380,6 +1416,10 @@ class BaseChatOpenAI(BaseChatModel):
else:
raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse()
except openai.BadRequestError as e:
_handle_openai_bad_request(e)
except openai.APIError as e:
_handle_openai_api_error(e)
except Exception as e:
if raw_response is not None and hasattr(raw_response, "http_response"):
e.response = raw_response.http_response # type: ignore[attr-defined]
@@ -1523,28 +1563,28 @@ class BaseChatOpenAI(BaseChatModel):
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
base_generation_info = {}
if "response_format" in payload:
if self.include_response_headers:
warnings.warn(
"Cannot currently include response headers when response_format is "
"specified."
)
payload.pop("stream")
response_stream = self.root_async_client.beta.chat.completions.stream(
**payload
)
context_manager = response_stream
else:
if self.include_response_headers:
raw_response = await self.async_client.with_raw_response.create(
try:
if "response_format" in payload:
if self.include_response_headers:
warnings.warn(
"Cannot currently include response headers when "
"response_format is specified."
)
payload.pop("stream")
response_stream = self.root_async_client.beta.chat.completions.stream(
**payload
)
response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)}
context_manager = response_stream
else:
response = await self.async_client.create(**payload)
context_manager = response
try:
if self.include_response_headers:
raw_response = await self.async_client.with_raw_response.create(
**payload
)
response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)}
else:
response = await self.async_client.create(**payload)
context_manager = response
async with context_manager as response:
is_first_chunk = True
async for chunk in response:
@@ -1569,6 +1609,8 @@ class BaseChatOpenAI(BaseChatModel):
yield generation_chunk
except openai.BadRequestError as e:
_handle_openai_bad_request(e)
except openai.APIError as e:
_handle_openai_api_error(e)
if hasattr(response, "get_final_completion") and "response_format" in payload:
final_completion = await response.get_final_completion()
generation_chunk = self._get_generation_chunk_from_completion(
@@ -1593,13 +1635,10 @@ class BaseChatOpenAI(BaseChatModel):
try:
if "response_format" in payload:
payload.pop("stream")
try:
raw_response = await self.root_async_client.chat.completions.with_raw_response.parse( # noqa: E501
**payload
)
response = raw_response.parse()
except openai.BadRequestError as e:
_handle_openai_bad_request(e)
raw_response = await self.root_async_client.chat.completions.with_raw_response.parse( # noqa: E501
**payload
)
response = raw_response.parse()
elif self._use_responses_api(payload):
original_schema_obj = kwargs.get("response_format")
if original_schema_obj and _is_pydantic_class(original_schema_obj):
@@ -1628,6 +1667,10 @@ class BaseChatOpenAI(BaseChatModel):
**payload
)
response = raw_response.parse()
except openai.BadRequestError as e:
_handle_openai_bad_request(e)
except openai.APIError as e:
_handle_openai_api_error(e)
except Exception as e:
if raw_response is not None and hasattr(raw_response, "http_response"):
e.response = raw_response.http_response # type: ignore[attr-defined]

View File

@@ -10,7 +10,9 @@ from typing import Any, Literal, cast
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import openai
import pytest
from langchain_core.exceptions import ContextOverflowError
from langchain_core.load import dumps, loads
from langchain_core.messages import (
AIMessage,
@@ -3231,3 +3233,162 @@ def test_openai_structured_output_refusal_handling_responses_api() -> None:
pass
except ValueError as e:
pytest.fail(f"This is a wrong behavior. Error details: {e}")
# Test fixtures for context overflow error tests
_CONTEXT_OVERFLOW_ERROR_BODY = {
"error": {
"message": (
"Input tokens exceed the configured limit of 272000 tokens. Your messages "
"resulted in 300007 tokens. Please reduce the length of the messages."
),
"type": "invalid_request_error",
"param": "messages",
"code": "context_length_exceeded",
}
}
_CONTEXT_OVERFLOW_BAD_REQUEST_ERROR = openai.BadRequestError(
message=_CONTEXT_OVERFLOW_ERROR_BODY["error"]["message"],
response=MagicMock(status_code=400),
body=_CONTEXT_OVERFLOW_ERROR_BODY,
)
_CONTEXT_OVERFLOW_API_ERROR = openai.APIError(
message=(
"Your input exceeds the context window of this model. Please adjust your input "
"and try again."
),
request=MagicMock(),
body=None,
)
def test_context_overflow_error_invoke_sync() -> None:
"""Test context overflow error on invoke (sync, chat completions API)."""
llm = ChatOpenAI()
with ( # noqa: PT012
patch.object(llm.client, "with_raw_response") as mock_client,
pytest.raises(ContextOverflowError) as exc_info,
):
mock_client.create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR
llm.invoke([HumanMessage(content="test")])
assert "Input tokens exceed the configured limit" in str(exc_info.value)
def test_context_overflow_error_invoke_sync_responses_api() -> None:
"""Test context overflow error on invoke (sync, responses API)."""
llm = ChatOpenAI(use_responses_api=True)
with ( # noqa: PT012
patch.object(llm.root_client.responses, "with_raw_response") as mock_client,
pytest.raises(ContextOverflowError) as exc_info,
):
mock_client.create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR
llm.invoke([HumanMessage(content="test")])
assert "Input tokens exceed the configured limit" in str(exc_info.value)
async def test_context_overflow_error_invoke_async() -> None:
"""Test context overflow error on invoke (async, chat completions API)."""
llm = ChatOpenAI()
with ( # noqa: PT012
patch.object(llm.async_client, "with_raw_response") as mock_client,
pytest.raises(ContextOverflowError) as exc_info,
):
mock_client.create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR
await llm.ainvoke([HumanMessage(content="test")])
assert "Input tokens exceed the configured limit" in str(exc_info.value)
async def test_context_overflow_error_invoke_async_responses_api() -> None:
"""Test context overflow error on invoke (async, responses API)."""
llm = ChatOpenAI(use_responses_api=True)
with ( # noqa: PT012
patch.object(
llm.root_async_client.responses, "with_raw_response"
) as mock_client,
pytest.raises(ContextOverflowError) as exc_info,
):
mock_client.create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR
await llm.ainvoke([HumanMessage(content="test")])
assert "Input tokens exceed the configured limit" in str(exc_info.value)
def test_context_overflow_error_stream_sync() -> None:
"""Test context overflow error on stream (sync, chat completions API)."""
llm = ChatOpenAI()
with ( # noqa: PT012
patch.object(llm.client, "create") as mock_create,
pytest.raises(ContextOverflowError) as exc_info,
):
mock_create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR
list(llm.stream([HumanMessage(content="test")]))
assert "Input tokens exceed the configured limit" in str(exc_info.value)
def test_context_overflow_error_stream_sync_responses_api() -> None:
"""Test context overflow error on stream (sync, responses API)."""
llm = ChatOpenAI(use_responses_api=True)
with ( # noqa: PT012
patch.object(llm.root_client.responses, "create") as mock_create,
pytest.raises(ContextOverflowError) as exc_info,
):
mock_create.side_effect = _CONTEXT_OVERFLOW_API_ERROR
list(llm.stream([HumanMessage(content="test")]))
assert "exceeds the context window" in str(exc_info.value)
async def test_context_overflow_error_stream_async() -> None:
"""Test context overflow error on stream (async, chat completions API)."""
llm = ChatOpenAI()
with ( # noqa: PT012
patch.object(llm.async_client, "create") as mock_create,
pytest.raises(ContextOverflowError) as exc_info,
):
mock_create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR
async for _ in llm.astream([HumanMessage(content="test")]):
pass
assert "Input tokens exceed the configured limit" in str(exc_info.value)
async def test_context_overflow_error_stream_async_responses_api() -> None:
"""Test context overflow error on stream (async, responses API)."""
llm = ChatOpenAI(use_responses_api=True)
with ( # noqa: PT012
patch.object(llm.root_async_client.responses, "create") as mock_create,
pytest.raises(ContextOverflowError) as exc_info,
):
mock_create.side_effect = _CONTEXT_OVERFLOW_API_ERROR
async for _ in llm.astream([HumanMessage(content="test")]):
pass
assert "exceeds the context window" in str(exc_info.value)
def test_context_overflow_error_backwards_compatibility() -> None:
"""Test that ContextOverflowError can be caught as BadRequestError."""
llm = ChatOpenAI()
with ( # noqa: PT012
patch.object(llm.client, "with_raw_response") as mock_client,
pytest.raises(openai.BadRequestError) as exc_info,
):
mock_client.create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR
llm.invoke([HumanMessage(content="test")])
# Verify it's both types (multiple inheritance)
assert isinstance(exc_info.value, openai.BadRequestError)
assert isinstance(exc_info.value, ContextOverflowError)