feat(fireworks): service_tier init kwarg on ChatFireworks (#37143)

Add a `service_tier` init kwarg to `ChatFireworks`, mirroring the field
on `ChatOpenAI`. Forwards to the Fireworks chat completions API when
set, and echoes the response's tier back onto `response_metadata` and
`llm_output` so callbacks and consumers can read what the server
actually applied.
This commit is contained in:
Mason Daugherty
2026-05-01 16:42:34 -04:00
committed by GitHub
parent 91842db32b
commit 390843bd84
2 changed files with 175 additions and 10 deletions

View File

@@ -974,3 +974,142 @@ class TestStreamUsage:
"output_tokens": 2,
"total_tokens": 7,
}
class TestServiceTier:
"""Tests for the `service_tier` field plumbing."""
def test_service_tier_omitted_by_default(self) -> None:
model = _make_model()
assert "service_tier" not in model._default_params
def test_service_tier_in_default_params_when_set(self) -> None:
model = _make_model(service_tier="priority")
assert model._default_params["service_tier"] == "priority"
def test_service_tier_passed_to_client_when_set(self) -> None:
model = _make_model(service_tier="priority")
model.client = MagicMock()
model.client.create.return_value = {
"choices": [
{
"message": {"role": "assistant", "content": "hi"},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
}
model.invoke("Hello")
call_kwargs = model.client.create.call_args[1]
assert call_kwargs["service_tier"] == "priority"
def test_service_tier_not_passed_when_unset(self) -> None:
model = _make_model()
model.client = MagicMock()
model.client.create.return_value = {
"choices": [
{
"message": {"role": "assistant", "content": "hi"},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
}
model.invoke("Hello")
call_kwargs = model.client.create.call_args[1]
assert "service_tier" not in call_kwargs
def test_service_tier_echoed_in_response_metadata(self) -> None:
model = _make_model(service_tier="priority")
model.client = MagicMock()
model.client.create.return_value = {
"choices": [
{
"message": {"role": "assistant", "content": "hi"},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
"service_tier": "priority",
}
result = model.invoke("Hello")
assert isinstance(result, AIMessage)
assert result.response_metadata["service_tier"] == "priority"
def test_service_tier_echoed_in_stream_chunks(self) -> None:
model = _make_model(service_tier="priority")
model.client = MagicMock()
chunks: list[dict[str, Any]] = [
{
"choices": [{"delta": {"role": "assistant", "content": "hi"}}],
"service_tier": "priority",
},
{
"choices": [],
"usage": {
"prompt_tokens": 1,
"completion_tokens": 1,
"total_tokens": 2,
},
"service_tier": "priority",
},
]
model.client.create.return_value = iter(chunks)
out = list(model.stream("Hello"))
tagged = [c for c in out if c.response_metadata.get("service_tier")]
assert tagged
assert all(c.response_metadata["service_tier"] == "priority" for c in tagged)
def test_service_tier_absent_when_not_in_response(self) -> None:
model = _make_model()
model.client = MagicMock()
model.client.create.return_value = {
"choices": [
{
"message": {"role": "assistant", "content": "hi"},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
}
result = model.invoke("Hello")
assert isinstance(result, AIMessage)
assert "service_tier" not in result.response_metadata
def test_service_tier_in_llm_output_when_response_carries_it(self) -> None:
model = _make_model(service_tier="priority")
chat_result = model._create_chat_result(
{
"choices": [
{
"message": {"role": "assistant", "content": "hi"},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 1,
"completion_tokens": 1,
"total_tokens": 2,
},
"service_tier": "priority",
}
)
assert chat_result.llm_output is not None
assert chat_result.llm_output["service_tier"] == "priority"
def test_service_tier_not_inferred_from_request(self) -> None:
"""Init-set tier must not leak into response_metadata if API omits it."""
model = _make_model(service_tier="priority")
model.client = MagicMock()
model.client.create.return_value = {
"choices": [
{
"message": {"role": "assistant", "content": "hi"},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
}
result = model.invoke("Hello")
assert isinstance(result, AIMessage)
assert "service_tier" not in result.response_metadata