Harrison/hf pipeline (#780)

Co-authored-by: Parth Chadha <parth29@gmail.com>
This commit is contained in:
Harrison Chase 2023-01-28 08:23:59 -08:00 committed by GitHub
parent c658f0aed3
commit 5bb2952860
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,4 +1,6 @@
"""Wrapper around HuggingFace Pipeline APIs.""" """Wrapper around HuggingFace Pipeline APIs."""
import importlib.util
import logging
from typing import Any, List, Mapping, Optional from typing import Any, List, Mapping, Optional
from pydantic import BaseModel, Extra from pydantic import BaseModel, Extra
@ -10,6 +12,8 @@ DEFAULT_MODEL_ID = "gpt2"
DEFAULT_TASK = "text-generation" DEFAULT_TASK = "text-generation"
VALID_TASKS = ("text2text-generation", "text-generation") VALID_TASKS = ("text2text-generation", "text-generation")
logger = logging.getLogger()
class HuggingFacePipeline(LLM, BaseModel): class HuggingFacePipeline(LLM, BaseModel):
"""Wrapper around HuggingFace Pipeline API. """Wrapper around HuggingFace Pipeline API.
@ -56,6 +60,7 @@ class HuggingFacePipeline(LLM, BaseModel):
cls, cls,
model_id: str, model_id: str,
task: str, task: str,
device: int = -1,
model_kwargs: Optional[dict] = None, model_kwargs: Optional[dict] = None,
**kwargs: Any, **kwargs: Any,
) -> LLM: ) -> LLM:
@ -68,8 +73,16 @@ class HuggingFacePipeline(LLM, BaseModel):
) )
from transformers import pipeline as hf_pipeline from transformers import pipeline as hf_pipeline
_model_kwargs = model_kwargs or {} except ImportError:
tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs) raise ValueError(
"Could not import transformers python package. "
"Please it install it with `pip install transformers`."
)
_model_kwargs = model_kwargs or {}
tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs)
try:
if task == "text-generation": if task == "text-generation":
model = AutoModelForCausalLM.from_pretrained(model_id, **_model_kwargs) model = AutoModelForCausalLM.from_pretrained(model_id, **_model_kwargs)
elif task == "text2text-generation": elif task == "text2text-generation":
@ -79,25 +92,47 @@ class HuggingFacePipeline(LLM, BaseModel):
f"Got invalid task {task}, " f"Got invalid task {task}, "
f"currently only {VALID_TASKS} are supported" f"currently only {VALID_TASKS} are supported"
) )
pipeline = hf_pipeline( except ImportError as e:
task=task, model=model, tokenizer=tokenizer, model_kwargs=_model_kwargs
)
if pipeline.task not in VALID_TASKS:
raise ValueError(
f"Got invalid task {pipeline.task}, "
f"currently only {VALID_TASKS} are supported"
)
return cls(
pipeline=pipeline,
model_id=model_id,
model_kwargs=_model_kwargs,
**kwargs,
)
except ImportError:
raise ValueError( raise ValueError(
"Could not import transformers python package. " f"Could not load the {task} model due to missing dependencies."
"Please it install it with `pip install transformers`." ) from e
if importlib.util.find_spec("torch") is not None:
import torch
cuda_device_count = torch.cuda.device_count()
if device < -1 or (device >= cuda_device_count):
raise ValueError(
f"Got device=={device}, "
f"device is required to be within [-1, {cuda_device_count})"
)
if device < 0 and cuda_device_count > 0:
logger.warning(
"Device has %d GPUs available. "
"Provide device={deviceId} to `from_model_id` to use available"
"GPUs for execution. deviceId is -1 (default) for CPU and "
"can be a positive integer associated with CUDA device id.",
cuda_device_count,
)
pipeline = hf_pipeline(
task=task,
model=model,
tokenizer=tokenizer,
device=device,
model_kwargs=_model_kwargs,
)
if pipeline.task not in VALID_TASKS:
raise ValueError(
f"Got invalid task {pipeline.task}, "
f"currently only {VALID_TASKS} are supported"
) )
return cls(
pipeline=pipeline,
model_id=model_id,
model_kwargs=_model_kwargs,
**kwargs,
)
@property @property
def _identifying_params(self) -> Mapping[str, Any]: def _identifying_params(self) -> Mapping[str, Any]: