From 40c515c7b18830460672b455b74e9d7140d2a03b Mon Sep 17 00:00:00 2001 From: "open-swe[bot]" <215916821+open-swe[bot]@users.noreply.github.com> Date: Sun, 17 May 2026 13:35:48 -0400 Subject: [PATCH] fix(fireworks): raise `ContextOverflowError` on prompt-too-long (#37458) Co-authored-by: open-swe[bot] Co-authored-by: ccurme <26529506+ccurme@users.noreply.github.com> --- .../langchain_fireworks/chat_models.py | 51 +++++++-- .../tests/unit_tests/test_chat_models.py | 105 ++++++++++++++++++ 2 files changed, 144 insertions(+), 12 deletions(-) diff --git a/libs/partners/fireworks/langchain_fireworks/chat_models.py b/libs/partners/fireworks/langchain_fireworks/chat_models.py index 4e65cb07dc6..817bde29ff6 100644 --- a/libs/partners/fireworks/langchain_fireworks/chat_models.py +++ b/libs/partners/fireworks/langchain_fireworks/chat_models.py @@ -21,6 +21,7 @@ from fireworks.client.error import ( # type: ignore[import-untyped] BadGatewayError, FireworksError, InternalServerError, + InvalidRequestError, RateLimitError, ServiceUnavailableError, ) @@ -28,6 +29,7 @@ from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) +from langchain_core.exceptions import ContextOverflowError from langchain_core.language_models import ( LanguageModelInput, ModelProfile, @@ -436,6 +438,17 @@ def _promote_http_status_error(exc: httpx.HTTPStatusError) -> NoReturn: raise exc +class FireworksContextOverflowError(InvalidRequestError, ContextOverflowError): + """`InvalidRequestError` raised when input exceeds Fireworks's context limit.""" + + +def _handle_fireworks_invalid_request(e: InvalidRequestError) -> NoReturn: + """Promote prompt-too-long errors to `FireworksContextOverflowError`.""" + if "prompt is too long" in str(e): + raise FireworksContextOverflowError(str(e)) from e + raise e + + def _raise_empty_stream() -> NoReturn: """Raise a descriptive error when the SDK returns a zero-chunk stream.""" msg = "Received empty stream from Fireworks" @@ -791,9 +804,13 @@ class ChatFireworks(BaseChatModel): params["stream_options"] = {"include_usage": True} default_chunk_class: type[BaseMessageChunk] = AIMessageChunk - for chunk in _completion_with_retry( - self, run_manager=run_manager, messages=message_dicts, **params - ): + try: + stream = _completion_with_retry( + self, run_manager=run_manager, messages=message_dicts, **params + ) + except InvalidRequestError as e: + _handle_fireworks_invalid_request(e) + for chunk in stream: if not isinstance(chunk, dict): chunk = chunk.model_dump() message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class) @@ -837,9 +854,12 @@ class ChatFireworks(BaseChatModel): **({"stream": stream} if stream is not None else {}), **kwargs, } - response = _completion_with_retry( - self, run_manager=run_manager, messages=message_dicts, **params - ) + try: + response = _completion_with_retry( + self, run_manager=run_manager, messages=message_dicts, **params + ) + except InvalidRequestError as e: + _handle_fireworks_invalid_request(e) return self._create_chat_result(response) def _create_message_dicts( @@ -899,9 +919,13 @@ class ChatFireworks(BaseChatModel): params["stream_options"] = {"include_usage": True} default_chunk_class: type[BaseMessageChunk] = AIMessageChunk - async for chunk in await _acompletion_with_retry( - self, run_manager=run_manager, messages=message_dicts, **params - ): + try: + stream = await _acompletion_with_retry( + self, run_manager=run_manager, messages=message_dicts, **params + ) + except InvalidRequestError as e: + _handle_fireworks_invalid_request(e) + async for chunk in stream: if not isinstance(chunk, dict): chunk = chunk.model_dump() message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class) @@ -948,9 +972,12 @@ class ChatFireworks(BaseChatModel): **({"stream": stream} if stream is not None else {}), **kwargs, } - response = await _acompletion_with_retry( - self, run_manager=run_manager, messages=message_dicts, **params - ) + try: + response = await _acompletion_with_retry( + self, run_manager=run_manager, messages=message_dicts, **params + ) + except InvalidRequestError as e: + _handle_fireworks_invalid_request(e) return self._create_chat_result(response) @property diff --git a/libs/partners/fireworks/tests/unit_tests/test_chat_models.py b/libs/partners/fireworks/tests/unit_tests/test_chat_models.py index 3259cde5726..7b2eb9c367f 100644 --- a/libs/partners/fireworks/tests/unit_tests/test_chat_models.py +++ b/libs/partners/fireworks/tests/unit_tests/test_chat_models.py @@ -14,6 +14,7 @@ from fireworks.client.error import ( # type: ignore[import-untyped] RateLimitError, ServiceUnavailableError, ) +from langchain_core.exceptions import ContextOverflowError from langchain_core.messages import ( AIMessage, AIMessageChunk, @@ -25,6 +26,7 @@ from langchain_core.messages import ( from langchain_fireworks import ChatFireworks from langchain_fireworks.chat_models import ( + FireworksContextOverflowError, _acompletion_with_retry, _completion_with_retry, _convert_chunk_to_message_chunk, @@ -1140,3 +1142,106 @@ class TestServiceTier: result = model.invoke("Hello") assert isinstance(result, AIMessage) assert "service_tier" not in result.response_metadata + + +_CONTEXT_OVERFLOW_MESSAGE = ( + '{"error": {"object": "error", "type": "invalid_request_error", ' + '"code": "invalid_request_error", "message": "The prompt is too long: ' + '500208, model maximum context length: 262143"}}' +) + + +def test_context_overflow_error_invoke_sync() -> None: + """Prompt-too-long errors surface as `ContextOverflowError` on invoke.""" + llm = _make_llm(max_retries=0) + mock_client = MagicMock() + mock_client.create.side_effect = InvalidRequestError(_CONTEXT_OVERFLOW_MESSAGE) + llm.client = mock_client + + with pytest.raises(ContextOverflowError) as exc_info: + llm.invoke([HumanMessage(content="test")]) + + assert "prompt is too long" in str(exc_info.value) + assert isinstance(exc_info.value, FireworksContextOverflowError) + + +async def test_context_overflow_error_invoke_async() -> None: + """Prompt-too-long errors surface as `ContextOverflowError` on ainvoke.""" + llm = _make_llm(max_retries=0) + mock_async = MagicMock() + + async def _acreate(**_kwargs: Any) -> dict[str, Any]: + raise InvalidRequestError(_CONTEXT_OVERFLOW_MESSAGE) + + mock_async.acreate = _acreate + llm.async_client = mock_async + + with pytest.raises(ContextOverflowError) as exc_info: + await llm.ainvoke([HumanMessage(content="test")]) + + assert "prompt is too long" in str(exc_info.value) + assert isinstance(exc_info.value, FireworksContextOverflowError) + + +def test_context_overflow_error_stream_sync() -> None: + """Prompt-too-long errors surface as `ContextOverflowError` on stream.""" + llm = _make_llm(max_retries=0) + mock_client = MagicMock() + mock_client.create.side_effect = InvalidRequestError(_CONTEXT_OVERFLOW_MESSAGE) + llm.client = mock_client + + with pytest.raises(ContextOverflowError) as exc_info: + list(llm.stream([HumanMessage(content="test")])) + + assert "prompt is too long" in str(exc_info.value) + assert isinstance(exc_info.value, FireworksContextOverflowError) + + +async def test_context_overflow_error_stream_async() -> None: + """Prompt-too-long errors surface as `ContextOverflowError` on astream.""" + llm = _make_llm(max_retries=0) + mock_async = MagicMock() + + def _acreate(**_kwargs: Any) -> Any: + async def _failing_agen() -> Any: + raise InvalidRequestError(_CONTEXT_OVERFLOW_MESSAGE) + yield # pragma: no cover + + return _failing_agen() + + mock_async.acreate = _acreate + llm.async_client = mock_async + + with pytest.raises(ContextOverflowError) as exc_info: + async for _ in llm.astream([HumanMessage(content="test")]): + pass + + assert "prompt is too long" in str(exc_info.value) + assert isinstance(exc_info.value, FireworksContextOverflowError) + + +def test_context_overflow_error_backwards_compatibility() -> None: + """`ContextOverflowError` is also catchable as `InvalidRequestError`.""" + llm = _make_llm(max_retries=0) + mock_client = MagicMock() + mock_client.create.side_effect = InvalidRequestError(_CONTEXT_OVERFLOW_MESSAGE) + llm.client = mock_client + + with pytest.raises(InvalidRequestError) as exc_info: + llm.invoke([HumanMessage(content="test")]) + + assert isinstance(exc_info.value, InvalidRequestError) + assert isinstance(exc_info.value, ContextOverflowError) + + +def test_unrelated_invalid_request_error_not_promoted() -> None: + """Unrelated `InvalidRequestError`s should not be wrapped.""" + llm = _make_llm(max_retries=0) + mock_client = MagicMock() + mock_client.create.side_effect = InvalidRequestError("some other bad request") + llm.client = mock_client + + with pytest.raises(InvalidRequestError) as exc_info: + llm.invoke([HumanMessage(content="test")]) + + assert not isinstance(exc_info.value, ContextOverflowError)