fix(fireworks): raise ContextOverflowError on prompt-too-long (#37458)

Co-authored-by: open-swe[bot] <open-swe@users.noreply.github.com>
Co-authored-by: ccurme <26529506+ccurme@users.noreply.github.com>
This commit is contained in:
open-swe[bot]
2026-05-17 13:35:48 -04:00
committed by GitHub
parent 67eca93d17
commit 40c515c7b1
2 changed files with 144 additions and 12 deletions

View File

@@ -14,6 +14,7 @@ from fireworks.client.error import ( # type: ignore[import-untyped]
RateLimitError,
ServiceUnavailableError,
)
from langchain_core.exceptions import ContextOverflowError
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
@@ -25,6 +26,7 @@ from langchain_core.messages import (
from langchain_fireworks import ChatFireworks
from langchain_fireworks.chat_models import (
FireworksContextOverflowError,
_acompletion_with_retry,
_completion_with_retry,
_convert_chunk_to_message_chunk,
@@ -1140,3 +1142,106 @@ class TestServiceTier:
result = model.invoke("Hello")
assert isinstance(result, AIMessage)
assert "service_tier" not in result.response_metadata
_CONTEXT_OVERFLOW_MESSAGE = (
'{"error": {"object": "error", "type": "invalid_request_error", '
'"code": "invalid_request_error", "message": "The prompt is too long: '
'500208, model maximum context length: 262143"}}'
)
def test_context_overflow_error_invoke_sync() -> None:
"""Prompt-too-long errors surface as `ContextOverflowError` on invoke."""
llm = _make_llm(max_retries=0)
mock_client = MagicMock()
mock_client.create.side_effect = InvalidRequestError(_CONTEXT_OVERFLOW_MESSAGE)
llm.client = mock_client
with pytest.raises(ContextOverflowError) as exc_info:
llm.invoke([HumanMessage(content="test")])
assert "prompt is too long" in str(exc_info.value)
assert isinstance(exc_info.value, FireworksContextOverflowError)
async def test_context_overflow_error_invoke_async() -> None:
"""Prompt-too-long errors surface as `ContextOverflowError` on ainvoke."""
llm = _make_llm(max_retries=0)
mock_async = MagicMock()
async def _acreate(**_kwargs: Any) -> dict[str, Any]:
raise InvalidRequestError(_CONTEXT_OVERFLOW_MESSAGE)
mock_async.acreate = _acreate
llm.async_client = mock_async
with pytest.raises(ContextOverflowError) as exc_info:
await llm.ainvoke([HumanMessage(content="test")])
assert "prompt is too long" in str(exc_info.value)
assert isinstance(exc_info.value, FireworksContextOverflowError)
def test_context_overflow_error_stream_sync() -> None:
"""Prompt-too-long errors surface as `ContextOverflowError` on stream."""
llm = _make_llm(max_retries=0)
mock_client = MagicMock()
mock_client.create.side_effect = InvalidRequestError(_CONTEXT_OVERFLOW_MESSAGE)
llm.client = mock_client
with pytest.raises(ContextOverflowError) as exc_info:
list(llm.stream([HumanMessage(content="test")]))
assert "prompt is too long" in str(exc_info.value)
assert isinstance(exc_info.value, FireworksContextOverflowError)
async def test_context_overflow_error_stream_async() -> None:
"""Prompt-too-long errors surface as `ContextOverflowError` on astream."""
llm = _make_llm(max_retries=0)
mock_async = MagicMock()
def _acreate(**_kwargs: Any) -> Any:
async def _failing_agen() -> Any:
raise InvalidRequestError(_CONTEXT_OVERFLOW_MESSAGE)
yield # pragma: no cover
return _failing_agen()
mock_async.acreate = _acreate
llm.async_client = mock_async
with pytest.raises(ContextOverflowError) as exc_info:
async for _ in llm.astream([HumanMessage(content="test")]):
pass
assert "prompt is too long" in str(exc_info.value)
assert isinstance(exc_info.value, FireworksContextOverflowError)
def test_context_overflow_error_backwards_compatibility() -> None:
"""`ContextOverflowError` is also catchable as `InvalidRequestError`."""
llm = _make_llm(max_retries=0)
mock_client = MagicMock()
mock_client.create.side_effect = InvalidRequestError(_CONTEXT_OVERFLOW_MESSAGE)
llm.client = mock_client
with pytest.raises(InvalidRequestError) as exc_info:
llm.invoke([HumanMessage(content="test")])
assert isinstance(exc_info.value, InvalidRequestError)
assert isinstance(exc_info.value, ContextOverflowError)
def test_unrelated_invalid_request_error_not_promoted() -> None:
"""Unrelated `InvalidRequestError`s should not be wrapped."""
llm = _make_llm(max_retries=0)
mock_client = MagicMock()
mock_client.create.side_effect = InvalidRequestError("some other bad request")
llm.client = mock_client
with pytest.raises(InvalidRequestError) as exc_info:
llm.invoke([HumanMessage(content="test")])
assert not isinstance(exc_info.value, ContextOverflowError)