diff --git a/langchain/llms/huggingface_pipeline.py b/langchain/llms/huggingface_pipeline.py index 22bd3661a5e..ee5678c6fc6 100644 --- a/langchain/llms/huggingface_pipeline.py +++ b/langchain/llms/huggingface_pipeline.py @@ -1,4 +1,6 @@ """Wrapper around HuggingFace Pipeline APIs.""" +import importlib.util +import logging from typing import Any, List, Mapping, Optional from pydantic import BaseModel, Extra @@ -10,6 +12,8 @@ DEFAULT_MODEL_ID = "gpt2" DEFAULT_TASK = "text-generation" VALID_TASKS = ("text2text-generation", "text-generation") +logger = logging.getLogger() + class HuggingFacePipeline(LLM, BaseModel): """Wrapper around HuggingFace Pipeline API. @@ -56,6 +60,7 @@ class HuggingFacePipeline(LLM, BaseModel): cls, model_id: str, task: str, + device: int = -1, model_kwargs: Optional[dict] = None, **kwargs: Any, ) -> LLM: @@ -68,8 +73,16 @@ class HuggingFacePipeline(LLM, BaseModel): ) from transformers import pipeline as hf_pipeline - _model_kwargs = model_kwargs or {} - tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs) + except ImportError: + raise ValueError( + "Could not import transformers python package. " + "Please it install it with `pip install transformers`." + ) + + _model_kwargs = model_kwargs or {} + tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs) + + try: if task == "text-generation": model = AutoModelForCausalLM.from_pretrained(model_id, **_model_kwargs) elif task == "text2text-generation": @@ -79,25 +92,47 @@ class HuggingFacePipeline(LLM, BaseModel): f"Got invalid task {task}, " f"currently only {VALID_TASKS} are supported" ) - pipeline = hf_pipeline( - task=task, model=model, tokenizer=tokenizer, model_kwargs=_model_kwargs - ) - if pipeline.task not in VALID_TASKS: - raise ValueError( - f"Got invalid task {pipeline.task}, " - f"currently only {VALID_TASKS} are supported" - ) - return cls( - pipeline=pipeline, - model_id=model_id, - model_kwargs=_model_kwargs, - **kwargs, - ) - except ImportError: + except ImportError as e: raise ValueError( - "Could not import transformers python package. " - "Please it install it with `pip install transformers`." + f"Could not load the {task} model due to missing dependencies." + ) from e + + if importlib.util.find_spec("torch") is not None: + import torch + + cuda_device_count = torch.cuda.device_count() + if device < -1 or (device >= cuda_device_count): + raise ValueError( + f"Got device=={device}, " + f"device is required to be within [-1, {cuda_device_count})" + ) + if device < 0 and cuda_device_count > 0: + logger.warning( + "Device has %d GPUs available. " + "Provide device={deviceId} to `from_model_id` to use available" + "GPUs for execution. deviceId is -1 (default) for CPU and " + "can be a positive integer associated with CUDA device id.", + cuda_device_count, + ) + + pipeline = hf_pipeline( + task=task, + model=model, + tokenizer=tokenizer, + device=device, + model_kwargs=_model_kwargs, + ) + if pipeline.task not in VALID_TASKS: + raise ValueError( + f"Got invalid task {pipeline.task}, " + f"currently only {VALID_TASKS} are supported" ) + return cls( + pipeline=pipeline, + model_id=model_id, + model_kwargs=_model_kwargs, + **kwargs, + ) @property def _identifying_params(self) -> Mapping[str, Any]: