mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
feat(fireworks): populate usage_metadata on streaming (#36977)
Populate `usage_metadata` on streaming responses. Newer Fireworks models (e.g. Kimi K2 slugs) require an explicit `stream_options.include_usage=True` opt-in and return token counts in a final empty-`choices` chunk; the chunk was previously `continue`-d past, so streaming usage silently came back as `None`.
This commit is contained in:
@@ -216,10 +216,35 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
return message_dict
|
||||
|
||||
|
||||
def _usage_to_metadata(usage: Mapping[str, Any]) -> dict[str, int]:
|
||||
input_tokens = usage.get("prompt_tokens", 0)
|
||||
output_tokens = usage.get("completion_tokens", 0)
|
||||
return {
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"total_tokens": usage.get("total_tokens", input_tokens + output_tokens),
|
||||
}
|
||||
|
||||
|
||||
def _convert_chunk_to_message_chunk(
|
||||
chunk: Mapping[str, Any], default_class: type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
choice = chunk["choices"][0]
|
||||
choices = chunk.get("choices") or []
|
||||
if not choices:
|
||||
# Final chunk emitted when `stream_options.include_usage=True`:
|
||||
# `choices` is empty and the chunk carries only `usage`.
|
||||
usage = chunk.get("usage")
|
||||
if not usage:
|
||||
logger.debug(
|
||||
"Received stream chunk with no choices and no usage: %s", chunk
|
||||
)
|
||||
usage_metadata = _usage_to_metadata(usage) if usage else None
|
||||
return AIMessageChunk(
|
||||
content="",
|
||||
usage_metadata=usage_metadata, # type: ignore[arg-type]
|
||||
response_metadata={"model_provider": "fireworks"},
|
||||
)
|
||||
choice = choices[0]
|
||||
_dict = choice["delta"]
|
||||
role = cast(str, _dict.get("role"))
|
||||
content = cast(str, _dict.get("content") or "")
|
||||
@@ -245,16 +270,8 @@ def _convert_chunk_to_message_chunk(
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
if role == "assistant" or default_class == AIMessageChunk:
|
||||
if usage := chunk.get("usage"):
|
||||
input_tokens = usage.get("prompt_tokens", 0)
|
||||
output_tokens = usage.get("completion_tokens", 0)
|
||||
usage_metadata = {
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"total_tokens": usage.get("total_tokens", input_tokens + output_tokens),
|
||||
}
|
||||
else:
|
||||
usage_metadata = None
|
||||
usage = chunk.get("usage")
|
||||
usage_metadata = _usage_to_metadata(usage) if usage else None
|
||||
return AIMessageChunk(
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
@@ -375,6 +392,23 @@ class ChatFireworks(BaseChatModel):
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
|
||||
stream_usage: bool = True
|
||||
"""Whether to include usage metadata in streaming output.
|
||||
|
||||
If `True`, a final empty-content chunk carrying `usage_metadata` is emitted
|
||||
during the stream. Set to `False` if the upstream model/proxy rejects
|
||||
`stream_options`, or pass `stream_options` explicitly via `model_kwargs` or
|
||||
a runtime kwarg to override.
|
||||
|
||||
!!! version-added "Added in `langchain-fireworks` 1.2.0"
|
||||
|
||||
!!! warning "Behavior changed in `langchain-fireworks` 1.2.0"
|
||||
|
||||
Streaming now opts into `stream_options.include_usage` by default, and
|
||||
the final empty-`choices` chunk is surfaced as an `AIMessageChunk` with
|
||||
`usage_metadata` instead of being silently dropped.
|
||||
"""
|
||||
|
||||
n: int = 1
|
||||
"""Number of chat completions to generate for each prompt."""
|
||||
|
||||
@@ -490,22 +524,24 @@ class ChatFireworks(BaseChatModel):
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
if self.stream_usage and "stream_options" not in params:
|
||||
params["stream_options"] = {"include_usage": True}
|
||||
|
||||
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
|
||||
for chunk in self.client.create(messages=message_dicts, **params):
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.model_dump()
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
choice = chunk["choices"][0]
|
||||
message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
|
||||
generation_info = {}
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
generation_info["model_name"] = self.model_name
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
generation_info: dict[str, Any] = {}
|
||||
logprobs = None
|
||||
if choices := chunk.get("choices"):
|
||||
choice = choices[0]
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
generation_info["model_name"] = self.model_name
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
default_chunk_class = message_chunk.__class__
|
||||
generation_chunk = ChatGenerationChunk(
|
||||
message=message_chunk, generation_info=generation_info or None
|
||||
@@ -586,22 +622,24 @@ class ChatFireworks(BaseChatModel):
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
if self.stream_usage and "stream_options" not in params:
|
||||
params["stream_options"] = {"include_usage": True}
|
||||
|
||||
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
|
||||
async for chunk in self.async_client.acreate(messages=message_dicts, **params):
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.model_dump()
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
choice = chunk["choices"][0]
|
||||
message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
|
||||
generation_info = {}
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
generation_info["model_name"] = self.model_name
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
generation_info: dict[str, Any] = {}
|
||||
logprobs = None
|
||||
if choices := chunk.get("choices"):
|
||||
choice = choices[0]
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
generation_info["model_name"] = self.model_name
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
default_chunk_class = message_chunk.__class__
|
||||
generation_chunk = ChatGenerationChunk(
|
||||
message=message_chunk, generation_info=generation_info or None
|
||||
|
||||
Reference in New Issue
Block a user