feat(fireworks): service_tier init kwarg on ChatFireworks (#37143)

Add a `service_tier` init kwarg to `ChatFireworks`, mirroring the field
on `ChatOpenAI`. Forwards to the Fireworks chat completions API when
set, and echoes the response's tier back onto `response_metadata` and
`llm_output` so callbacks and consumers can read what the server
actually applied.
This commit is contained in:
Mason Daugherty
2026-05-01 16:42:34 -04:00
committed by GitHub
parent 91842db32b
commit 390843bd84
2 changed files with 175 additions and 10 deletions

View File

@@ -321,6 +321,9 @@ def _convert_chunk_to_message_chunk(
chunk: Mapping[str, Any], default_class: type[BaseMessageChunk]
) -> BaseMessageChunk:
choices = chunk.get("choices") or []
response_metadata: dict[str, Any] = {"model_provider": "fireworks"}
if service_tier := chunk.get("service_tier"):
response_metadata["service_tier"] = service_tier
if not choices:
# Final chunk emitted when `stream_options.include_usage=True`:
# `choices` is empty and the chunk carries only `usage`.
@@ -333,7 +336,7 @@ def _convert_chunk_to_message_chunk(
return AIMessageChunk(
content="",
usage_metadata=usage_metadata, # type: ignore[arg-type]
response_metadata={"model_provider": "fireworks"},
response_metadata=response_metadata,
)
choice = choices[0]
_dict = choice["delta"]
@@ -368,7 +371,7 @@ def _convert_chunk_to_message_chunk(
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks,
usage_metadata=usage_metadata, # type: ignore[arg-type]
response_metadata={"model_provider": "fireworks"},
response_metadata=response_metadata,
)
if role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
@@ -647,6 +650,21 @@ class ChatFireworks(BaseChatModel):
A value of `None` or `0` disables retries.
"""
service_tier: str | None = None
"""Service tier for the request.
Forwarded as the `service_tier` field on the Fireworks chat completions
request when set. Pass `'priority'` to opt into Fireworks' priority tier;
leave as `None` to use the default tier.
To use Fireworks' fast mode instead, select a fast-routed `model`; fast mode
is not controlled by this field. See Fireworks'
[serverless product docs](https://docs.fireworks.ai/guides/serverless-products)
for the current list of fast routers and tiers.
!!! version-added "Added in `langchain-fireworks` 1.3.0"
"""
model_config = ConfigDict(
populate_by_name=True,
)
@@ -701,6 +719,8 @@ class ChatFireworks(BaseChatModel):
params["temperature"] = self.temperature
if self.max_tokens is not None:
params["max_tokens"] = self.max_tokens
if self.service_tier is not None:
params["service_tier"] = self.service_tier
return params
def _get_ls_params(
@@ -819,16 +839,20 @@ class ChatFireworks(BaseChatModel):
if not isinstance(response, dict):
response = response.model_dump()
token_usage = response.get("usage", {})
service_tier = response.get("service_tier")
for res in response["choices"]:
message = _convert_dict_to_message(res["message"])
if token_usage and isinstance(message, AIMessage):
message.usage_metadata = {
"input_tokens": token_usage.get("prompt_tokens", 0),
"output_tokens": token_usage.get("completion_tokens", 0),
"total_tokens": token_usage.get("total_tokens", 0),
}
message.response_metadata["model_provider"] = "fireworks"
message.response_metadata["model_name"] = self.model_name
if isinstance(message, AIMessage):
if token_usage:
message.usage_metadata = {
"input_tokens": token_usage.get("prompt_tokens", 0),
"output_tokens": token_usage.get("completion_tokens", 0),
"total_tokens": token_usage.get("total_tokens", 0),
}
message.response_metadata["model_provider"] = "fireworks"
message.response_metadata["model_name"] = self.model_name
if service_tier:
message.response_metadata["service_tier"] = service_tier
generation_info = {"finish_reason": res.get("finish_reason")}
if "logprobs" in res:
generation_info["logprobs"] = res["logprobs"]
@@ -841,6 +865,8 @@ class ChatFireworks(BaseChatModel):
"token_usage": token_usage,
"system_fingerprint": response.get("system_fingerprint", ""),
}
if service_tier:
llm_output["service_tier"] = service_tier
return ChatResult(generations=generations, llm_output=llm_output)
async def _astream(