mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
fix(fireworks): honor max_retries (#36973)
`ChatFireworks.max_retries` silently did nothing. The old code assigned the value to a `ChatCompletionV2` sub-object rather than the base client, and the pinned Fireworks SDK (0.13.0–0.19.20) never honors its own `_max_retries` attribute on the base client either. Since the Stainless-generated 1.x SDK that does implement retries is still pre-release (1.0.1a63 at time of writing), retry responsibility is ported to the LangChain side until the pin can be bumped.
This commit is contained in:
@@ -10,10 +10,20 @@ from operator import itemgetter
|
||||
from typing import (
|
||||
Any,
|
||||
Literal,
|
||||
NoReturn,
|
||||
cast,
|
||||
)
|
||||
|
||||
import httpx
|
||||
from fireworks.client import AsyncFireworks, Fireworks # type: ignore[import-untyped]
|
||||
from fireworks.client.error import ( # type: ignore[import-untyped]
|
||||
APITimeoutError,
|
||||
BadGatewayError,
|
||||
FireworksError,
|
||||
InternalServerError,
|
||||
RateLimitError,
|
||||
ServiceUnavailableError,
|
||||
)
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
@@ -29,6 +39,7 @@ from langchain_core.language_models.chat_models import (
|
||||
agenerate_from_stream,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.language_models.llms import create_base_retry_decorator
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
@@ -290,6 +301,137 @@ def _convert_chunk_to_message_chunk(
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
|
||||
|
||||
class _RetryableHTTPStatusError(FireworksError):
|
||||
"""Internal marker for 5xx `httpx.HTTPStatusError` responses.
|
||||
|
||||
The Fireworks SDK maps a subset of status codes (500, 502, 503) to typed
|
||||
exceptions but lets others (504, 507-511, Cloudflare-edge 520-599)
|
||||
propagate as raw `httpx.HTTPStatusError`. Promoting those to this marker
|
||||
inside `_call` keeps the retryable set expressible as a list of classes
|
||||
for `create_base_retry_decorator`, preserving parity with `ChatMistralAI`.
|
||||
"""
|
||||
|
||||
|
||||
_RETRYABLE_ERRORS: tuple[type[BaseException], ...] = (
|
||||
APITimeoutError,
|
||||
BadGatewayError,
|
||||
InternalServerError,
|
||||
RateLimitError,
|
||||
ServiceUnavailableError,
|
||||
httpx.TimeoutException,
|
||||
httpx.TransportError,
|
||||
_RetryableHTTPStatusError,
|
||||
)
|
||||
|
||||
|
||||
def _promote_http_status_error(exc: httpx.HTTPStatusError) -> NoReturn:
|
||||
"""Re-raise 5xx `httpx.HTTPStatusError` as a retryable marker."""
|
||||
if exc.response.status_code >= 500:
|
||||
msg = f"Retryable {exc.response.status_code} from Fireworks: {exc}"
|
||||
raise _RetryableHTTPStatusError(msg) from exc
|
||||
raise exc
|
||||
|
||||
|
||||
def _raise_empty_stream() -> NoReturn:
|
||||
"""Raise a descriptive error when the SDK returns a zero-chunk stream."""
|
||||
msg = "Received empty stream from Fireworks"
|
||||
raise FireworksError(msg)
|
||||
|
||||
|
||||
def _create_retry_decorator(
|
||||
llm: ChatFireworks,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | CallbackManagerForLLMRun | None = None,
|
||||
) -> Callable[[Any], Any]:
|
||||
"""Return a tenacity retry decorator for Fireworks SDK calls.
|
||||
|
||||
Retries are implemented here because the pinned Fireworks SDK 0.x does
|
||||
not honor its own `_max_retries` attribute on completion resources.
|
||||
"""
|
||||
# `max_retries` counts retries *after* the initial attempt.
|
||||
# `create_base_retry_decorator` forwards its `max_retries` to
|
||||
# `stop_after_attempt`, which counts total attempts — so offset by 1.
|
||||
# Note: this diverges from `ChatMistralAI`, which passes the raw value;
|
||||
# the fireworks field docstring is the source of truth here.
|
||||
# `None` and `0` both mean "single attempt, no retries".
|
||||
attempts = (llm.max_retries + 1) if llm.max_retries else 1
|
||||
return create_base_retry_decorator(
|
||||
error_types=list(_RETRYABLE_ERRORS),
|
||||
max_retries=attempts,
|
||||
run_manager=run_manager,
|
||||
)
|
||||
|
||||
|
||||
def _completion_with_retry(
|
||||
llm: ChatFireworks,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Retry the sync completion call, including stream setup."""
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
def _call() -> Any:
|
||||
try:
|
||||
result = llm.client.create(**kwargs)
|
||||
except httpx.HTTPStatusError as e:
|
||||
_promote_http_status_error(e)
|
||||
if kwargs.get("stream"):
|
||||
# The streaming generator is lazy — advance once so the HTTP
|
||||
# connection and any transport error happen inside the retry
|
||||
# boundary. `_prepend_chunk` then re-yields the consumed chunk
|
||||
# ahead of the rest so callers still see every event.
|
||||
try:
|
||||
iterator = iter(result)
|
||||
first = next(iterator)
|
||||
except StopIteration:
|
||||
_raise_empty_stream()
|
||||
except httpx.HTTPStatusError as e:
|
||||
_promote_http_status_error(e)
|
||||
return _prepend_chunk(first, iterator)
|
||||
return result
|
||||
|
||||
return _call()
|
||||
|
||||
|
||||
async def _acompletion_with_retry(
|
||||
llm: ChatFireworks,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Retry the async completion call, including stream setup."""
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
async def _call() -> Any:
|
||||
if kwargs.get("stream"):
|
||||
try:
|
||||
result = llm.async_client.acreate(**kwargs)
|
||||
agen = result.__aiter__()
|
||||
first = await agen.__anext__()
|
||||
except StopAsyncIteration:
|
||||
_raise_empty_stream()
|
||||
except httpx.HTTPStatusError as e:
|
||||
_promote_http_status_error(e)
|
||||
return _aprepend_chunk(first, agen)
|
||||
try:
|
||||
return await llm.async_client.acreate(**kwargs)
|
||||
except httpx.HTTPStatusError as e:
|
||||
_promote_http_status_error(e)
|
||||
|
||||
return await _call()
|
||||
|
||||
|
||||
def _prepend_chunk(first: Any, rest: Iterator[Any]) -> Iterator[Any]:
|
||||
yield first
|
||||
yield from rest
|
||||
|
||||
|
||||
async def _aprepend_chunk(first: Any, rest: AsyncIterator[Any]) -> AsyncIterator[Any]:
|
||||
yield first
|
||||
async for item in rest:
|
||||
yield item
|
||||
|
||||
|
||||
# This is basically a copy and replace for ChatFireworks, except
|
||||
# - I needed to gut out tiktoken and some of the token estimation logic
|
||||
# (not sure how important it is)
|
||||
@@ -416,7 +558,14 @@ class ChatFireworks(BaseChatModel):
|
||||
"""Maximum number of tokens to generate."""
|
||||
|
||||
max_retries: int | None = None
|
||||
"""Maximum number of retries to make when generating."""
|
||||
"""Maximum number of retries after the initial attempt when generating.
|
||||
|
||||
Retries use exponential backoff and trigger on transient errors:
|
||||
`RateLimitError`, `APITimeoutError`, 5xx responses (including those that
|
||||
surface as `httpx.HTTPStatusError` rather than typed SDK errors), and
|
||||
underlying transport errors (`httpx.TimeoutException`, `httpx.TransportError`).
|
||||
A value of `None` or `0` disables retries.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
@@ -453,9 +602,6 @@ class ChatFireworks(BaseChatModel):
|
||||
self.client = Fireworks(**client_params).chat.completions
|
||||
if not self.async_client:
|
||||
self.async_client = AsyncFireworks(**client_params).chat.completions
|
||||
if self.max_retries:
|
||||
self.client._max_retries = self.max_retries
|
||||
self.async_client._max_retries = self.max_retries
|
||||
return self
|
||||
|
||||
def _resolve_model_profile(self) -> ModelProfile | None:
|
||||
@@ -528,7 +674,9 @@ class ChatFireworks(BaseChatModel):
|
||||
params["stream_options"] = {"include_usage": True}
|
||||
|
||||
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
|
||||
for chunk in self.client.create(messages=message_dicts, **params):
|
||||
for chunk in _completion_with_retry(
|
||||
self, run_manager=run_manager, messages=message_dicts, **params
|
||||
):
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.model_dump()
|
||||
message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
|
||||
@@ -572,7 +720,9 @@ class ChatFireworks(BaseChatModel):
|
||||
**({"stream": stream} if stream is not None else {}),
|
||||
**kwargs,
|
||||
}
|
||||
response = self.client.create(messages=message_dicts, **params)
|
||||
response = _completion_with_retry(
|
||||
self, run_manager=run_manager, messages=message_dicts, **params
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _create_message_dicts(
|
||||
@@ -626,7 +776,9 @@ class ChatFireworks(BaseChatModel):
|
||||
params["stream_options"] = {"include_usage": True}
|
||||
|
||||
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
|
||||
async for chunk in self.async_client.acreate(messages=message_dicts, **params):
|
||||
async for chunk in await _acompletion_with_retry(
|
||||
self, run_manager=run_manager, messages=message_dicts, **params
|
||||
):
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.model_dump()
|
||||
message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
|
||||
@@ -673,7 +825,9 @@ class ChatFireworks(BaseChatModel):
|
||||
**({"stream": stream} if stream is not None else {}),
|
||||
**kwargs,
|
||||
}
|
||||
response = await self.async_client.acreate(messages=message_dicts, **params)
|
||||
response = await _acompletion_with_retry(
|
||||
self, run_manager=run_manager, messages=message_dicts, **params
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user