From 0e36185933852c5e293d4b9f3bfb363366008a62 Mon Sep 17 00:00:00 2001 From: Hyejeong Jo <83329561+girlsending0@users.noreply.github.com> Date: Tue, 4 Nov 2025 04:44:32 +0900 Subject: [PATCH] fix(huggingface): add `stream_usage` support for `ChatHuggingFace` invoke/stream (#32708) --- .../chat_models/huggingface.py | 60 +++++++++++++++++++ .../integration_tests/test_chat_models.py | 22 +++++++ 2 files changed, 82 insertions(+) create mode 100644 libs/partners/huggingface/tests/integration_tests/test_chat_models.py diff --git a/libs/partners/huggingface/langchain_huggingface/chat_models/huggingface.py b/libs/partners/huggingface/langchain_huggingface/chat_models/huggingface.py index 70e2f275e1f..5316951a052 100644 --- a/libs/partners/huggingface/langchain_huggingface/chat_models/huggingface.py +++ b/libs/partners/huggingface/langchain_huggingface/chat_models/huggingface.py @@ -499,6 +499,9 @@ class ChatHuggingFace(BaseChatModel): """Modify the likelihood of specified tokens appearing in the completion.""" streaming: bool = False """Whether to stream the results or not.""" + stream_usage: bool | None = None + """Whether to include usage metadata in streaming output. If True, an additional + message chunk will be generated during the stream including usage metadata.""" n: int | None = None """Number of chat completions to generate for each prompt.""" top_p: float | None = None @@ -634,14 +637,40 @@ class ChatHuggingFace(BaseChatModel): ) return self._to_chat_result(llm_result) + def _should_stream_usage( + self, *, stream_usage: bool | None = None, **kwargs: Any + ) -> bool | None: + """Determine whether to include usage metadata in streaming output. + + For backwards compatibility, we check for `stream_options` passed + explicitly to kwargs or in the model_kwargs and override self.stream_usage. + """ + stream_usage_sources = [ # order of precedence + stream_usage, + kwargs.get("stream_options", {}).get("include_usage"), + self.model_kwargs.get("stream_options", {}).get("include_usage"), + self.stream_usage, + ] + for source in stream_usage_sources: + if isinstance(source, bool): + return source + return self.stream_usage + def _stream( self, messages: list[BaseMessage], stop: list[str] | None = None, run_manager: CallbackManagerForLLMRun | None = None, + *, + stream_usage: bool | None = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: if _is_huggingface_endpoint(self.llm): + stream_usage = self._should_stream_usage( + stream_usage=stream_usage, **kwargs + ) + if stream_usage: + kwargs["stream_options"] = {"include_usage": stream_usage} message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs, "stream": True} @@ -650,7 +679,20 @@ class ChatHuggingFace(BaseChatModel): messages=message_dicts, **params ): if len(chunk["choices"]) == 0: + if usage := chunk.get("usage"): + usage_msg = AIMessageChunk( + content="", + additional_kwargs={}, + response_metadata={}, + usage_metadata={ + "input_tokens": usage.get("prompt_tokens", 0), + "output_tokens": usage.get("completion_tokens", 0), + "total_tokens": usage.get("total_tokens", 0), + }, + ) + yield ChatGenerationChunk(message=usage_msg) continue + choice = chunk["choices"][0] message_chunk = _convert_chunk_to_message_chunk( chunk, default_chunk_class @@ -688,8 +730,13 @@ class ChatHuggingFace(BaseChatModel): messages: list[BaseMessage], stop: list[str] | None = None, run_manager: AsyncCallbackManagerForLLMRun | None = None, + *, + stream_usage: bool | None = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: + stream_usage = self._should_stream_usage(stream_usage=stream_usage, **kwargs) + if stream_usage: + kwargs["stream_options"] = {"include_usage": stream_usage} message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs, "stream": True} @@ -699,7 +746,20 @@ class ChatHuggingFace(BaseChatModel): messages=message_dicts, **params ): if len(chunk["choices"]) == 0: + if usage := chunk.get("usage"): + usage_msg = AIMessageChunk( + content="", + additional_kwargs={}, + response_metadata={}, + usage_metadata={ + "input_tokens": usage.get("prompt_tokens", 0), + "output_tokens": usage.get("completion_tokens", 0), + "total_tokens": usage.get("total_tokens", 0), + }, + ) + yield ChatGenerationChunk(message=usage_msg) continue + choice = chunk["choices"][0] message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class) generation_info = {} diff --git a/libs/partners/huggingface/tests/integration_tests/test_chat_models.py b/libs/partners/huggingface/tests/integration_tests/test_chat_models.py new file mode 100644 index 00000000000..4a64f7f0b5a --- /dev/null +++ b/libs/partners/huggingface/tests/integration_tests/test_chat_models.py @@ -0,0 +1,22 @@ +from langchain_core.messages import AIMessageChunk + +from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint + + +def test_stream_usage() -> None: + """Test we are able to configure stream options on models that require it.""" + llm = HuggingFaceEndpoint( # type: ignore[call-arg] # (model is inferred in class) + repo_id="google/gemma-3-27b-it", + task="conversational", + provider="nebius", + ) + + model = ChatHuggingFace(llm=llm, stream_usage=True) + + full: AIMessageChunk | None = None + for chunk in model.stream("hello"): + assert isinstance(chunk, AIMessageChunk) + full = chunk if full is None else full + chunk + + assert isinstance(full, AIMessageChunk) + assert full.usage_metadata