mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-01 11:02:37 +00:00
fix(petals) allows to run models that aren't Bloom (Support for LLama and newer models) (#8356)
In this PR: - Removed restricted model loading logic for Petals-Bloom - Removed petals imports (DistributedBloomForCausalLM, BloomTokenizerFast) - Instead imported more generalized versions of loader (AutoDistributedModelForCausalLM, AutoTokenizer) - Updated the Petals example notebook to allow for a successful installation of Petals in Apple Silicon Macs - Tag maintainer: @hwchase17, @baskaryan --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -93,12 +93,14 @@ class Petals(LLM):
|
||||
values, "huggingface_api_key", "HUGGINGFACE_API_KEY"
|
||||
)
|
||||
try:
|
||||
from petals import DistributedBloomForCausalLM
|
||||
from transformers import BloomTokenizerFast
|
||||
from petals import AutoDistributedModelForCausalLM
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
model_name = values["model_name"]
|
||||
values["tokenizer"] = BloomTokenizerFast.from_pretrained(model_name)
|
||||
values["client"] = DistributedBloomForCausalLM.from_pretrained(model_name)
|
||||
values["tokenizer"] = AutoTokenizer.from_pretrained(model_name)
|
||||
values["client"] = AutoDistributedModelForCausalLM.from_pretrained(
|
||||
model_name
|
||||
)
|
||||
values["huggingface_api_key"] = huggingface_api_key
|
||||
|
||||
except ImportError:
|
||||
|
Reference in New Issue
Block a user