mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 12:18:24 +00:00
community: for Predibase -- enable both Predibase-hosted and HuggingFace-hosted fine-tuned adapter repositories (#20370)
This commit is contained in:
parent
5395c409cb
commit
fad0962643
@ -63,7 +63,24 @@
|
||||
"source": [
|
||||
"from langchain_community.llms import Predibase\n",
|
||||
"\n",
|
||||
"# With an adapter, fine-tuned on the specified model\n",
|
||||
"# With a fine-tuned adapter hosted at Predibase (adapter_version can be specified; omitting it is equivalent to the most recent version).\n",
|
||||
"model = Predibase(\n",
|
||||
" model=\"mistral-7b\",\n",
|
||||
" adapter_id=\"e2e_nlg\",\n",
|
||||
" adapter_version=1,\n",
|
||||
" predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.llms import Predibase\n",
|
||||
"\n",
|
||||
"# With a fine-tuned adapter hosted at HuggingFace (adapter_version does not apply and will be ignored).\n",
|
||||
"model = Predibase(\n",
|
||||
" model=\"mistral-7b\",\n",
|
||||
" adapter_id=\"predibase/e2e_nlg\",\n",
|
||||
@ -109,13 +126,32 @@
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"vscode": {
|
||||
"languageId": "plaintext"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# With a fine-tuned adapter hosted at Predibase (adapter_version can be specified; omitting it is equivalent to the most recent version).\n",
|
||||
"model = Predibase(\n",
|
||||
" model=\"mistral-7b\",\n",
|
||||
" adapter_id=\"e2e_nlg\",\n",
|
||||
" adapter_version=1,\n",
|
||||
" predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# With an adapter, fine-tuned on the specified model\n",
|
||||
"# With a fine-tuned adapter hosted at HuggingFace (adapter_version does not apply and will be ignored).\n",
|
||||
"llm = Predibase(\n",
|
||||
" model=\"mistral-7b\",\n",
|
||||
" adapter_id=\"predibase/e2e_nlg\",\n",
|
||||
@ -211,7 +247,8 @@
|
||||
"\n",
|
||||
"model = Predibase(\n",
|
||||
" model=\"my-base-LLM\",\n",
|
||||
" adapter_id=\"my-finetuned-adapter-id\",\n",
|
||||
" adapter_id=\"my-finetuned-adapter-id\", # Supports both, Predibase-hosted and HuggingFace-hosted model repositories.\n",
|
||||
" # adapter_version=1, # optional (returns the latest, if omitted)\n",
|
||||
" predibase_api_key=os.environ.get(\n",
|
||||
" \"PREDIBASE_API_TOKEN\"\n",
|
||||
" ), # Adapter argument is optional.\n",
|
||||
|
@ -23,6 +23,21 @@ response = model("Can you recommend me a nice dry wine?")
|
||||
print(response)
|
||||
```
|
||||
|
||||
Predibase also supports Predibase-hosted and HuggingFace-hosted adapters that are fine-tuned on the base model given by the `model` argument:
|
||||
|
||||
```python
|
||||
import os
|
||||
os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}"
|
||||
|
||||
from langchain_community.llms import Predibase
|
||||
|
||||
# The fine-tuned adapter is hosted at Predibase (adapter_version can be specified; omitting it is equivalent to the most recent version).
|
||||
model = Predibase(model="mistral-7b"", adapter_id="e2e_nlg", adapter_version=1, predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN"))
|
||||
|
||||
response = model("Can you recommend me a nice dry wine?")
|
||||
print(response)
|
||||
```
|
||||
|
||||
Predibase also supports adapters that are fine-tuned on the base model given by the `model` argument:
|
||||
|
||||
```python
|
||||
@ -31,6 +46,7 @@ os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}"
|
||||
|
||||
from langchain_community.llms import Predibase
|
||||
|
||||
# The fine-tuned adapter is hosted at HuggingFace (adapter_version does not apply and will be ignored).
|
||||
model = Predibase(model="mistral-7b"", adapter_id="predibase/e2e_nlg", predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN"))
|
||||
|
||||
response = model("Can you recommend me a nice dry wine?")
|
||||
|
@ -14,14 +14,18 @@ class Predibase(LLM):
|
||||
The `model` parameter is the Predibase "serverless" base_model ID
|
||||
(see https://docs.predibase.com/user-guide/inference/models for the catalog).
|
||||
|
||||
An optional `adapter_id` parameter is the HuggingFace ID of a fine-tuned LLM
|
||||
adapter, whose base model is the `model` parameter; the fine-tuned adapter
|
||||
must be compatible with its base model; otherwise, an error is raised.
|
||||
An optional `adapter_id` parameter is the Predibase ID or HuggingFace ID of a
|
||||
fine-tuned LLM adapter, whose base model is the `model` parameter; the
|
||||
fine-tuned adapter must be compatible with its base model;
|
||||
otherwise, an error is raised. If a Predibase ID references the
|
||||
fine-tuned adapter, then the `adapter_version` in the adapter repository can
|
||||
be optionally specified; omitting it defaults to the most recent version.
|
||||
"""
|
||||
|
||||
model: str
|
||||
predibase_api_key: SecretStr
|
||||
adapter_id: Optional[str] = None
|
||||
adapter_version: Optional[int] = None
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
default_options_for_generation: dict = Field(
|
||||
{
|
||||
@ -45,12 +49,16 @@ class Predibase(LLM):
|
||||
try:
|
||||
from predibase import PredibaseClient
|
||||
from predibase.pql import get_session
|
||||
from predibase.pql.api import Session
|
||||
from predibase.pql.api import (
|
||||
ServerResponseError,
|
||||
Session,
|
||||
)
|
||||
from predibase.resource.llm.interface import (
|
||||
HuggingFaceLLM,
|
||||
LLMDeployment,
|
||||
)
|
||||
from predibase.resource.llm.response import GeneratedResponse
|
||||
from predibase.resource.model import Model
|
||||
|
||||
session: Session = get_session(
|
||||
token=self.predibase_api_key.get_secret_value(),
|
||||
@ -73,7 +81,20 @@ class Predibase(LLM):
|
||||
)
|
||||
result: GeneratedResponse
|
||||
if self.adapter_id:
|
||||
adapter_model: HuggingFaceLLM = pc.LLM(uri=f"hf://{self.adapter_id}")
|
||||
"""
|
||||
Attempt to retrieve the fine-tuned adapter from a Predibase repository.
|
||||
If absent, then load the fine-tuned adapter from a HuggingFace repository.
|
||||
"""
|
||||
adapter_model: Union[Model, HuggingFaceLLM]
|
||||
try:
|
||||
adapter_model = pc.get_model(
|
||||
name=self.adapter_id,
|
||||
version=self.adapter_version,
|
||||
model_id=None,
|
||||
)
|
||||
except ServerResponseError:
|
||||
# Predibase does not recognize the adapter ID (query HuggingFace).
|
||||
adapter_model = pc.LLM(uri=f"hf://{self.adapter_id}")
|
||||
result = base_llm_deployment.with_adapter(model=adapter_model).generate(
|
||||
prompt=prompt,
|
||||
options=options,
|
||||
|
@ -24,9 +24,12 @@ def test_specifying_adapter_id_argument() -> None:
|
||||
assert not llm.adapter_id
|
||||
|
||||
llm = Predibase(
|
||||
model="my_llm", predibase_api_key="secret-api-key", adapter_id="my-hf-adapter"
|
||||
model="my_llm",
|
||||
predibase_api_key="secret-api-key",
|
||||
adapter_id="my-hf-adapter",
|
||||
)
|
||||
assert llm.adapter_id == "my-hf-adapter"
|
||||
assert llm.adapter_version is None
|
||||
|
||||
llm = Predibase(
|
||||
model="my_llm",
|
||||
@ -34,3 +37,27 @@ def test_specifying_adapter_id_argument() -> None:
|
||||
predibase_api_key="secret-api-key",
|
||||
)
|
||||
assert llm.adapter_id == "my-other-hf-adapter"
|
||||
assert llm.adapter_version is None
|
||||
|
||||
|
||||
def test_specifying_adapter_id_and_adapter_version_arguments() -> None:
|
||||
llm = Predibase(model="my_llm", predibase_api_key="secret-api-key")
|
||||
assert not llm.adapter_id
|
||||
|
||||
llm = Predibase(
|
||||
model="my_llm",
|
||||
predibase_api_key="secret-api-key",
|
||||
adapter_id="my-hf-adapter",
|
||||
adapter_version=None,
|
||||
)
|
||||
assert llm.adapter_id == "my-hf-adapter"
|
||||
assert llm.adapter_version is None
|
||||
|
||||
llm = Predibase(
|
||||
model="my_llm",
|
||||
adapter_id="my-other-hf-adapter",
|
||||
adapter_version=3,
|
||||
predibase_api_key="secret-api-key",
|
||||
)
|
||||
assert llm.adapter_id == "my-other-hf-adapter"
|
||||
assert llm.adapter_version == 3
|
||||
|
Loading…
Reference in New Issue
Block a user