mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-02 01:23:07 +00:00
Harrison/hf pipeline (#780)
Co-authored-by: Parth Chadha <parth29@gmail.com>
This commit is contained in:
parent
c658f0aed3
commit
5bb2952860
@ -1,4 +1,6 @@
|
||||
"""Wrapper around HuggingFace Pipeline APIs."""
|
||||
import importlib.util
|
||||
import logging
|
||||
from typing import Any, List, Mapping, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
@ -10,6 +12,8 @@ DEFAULT_MODEL_ID = "gpt2"
|
||||
DEFAULT_TASK = "text-generation"
|
||||
VALID_TASKS = ("text2text-generation", "text-generation")
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class HuggingFacePipeline(LLM, BaseModel):
|
||||
"""Wrapper around HuggingFace Pipeline API.
|
||||
@ -56,6 +60,7 @@ class HuggingFacePipeline(LLM, BaseModel):
|
||||
cls,
|
||||
model_id: str,
|
||||
task: str,
|
||||
device: int = -1,
|
||||
model_kwargs: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLM:
|
||||
@ -68,8 +73,16 @@ class HuggingFacePipeline(LLM, BaseModel):
|
||||
)
|
||||
from transformers import pipeline as hf_pipeline
|
||||
|
||||
_model_kwargs = model_kwargs or {}
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs)
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import transformers python package. "
|
||||
"Please it install it with `pip install transformers`."
|
||||
)
|
||||
|
||||
_model_kwargs = model_kwargs or {}
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs)
|
||||
|
||||
try:
|
||||
if task == "text-generation":
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, **_model_kwargs)
|
||||
elif task == "text2text-generation":
|
||||
@ -79,25 +92,47 @@ class HuggingFacePipeline(LLM, BaseModel):
|
||||
f"Got invalid task {task}, "
|
||||
f"currently only {VALID_TASKS} are supported"
|
||||
)
|
||||
pipeline = hf_pipeline(
|
||||
task=task, model=model, tokenizer=tokenizer, model_kwargs=_model_kwargs
|
||||
)
|
||||
if pipeline.task not in VALID_TASKS:
|
||||
raise ValueError(
|
||||
f"Got invalid task {pipeline.task}, "
|
||||
f"currently only {VALID_TASKS} are supported"
|
||||
)
|
||||
return cls(
|
||||
pipeline=pipeline,
|
||||
model_id=model_id,
|
||||
model_kwargs=_model_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
raise ValueError(
|
||||
"Could not import transformers python package. "
|
||||
"Please it install it with `pip install transformers`."
|
||||
f"Could not load the {task} model due to missing dependencies."
|
||||
) from e
|
||||
|
||||
if importlib.util.find_spec("torch") is not None:
|
||||
import torch
|
||||
|
||||
cuda_device_count = torch.cuda.device_count()
|
||||
if device < -1 or (device >= cuda_device_count):
|
||||
raise ValueError(
|
||||
f"Got device=={device}, "
|
||||
f"device is required to be within [-1, {cuda_device_count})"
|
||||
)
|
||||
if device < 0 and cuda_device_count > 0:
|
||||
logger.warning(
|
||||
"Device has %d GPUs available. "
|
||||
"Provide device={deviceId} to `from_model_id` to use available"
|
||||
"GPUs for execution. deviceId is -1 (default) for CPU and "
|
||||
"can be a positive integer associated with CUDA device id.",
|
||||
cuda_device_count,
|
||||
)
|
||||
|
||||
pipeline = hf_pipeline(
|
||||
task=task,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
device=device,
|
||||
model_kwargs=_model_kwargs,
|
||||
)
|
||||
if pipeline.task not in VALID_TASKS:
|
||||
raise ValueError(
|
||||
f"Got invalid task {pipeline.task}, "
|
||||
f"currently only {VALID_TASKS} are supported"
|
||||
)
|
||||
return cls(
|
||||
pipeline=pipeline,
|
||||
model_id=model_id,
|
||||
model_kwargs=_model_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
|
Loading…
Reference in New Issue
Block a user