diff --git a/libs/partners/huggingface/langchain_huggingface/llms/huggingface_pipeline.py b/libs/partners/huggingface/langchain_huggingface/llms/huggingface_pipeline.py index b1261bcfc4e..3e743a64289 100644 --- a/libs/partners/huggingface/langchain_huggingface/llms/huggingface_pipeline.py +++ b/libs/partners/huggingface/langchain_huggingface/llms/huggingface_pipeline.py @@ -2,12 +2,12 @@ from __future__ import annotations # type: ignore[import-not-found] import importlib.util import logging -from typing import Any, Iterator, List, Mapping, Optional +from typing import Any, Dict, Iterator, List, Mapping, Optional from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import BaseLLM from langchain_core.outputs import Generation, GenerationChunk, LLMResult -from pydantic import ConfigDict +from pydantic import ConfigDict, model_validator DEFAULT_MODEL_ID = "gpt2" DEFAULT_TASK = "text-generation" @@ -55,8 +55,10 @@ class HuggingFacePipeline(BaseLLM): """ pipeline: Any = None #: :meta private: - model_id: str = DEFAULT_MODEL_ID - """Model name to use.""" + model_id: Optional[str] = None + """The model name. If not set explicitly by the user, + it will be inferred from the provided pipeline (if available). + If neither is provided, the DEFAULT_MODEL_ID will be used.""" model_kwargs: Optional[dict] = None """Keyword arguments passed to the model.""" pipeline_kwargs: Optional[dict] = None @@ -68,6 +70,17 @@ class HuggingFacePipeline(BaseLLM): extra="forbid", ) + @model_validator(mode="before") + @classmethod + def pre_init_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Ensure model_id is set either by pipeline or user input.""" + if "model_id" not in values: + if "pipeline" in values and values["pipeline"]: + values["model_id"] = values["pipeline"].model.name_or_path + else: + values["model_id"] = DEFAULT_MODEL_ID + return values + @classmethod def from_model_id( cls, diff --git a/libs/partners/huggingface/tests/unit_tests/test_huggingface_pipeline.py b/libs/partners/huggingface/tests/unit_tests/test_huggingface_pipeline.py new file mode 100644 index 00000000000..b4e1639493d --- /dev/null +++ b/libs/partners/huggingface/tests/unit_tests/test_huggingface_pipeline.py @@ -0,0 +1,50 @@ +from unittest.mock import MagicMock, patch + +from langchain_huggingface import HuggingFacePipeline + +DEFAULT_MODEL_ID = "gpt2" + + +def test_initialization_default() -> None: + """Test default initialization""" + + llm = HuggingFacePipeline() + + assert llm.model_id == DEFAULT_MODEL_ID + + +@patch("transformers.pipeline") +def test_initialization_with_pipeline(mock_pipeline: MagicMock) -> None: + """Test initialization with a pipeline object""" + + mock_pipe = MagicMock() + mock_pipe.model.name_or_path = "mock-model-id" + mock_pipeline.return_value = mock_pipe + + llm = HuggingFacePipeline(pipeline=mock_pipe) + + assert llm.model_id == "mock-model-id" + + +@patch("transformers.AutoTokenizer.from_pretrained") +@patch("transformers.AutoModelForCausalLM.from_pretrained") +@patch("transformers.pipeline") +def test_initialization_with_from_model_id( + mock_pipeline: MagicMock, mock_model: MagicMock, mock_tokenizer: MagicMock +) -> None: + """Test initialization with the from_model_id method""" + + mock_tokenizer.return_value = MagicMock(pad_token_id=0) + mock_model.return_value = MagicMock() + + mock_pipe = MagicMock() + mock_pipe.task = "text-generation" + mock_pipe.model = mock_model.return_value + mock_pipeline.return_value = mock_pipe + + llm = HuggingFacePipeline.from_model_id( + model_id="mock-model-id", + task="text-generation", + ) + + assert llm.model_id == "mock-model-id"