mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
feat(fireworks): service_tier init kwarg on ChatFireworks (#37143)
Add a `service_tier` init kwarg to `ChatFireworks`, mirroring the field on `ChatOpenAI`. Forwards to the Fireworks chat completions API when set, and echoes the response's tier back onto `response_metadata` and `llm_output` so callbacks and consumers can read what the server actually applied.
This commit is contained in:
@@ -974,3 +974,142 @@ class TestStreamUsage:
|
||||
"output_tokens": 2,
|
||||
"total_tokens": 7,
|
||||
}
|
||||
|
||||
|
||||
class TestServiceTier:
|
||||
"""Tests for the `service_tier` field plumbing."""
|
||||
|
||||
def test_service_tier_omitted_by_default(self) -> None:
|
||||
model = _make_model()
|
||||
assert "service_tier" not in model._default_params
|
||||
|
||||
def test_service_tier_in_default_params_when_set(self) -> None:
|
||||
model = _make_model(service_tier="priority")
|
||||
assert model._default_params["service_tier"] == "priority"
|
||||
|
||||
def test_service_tier_passed_to_client_when_set(self) -> None:
|
||||
model = _make_model(service_tier="priority")
|
||||
model.client = MagicMock()
|
||||
model.client.create.return_value = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {"role": "assistant", "content": "hi"},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
|
||||
}
|
||||
model.invoke("Hello")
|
||||
call_kwargs = model.client.create.call_args[1]
|
||||
assert call_kwargs["service_tier"] == "priority"
|
||||
|
||||
def test_service_tier_not_passed_when_unset(self) -> None:
|
||||
model = _make_model()
|
||||
model.client = MagicMock()
|
||||
model.client.create.return_value = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {"role": "assistant", "content": "hi"},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
|
||||
}
|
||||
model.invoke("Hello")
|
||||
call_kwargs = model.client.create.call_args[1]
|
||||
assert "service_tier" not in call_kwargs
|
||||
|
||||
def test_service_tier_echoed_in_response_metadata(self) -> None:
|
||||
model = _make_model(service_tier="priority")
|
||||
model.client = MagicMock()
|
||||
model.client.create.return_value = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {"role": "assistant", "content": "hi"},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
|
||||
"service_tier": "priority",
|
||||
}
|
||||
result = model.invoke("Hello")
|
||||
assert isinstance(result, AIMessage)
|
||||
assert result.response_metadata["service_tier"] == "priority"
|
||||
|
||||
def test_service_tier_echoed_in_stream_chunks(self) -> None:
|
||||
model = _make_model(service_tier="priority")
|
||||
model.client = MagicMock()
|
||||
chunks: list[dict[str, Any]] = [
|
||||
{
|
||||
"choices": [{"delta": {"role": "assistant", "content": "hi"}}],
|
||||
"service_tier": "priority",
|
||||
},
|
||||
{
|
||||
"choices": [],
|
||||
"usage": {
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 2,
|
||||
},
|
||||
"service_tier": "priority",
|
||||
},
|
||||
]
|
||||
model.client.create.return_value = iter(chunks)
|
||||
out = list(model.stream("Hello"))
|
||||
tagged = [c for c in out if c.response_metadata.get("service_tier")]
|
||||
assert tagged
|
||||
assert all(c.response_metadata["service_tier"] == "priority" for c in tagged)
|
||||
|
||||
def test_service_tier_absent_when_not_in_response(self) -> None:
|
||||
model = _make_model()
|
||||
model.client = MagicMock()
|
||||
model.client.create.return_value = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {"role": "assistant", "content": "hi"},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
|
||||
}
|
||||
result = model.invoke("Hello")
|
||||
assert isinstance(result, AIMessage)
|
||||
assert "service_tier" not in result.response_metadata
|
||||
|
||||
def test_service_tier_in_llm_output_when_response_carries_it(self) -> None:
|
||||
model = _make_model(service_tier="priority")
|
||||
chat_result = model._create_chat_result(
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"message": {"role": "assistant", "content": "hi"},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 2,
|
||||
},
|
||||
"service_tier": "priority",
|
||||
}
|
||||
)
|
||||
assert chat_result.llm_output is not None
|
||||
assert chat_result.llm_output["service_tier"] == "priority"
|
||||
|
||||
def test_service_tier_not_inferred_from_request(self) -> None:
|
||||
"""Init-set tier must not leak into response_metadata if API omits it."""
|
||||
model = _make_model(service_tier="priority")
|
||||
model.client = MagicMock()
|
||||
model.client.create.return_value = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {"role": "assistant", "content": "hi"},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
|
||||
}
|
||||
result = model.invoke("Hello")
|
||||
assert isinstance(result, AIMessage)
|
||||
assert "service_tier" not in result.response_metadata
|
||||
|
||||
Reference in New Issue
Block a user