mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
fix(huggingface): Helper logic for init_chat_model with HuggingFace backend (#34259)
This commit is contained in:
@@ -444,10 +444,9 @@ def _init_chat_model_helper(
|
||||
|
||||
if model_provider == "huggingface":
|
||||
_check_pkg("langchain_huggingface")
|
||||
from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline
|
||||
from langchain_huggingface import ChatHuggingFace
|
||||
|
||||
llm = HuggingFacePipeline.from_model_id(model_id=model, **kwargs)
|
||||
return ChatHuggingFace(llm=llm)
|
||||
return ChatHuggingFace.from_model_id(model_id=model, **kwargs)
|
||||
|
||||
if model_provider == "groq":
|
||||
_check_pkg("langchain_groq")
|
||||
|
||||
@@ -405,7 +405,7 @@ def _init_chat_model_helper(
|
||||
_check_pkg("langchain_huggingface")
|
||||
from langchain_huggingface import ChatHuggingFace
|
||||
|
||||
return ChatHuggingFace(model_id=model, **kwargs)
|
||||
return ChatHuggingFace.from_model_id(model_id=model, **kwargs)
|
||||
if model_provider == "groq":
|
||||
_check_pkg("langchain_groq")
|
||||
from langchain_groq import ChatGroq
|
||||
|
||||
@@ -7,7 +7,11 @@ import json
|
||||
from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from operator import itemgetter
|
||||
from typing import Any, Literal, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_huggingface.llms.huggingface_endpoint import HuggingFaceEndpoint
|
||||
from langchain_huggingface.llms.huggingface_pipeline import HuggingFacePipeline
|
||||
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
@@ -599,6 +603,51 @@ class ChatHuggingFace(BaseChatModel):
|
||||
self.profile = _get_default_model_profile(self.model_id)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_model_id(
|
||||
cls,
|
||||
model_id: str,
|
||||
task: str | None = None,
|
||||
backend: Literal["pipeline", "endpoint", "text-gen"] = "pipeline",
|
||||
**kwargs: Any,
|
||||
) -> ChatHuggingFace:
|
||||
"""Construct a ChatHuggingFace model from a model_id.
|
||||
|
||||
Args:
|
||||
model_id: The model ID of the Hugging Face model.
|
||||
task: The task to perform (e.g., "text-generation").
|
||||
backend: The backend to use. One of "pipeline", "endpoint", "text-gen".
|
||||
**kwargs: Additional arguments to pass to the backend or ChatHuggingFace.
|
||||
"""
|
||||
llm: (
|
||||
Any # HuggingFacePipeline, HuggingFaceEndpoint, HuggingFaceTextGenInference
|
||||
)
|
||||
if backend == "pipeline":
|
||||
from langchain_huggingface.llms.huggingface_pipeline import (
|
||||
HuggingFacePipeline,
|
||||
)
|
||||
|
||||
llm = HuggingFacePipeline.from_model_id(
|
||||
model_id=model_id, task=cast(str, task), **kwargs
|
||||
)
|
||||
elif backend == "endpoint":
|
||||
from langchain_huggingface.llms.huggingface_endpoint import (
|
||||
HuggingFaceEndpoint,
|
||||
)
|
||||
|
||||
llm = HuggingFaceEndpoint(repo_id=model_id, task=task, **kwargs)
|
||||
elif backend == "text-gen":
|
||||
from langchain_community.llms.huggingface_text_gen_inference import ( # type: ignore[import-not-found]
|
||||
HuggingFaceTextGenInference,
|
||||
)
|
||||
|
||||
llm = HuggingFaceTextGenInference(inference_server_url=model_id, **kwargs)
|
||||
else:
|
||||
msg = f"Unknown backend: {backend}"
|
||||
raise ValueError(msg)
|
||||
|
||||
return cls(llm=llm, **kwargs)
|
||||
|
||||
def _create_chat_result(self, response: dict) -> ChatResult:
|
||||
generations = []
|
||||
token_usage = response.get("usage", {})
|
||||
|
||||
Reference in New Issue
Block a user