partners/huggingface[patch]: fix HuggingFacePipeline model_id parameter (#27514)

**Description:** Fixes issue with model parameter not getting
initialized correctly when passing transformers pipeline
**Issue:** https://github.com/langchain-ai/langchain/issues/25915
This commit is contained in:
Andrew Effendi
2024-10-29 10:34:46 -04:00
committed by GitHub
parent 0a465b8032
commit 49517cc1e7
2 changed files with 67 additions and 4 deletions

View File

@@ -2,12 +2,12 @@ from __future__ import annotations # type: ignore[import-not-found]
import importlib.util
import logging
from typing import Any, Iterator, List, Mapping, Optional
from typing import Any, Dict, Iterator, List, Mapping, Optional
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from pydantic import ConfigDict
from pydantic import ConfigDict, model_validator
DEFAULT_MODEL_ID = "gpt2"
DEFAULT_TASK = "text-generation"
@@ -55,8 +55,10 @@ class HuggingFacePipeline(BaseLLM):
"""
pipeline: Any = None #: :meta private:
model_id: str = DEFAULT_MODEL_ID
"""Model name to use."""
model_id: Optional[str] = None
"""The model name. If not set explicitly by the user,
it will be inferred from the provided pipeline (if available).
If neither is provided, the DEFAULT_MODEL_ID will be used."""
model_kwargs: Optional[dict] = None
"""Keyword arguments passed to the model."""
pipeline_kwargs: Optional[dict] = None
@@ -68,6 +70,17 @@ class HuggingFacePipeline(BaseLLM):
extra="forbid",
)
@model_validator(mode="before")
@classmethod
def pre_init_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Ensure model_id is set either by pipeline or user input."""
if "model_id" not in values:
if "pipeline" in values and values["pipeline"]:
values["model_id"] = values["pipeline"].model.name_or_path
else:
values["model_id"] = DEFAULT_MODEL_ID
return values
@classmethod
def from_model_id(
cls,