mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 18:50:33 +00:00
fix(huggingface): add stream_usage support for ChatHuggingFace invoke/stream (#32708)
This commit is contained in:
@@ -499,6 +499,9 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
"""Modify the likelihood of specified tokens appearing in the completion."""
|
"""Modify the likelihood of specified tokens appearing in the completion."""
|
||||||
streaming: bool = False
|
streaming: bool = False
|
||||||
"""Whether to stream the results or not."""
|
"""Whether to stream the results or not."""
|
||||||
|
stream_usage: bool | None = None
|
||||||
|
"""Whether to include usage metadata in streaming output. If True, an additional
|
||||||
|
message chunk will be generated during the stream including usage metadata."""
|
||||||
n: int | None = None
|
n: int | None = None
|
||||||
"""Number of chat completions to generate for each prompt."""
|
"""Number of chat completions to generate for each prompt."""
|
||||||
top_p: float | None = None
|
top_p: float | None = None
|
||||||
@@ -634,14 +637,40 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
)
|
)
|
||||||
return self._to_chat_result(llm_result)
|
return self._to_chat_result(llm_result)
|
||||||
|
|
||||||
|
def _should_stream_usage(
|
||||||
|
self, *, stream_usage: bool | None = None, **kwargs: Any
|
||||||
|
) -> bool | None:
|
||||||
|
"""Determine whether to include usage metadata in streaming output.
|
||||||
|
|
||||||
|
For backwards compatibility, we check for `stream_options` passed
|
||||||
|
explicitly to kwargs or in the model_kwargs and override self.stream_usage.
|
||||||
|
"""
|
||||||
|
stream_usage_sources = [ # order of precedence
|
||||||
|
stream_usage,
|
||||||
|
kwargs.get("stream_options", {}).get("include_usage"),
|
||||||
|
self.model_kwargs.get("stream_options", {}).get("include_usage"),
|
||||||
|
self.stream_usage,
|
||||||
|
]
|
||||||
|
for source in stream_usage_sources:
|
||||||
|
if isinstance(source, bool):
|
||||||
|
return source
|
||||||
|
return self.stream_usage
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
messages: list[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
stop: list[str] | None = None,
|
stop: list[str] | None = None,
|
||||||
run_manager: CallbackManagerForLLMRun | None = None,
|
run_manager: CallbackManagerForLLMRun | None = None,
|
||||||
|
*,
|
||||||
|
stream_usage: bool | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[ChatGenerationChunk]:
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
if _is_huggingface_endpoint(self.llm):
|
if _is_huggingface_endpoint(self.llm):
|
||||||
|
stream_usage = self._should_stream_usage(
|
||||||
|
stream_usage=stream_usage, **kwargs
|
||||||
|
)
|
||||||
|
if stream_usage:
|
||||||
|
kwargs["stream_options"] = {"include_usage": stream_usage}
|
||||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
params = {**params, **kwargs, "stream": True}
|
params = {**params, **kwargs, "stream": True}
|
||||||
|
|
||||||
@@ -650,7 +679,20 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
messages=message_dicts, **params
|
messages=message_dicts, **params
|
||||||
):
|
):
|
||||||
if len(chunk["choices"]) == 0:
|
if len(chunk["choices"]) == 0:
|
||||||
|
if usage := chunk.get("usage"):
|
||||||
|
usage_msg = AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={},
|
||||||
|
response_metadata={},
|
||||||
|
usage_metadata={
|
||||||
|
"input_tokens": usage.get("prompt_tokens", 0),
|
||||||
|
"output_tokens": usage.get("completion_tokens", 0),
|
||||||
|
"total_tokens": usage.get("total_tokens", 0),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
yield ChatGenerationChunk(message=usage_msg)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
choice = chunk["choices"][0]
|
choice = chunk["choices"][0]
|
||||||
message_chunk = _convert_chunk_to_message_chunk(
|
message_chunk = _convert_chunk_to_message_chunk(
|
||||||
chunk, default_chunk_class
|
chunk, default_chunk_class
|
||||||
@@ -688,8 +730,13 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
messages: list[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
stop: list[str] | None = None,
|
stop: list[str] | None = None,
|
||||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||||
|
*,
|
||||||
|
stream_usage: bool | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[ChatGenerationChunk]:
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
|
stream_usage = self._should_stream_usage(stream_usage=stream_usage, **kwargs)
|
||||||
|
if stream_usage:
|
||||||
|
kwargs["stream_options"] = {"include_usage": stream_usage}
|
||||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
params = {**params, **kwargs, "stream": True}
|
params = {**params, **kwargs, "stream": True}
|
||||||
|
|
||||||
@@ -699,7 +746,20 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
messages=message_dicts, **params
|
messages=message_dicts, **params
|
||||||
):
|
):
|
||||||
if len(chunk["choices"]) == 0:
|
if len(chunk["choices"]) == 0:
|
||||||
|
if usage := chunk.get("usage"):
|
||||||
|
usage_msg = AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={},
|
||||||
|
response_metadata={},
|
||||||
|
usage_metadata={
|
||||||
|
"input_tokens": usage.get("prompt_tokens", 0),
|
||||||
|
"output_tokens": usage.get("completion_tokens", 0),
|
||||||
|
"total_tokens": usage.get("total_tokens", 0),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
yield ChatGenerationChunk(message=usage_msg)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
choice = chunk["choices"][0]
|
choice = chunk["choices"][0]
|
||||||
message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
|
message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
|
||||||
generation_info = {}
|
generation_info = {}
|
||||||
|
|||||||
@@ -0,0 +1,22 @@
|
|||||||
|
from langchain_core.messages import AIMessageChunk
|
||||||
|
|
||||||
|
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_usage() -> None:
|
||||||
|
"""Test we are able to configure stream options on models that require it."""
|
||||||
|
llm = HuggingFaceEndpoint( # type: ignore[call-arg] # (model is inferred in class)
|
||||||
|
repo_id="google/gemma-3-27b-it",
|
||||||
|
task="conversational",
|
||||||
|
provider="nebius",
|
||||||
|
)
|
||||||
|
|
||||||
|
model = ChatHuggingFace(llm=llm, stream_usage=True)
|
||||||
|
|
||||||
|
full: AIMessageChunk | None = None
|
||||||
|
for chunk in model.stream("hello"):
|
||||||
|
assert isinstance(chunk, AIMessageChunk)
|
||||||
|
full = chunk if full is None else full + chunk
|
||||||
|
|
||||||
|
assert isinstance(full, AIMessageChunk)
|
||||||
|
assert full.usage_metadata
|
||||||
Reference in New Issue
Block a user