feat(fireworks): service_tier init kwarg on ChatFireworks (#37143)

Add a `service_tier` init kwarg to `ChatFireworks`, mirroring the field
on `ChatOpenAI`. Forwards to the Fireworks chat completions API when
set, and echoes the response's tier back onto `response_metadata` and
`llm_output` so callbacks and consumers can read what the server
actually applied.
This commit is contained in:
Mason Daugherty
2026-05-01 16:42:34 -04:00
committed by GitHub
parent 91842db32b
commit 390843bd84
2 changed files with 175 additions and 10 deletions

View File

@@ -321,6 +321,9 @@ def _convert_chunk_to_message_chunk(
chunk: Mapping[str, Any], default_class: type[BaseMessageChunk] chunk: Mapping[str, Any], default_class: type[BaseMessageChunk]
) -> BaseMessageChunk: ) -> BaseMessageChunk:
choices = chunk.get("choices") or [] choices = chunk.get("choices") or []
response_metadata: dict[str, Any] = {"model_provider": "fireworks"}
if service_tier := chunk.get("service_tier"):
response_metadata["service_tier"] = service_tier
if not choices: if not choices:
# Final chunk emitted when `stream_options.include_usage=True`: # Final chunk emitted when `stream_options.include_usage=True`:
# `choices` is empty and the chunk carries only `usage`. # `choices` is empty and the chunk carries only `usage`.
@@ -333,7 +336,7 @@ def _convert_chunk_to_message_chunk(
return AIMessageChunk( return AIMessageChunk(
content="", content="",
usage_metadata=usage_metadata, # type: ignore[arg-type] usage_metadata=usage_metadata, # type: ignore[arg-type]
response_metadata={"model_provider": "fireworks"}, response_metadata=response_metadata,
) )
choice = choices[0] choice = choices[0]
_dict = choice["delta"] _dict = choice["delta"]
@@ -368,7 +371,7 @@ def _convert_chunk_to_message_chunk(
additional_kwargs=additional_kwargs, additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks, tool_call_chunks=tool_call_chunks,
usage_metadata=usage_metadata, # type: ignore[arg-type] usage_metadata=usage_metadata, # type: ignore[arg-type]
response_metadata={"model_provider": "fireworks"}, response_metadata=response_metadata,
) )
if role == "system" or default_class == SystemMessageChunk: if role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content) return SystemMessageChunk(content=content)
@@ -647,6 +650,21 @@ class ChatFireworks(BaseChatModel):
A value of `None` or `0` disables retries. A value of `None` or `0` disables retries.
""" """
service_tier: str | None = None
"""Service tier for the request.
Forwarded as the `service_tier` field on the Fireworks chat completions
request when set. Pass `'priority'` to opt into Fireworks' priority tier;
leave as `None` to use the default tier.
To use Fireworks' fast mode instead, select a fast-routed `model`; fast mode
is not controlled by this field. See Fireworks'
[serverless product docs](https://docs.fireworks.ai/guides/serverless-products)
for the current list of fast routers and tiers.
!!! version-added "Added in `langchain-fireworks` 1.3.0"
"""
model_config = ConfigDict( model_config = ConfigDict(
populate_by_name=True, populate_by_name=True,
) )
@@ -701,6 +719,8 @@ class ChatFireworks(BaseChatModel):
params["temperature"] = self.temperature params["temperature"] = self.temperature
if self.max_tokens is not None: if self.max_tokens is not None:
params["max_tokens"] = self.max_tokens params["max_tokens"] = self.max_tokens
if self.service_tier is not None:
params["service_tier"] = self.service_tier
return params return params
def _get_ls_params( def _get_ls_params(
@@ -819,9 +839,11 @@ class ChatFireworks(BaseChatModel):
if not isinstance(response, dict): if not isinstance(response, dict):
response = response.model_dump() response = response.model_dump()
token_usage = response.get("usage", {}) token_usage = response.get("usage", {})
service_tier = response.get("service_tier")
for res in response["choices"]: for res in response["choices"]:
message = _convert_dict_to_message(res["message"]) message = _convert_dict_to_message(res["message"])
if token_usage and isinstance(message, AIMessage): if isinstance(message, AIMessage):
if token_usage:
message.usage_metadata = { message.usage_metadata = {
"input_tokens": token_usage.get("prompt_tokens", 0), "input_tokens": token_usage.get("prompt_tokens", 0),
"output_tokens": token_usage.get("completion_tokens", 0), "output_tokens": token_usage.get("completion_tokens", 0),
@@ -829,6 +851,8 @@ class ChatFireworks(BaseChatModel):
} }
message.response_metadata["model_provider"] = "fireworks" message.response_metadata["model_provider"] = "fireworks"
message.response_metadata["model_name"] = self.model_name message.response_metadata["model_name"] = self.model_name
if service_tier:
message.response_metadata["service_tier"] = service_tier
generation_info = {"finish_reason": res.get("finish_reason")} generation_info = {"finish_reason": res.get("finish_reason")}
if "logprobs" in res: if "logprobs" in res:
generation_info["logprobs"] = res["logprobs"] generation_info["logprobs"] = res["logprobs"]
@@ -841,6 +865,8 @@ class ChatFireworks(BaseChatModel):
"token_usage": token_usage, "token_usage": token_usage,
"system_fingerprint": response.get("system_fingerprint", ""), "system_fingerprint": response.get("system_fingerprint", ""),
} }
if service_tier:
llm_output["service_tier"] = service_tier
return ChatResult(generations=generations, llm_output=llm_output) return ChatResult(generations=generations, llm_output=llm_output)
async def _astream( async def _astream(

View File

@@ -974,3 +974,142 @@ class TestStreamUsage:
"output_tokens": 2, "output_tokens": 2,
"total_tokens": 7, "total_tokens": 7,
} }
class TestServiceTier:
"""Tests for the `service_tier` field plumbing."""
def test_service_tier_omitted_by_default(self) -> None:
model = _make_model()
assert "service_tier" not in model._default_params
def test_service_tier_in_default_params_when_set(self) -> None:
model = _make_model(service_tier="priority")
assert model._default_params["service_tier"] == "priority"
def test_service_tier_passed_to_client_when_set(self) -> None:
model = _make_model(service_tier="priority")
model.client = MagicMock()
model.client.create.return_value = {
"choices": [
{
"message": {"role": "assistant", "content": "hi"},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
}
model.invoke("Hello")
call_kwargs = model.client.create.call_args[1]
assert call_kwargs["service_tier"] == "priority"
def test_service_tier_not_passed_when_unset(self) -> None:
model = _make_model()
model.client = MagicMock()
model.client.create.return_value = {
"choices": [
{
"message": {"role": "assistant", "content": "hi"},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
}
model.invoke("Hello")
call_kwargs = model.client.create.call_args[1]
assert "service_tier" not in call_kwargs
def test_service_tier_echoed_in_response_metadata(self) -> None:
model = _make_model(service_tier="priority")
model.client = MagicMock()
model.client.create.return_value = {
"choices": [
{
"message": {"role": "assistant", "content": "hi"},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
"service_tier": "priority",
}
result = model.invoke("Hello")
assert isinstance(result, AIMessage)
assert result.response_metadata["service_tier"] == "priority"
def test_service_tier_echoed_in_stream_chunks(self) -> None:
model = _make_model(service_tier="priority")
model.client = MagicMock()
chunks: list[dict[str, Any]] = [
{
"choices": [{"delta": {"role": "assistant", "content": "hi"}}],
"service_tier": "priority",
},
{
"choices": [],
"usage": {
"prompt_tokens": 1,
"completion_tokens": 1,
"total_tokens": 2,
},
"service_tier": "priority",
},
]
model.client.create.return_value = iter(chunks)
out = list(model.stream("Hello"))
tagged = [c for c in out if c.response_metadata.get("service_tier")]
assert tagged
assert all(c.response_metadata["service_tier"] == "priority" for c in tagged)
def test_service_tier_absent_when_not_in_response(self) -> None:
model = _make_model()
model.client = MagicMock()
model.client.create.return_value = {
"choices": [
{
"message": {"role": "assistant", "content": "hi"},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
}
result = model.invoke("Hello")
assert isinstance(result, AIMessage)
assert "service_tier" not in result.response_metadata
def test_service_tier_in_llm_output_when_response_carries_it(self) -> None:
model = _make_model(service_tier="priority")
chat_result = model._create_chat_result(
{
"choices": [
{
"message": {"role": "assistant", "content": "hi"},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 1,
"completion_tokens": 1,
"total_tokens": 2,
},
"service_tier": "priority",
}
)
assert chat_result.llm_output is not None
assert chat_result.llm_output["service_tier"] == "priority"
def test_service_tier_not_inferred_from_request(self) -> None:
"""Init-set tier must not leak into response_metadata if API omits it."""
model = _make_model(service_tier="priority")
model.client = MagicMock()
model.client.create.return_value = {
"choices": [
{
"message": {"role": "assistant", "content": "hi"},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
}
result = model.invoke("Hello")
assert isinstance(result, AIMessage)
assert "service_tier" not in result.response_metadata