mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
feat(fireworks): service_tier init kwarg on ChatFireworks (#37143)
Add a `service_tier` init kwarg to `ChatFireworks`, mirroring the field on `ChatOpenAI`. Forwards to the Fireworks chat completions API when set, and echoes the response's tier back onto `response_metadata` and `llm_output` so callbacks and consumers can read what the server actually applied.
This commit is contained in:
@@ -321,6 +321,9 @@ def _convert_chunk_to_message_chunk(
|
||||
chunk: Mapping[str, Any], default_class: type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
choices = chunk.get("choices") or []
|
||||
response_metadata: dict[str, Any] = {"model_provider": "fireworks"}
|
||||
if service_tier := chunk.get("service_tier"):
|
||||
response_metadata["service_tier"] = service_tier
|
||||
if not choices:
|
||||
# Final chunk emitted when `stream_options.include_usage=True`:
|
||||
# `choices` is empty and the chunk carries only `usage`.
|
||||
@@ -333,7 +336,7 @@ def _convert_chunk_to_message_chunk(
|
||||
return AIMessageChunk(
|
||||
content="",
|
||||
usage_metadata=usage_metadata, # type: ignore[arg-type]
|
||||
response_metadata={"model_provider": "fireworks"},
|
||||
response_metadata=response_metadata,
|
||||
)
|
||||
choice = choices[0]
|
||||
_dict = choice["delta"]
|
||||
@@ -368,7 +371,7 @@ def _convert_chunk_to_message_chunk(
|
||||
additional_kwargs=additional_kwargs,
|
||||
tool_call_chunks=tool_call_chunks,
|
||||
usage_metadata=usage_metadata, # type: ignore[arg-type]
|
||||
response_metadata={"model_provider": "fireworks"},
|
||||
response_metadata=response_metadata,
|
||||
)
|
||||
if role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
@@ -647,6 +650,21 @@ class ChatFireworks(BaseChatModel):
|
||||
A value of `None` or `0` disables retries.
|
||||
"""
|
||||
|
||||
service_tier: str | None = None
|
||||
"""Service tier for the request.
|
||||
|
||||
Forwarded as the `service_tier` field on the Fireworks chat completions
|
||||
request when set. Pass `'priority'` to opt into Fireworks' priority tier;
|
||||
leave as `None` to use the default tier.
|
||||
|
||||
To use Fireworks' fast mode instead, select a fast-routed `model`; fast mode
|
||||
is not controlled by this field. See Fireworks'
|
||||
[serverless product docs](https://docs.fireworks.ai/guides/serverless-products)
|
||||
for the current list of fast routers and tiers.
|
||||
|
||||
!!! version-added "Added in `langchain-fireworks` 1.3.0"
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
)
|
||||
@@ -701,6 +719,8 @@ class ChatFireworks(BaseChatModel):
|
||||
params["temperature"] = self.temperature
|
||||
if self.max_tokens is not None:
|
||||
params["max_tokens"] = self.max_tokens
|
||||
if self.service_tier is not None:
|
||||
params["service_tier"] = self.service_tier
|
||||
return params
|
||||
|
||||
def _get_ls_params(
|
||||
@@ -819,16 +839,20 @@ class ChatFireworks(BaseChatModel):
|
||||
if not isinstance(response, dict):
|
||||
response = response.model_dump()
|
||||
token_usage = response.get("usage", {})
|
||||
service_tier = response.get("service_tier")
|
||||
for res in response["choices"]:
|
||||
message = _convert_dict_to_message(res["message"])
|
||||
if token_usage and isinstance(message, AIMessage):
|
||||
message.usage_metadata = {
|
||||
"input_tokens": token_usage.get("prompt_tokens", 0),
|
||||
"output_tokens": token_usage.get("completion_tokens", 0),
|
||||
"total_tokens": token_usage.get("total_tokens", 0),
|
||||
}
|
||||
message.response_metadata["model_provider"] = "fireworks"
|
||||
message.response_metadata["model_name"] = self.model_name
|
||||
if isinstance(message, AIMessage):
|
||||
if token_usage:
|
||||
message.usage_metadata = {
|
||||
"input_tokens": token_usage.get("prompt_tokens", 0),
|
||||
"output_tokens": token_usage.get("completion_tokens", 0),
|
||||
"total_tokens": token_usage.get("total_tokens", 0),
|
||||
}
|
||||
message.response_metadata["model_provider"] = "fireworks"
|
||||
message.response_metadata["model_name"] = self.model_name
|
||||
if service_tier:
|
||||
message.response_metadata["service_tier"] = service_tier
|
||||
generation_info = {"finish_reason": res.get("finish_reason")}
|
||||
if "logprobs" in res:
|
||||
generation_info["logprobs"] = res["logprobs"]
|
||||
@@ -841,6 +865,8 @@ class ChatFireworks(BaseChatModel):
|
||||
"token_usage": token_usage,
|
||||
"system_fingerprint": response.get("system_fingerprint", ""),
|
||||
}
|
||||
if service_tier:
|
||||
llm_output["service_tier"] = service_tier
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
async def _astream(
|
||||
|
||||
Reference in New Issue
Block a user