Harrison/hf pipeline (#780)

Co-authored-by: Parth Chadha <parth29@gmail.com>
This commit is contained in:
Harrison Chase 2023-01-28 08:23:59 -08:00 committed by GitHub
parent c658f0aed3
commit 5bb2952860
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,4 +1,6 @@
"""Wrapper around HuggingFace Pipeline APIs.""" """Wrapper around HuggingFace Pipeline APIs."""
import importlib.util
import logging
from typing import Any, List, Mapping, Optional from typing import Any, List, Mapping, Optional
from pydantic import BaseModel, Extra from pydantic import BaseModel, Extra
@ -10,6 +12,8 @@ DEFAULT_MODEL_ID = "gpt2"
DEFAULT_TASK = "text-generation" DEFAULT_TASK = "text-generation"
VALID_TASKS = ("text2text-generation", "text-generation") VALID_TASKS = ("text2text-generation", "text-generation")
logger = logging.getLogger()
class HuggingFacePipeline(LLM, BaseModel): class HuggingFacePipeline(LLM, BaseModel):
"""Wrapper around HuggingFace Pipeline API. """Wrapper around HuggingFace Pipeline API.
@ -56,6 +60,7 @@ class HuggingFacePipeline(LLM, BaseModel):
cls, cls,
model_id: str, model_id: str,
task: str, task: str,
device: int = -1,
model_kwargs: Optional[dict] = None, model_kwargs: Optional[dict] = None,
**kwargs: Any, **kwargs: Any,
) -> LLM: ) -> LLM:
@ -68,8 +73,16 @@ class HuggingFacePipeline(LLM, BaseModel):
) )
from transformers import pipeline as hf_pipeline from transformers import pipeline as hf_pipeline
except ImportError:
raise ValueError(
"Could not import transformers python package. "
"Please it install it with `pip install transformers`."
)
_model_kwargs = model_kwargs or {} _model_kwargs = model_kwargs or {}
tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs) tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs)
try:
if task == "text-generation": if task == "text-generation":
model = AutoModelForCausalLM.from_pretrained(model_id, **_model_kwargs) model = AutoModelForCausalLM.from_pretrained(model_id, **_model_kwargs)
elif task == "text2text-generation": elif task == "text2text-generation":
@ -79,8 +92,35 @@ class HuggingFacePipeline(LLM, BaseModel):
f"Got invalid task {task}, " f"Got invalid task {task}, "
f"currently only {VALID_TASKS} are supported" f"currently only {VALID_TASKS} are supported"
) )
except ImportError as e:
raise ValueError(
f"Could not load the {task} model due to missing dependencies."
) from e
if importlib.util.find_spec("torch") is not None:
import torch
cuda_device_count = torch.cuda.device_count()
if device < -1 or (device >= cuda_device_count):
raise ValueError(
f"Got device=={device}, "
f"device is required to be within [-1, {cuda_device_count})"
)
if device < 0 and cuda_device_count > 0:
logger.warning(
"Device has %d GPUs available. "
"Provide device={deviceId} to `from_model_id` to use available"
"GPUs for execution. deviceId is -1 (default) for CPU and "
"can be a positive integer associated with CUDA device id.",
cuda_device_count,
)
pipeline = hf_pipeline( pipeline = hf_pipeline(
task=task, model=model, tokenizer=tokenizer, model_kwargs=_model_kwargs task=task,
model=model,
tokenizer=tokenizer,
device=device,
model_kwargs=_model_kwargs,
) )
if pipeline.task not in VALID_TASKS: if pipeline.task not in VALID_TASKS:
raise ValueError( raise ValueError(
@ -93,11 +133,6 @@ class HuggingFacePipeline(LLM, BaseModel):
model_kwargs=_model_kwargs, model_kwargs=_model_kwargs,
**kwargs, **kwargs,
) )
except ImportError:
raise ValueError(
"Could not import transformers python package. "
"Please it install it with `pip install transformers`."
)
@property @property
def _identifying_params(self) -> Mapping[str, Any]: def _identifying_params(self) -> Mapping[str, Any]: