From 7c41298355ff0e66b1a252a0a4284aa6ca2e3184 Mon Sep 17 00:00:00 2001 From: ccurme Date: Mon, 9 Feb 2026 15:15:34 -0500 Subject: [PATCH] feat(core): add ContextOverflowError, raise in anthropic and openai (#35099) --- libs/core/langchain_core/exceptions.py | 8 + .../langchain_anthropic/chat_models.py | 10 +- .../tests/unit_tests/test_chat_models.py | 88 +++++ .../langchain_openai/chat_models/base.py | 315 ++++++++++-------- .../tests/unit_tests/chat_models/test_base.py | 161 +++++++++ 5 files changed, 445 insertions(+), 137 deletions(-) diff --git a/libs/core/langchain_core/exceptions.py b/libs/core/langchain_core/exceptions.py index 073cfa99022..f58754c9d55 100644 --- a/libs/core/langchain_core/exceptions.py +++ b/libs/core/langchain_core/exceptions.py @@ -65,6 +65,14 @@ class OutputParserException(ValueError, LangChainException): # noqa: N818 self.send_to_llm = send_to_llm +class ContextOverflowError(LangChainException): + """Exception raised when input exceeds the model's context limit. + + This exception is raised by chat models when the input tokens exceed + the maximum context window supported by the model. + """ + + class ErrorCode(Enum): """Error codes.""" diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index 266d60bdce8..9f24778fbb3 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -17,7 +17,7 @@ from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) -from langchain_core.exceptions import OutputParserException +from langchain_core.exceptions import ContextOverflowError, OutputParserException from langchain_core.language_models import ( LanguageModelInput, ModelProfile, @@ -721,8 +721,16 @@ def _is_code_execution_related_block( return False +class AnthropicContextOverflowError(anthropic.BadRequestError, ContextOverflowError): + """BadRequestError raised when input exceeds Anthropic's context limit.""" + + def _handle_anthropic_bad_request(e: anthropic.BadRequestError) -> None: """Handle Anthropic BadRequestError.""" + if "prompt is too long" in e.message: + raise AnthropicContextOverflowError( + message=e.message, response=e.response, body=e.body + ) from e if ("messages: at least one message is required") in e.message: message = "Received only system message(s). " warnings.warn(message, stacklevel=2) diff --git a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py index e669b0655c0..f0e11f177ff 100644 --- a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py @@ -11,6 +11,7 @@ import anthropic import pytest from anthropic.types import Message, TextBlock, Usage from blockbuster import blockbuster_ctx +from langchain_core.exceptions import ContextOverflowError from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from langchain_core.runnables import RunnableBinding from langchain_core.tools import BaseTool @@ -2339,3 +2340,90 @@ def test__format_messages_trailing_whitespace() -> None: ai_intermediate = AIMessage("thought ") # type: ignore[misc] _, anthropic_messages = _format_messages([human, ai_intermediate, human]) assert anthropic_messages[1]["content"] == "thought " + + +# Test fixtures for context overflow error tests +_CONTEXT_OVERFLOW_BAD_REQUEST_ERROR = anthropic.BadRequestError( + message="prompt is too long: 209752 tokens > 200000 maximum", + response=MagicMock(status_code=400), + body={ + "type": "error", + "error": { + "type": "invalid_request_error", + "message": "prompt is too long: 209752 tokens > 200000 maximum", + }, + }, +) + + +def test_context_overflow_error_invoke_sync() -> None: + """Test context overflow error on invoke (sync).""" + llm = ChatAnthropic(model=MODEL_NAME) + + with ( # noqa: PT012 + patch.object(llm._client.messages, "create") as mock_create, + pytest.raises(ContextOverflowError) as exc_info, + ): + mock_create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR + llm.invoke([HumanMessage(content="test")]) + + assert "prompt is too long" in str(exc_info.value) + + +async def test_context_overflow_error_invoke_async() -> None: + """Test context overflow error on invoke (async).""" + llm = ChatAnthropic(model=MODEL_NAME) + + with ( # noqa: PT012 + patch.object(llm._async_client.messages, "create") as mock_create, + pytest.raises(ContextOverflowError) as exc_info, + ): + mock_create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR + await llm.ainvoke([HumanMessage(content="test")]) + + assert "prompt is too long" in str(exc_info.value) + + +def test_context_overflow_error_stream_sync() -> None: + """Test context overflow error on stream (sync).""" + llm = ChatAnthropic(model=MODEL_NAME) + + with ( # noqa: PT012 + patch.object(llm._client.messages, "create") as mock_create, + pytest.raises(ContextOverflowError) as exc_info, + ): + mock_create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR + list(llm.stream([HumanMessage(content="test")])) + + assert "prompt is too long" in str(exc_info.value) + + +async def test_context_overflow_error_stream_async() -> None: + """Test context overflow error on stream (async).""" + llm = ChatAnthropic(model=MODEL_NAME) + + with ( # noqa: PT012 + patch.object(llm._async_client.messages, "create") as mock_create, + pytest.raises(ContextOverflowError) as exc_info, + ): + mock_create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR + async for _ in llm.astream([HumanMessage(content="test")]): + pass + + assert "prompt is too long" in str(exc_info.value) + + +def test_context_overflow_error_backwards_compatibility() -> None: + """Test that ContextOverflowError can be caught as BadRequestError.""" + llm = ChatAnthropic(model=MODEL_NAME) + + with ( # noqa: PT012 + patch.object(llm._client.messages, "create") as mock_create, + pytest.raises(anthropic.BadRequestError) as exc_info, + ): + mock_create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR + llm.invoke([HumanMessage(content="test")]) + + # Verify it's both types (multiple inheritance) + assert isinstance(exc_info.value, anthropic.BadRequestError) + assert isinstance(exc_info.value, ContextOverflowError) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 41aecad78c1..0128bfd3232 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -40,6 +40,7 @@ from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) +from langchain_core.exceptions import ContextOverflowError from langchain_core.language_models import ( LanguageModelInput, ModelProfileRegistry, @@ -449,7 +450,22 @@ def _update_token_usage( return new_usage +class OpenAIContextOverflowError(openai.BadRequestError, ContextOverflowError): + """BadRequestError raised when input exceeds OpenAI's context limit.""" + + +class OpenAIAPIContextOverflowError(openai.APIError, ContextOverflowError): + """APIError raised when input exceeds OpenAI's context limit.""" + + def _handle_openai_bad_request(e: openai.BadRequestError) -> None: + if ( + "context_length_exceeded" in str(e) + or "Input tokens exceed the configured limit" in e.message + ): + raise OpenAIContextOverflowError( + message=e.message, response=e.response, body=e.body + ) from e if ( "'response_format' of type 'json_schema' is not supported with this model" ) in e.message: @@ -474,6 +490,15 @@ def _handle_openai_bad_request(e: openai.BadRequestError) -> None: raise +def _handle_openai_api_error(e: openai.APIError) -> None: + error_message = str(e) + if "exceeds the context window" in error_message: + raise OpenAIAPIContextOverflowError( + message=e.message, request=e.request, body=e.body + ) from e + raise + + def _model_prefers_responses_api(model_name: str | None) -> bool: if not model_name: return False @@ -1146,49 +1171,54 @@ class BaseChatOpenAI(BaseChatModel): self._ensure_sync_client_available() kwargs["stream"] = True payload = self._get_request_payload(messages, stop=stop, **kwargs) - if self.include_response_headers: - raw_context_manager = self.root_client.with_raw_response.responses.create( - **payload - ) - context_manager = raw_context_manager.parse() - headers = {"headers": dict(raw_context_manager.headers)} - else: - context_manager = self.root_client.responses.create(**payload) - headers = {} - original_schema_obj = kwargs.get("response_format") - - with context_manager as response: - is_first_chunk = True - current_index = -1 - current_output_index = -1 - current_sub_index = -1 - has_reasoning = False - for chunk in response: - metadata = headers if is_first_chunk else {} - ( - current_index, - current_output_index, - current_sub_index, - generation_chunk, - ) = _convert_responses_chunk_to_generation_chunk( - chunk, - current_index, - current_output_index, - current_sub_index, - schema=original_schema_obj, - metadata=metadata, - has_reasoning=has_reasoning, - output_version=self.output_version, + try: + if self.include_response_headers: + raw_context_manager = ( + self.root_client.with_raw_response.responses.create(**payload) ) - if generation_chunk: - if run_manager: - run_manager.on_llm_new_token( - generation_chunk.text, chunk=generation_chunk - ) - is_first_chunk = False - if "reasoning" in generation_chunk.message.additional_kwargs: - has_reasoning = True - yield generation_chunk + context_manager = raw_context_manager.parse() + headers = {"headers": dict(raw_context_manager.headers)} + else: + context_manager = self.root_client.responses.create(**payload) + headers = {} + original_schema_obj = kwargs.get("response_format") + + with context_manager as response: + is_first_chunk = True + current_index = -1 + current_output_index = -1 + current_sub_index = -1 + has_reasoning = False + for chunk in response: + metadata = headers if is_first_chunk else {} + ( + current_index, + current_output_index, + current_sub_index, + generation_chunk, + ) = _convert_responses_chunk_to_generation_chunk( + chunk, + current_index, + current_output_index, + current_sub_index, + schema=original_schema_obj, + metadata=metadata, + has_reasoning=has_reasoning, + output_version=self.output_version, + ) + if generation_chunk: + if run_manager: + run_manager.on_llm_new_token( + generation_chunk.text, chunk=generation_chunk + ) + is_first_chunk = False + if "reasoning" in generation_chunk.message.additional_kwargs: + has_reasoning = True + yield generation_chunk + except openai.BadRequestError as e: + _handle_openai_bad_request(e) + except openai.APIError as e: + _handle_openai_api_error(e) async def _astream_responses( self, @@ -1199,51 +1229,58 @@ class BaseChatOpenAI(BaseChatModel): ) -> AsyncIterator[ChatGenerationChunk]: kwargs["stream"] = True payload = self._get_request_payload(messages, stop=stop, **kwargs) - if self.include_response_headers: - raw_context_manager = ( - await self.root_async_client.with_raw_response.responses.create( + try: + if self.include_response_headers: + raw_context_manager = ( + await self.root_async_client.with_raw_response.responses.create( + **payload + ) + ) + context_manager = raw_context_manager.parse() + headers = {"headers": dict(raw_context_manager.headers)} + else: + context_manager = await self.root_async_client.responses.create( **payload ) - ) - context_manager = raw_context_manager.parse() - headers = {"headers": dict(raw_context_manager.headers)} - else: - context_manager = await self.root_async_client.responses.create(**payload) - headers = {} - original_schema_obj = kwargs.get("response_format") + headers = {} + original_schema_obj = kwargs.get("response_format") - async with context_manager as response: - is_first_chunk = True - current_index = -1 - current_output_index = -1 - current_sub_index = -1 - has_reasoning = False - async for chunk in response: - metadata = headers if is_first_chunk else {} - ( - current_index, - current_output_index, - current_sub_index, - generation_chunk, - ) = _convert_responses_chunk_to_generation_chunk( - chunk, - current_index, - current_output_index, - current_sub_index, - schema=original_schema_obj, - metadata=metadata, - has_reasoning=has_reasoning, - output_version=self.output_version, - ) - if generation_chunk: - if run_manager: - await run_manager.on_llm_new_token( - generation_chunk.text, chunk=generation_chunk - ) - is_first_chunk = False - if "reasoning" in generation_chunk.message.additional_kwargs: - has_reasoning = True - yield generation_chunk + async with context_manager as response: + is_first_chunk = True + current_index = -1 + current_output_index = -1 + current_sub_index = -1 + has_reasoning = False + async for chunk in response: + metadata = headers if is_first_chunk else {} + ( + current_index, + current_output_index, + current_sub_index, + generation_chunk, + ) = _convert_responses_chunk_to_generation_chunk( + chunk, + current_index, + current_output_index, + current_sub_index, + schema=original_schema_obj, + metadata=metadata, + has_reasoning=has_reasoning, + output_version=self.output_version, + ) + if generation_chunk: + if run_manager: + await run_manager.on_llm_new_token( + generation_chunk.text, chunk=generation_chunk + ) + is_first_chunk = False + if "reasoning" in generation_chunk.message.additional_kwargs: + has_reasoning = True + yield generation_chunk + except openai.BadRequestError as e: + _handle_openai_bad_request(e) + except openai.APIError as e: + _handle_openai_api_error(e) def _should_stream_usage( self, stream_usage: bool | None = None, **kwargs: Any @@ -1282,24 +1319,26 @@ class BaseChatOpenAI(BaseChatModel): default_chunk_class: type[BaseMessageChunk] = AIMessageChunk base_generation_info = {} - if "response_format" in payload: - if self.include_response_headers: - warnings.warn( - "Cannot currently include response headers when response_format is " - "specified." - ) - payload.pop("stream") - response_stream = self.root_client.beta.chat.completions.stream(**payload) - context_manager = response_stream - else: - if self.include_response_headers: - raw_response = self.client.with_raw_response.create(**payload) - response = raw_response.parse() - base_generation_info = {"headers": dict(raw_response.headers)} - else: - response = self.client.create(**payload) - context_manager = response try: + if "response_format" in payload: + if self.include_response_headers: + warnings.warn( + "Cannot currently include response headers when " + "response_format is specified." + ) + payload.pop("stream") + response_stream = self.root_client.beta.chat.completions.stream( + **payload + ) + context_manager = response_stream + else: + if self.include_response_headers: + raw_response = self.client.with_raw_response.create(**payload) + response = raw_response.parse() + base_generation_info = {"headers": dict(raw_response.headers)} + else: + response = self.client.create(**payload) + context_manager = response with context_manager as response: is_first_chunk = True for chunk in response: @@ -1324,6 +1363,8 @@ class BaseChatOpenAI(BaseChatModel): yield generation_chunk except openai.BadRequestError as e: _handle_openai_bad_request(e) + except openai.APIError as e: + _handle_openai_api_error(e) if hasattr(response, "get_final_completion") and "response_format" in payload: final_completion = response.get_final_completion() generation_chunk = self._get_generation_chunk_from_completion( @@ -1349,15 +1390,10 @@ class BaseChatOpenAI(BaseChatModel): try: if "response_format" in payload: payload.pop("stream") - try: - raw_response = ( - self.root_client.chat.completions.with_raw_response.parse( - **payload - ) - ) - response = raw_response.parse() - except openai.BadRequestError as e: - _handle_openai_bad_request(e) + raw_response = ( + self.root_client.chat.completions.with_raw_response.parse(**payload) + ) + response = raw_response.parse() elif self._use_responses_api(payload): original_schema_obj = kwargs.get("response_format") if original_schema_obj and _is_pydantic_class(original_schema_obj): @@ -1380,6 +1416,10 @@ class BaseChatOpenAI(BaseChatModel): else: raw_response = self.client.with_raw_response.create(**payload) response = raw_response.parse() + except openai.BadRequestError as e: + _handle_openai_bad_request(e) + except openai.APIError as e: + _handle_openai_api_error(e) except Exception as e: if raw_response is not None and hasattr(raw_response, "http_response"): e.response = raw_response.http_response # type: ignore[attr-defined] @@ -1523,28 +1563,28 @@ class BaseChatOpenAI(BaseChatModel): default_chunk_class: type[BaseMessageChunk] = AIMessageChunk base_generation_info = {} - if "response_format" in payload: - if self.include_response_headers: - warnings.warn( - "Cannot currently include response headers when response_format is " - "specified." - ) - payload.pop("stream") - response_stream = self.root_async_client.beta.chat.completions.stream( - **payload - ) - context_manager = response_stream - else: - if self.include_response_headers: - raw_response = await self.async_client.with_raw_response.create( + try: + if "response_format" in payload: + if self.include_response_headers: + warnings.warn( + "Cannot currently include response headers when " + "response_format is specified." + ) + payload.pop("stream") + response_stream = self.root_async_client.beta.chat.completions.stream( **payload ) - response = raw_response.parse() - base_generation_info = {"headers": dict(raw_response.headers)} + context_manager = response_stream else: - response = await self.async_client.create(**payload) - context_manager = response - try: + if self.include_response_headers: + raw_response = await self.async_client.with_raw_response.create( + **payload + ) + response = raw_response.parse() + base_generation_info = {"headers": dict(raw_response.headers)} + else: + response = await self.async_client.create(**payload) + context_manager = response async with context_manager as response: is_first_chunk = True async for chunk in response: @@ -1569,6 +1609,8 @@ class BaseChatOpenAI(BaseChatModel): yield generation_chunk except openai.BadRequestError as e: _handle_openai_bad_request(e) + except openai.APIError as e: + _handle_openai_api_error(e) if hasattr(response, "get_final_completion") and "response_format" in payload: final_completion = await response.get_final_completion() generation_chunk = self._get_generation_chunk_from_completion( @@ -1593,13 +1635,10 @@ class BaseChatOpenAI(BaseChatModel): try: if "response_format" in payload: payload.pop("stream") - try: - raw_response = await self.root_async_client.chat.completions.with_raw_response.parse( # noqa: E501 - **payload - ) - response = raw_response.parse() - except openai.BadRequestError as e: - _handle_openai_bad_request(e) + raw_response = await self.root_async_client.chat.completions.with_raw_response.parse( # noqa: E501 + **payload + ) + response = raw_response.parse() elif self._use_responses_api(payload): original_schema_obj = kwargs.get("response_format") if original_schema_obj and _is_pydantic_class(original_schema_obj): @@ -1628,6 +1667,10 @@ class BaseChatOpenAI(BaseChatModel): **payload ) response = raw_response.parse() + except openai.BadRequestError as e: + _handle_openai_bad_request(e) + except openai.APIError as e: + _handle_openai_api_error(e) except Exception as e: if raw_response is not None and hasattr(raw_response, "http_response"): e.response = raw_response.http_response # type: ignore[attr-defined] diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index 886644c8a0c..fe87233572a 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -10,7 +10,9 @@ from typing import Any, Literal, cast from unittest.mock import AsyncMock, MagicMock, patch import httpx +import openai import pytest +from langchain_core.exceptions import ContextOverflowError from langchain_core.load import dumps, loads from langchain_core.messages import ( AIMessage, @@ -3231,3 +3233,162 @@ def test_openai_structured_output_refusal_handling_responses_api() -> None: pass except ValueError as e: pytest.fail(f"This is a wrong behavior. Error details: {e}") + + +# Test fixtures for context overflow error tests +_CONTEXT_OVERFLOW_ERROR_BODY = { + "error": { + "message": ( + "Input tokens exceed the configured limit of 272000 tokens. Your messages " + "resulted in 300007 tokens. Please reduce the length of the messages." + ), + "type": "invalid_request_error", + "param": "messages", + "code": "context_length_exceeded", + } +} +_CONTEXT_OVERFLOW_BAD_REQUEST_ERROR = openai.BadRequestError( + message=_CONTEXT_OVERFLOW_ERROR_BODY["error"]["message"], + response=MagicMock(status_code=400), + body=_CONTEXT_OVERFLOW_ERROR_BODY, +) +_CONTEXT_OVERFLOW_API_ERROR = openai.APIError( + message=( + "Your input exceeds the context window of this model. Please adjust your input " + "and try again." + ), + request=MagicMock(), + body=None, +) + + +def test_context_overflow_error_invoke_sync() -> None: + """Test context overflow error on invoke (sync, chat completions API).""" + llm = ChatOpenAI() + + with ( # noqa: PT012 + patch.object(llm.client, "with_raw_response") as mock_client, + pytest.raises(ContextOverflowError) as exc_info, + ): + mock_client.create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR + llm.invoke([HumanMessage(content="test")]) + + assert "Input tokens exceed the configured limit" in str(exc_info.value) + + +def test_context_overflow_error_invoke_sync_responses_api() -> None: + """Test context overflow error on invoke (sync, responses API).""" + llm = ChatOpenAI(use_responses_api=True) + + with ( # noqa: PT012 + patch.object(llm.root_client.responses, "with_raw_response") as mock_client, + pytest.raises(ContextOverflowError) as exc_info, + ): + mock_client.create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR + llm.invoke([HumanMessage(content="test")]) + + assert "Input tokens exceed the configured limit" in str(exc_info.value) + + +async def test_context_overflow_error_invoke_async() -> None: + """Test context overflow error on invoke (async, chat completions API).""" + llm = ChatOpenAI() + + with ( # noqa: PT012 + patch.object(llm.async_client, "with_raw_response") as mock_client, + pytest.raises(ContextOverflowError) as exc_info, + ): + mock_client.create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR + await llm.ainvoke([HumanMessage(content="test")]) + + assert "Input tokens exceed the configured limit" in str(exc_info.value) + + +async def test_context_overflow_error_invoke_async_responses_api() -> None: + """Test context overflow error on invoke (async, responses API).""" + llm = ChatOpenAI(use_responses_api=True) + + with ( # noqa: PT012 + patch.object( + llm.root_async_client.responses, "with_raw_response" + ) as mock_client, + pytest.raises(ContextOverflowError) as exc_info, + ): + mock_client.create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR + await llm.ainvoke([HumanMessage(content="test")]) + + assert "Input tokens exceed the configured limit" in str(exc_info.value) + + +def test_context_overflow_error_stream_sync() -> None: + """Test context overflow error on stream (sync, chat completions API).""" + llm = ChatOpenAI() + + with ( # noqa: PT012 + patch.object(llm.client, "create") as mock_create, + pytest.raises(ContextOverflowError) as exc_info, + ): + mock_create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR + list(llm.stream([HumanMessage(content="test")])) + + assert "Input tokens exceed the configured limit" in str(exc_info.value) + + +def test_context_overflow_error_stream_sync_responses_api() -> None: + """Test context overflow error on stream (sync, responses API).""" + llm = ChatOpenAI(use_responses_api=True) + + with ( # noqa: PT012 + patch.object(llm.root_client.responses, "create") as mock_create, + pytest.raises(ContextOverflowError) as exc_info, + ): + mock_create.side_effect = _CONTEXT_OVERFLOW_API_ERROR + list(llm.stream([HumanMessage(content="test")])) + + assert "exceeds the context window" in str(exc_info.value) + + +async def test_context_overflow_error_stream_async() -> None: + """Test context overflow error on stream (async, chat completions API).""" + llm = ChatOpenAI() + + with ( # noqa: PT012 + patch.object(llm.async_client, "create") as mock_create, + pytest.raises(ContextOverflowError) as exc_info, + ): + mock_create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR + async for _ in llm.astream([HumanMessage(content="test")]): + pass + + assert "Input tokens exceed the configured limit" in str(exc_info.value) + + +async def test_context_overflow_error_stream_async_responses_api() -> None: + """Test context overflow error on stream (async, responses API).""" + llm = ChatOpenAI(use_responses_api=True) + + with ( # noqa: PT012 + patch.object(llm.root_async_client.responses, "create") as mock_create, + pytest.raises(ContextOverflowError) as exc_info, + ): + mock_create.side_effect = _CONTEXT_OVERFLOW_API_ERROR + async for _ in llm.astream([HumanMessage(content="test")]): + pass + + assert "exceeds the context window" in str(exc_info.value) + + +def test_context_overflow_error_backwards_compatibility() -> None: + """Test that ContextOverflowError can be caught as BadRequestError.""" + llm = ChatOpenAI() + + with ( # noqa: PT012 + patch.object(llm.client, "with_raw_response") as mock_client, + pytest.raises(openai.BadRequestError) as exc_info, + ): + mock_client.create.side_effect = _CONTEXT_OVERFLOW_BAD_REQUEST_ERROR + llm.invoke([HumanMessage(content="test")]) + + # Verify it's both types (multiple inheritance) + assert isinstance(exc_info.value, openai.BadRequestError) + assert isinstance(exc_info.value, ContextOverflowError)