From 7b09eb7bda4dbe99615ff5e5f74539aea682d718 Mon Sep 17 00:00:00 2001 From: Mason Daugherty Date: Thu, 23 Apr 2026 16:40:54 -0400 Subject: [PATCH] fix(fireworks): honor `max_retries` (#36973) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `ChatFireworks.max_retries` silently did nothing. The old code assigned the value to a `ChatCompletionV2` sub-object rather than the base client, and the pinned Fireworks SDK (0.13.0–0.19.20) never honors its own `_max_retries` attribute on the base client either. Since the Stainless-generated 1.x SDK that does implement retries is still pre-release (1.0.1a63 at time of writing), retry responsibility is ported to the LangChain side until the pin can be bumped. --- libs/langchain/uv.lock | 2 +- libs/langchain_v1/uv.lock | 2 +- libs/model-profiles/uv.lock | 2 +- .../langchain_fireworks/chat_models.py | 170 ++++++- .../tests/unit_tests/test_chat_models.py | 439 +++++++++++++++++- libs/text-splitters/uv.lock | 2 +- 6 files changed, 604 insertions(+), 13 deletions(-) diff --git a/libs/langchain/uv.lock b/libs/langchain/uv.lock index 10aac933a34..8ad1fff9b38 100644 --- a/libs/langchain/uv.lock +++ b/libs/langchain/uv.lock @@ -2601,7 +2601,7 @@ typing = [ [[package]] name = "langchain-core" -version = "1.3.0" +version = "1.3.1" source = { editable = "../core" } dependencies = [ { name = "jsonpatch" }, diff --git a/libs/langchain_v1/uv.lock b/libs/langchain_v1/uv.lock index a346ff52ee1..1a73a323a1f 100644 --- a/libs/langchain_v1/uv.lock +++ b/libs/langchain_v1/uv.lock @@ -2208,7 +2208,7 @@ wheels = [ [[package]] name = "langchain-core" -version = "1.3.0" +version = "1.3.1" source = { editable = "../core" } dependencies = [ { name = "jsonpatch" }, diff --git a/libs/model-profiles/uv.lock b/libs/model-profiles/uv.lock index 9797f2e2161..a54a5826a79 100644 --- a/libs/model-profiles/uv.lock +++ b/libs/model-profiles/uv.lock @@ -532,7 +532,7 @@ typing = [ [[package]] name = "langchain-core" -version = "1.3.0" +version = "1.3.1" source = { editable = "../core" } dependencies = [ { name = "jsonpatch" }, diff --git a/libs/partners/fireworks/langchain_fireworks/chat_models.py b/libs/partners/fireworks/langchain_fireworks/chat_models.py index 9e2e4a45a8a..05ef1be5f3d 100644 --- a/libs/partners/fireworks/langchain_fireworks/chat_models.py +++ b/libs/partners/fireworks/langchain_fireworks/chat_models.py @@ -10,10 +10,20 @@ from operator import itemgetter from typing import ( Any, Literal, + NoReturn, cast, ) +import httpx from fireworks.client import AsyncFireworks, Fireworks # type: ignore[import-untyped] +from fireworks.client.error import ( # type: ignore[import-untyped] + APITimeoutError, + BadGatewayError, + FireworksError, + InternalServerError, + RateLimitError, + ServiceUnavailableError, +) from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, @@ -29,6 +39,7 @@ from langchain_core.language_models.chat_models import ( agenerate_from_stream, generate_from_stream, ) +from langchain_core.language_models.llms import create_base_retry_decorator from langchain_core.messages import ( AIMessage, AIMessageChunk, @@ -290,6 +301,137 @@ def _convert_chunk_to_message_chunk( return default_class(content=content) # type: ignore[call-arg] +class _RetryableHTTPStatusError(FireworksError): + """Internal marker for 5xx `httpx.HTTPStatusError` responses. + + The Fireworks SDK maps a subset of status codes (500, 502, 503) to typed + exceptions but lets others (504, 507-511, Cloudflare-edge 520-599) + propagate as raw `httpx.HTTPStatusError`. Promoting those to this marker + inside `_call` keeps the retryable set expressible as a list of classes + for `create_base_retry_decorator`, preserving parity with `ChatMistralAI`. + """ + + +_RETRYABLE_ERRORS: tuple[type[BaseException], ...] = ( + APITimeoutError, + BadGatewayError, + InternalServerError, + RateLimitError, + ServiceUnavailableError, + httpx.TimeoutException, + httpx.TransportError, + _RetryableHTTPStatusError, +) + + +def _promote_http_status_error(exc: httpx.HTTPStatusError) -> NoReturn: + """Re-raise 5xx `httpx.HTTPStatusError` as a retryable marker.""" + if exc.response.status_code >= 500: + msg = f"Retryable {exc.response.status_code} from Fireworks: {exc}" + raise _RetryableHTTPStatusError(msg) from exc + raise exc + + +def _raise_empty_stream() -> NoReturn: + """Raise a descriptive error when the SDK returns a zero-chunk stream.""" + msg = "Received empty stream from Fireworks" + raise FireworksError(msg) + + +def _create_retry_decorator( + llm: ChatFireworks, + run_manager: AsyncCallbackManagerForLLMRun | CallbackManagerForLLMRun | None = None, +) -> Callable[[Any], Any]: + """Return a tenacity retry decorator for Fireworks SDK calls. + + Retries are implemented here because the pinned Fireworks SDK 0.x does + not honor its own `_max_retries` attribute on completion resources. + """ + # `max_retries` counts retries *after* the initial attempt. + # `create_base_retry_decorator` forwards its `max_retries` to + # `stop_after_attempt`, which counts total attempts — so offset by 1. + # Note: this diverges from `ChatMistralAI`, which passes the raw value; + # the fireworks field docstring is the source of truth here. + # `None` and `0` both mean "single attempt, no retries". + attempts = (llm.max_retries + 1) if llm.max_retries else 1 + return create_base_retry_decorator( + error_types=list(_RETRYABLE_ERRORS), + max_retries=attempts, + run_manager=run_manager, + ) + + +def _completion_with_retry( + llm: ChatFireworks, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, +) -> Any: + """Retry the sync completion call, including stream setup.""" + retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) + + @retry_decorator + def _call() -> Any: + try: + result = llm.client.create(**kwargs) + except httpx.HTTPStatusError as e: + _promote_http_status_error(e) + if kwargs.get("stream"): + # The streaming generator is lazy — advance once so the HTTP + # connection and any transport error happen inside the retry + # boundary. `_prepend_chunk` then re-yields the consumed chunk + # ahead of the rest so callers still see every event. + try: + iterator = iter(result) + first = next(iterator) + except StopIteration: + _raise_empty_stream() + except httpx.HTTPStatusError as e: + _promote_http_status_error(e) + return _prepend_chunk(first, iterator) + return result + + return _call() + + +async def _acompletion_with_retry( + llm: ChatFireworks, + run_manager: AsyncCallbackManagerForLLMRun | None = None, + **kwargs: Any, +) -> Any: + """Retry the async completion call, including stream setup.""" + retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) + + @retry_decorator + async def _call() -> Any: + if kwargs.get("stream"): + try: + result = llm.async_client.acreate(**kwargs) + agen = result.__aiter__() + first = await agen.__anext__() + except StopAsyncIteration: + _raise_empty_stream() + except httpx.HTTPStatusError as e: + _promote_http_status_error(e) + return _aprepend_chunk(first, agen) + try: + return await llm.async_client.acreate(**kwargs) + except httpx.HTTPStatusError as e: + _promote_http_status_error(e) + + return await _call() + + +def _prepend_chunk(first: Any, rest: Iterator[Any]) -> Iterator[Any]: + yield first + yield from rest + + +async def _aprepend_chunk(first: Any, rest: AsyncIterator[Any]) -> AsyncIterator[Any]: + yield first + async for item in rest: + yield item + + # This is basically a copy and replace for ChatFireworks, except # - I needed to gut out tiktoken and some of the token estimation logic # (not sure how important it is) @@ -416,7 +558,14 @@ class ChatFireworks(BaseChatModel): """Maximum number of tokens to generate.""" max_retries: int | None = None - """Maximum number of retries to make when generating.""" + """Maximum number of retries after the initial attempt when generating. + + Retries use exponential backoff and trigger on transient errors: + `RateLimitError`, `APITimeoutError`, 5xx responses (including those that + surface as `httpx.HTTPStatusError` rather than typed SDK errors), and + underlying transport errors (`httpx.TimeoutException`, `httpx.TransportError`). + A value of `None` or `0` disables retries. + """ model_config = ConfigDict( populate_by_name=True, @@ -453,9 +602,6 @@ class ChatFireworks(BaseChatModel): self.client = Fireworks(**client_params).chat.completions if not self.async_client: self.async_client = AsyncFireworks(**client_params).chat.completions - if self.max_retries: - self.client._max_retries = self.max_retries - self.async_client._max_retries = self.max_retries return self def _resolve_model_profile(self) -> ModelProfile | None: @@ -528,7 +674,9 @@ class ChatFireworks(BaseChatModel): params["stream_options"] = {"include_usage": True} default_chunk_class: type[BaseMessageChunk] = AIMessageChunk - for chunk in self.client.create(messages=message_dicts, **params): + for chunk in _completion_with_retry( + self, run_manager=run_manager, messages=message_dicts, **params + ): if not isinstance(chunk, dict): chunk = chunk.model_dump() message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class) @@ -572,7 +720,9 @@ class ChatFireworks(BaseChatModel): **({"stream": stream} if stream is not None else {}), **kwargs, } - response = self.client.create(messages=message_dicts, **params) + response = _completion_with_retry( + self, run_manager=run_manager, messages=message_dicts, **params + ) return self._create_chat_result(response) def _create_message_dicts( @@ -626,7 +776,9 @@ class ChatFireworks(BaseChatModel): params["stream_options"] = {"include_usage": True} default_chunk_class: type[BaseMessageChunk] = AIMessageChunk - async for chunk in self.async_client.acreate(messages=message_dicts, **params): + async for chunk in await _acompletion_with_retry( + self, run_manager=run_manager, messages=message_dicts, **params + ): if not isinstance(chunk, dict): chunk = chunk.model_dump() message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class) @@ -673,7 +825,9 @@ class ChatFireworks(BaseChatModel): **({"stream": stream} if stream is not None else {}), **kwargs, } - response = await self.async_client.acreate(messages=message_dicts, **params) + response = await _acompletion_with_retry( + self, run_manager=run_manager, messages=message_dicts, **params + ) return self._create_chat_result(response) @property diff --git a/libs/partners/fireworks/tests/unit_tests/test_chat_models.py b/libs/partners/fireworks/tests/unit_tests/test_chat_models.py index 4acfc9a14b2..1e3d89d62dc 100644 --- a/libs/partners/fireworks/tests/unit_tests/test_chat_models.py +++ b/libs/partners/fireworks/tests/unit_tests/test_chat_models.py @@ -5,11 +5,21 @@ from __future__ import annotations from typing import Any from unittest.mock import MagicMock +import httpx import pytest -from langchain_core.messages import AIMessage, AIMessageChunk +from fireworks.client.error import ( # type: ignore[import-untyped] + AuthenticationError, + FireworksError, + InvalidRequestError, + RateLimitError, + ServiceUnavailableError, +) +from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage from langchain_fireworks import ChatFireworks from langchain_fireworks.chat_models import ( + _acompletion_with_retry, + _completion_with_retry, _convert_chunk_to_message_chunk, _convert_dict_to_message, _usage_to_metadata, @@ -82,6 +92,433 @@ def test_convert_dict_to_message_without_reasoning_content() -> None: assert "reasoning_content" not in message.additional_kwargs +def _make_llm(max_retries: int | None = 2) -> ChatFireworks: + return ChatFireworks( + model="accounts/fireworks/models/test", + api_key="fake-key", # type: ignore[arg-type] + max_retries=max_retries, + ) + + +def _success_response() -> dict[str, Any]: + return { + "choices": [ + { + "message": {"role": "assistant", "content": "hello"}, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, + } + + +@pytest.fixture(autouse=True) +def _no_retry_sleep(monkeypatch: pytest.MonkeyPatch) -> None: + """Avoid tenacity's exponential backoff in tests.""" + import asyncio + import time + + monkeypatch.setattr(time, "sleep", lambda _s: None) + + async def _no_async_sleep(_s: float) -> None: + return None + + monkeypatch.setattr(asyncio, "sleep", _no_async_sleep) + + +def test_completion_with_retry_retries_on_retryable_error() -> None: + """Retryable errors trigger retries up to the configured limit.""" + llm = _make_llm(max_retries=2) + mock_client = MagicMock() + mock_client.create.side_effect = [ + RateLimitError("rate limited"), + ServiceUnavailableError("unavailable"), + _success_response(), + ] + llm.client = mock_client + + result = _completion_with_retry(llm, messages=[]) + + assert result == _success_response() + assert mock_client.create.call_count == 3 + + +def test_completion_with_retry_does_not_retry_non_retryable() -> None: + """Non-retryable errors propagate after a single attempt.""" + llm = _make_llm(max_retries=3) + mock_client = MagicMock() + mock_client.create.side_effect = AuthenticationError("bad key") + llm.client = mock_client + + with pytest.raises(AuthenticationError): + _completion_with_retry(llm, messages=[]) + + assert mock_client.create.call_count == 1 + + +def test_completion_with_retry_respects_max_retries_none() -> None: + """`max_retries=None` disables retries.""" + llm = _make_llm(max_retries=None) + mock_client = MagicMock() + mock_client.create.side_effect = RateLimitError("rate limited") + llm.client = mock_client + + with pytest.raises(RateLimitError): + _completion_with_retry(llm, messages=[]) + + assert mock_client.create.call_count == 1 + + +def test_completion_with_retry_exhausts_and_raises() -> None: + """When every attempt fails, the last error is re-raised.""" + llm = _make_llm(max_retries=2) + mock_client = MagicMock() + mock_client.create.side_effect = RateLimitError("rate limited") + llm.client = mock_client + + with pytest.raises(RateLimitError): + _completion_with_retry(llm, messages=[]) + + # 1 initial attempt + 2 retries = 3 total attempts + assert mock_client.create.call_count == 3 + + +def test_completion_with_retry_streaming_retries_on_setup() -> None: + """Streaming errors raised during the first-chunk pull are retried.""" + llm = _make_llm(max_retries=1) + + calls = {"n": 0} + + def _fail_then_stream(**_kwargs: Any) -> Any: + calls["n"] += 1 + if calls["n"] == 1: + + def _failing_gen() -> Any: + msg = "rate limited" + raise RateLimitError(msg) + yield # pragma: no cover + + return _failing_gen() + + def _good_gen() -> Any: + yield {"id": 0, "choices": [{"delta": {"content": "one"}}]} + yield {"id": 1, "choices": [{"delta": {"content": "two"}}]} + + return _good_gen() + + mock_client = MagicMock() + mock_client.create.side_effect = _fail_then_stream + llm.client = mock_client + + chunks = list(_completion_with_retry(llm, messages=[], stream=True)) + + # First chunk is preserved and in order — guards `_prepend_chunk` regression + assert [c["id"] for c in chunks] == [0, 1] + assert calls["n"] == 2 + + +def test_completion_with_retry_streaming_accepts_iterable_only_result() -> None: + """Streaming setup accepts iterable-only custom client wrappers.""" + + class _IterableOnlyStream: + def __iter__(self) -> Any: + yield {"id": 0, "choices": [{"delta": {"content": "one"}}]} + yield {"id": 1, "choices": [{"delta": {"content": "two"}}]} + + llm = _make_llm(max_retries=0) + mock_client = MagicMock() + mock_client.create.return_value = _IterableOnlyStream() + llm.client = mock_client + + chunks = list(_completion_with_retry(llm, messages=[], stream=True)) + + assert [c["id"] for c in chunks] == [0, 1] + assert mock_client.create.call_count == 1 + + +def test_completion_with_retry_retries_on_5xx_http_status_error() -> None: + """5xx `httpx.HTTPStatusError` is promoted and retried.""" + llm = _make_llm(max_retries=1) + mock_client = MagicMock() + response_504 = httpx.Response(status_code=504, request=httpx.Request("POST", "x")) + mock_client.create.side_effect = [ + httpx.HTTPStatusError( + "504", request=response_504.request, response=response_504 + ), + _success_response(), + ] + llm.client = mock_client + + result = _completion_with_retry(llm, messages=[]) + + assert result == _success_response() + assert mock_client.create.call_count == 2 + + +def test_completion_with_retry_does_not_retry_on_4xx_http_status_error() -> None: + """Non-5xx `httpx.HTTPStatusError` passes through unretried.""" + llm = _make_llm(max_retries=3) + mock_client = MagicMock() + response_422 = httpx.Response(status_code=422, request=httpx.Request("POST", "x")) + mock_client.create.side_effect = httpx.HTTPStatusError( + "422", request=response_422.request, response=response_422 + ) + llm.client = mock_client + + with pytest.raises(httpx.HTTPStatusError): + _completion_with_retry(llm, messages=[]) + assert mock_client.create.call_count == 1 + + +def test_completion_with_retry_retries_on_timeout_exception() -> None: + """`httpx.TimeoutException` is in the retryable set.""" + llm = _make_llm(max_retries=1) + mock_client = MagicMock() + mock_client.create.side_effect = [ + httpx.ConnectTimeout("slow"), + _success_response(), + ] + llm.client = mock_client + + result = _completion_with_retry(llm, messages=[]) + + assert result == _success_response() + assert mock_client.create.call_count == 2 + + +def test_completion_with_retry_max_retries_zero_is_single_attempt() -> None: + """`max_retries=0` disables retries (same as `None`).""" + llm = _make_llm(max_retries=0) + mock_client = MagicMock() + mock_client.create.side_effect = RateLimitError("rate limited") + llm.client = mock_client + + with pytest.raises(RateLimitError): + _completion_with_retry(llm, messages=[]) + assert mock_client.create.call_count == 1 + + +def test_completion_with_retry_raises_on_empty_stream() -> None: + """Empty streams surface as a descriptive `FireworksError`.""" + llm = _make_llm(max_retries=0) + mock_client = MagicMock() + + def _empty_gen(**_kwargs: Any) -> Any: + if False: + yield # pragma: no cover + return + + mock_client.create.side_effect = _empty_gen + llm.client = mock_client + + with pytest.raises(FireworksError, match="empty stream"): + list(_completion_with_retry(llm, messages=[], stream=True)) + + +def test_chat_fireworks_invoke_routes_through_retry() -> None: + """`.invoke()` end-to-end exercises the retry helper on `self.client.create`. + + Guards against a regression that bypasses `_completion_with_retry` from + `_generate`. + """ + llm = _make_llm(max_retries=2) + mock_client = MagicMock() + mock_client.create.side_effect = [ + RateLimitError("rate limited"), + _success_response(), + ] + llm.client = mock_client + + result = llm.invoke("hello") + + assert isinstance(result, AIMessage) + assert result.content == "hello" + assert mock_client.create.call_count == 2 + + +async def test_acompletion_with_retry_streaming_retries_on_setup() -> None: + """Async streaming errors during the first-chunk pull are retried.""" + llm = _make_llm(max_retries=1) + calls = {"n": 0} + + def _acreate(**_kwargs: Any) -> Any: + calls["n"] += 1 + if calls["n"] == 1: + + async def _failing_agen() -> Any: + msg = "rate limited" + raise RateLimitError(msg) + yield # pragma: no cover + + return _failing_agen() + + async def _good_agen() -> Any: + yield {"id": 0, "choices": [{"delta": {"content": "one"}}]} + yield {"id": 1, "choices": [{"delta": {"content": "two"}}]} + + return _good_agen() + + mock_async = MagicMock() + mock_async.acreate = _acreate + llm.async_client = mock_async + + agen = await _acompletion_with_retry(llm, messages=[], stream=True) + chunks = [c async for c in agen] + + assert [c["id"] for c in chunks] == [0, 1] + assert calls["n"] == 2 + + +async def test_acompletion_with_retry_streaming_accepts_async_iterable_only_result() -> ( # noqa: E501 + None +): + """Async streaming setup accepts async-iterable-only custom wrappers.""" + + class _AsyncIterableOnlyStream: + def __aiter__(self) -> Any: + async def _aiter() -> Any: + yield {"id": 0, "choices": [{"delta": {"content": "one"}}]} + yield {"id": 1, "choices": [{"delta": {"content": "two"}}]} + + return _aiter() + + llm = _make_llm(max_retries=0) + mock_async = MagicMock() + mock_async.acreate = MagicMock(return_value=_AsyncIterableOnlyStream()) + llm.async_client = mock_async + + agen = await _acompletion_with_retry(llm, messages=[], stream=True) + chunks = [c async for c in agen] + + assert [c["id"] for c in chunks] == [0, 1] + assert mock_async.acreate.call_count == 1 + + +async def test_achat_fireworks_ainvoke_routes_through_retry() -> None: + """`.ainvoke()` end-to-end exercises the async retry helper.""" + llm = _make_llm(max_retries=2) + calls = {"n": 0} + + async def _acreate(**_kwargs: Any) -> dict[str, Any]: + calls["n"] += 1 + if calls["n"] == 1: + msg = "rate limited" + raise RateLimitError(msg) + return _success_response() + + mock_async = MagicMock() + mock_async.acreate = _acreate + llm.async_client = mock_async + + result = await llm.ainvoke("hello") + assert isinstance(result, AIMessage) + assert result.content == "hello" + assert calls["n"] == 2 + + +async def test_acompletion_with_retry_retries_on_retryable_error() -> None: + """Async retries on retryable errors up to the configured limit.""" + llm = _make_llm(max_retries=2) + mock_async = MagicMock() + + call_count = {"n": 0} + + async def _acreate(**_kwargs: Any) -> dict[str, Any]: + call_count["n"] += 1 + if call_count["n"] < 3: + msg = "rate limited" + raise RateLimitError(msg) + return _success_response() + + mock_async.acreate = _acreate + llm.async_client = mock_async + + result = await _acompletion_with_retry(llm, messages=[]) + assert result == _success_response() + assert call_count["n"] == 3 + + +async def test_acompletion_with_retry_does_not_retry_non_retryable() -> None: + """Async does not retry non-retryable errors.""" + llm = _make_llm(max_retries=3) + mock_async = MagicMock() + call_count = {"n": 0} + + async def _acreate(**_kwargs: Any) -> dict[str, Any]: + call_count["n"] += 1 + msg = "bad input" + raise InvalidRequestError(msg) + + mock_async.acreate = _acreate + llm.async_client = mock_async + + with pytest.raises(InvalidRequestError): + await _acompletion_with_retry(llm, messages=[HumanMessage(content="hi")]) + assert call_count["n"] == 1 + + +async def test_acompletion_with_retry_retries_on_5xx_http_status_error() -> None: + """Async 5xx `httpx.HTTPStatusError` is promoted and retried.""" + llm = _make_llm(max_retries=1) + call_count = {"n": 0} + response_504 = httpx.Response(status_code=504, request=httpx.Request("POST", "x")) + + async def _acreate(**_kwargs: Any) -> dict[str, Any]: + call_count["n"] += 1 + if call_count["n"] == 1: + msg = "504" + raise httpx.HTTPStatusError( + msg, request=response_504.request, response=response_504 + ) + return _success_response() + + mock_async = MagicMock() + mock_async.acreate = _acreate + llm.async_client = mock_async + + result = await _acompletion_with_retry(llm, messages=[]) + assert result == _success_response() + assert call_count["n"] == 2 + + +async def test_acompletion_with_retry_raises_on_empty_stream() -> None: + """Async empty streams surface as a descriptive `FireworksError`.""" + llm = _make_llm(max_retries=0) + + def _acreate(**_kwargs: Any) -> Any: + async def _empty_agen() -> Any: + if False: + yield # pragma: no cover + return + + return _empty_agen() + + mock_async = MagicMock() + mock_async.acreate = _acreate + llm.async_client = mock_async + + with pytest.raises(FireworksError, match="empty stream"): + agen = await _acompletion_with_retry(llm, messages=[], stream=True) + async for _ in agen: + pass + + +def test_completion_with_retry_retries_on_transport_error() -> None: + """`httpx.TransportError` is in the retryable set.""" + llm = _make_llm(max_retries=1) + mock_client = MagicMock() + mock_client.create.side_effect = [ + httpx.ConnectError("refused"), + _success_response(), + ] + llm.client = mock_client + + result = _completion_with_retry(llm, messages=[]) + + assert result == _success_response() + assert mock_client.create.call_count == 2 + + class TestUsageToMetadata: """Tests for the `_usage_to_metadata` helper.""" diff --git a/libs/text-splitters/uv.lock b/libs/text-splitters/uv.lock index b2bf206a8e2..05e61dd58de 100644 --- a/libs/text-splitters/uv.lock +++ b/libs/text-splitters/uv.lock @@ -1186,7 +1186,7 @@ wheels = [ [[package]] name = "langchain-core" -version = "1.3.0" +version = "1.3.1" source = { editable = "../core" } dependencies = [ { name = "jsonpatch" },