mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-05 20:58:25 +00:00
community: extend Predibase integration to support fine-tuned LLM adapters (#19979)
- [x] **PR title**: "package: description" - Where "package" is whichever of langchain, community, core, experimental, etc. is being modified. Use "docs: ..." for purely docs changes, "templates: ..." for template changes, "infra: ..." for CI changes. - Example: "community: add foobar LLM" - [x] **PR message**: ***Delete this entire checklist*** and replace with - **Description:** Langchain-Predibase integration was failing, because it was not current with the Predibase SDK; in addition, Predibase integration tests were instantiating the Langchain Community `Predibase` class with one required argument (`model`) missing. This change updates the Predibase SDK usage and fixes the integration tests. - **Twitter handle:** `@alexsherstinsky` - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, hwchase17. --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
a27d88f12a
commit
5f563e040a
@ -50,7 +50,24 @@
|
||||
"from langchain_community.llms import Predibase\n",
|
||||
"\n",
|
||||
"model = Predibase(\n",
|
||||
" model=\"vicuna-13b\", predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\")\n",
|
||||
" model=\"mistral-7b\",\n",
|
||||
" predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.llms import Predibase\n",
|
||||
"\n",
|
||||
"# With an adapter, fine-tuned on the specified model\n",
|
||||
"model = Predibase(\n",
|
||||
" model=\"mistral-7b\",\n",
|
||||
" adapter_id=\"predibase/e2e_nlg\",\n",
|
||||
" predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@ -66,19 +83,43 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"vscode": {
|
||||
"languageId": "plaintext"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"## Chain Call Setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"vscode": {
|
||||
"languageId": "plaintext"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.llms import Predibase\n",
|
||||
"\n",
|
||||
"model = Predibase(\n",
|
||||
" model=\"mistral-7b\", predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\")\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# With an adapter, fine-tuned on the specified model\n",
|
||||
"llm = Predibase(\n",
|
||||
" model=\"vicuna-13b\", predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\")\n",
|
||||
" model=\"mistral-7b\",\n",
|
||||
" adapter_id=\"predibase/e2e_nlg\",\n",
|
||||
" predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@ -169,7 +210,11 @@
|
||||
"from langchain_community.llms import Predibase\n",
|
||||
"\n",
|
||||
"model = Predibase(\n",
|
||||
" model=\"my-finetuned-LLM\", predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\")\n",
|
||||
" model=\"my-base-LLM\",\n",
|
||||
" adapter_id=\"my-finetuned-adapter-id\",\n",
|
||||
" predibase_api_key=os.environ.get(\n",
|
||||
" \"PREDIBASE_API_TOKEN\"\n",
|
||||
" ), # Adapter argument is optional.\n",
|
||||
")\n",
|
||||
"# replace my-finetuned-LLM with the name of your model in Predibase"
|
||||
]
|
||||
|
@ -17,7 +17,21 @@ os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}"
|
||||
|
||||
from langchain_community.llms import Predibase
|
||||
|
||||
model = Predibase(model = 'vicuna-13b', predibase_api_key=os.environ.get('PREDIBASE_API_TOKEN'))
|
||||
model = Predibase(model="mistral-7b"", predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN"))
|
||||
|
||||
response = model("Can you recommend me a nice dry wine?")
|
||||
print(response)
|
||||
```
|
||||
|
||||
Predibase also supports adapters that are fine-tuned on the base model given by the `model` argument:
|
||||
|
||||
```python
|
||||
import os
|
||||
os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}"
|
||||
|
||||
from langchain_community.llms import Predibase
|
||||
|
||||
model = Predibase(model="mistral-7b"", adapter_id="predibase/e2e_nlg", predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN"))
|
||||
|
||||
response = model("Can you recommend me a nice dry wine?")
|
||||
print(response)
|
||||
|
@ -10,10 +10,18 @@ class Predibase(LLM):
|
||||
|
||||
To use, you should have the ``predibase`` python package installed,
|
||||
and have your Predibase API key.
|
||||
|
||||
The `model` parameter is the Predibase "serverless" base_model ID
|
||||
(see https://docs.predibase.com/user-guide/inference/models for the catalog).
|
||||
|
||||
An optional `adapter_id` parameter is the HuggingFace ID of a fine-tuned LLM
|
||||
adapter, whose base model is the `model` parameter; the fine-tuned adapter
|
||||
must be compatible with its base model; otherwise, an error is raised.
|
||||
"""
|
||||
|
||||
model: str
|
||||
predibase_api_key: SecretStr
|
||||
adapter_id: Optional[str] = None
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
default_options_for_generation: dict = Field(
|
||||
{
|
||||
@ -38,7 +46,10 @@ class Predibase(LLM):
|
||||
from predibase import PredibaseClient
|
||||
from predibase.pql import get_session
|
||||
from predibase.pql.api import Session
|
||||
from predibase.resource.llm.interface import LLMDeployment
|
||||
from predibase.resource.llm.interface import (
|
||||
HuggingFaceLLM,
|
||||
LLMDeployment,
|
||||
)
|
||||
from predibase.resource.llm.response import GeneratedResponse
|
||||
|
||||
session: Session = get_session(
|
||||
@ -55,12 +66,20 @@ class Predibase(LLM):
|
||||
except ValueError as e:
|
||||
raise ValueError("Your API key is not correct. Please try again") from e
|
||||
options: Dict[str, Union[str, float]] = (
|
||||
kwargs or self.default_options_for_generation
|
||||
self.model_kwargs or self.default_options_for_generation
|
||||
)
|
||||
base_llm_deployment: LLMDeployment = pc.LLM(
|
||||
uri=f"pb://deployments/{self.model}"
|
||||
)
|
||||
result: GeneratedResponse = base_llm_deployment.generate(
|
||||
result: GeneratedResponse
|
||||
if self.adapter_id:
|
||||
adapter_model: HuggingFaceLLM = pc.LLM(uri=f"hf://{self.adapter_id}")
|
||||
result = base_llm_deployment.with_adapter(model=adapter_model).generate(
|
||||
prompt=prompt,
|
||||
options=options,
|
||||
)
|
||||
else:
|
||||
result = base_llm_deployment.generate(
|
||||
prompt=prompt,
|
||||
options=options,
|
||||
)
|
||||
|
@ -17,3 +17,20 @@ def test_api_key_masked_when_passed_via_constructor(
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert captured.out == "**********"
|
||||
|
||||
|
||||
def test_specifying_adapter_id_argument() -> None:
|
||||
llm = Predibase(model="my_llm", predibase_api_key="secret-api-key")
|
||||
assert not llm.adapter_id
|
||||
|
||||
llm = Predibase(
|
||||
model="my_llm", predibase_api_key="secret-api-key", adapter_id="my-hf-adapter"
|
||||
)
|
||||
assert llm.adapter_id == "my-hf-adapter"
|
||||
|
||||
llm = Predibase(
|
||||
model="my_llm",
|
||||
adapter_id="my-other-hf-adapter",
|
||||
predibase_api_key="secret-api-key",
|
||||
)
|
||||
assert llm.adapter_id == "my-other-hf-adapter"
|
Loading…
Reference in New Issue
Block a user