mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
fix(fireworks): honor max_retries (#36973)
`ChatFireworks.max_retries` silently did nothing. The old code assigned the value to a `ChatCompletionV2` sub-object rather than the base client, and the pinned Fireworks SDK (0.13.0–0.19.20) never honors its own `_max_retries` attribute on the base client either. Since the Stainless-generated 1.x SDK that does implement retries is still pre-release (1.0.1a63 at time of writing), retry responsibility is ported to the LangChain side until the pin can be bumped.
This commit is contained in:
@@ -10,10 +10,20 @@ from operator import itemgetter
|
||||
from typing import (
|
||||
Any,
|
||||
Literal,
|
||||
NoReturn,
|
||||
cast,
|
||||
)
|
||||
|
||||
import httpx
|
||||
from fireworks.client import AsyncFireworks, Fireworks # type: ignore[import-untyped]
|
||||
from fireworks.client.error import ( # type: ignore[import-untyped]
|
||||
APITimeoutError,
|
||||
BadGatewayError,
|
||||
FireworksError,
|
||||
InternalServerError,
|
||||
RateLimitError,
|
||||
ServiceUnavailableError,
|
||||
)
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
@@ -29,6 +39,7 @@ from langchain_core.language_models.chat_models import (
|
||||
agenerate_from_stream,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.language_models.llms import create_base_retry_decorator
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
@@ -290,6 +301,137 @@ def _convert_chunk_to_message_chunk(
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
|
||||
|
||||
class _RetryableHTTPStatusError(FireworksError):
|
||||
"""Internal marker for 5xx `httpx.HTTPStatusError` responses.
|
||||
|
||||
The Fireworks SDK maps a subset of status codes (500, 502, 503) to typed
|
||||
exceptions but lets others (504, 507-511, Cloudflare-edge 520-599)
|
||||
propagate as raw `httpx.HTTPStatusError`. Promoting those to this marker
|
||||
inside `_call` keeps the retryable set expressible as a list of classes
|
||||
for `create_base_retry_decorator`, preserving parity with `ChatMistralAI`.
|
||||
"""
|
||||
|
||||
|
||||
_RETRYABLE_ERRORS: tuple[type[BaseException], ...] = (
|
||||
APITimeoutError,
|
||||
BadGatewayError,
|
||||
InternalServerError,
|
||||
RateLimitError,
|
||||
ServiceUnavailableError,
|
||||
httpx.TimeoutException,
|
||||
httpx.TransportError,
|
||||
_RetryableHTTPStatusError,
|
||||
)
|
||||
|
||||
|
||||
def _promote_http_status_error(exc: httpx.HTTPStatusError) -> NoReturn:
|
||||
"""Re-raise 5xx `httpx.HTTPStatusError` as a retryable marker."""
|
||||
if exc.response.status_code >= 500:
|
||||
msg = f"Retryable {exc.response.status_code} from Fireworks: {exc}"
|
||||
raise _RetryableHTTPStatusError(msg) from exc
|
||||
raise exc
|
||||
|
||||
|
||||
def _raise_empty_stream() -> NoReturn:
|
||||
"""Raise a descriptive error when the SDK returns a zero-chunk stream."""
|
||||
msg = "Received empty stream from Fireworks"
|
||||
raise FireworksError(msg)
|
||||
|
||||
|
||||
def _create_retry_decorator(
|
||||
llm: ChatFireworks,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | CallbackManagerForLLMRun | None = None,
|
||||
) -> Callable[[Any], Any]:
|
||||
"""Return a tenacity retry decorator for Fireworks SDK calls.
|
||||
|
||||
Retries are implemented here because the pinned Fireworks SDK 0.x does
|
||||
not honor its own `_max_retries` attribute on completion resources.
|
||||
"""
|
||||
# `max_retries` counts retries *after* the initial attempt.
|
||||
# `create_base_retry_decorator` forwards its `max_retries` to
|
||||
# `stop_after_attempt`, which counts total attempts — so offset by 1.
|
||||
# Note: this diverges from `ChatMistralAI`, which passes the raw value;
|
||||
# the fireworks field docstring is the source of truth here.
|
||||
# `None` and `0` both mean "single attempt, no retries".
|
||||
attempts = (llm.max_retries + 1) if llm.max_retries else 1
|
||||
return create_base_retry_decorator(
|
||||
error_types=list(_RETRYABLE_ERRORS),
|
||||
max_retries=attempts,
|
||||
run_manager=run_manager,
|
||||
)
|
||||
|
||||
|
||||
def _completion_with_retry(
|
||||
llm: ChatFireworks,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Retry the sync completion call, including stream setup."""
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
def _call() -> Any:
|
||||
try:
|
||||
result = llm.client.create(**kwargs)
|
||||
except httpx.HTTPStatusError as e:
|
||||
_promote_http_status_error(e)
|
||||
if kwargs.get("stream"):
|
||||
# The streaming generator is lazy — advance once so the HTTP
|
||||
# connection and any transport error happen inside the retry
|
||||
# boundary. `_prepend_chunk` then re-yields the consumed chunk
|
||||
# ahead of the rest so callers still see every event.
|
||||
try:
|
||||
iterator = iter(result)
|
||||
first = next(iterator)
|
||||
except StopIteration:
|
||||
_raise_empty_stream()
|
||||
except httpx.HTTPStatusError as e:
|
||||
_promote_http_status_error(e)
|
||||
return _prepend_chunk(first, iterator)
|
||||
return result
|
||||
|
||||
return _call()
|
||||
|
||||
|
||||
async def _acompletion_with_retry(
|
||||
llm: ChatFireworks,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Retry the async completion call, including stream setup."""
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
async def _call() -> Any:
|
||||
if kwargs.get("stream"):
|
||||
try:
|
||||
result = llm.async_client.acreate(**kwargs)
|
||||
agen = result.__aiter__()
|
||||
first = await agen.__anext__()
|
||||
except StopAsyncIteration:
|
||||
_raise_empty_stream()
|
||||
except httpx.HTTPStatusError as e:
|
||||
_promote_http_status_error(e)
|
||||
return _aprepend_chunk(first, agen)
|
||||
try:
|
||||
return await llm.async_client.acreate(**kwargs)
|
||||
except httpx.HTTPStatusError as e:
|
||||
_promote_http_status_error(e)
|
||||
|
||||
return await _call()
|
||||
|
||||
|
||||
def _prepend_chunk(first: Any, rest: Iterator[Any]) -> Iterator[Any]:
|
||||
yield first
|
||||
yield from rest
|
||||
|
||||
|
||||
async def _aprepend_chunk(first: Any, rest: AsyncIterator[Any]) -> AsyncIterator[Any]:
|
||||
yield first
|
||||
async for item in rest:
|
||||
yield item
|
||||
|
||||
|
||||
# This is basically a copy and replace for ChatFireworks, except
|
||||
# - I needed to gut out tiktoken and some of the token estimation logic
|
||||
# (not sure how important it is)
|
||||
@@ -416,7 +558,14 @@ class ChatFireworks(BaseChatModel):
|
||||
"""Maximum number of tokens to generate."""
|
||||
|
||||
max_retries: int | None = None
|
||||
"""Maximum number of retries to make when generating."""
|
||||
"""Maximum number of retries after the initial attempt when generating.
|
||||
|
||||
Retries use exponential backoff and trigger on transient errors:
|
||||
`RateLimitError`, `APITimeoutError`, 5xx responses (including those that
|
||||
surface as `httpx.HTTPStatusError` rather than typed SDK errors), and
|
||||
underlying transport errors (`httpx.TimeoutException`, `httpx.TransportError`).
|
||||
A value of `None` or `0` disables retries.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
@@ -453,9 +602,6 @@ class ChatFireworks(BaseChatModel):
|
||||
self.client = Fireworks(**client_params).chat.completions
|
||||
if not self.async_client:
|
||||
self.async_client = AsyncFireworks(**client_params).chat.completions
|
||||
if self.max_retries:
|
||||
self.client._max_retries = self.max_retries
|
||||
self.async_client._max_retries = self.max_retries
|
||||
return self
|
||||
|
||||
def _resolve_model_profile(self) -> ModelProfile | None:
|
||||
@@ -528,7 +674,9 @@ class ChatFireworks(BaseChatModel):
|
||||
params["stream_options"] = {"include_usage": True}
|
||||
|
||||
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
|
||||
for chunk in self.client.create(messages=message_dicts, **params):
|
||||
for chunk in _completion_with_retry(
|
||||
self, run_manager=run_manager, messages=message_dicts, **params
|
||||
):
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.model_dump()
|
||||
message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
|
||||
@@ -572,7 +720,9 @@ class ChatFireworks(BaseChatModel):
|
||||
**({"stream": stream} if stream is not None else {}),
|
||||
**kwargs,
|
||||
}
|
||||
response = self.client.create(messages=message_dicts, **params)
|
||||
response = _completion_with_retry(
|
||||
self, run_manager=run_manager, messages=message_dicts, **params
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _create_message_dicts(
|
||||
@@ -626,7 +776,9 @@ class ChatFireworks(BaseChatModel):
|
||||
params["stream_options"] = {"include_usage": True}
|
||||
|
||||
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
|
||||
async for chunk in self.async_client.acreate(messages=message_dicts, **params):
|
||||
async for chunk in await _acompletion_with_retry(
|
||||
self, run_manager=run_manager, messages=message_dicts, **params
|
||||
):
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.model_dump()
|
||||
message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
|
||||
@@ -673,7 +825,9 @@ class ChatFireworks(BaseChatModel):
|
||||
**({"stream": stream} if stream is not None else {}),
|
||||
**kwargs,
|
||||
}
|
||||
response = await self.async_client.acreate(messages=message_dicts, **params)
|
||||
response = await _acompletion_with_retry(
|
||||
self, run_manager=run_manager, messages=message_dicts, **params
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
@property
|
||||
|
||||
@@ -5,11 +5,21 @@ from __future__ import annotations
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
from fireworks.client.error import ( # type: ignore[import-untyped]
|
||||
AuthenticationError,
|
||||
FireworksError,
|
||||
InvalidRequestError,
|
||||
RateLimitError,
|
||||
ServiceUnavailableError,
|
||||
)
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage
|
||||
|
||||
from langchain_fireworks import ChatFireworks
|
||||
from langchain_fireworks.chat_models import (
|
||||
_acompletion_with_retry,
|
||||
_completion_with_retry,
|
||||
_convert_chunk_to_message_chunk,
|
||||
_convert_dict_to_message,
|
||||
_usage_to_metadata,
|
||||
@@ -82,6 +92,433 @@ def test_convert_dict_to_message_without_reasoning_content() -> None:
|
||||
assert "reasoning_content" not in message.additional_kwargs
|
||||
|
||||
|
||||
def _make_llm(max_retries: int | None = 2) -> ChatFireworks:
|
||||
return ChatFireworks(
|
||||
model="accounts/fireworks/models/test",
|
||||
api_key="fake-key", # type: ignore[arg-type]
|
||||
max_retries=max_retries,
|
||||
)
|
||||
|
||||
|
||||
def _success_response() -> dict[str, Any]:
|
||||
return {
|
||||
"choices": [
|
||||
{
|
||||
"message": {"role": "assistant", "content": "hello"},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _no_retry_sleep(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Avoid tenacity's exponential backoff in tests."""
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
monkeypatch.setattr(time, "sleep", lambda _s: None)
|
||||
|
||||
async def _no_async_sleep(_s: float) -> None:
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(asyncio, "sleep", _no_async_sleep)
|
||||
|
||||
|
||||
def test_completion_with_retry_retries_on_retryable_error() -> None:
|
||||
"""Retryable errors trigger retries up to the configured limit."""
|
||||
llm = _make_llm(max_retries=2)
|
||||
mock_client = MagicMock()
|
||||
mock_client.create.side_effect = [
|
||||
RateLimitError("rate limited"),
|
||||
ServiceUnavailableError("unavailable"),
|
||||
_success_response(),
|
||||
]
|
||||
llm.client = mock_client
|
||||
|
||||
result = _completion_with_retry(llm, messages=[])
|
||||
|
||||
assert result == _success_response()
|
||||
assert mock_client.create.call_count == 3
|
||||
|
||||
|
||||
def test_completion_with_retry_does_not_retry_non_retryable() -> None:
|
||||
"""Non-retryable errors propagate after a single attempt."""
|
||||
llm = _make_llm(max_retries=3)
|
||||
mock_client = MagicMock()
|
||||
mock_client.create.side_effect = AuthenticationError("bad key")
|
||||
llm.client = mock_client
|
||||
|
||||
with pytest.raises(AuthenticationError):
|
||||
_completion_with_retry(llm, messages=[])
|
||||
|
||||
assert mock_client.create.call_count == 1
|
||||
|
||||
|
||||
def test_completion_with_retry_respects_max_retries_none() -> None:
|
||||
"""`max_retries=None` disables retries."""
|
||||
llm = _make_llm(max_retries=None)
|
||||
mock_client = MagicMock()
|
||||
mock_client.create.side_effect = RateLimitError("rate limited")
|
||||
llm.client = mock_client
|
||||
|
||||
with pytest.raises(RateLimitError):
|
||||
_completion_with_retry(llm, messages=[])
|
||||
|
||||
assert mock_client.create.call_count == 1
|
||||
|
||||
|
||||
def test_completion_with_retry_exhausts_and_raises() -> None:
|
||||
"""When every attempt fails, the last error is re-raised."""
|
||||
llm = _make_llm(max_retries=2)
|
||||
mock_client = MagicMock()
|
||||
mock_client.create.side_effect = RateLimitError("rate limited")
|
||||
llm.client = mock_client
|
||||
|
||||
with pytest.raises(RateLimitError):
|
||||
_completion_with_retry(llm, messages=[])
|
||||
|
||||
# 1 initial attempt + 2 retries = 3 total attempts
|
||||
assert mock_client.create.call_count == 3
|
||||
|
||||
|
||||
def test_completion_with_retry_streaming_retries_on_setup() -> None:
|
||||
"""Streaming errors raised during the first-chunk pull are retried."""
|
||||
llm = _make_llm(max_retries=1)
|
||||
|
||||
calls = {"n": 0}
|
||||
|
||||
def _fail_then_stream(**_kwargs: Any) -> Any:
|
||||
calls["n"] += 1
|
||||
if calls["n"] == 1:
|
||||
|
||||
def _failing_gen() -> Any:
|
||||
msg = "rate limited"
|
||||
raise RateLimitError(msg)
|
||||
yield # pragma: no cover
|
||||
|
||||
return _failing_gen()
|
||||
|
||||
def _good_gen() -> Any:
|
||||
yield {"id": 0, "choices": [{"delta": {"content": "one"}}]}
|
||||
yield {"id": 1, "choices": [{"delta": {"content": "two"}}]}
|
||||
|
||||
return _good_gen()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.create.side_effect = _fail_then_stream
|
||||
llm.client = mock_client
|
||||
|
||||
chunks = list(_completion_with_retry(llm, messages=[], stream=True))
|
||||
|
||||
# First chunk is preserved and in order — guards `_prepend_chunk` regression
|
||||
assert [c["id"] for c in chunks] == [0, 1]
|
||||
assert calls["n"] == 2
|
||||
|
||||
|
||||
def test_completion_with_retry_streaming_accepts_iterable_only_result() -> None:
|
||||
"""Streaming setup accepts iterable-only custom client wrappers."""
|
||||
|
||||
class _IterableOnlyStream:
|
||||
def __iter__(self) -> Any:
|
||||
yield {"id": 0, "choices": [{"delta": {"content": "one"}}]}
|
||||
yield {"id": 1, "choices": [{"delta": {"content": "two"}}]}
|
||||
|
||||
llm = _make_llm(max_retries=0)
|
||||
mock_client = MagicMock()
|
||||
mock_client.create.return_value = _IterableOnlyStream()
|
||||
llm.client = mock_client
|
||||
|
||||
chunks = list(_completion_with_retry(llm, messages=[], stream=True))
|
||||
|
||||
assert [c["id"] for c in chunks] == [0, 1]
|
||||
assert mock_client.create.call_count == 1
|
||||
|
||||
|
||||
def test_completion_with_retry_retries_on_5xx_http_status_error() -> None:
|
||||
"""5xx `httpx.HTTPStatusError` is promoted and retried."""
|
||||
llm = _make_llm(max_retries=1)
|
||||
mock_client = MagicMock()
|
||||
response_504 = httpx.Response(status_code=504, request=httpx.Request("POST", "x"))
|
||||
mock_client.create.side_effect = [
|
||||
httpx.HTTPStatusError(
|
||||
"504", request=response_504.request, response=response_504
|
||||
),
|
||||
_success_response(),
|
||||
]
|
||||
llm.client = mock_client
|
||||
|
||||
result = _completion_with_retry(llm, messages=[])
|
||||
|
||||
assert result == _success_response()
|
||||
assert mock_client.create.call_count == 2
|
||||
|
||||
|
||||
def test_completion_with_retry_does_not_retry_on_4xx_http_status_error() -> None:
|
||||
"""Non-5xx `httpx.HTTPStatusError` passes through unretried."""
|
||||
llm = _make_llm(max_retries=3)
|
||||
mock_client = MagicMock()
|
||||
response_422 = httpx.Response(status_code=422, request=httpx.Request("POST", "x"))
|
||||
mock_client.create.side_effect = httpx.HTTPStatusError(
|
||||
"422", request=response_422.request, response=response_422
|
||||
)
|
||||
llm.client = mock_client
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
_completion_with_retry(llm, messages=[])
|
||||
assert mock_client.create.call_count == 1
|
||||
|
||||
|
||||
def test_completion_with_retry_retries_on_timeout_exception() -> None:
|
||||
"""`httpx.TimeoutException` is in the retryable set."""
|
||||
llm = _make_llm(max_retries=1)
|
||||
mock_client = MagicMock()
|
||||
mock_client.create.side_effect = [
|
||||
httpx.ConnectTimeout("slow"),
|
||||
_success_response(),
|
||||
]
|
||||
llm.client = mock_client
|
||||
|
||||
result = _completion_with_retry(llm, messages=[])
|
||||
|
||||
assert result == _success_response()
|
||||
assert mock_client.create.call_count == 2
|
||||
|
||||
|
||||
def test_completion_with_retry_max_retries_zero_is_single_attempt() -> None:
|
||||
"""`max_retries=0` disables retries (same as `None`)."""
|
||||
llm = _make_llm(max_retries=0)
|
||||
mock_client = MagicMock()
|
||||
mock_client.create.side_effect = RateLimitError("rate limited")
|
||||
llm.client = mock_client
|
||||
|
||||
with pytest.raises(RateLimitError):
|
||||
_completion_with_retry(llm, messages=[])
|
||||
assert mock_client.create.call_count == 1
|
||||
|
||||
|
||||
def test_completion_with_retry_raises_on_empty_stream() -> None:
|
||||
"""Empty streams surface as a descriptive `FireworksError`."""
|
||||
llm = _make_llm(max_retries=0)
|
||||
mock_client = MagicMock()
|
||||
|
||||
def _empty_gen(**_kwargs: Any) -> Any:
|
||||
if False:
|
||||
yield # pragma: no cover
|
||||
return
|
||||
|
||||
mock_client.create.side_effect = _empty_gen
|
||||
llm.client = mock_client
|
||||
|
||||
with pytest.raises(FireworksError, match="empty stream"):
|
||||
list(_completion_with_retry(llm, messages=[], stream=True))
|
||||
|
||||
|
||||
def test_chat_fireworks_invoke_routes_through_retry() -> None:
|
||||
"""`.invoke()` end-to-end exercises the retry helper on `self.client.create`.
|
||||
|
||||
Guards against a regression that bypasses `_completion_with_retry` from
|
||||
`_generate`.
|
||||
"""
|
||||
llm = _make_llm(max_retries=2)
|
||||
mock_client = MagicMock()
|
||||
mock_client.create.side_effect = [
|
||||
RateLimitError("rate limited"),
|
||||
_success_response(),
|
||||
]
|
||||
llm.client = mock_client
|
||||
|
||||
result = llm.invoke("hello")
|
||||
|
||||
assert isinstance(result, AIMessage)
|
||||
assert result.content == "hello"
|
||||
assert mock_client.create.call_count == 2
|
||||
|
||||
|
||||
async def test_acompletion_with_retry_streaming_retries_on_setup() -> None:
|
||||
"""Async streaming errors during the first-chunk pull are retried."""
|
||||
llm = _make_llm(max_retries=1)
|
||||
calls = {"n": 0}
|
||||
|
||||
def _acreate(**_kwargs: Any) -> Any:
|
||||
calls["n"] += 1
|
||||
if calls["n"] == 1:
|
||||
|
||||
async def _failing_agen() -> Any:
|
||||
msg = "rate limited"
|
||||
raise RateLimitError(msg)
|
||||
yield # pragma: no cover
|
||||
|
||||
return _failing_agen()
|
||||
|
||||
async def _good_agen() -> Any:
|
||||
yield {"id": 0, "choices": [{"delta": {"content": "one"}}]}
|
||||
yield {"id": 1, "choices": [{"delta": {"content": "two"}}]}
|
||||
|
||||
return _good_agen()
|
||||
|
||||
mock_async = MagicMock()
|
||||
mock_async.acreate = _acreate
|
||||
llm.async_client = mock_async
|
||||
|
||||
agen = await _acompletion_with_retry(llm, messages=[], stream=True)
|
||||
chunks = [c async for c in agen]
|
||||
|
||||
assert [c["id"] for c in chunks] == [0, 1]
|
||||
assert calls["n"] == 2
|
||||
|
||||
|
||||
async def test_acompletion_with_retry_streaming_accepts_async_iterable_only_result() -> ( # noqa: E501
|
||||
None
|
||||
):
|
||||
"""Async streaming setup accepts async-iterable-only custom wrappers."""
|
||||
|
||||
class _AsyncIterableOnlyStream:
|
||||
def __aiter__(self) -> Any:
|
||||
async def _aiter() -> Any:
|
||||
yield {"id": 0, "choices": [{"delta": {"content": "one"}}]}
|
||||
yield {"id": 1, "choices": [{"delta": {"content": "two"}}]}
|
||||
|
||||
return _aiter()
|
||||
|
||||
llm = _make_llm(max_retries=0)
|
||||
mock_async = MagicMock()
|
||||
mock_async.acreate = MagicMock(return_value=_AsyncIterableOnlyStream())
|
||||
llm.async_client = mock_async
|
||||
|
||||
agen = await _acompletion_with_retry(llm, messages=[], stream=True)
|
||||
chunks = [c async for c in agen]
|
||||
|
||||
assert [c["id"] for c in chunks] == [0, 1]
|
||||
assert mock_async.acreate.call_count == 1
|
||||
|
||||
|
||||
async def test_achat_fireworks_ainvoke_routes_through_retry() -> None:
|
||||
"""`.ainvoke()` end-to-end exercises the async retry helper."""
|
||||
llm = _make_llm(max_retries=2)
|
||||
calls = {"n": 0}
|
||||
|
||||
async def _acreate(**_kwargs: Any) -> dict[str, Any]:
|
||||
calls["n"] += 1
|
||||
if calls["n"] == 1:
|
||||
msg = "rate limited"
|
||||
raise RateLimitError(msg)
|
||||
return _success_response()
|
||||
|
||||
mock_async = MagicMock()
|
||||
mock_async.acreate = _acreate
|
||||
llm.async_client = mock_async
|
||||
|
||||
result = await llm.ainvoke("hello")
|
||||
assert isinstance(result, AIMessage)
|
||||
assert result.content == "hello"
|
||||
assert calls["n"] == 2
|
||||
|
||||
|
||||
async def test_acompletion_with_retry_retries_on_retryable_error() -> None:
|
||||
"""Async retries on retryable errors up to the configured limit."""
|
||||
llm = _make_llm(max_retries=2)
|
||||
mock_async = MagicMock()
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def _acreate(**_kwargs: Any) -> dict[str, Any]:
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] < 3:
|
||||
msg = "rate limited"
|
||||
raise RateLimitError(msg)
|
||||
return _success_response()
|
||||
|
||||
mock_async.acreate = _acreate
|
||||
llm.async_client = mock_async
|
||||
|
||||
result = await _acompletion_with_retry(llm, messages=[])
|
||||
assert result == _success_response()
|
||||
assert call_count["n"] == 3
|
||||
|
||||
|
||||
async def test_acompletion_with_retry_does_not_retry_non_retryable() -> None:
|
||||
"""Async does not retry non-retryable errors."""
|
||||
llm = _make_llm(max_retries=3)
|
||||
mock_async = MagicMock()
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def _acreate(**_kwargs: Any) -> dict[str, Any]:
|
||||
call_count["n"] += 1
|
||||
msg = "bad input"
|
||||
raise InvalidRequestError(msg)
|
||||
|
||||
mock_async.acreate = _acreate
|
||||
llm.async_client = mock_async
|
||||
|
||||
with pytest.raises(InvalidRequestError):
|
||||
await _acompletion_with_retry(llm, messages=[HumanMessage(content="hi")])
|
||||
assert call_count["n"] == 1
|
||||
|
||||
|
||||
async def test_acompletion_with_retry_retries_on_5xx_http_status_error() -> None:
|
||||
"""Async 5xx `httpx.HTTPStatusError` is promoted and retried."""
|
||||
llm = _make_llm(max_retries=1)
|
||||
call_count = {"n": 0}
|
||||
response_504 = httpx.Response(status_code=504, request=httpx.Request("POST", "x"))
|
||||
|
||||
async def _acreate(**_kwargs: Any) -> dict[str, Any]:
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
msg = "504"
|
||||
raise httpx.HTTPStatusError(
|
||||
msg, request=response_504.request, response=response_504
|
||||
)
|
||||
return _success_response()
|
||||
|
||||
mock_async = MagicMock()
|
||||
mock_async.acreate = _acreate
|
||||
llm.async_client = mock_async
|
||||
|
||||
result = await _acompletion_with_retry(llm, messages=[])
|
||||
assert result == _success_response()
|
||||
assert call_count["n"] == 2
|
||||
|
||||
|
||||
async def test_acompletion_with_retry_raises_on_empty_stream() -> None:
|
||||
"""Async empty streams surface as a descriptive `FireworksError`."""
|
||||
llm = _make_llm(max_retries=0)
|
||||
|
||||
def _acreate(**_kwargs: Any) -> Any:
|
||||
async def _empty_agen() -> Any:
|
||||
if False:
|
||||
yield # pragma: no cover
|
||||
return
|
||||
|
||||
return _empty_agen()
|
||||
|
||||
mock_async = MagicMock()
|
||||
mock_async.acreate = _acreate
|
||||
llm.async_client = mock_async
|
||||
|
||||
with pytest.raises(FireworksError, match="empty stream"):
|
||||
agen = await _acompletion_with_retry(llm, messages=[], stream=True)
|
||||
async for _ in agen:
|
||||
pass
|
||||
|
||||
|
||||
def test_completion_with_retry_retries_on_transport_error() -> None:
|
||||
"""`httpx.TransportError` is in the retryable set."""
|
||||
llm = _make_llm(max_retries=1)
|
||||
mock_client = MagicMock()
|
||||
mock_client.create.side_effect = [
|
||||
httpx.ConnectError("refused"),
|
||||
_success_response(),
|
||||
]
|
||||
llm.client = mock_client
|
||||
|
||||
result = _completion_with_retry(llm, messages=[])
|
||||
|
||||
assert result == _success_response()
|
||||
assert mock_client.create.call_count == 2
|
||||
|
||||
|
||||
class TestUsageToMetadata:
|
||||
"""Tests for the `_usage_to_metadata` helper."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user