mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-10 07:21:03 +00:00
partners/huggingface[patch]: fix HuggingFacePipeline model_id parameter (#27514)
**Description:** Fixes issue with model parameter not getting initialized correctly when passing transformers pipeline **Issue:** https://github.com/langchain-ai/langchain/issues/25915
This commit is contained in:
@@ -2,12 +2,12 @@ from __future__ import annotations # type: ignore[import-not-found]
|
||||
|
||||
import importlib.util
|
||||
import logging
|
||||
from typing import Any, Iterator, List, Mapping, Optional
|
||||
from typing import Any, Dict, Iterator, List, Mapping, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import BaseLLM
|
||||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import ConfigDict, model_validator
|
||||
|
||||
DEFAULT_MODEL_ID = "gpt2"
|
||||
DEFAULT_TASK = "text-generation"
|
||||
@@ -55,8 +55,10 @@ class HuggingFacePipeline(BaseLLM):
|
||||
"""
|
||||
|
||||
pipeline: Any = None #: :meta private:
|
||||
model_id: str = DEFAULT_MODEL_ID
|
||||
"""Model name to use."""
|
||||
model_id: Optional[str] = None
|
||||
"""The model name. If not set explicitly by the user,
|
||||
it will be inferred from the provided pipeline (if available).
|
||||
If neither is provided, the DEFAULT_MODEL_ID will be used."""
|
||||
model_kwargs: Optional[dict] = None
|
||||
"""Keyword arguments passed to the model."""
|
||||
pipeline_kwargs: Optional[dict] = None
|
||||
@@ -68,6 +70,17 @@ class HuggingFacePipeline(BaseLLM):
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def pre_init_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Ensure model_id is set either by pipeline or user input."""
|
||||
if "model_id" not in values:
|
||||
if "pipeline" in values and values["pipeline"]:
|
||||
values["model_id"] = values["pipeline"].model.name_or_path
|
||||
else:
|
||||
values["model_id"] = DEFAULT_MODEL_ID
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def from_model_id(
|
||||
cls,
|
||||
|
@@ -0,0 +1,50 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from langchain_huggingface import HuggingFacePipeline
|
||||
|
||||
DEFAULT_MODEL_ID = "gpt2"
|
||||
|
||||
|
||||
def test_initialization_default() -> None:
|
||||
"""Test default initialization"""
|
||||
|
||||
llm = HuggingFacePipeline()
|
||||
|
||||
assert llm.model_id == DEFAULT_MODEL_ID
|
||||
|
||||
|
||||
@patch("transformers.pipeline")
|
||||
def test_initialization_with_pipeline(mock_pipeline: MagicMock) -> None:
|
||||
"""Test initialization with a pipeline object"""
|
||||
|
||||
mock_pipe = MagicMock()
|
||||
mock_pipe.model.name_or_path = "mock-model-id"
|
||||
mock_pipeline.return_value = mock_pipe
|
||||
|
||||
llm = HuggingFacePipeline(pipeline=mock_pipe)
|
||||
|
||||
assert llm.model_id == "mock-model-id"
|
||||
|
||||
|
||||
@patch("transformers.AutoTokenizer.from_pretrained")
|
||||
@patch("transformers.AutoModelForCausalLM.from_pretrained")
|
||||
@patch("transformers.pipeline")
|
||||
def test_initialization_with_from_model_id(
|
||||
mock_pipeline: MagicMock, mock_model: MagicMock, mock_tokenizer: MagicMock
|
||||
) -> None:
|
||||
"""Test initialization with the from_model_id method"""
|
||||
|
||||
mock_tokenizer.return_value = MagicMock(pad_token_id=0)
|
||||
mock_model.return_value = MagicMock()
|
||||
|
||||
mock_pipe = MagicMock()
|
||||
mock_pipe.task = "text-generation"
|
||||
mock_pipe.model = mock_model.return_value
|
||||
mock_pipeline.return_value = mock_pipe
|
||||
|
||||
llm = HuggingFacePipeline.from_model_id(
|
||||
model_id="mock-model-id",
|
||||
task="text-generation",
|
||||
)
|
||||
|
||||
assert llm.model_id == "mock-model-id"
|
Reference in New Issue
Block a user