mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-11 07:25:29 +00:00
48 lines
1.4 KiB
Python
48 lines
1.4 KiB
Python
from unittest.mock import MagicMock, patch
|
|
|
|
from langchain_huggingface import HuggingFacePipeline
|
|
|
|
DEFAULT_MODEL_ID = "gpt2"
|
|
|
|
|
|
def test_initialization_default() -> None:
|
|
"""Test default initialization."""
|
|
llm = HuggingFacePipeline()
|
|
|
|
assert llm.model_id == DEFAULT_MODEL_ID
|
|
|
|
|
|
@patch("transformers.pipeline")
|
|
def test_initialization_with_pipeline(mock_pipeline: MagicMock) -> None:
|
|
"""Test initialization with a pipeline object."""
|
|
mock_pipe = MagicMock()
|
|
mock_pipe.model.name_or_path = "mock-model-id"
|
|
mock_pipeline.return_value = mock_pipe
|
|
|
|
llm = HuggingFacePipeline(pipeline=mock_pipe)
|
|
|
|
assert llm.model_id == "mock-model-id"
|
|
|
|
|
|
@patch("transformers.AutoTokenizer.from_pretrained")
|
|
@patch("transformers.AutoModelForCausalLM.from_pretrained")
|
|
@patch("transformers.pipeline")
|
|
def test_initialization_with_from_model_id(
|
|
mock_pipeline: MagicMock, mock_model: MagicMock, mock_tokenizer: MagicMock
|
|
) -> None:
|
|
"""Test initialization with the from_model_id method."""
|
|
mock_tokenizer.return_value = MagicMock(pad_token_id=0)
|
|
mock_model.return_value = MagicMock()
|
|
|
|
mock_pipe = MagicMock()
|
|
mock_pipe.task = "text-generation"
|
|
mock_pipe.model = mock_model.return_value
|
|
mock_pipeline.return_value = mock_pipe
|
|
|
|
llm = HuggingFacePipeline.from_model_id(
|
|
model_id="mock-model-id",
|
|
task="text-generation",
|
|
)
|
|
|
|
assert llm.model_id == "mock-model-id"
|