mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
feat(fireworks): populate usage_metadata on streaming (#36977)
Populate `usage_metadata` on streaming responses. Newer Fireworks models (e.g. Kimi K2 slugs) require an explicit `stream_options.include_usage=True` opt-in and return token counts in a final empty-`choices` chunk; the chunk was previously `continue`-d past, so streaming usage silently came back as `None`.
This commit is contained in:
@@ -216,10 +216,35 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
|||||||
return message_dict
|
return message_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _usage_to_metadata(usage: Mapping[str, Any]) -> dict[str, int]:
|
||||||
|
input_tokens = usage.get("prompt_tokens", 0)
|
||||||
|
output_tokens = usage.get("completion_tokens", 0)
|
||||||
|
return {
|
||||||
|
"input_tokens": input_tokens,
|
||||||
|
"output_tokens": output_tokens,
|
||||||
|
"total_tokens": usage.get("total_tokens", input_tokens + output_tokens),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _convert_chunk_to_message_chunk(
|
def _convert_chunk_to_message_chunk(
|
||||||
chunk: Mapping[str, Any], default_class: type[BaseMessageChunk]
|
chunk: Mapping[str, Any], default_class: type[BaseMessageChunk]
|
||||||
) -> BaseMessageChunk:
|
) -> BaseMessageChunk:
|
||||||
choice = chunk["choices"][0]
|
choices = chunk.get("choices") or []
|
||||||
|
if not choices:
|
||||||
|
# Final chunk emitted when `stream_options.include_usage=True`:
|
||||||
|
# `choices` is empty and the chunk carries only `usage`.
|
||||||
|
usage = chunk.get("usage")
|
||||||
|
if not usage:
|
||||||
|
logger.debug(
|
||||||
|
"Received stream chunk with no choices and no usage: %s", chunk
|
||||||
|
)
|
||||||
|
usage_metadata = _usage_to_metadata(usage) if usage else None
|
||||||
|
return AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
usage_metadata=usage_metadata, # type: ignore[arg-type]
|
||||||
|
response_metadata={"model_provider": "fireworks"},
|
||||||
|
)
|
||||||
|
choice = choices[0]
|
||||||
_dict = choice["delta"]
|
_dict = choice["delta"]
|
||||||
role = cast(str, _dict.get("role"))
|
role = cast(str, _dict.get("role"))
|
||||||
content = cast(str, _dict.get("content") or "")
|
content = cast(str, _dict.get("content") or "")
|
||||||
@@ -245,16 +270,8 @@ def _convert_chunk_to_message_chunk(
|
|||||||
if role == "user" or default_class == HumanMessageChunk:
|
if role == "user" or default_class == HumanMessageChunk:
|
||||||
return HumanMessageChunk(content=content)
|
return HumanMessageChunk(content=content)
|
||||||
if role == "assistant" or default_class == AIMessageChunk:
|
if role == "assistant" or default_class == AIMessageChunk:
|
||||||
if usage := chunk.get("usage"):
|
usage = chunk.get("usage")
|
||||||
input_tokens = usage.get("prompt_tokens", 0)
|
usage_metadata = _usage_to_metadata(usage) if usage else None
|
||||||
output_tokens = usage.get("completion_tokens", 0)
|
|
||||||
usage_metadata = {
|
|
||||||
"input_tokens": input_tokens,
|
|
||||||
"output_tokens": output_tokens,
|
|
||||||
"total_tokens": usage.get("total_tokens", input_tokens + output_tokens),
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
usage_metadata = None
|
|
||||||
return AIMessageChunk(
|
return AIMessageChunk(
|
||||||
content=content,
|
content=content,
|
||||||
additional_kwargs=additional_kwargs,
|
additional_kwargs=additional_kwargs,
|
||||||
@@ -375,6 +392,23 @@ class ChatFireworks(BaseChatModel):
|
|||||||
streaming: bool = False
|
streaming: bool = False
|
||||||
"""Whether to stream the results or not."""
|
"""Whether to stream the results or not."""
|
||||||
|
|
||||||
|
stream_usage: bool = True
|
||||||
|
"""Whether to include usage metadata in streaming output.
|
||||||
|
|
||||||
|
If `True`, a final empty-content chunk carrying `usage_metadata` is emitted
|
||||||
|
during the stream. Set to `False` if the upstream model/proxy rejects
|
||||||
|
`stream_options`, or pass `stream_options` explicitly via `model_kwargs` or
|
||||||
|
a runtime kwarg to override.
|
||||||
|
|
||||||
|
!!! version-added "Added in `langchain-fireworks` 1.2.0"
|
||||||
|
|
||||||
|
!!! warning "Behavior changed in `langchain-fireworks` 1.2.0"
|
||||||
|
|
||||||
|
Streaming now opts into `stream_options.include_usage` by default, and
|
||||||
|
the final empty-`choices` chunk is surfaced as an `AIMessageChunk` with
|
||||||
|
`usage_metadata` instead of being silently dropped.
|
||||||
|
"""
|
||||||
|
|
||||||
n: int = 1
|
n: int = 1
|
||||||
"""Number of chat completions to generate for each prompt."""
|
"""Number of chat completions to generate for each prompt."""
|
||||||
|
|
||||||
@@ -490,22 +524,24 @@ class ChatFireworks(BaseChatModel):
|
|||||||
) -> Iterator[ChatGenerationChunk]:
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
params = {**params, **kwargs, "stream": True}
|
params = {**params, **kwargs, "stream": True}
|
||||||
|
if self.stream_usage and "stream_options" not in params:
|
||||||
|
params["stream_options"] = {"include_usage": True}
|
||||||
|
|
||||||
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
|
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
|
||||||
for chunk in self.client.create(messages=message_dicts, **params):
|
for chunk in self.client.create(messages=message_dicts, **params):
|
||||||
if not isinstance(chunk, dict):
|
if not isinstance(chunk, dict):
|
||||||
chunk = chunk.model_dump()
|
chunk = chunk.model_dump()
|
||||||
if len(chunk["choices"]) == 0:
|
|
||||||
continue
|
|
||||||
choice = chunk["choices"][0]
|
|
||||||
message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
|
message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
|
||||||
generation_info = {}
|
generation_info: dict[str, Any] = {}
|
||||||
if finish_reason := choice.get("finish_reason"):
|
logprobs = None
|
||||||
generation_info["finish_reason"] = finish_reason
|
if choices := chunk.get("choices"):
|
||||||
generation_info["model_name"] = self.model_name
|
choice = choices[0]
|
||||||
logprobs = choice.get("logprobs")
|
if finish_reason := choice.get("finish_reason"):
|
||||||
if logprobs:
|
generation_info["finish_reason"] = finish_reason
|
||||||
generation_info["logprobs"] = logprobs
|
generation_info["model_name"] = self.model_name
|
||||||
|
logprobs = choice.get("logprobs")
|
||||||
|
if logprobs:
|
||||||
|
generation_info["logprobs"] = logprobs
|
||||||
default_chunk_class = message_chunk.__class__
|
default_chunk_class = message_chunk.__class__
|
||||||
generation_chunk = ChatGenerationChunk(
|
generation_chunk = ChatGenerationChunk(
|
||||||
message=message_chunk, generation_info=generation_info or None
|
message=message_chunk, generation_info=generation_info or None
|
||||||
@@ -586,22 +622,24 @@ class ChatFireworks(BaseChatModel):
|
|||||||
) -> AsyncIterator[ChatGenerationChunk]:
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
params = {**params, **kwargs, "stream": True}
|
params = {**params, **kwargs, "stream": True}
|
||||||
|
if self.stream_usage and "stream_options" not in params:
|
||||||
|
params["stream_options"] = {"include_usage": True}
|
||||||
|
|
||||||
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
|
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
|
||||||
async for chunk in self.async_client.acreate(messages=message_dicts, **params):
|
async for chunk in self.async_client.acreate(messages=message_dicts, **params):
|
||||||
if not isinstance(chunk, dict):
|
if not isinstance(chunk, dict):
|
||||||
chunk = chunk.model_dump()
|
chunk = chunk.model_dump()
|
||||||
if len(chunk["choices"]) == 0:
|
|
||||||
continue
|
|
||||||
choice = chunk["choices"][0]
|
|
||||||
message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
|
message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
|
||||||
generation_info = {}
|
generation_info: dict[str, Any] = {}
|
||||||
if finish_reason := choice.get("finish_reason"):
|
logprobs = None
|
||||||
generation_info["finish_reason"] = finish_reason
|
if choices := chunk.get("choices"):
|
||||||
generation_info["model_name"] = self.model_name
|
choice = choices[0]
|
||||||
logprobs = choice.get("logprobs")
|
if finish_reason := choice.get("finish_reason"):
|
||||||
if logprobs:
|
generation_info["finish_reason"] = finish_reason
|
||||||
generation_info["logprobs"] = logprobs
|
generation_info["model_name"] = self.model_name
|
||||||
|
logprobs = choice.get("logprobs")
|
||||||
|
if logprobs:
|
||||||
|
generation_info["logprobs"] = logprobs
|
||||||
default_chunk_class = message_chunk.__class__
|
default_chunk_class = message_chunk.__class__
|
||||||
generation_chunk = ChatGenerationChunk(
|
generation_chunk = ChatGenerationChunk(
|
||||||
message=message_chunk, generation_info=generation_info or None
|
message=message_chunk, generation_info=generation_info or None
|
||||||
|
|||||||
@@ -22,6 +22,7 @@
|
|||||||
'request_timeout': 60.0,
|
'request_timeout': 60.0,
|
||||||
'stop': list([
|
'stop': list([
|
||||||
]),
|
]),
|
||||||
|
'stream_usage': True,
|
||||||
'temperature': 0.0,
|
'temperature': 0.0,
|
||||||
}),
|
}),
|
||||||
'lc': 1,
|
'lc': 1,
|
||||||
|
|||||||
@@ -2,10 +2,44 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage
|
from typing import Any
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||||
|
|
||||||
from langchain_fireworks import ChatFireworks
|
from langchain_fireworks import ChatFireworks
|
||||||
from langchain_fireworks.chat_models import _convert_dict_to_message
|
from langchain_fireworks.chat_models import (
|
||||||
|
_convert_chunk_to_message_chunk,
|
||||||
|
_convert_dict_to_message,
|
||||||
|
_usage_to_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
MODEL_NAME = "accounts/fireworks/models/test-model"
|
||||||
|
|
||||||
|
|
||||||
|
def _make_model(**kwargs: Any) -> ChatFireworks:
|
||||||
|
defaults: dict[str, Any] = {"model": MODEL_NAME, "api_key": "fake-key"}
|
||||||
|
defaults.update(kwargs)
|
||||||
|
return ChatFireworks(**defaults) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
|
_STREAM_CHUNKS: list[dict[str, Any]] = [
|
||||||
|
{
|
||||||
|
"choices": [{"delta": {"role": "assistant", "content": ""}, "index": 0}],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [{"delta": {"content": "Hello"}, "index": 0}],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [{"delta": {}, "finish_reason": "stop", "index": 0}],
|
||||||
|
},
|
||||||
|
# Final usage-only chunk (empty choices)
|
||||||
|
{
|
||||||
|
"choices": [],
|
||||||
|
"usage": {"prompt_tokens": 5, "completion_tokens": 2, "total_tokens": 7},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_fireworks_model_param() -> None:
|
def test_fireworks_model_param() -> None:
|
||||||
@@ -46,3 +80,139 @@ def test_convert_dict_to_message_without_reasoning_content() -> None:
|
|||||||
assert isinstance(message, AIMessage)
|
assert isinstance(message, AIMessage)
|
||||||
assert message.content == "The answer is 42."
|
assert message.content == "The answer is 42."
|
||||||
assert "reasoning_content" not in message.additional_kwargs
|
assert "reasoning_content" not in message.additional_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
class TestUsageToMetadata:
|
||||||
|
"""Tests for the `_usage_to_metadata` helper."""
|
||||||
|
|
||||||
|
def test_all_fields_present(self) -> None:
|
||||||
|
result = _usage_to_metadata(
|
||||||
|
{"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}
|
||||||
|
)
|
||||||
|
assert result == {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||||
|
|
||||||
|
def test_total_tokens_fallback_sums_input_and_output(self) -> None:
|
||||||
|
"""When provider omits total_tokens, sum input + output."""
|
||||||
|
result = _usage_to_metadata({"prompt_tokens": 7, "completion_tokens": 3})
|
||||||
|
assert result == {"input_tokens": 7, "output_tokens": 3, "total_tokens": 10}
|
||||||
|
|
||||||
|
def test_missing_fields_default_to_zero(self) -> None:
|
||||||
|
result = _usage_to_metadata({})
|
||||||
|
assert result == {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
||||||
|
|
||||||
|
|
||||||
|
class TestConvertChunkToMessageChunk:
|
||||||
|
"""Tests for `_convert_chunk_to_message_chunk` empty-choices handling."""
|
||||||
|
|
||||||
|
def test_empty_choices_with_usage_returns_usage_chunk(self) -> None:
|
||||||
|
chunk = {
|
||||||
|
"choices": [],
|
||||||
|
"usage": {"prompt_tokens": 4, "completion_tokens": 1, "total_tokens": 5},
|
||||||
|
}
|
||||||
|
result = _convert_chunk_to_message_chunk(chunk, AIMessageChunk)
|
||||||
|
assert isinstance(result, AIMessageChunk)
|
||||||
|
assert result.content == ""
|
||||||
|
assert result.usage_metadata == {
|
||||||
|
"input_tokens": 4,
|
||||||
|
"output_tokens": 1,
|
||||||
|
"total_tokens": 5,
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_empty_choices_without_usage_logs_and_returns_blank(
|
||||||
|
self, caplog: pytest.LogCaptureFixture
|
||||||
|
) -> None:
|
||||||
|
chunk: dict[str, Any] = {"choices": []}
|
||||||
|
with caplog.at_level("DEBUG", logger="langchain_fireworks.chat_models"):
|
||||||
|
result = _convert_chunk_to_message_chunk(chunk, AIMessageChunk)
|
||||||
|
assert isinstance(result, AIMessageChunk)
|
||||||
|
assert result.content == ""
|
||||||
|
assert result.usage_metadata is None
|
||||||
|
assert any("no choices and no usage" in rec.message for rec in caplog.records)
|
||||||
|
|
||||||
|
def test_missing_choices_key_treated_as_empty(self) -> None:
|
||||||
|
"""Provider may omit `choices` entirely on the final usage frame."""
|
||||||
|
chunk = {
|
||||||
|
"usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3},
|
||||||
|
}
|
||||||
|
result = _convert_chunk_to_message_chunk(chunk, AIMessageChunk)
|
||||||
|
assert isinstance(result, AIMessageChunk)
|
||||||
|
assert result.usage_metadata == {
|
||||||
|
"input_tokens": 1,
|
||||||
|
"output_tokens": 2,
|
||||||
|
"total_tokens": 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestStreamUsage:
|
||||||
|
"""Tests for the `stream_usage` field and `stream_options` plumbing."""
|
||||||
|
|
||||||
|
def test_stream_options_passed_by_default(self) -> None:
|
||||||
|
model = _make_model()
|
||||||
|
model.client = MagicMock()
|
||||||
|
model.client.create.return_value = iter(list(_STREAM_CHUNKS))
|
||||||
|
list(model.stream("Hello"))
|
||||||
|
call_kwargs = model.client.create.call_args[1]
|
||||||
|
assert call_kwargs["stream_options"] == {"include_usage": True}
|
||||||
|
|
||||||
|
def test_stream_options_not_passed_when_disabled(self) -> None:
|
||||||
|
model = _make_model(stream_usage=False)
|
||||||
|
model.client = MagicMock()
|
||||||
|
model.client.create.return_value = iter(list(_STREAM_CHUNKS))
|
||||||
|
list(model.stream("Hello"))
|
||||||
|
call_kwargs = model.client.create.call_args[1]
|
||||||
|
assert "stream_options" not in call_kwargs
|
||||||
|
|
||||||
|
def test_user_stream_options_in_model_kwargs_wins(self) -> None:
|
||||||
|
"""User-provided stream_options via model_kwargs overrides the default."""
|
||||||
|
custom = {"include_usage": False}
|
||||||
|
model = _make_model(model_kwargs={"stream_options": custom})
|
||||||
|
model.client = MagicMock()
|
||||||
|
model.client.create.return_value = iter(list(_STREAM_CHUNKS))
|
||||||
|
list(model.stream("Hello"))
|
||||||
|
call_kwargs = model.client.create.call_args[1]
|
||||||
|
assert call_kwargs["stream_options"] == custom
|
||||||
|
|
||||||
|
def test_usage_only_chunk_emits_usage_metadata(self) -> None:
|
||||||
|
"""The final empty-choices + usage chunk propagates as usage_metadata."""
|
||||||
|
model = _make_model()
|
||||||
|
model.client = MagicMock()
|
||||||
|
model.client.create.return_value = iter(list(_STREAM_CHUNKS))
|
||||||
|
chunks = list(model.stream("Hello"))
|
||||||
|
usage_chunks = [c for c in chunks if c.usage_metadata]
|
||||||
|
assert len(usage_chunks) == 1
|
||||||
|
assert usage_chunks[0].usage_metadata == {
|
||||||
|
"input_tokens": 5,
|
||||||
|
"output_tokens": 2,
|
||||||
|
"total_tokens": 7,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def test_astream_options_passed_by_default(self) -> None:
|
||||||
|
model = _make_model()
|
||||||
|
model.async_client = MagicMock()
|
||||||
|
|
||||||
|
async def _aiter() -> Any:
|
||||||
|
for c in _STREAM_CHUNKS:
|
||||||
|
yield c
|
||||||
|
|
||||||
|
model.async_client.acreate = MagicMock(return_value=_aiter())
|
||||||
|
[chunk async for chunk in model.astream("Hello")]
|
||||||
|
call_kwargs = model.async_client.acreate.call_args[1]
|
||||||
|
assert call_kwargs["stream_options"] == {"include_usage": True}
|
||||||
|
|
||||||
|
async def test_astream_usage_only_chunk_emits_usage_metadata(self) -> None:
|
||||||
|
model = _make_model()
|
||||||
|
model.async_client = MagicMock()
|
||||||
|
|
||||||
|
async def _aiter() -> Any:
|
||||||
|
for c in _STREAM_CHUNKS:
|
||||||
|
yield c
|
||||||
|
|
||||||
|
model.async_client.acreate = MagicMock(return_value=_aiter())
|
||||||
|
chunks = [chunk async for chunk in model.astream("Hello")]
|
||||||
|
usage_chunks = [c for c in chunks if c.usage_metadata]
|
||||||
|
assert len(usage_chunks) == 1
|
||||||
|
assert usage_chunks[0].usage_metadata == {
|
||||||
|
"input_tokens": 5,
|
||||||
|
"output_tokens": 2,
|
||||||
|
"total_tokens": 7,
|
||||||
|
}
|
||||||
|
|||||||
2
libs/partners/fireworks/uv.lock
generated
2
libs/partners/fireworks/uv.lock
generated
@@ -697,7 +697,7 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langchain-core"
|
name = "langchain-core"
|
||||||
version = "1.3.0a2"
|
version = "1.3.1"
|
||||||
source = { editable = "../../core" }
|
source = { editable = "../../core" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "jsonpatch" },
|
{ name = "jsonpatch" },
|
||||||
|
|||||||
Reference in New Issue
Block a user