From bf6a5eb122d083a471ae4823af519f07c87dfa8c Mon Sep 17 00:00:00 2001 From: Paul Date: Fri, 12 Dec 2025 20:35:16 +0530 Subject: [PATCH] fix(huggingface): Helper logic for init_chat_model with HuggingFace backend (#34259) --- .../langchain_classic/chat_models/base.py | 5 +- .../langchain/chat_models/base.py | 2 +- .../chat_models/huggingface.py | 51 ++++++++++++++++++- 3 files changed, 53 insertions(+), 5 deletions(-) diff --git a/libs/langchain/langchain_classic/chat_models/base.py b/libs/langchain/langchain_classic/chat_models/base.py index 8135d3b103f..e637b866f04 100644 --- a/libs/langchain/langchain_classic/chat_models/base.py +++ b/libs/langchain/langchain_classic/chat_models/base.py @@ -444,10 +444,9 @@ def _init_chat_model_helper( if model_provider == "huggingface": _check_pkg("langchain_huggingface") - from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline + from langchain_huggingface import ChatHuggingFace - llm = HuggingFacePipeline.from_model_id(model_id=model, **kwargs) - return ChatHuggingFace(llm=llm) + return ChatHuggingFace.from_model_id(model_id=model, **kwargs) if model_provider == "groq": _check_pkg("langchain_groq") diff --git a/libs/langchain_v1/langchain/chat_models/base.py b/libs/langchain_v1/langchain/chat_models/base.py index b7d425a396b..8c5965cf0ea 100644 --- a/libs/langchain_v1/langchain/chat_models/base.py +++ b/libs/langchain_v1/langchain/chat_models/base.py @@ -405,7 +405,7 @@ def _init_chat_model_helper( _check_pkg("langchain_huggingface") from langchain_huggingface import ChatHuggingFace - return ChatHuggingFace(model_id=model, **kwargs) + return ChatHuggingFace.from_model_id(model_id=model, **kwargs) if model_provider == "groq": _check_pkg("langchain_groq") from langchain_groq import ChatGroq diff --git a/libs/partners/huggingface/langchain_huggingface/chat_models/huggingface.py b/libs/partners/huggingface/langchain_huggingface/chat_models/huggingface.py index 725c40337e9..cedb1327efe 100644 --- a/libs/partners/huggingface/langchain_huggingface/chat_models/huggingface.py +++ b/libs/partners/huggingface/langchain_huggingface/chat_models/huggingface.py @@ -7,7 +7,11 @@ import json from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence from dataclasses import dataclass from operator import itemgetter -from typing import Any, Literal, cast +from typing import TYPE_CHECKING, Any, Literal, cast + +if TYPE_CHECKING: + from langchain_huggingface.llms.huggingface_endpoint import HuggingFaceEndpoint + from langchain_huggingface.llms.huggingface_pipeline import HuggingFacePipeline from langchain_core.callbacks.manager import ( AsyncCallbackManagerForLLMRun, @@ -599,6 +603,51 @@ class ChatHuggingFace(BaseChatModel): self.profile = _get_default_model_profile(self.model_id) return self + @classmethod + def from_model_id( + cls, + model_id: str, + task: str | None = None, + backend: Literal["pipeline", "endpoint", "text-gen"] = "pipeline", + **kwargs: Any, + ) -> ChatHuggingFace: + """Construct a ChatHuggingFace model from a model_id. + + Args: + model_id: The model ID of the Hugging Face model. + task: The task to perform (e.g., "text-generation"). + backend: The backend to use. One of "pipeline", "endpoint", "text-gen". + **kwargs: Additional arguments to pass to the backend or ChatHuggingFace. + """ + llm: ( + Any # HuggingFacePipeline, HuggingFaceEndpoint, HuggingFaceTextGenInference + ) + if backend == "pipeline": + from langchain_huggingface.llms.huggingface_pipeline import ( + HuggingFacePipeline, + ) + + llm = HuggingFacePipeline.from_model_id( + model_id=model_id, task=cast(str, task), **kwargs + ) + elif backend == "endpoint": + from langchain_huggingface.llms.huggingface_endpoint import ( + HuggingFaceEndpoint, + ) + + llm = HuggingFaceEndpoint(repo_id=model_id, task=task, **kwargs) + elif backend == "text-gen": + from langchain_community.llms.huggingface_text_gen_inference import ( # type: ignore[import-not-found] + HuggingFaceTextGenInference, + ) + + llm = HuggingFaceTextGenInference(inference_server_url=model_id, **kwargs) + else: + msg = f"Unknown backend: {backend}" + raise ValueError(msg) + + return cls(llm=llm, **kwargs) + def _create_chat_result(self, response: dict) -> ChatResult: generations = [] token_usage = response.get("usage", {})