fix(petals) allows to run models that aren't Bloom (Support for LLama and newer models) (#8356)

In this PR:

- Removed restricted model loading logic for Petals-Bloom
- Removed petals imports (DistributedBloomForCausalLM,
BloomTokenizerFast)
- Instead imported more generalized versions of loader
(AutoDistributedModelForCausalLM, AutoTokenizer)
- Updated the Petals example notebook to allow for a successful
installation of Petals in Apple Silicon Macs

- Tag maintainer: @hwchase17, @baskaryan

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Karan V
2023-07-28 06:31:04 +05:30
committed by GitHub
parent e758e9e7f5
commit a003a0baf6
2 changed files with 10 additions and 6 deletions

View File

@@ -93,12 +93,14 @@ class Petals(LLM):
values, "huggingface_api_key", "HUGGINGFACE_API_KEY"
)
try:
from petals import DistributedBloomForCausalLM
from transformers import BloomTokenizerFast
from petals import AutoDistributedModelForCausalLM
from transformers import AutoTokenizer
model_name = values["model_name"]
values["tokenizer"] = BloomTokenizerFast.from_pretrained(model_name)
values["client"] = DistributedBloomForCausalLM.from_pretrained(model_name)
values["tokenizer"] = AutoTokenizer.from_pretrained(model_name)
values["client"] = AutoDistributedModelForCausalLM.from_pretrained(
model_name
)
values["huggingface_api_key"] = huggingface_api_key
except ImportError: