feat(fireworks): migrate to fireworks-ai 1.x SDK (#37581)

Closes #37172

---

Bumps `langchain-fireworks` to the rewritten `fireworks-ai` 1.x SDK
(currently 1.2.0a*; Stainless-generated, pure-httpx, no
`grpcio`/`protobuf`/`googleapis-common-protos`).

The motivating bug is a startup crash in self-hosted LangGraph
environments that also import `langchain-google-vertexai`. Importing
`fireworks-ai` 0.19.x eagerly loads vendored grpcio protobuf modules
under `fireworks.control_plane.generated.protos_grpcio.*`, which
register `google/rpc/status.proto`, `google/api/*.proto`, and
`google/longrunning/*.proto` in the default protobuf descriptor pool.
When `langchain-google-vertexai` later triggers
`google.api_core.exceptions` → `grpc_status.rpc_status` →
`google.rpc.status_pb2`, the pool already holds a byte-different
descriptor for `google/rpc/status.proto` and startup dies with:

```
TypeError: Couldn't build proto file into descriptor pool:
duplicate file name google/rpc/status.proto
```

Fleet has been pinning around this by routing Fireworks through
`ChatOpenAI` against the OpenAI-compat endpoint, which works for
inference but means Fireworks `ModelProfile` data never loads — so Kimi
K2.6's ~262k context window goes unrecognized and summarization triggers
below limit.

The 1.x SDK does not vendor protobuf at all. The control-plane gRPC code
path is gone; chat inference goes over httpx. Verified locally that
`import langchain_fireworks` and `from langchain_fireworks import
ChatFireworks` load zero `_pb2` / `google.*` modules.

## What changed in `ChatFireworks`

- Imports switch from `fireworks.client` to the top-level `fireworks`
package.
- Async path now `await client.chat.completions.create(...)`; the 0.x
`acreate` shim is no longer used.
- Error classes remapped to the 1.x hierarchy. `InvalidRequestError` →
`BadRequestError`. `BadGatewayError` and `ServiceUnavailableError` no
longer exist (1.x maps all `>=500` to `InternalServerError`) and were
dropped from the retryable set with no loss of coverage.
`FireworksContextOverflowError`'s parent class becomes
`BadRequestError`.
- `stream_options` is moved into the SDK's `extra_body` because the
Stainless-generated `create()` signature does not model it as a typed
kwarg. Top-level `stream_options` is preserved as a caller convenience;
if a caller supplies both `extra_body["stream_options"]` and a top-level
value, `extra_body` wins and the discarded value is logged.
- The 0.x `(connect, read)` tuple form of `request_timeout` is
normalized to an `httpx.Timeout` so existing user code keeps working.
- The SDK's built-in retry layer is suppressed via `max_retries=0` on
client construction so retries remain owned by
`create_base_retry_decorator` and surface through the LangChain
`run_manager`.

## Lifecycle methods

Adds `close()` and `aclose()` on `ChatFireworks`. The 1.x
`AsyncFireworks` client defaults to `httpx_aiohttp.HttpxAiohttpClient`,
whose underlying aiohttp `ClientSession` is created lazily on first
request. Sync-only paths therefore never open a session — which fixes
the "Unclosed client session" warnings from #37172 at the source.
Callers using async paths can now release the connector
deterministically rather than relying on GC after the event loop has
stopped. An autouse fixture in the integration `conftest.py` calls
`aclose()` between tests to silence the corresponding `Unclosed
connector` warning that surfaces under `pytest-asyncio`.

## Relation to #37227

Supersedes #37227. That PR monkey-patched
`fireworks._util.is_running_in_async_context` and
`fireworks.client.api_client.is_running_in_async_context` to suppress
the 0.x SDK's eager `aiohttp.ClientSession` creation in async contexts.
Both module paths are removed in 1.x; the SDK's lazy-session behavior
makes the suppression unnecessary, and the explicit `aclose()` provides
the cleaner long-term lifecycle hook. Thanks to @keenborder786 for
surfacing the failure mode.

## Installation note

`fireworks-ai` 1.x is currently published as an alpha (`1.2.0a*`); a
stable 1.x is not yet out. `pip install langchain-fireworks` / `uv pip
install langchain-fireworks` will need `--pre` (or `--prerelease=allow`)
until Fireworks GAs 1.x. The `pyproject.toml` adds `[tool.uv] prerelease
= "allow"` so the in-repo dev environment resolves cleanly. The package
version is bumped to `1.4.0` — the public surface (`ChatFireworks`,
`Fireworks`, `FireworksEmbeddings`) is unchanged; the breakage is
confined to internal error classes and the transitive SDK.
This commit is contained in:
Mason Daugherty
2026-05-20 16:39:01 -05:00
committed by GitHub
parent ac41199338
commit d39950cb18
5 changed files with 380 additions and 226 deletions

View File

@@ -15,15 +15,14 @@ from typing import (
)
import httpx
from fireworks.client import AsyncFireworks, Fireworks # type: ignore[import-untyped]
from fireworks.client.error import ( # type: ignore[import-untyped]
from fireworks import (
APITimeoutError,
BadGatewayError,
AsyncFireworks,
BadRequestError,
Fireworks,
FireworksError,
InternalServerError,
InvalidRequestError,
RateLimitError,
ServiceUnavailableError,
)
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
@@ -94,6 +93,7 @@ from pydantic import (
BaseModel,
ConfigDict,
Field,
PrivateAttr,
SecretStr,
model_validator,
)
@@ -410,20 +410,19 @@ def _convert_chunk_to_message_chunk(
class _RetryableHTTPStatusError(FireworksError):
"""Internal marker for 5xx `httpx.HTTPStatusError` responses.
The Fireworks SDK maps a subset of status codes (500, 502, 503) to typed
exceptions but lets others (504, 507-511, Cloudflare-edge 520-599)
propagate as raw `httpx.HTTPStatusError`. Promoting those to this marker
inside `_call` keeps the retryable set expressible as a list of classes
for `create_base_retry_decorator`, preserving parity with `ChatMistralAI`.
The 1.x SDK wraps every status response into a typed `APIStatusError`
subclass, so this path is defense-in-depth: it only fires when a raw
`httpx.HTTPStatusError` escapes the SDK (e.g., a custom `http_client` or
monkey-patched transport raises one directly). Promoting it here keeps the
retryable set expressible as a list of classes for
`create_base_retry_decorator`.
"""
_RETRYABLE_ERRORS: tuple[type[BaseException], ...] = (
APITimeoutError,
BadGatewayError,
InternalServerError,
RateLimitError,
ServiceUnavailableError,
httpx.TimeoutException,
httpx.TransportError,
_RetryableHTTPStatusError,
@@ -438,14 +437,16 @@ def _promote_http_status_error(exc: httpx.HTTPStatusError) -> NoReturn:
raise exc
class FireworksContextOverflowError(InvalidRequestError, ContextOverflowError):
"""`InvalidRequestError` raised when input exceeds Fireworks's context limit."""
class FireworksContextOverflowError(BadRequestError, ContextOverflowError):
"""`BadRequestError` raised when input exceeds Fireworks's context limit."""
def _handle_fireworks_invalid_request(e: InvalidRequestError) -> NoReturn:
def _handle_fireworks_invalid_request(e: BadRequestError) -> NoReturn:
"""Promote prompt-too-long errors to `FireworksContextOverflowError`."""
if "prompt is too long" in str(e):
raise FireworksContextOverflowError(str(e)) from e
raise FireworksContextOverflowError(
str(e), response=e.response, body=e.body
) from e
raise e
@@ -461,14 +462,13 @@ def _create_retry_decorator(
) -> Callable[[Any], Any]:
"""Return a tenacity retry decorator for Fireworks SDK calls.
Retries are implemented here because the pinned Fireworks SDK 0.x does
not honor its own `_max_retries` attribute on completion resources.
Retries live here rather than in the SDK so each attempt is visible to the
LangChain `run_manager.on_retry` callback. The SDK's own retry layer is
suppressed via `max_retries=0` on the client; see `validate_environment`.
"""
# `max_retries` counts retries *after* the initial attempt.
# `create_base_retry_decorator` forwards its `max_retries` to
# `stop_after_attempt`, which counts total attempts — so offset by 1.
# Note: this diverges from `ChatMistralAI`, which passes the raw value;
# the fireworks field docstring is the source of truth here.
# `None` and `0` both mean "single attempt, no retries".
attempts = (llm.max_retries + 1) if llm.max_retries else 1
return create_base_retry_decorator(
@@ -478,6 +478,36 @@ def _create_retry_decorator(
)
def _prepare_sdk_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]:
"""Move fields the 1.x SDK does not model into `extra_body`.
The Stainless-generated `chat.completions.create` signature has a fixed set
of typed parameters. Fireworks accepts additional fields on the wire (notably
`stream_options.include_usage`) that the SDK schema does not declare. The
SDK exposes `extra_body` precisely for this — merge anything that looks
extra-body-shaped into it so it lands in the JSON request body.
If a caller supplies both `extra_body={"stream_options": ...}` and a
top-level `stream_options=...`, the value already in `extra_body` wins
(callers using `extra_body` are presumed to want explicit control); the
discarded top-level value is logged.
"""
extra_body = dict(kwargs.pop("extra_body", None) or {})
top_level_stream_options = kwargs.pop("stream_options", None)
if top_level_stream_options is not None:
if "stream_options" in extra_body:
logger.warning(
"Both `extra_body['stream_options']` and a top-level "
"`stream_options` were supplied; using `extra_body`'s value "
"and discarding the top-level value.",
)
else:
extra_body["stream_options"] = top_level_stream_options
if extra_body:
kwargs["extra_body"] = extra_body
return kwargs
def _completion_with_retry(
llm: ChatFireworks,
run_manager: CallbackManagerForLLMRun | None = None,
@@ -485,6 +515,7 @@ def _completion_with_retry(
) -> Any:
"""Retry the sync completion call, including stream setup."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
kwargs = _prepare_sdk_kwargs(kwargs)
@retry_decorator
def _call() -> Any:
@@ -517,12 +548,17 @@ async def _acompletion_with_retry(
) -> Any:
"""Retry the async completion call, including stream setup."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
kwargs = _prepare_sdk_kwargs(kwargs)
@retry_decorator
async def _call() -> Any:
if kwargs.get("stream"):
try:
result = llm.async_client.acreate(**kwargs)
# 1.x async `create()` is a coroutine that resolves to an
# `AsyncStream` when `stream=True`. Await it, then advance the
# async iterator once inside the retry boundary so transport
# errors surface here rather than at first downstream consumer.
result = await llm.async_client.create(**kwargs)
agen = result.__aiter__()
first = await agen.__anext__()
except StopAsyncIteration:
@@ -531,7 +567,7 @@ async def _acompletion_with_retry(
_promote_http_status_error(e)
return _aprepend_chunk(first, agen)
try:
return await llm.async_client.acreate(**kwargs)
return await llm.async_client.create(**kwargs)
except httpx.HTTPStatusError as e:
_promote_http_status_error(e)
@@ -549,11 +585,6 @@ async def _aprepend_chunk(first: Any, rest: AsyncIterator[Any]) -> AsyncIterator
yield item
# This is basically a copy and replace for ChatFireworks, except
# - I needed to gut out tiktoken and some of the token estimation logic
# (not sure how important it is)
# - Environment variable is different
# we should refactor into some OpenAI-like class in the future
class ChatFireworks(BaseChatModel):
"""`Fireworks` Chat large language models API.
@@ -598,8 +629,28 @@ class ChatFireworks(BaseChatModel):
return True
client: Any = Field(default=None, exclude=True)
"""Internal `fireworks.Fireworks().chat.completions` resource.
Constructed with `max_retries=0` so retries are owned by
`_create_retry_decorator` (which surfaces each attempt to the LangChain
`run_manager`). Callers reaching for this directly should set their own
retry layer.
"""
async_client: Any = Field(default=None, exclude=True)
"""Internal `fireworks.AsyncFireworks().chat.completions` resource.
Constructed with `max_retries=0`; see `client`.
"""
_sdk_client: Any = PrivateAttr(default=None)
"""Owning `fireworks.Fireworks` instance, retained so `close()` can call
into the underlying HTTPX client. The 1.x SDK does not expose lifecycle
methods on the `chat.completions` resource itself.
"""
_async_sdk_client: Any = PrivateAttr(default=None)
"""Owning `fireworks.AsyncFireworks` instance; see `_sdk_client`."""
model_name: str = Field(alias="model")
"""Model name to use."""
@@ -720,18 +771,58 @@ class ChatFireworks(BaseChatModel):
msg = "n must be 1 when streaming."
raise ValueError(msg)
client_params = {
"api_key": self.fireworks_api_key.get_secret_value(),
"base_url": self.fireworks_api_base,
"timeout": self.request_timeout,
}
api_key = self.fireworks_api_key.get_secret_value()
base_url = self.fireworks_api_base
# 0.x accepted a `(connect, read)` tuple. 1.x's SDK only accepts a
# float, `httpx.Timeout`, or `None` — normalize so existing user code
# keeps working.
if isinstance(self.request_timeout, tuple):
connect, read = self.request_timeout
timeout: Any = httpx.Timeout(read, connect=connect)
else:
timeout = self.request_timeout
# `langchain-fireworks` owns retry/backoff via `_create_retry_decorator`
# so the LangChain `run_manager` sees each attempt. Suppress the
# SDK's built-in retry layer to avoid double-retrying.
if not self.client:
self.client = Fireworks(**client_params).chat.completions
self._sdk_client = Fireworks(
api_key=api_key,
base_url=base_url,
timeout=timeout,
max_retries=0,
)
self.client = self._sdk_client.chat.completions
if not self.async_client:
self.async_client = AsyncFireworks(**client_params).chat.completions
self._async_sdk_client = AsyncFireworks(
api_key=api_key,
base_url=base_url,
timeout=timeout,
max_retries=0,
)
self.async_client = self._async_sdk_client.chat.completions
return self
def close(self) -> None:
"""Close the underlying sync HTTP client.
After calling, sync invocations on this model will raise. Async
invocations remain available until `aclose()` is also called. Safe to
call multiple times.
"""
if self._sdk_client is not None:
self._sdk_client.close()
async def aclose(self) -> None:
"""Close the underlying async HTTP client.
Releases the aiohttp-backed connector that the 1.x SDK uses by
default. Without this, transient `ChatFireworks` instances can leak
an `Unclosed connector` warning at GC if the event loop has already
stopped. Safe to call multiple times.
"""
if self._async_sdk_client is not None:
await self._async_sdk_client.close()
def _resolve_model_profile(self) -> ModelProfile | None:
return _get_default_model_profile(self.model_name) or None
@@ -808,7 +899,7 @@ class ChatFireworks(BaseChatModel):
stream = _completion_with_retry(
self, run_manager=run_manager, messages=message_dicts, **params
)
except InvalidRequestError as e:
except BadRequestError as e:
_handle_fireworks_invalid_request(e)
for chunk in stream:
if not isinstance(chunk, dict):
@@ -858,7 +949,7 @@ class ChatFireworks(BaseChatModel):
response = _completion_with_retry(
self, run_manager=run_manager, messages=message_dicts, **params
)
except InvalidRequestError as e:
except BadRequestError as e:
_handle_fireworks_invalid_request(e)
return self._create_chat_result(response)
@@ -923,7 +1014,7 @@ class ChatFireworks(BaseChatModel):
stream = await _acompletion_with_retry(
self, run_manager=run_manager, messages=message_dicts, **params
)
except InvalidRequestError as e:
except BadRequestError as e:
_handle_fireworks_invalid_request(e)
async for chunk in stream:
if not isinstance(chunk, dict):
@@ -976,7 +1067,7 @@ class ChatFireworks(BaseChatModel):
response = await _acompletion_with_retry(
self, run_manager=run_manager, messages=message_dicts, **params
)
except InvalidRequestError as e:
except BadRequestError as e:
_handle_fireworks_invalid_request(e)
return self._create_chat_result(response)