mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
feat(fireworks): service_tier init kwarg on ChatFireworks (#37143)
Add a `service_tier` init kwarg to `ChatFireworks`, mirroring the field on `ChatOpenAI`. Forwards to the Fireworks chat completions API when set, and echoes the response's tier back onto `response_metadata` and `llm_output` so callbacks and consumers can read what the server actually applied.
This commit is contained in:
@@ -321,6 +321,9 @@ def _convert_chunk_to_message_chunk(
|
||||
chunk: Mapping[str, Any], default_class: type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
choices = chunk.get("choices") or []
|
||||
response_metadata: dict[str, Any] = {"model_provider": "fireworks"}
|
||||
if service_tier := chunk.get("service_tier"):
|
||||
response_metadata["service_tier"] = service_tier
|
||||
if not choices:
|
||||
# Final chunk emitted when `stream_options.include_usage=True`:
|
||||
# `choices` is empty and the chunk carries only `usage`.
|
||||
@@ -333,7 +336,7 @@ def _convert_chunk_to_message_chunk(
|
||||
return AIMessageChunk(
|
||||
content="",
|
||||
usage_metadata=usage_metadata, # type: ignore[arg-type]
|
||||
response_metadata={"model_provider": "fireworks"},
|
||||
response_metadata=response_metadata,
|
||||
)
|
||||
choice = choices[0]
|
||||
_dict = choice["delta"]
|
||||
@@ -368,7 +371,7 @@ def _convert_chunk_to_message_chunk(
|
||||
additional_kwargs=additional_kwargs,
|
||||
tool_call_chunks=tool_call_chunks,
|
||||
usage_metadata=usage_metadata, # type: ignore[arg-type]
|
||||
response_metadata={"model_provider": "fireworks"},
|
||||
response_metadata=response_metadata,
|
||||
)
|
||||
if role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
@@ -647,6 +650,21 @@ class ChatFireworks(BaseChatModel):
|
||||
A value of `None` or `0` disables retries.
|
||||
"""
|
||||
|
||||
service_tier: str | None = None
|
||||
"""Service tier for the request.
|
||||
|
||||
Forwarded as the `service_tier` field on the Fireworks chat completions
|
||||
request when set. Pass `'priority'` to opt into Fireworks' priority tier;
|
||||
leave as `None` to use the default tier.
|
||||
|
||||
To use Fireworks' fast mode instead, select a fast-routed `model`; fast mode
|
||||
is not controlled by this field. See Fireworks'
|
||||
[serverless product docs](https://docs.fireworks.ai/guides/serverless-products)
|
||||
for the current list of fast routers and tiers.
|
||||
|
||||
!!! version-added "Added in `langchain-fireworks` 1.3.0"
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
)
|
||||
@@ -701,6 +719,8 @@ class ChatFireworks(BaseChatModel):
|
||||
params["temperature"] = self.temperature
|
||||
if self.max_tokens is not None:
|
||||
params["max_tokens"] = self.max_tokens
|
||||
if self.service_tier is not None:
|
||||
params["service_tier"] = self.service_tier
|
||||
return params
|
||||
|
||||
def _get_ls_params(
|
||||
@@ -819,9 +839,11 @@ class ChatFireworks(BaseChatModel):
|
||||
if not isinstance(response, dict):
|
||||
response = response.model_dump()
|
||||
token_usage = response.get("usage", {})
|
||||
service_tier = response.get("service_tier")
|
||||
for res in response["choices"]:
|
||||
message = _convert_dict_to_message(res["message"])
|
||||
if token_usage and isinstance(message, AIMessage):
|
||||
if isinstance(message, AIMessage):
|
||||
if token_usage:
|
||||
message.usage_metadata = {
|
||||
"input_tokens": token_usage.get("prompt_tokens", 0),
|
||||
"output_tokens": token_usage.get("completion_tokens", 0),
|
||||
@@ -829,6 +851,8 @@ class ChatFireworks(BaseChatModel):
|
||||
}
|
||||
message.response_metadata["model_provider"] = "fireworks"
|
||||
message.response_metadata["model_name"] = self.model_name
|
||||
if service_tier:
|
||||
message.response_metadata["service_tier"] = service_tier
|
||||
generation_info = {"finish_reason": res.get("finish_reason")}
|
||||
if "logprobs" in res:
|
||||
generation_info["logprobs"] = res["logprobs"]
|
||||
@@ -841,6 +865,8 @@ class ChatFireworks(BaseChatModel):
|
||||
"token_usage": token_usage,
|
||||
"system_fingerprint": response.get("system_fingerprint", ""),
|
||||
}
|
||||
if service_tier:
|
||||
llm_output["service_tier"] = service_tier
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
async def _astream(
|
||||
|
||||
@@ -974,3 +974,142 @@ class TestStreamUsage:
|
||||
"output_tokens": 2,
|
||||
"total_tokens": 7,
|
||||
}
|
||||
|
||||
|
||||
class TestServiceTier:
|
||||
"""Tests for the `service_tier` field plumbing."""
|
||||
|
||||
def test_service_tier_omitted_by_default(self) -> None:
|
||||
model = _make_model()
|
||||
assert "service_tier" not in model._default_params
|
||||
|
||||
def test_service_tier_in_default_params_when_set(self) -> None:
|
||||
model = _make_model(service_tier="priority")
|
||||
assert model._default_params["service_tier"] == "priority"
|
||||
|
||||
def test_service_tier_passed_to_client_when_set(self) -> None:
|
||||
model = _make_model(service_tier="priority")
|
||||
model.client = MagicMock()
|
||||
model.client.create.return_value = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {"role": "assistant", "content": "hi"},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
|
||||
}
|
||||
model.invoke("Hello")
|
||||
call_kwargs = model.client.create.call_args[1]
|
||||
assert call_kwargs["service_tier"] == "priority"
|
||||
|
||||
def test_service_tier_not_passed_when_unset(self) -> None:
|
||||
model = _make_model()
|
||||
model.client = MagicMock()
|
||||
model.client.create.return_value = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {"role": "assistant", "content": "hi"},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
|
||||
}
|
||||
model.invoke("Hello")
|
||||
call_kwargs = model.client.create.call_args[1]
|
||||
assert "service_tier" not in call_kwargs
|
||||
|
||||
def test_service_tier_echoed_in_response_metadata(self) -> None:
|
||||
model = _make_model(service_tier="priority")
|
||||
model.client = MagicMock()
|
||||
model.client.create.return_value = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {"role": "assistant", "content": "hi"},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
|
||||
"service_tier": "priority",
|
||||
}
|
||||
result = model.invoke("Hello")
|
||||
assert isinstance(result, AIMessage)
|
||||
assert result.response_metadata["service_tier"] == "priority"
|
||||
|
||||
def test_service_tier_echoed_in_stream_chunks(self) -> None:
|
||||
model = _make_model(service_tier="priority")
|
||||
model.client = MagicMock()
|
||||
chunks: list[dict[str, Any]] = [
|
||||
{
|
||||
"choices": [{"delta": {"role": "assistant", "content": "hi"}}],
|
||||
"service_tier": "priority",
|
||||
},
|
||||
{
|
||||
"choices": [],
|
||||
"usage": {
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 2,
|
||||
},
|
||||
"service_tier": "priority",
|
||||
},
|
||||
]
|
||||
model.client.create.return_value = iter(chunks)
|
||||
out = list(model.stream("Hello"))
|
||||
tagged = [c for c in out if c.response_metadata.get("service_tier")]
|
||||
assert tagged
|
||||
assert all(c.response_metadata["service_tier"] == "priority" for c in tagged)
|
||||
|
||||
def test_service_tier_absent_when_not_in_response(self) -> None:
|
||||
model = _make_model()
|
||||
model.client = MagicMock()
|
||||
model.client.create.return_value = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {"role": "assistant", "content": "hi"},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
|
||||
}
|
||||
result = model.invoke("Hello")
|
||||
assert isinstance(result, AIMessage)
|
||||
assert "service_tier" not in result.response_metadata
|
||||
|
||||
def test_service_tier_in_llm_output_when_response_carries_it(self) -> None:
|
||||
model = _make_model(service_tier="priority")
|
||||
chat_result = model._create_chat_result(
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"message": {"role": "assistant", "content": "hi"},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 2,
|
||||
},
|
||||
"service_tier": "priority",
|
||||
}
|
||||
)
|
||||
assert chat_result.llm_output is not None
|
||||
assert chat_result.llm_output["service_tier"] == "priority"
|
||||
|
||||
def test_service_tier_not_inferred_from_request(self) -> None:
|
||||
"""Init-set tier must not leak into response_metadata if API omits it."""
|
||||
model = _make_model(service_tier="priority")
|
||||
model.client = MagicMock()
|
||||
model.client.create.return_value = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {"role": "assistant", "content": "hi"},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
|
||||
}
|
||||
result = model.invoke("Hello")
|
||||
assert isinstance(result, AIMessage)
|
||||
assert "service_tier" not in result.response_metadata
|
||||
|
||||
Reference in New Issue
Block a user