From a003a0baf6fb65ef1d1e1433bd183ab7a86fdbd2 Mon Sep 17 00:00:00 2001 From: Karan V Date: Fri, 28 Jul 2023 06:31:04 +0530 Subject: [PATCH] fix(petals) allows to run models that aren't Bloom (Support for LLama and newer models) (#8356) In this PR: - Removed restricted model loading logic for Petals-Bloom - Removed petals imports (DistributedBloomForCausalLM, BloomTokenizerFast) - Instead imported more generalized versions of loader (AutoDistributedModelForCausalLM, AutoTokenizer) - Updated the Petals example notebook to allow for a successful installation of Petals in Apple Silicon Macs - Tag maintainer: @hwchase17, @baskaryan --------- Co-authored-by: Bagatur --- docs/extras/integrations/llms/petals_example.ipynb | 6 ++++-- libs/langchain/langchain/llms/petals.py | 10 ++++++---- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/docs/extras/integrations/llms/petals_example.ipynb b/docs/extras/integrations/llms/petals_example.ipynb index f326888b0e5..8232ecd6c61 100644 --- a/docs/extras/integrations/llms/petals_example.ipynb +++ b/docs/extras/integrations/llms/petals_example.ipynb @@ -16,7 +16,9 @@ "metadata": {}, "source": [ "## Install petals\n", - "The `petals` package is required to use the Petals API. Install `petals` using `pip3 install petals`." + "The `petals` package is required to use the Petals API. Install `petals` using `pip3 install petals`.\n", + "\n", + "For Apple Silicon(M1/M2) users please follow this guide [https://github.com/bigscience-workshop/petals/issues/147#issuecomment-1365379642](https://github.com/bigscience-workshop/petals/issues/147#issuecomment-1365379642) to install petals " ] }, { @@ -62,7 +64,7 @@ }, "outputs": [ { - "name": "stdin", + "name": "stdout", "output_type": "stream", "text": [ " ········\n" diff --git a/libs/langchain/langchain/llms/petals.py b/libs/langchain/langchain/llms/petals.py index 413437dc086..1b74734cd54 100644 --- a/libs/langchain/langchain/llms/petals.py +++ b/libs/langchain/langchain/llms/petals.py @@ -93,12 +93,14 @@ class Petals(LLM): values, "huggingface_api_key", "HUGGINGFACE_API_KEY" ) try: - from petals import DistributedBloomForCausalLM - from transformers import BloomTokenizerFast + from petals import AutoDistributedModelForCausalLM + from transformers import AutoTokenizer model_name = values["model_name"] - values["tokenizer"] = BloomTokenizerFast.from_pretrained(model_name) - values["client"] = DistributedBloomForCausalLM.from_pretrained(model_name) + values["tokenizer"] = AutoTokenizer.from_pretrained(model_name) + values["client"] = AutoDistributedModelForCausalLM.from_pretrained( + model_name + ) values["huggingface_api_key"] = huggingface_api_key except ImportError: