diff --git a/docs/docs/integrations/llms/predibase.ipynb b/docs/docs/integrations/llms/predibase.ipynb index ea2257189c0..fabd36d75fb 100644 --- a/docs/docs/integrations/llms/predibase.ipynb +++ b/docs/docs/integrations/llms/predibase.ipynb @@ -63,7 +63,24 @@ "source": [ "from langchain_community.llms import Predibase\n", "\n", - "# With an adapter, fine-tuned on the specified model\n", + "# With a fine-tuned adapter hosted at Predibase (adapter_version can be specified; omitting it is equivalent to the most recent version).\n", + "model = Predibase(\n", + " model=\"mistral-7b\",\n", + " adapter_id=\"e2e_nlg\",\n", + " adapter_version=1,\n", + " predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.llms import Predibase\n", + "\n", + "# With a fine-tuned adapter hosted at HuggingFace (adapter_version does not apply and will be ignored).\n", "model = Predibase(\n", " model=\"mistral-7b\",\n", " adapter_id=\"predibase/e2e_nlg\",\n", @@ -109,13 +126,32 @@ ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "# With a fine-tuned adapter hosted at Predibase (adapter_version can be specified; omitting it is equivalent to the most recent version).\n", + "model = Predibase(\n", + " model=\"mistral-7b\",\n", + " adapter_id=\"e2e_nlg\",\n", + " adapter_version=1,\n", + " predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n", + ")" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "# With an adapter, fine-tuned on the specified model\n", + "# With a fine-tuned adapter hosted at HuggingFace (adapter_version does not apply and will be ignored).\n", "llm = Predibase(\n", " model=\"mistral-7b\",\n", " adapter_id=\"predibase/e2e_nlg\",\n", @@ -211,7 +247,8 @@ "\n", "model = Predibase(\n", " model=\"my-base-LLM\",\n", - " adapter_id=\"my-finetuned-adapter-id\",\n", + " adapter_id=\"my-finetuned-adapter-id\", # Supports both, Predibase-hosted and HuggingFace-hosted model repositories.\n", + " # adapter_version=1, # optional (returns the latest, if omitted)\n", " predibase_api_key=os.environ.get(\n", " \"PREDIBASE_API_TOKEN\"\n", " ), # Adapter argument is optional.\n", diff --git a/docs/docs/integrations/providers/predibase.md b/docs/docs/integrations/providers/predibase.md index 3c04512bc03..5a88ff117f3 100644 --- a/docs/docs/integrations/providers/predibase.md +++ b/docs/docs/integrations/providers/predibase.md @@ -23,6 +23,21 @@ response = model("Can you recommend me a nice dry wine?") print(response) ``` +Predibase also supports Predibase-hosted and HuggingFace-hosted adapters that are fine-tuned on the base model given by the `model` argument: + +```python +import os +os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}" + +from langchain_community.llms import Predibase + +# The fine-tuned adapter is hosted at Predibase (adapter_version can be specified; omitting it is equivalent to the most recent version). +model = Predibase(model="mistral-7b"", adapter_id="e2e_nlg", adapter_version=1, predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN")) + +response = model("Can you recommend me a nice dry wine?") +print(response) +``` + Predibase also supports adapters that are fine-tuned on the base model given by the `model` argument: ```python @@ -31,6 +46,7 @@ os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}" from langchain_community.llms import Predibase +# The fine-tuned adapter is hosted at HuggingFace (adapter_version does not apply and will be ignored). model = Predibase(model="mistral-7b"", adapter_id="predibase/e2e_nlg", predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN")) response = model("Can you recommend me a nice dry wine?") diff --git a/libs/community/langchain_community/llms/predibase.py b/libs/community/langchain_community/llms/predibase.py index ee9882ed4e6..e3f5da7fd9e 100644 --- a/libs/community/langchain_community/llms/predibase.py +++ b/libs/community/langchain_community/llms/predibase.py @@ -14,14 +14,18 @@ class Predibase(LLM): The `model` parameter is the Predibase "serverless" base_model ID (see https://docs.predibase.com/user-guide/inference/models for the catalog). - An optional `adapter_id` parameter is the HuggingFace ID of a fine-tuned LLM - adapter, whose base model is the `model` parameter; the fine-tuned adapter - must be compatible with its base model; otherwise, an error is raised. + An optional `adapter_id` parameter is the Predibase ID or HuggingFace ID of a + fine-tuned LLM adapter, whose base model is the `model` parameter; the + fine-tuned adapter must be compatible with its base model; + otherwise, an error is raised. If a Predibase ID references the + fine-tuned adapter, then the `adapter_version` in the adapter repository can + be optionally specified; omitting it defaults to the most recent version. """ model: str predibase_api_key: SecretStr adapter_id: Optional[str] = None + adapter_version: Optional[int] = None model_kwargs: Dict[str, Any] = Field(default_factory=dict) default_options_for_generation: dict = Field( { @@ -45,12 +49,16 @@ class Predibase(LLM): try: from predibase import PredibaseClient from predibase.pql import get_session - from predibase.pql.api import Session + from predibase.pql.api import ( + ServerResponseError, + Session, + ) from predibase.resource.llm.interface import ( HuggingFaceLLM, LLMDeployment, ) from predibase.resource.llm.response import GeneratedResponse + from predibase.resource.model import Model session: Session = get_session( token=self.predibase_api_key.get_secret_value(), @@ -73,7 +81,20 @@ class Predibase(LLM): ) result: GeneratedResponse if self.adapter_id: - adapter_model: HuggingFaceLLM = pc.LLM(uri=f"hf://{self.adapter_id}") + """ + Attempt to retrieve the fine-tuned adapter from a Predibase repository. + If absent, then load the fine-tuned adapter from a HuggingFace repository. + """ + adapter_model: Union[Model, HuggingFaceLLM] + try: + adapter_model = pc.get_model( + name=self.adapter_id, + version=self.adapter_version, + model_id=None, + ) + except ServerResponseError: + # Predibase does not recognize the adapter ID (query HuggingFace). + adapter_model = pc.LLM(uri=f"hf://{self.adapter_id}") result = base_llm_deployment.with_adapter(model=adapter_model).generate( prompt=prompt, options=options, diff --git a/libs/community/tests/unit_tests/llms/test_predibase.py b/libs/community/tests/unit_tests/llms/test_predibase.py index c875ef1bece..9a9fba7f0ef 100644 --- a/libs/community/tests/unit_tests/llms/test_predibase.py +++ b/libs/community/tests/unit_tests/llms/test_predibase.py @@ -24,9 +24,12 @@ def test_specifying_adapter_id_argument() -> None: assert not llm.adapter_id llm = Predibase( - model="my_llm", predibase_api_key="secret-api-key", adapter_id="my-hf-adapter" + model="my_llm", + predibase_api_key="secret-api-key", + adapter_id="my-hf-adapter", ) assert llm.adapter_id == "my-hf-adapter" + assert llm.adapter_version is None llm = Predibase( model="my_llm", @@ -34,3 +37,27 @@ def test_specifying_adapter_id_argument() -> None: predibase_api_key="secret-api-key", ) assert llm.adapter_id == "my-other-hf-adapter" + assert llm.adapter_version is None + + +def test_specifying_adapter_id_and_adapter_version_arguments() -> None: + llm = Predibase(model="my_llm", predibase_api_key="secret-api-key") + assert not llm.adapter_id + + llm = Predibase( + model="my_llm", + predibase_api_key="secret-api-key", + adapter_id="my-hf-adapter", + adapter_version=None, + ) + assert llm.adapter_id == "my-hf-adapter" + assert llm.adapter_version is None + + llm = Predibase( + model="my_llm", + adapter_id="my-other-hf-adapter", + adapter_version=3, + predibase_api_key="secret-api-key", + ) + assert llm.adapter_id == "my-other-hf-adapter" + assert llm.adapter_version == 3