fix(fireworks): raise ContextOverflowError on prompt-too-long (#37458)

Co-authored-by: open-swe[bot] <open-swe@users.noreply.github.com>
Co-authored-by: ccurme <26529506+ccurme@users.noreply.github.com>
This commit is contained in:
open-swe[bot]
2026-05-17 13:35:48 -04:00
committed by GitHub
parent 67eca93d17
commit 40c515c7b1
2 changed files with 144 additions and 12 deletions

View File

@@ -21,6 +21,7 @@ from fireworks.client.error import ( # type: ignore[import-untyped]
BadGatewayError,
FireworksError,
InternalServerError,
InvalidRequestError,
RateLimitError,
ServiceUnavailableError,
)
@@ -28,6 +29,7 @@ from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.exceptions import ContextOverflowError
from langchain_core.language_models import (
LanguageModelInput,
ModelProfile,
@@ -436,6 +438,17 @@ def _promote_http_status_error(exc: httpx.HTTPStatusError) -> NoReturn:
raise exc
class FireworksContextOverflowError(InvalidRequestError, ContextOverflowError):
"""`InvalidRequestError` raised when input exceeds Fireworks's context limit."""
def _handle_fireworks_invalid_request(e: InvalidRequestError) -> NoReturn:
"""Promote prompt-too-long errors to `FireworksContextOverflowError`."""
if "prompt is too long" in str(e):
raise FireworksContextOverflowError(str(e)) from e
raise e
def _raise_empty_stream() -> NoReturn:
"""Raise a descriptive error when the SDK returns a zero-chunk stream."""
msg = "Received empty stream from Fireworks"
@@ -791,9 +804,13 @@ class ChatFireworks(BaseChatModel):
params["stream_options"] = {"include_usage": True}
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
for chunk in _completion_with_retry(
self, run_manager=run_manager, messages=message_dicts, **params
):
try:
stream = _completion_with_retry(
self, run_manager=run_manager, messages=message_dicts, **params
)
except InvalidRequestError as e:
_handle_fireworks_invalid_request(e)
for chunk in stream:
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
@@ -837,9 +854,12 @@ class ChatFireworks(BaseChatModel):
**({"stream": stream} if stream is not None else {}),
**kwargs,
}
response = _completion_with_retry(
self, run_manager=run_manager, messages=message_dicts, **params
)
try:
response = _completion_with_retry(
self, run_manager=run_manager, messages=message_dicts, **params
)
except InvalidRequestError as e:
_handle_fireworks_invalid_request(e)
return self._create_chat_result(response)
def _create_message_dicts(
@@ -899,9 +919,13 @@ class ChatFireworks(BaseChatModel):
params["stream_options"] = {"include_usage": True}
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
async for chunk in await _acompletion_with_retry(
self, run_manager=run_manager, messages=message_dicts, **params
):
try:
stream = await _acompletion_with_retry(
self, run_manager=run_manager, messages=message_dicts, **params
)
except InvalidRequestError as e:
_handle_fireworks_invalid_request(e)
async for chunk in stream:
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
@@ -948,9 +972,12 @@ class ChatFireworks(BaseChatModel):
**({"stream": stream} if stream is not None else {}),
**kwargs,
}
response = await _acompletion_with_retry(
self, run_manager=run_manager, messages=message_dicts, **params
)
try:
response = await _acompletion_with_retry(
self, run_manager=run_manager, messages=message_dicts, **params
)
except InvalidRequestError as e:
_handle_fireworks_invalid_request(e)
return self._create_chat_result(response)
@property

View File

@@ -14,6 +14,7 @@ from fireworks.client.error import ( # type: ignore[import-untyped]
RateLimitError,
ServiceUnavailableError,
)
from langchain_core.exceptions import ContextOverflowError
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
@@ -25,6 +26,7 @@ from langchain_core.messages import (
from langchain_fireworks import ChatFireworks
from langchain_fireworks.chat_models import (
FireworksContextOverflowError,
_acompletion_with_retry,
_completion_with_retry,
_convert_chunk_to_message_chunk,
@@ -1140,3 +1142,106 @@ class TestServiceTier:
result = model.invoke("Hello")
assert isinstance(result, AIMessage)
assert "service_tier" not in result.response_metadata
_CONTEXT_OVERFLOW_MESSAGE = (
'{"error": {"object": "error", "type": "invalid_request_error", '
'"code": "invalid_request_error", "message": "The prompt is too long: '
'500208, model maximum context length: 262143"}}'
)
def test_context_overflow_error_invoke_sync() -> None:
"""Prompt-too-long errors surface as `ContextOverflowError` on invoke."""
llm = _make_llm(max_retries=0)
mock_client = MagicMock()
mock_client.create.side_effect = InvalidRequestError(_CONTEXT_OVERFLOW_MESSAGE)
llm.client = mock_client
with pytest.raises(ContextOverflowError) as exc_info:
llm.invoke([HumanMessage(content="test")])
assert "prompt is too long" in str(exc_info.value)
assert isinstance(exc_info.value, FireworksContextOverflowError)
async def test_context_overflow_error_invoke_async() -> None:
"""Prompt-too-long errors surface as `ContextOverflowError` on ainvoke."""
llm = _make_llm(max_retries=0)
mock_async = MagicMock()
async def _acreate(**_kwargs: Any) -> dict[str, Any]:
raise InvalidRequestError(_CONTEXT_OVERFLOW_MESSAGE)
mock_async.acreate = _acreate
llm.async_client = mock_async
with pytest.raises(ContextOverflowError) as exc_info:
await llm.ainvoke([HumanMessage(content="test")])
assert "prompt is too long" in str(exc_info.value)
assert isinstance(exc_info.value, FireworksContextOverflowError)
def test_context_overflow_error_stream_sync() -> None:
"""Prompt-too-long errors surface as `ContextOverflowError` on stream."""
llm = _make_llm(max_retries=0)
mock_client = MagicMock()
mock_client.create.side_effect = InvalidRequestError(_CONTEXT_OVERFLOW_MESSAGE)
llm.client = mock_client
with pytest.raises(ContextOverflowError) as exc_info:
list(llm.stream([HumanMessage(content="test")]))
assert "prompt is too long" in str(exc_info.value)
assert isinstance(exc_info.value, FireworksContextOverflowError)
async def test_context_overflow_error_stream_async() -> None:
"""Prompt-too-long errors surface as `ContextOverflowError` on astream."""
llm = _make_llm(max_retries=0)
mock_async = MagicMock()
def _acreate(**_kwargs: Any) -> Any:
async def _failing_agen() -> Any:
raise InvalidRequestError(_CONTEXT_OVERFLOW_MESSAGE)
yield # pragma: no cover
return _failing_agen()
mock_async.acreate = _acreate
llm.async_client = mock_async
with pytest.raises(ContextOverflowError) as exc_info:
async for _ in llm.astream([HumanMessage(content="test")]):
pass
assert "prompt is too long" in str(exc_info.value)
assert isinstance(exc_info.value, FireworksContextOverflowError)
def test_context_overflow_error_backwards_compatibility() -> None:
"""`ContextOverflowError` is also catchable as `InvalidRequestError`."""
llm = _make_llm(max_retries=0)
mock_client = MagicMock()
mock_client.create.side_effect = InvalidRequestError(_CONTEXT_OVERFLOW_MESSAGE)
llm.client = mock_client
with pytest.raises(InvalidRequestError) as exc_info:
llm.invoke([HumanMessage(content="test")])
assert isinstance(exc_info.value, InvalidRequestError)
assert isinstance(exc_info.value, ContextOverflowError)
def test_unrelated_invalid_request_error_not_promoted() -> None:
"""Unrelated `InvalidRequestError`s should not be wrapped."""
llm = _make_llm(max_retries=0)
mock_client = MagicMock()
mock_client.create.side_effect = InvalidRequestError("some other bad request")
llm.client = mock_client
with pytest.raises(InvalidRequestError) as exc_info:
llm.invoke([HumanMessage(content="test")])
assert not isinstance(exc_info.value, ContextOverflowError)