From 5f563e040a1606b25e53559dbe492671a6011f63 Mon Sep 17 00:00:00 2001 From: Alex Sherstinsky Date: Mon, 8 Apr 2024 11:54:29 -0700 Subject: [PATCH] community: extend Predibase integration to support fine-tuned LLM adapters (#19979) - [x] **PR title**: "package: description" - Where "package" is whichever of langchain, community, core, experimental, etc. is being modified. Use "docs: ..." for purely docs changes, "templates: ..." for template changes, "infra: ..." for CI changes. - Example: "community: add foobar LLM" - [x] **PR message**: ***Delete this entire checklist*** and replace with - **Description:** Langchain-Predibase integration was failing, because it was not current with the Predibase SDK; in addition, Predibase integration tests were instantiating the Langchain Community `Predibase` class with one required argument (`model`) missing. This change updates the Predibase SDK usage and fixes the integration tests. - **Twitter handle:** `@alexsherstinsky` - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, hwchase17. --------- Co-authored-by: Erick Friis --- docs/docs/integrations/llms/predibase.ipynb | 53 +++++++++++++++++-- docs/docs/integrations/providers/predibase.md | 16 +++++- .../langchain_community/llms/predibase.py | 31 ++++++++--- .../llms/test_predibase.py | 17 ++++++ 4 files changed, 106 insertions(+), 11 deletions(-) rename libs/community/tests/{integration_tests => unit_tests}/llms/test_predibase.py (54%) diff --git a/docs/docs/integrations/llms/predibase.ipynb b/docs/docs/integrations/llms/predibase.ipynb index 750cbf90387..ea2257189c0 100644 --- a/docs/docs/integrations/llms/predibase.ipynb +++ b/docs/docs/integrations/llms/predibase.ipynb @@ -50,7 +50,24 @@ "from langchain_community.llms import Predibase\n", "\n", "model = Predibase(\n", - " model=\"vicuna-13b\", predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\")\n", + " model=\"mistral-7b\",\n", + " predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.llms import Predibase\n", + "\n", + "# With an adapter, fine-tuned on the specified model\n", + "model = Predibase(\n", + " model=\"mistral-7b\",\n", + " adapter_id=\"predibase/e2e_nlg\",\n", + " predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n", ")" ] }, @@ -66,19 +83,43 @@ }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, "source": [ "## Chain Call Setup" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "from langchain_community.llms import Predibase\n", + "\n", + "model = Predibase(\n", + " model=\"mistral-7b\", predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\")\n", + ")" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ + "# With an adapter, fine-tuned on the specified model\n", "llm = Predibase(\n", - " model=\"vicuna-13b\", predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\")\n", + " model=\"mistral-7b\",\n", + " adapter_id=\"predibase/e2e_nlg\",\n", + " predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n", ")" ] }, @@ -169,7 +210,11 @@ "from langchain_community.llms import Predibase\n", "\n", "model = Predibase(\n", - " model=\"my-finetuned-LLM\", predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\")\n", + " model=\"my-base-LLM\",\n", + " adapter_id=\"my-finetuned-adapter-id\",\n", + " predibase_api_key=os.environ.get(\n", + " \"PREDIBASE_API_TOKEN\"\n", + " ), # Adapter argument is optional.\n", ")\n", "# replace my-finetuned-LLM with the name of your model in Predibase" ] diff --git a/docs/docs/integrations/providers/predibase.md b/docs/docs/integrations/providers/predibase.md index 31a445e99ad..3c04512bc03 100644 --- a/docs/docs/integrations/providers/predibase.md +++ b/docs/docs/integrations/providers/predibase.md @@ -17,7 +17,21 @@ os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}" from langchain_community.llms import Predibase -model = Predibase(model = 'vicuna-13b', predibase_api_key=os.environ.get('PREDIBASE_API_TOKEN')) +model = Predibase(model="mistral-7b"", predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN")) + +response = model("Can you recommend me a nice dry wine?") +print(response) +``` + +Predibase also supports adapters that are fine-tuned on the base model given by the `model` argument: + +```python +import os +os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}" + +from langchain_community.llms import Predibase + +model = Predibase(model="mistral-7b"", adapter_id="predibase/e2e_nlg", predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN")) response = model("Can you recommend me a nice dry wine?") print(response) diff --git a/libs/community/langchain_community/llms/predibase.py b/libs/community/langchain_community/llms/predibase.py index 182ee0acd37..ee9882ed4e6 100644 --- a/libs/community/langchain_community/llms/predibase.py +++ b/libs/community/langchain_community/llms/predibase.py @@ -10,10 +10,18 @@ class Predibase(LLM): To use, you should have the ``predibase`` python package installed, and have your Predibase API key. + + The `model` parameter is the Predibase "serverless" base_model ID + (see https://docs.predibase.com/user-guide/inference/models for the catalog). + + An optional `adapter_id` parameter is the HuggingFace ID of a fine-tuned LLM + adapter, whose base model is the `model` parameter; the fine-tuned adapter + must be compatible with its base model; otherwise, an error is raised. """ model: str predibase_api_key: SecretStr + adapter_id: Optional[str] = None model_kwargs: Dict[str, Any] = Field(default_factory=dict) default_options_for_generation: dict = Field( { @@ -38,7 +46,10 @@ class Predibase(LLM): from predibase import PredibaseClient from predibase.pql import get_session from predibase.pql.api import Session - from predibase.resource.llm.interface import LLMDeployment + from predibase.resource.llm.interface import ( + HuggingFaceLLM, + LLMDeployment, + ) from predibase.resource.llm.response import GeneratedResponse session: Session = get_session( @@ -55,15 +66,23 @@ class Predibase(LLM): except ValueError as e: raise ValueError("Your API key is not correct. Please try again") from e options: Dict[str, Union[str, float]] = ( - kwargs or self.default_options_for_generation + self.model_kwargs or self.default_options_for_generation ) base_llm_deployment: LLMDeployment = pc.LLM( uri=f"pb://deployments/{self.model}" ) - result: GeneratedResponse = base_llm_deployment.generate( - prompt=prompt, - options=options, - ) + result: GeneratedResponse + if self.adapter_id: + adapter_model: HuggingFaceLLM = pc.LLM(uri=f"hf://{self.adapter_id}") + result = base_llm_deployment.with_adapter(model=adapter_model).generate( + prompt=prompt, + options=options, + ) + else: + result = base_llm_deployment.generate( + prompt=prompt, + options=options, + ) return result.response @property diff --git a/libs/community/tests/integration_tests/llms/test_predibase.py b/libs/community/tests/unit_tests/llms/test_predibase.py similarity index 54% rename from libs/community/tests/integration_tests/llms/test_predibase.py rename to libs/community/tests/unit_tests/llms/test_predibase.py index 88ac72cfc89..c875ef1bece 100644 --- a/libs/community/tests/integration_tests/llms/test_predibase.py +++ b/libs/community/tests/unit_tests/llms/test_predibase.py @@ -17,3 +17,20 @@ def test_api_key_masked_when_passed_via_constructor( captured = capsys.readouterr() assert captured.out == "**********" + + +def test_specifying_adapter_id_argument() -> None: + llm = Predibase(model="my_llm", predibase_api_key="secret-api-key") + assert not llm.adapter_id + + llm = Predibase( + model="my_llm", predibase_api_key="secret-api-key", adapter_id="my-hf-adapter" + ) + assert llm.adapter_id == "my-hf-adapter" + + llm = Predibase( + model="my_llm", + adapter_id="my-other-hf-adapter", + predibase_api_key="secret-api-key", + ) + assert llm.adapter_id == "my-other-hf-adapter"