diff --git a/libs/partners/fireworks/langchain_fireworks/chat_models.py b/libs/partners/fireworks/langchain_fireworks/chat_models.py index 5a4b7d7e1e9..b05c83f3cdd 100644 --- a/libs/partners/fireworks/langchain_fireworks/chat_models.py +++ b/libs/partners/fireworks/langchain_fireworks/chat_models.py @@ -321,6 +321,9 @@ def _convert_chunk_to_message_chunk( chunk: Mapping[str, Any], default_class: type[BaseMessageChunk] ) -> BaseMessageChunk: choices = chunk.get("choices") or [] + response_metadata: dict[str, Any] = {"model_provider": "fireworks"} + if service_tier := chunk.get("service_tier"): + response_metadata["service_tier"] = service_tier if not choices: # Final chunk emitted when `stream_options.include_usage=True`: # `choices` is empty and the chunk carries only `usage`. @@ -333,7 +336,7 @@ def _convert_chunk_to_message_chunk( return AIMessageChunk( content="", usage_metadata=usage_metadata, # type: ignore[arg-type] - response_metadata={"model_provider": "fireworks"}, + response_metadata=response_metadata, ) choice = choices[0] _dict = choice["delta"] @@ -368,7 +371,7 @@ def _convert_chunk_to_message_chunk( additional_kwargs=additional_kwargs, tool_call_chunks=tool_call_chunks, usage_metadata=usage_metadata, # type: ignore[arg-type] - response_metadata={"model_provider": "fireworks"}, + response_metadata=response_metadata, ) if role == "system" or default_class == SystemMessageChunk: return SystemMessageChunk(content=content) @@ -647,6 +650,21 @@ class ChatFireworks(BaseChatModel): A value of `None` or `0` disables retries. """ + service_tier: str | None = None + """Service tier for the request. + + Forwarded as the `service_tier` field on the Fireworks chat completions + request when set. Pass `'priority'` to opt into Fireworks' priority tier; + leave as `None` to use the default tier. + + To use Fireworks' fast mode instead, select a fast-routed `model`; fast mode + is not controlled by this field. See Fireworks' + [serverless product docs](https://docs.fireworks.ai/guides/serverless-products) + for the current list of fast routers and tiers. + + !!! version-added "Added in `langchain-fireworks` 1.3.0" + """ + model_config = ConfigDict( populate_by_name=True, ) @@ -701,6 +719,8 @@ class ChatFireworks(BaseChatModel): params["temperature"] = self.temperature if self.max_tokens is not None: params["max_tokens"] = self.max_tokens + if self.service_tier is not None: + params["service_tier"] = self.service_tier return params def _get_ls_params( @@ -819,16 +839,20 @@ class ChatFireworks(BaseChatModel): if not isinstance(response, dict): response = response.model_dump() token_usage = response.get("usage", {}) + service_tier = response.get("service_tier") for res in response["choices"]: message = _convert_dict_to_message(res["message"]) - if token_usage and isinstance(message, AIMessage): - message.usage_metadata = { - "input_tokens": token_usage.get("prompt_tokens", 0), - "output_tokens": token_usage.get("completion_tokens", 0), - "total_tokens": token_usage.get("total_tokens", 0), - } - message.response_metadata["model_provider"] = "fireworks" - message.response_metadata["model_name"] = self.model_name + if isinstance(message, AIMessage): + if token_usage: + message.usage_metadata = { + "input_tokens": token_usage.get("prompt_tokens", 0), + "output_tokens": token_usage.get("completion_tokens", 0), + "total_tokens": token_usage.get("total_tokens", 0), + } + message.response_metadata["model_provider"] = "fireworks" + message.response_metadata["model_name"] = self.model_name + if service_tier: + message.response_metadata["service_tier"] = service_tier generation_info = {"finish_reason": res.get("finish_reason")} if "logprobs" in res: generation_info["logprobs"] = res["logprobs"] @@ -841,6 +865,8 @@ class ChatFireworks(BaseChatModel): "token_usage": token_usage, "system_fingerprint": response.get("system_fingerprint", ""), } + if service_tier: + llm_output["service_tier"] = service_tier return ChatResult(generations=generations, llm_output=llm_output) async def _astream( diff --git a/libs/partners/fireworks/tests/unit_tests/test_chat_models.py b/libs/partners/fireworks/tests/unit_tests/test_chat_models.py index 23d0f1520e6..785928214fe 100644 --- a/libs/partners/fireworks/tests/unit_tests/test_chat_models.py +++ b/libs/partners/fireworks/tests/unit_tests/test_chat_models.py @@ -974,3 +974,142 @@ class TestStreamUsage: "output_tokens": 2, "total_tokens": 7, } + + +class TestServiceTier: + """Tests for the `service_tier` field plumbing.""" + + def test_service_tier_omitted_by_default(self) -> None: + model = _make_model() + assert "service_tier" not in model._default_params + + def test_service_tier_in_default_params_when_set(self) -> None: + model = _make_model(service_tier="priority") + assert model._default_params["service_tier"] == "priority" + + def test_service_tier_passed_to_client_when_set(self) -> None: + model = _make_model(service_tier="priority") + model.client = MagicMock() + model.client.create.return_value = { + "choices": [ + { + "message": {"role": "assistant", "content": "hi"}, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, + } + model.invoke("Hello") + call_kwargs = model.client.create.call_args[1] + assert call_kwargs["service_tier"] == "priority" + + def test_service_tier_not_passed_when_unset(self) -> None: + model = _make_model() + model.client = MagicMock() + model.client.create.return_value = { + "choices": [ + { + "message": {"role": "assistant", "content": "hi"}, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, + } + model.invoke("Hello") + call_kwargs = model.client.create.call_args[1] + assert "service_tier" not in call_kwargs + + def test_service_tier_echoed_in_response_metadata(self) -> None: + model = _make_model(service_tier="priority") + model.client = MagicMock() + model.client.create.return_value = { + "choices": [ + { + "message": {"role": "assistant", "content": "hi"}, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, + "service_tier": "priority", + } + result = model.invoke("Hello") + assert isinstance(result, AIMessage) + assert result.response_metadata["service_tier"] == "priority" + + def test_service_tier_echoed_in_stream_chunks(self) -> None: + model = _make_model(service_tier="priority") + model.client = MagicMock() + chunks: list[dict[str, Any]] = [ + { + "choices": [{"delta": {"role": "assistant", "content": "hi"}}], + "service_tier": "priority", + }, + { + "choices": [], + "usage": { + "prompt_tokens": 1, + "completion_tokens": 1, + "total_tokens": 2, + }, + "service_tier": "priority", + }, + ] + model.client.create.return_value = iter(chunks) + out = list(model.stream("Hello")) + tagged = [c for c in out if c.response_metadata.get("service_tier")] + assert tagged + assert all(c.response_metadata["service_tier"] == "priority" for c in tagged) + + def test_service_tier_absent_when_not_in_response(self) -> None: + model = _make_model() + model.client = MagicMock() + model.client.create.return_value = { + "choices": [ + { + "message": {"role": "assistant", "content": "hi"}, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, + } + result = model.invoke("Hello") + assert isinstance(result, AIMessage) + assert "service_tier" not in result.response_metadata + + def test_service_tier_in_llm_output_when_response_carries_it(self) -> None: + model = _make_model(service_tier="priority") + chat_result = model._create_chat_result( + { + "choices": [ + { + "message": {"role": "assistant", "content": "hi"}, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 1, + "completion_tokens": 1, + "total_tokens": 2, + }, + "service_tier": "priority", + } + ) + assert chat_result.llm_output is not None + assert chat_result.llm_output["service_tier"] == "priority" + + def test_service_tier_not_inferred_from_request(self) -> None: + """Init-set tier must not leak into response_metadata if API omits it.""" + model = _make_model(service_tier="priority") + model.client = MagicMock() + model.client.create.return_value = { + "choices": [ + { + "message": {"role": "assistant", "content": "hi"}, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, + } + result = model.invoke("Hello") + assert isinstance(result, AIMessage) + assert "service_tier" not in result.response_metadata