fix(huggingface): add stream_usage support for ChatHuggingFace invoke/stream (#32708)

This commit is contained in:
Hyejeong Jo
2025-11-04 04:44:32 +09:00
committed by GitHub
parent 6617865440
commit 0e36185933
2 changed files with 82 additions and 0 deletions

View File

@@ -499,6 +499,9 @@ class ChatHuggingFace(BaseChatModel):
"""Modify the likelihood of specified tokens appearing in the completion.""" """Modify the likelihood of specified tokens appearing in the completion."""
streaming: bool = False streaming: bool = False
"""Whether to stream the results or not.""" """Whether to stream the results or not."""
stream_usage: bool | None = None
"""Whether to include usage metadata in streaming output. If True, an additional
message chunk will be generated during the stream including usage metadata."""
n: int | None = None n: int | None = None
"""Number of chat completions to generate for each prompt.""" """Number of chat completions to generate for each prompt."""
top_p: float | None = None top_p: float | None = None
@@ -634,14 +637,40 @@ class ChatHuggingFace(BaseChatModel):
) )
return self._to_chat_result(llm_result) return self._to_chat_result(llm_result)
def _should_stream_usage(
self, *, stream_usage: bool | None = None, **kwargs: Any
) -> bool | None:
"""Determine whether to include usage metadata in streaming output.
For backwards compatibility, we check for `stream_options` passed
explicitly to kwargs or in the model_kwargs and override self.stream_usage.
"""
stream_usage_sources = [ # order of precedence
stream_usage,
kwargs.get("stream_options", {}).get("include_usage"),
self.model_kwargs.get("stream_options", {}).get("include_usage"),
self.stream_usage,
]
for source in stream_usage_sources:
if isinstance(source, bool):
return source
return self.stream_usage
def _stream( def _stream(
self, self,
messages: list[BaseMessage], messages: list[BaseMessage],
stop: list[str] | None = None, stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None, run_manager: CallbackManagerForLLMRun | None = None,
*,
stream_usage: bool | None = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[ChatGenerationChunk]: ) -> Iterator[ChatGenerationChunk]:
if _is_huggingface_endpoint(self.llm): if _is_huggingface_endpoint(self.llm):
stream_usage = self._should_stream_usage(
stream_usage=stream_usage, **kwargs
)
if stream_usage:
kwargs["stream_options"] = {"include_usage": stream_usage}
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True} params = {**params, **kwargs, "stream": True}
@@ -650,7 +679,20 @@ class ChatHuggingFace(BaseChatModel):
messages=message_dicts, **params messages=message_dicts, **params
): ):
if len(chunk["choices"]) == 0: if len(chunk["choices"]) == 0:
if usage := chunk.get("usage"):
usage_msg = AIMessageChunk(
content="",
additional_kwargs={},
response_metadata={},
usage_metadata={
"input_tokens": usage.get("prompt_tokens", 0),
"output_tokens": usage.get("completion_tokens", 0),
"total_tokens": usage.get("total_tokens", 0),
},
)
yield ChatGenerationChunk(message=usage_msg)
continue continue
choice = chunk["choices"][0] choice = chunk["choices"][0]
message_chunk = _convert_chunk_to_message_chunk( message_chunk = _convert_chunk_to_message_chunk(
chunk, default_chunk_class chunk, default_chunk_class
@@ -688,8 +730,13 @@ class ChatHuggingFace(BaseChatModel):
messages: list[BaseMessage], messages: list[BaseMessage],
stop: list[str] | None = None, stop: list[str] | None = None,
run_manager: AsyncCallbackManagerForLLMRun | None = None, run_manager: AsyncCallbackManagerForLLMRun | None = None,
*,
stream_usage: bool | None = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]: ) -> AsyncIterator[ChatGenerationChunk]:
stream_usage = self._should_stream_usage(stream_usage=stream_usage, **kwargs)
if stream_usage:
kwargs["stream_options"] = {"include_usage": stream_usage}
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True} params = {**params, **kwargs, "stream": True}
@@ -699,7 +746,20 @@ class ChatHuggingFace(BaseChatModel):
messages=message_dicts, **params messages=message_dicts, **params
): ):
if len(chunk["choices"]) == 0: if len(chunk["choices"]) == 0:
if usage := chunk.get("usage"):
usage_msg = AIMessageChunk(
content="",
additional_kwargs={},
response_metadata={},
usage_metadata={
"input_tokens": usage.get("prompt_tokens", 0),
"output_tokens": usage.get("completion_tokens", 0),
"total_tokens": usage.get("total_tokens", 0),
},
)
yield ChatGenerationChunk(message=usage_msg)
continue continue
choice = chunk["choices"][0] choice = chunk["choices"][0]
message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class) message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
generation_info = {} generation_info = {}

View File

@@ -0,0 +1,22 @@
from langchain_core.messages import AIMessageChunk
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
def test_stream_usage() -> None:
"""Test we are able to configure stream options on models that require it."""
llm = HuggingFaceEndpoint( # type: ignore[call-arg] # (model is inferred in class)
repo_id="google/gemma-3-27b-it",
task="conversational",
provider="nebius",
)
model = ChatHuggingFace(llm=llm, stream_usage=True)
full: AIMessageChunk | None = None
for chunk in model.stream("hello"):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert full.usage_metadata