mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
fix(fireworks): raise ContextOverflowError on prompt-too-long (#37458)
Co-authored-by: open-swe[bot] <open-swe@users.noreply.github.com> Co-authored-by: ccurme <26529506+ccurme@users.noreply.github.com>
This commit is contained in:
@@ -21,6 +21,7 @@ from fireworks.client.error import ( # type: ignore[import-untyped]
|
||||
BadGatewayError,
|
||||
FireworksError,
|
||||
InternalServerError,
|
||||
InvalidRequestError,
|
||||
RateLimitError,
|
||||
ServiceUnavailableError,
|
||||
)
|
||||
@@ -28,6 +29,7 @@ from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.exceptions import ContextOverflowError
|
||||
from langchain_core.language_models import (
|
||||
LanguageModelInput,
|
||||
ModelProfile,
|
||||
@@ -436,6 +438,17 @@ def _promote_http_status_error(exc: httpx.HTTPStatusError) -> NoReturn:
|
||||
raise exc
|
||||
|
||||
|
||||
class FireworksContextOverflowError(InvalidRequestError, ContextOverflowError):
|
||||
"""`InvalidRequestError` raised when input exceeds Fireworks's context limit."""
|
||||
|
||||
|
||||
def _handle_fireworks_invalid_request(e: InvalidRequestError) -> NoReturn:
|
||||
"""Promote prompt-too-long errors to `FireworksContextOverflowError`."""
|
||||
if "prompt is too long" in str(e):
|
||||
raise FireworksContextOverflowError(str(e)) from e
|
||||
raise e
|
||||
|
||||
|
||||
def _raise_empty_stream() -> NoReturn:
|
||||
"""Raise a descriptive error when the SDK returns a zero-chunk stream."""
|
||||
msg = "Received empty stream from Fireworks"
|
||||
@@ -791,9 +804,13 @@ class ChatFireworks(BaseChatModel):
|
||||
params["stream_options"] = {"include_usage": True}
|
||||
|
||||
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
|
||||
for chunk in _completion_with_retry(
|
||||
try:
|
||||
stream = _completion_with_retry(
|
||||
self, run_manager=run_manager, messages=message_dicts, **params
|
||||
):
|
||||
)
|
||||
except InvalidRequestError as e:
|
||||
_handle_fireworks_invalid_request(e)
|
||||
for chunk in stream:
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.model_dump()
|
||||
message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
|
||||
@@ -837,9 +854,12 @@ class ChatFireworks(BaseChatModel):
|
||||
**({"stream": stream} if stream is not None else {}),
|
||||
**kwargs,
|
||||
}
|
||||
try:
|
||||
response = _completion_with_retry(
|
||||
self, run_manager=run_manager, messages=message_dicts, **params
|
||||
)
|
||||
except InvalidRequestError as e:
|
||||
_handle_fireworks_invalid_request(e)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _create_message_dicts(
|
||||
@@ -899,9 +919,13 @@ class ChatFireworks(BaseChatModel):
|
||||
params["stream_options"] = {"include_usage": True}
|
||||
|
||||
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
|
||||
async for chunk in await _acompletion_with_retry(
|
||||
try:
|
||||
stream = await _acompletion_with_retry(
|
||||
self, run_manager=run_manager, messages=message_dicts, **params
|
||||
):
|
||||
)
|
||||
except InvalidRequestError as e:
|
||||
_handle_fireworks_invalid_request(e)
|
||||
async for chunk in stream:
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.model_dump()
|
||||
message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
|
||||
@@ -948,9 +972,12 @@ class ChatFireworks(BaseChatModel):
|
||||
**({"stream": stream} if stream is not None else {}),
|
||||
**kwargs,
|
||||
}
|
||||
try:
|
||||
response = await _acompletion_with_retry(
|
||||
self, run_manager=run_manager, messages=message_dicts, **params
|
||||
)
|
||||
except InvalidRequestError as e:
|
||||
_handle_fireworks_invalid_request(e)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
@property
|
||||
|
||||
@@ -14,6 +14,7 @@ from fireworks.client.error import ( # type: ignore[import-untyped]
|
||||
RateLimitError,
|
||||
ServiceUnavailableError,
|
||||
)
|
||||
from langchain_core.exceptions import ContextOverflowError
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
@@ -25,6 +26,7 @@ from langchain_core.messages import (
|
||||
|
||||
from langchain_fireworks import ChatFireworks
|
||||
from langchain_fireworks.chat_models import (
|
||||
FireworksContextOverflowError,
|
||||
_acompletion_with_retry,
|
||||
_completion_with_retry,
|
||||
_convert_chunk_to_message_chunk,
|
||||
@@ -1140,3 +1142,106 @@ class TestServiceTier:
|
||||
result = model.invoke("Hello")
|
||||
assert isinstance(result, AIMessage)
|
||||
assert "service_tier" not in result.response_metadata
|
||||
|
||||
|
||||
_CONTEXT_OVERFLOW_MESSAGE = (
|
||||
'{"error": {"object": "error", "type": "invalid_request_error", '
|
||||
'"code": "invalid_request_error", "message": "The prompt is too long: '
|
||||
'500208, model maximum context length: 262143"}}'
|
||||
)
|
||||
|
||||
|
||||
def test_context_overflow_error_invoke_sync() -> None:
|
||||
"""Prompt-too-long errors surface as `ContextOverflowError` on invoke."""
|
||||
llm = _make_llm(max_retries=0)
|
||||
mock_client = MagicMock()
|
||||
mock_client.create.side_effect = InvalidRequestError(_CONTEXT_OVERFLOW_MESSAGE)
|
||||
llm.client = mock_client
|
||||
|
||||
with pytest.raises(ContextOverflowError) as exc_info:
|
||||
llm.invoke([HumanMessage(content="test")])
|
||||
|
||||
assert "prompt is too long" in str(exc_info.value)
|
||||
assert isinstance(exc_info.value, FireworksContextOverflowError)
|
||||
|
||||
|
||||
async def test_context_overflow_error_invoke_async() -> None:
|
||||
"""Prompt-too-long errors surface as `ContextOverflowError` on ainvoke."""
|
||||
llm = _make_llm(max_retries=0)
|
||||
mock_async = MagicMock()
|
||||
|
||||
async def _acreate(**_kwargs: Any) -> dict[str, Any]:
|
||||
raise InvalidRequestError(_CONTEXT_OVERFLOW_MESSAGE)
|
||||
|
||||
mock_async.acreate = _acreate
|
||||
llm.async_client = mock_async
|
||||
|
||||
with pytest.raises(ContextOverflowError) as exc_info:
|
||||
await llm.ainvoke([HumanMessage(content="test")])
|
||||
|
||||
assert "prompt is too long" in str(exc_info.value)
|
||||
assert isinstance(exc_info.value, FireworksContextOverflowError)
|
||||
|
||||
|
||||
def test_context_overflow_error_stream_sync() -> None:
|
||||
"""Prompt-too-long errors surface as `ContextOverflowError` on stream."""
|
||||
llm = _make_llm(max_retries=0)
|
||||
mock_client = MagicMock()
|
||||
mock_client.create.side_effect = InvalidRequestError(_CONTEXT_OVERFLOW_MESSAGE)
|
||||
llm.client = mock_client
|
||||
|
||||
with pytest.raises(ContextOverflowError) as exc_info:
|
||||
list(llm.stream([HumanMessage(content="test")]))
|
||||
|
||||
assert "prompt is too long" in str(exc_info.value)
|
||||
assert isinstance(exc_info.value, FireworksContextOverflowError)
|
||||
|
||||
|
||||
async def test_context_overflow_error_stream_async() -> None:
|
||||
"""Prompt-too-long errors surface as `ContextOverflowError` on astream."""
|
||||
llm = _make_llm(max_retries=0)
|
||||
mock_async = MagicMock()
|
||||
|
||||
def _acreate(**_kwargs: Any) -> Any:
|
||||
async def _failing_agen() -> Any:
|
||||
raise InvalidRequestError(_CONTEXT_OVERFLOW_MESSAGE)
|
||||
yield # pragma: no cover
|
||||
|
||||
return _failing_agen()
|
||||
|
||||
mock_async.acreate = _acreate
|
||||
llm.async_client = mock_async
|
||||
|
||||
with pytest.raises(ContextOverflowError) as exc_info:
|
||||
async for _ in llm.astream([HumanMessage(content="test")]):
|
||||
pass
|
||||
|
||||
assert "prompt is too long" in str(exc_info.value)
|
||||
assert isinstance(exc_info.value, FireworksContextOverflowError)
|
||||
|
||||
|
||||
def test_context_overflow_error_backwards_compatibility() -> None:
|
||||
"""`ContextOverflowError` is also catchable as `InvalidRequestError`."""
|
||||
llm = _make_llm(max_retries=0)
|
||||
mock_client = MagicMock()
|
||||
mock_client.create.side_effect = InvalidRequestError(_CONTEXT_OVERFLOW_MESSAGE)
|
||||
llm.client = mock_client
|
||||
|
||||
with pytest.raises(InvalidRequestError) as exc_info:
|
||||
llm.invoke([HumanMessage(content="test")])
|
||||
|
||||
assert isinstance(exc_info.value, InvalidRequestError)
|
||||
assert isinstance(exc_info.value, ContextOverflowError)
|
||||
|
||||
|
||||
def test_unrelated_invalid_request_error_not_promoted() -> None:
|
||||
"""Unrelated `InvalidRequestError`s should not be wrapped."""
|
||||
llm = _make_llm(max_retries=0)
|
||||
mock_client = MagicMock()
|
||||
mock_client.create.side_effect = InvalidRequestError("some other bad request")
|
||||
llm.client = mock_client
|
||||
|
||||
with pytest.raises(InvalidRequestError) as exc_info:
|
||||
llm.invoke([HumanMessage(content="test")])
|
||||
|
||||
assert not isinstance(exc_info.value, ContextOverflowError)
|
||||
|
||||
Reference in New Issue
Block a user