mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-02 21:58:46 +00:00
fix(petals) allows to run models that aren't Bloom (Support for LLama and newer models) (#8356)
In this PR: - Removed restricted model loading logic for Petals-Bloom - Removed petals imports (DistributedBloomForCausalLM, BloomTokenizerFast) - Instead imported more generalized versions of loader (AutoDistributedModelForCausalLM, AutoTokenizer) - Updated the Petals example notebook to allow for a successful installation of Petals in Apple Silicon Macs - Tag maintainer: @hwchase17, @baskaryan --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
e758e9e7f5
commit
a003a0baf6
@ -16,7 +16,9 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Install petals\n",
|
||||
"The `petals` package is required to use the Petals API. Install `petals` using `pip3 install petals`."
|
||||
"The `petals` package is required to use the Petals API. Install `petals` using `pip3 install petals`.\n",
|
||||
"\n",
|
||||
"For Apple Silicon(M1/M2) users please follow this guide [https://github.com/bigscience-workshop/petals/issues/147#issuecomment-1365379642](https://github.com/bigscience-workshop/petals/issues/147#issuecomment-1365379642) to install petals "
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -62,7 +64,7 @@
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdin",
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" ········\n"
|
||||
|
@ -93,12 +93,14 @@ class Petals(LLM):
|
||||
values, "huggingface_api_key", "HUGGINGFACE_API_KEY"
|
||||
)
|
||||
try:
|
||||
from petals import DistributedBloomForCausalLM
|
||||
from transformers import BloomTokenizerFast
|
||||
from petals import AutoDistributedModelForCausalLM
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
model_name = values["model_name"]
|
||||
values["tokenizer"] = BloomTokenizerFast.from_pretrained(model_name)
|
||||
values["client"] = DistributedBloomForCausalLM.from_pretrained(model_name)
|
||||
values["tokenizer"] = AutoTokenizer.from_pretrained(model_name)
|
||||
values["client"] = AutoDistributedModelForCausalLM.from_pretrained(
|
||||
model_name
|
||||
)
|
||||
values["huggingface_api_key"] = huggingface_api_key
|
||||
|
||||
except ImportError:
|
||||
|
Loading…
Reference in New Issue
Block a user