diff --git a/libs/partners/fireworks/langchain_fireworks/chat_models.py b/libs/partners/fireworks/langchain_fireworks/chat_models.py index 416a6503ec5..9e2e4a45a8a 100644 --- a/libs/partners/fireworks/langchain_fireworks/chat_models.py +++ b/libs/partners/fireworks/langchain_fireworks/chat_models.py @@ -216,10 +216,35 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: return message_dict +def _usage_to_metadata(usage: Mapping[str, Any]) -> dict[str, int]: + input_tokens = usage.get("prompt_tokens", 0) + output_tokens = usage.get("completion_tokens", 0) + return { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": usage.get("total_tokens", input_tokens + output_tokens), + } + + def _convert_chunk_to_message_chunk( chunk: Mapping[str, Any], default_class: type[BaseMessageChunk] ) -> BaseMessageChunk: - choice = chunk["choices"][0] + choices = chunk.get("choices") or [] + if not choices: + # Final chunk emitted when `stream_options.include_usage=True`: + # `choices` is empty and the chunk carries only `usage`. + usage = chunk.get("usage") + if not usage: + logger.debug( + "Received stream chunk with no choices and no usage: %s", chunk + ) + usage_metadata = _usage_to_metadata(usage) if usage else None + return AIMessageChunk( + content="", + usage_metadata=usage_metadata, # type: ignore[arg-type] + response_metadata={"model_provider": "fireworks"}, + ) + choice = choices[0] _dict = choice["delta"] role = cast(str, _dict.get("role")) content = cast(str, _dict.get("content") or "") @@ -245,16 +270,8 @@ def _convert_chunk_to_message_chunk( if role == "user" or default_class == HumanMessageChunk: return HumanMessageChunk(content=content) if role == "assistant" or default_class == AIMessageChunk: - if usage := chunk.get("usage"): - input_tokens = usage.get("prompt_tokens", 0) - output_tokens = usage.get("completion_tokens", 0) - usage_metadata = { - "input_tokens": input_tokens, - "output_tokens": output_tokens, - "total_tokens": usage.get("total_tokens", input_tokens + output_tokens), - } - else: - usage_metadata = None + usage = chunk.get("usage") + usage_metadata = _usage_to_metadata(usage) if usage else None return AIMessageChunk( content=content, additional_kwargs=additional_kwargs, @@ -375,6 +392,23 @@ class ChatFireworks(BaseChatModel): streaming: bool = False """Whether to stream the results or not.""" + stream_usage: bool = True + """Whether to include usage metadata in streaming output. + + If `True`, a final empty-content chunk carrying `usage_metadata` is emitted + during the stream. Set to `False` if the upstream model/proxy rejects + `stream_options`, or pass `stream_options` explicitly via `model_kwargs` or + a runtime kwarg to override. + + !!! version-added "Added in `langchain-fireworks` 1.2.0" + + !!! warning "Behavior changed in `langchain-fireworks` 1.2.0" + + Streaming now opts into `stream_options.include_usage` by default, and + the final empty-`choices` chunk is surfaced as an `AIMessageChunk` with + `usage_metadata` instead of being silently dropped. + """ + n: int = 1 """Number of chat completions to generate for each prompt.""" @@ -490,22 +524,24 @@ class ChatFireworks(BaseChatModel): ) -> Iterator[ChatGenerationChunk]: message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs, "stream": True} + if self.stream_usage and "stream_options" not in params: + params["stream_options"] = {"include_usage": True} default_chunk_class: type[BaseMessageChunk] = AIMessageChunk for chunk in self.client.create(messages=message_dicts, **params): if not isinstance(chunk, dict): chunk = chunk.model_dump() - if len(chunk["choices"]) == 0: - continue - choice = chunk["choices"][0] message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class) - generation_info = {} - if finish_reason := choice.get("finish_reason"): - generation_info["finish_reason"] = finish_reason - generation_info["model_name"] = self.model_name - logprobs = choice.get("logprobs") - if logprobs: - generation_info["logprobs"] = logprobs + generation_info: dict[str, Any] = {} + logprobs = None + if choices := chunk.get("choices"): + choice = choices[0] + if finish_reason := choice.get("finish_reason"): + generation_info["finish_reason"] = finish_reason + generation_info["model_name"] = self.model_name + logprobs = choice.get("logprobs") + if logprobs: + generation_info["logprobs"] = logprobs default_chunk_class = message_chunk.__class__ generation_chunk = ChatGenerationChunk( message=message_chunk, generation_info=generation_info or None @@ -586,22 +622,24 @@ class ChatFireworks(BaseChatModel): ) -> AsyncIterator[ChatGenerationChunk]: message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs, "stream": True} + if self.stream_usage and "stream_options" not in params: + params["stream_options"] = {"include_usage": True} default_chunk_class: type[BaseMessageChunk] = AIMessageChunk async for chunk in self.async_client.acreate(messages=message_dicts, **params): if not isinstance(chunk, dict): chunk = chunk.model_dump() - if len(chunk["choices"]) == 0: - continue - choice = chunk["choices"][0] message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class) - generation_info = {} - if finish_reason := choice.get("finish_reason"): - generation_info["finish_reason"] = finish_reason - generation_info["model_name"] = self.model_name - logprobs = choice.get("logprobs") - if logprobs: - generation_info["logprobs"] = logprobs + generation_info: dict[str, Any] = {} + logprobs = None + if choices := chunk.get("choices"): + choice = choices[0] + if finish_reason := choice.get("finish_reason"): + generation_info["finish_reason"] = finish_reason + generation_info["model_name"] = self.model_name + logprobs = choice.get("logprobs") + if logprobs: + generation_info["logprobs"] = logprobs default_chunk_class = message_chunk.__class__ generation_chunk = ChatGenerationChunk( message=message_chunk, generation_info=generation_info or None diff --git a/libs/partners/fireworks/tests/unit_tests/__snapshots__/test_standard.ambr b/libs/partners/fireworks/tests/unit_tests/__snapshots__/test_standard.ambr index 99ef0e8109a..60efb25f59b 100644 --- a/libs/partners/fireworks/tests/unit_tests/__snapshots__/test_standard.ambr +++ b/libs/partners/fireworks/tests/unit_tests/__snapshots__/test_standard.ambr @@ -22,6 +22,7 @@ 'request_timeout': 60.0, 'stop': list([ ]), + 'stream_usage': True, 'temperature': 0.0, }), 'lc': 1, diff --git a/libs/partners/fireworks/tests/unit_tests/test_chat_models.py b/libs/partners/fireworks/tests/unit_tests/test_chat_models.py index 764067ac6a7..4acfc9a14b2 100644 --- a/libs/partners/fireworks/tests/unit_tests/test_chat_models.py +++ b/libs/partners/fireworks/tests/unit_tests/test_chat_models.py @@ -2,10 +2,44 @@ from __future__ import annotations -from langchain_core.messages import AIMessage +from typing import Any +from unittest.mock import MagicMock + +import pytest +from langchain_core.messages import AIMessage, AIMessageChunk from langchain_fireworks import ChatFireworks -from langchain_fireworks.chat_models import _convert_dict_to_message +from langchain_fireworks.chat_models import ( + _convert_chunk_to_message_chunk, + _convert_dict_to_message, + _usage_to_metadata, +) + +MODEL_NAME = "accounts/fireworks/models/test-model" + + +def _make_model(**kwargs: Any) -> ChatFireworks: + defaults: dict[str, Any] = {"model": MODEL_NAME, "api_key": "fake-key"} + defaults.update(kwargs) + return ChatFireworks(**defaults) # type: ignore[arg-type] + + +_STREAM_CHUNKS: list[dict[str, Any]] = [ + { + "choices": [{"delta": {"role": "assistant", "content": ""}, "index": 0}], + }, + { + "choices": [{"delta": {"content": "Hello"}, "index": 0}], + }, + { + "choices": [{"delta": {}, "finish_reason": "stop", "index": 0}], + }, + # Final usage-only chunk (empty choices) + { + "choices": [], + "usage": {"prompt_tokens": 5, "completion_tokens": 2, "total_tokens": 7}, + }, +] def test_fireworks_model_param() -> None: @@ -46,3 +80,139 @@ def test_convert_dict_to_message_without_reasoning_content() -> None: assert isinstance(message, AIMessage) assert message.content == "The answer is 42." assert "reasoning_content" not in message.additional_kwargs + + +class TestUsageToMetadata: + """Tests for the `_usage_to_metadata` helper.""" + + def test_all_fields_present(self) -> None: + result = _usage_to_metadata( + {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15} + ) + assert result == {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + + def test_total_tokens_fallback_sums_input_and_output(self) -> None: + """When provider omits total_tokens, sum input + output.""" + result = _usage_to_metadata({"prompt_tokens": 7, "completion_tokens": 3}) + assert result == {"input_tokens": 7, "output_tokens": 3, "total_tokens": 10} + + def test_missing_fields_default_to_zero(self) -> None: + result = _usage_to_metadata({}) + assert result == {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} + + +class TestConvertChunkToMessageChunk: + """Tests for `_convert_chunk_to_message_chunk` empty-choices handling.""" + + def test_empty_choices_with_usage_returns_usage_chunk(self) -> None: + chunk = { + "choices": [], + "usage": {"prompt_tokens": 4, "completion_tokens": 1, "total_tokens": 5}, + } + result = _convert_chunk_to_message_chunk(chunk, AIMessageChunk) + assert isinstance(result, AIMessageChunk) + assert result.content == "" + assert result.usage_metadata == { + "input_tokens": 4, + "output_tokens": 1, + "total_tokens": 5, + } + + def test_empty_choices_without_usage_logs_and_returns_blank( + self, caplog: pytest.LogCaptureFixture + ) -> None: + chunk: dict[str, Any] = {"choices": []} + with caplog.at_level("DEBUG", logger="langchain_fireworks.chat_models"): + result = _convert_chunk_to_message_chunk(chunk, AIMessageChunk) + assert isinstance(result, AIMessageChunk) + assert result.content == "" + assert result.usage_metadata is None + assert any("no choices and no usage" in rec.message for rec in caplog.records) + + def test_missing_choices_key_treated_as_empty(self) -> None: + """Provider may omit `choices` entirely on the final usage frame.""" + chunk = { + "usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3}, + } + result = _convert_chunk_to_message_chunk(chunk, AIMessageChunk) + assert isinstance(result, AIMessageChunk) + assert result.usage_metadata == { + "input_tokens": 1, + "output_tokens": 2, + "total_tokens": 3, + } + + +class TestStreamUsage: + """Tests for the `stream_usage` field and `stream_options` plumbing.""" + + def test_stream_options_passed_by_default(self) -> None: + model = _make_model() + model.client = MagicMock() + model.client.create.return_value = iter(list(_STREAM_CHUNKS)) + list(model.stream("Hello")) + call_kwargs = model.client.create.call_args[1] + assert call_kwargs["stream_options"] == {"include_usage": True} + + def test_stream_options_not_passed_when_disabled(self) -> None: + model = _make_model(stream_usage=False) + model.client = MagicMock() + model.client.create.return_value = iter(list(_STREAM_CHUNKS)) + list(model.stream("Hello")) + call_kwargs = model.client.create.call_args[1] + assert "stream_options" not in call_kwargs + + def test_user_stream_options_in_model_kwargs_wins(self) -> None: + """User-provided stream_options via model_kwargs overrides the default.""" + custom = {"include_usage": False} + model = _make_model(model_kwargs={"stream_options": custom}) + model.client = MagicMock() + model.client.create.return_value = iter(list(_STREAM_CHUNKS)) + list(model.stream("Hello")) + call_kwargs = model.client.create.call_args[1] + assert call_kwargs["stream_options"] == custom + + def test_usage_only_chunk_emits_usage_metadata(self) -> None: + """The final empty-choices + usage chunk propagates as usage_metadata.""" + model = _make_model() + model.client = MagicMock() + model.client.create.return_value = iter(list(_STREAM_CHUNKS)) + chunks = list(model.stream("Hello")) + usage_chunks = [c for c in chunks if c.usage_metadata] + assert len(usage_chunks) == 1 + assert usage_chunks[0].usage_metadata == { + "input_tokens": 5, + "output_tokens": 2, + "total_tokens": 7, + } + + async def test_astream_options_passed_by_default(self) -> None: + model = _make_model() + model.async_client = MagicMock() + + async def _aiter() -> Any: + for c in _STREAM_CHUNKS: + yield c + + model.async_client.acreate = MagicMock(return_value=_aiter()) + [chunk async for chunk in model.astream("Hello")] + call_kwargs = model.async_client.acreate.call_args[1] + assert call_kwargs["stream_options"] == {"include_usage": True} + + async def test_astream_usage_only_chunk_emits_usage_metadata(self) -> None: + model = _make_model() + model.async_client = MagicMock() + + async def _aiter() -> Any: + for c in _STREAM_CHUNKS: + yield c + + model.async_client.acreate = MagicMock(return_value=_aiter()) + chunks = [chunk async for chunk in model.astream("Hello")] + usage_chunks = [c for c in chunks if c.usage_metadata] + assert len(usage_chunks) == 1 + assert usage_chunks[0].usage_metadata == { + "input_tokens": 5, + "output_tokens": 2, + "total_tokens": 7, + } diff --git a/libs/partners/fireworks/uv.lock b/libs/partners/fireworks/uv.lock index 83d6ef93ac3..0d957194c9f 100644 --- a/libs/partners/fireworks/uv.lock +++ b/libs/partners/fireworks/uv.lock @@ -697,7 +697,7 @@ wheels = [ [[package]] name = "langchain-core" -version = "1.3.0a2" +version = "1.3.1" source = { editable = "../../core" } dependencies = [ { name = "jsonpatch" },