mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
fix(fireworks): raise ContextOverflowError on prompt-too-long (#37458)
Co-authored-by: open-swe[bot] <open-swe@users.noreply.github.com> Co-authored-by: ccurme <26529506+ccurme@users.noreply.github.com>
This commit is contained in:
@@ -21,6 +21,7 @@ from fireworks.client.error import ( # type: ignore[import-untyped]
|
|||||||
BadGatewayError,
|
BadGatewayError,
|
||||||
FireworksError,
|
FireworksError,
|
||||||
InternalServerError,
|
InternalServerError,
|
||||||
|
InvalidRequestError,
|
||||||
RateLimitError,
|
RateLimitError,
|
||||||
ServiceUnavailableError,
|
ServiceUnavailableError,
|
||||||
)
|
)
|
||||||
@@ -28,6 +29,7 @@ from langchain_core.callbacks import (
|
|||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
)
|
)
|
||||||
|
from langchain_core.exceptions import ContextOverflowError
|
||||||
from langchain_core.language_models import (
|
from langchain_core.language_models import (
|
||||||
LanguageModelInput,
|
LanguageModelInput,
|
||||||
ModelProfile,
|
ModelProfile,
|
||||||
@@ -436,6 +438,17 @@ def _promote_http_status_error(exc: httpx.HTTPStatusError) -> NoReturn:
|
|||||||
raise exc
|
raise exc
|
||||||
|
|
||||||
|
|
||||||
|
class FireworksContextOverflowError(InvalidRequestError, ContextOverflowError):
|
||||||
|
"""`InvalidRequestError` raised when input exceeds Fireworks's context limit."""
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_fireworks_invalid_request(e: InvalidRequestError) -> NoReturn:
|
||||||
|
"""Promote prompt-too-long errors to `FireworksContextOverflowError`."""
|
||||||
|
if "prompt is too long" in str(e):
|
||||||
|
raise FireworksContextOverflowError(str(e)) from e
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def _raise_empty_stream() -> NoReturn:
|
def _raise_empty_stream() -> NoReturn:
|
||||||
"""Raise a descriptive error when the SDK returns a zero-chunk stream."""
|
"""Raise a descriptive error when the SDK returns a zero-chunk stream."""
|
||||||
msg = "Received empty stream from Fireworks"
|
msg = "Received empty stream from Fireworks"
|
||||||
@@ -791,9 +804,13 @@ class ChatFireworks(BaseChatModel):
|
|||||||
params["stream_options"] = {"include_usage": True}
|
params["stream_options"] = {"include_usage": True}
|
||||||
|
|
||||||
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
|
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
|
||||||
for chunk in _completion_with_retry(
|
try:
|
||||||
self, run_manager=run_manager, messages=message_dicts, **params
|
stream = _completion_with_retry(
|
||||||
):
|
self, run_manager=run_manager, messages=message_dicts, **params
|
||||||
|
)
|
||||||
|
except InvalidRequestError as e:
|
||||||
|
_handle_fireworks_invalid_request(e)
|
||||||
|
for chunk in stream:
|
||||||
if not isinstance(chunk, dict):
|
if not isinstance(chunk, dict):
|
||||||
chunk = chunk.model_dump()
|
chunk = chunk.model_dump()
|
||||||
message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
|
message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
|
||||||
@@ -837,9 +854,12 @@ class ChatFireworks(BaseChatModel):
|
|||||||
**({"stream": stream} if stream is not None else {}),
|
**({"stream": stream} if stream is not None else {}),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
}
|
}
|
||||||
response = _completion_with_retry(
|
try:
|
||||||
self, run_manager=run_manager, messages=message_dicts, **params
|
response = _completion_with_retry(
|
||||||
)
|
self, run_manager=run_manager, messages=message_dicts, **params
|
||||||
|
)
|
||||||
|
except InvalidRequestError as e:
|
||||||
|
_handle_fireworks_invalid_request(e)
|
||||||
return self._create_chat_result(response)
|
return self._create_chat_result(response)
|
||||||
|
|
||||||
def _create_message_dicts(
|
def _create_message_dicts(
|
||||||
@@ -899,9 +919,13 @@ class ChatFireworks(BaseChatModel):
|
|||||||
params["stream_options"] = {"include_usage": True}
|
params["stream_options"] = {"include_usage": True}
|
||||||
|
|
||||||
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
|
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
|
||||||
async for chunk in await _acompletion_with_retry(
|
try:
|
||||||
self, run_manager=run_manager, messages=message_dicts, **params
|
stream = await _acompletion_with_retry(
|
||||||
):
|
self, run_manager=run_manager, messages=message_dicts, **params
|
||||||
|
)
|
||||||
|
except InvalidRequestError as e:
|
||||||
|
_handle_fireworks_invalid_request(e)
|
||||||
|
async for chunk in stream:
|
||||||
if not isinstance(chunk, dict):
|
if not isinstance(chunk, dict):
|
||||||
chunk = chunk.model_dump()
|
chunk = chunk.model_dump()
|
||||||
message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
|
message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
|
||||||
@@ -948,9 +972,12 @@ class ChatFireworks(BaseChatModel):
|
|||||||
**({"stream": stream} if stream is not None else {}),
|
**({"stream": stream} if stream is not None else {}),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
}
|
}
|
||||||
response = await _acompletion_with_retry(
|
try:
|
||||||
self, run_manager=run_manager, messages=message_dicts, **params
|
response = await _acompletion_with_retry(
|
||||||
)
|
self, run_manager=run_manager, messages=message_dicts, **params
|
||||||
|
)
|
||||||
|
except InvalidRequestError as e:
|
||||||
|
_handle_fireworks_invalid_request(e)
|
||||||
return self._create_chat_result(response)
|
return self._create_chat_result(response)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from fireworks.client.error import ( # type: ignore[import-untyped]
|
|||||||
RateLimitError,
|
RateLimitError,
|
||||||
ServiceUnavailableError,
|
ServiceUnavailableError,
|
||||||
)
|
)
|
||||||
|
from langchain_core.exceptions import ContextOverflowError
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
AIMessageChunk,
|
AIMessageChunk,
|
||||||
@@ -25,6 +26,7 @@ from langchain_core.messages import (
|
|||||||
|
|
||||||
from langchain_fireworks import ChatFireworks
|
from langchain_fireworks import ChatFireworks
|
||||||
from langchain_fireworks.chat_models import (
|
from langchain_fireworks.chat_models import (
|
||||||
|
FireworksContextOverflowError,
|
||||||
_acompletion_with_retry,
|
_acompletion_with_retry,
|
||||||
_completion_with_retry,
|
_completion_with_retry,
|
||||||
_convert_chunk_to_message_chunk,
|
_convert_chunk_to_message_chunk,
|
||||||
@@ -1140,3 +1142,106 @@ class TestServiceTier:
|
|||||||
result = model.invoke("Hello")
|
result = model.invoke("Hello")
|
||||||
assert isinstance(result, AIMessage)
|
assert isinstance(result, AIMessage)
|
||||||
assert "service_tier" not in result.response_metadata
|
assert "service_tier" not in result.response_metadata
|
||||||
|
|
||||||
|
|
||||||
|
_CONTEXT_OVERFLOW_MESSAGE = (
|
||||||
|
'{"error": {"object": "error", "type": "invalid_request_error", '
|
||||||
|
'"code": "invalid_request_error", "message": "The prompt is too long: '
|
||||||
|
'500208, model maximum context length: 262143"}}'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_context_overflow_error_invoke_sync() -> None:
|
||||||
|
"""Prompt-too-long errors surface as `ContextOverflowError` on invoke."""
|
||||||
|
llm = _make_llm(max_retries=0)
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.create.side_effect = InvalidRequestError(_CONTEXT_OVERFLOW_MESSAGE)
|
||||||
|
llm.client = mock_client
|
||||||
|
|
||||||
|
with pytest.raises(ContextOverflowError) as exc_info:
|
||||||
|
llm.invoke([HumanMessage(content="test")])
|
||||||
|
|
||||||
|
assert "prompt is too long" in str(exc_info.value)
|
||||||
|
assert isinstance(exc_info.value, FireworksContextOverflowError)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_context_overflow_error_invoke_async() -> None:
|
||||||
|
"""Prompt-too-long errors surface as `ContextOverflowError` on ainvoke."""
|
||||||
|
llm = _make_llm(max_retries=0)
|
||||||
|
mock_async = MagicMock()
|
||||||
|
|
||||||
|
async def _acreate(**_kwargs: Any) -> dict[str, Any]:
|
||||||
|
raise InvalidRequestError(_CONTEXT_OVERFLOW_MESSAGE)
|
||||||
|
|
||||||
|
mock_async.acreate = _acreate
|
||||||
|
llm.async_client = mock_async
|
||||||
|
|
||||||
|
with pytest.raises(ContextOverflowError) as exc_info:
|
||||||
|
await llm.ainvoke([HumanMessage(content="test")])
|
||||||
|
|
||||||
|
assert "prompt is too long" in str(exc_info.value)
|
||||||
|
assert isinstance(exc_info.value, FireworksContextOverflowError)
|
||||||
|
|
||||||
|
|
||||||
|
def test_context_overflow_error_stream_sync() -> None:
|
||||||
|
"""Prompt-too-long errors surface as `ContextOverflowError` on stream."""
|
||||||
|
llm = _make_llm(max_retries=0)
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.create.side_effect = InvalidRequestError(_CONTEXT_OVERFLOW_MESSAGE)
|
||||||
|
llm.client = mock_client
|
||||||
|
|
||||||
|
with pytest.raises(ContextOverflowError) as exc_info:
|
||||||
|
list(llm.stream([HumanMessage(content="test")]))
|
||||||
|
|
||||||
|
assert "prompt is too long" in str(exc_info.value)
|
||||||
|
assert isinstance(exc_info.value, FireworksContextOverflowError)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_context_overflow_error_stream_async() -> None:
|
||||||
|
"""Prompt-too-long errors surface as `ContextOverflowError` on astream."""
|
||||||
|
llm = _make_llm(max_retries=0)
|
||||||
|
mock_async = MagicMock()
|
||||||
|
|
||||||
|
def _acreate(**_kwargs: Any) -> Any:
|
||||||
|
async def _failing_agen() -> Any:
|
||||||
|
raise InvalidRequestError(_CONTEXT_OVERFLOW_MESSAGE)
|
||||||
|
yield # pragma: no cover
|
||||||
|
|
||||||
|
return _failing_agen()
|
||||||
|
|
||||||
|
mock_async.acreate = _acreate
|
||||||
|
llm.async_client = mock_async
|
||||||
|
|
||||||
|
with pytest.raises(ContextOverflowError) as exc_info:
|
||||||
|
async for _ in llm.astream([HumanMessage(content="test")]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert "prompt is too long" in str(exc_info.value)
|
||||||
|
assert isinstance(exc_info.value, FireworksContextOverflowError)
|
||||||
|
|
||||||
|
|
||||||
|
def test_context_overflow_error_backwards_compatibility() -> None:
|
||||||
|
"""`ContextOverflowError` is also catchable as `InvalidRequestError`."""
|
||||||
|
llm = _make_llm(max_retries=0)
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.create.side_effect = InvalidRequestError(_CONTEXT_OVERFLOW_MESSAGE)
|
||||||
|
llm.client = mock_client
|
||||||
|
|
||||||
|
with pytest.raises(InvalidRequestError) as exc_info:
|
||||||
|
llm.invoke([HumanMessage(content="test")])
|
||||||
|
|
||||||
|
assert isinstance(exc_info.value, InvalidRequestError)
|
||||||
|
assert isinstance(exc_info.value, ContextOverflowError)
|
||||||
|
|
||||||
|
|
||||||
|
def test_unrelated_invalid_request_error_not_promoted() -> None:
|
||||||
|
"""Unrelated `InvalidRequestError`s should not be wrapped."""
|
||||||
|
llm = _make_llm(max_retries=0)
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.create.side_effect = InvalidRequestError("some other bad request")
|
||||||
|
llm.client = mock_client
|
||||||
|
|
||||||
|
with pytest.raises(InvalidRequestError) as exc_info:
|
||||||
|
llm.invoke([HumanMessage(content="test")])
|
||||||
|
|
||||||
|
assert not isinstance(exc_info.value, ContextOverflowError)
|
||||||
|
|||||||
Reference in New Issue
Block a user