mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-06 21:20:33 +00:00
community: extend Predibase integration to support fine-tuned LLM adapters (#19979)
- [x] **PR title**: "package: description" - Where "package" is whichever of langchain, community, core, experimental, etc. is being modified. Use "docs: ..." for purely docs changes, "templates: ..." for template changes, "infra: ..." for CI changes. - Example: "community: add foobar LLM" - [x] **PR message**: ***Delete this entire checklist*** and replace with - **Description:** Langchain-Predibase integration was failing, because it was not current with the Predibase SDK; in addition, Predibase integration tests were instantiating the Langchain Community `Predibase` class with one required argument (`model`) missing. This change updates the Predibase SDK usage and fixes the integration tests. - **Twitter handle:** `@alexsherstinsky` - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, hwchase17. --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
a27d88f12a
commit
5f563e040a
@ -50,7 +50,24 @@
|
|||||||
"from langchain_community.llms import Predibase\n",
|
"from langchain_community.llms import Predibase\n",
|
||||||
"\n",
|
"\n",
|
||||||
"model = Predibase(\n",
|
"model = Predibase(\n",
|
||||||
" model=\"vicuna-13b\", predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\")\n",
|
" model=\"mistral-7b\",\n",
|
||||||
|
" predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain_community.llms import Predibase\n",
|
||||||
|
"\n",
|
||||||
|
"# With an adapter, fine-tuned on the specified model\n",
|
||||||
|
"model = Predibase(\n",
|
||||||
|
" model=\"mistral-7b\",\n",
|
||||||
|
" adapter_id=\"predibase/e2e_nlg\",\n",
|
||||||
|
" predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -66,19 +83,43 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"vscode": {
|
||||||
|
"languageId": "plaintext"
|
||||||
|
}
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"## Chain Call Setup"
|
"## Chain Call Setup"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"vscode": {
|
||||||
|
"languageId": "plaintext"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain_community.llms import Predibase\n",
|
||||||
|
"\n",
|
||||||
|
"model = Predibase(\n",
|
||||||
|
" model=\"mistral-7b\", predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\")\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
"# With an adapter, fine-tuned on the specified model\n",
|
||||||
"llm = Predibase(\n",
|
"llm = Predibase(\n",
|
||||||
" model=\"vicuna-13b\", predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\")\n",
|
" model=\"mistral-7b\",\n",
|
||||||
|
" adapter_id=\"predibase/e2e_nlg\",\n",
|
||||||
|
" predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -169,7 +210,11 @@
|
|||||||
"from langchain_community.llms import Predibase\n",
|
"from langchain_community.llms import Predibase\n",
|
||||||
"\n",
|
"\n",
|
||||||
"model = Predibase(\n",
|
"model = Predibase(\n",
|
||||||
" model=\"my-finetuned-LLM\", predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\")\n",
|
" model=\"my-base-LLM\",\n",
|
||||||
|
" adapter_id=\"my-finetuned-adapter-id\",\n",
|
||||||
|
" predibase_api_key=os.environ.get(\n",
|
||||||
|
" \"PREDIBASE_API_TOKEN\"\n",
|
||||||
|
" ), # Adapter argument is optional.\n",
|
||||||
")\n",
|
")\n",
|
||||||
"# replace my-finetuned-LLM with the name of your model in Predibase"
|
"# replace my-finetuned-LLM with the name of your model in Predibase"
|
||||||
]
|
]
|
||||||
|
@ -17,7 +17,21 @@ os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}"
|
|||||||
|
|
||||||
from langchain_community.llms import Predibase
|
from langchain_community.llms import Predibase
|
||||||
|
|
||||||
model = Predibase(model = 'vicuna-13b', predibase_api_key=os.environ.get('PREDIBASE_API_TOKEN'))
|
model = Predibase(model="mistral-7b"", predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN"))
|
||||||
|
|
||||||
|
response = model("Can you recommend me a nice dry wine?")
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
Predibase also supports adapters that are fine-tuned on the base model given by the `model` argument:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}"
|
||||||
|
|
||||||
|
from langchain_community.llms import Predibase
|
||||||
|
|
||||||
|
model = Predibase(model="mistral-7b"", adapter_id="predibase/e2e_nlg", predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN"))
|
||||||
|
|
||||||
response = model("Can you recommend me a nice dry wine?")
|
response = model("Can you recommend me a nice dry wine?")
|
||||||
print(response)
|
print(response)
|
||||||
|
@ -10,10 +10,18 @@ class Predibase(LLM):
|
|||||||
|
|
||||||
To use, you should have the ``predibase`` python package installed,
|
To use, you should have the ``predibase`` python package installed,
|
||||||
and have your Predibase API key.
|
and have your Predibase API key.
|
||||||
|
|
||||||
|
The `model` parameter is the Predibase "serverless" base_model ID
|
||||||
|
(see https://docs.predibase.com/user-guide/inference/models for the catalog).
|
||||||
|
|
||||||
|
An optional `adapter_id` parameter is the HuggingFace ID of a fine-tuned LLM
|
||||||
|
adapter, whose base model is the `model` parameter; the fine-tuned adapter
|
||||||
|
must be compatible with its base model; otherwise, an error is raised.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model: str
|
model: str
|
||||||
predibase_api_key: SecretStr
|
predibase_api_key: SecretStr
|
||||||
|
adapter_id: Optional[str] = None
|
||||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
default_options_for_generation: dict = Field(
|
default_options_for_generation: dict = Field(
|
||||||
{
|
{
|
||||||
@ -38,7 +46,10 @@ class Predibase(LLM):
|
|||||||
from predibase import PredibaseClient
|
from predibase import PredibaseClient
|
||||||
from predibase.pql import get_session
|
from predibase.pql import get_session
|
||||||
from predibase.pql.api import Session
|
from predibase.pql.api import Session
|
||||||
from predibase.resource.llm.interface import LLMDeployment
|
from predibase.resource.llm.interface import (
|
||||||
|
HuggingFaceLLM,
|
||||||
|
LLMDeployment,
|
||||||
|
)
|
||||||
from predibase.resource.llm.response import GeneratedResponse
|
from predibase.resource.llm.response import GeneratedResponse
|
||||||
|
|
||||||
session: Session = get_session(
|
session: Session = get_session(
|
||||||
@ -55,12 +66,20 @@ class Predibase(LLM):
|
|||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise ValueError("Your API key is not correct. Please try again") from e
|
raise ValueError("Your API key is not correct. Please try again") from e
|
||||||
options: Dict[str, Union[str, float]] = (
|
options: Dict[str, Union[str, float]] = (
|
||||||
kwargs or self.default_options_for_generation
|
self.model_kwargs or self.default_options_for_generation
|
||||||
)
|
)
|
||||||
base_llm_deployment: LLMDeployment = pc.LLM(
|
base_llm_deployment: LLMDeployment = pc.LLM(
|
||||||
uri=f"pb://deployments/{self.model}"
|
uri=f"pb://deployments/{self.model}"
|
||||||
)
|
)
|
||||||
result: GeneratedResponse = base_llm_deployment.generate(
|
result: GeneratedResponse
|
||||||
|
if self.adapter_id:
|
||||||
|
adapter_model: HuggingFaceLLM = pc.LLM(uri=f"hf://{self.adapter_id}")
|
||||||
|
result = base_llm_deployment.with_adapter(model=adapter_model).generate(
|
||||||
|
prompt=prompt,
|
||||||
|
options=options,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = base_llm_deployment.generate(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
options=options,
|
options=options,
|
||||||
)
|
)
|
||||||
|
@ -17,3 +17,20 @@ def test_api_key_masked_when_passed_via_constructor(
|
|||||||
captured = capsys.readouterr()
|
captured = capsys.readouterr()
|
||||||
|
|
||||||
assert captured.out == "**********"
|
assert captured.out == "**********"
|
||||||
|
|
||||||
|
|
||||||
|
def test_specifying_adapter_id_argument() -> None:
|
||||||
|
llm = Predibase(model="my_llm", predibase_api_key="secret-api-key")
|
||||||
|
assert not llm.adapter_id
|
||||||
|
|
||||||
|
llm = Predibase(
|
||||||
|
model="my_llm", predibase_api_key="secret-api-key", adapter_id="my-hf-adapter"
|
||||||
|
)
|
||||||
|
assert llm.adapter_id == "my-hf-adapter"
|
||||||
|
|
||||||
|
llm = Predibase(
|
||||||
|
model="my_llm",
|
||||||
|
adapter_id="my-other-hf-adapter",
|
||||||
|
predibase_api_key="secret-api-key",
|
||||||
|
)
|
||||||
|
assert llm.adapter_id == "my-other-hf-adapter"
|
Loading…
Reference in New Issue
Block a user