fix(fireworks): honor max_retries (#36973)

`ChatFireworks.max_retries` silently did nothing. The old code assigned
the value to a `ChatCompletionV2` sub-object rather than the base
client, and the pinned Fireworks SDK (0.13.0–0.19.20) never honors its
own `_max_retries` attribute on the base client either. Since the
Stainless-generated 1.x SDK that does implement retries is still
pre-release (1.0.1a63 at time of writing), retry responsibility is
ported to the LangChain side until the pin can be bumped.
This commit is contained in:
Mason Daugherty
2026-04-23 16:40:54 -04:00
committed by GitHub
parent d30ef8a8aa
commit 7b09eb7bda
6 changed files with 604 additions and 13 deletions

View File

@@ -10,10 +10,20 @@ from operator import itemgetter
from typing import (
Any,
Literal,
NoReturn,
cast,
)
import httpx
from fireworks.client import AsyncFireworks, Fireworks # type: ignore[import-untyped]
from fireworks.client.error import ( # type: ignore[import-untyped]
APITimeoutError,
BadGatewayError,
FireworksError,
InternalServerError,
RateLimitError,
ServiceUnavailableError,
)
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
@@ -29,6 +39,7 @@ from langchain_core.language_models.chat_models import (
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.language_models.llms import create_base_retry_decorator
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
@@ -290,6 +301,137 @@ def _convert_chunk_to_message_chunk(
return default_class(content=content) # type: ignore[call-arg]
class _RetryableHTTPStatusError(FireworksError):
"""Internal marker for 5xx `httpx.HTTPStatusError` responses.
The Fireworks SDK maps a subset of status codes (500, 502, 503) to typed
exceptions but lets others (504, 507-511, Cloudflare-edge 520-599)
propagate as raw `httpx.HTTPStatusError`. Promoting those to this marker
inside `_call` keeps the retryable set expressible as a list of classes
for `create_base_retry_decorator`, preserving parity with `ChatMistralAI`.
"""
_RETRYABLE_ERRORS: tuple[type[BaseException], ...] = (
APITimeoutError,
BadGatewayError,
InternalServerError,
RateLimitError,
ServiceUnavailableError,
httpx.TimeoutException,
httpx.TransportError,
_RetryableHTTPStatusError,
)
def _promote_http_status_error(exc: httpx.HTTPStatusError) -> NoReturn:
"""Re-raise 5xx `httpx.HTTPStatusError` as a retryable marker."""
if exc.response.status_code >= 500:
msg = f"Retryable {exc.response.status_code} from Fireworks: {exc}"
raise _RetryableHTTPStatusError(msg) from exc
raise exc
def _raise_empty_stream() -> NoReturn:
"""Raise a descriptive error when the SDK returns a zero-chunk stream."""
msg = "Received empty stream from Fireworks"
raise FireworksError(msg)
def _create_retry_decorator(
llm: ChatFireworks,
run_manager: AsyncCallbackManagerForLLMRun | CallbackManagerForLLMRun | None = None,
) -> Callable[[Any], Any]:
"""Return a tenacity retry decorator for Fireworks SDK calls.
Retries are implemented here because the pinned Fireworks SDK 0.x does
not honor its own `_max_retries` attribute on completion resources.
"""
# `max_retries` counts retries *after* the initial attempt.
# `create_base_retry_decorator` forwards its `max_retries` to
# `stop_after_attempt`, which counts total attempts — so offset by 1.
# Note: this diverges from `ChatMistralAI`, which passes the raw value;
# the fireworks field docstring is the source of truth here.
# `None` and `0` both mean "single attempt, no retries".
attempts = (llm.max_retries + 1) if llm.max_retries else 1
return create_base_retry_decorator(
error_types=list(_RETRYABLE_ERRORS),
max_retries=attempts,
run_manager=run_manager,
)
def _completion_with_retry(
llm: ChatFireworks,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> Any:
"""Retry the sync completion call, including stream setup."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
def _call() -> Any:
try:
result = llm.client.create(**kwargs)
except httpx.HTTPStatusError as e:
_promote_http_status_error(e)
if kwargs.get("stream"):
# The streaming generator is lazy — advance once so the HTTP
# connection and any transport error happen inside the retry
# boundary. `_prepend_chunk` then re-yields the consumed chunk
# ahead of the rest so callers still see every event.
try:
iterator = iter(result)
first = next(iterator)
except StopIteration:
_raise_empty_stream()
except httpx.HTTPStatusError as e:
_promote_http_status_error(e)
return _prepend_chunk(first, iterator)
return result
return _call()
async def _acompletion_with_retry(
llm: ChatFireworks,
run_manager: AsyncCallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> Any:
"""Retry the async completion call, including stream setup."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
async def _call() -> Any:
if kwargs.get("stream"):
try:
result = llm.async_client.acreate(**kwargs)
agen = result.__aiter__()
first = await agen.__anext__()
except StopAsyncIteration:
_raise_empty_stream()
except httpx.HTTPStatusError as e:
_promote_http_status_error(e)
return _aprepend_chunk(first, agen)
try:
return await llm.async_client.acreate(**kwargs)
except httpx.HTTPStatusError as e:
_promote_http_status_error(e)
return await _call()
def _prepend_chunk(first: Any, rest: Iterator[Any]) -> Iterator[Any]:
yield first
yield from rest
async def _aprepend_chunk(first: Any, rest: AsyncIterator[Any]) -> AsyncIterator[Any]:
yield first
async for item in rest:
yield item
# This is basically a copy and replace for ChatFireworks, except
# - I needed to gut out tiktoken and some of the token estimation logic
# (not sure how important it is)
@@ -416,7 +558,14 @@ class ChatFireworks(BaseChatModel):
"""Maximum number of tokens to generate."""
max_retries: int | None = None
"""Maximum number of retries to make when generating."""
"""Maximum number of retries after the initial attempt when generating.
Retries use exponential backoff and trigger on transient errors:
`RateLimitError`, `APITimeoutError`, 5xx responses (including those that
surface as `httpx.HTTPStatusError` rather than typed SDK errors), and
underlying transport errors (`httpx.TimeoutException`, `httpx.TransportError`).
A value of `None` or `0` disables retries.
"""
model_config = ConfigDict(
populate_by_name=True,
@@ -453,9 +602,6 @@ class ChatFireworks(BaseChatModel):
self.client = Fireworks(**client_params).chat.completions
if not self.async_client:
self.async_client = AsyncFireworks(**client_params).chat.completions
if self.max_retries:
self.client._max_retries = self.max_retries
self.async_client._max_retries = self.max_retries
return self
def _resolve_model_profile(self) -> ModelProfile | None:
@@ -528,7 +674,9 @@ class ChatFireworks(BaseChatModel):
params["stream_options"] = {"include_usage": True}
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
for chunk in self.client.create(messages=message_dicts, **params):
for chunk in _completion_with_retry(
self, run_manager=run_manager, messages=message_dicts, **params
):
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
@@ -572,7 +720,9 @@ class ChatFireworks(BaseChatModel):
**({"stream": stream} if stream is not None else {}),
**kwargs,
}
response = self.client.create(messages=message_dicts, **params)
response = _completion_with_retry(
self, run_manager=run_manager, messages=message_dicts, **params
)
return self._create_chat_result(response)
def _create_message_dicts(
@@ -626,7 +776,9 @@ class ChatFireworks(BaseChatModel):
params["stream_options"] = {"include_usage": True}
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
async for chunk in self.async_client.acreate(messages=message_dicts, **params):
async for chunk in await _acompletion_with_retry(
self, run_manager=run_manager, messages=message_dicts, **params
):
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
@@ -673,7 +825,9 @@ class ChatFireworks(BaseChatModel):
**({"stream": stream} if stream is not None else {}),
**kwargs,
}
response = await self.async_client.acreate(messages=message_dicts, **params)
response = await _acompletion_with_retry(
self, run_manager=run_manager, messages=message_dicts, **params
)
return self._create_chat_result(response)
@property