fix(huggingface): Helper logic for init_chat_model with HuggingFace backend (#34259)

This commit is contained in:
Paul
2025-12-12 20:35:16 +05:30
committed by GitHub
parent 5720dea41b
commit bf6a5eb122
3 changed files with 53 additions and 5 deletions

View File

@@ -7,7 +7,11 @@ import json
from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence
from dataclasses import dataclass
from operator import itemgetter
from typing import Any, Literal, cast
from typing import TYPE_CHECKING, Any, Literal, cast
if TYPE_CHECKING:
from langchain_huggingface.llms.huggingface_endpoint import HuggingFaceEndpoint
from langchain_huggingface.llms.huggingface_pipeline import HuggingFacePipeline
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
@@ -599,6 +603,51 @@ class ChatHuggingFace(BaseChatModel):
self.profile = _get_default_model_profile(self.model_id)
return self
@classmethod
def from_model_id(
cls,
model_id: str,
task: str | None = None,
backend: Literal["pipeline", "endpoint", "text-gen"] = "pipeline",
**kwargs: Any,
) -> ChatHuggingFace:
"""Construct a ChatHuggingFace model from a model_id.
Args:
model_id: The model ID of the Hugging Face model.
task: The task to perform (e.g., "text-generation").
backend: The backend to use. One of "pipeline", "endpoint", "text-gen".
**kwargs: Additional arguments to pass to the backend or ChatHuggingFace.
"""
llm: (
Any # HuggingFacePipeline, HuggingFaceEndpoint, HuggingFaceTextGenInference
)
if backend == "pipeline":
from langchain_huggingface.llms.huggingface_pipeline import (
HuggingFacePipeline,
)
llm = HuggingFacePipeline.from_model_id(
model_id=model_id, task=cast(str, task), **kwargs
)
elif backend == "endpoint":
from langchain_huggingface.llms.huggingface_endpoint import (
HuggingFaceEndpoint,
)
llm = HuggingFaceEndpoint(repo_id=model_id, task=task, **kwargs)
elif backend == "text-gen":
from langchain_community.llms.huggingface_text_gen_inference import ( # type: ignore[import-not-found]
HuggingFaceTextGenInference,
)
llm = HuggingFaceTextGenInference(inference_server_url=model_id, **kwargs)
else:
msg = f"Unknown backend: {backend}"
raise ValueError(msg)
return cls(llm=llm, **kwargs)
def _create_chat_result(self, response: dict) -> ChatResult:
generations = []
token_usage = response.get("usage", {})