mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-09 23:12:38 +00:00
partners/huggingface[patch]: fix HuggingFacePipeline model_id parameter (#27514)
**Description:** Fixes issue with model parameter not getting initialized correctly when passing transformers pipeline **Issue:** https://github.com/langchain-ai/langchain/issues/25915
This commit is contained in:
@@ -2,12 +2,12 @@ from __future__ import annotations # type: ignore[import-not-found]
|
||||
|
||||
import importlib.util
|
||||
import logging
|
||||
from typing import Any, Iterator, List, Mapping, Optional
|
||||
from typing import Any, Dict, Iterator, List, Mapping, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import BaseLLM
|
||||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import ConfigDict, model_validator
|
||||
|
||||
DEFAULT_MODEL_ID = "gpt2"
|
||||
DEFAULT_TASK = "text-generation"
|
||||
@@ -55,8 +55,10 @@ class HuggingFacePipeline(BaseLLM):
|
||||
"""
|
||||
|
||||
pipeline: Any = None #: :meta private:
|
||||
model_id: str = DEFAULT_MODEL_ID
|
||||
"""Model name to use."""
|
||||
model_id: Optional[str] = None
|
||||
"""The model name. If not set explicitly by the user,
|
||||
it will be inferred from the provided pipeline (if available).
|
||||
If neither is provided, the DEFAULT_MODEL_ID will be used."""
|
||||
model_kwargs: Optional[dict] = None
|
||||
"""Keyword arguments passed to the model."""
|
||||
pipeline_kwargs: Optional[dict] = None
|
||||
@@ -68,6 +70,17 @@ class HuggingFacePipeline(BaseLLM):
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def pre_init_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Ensure model_id is set either by pipeline or user input."""
|
||||
if "model_id" not in values:
|
||||
if "pipeline" in values and values["pipeline"]:
|
||||
values["model_id"] = values["pipeline"].model.name_or_path
|
||||
else:
|
||||
values["model_id"] = DEFAULT_MODEL_ID
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def from_model_id(
|
||||
cls,
|
||||
|
Reference in New Issue
Block a user