community: for Predibase -- enable both Predibase-hosted and HuggingFace-hosted fine-tuned adapter repositories (#20370)

This commit is contained in:
Alex Sherstinsky
2024-04-12 08:32:00 -07:00
committed by GitHub
parent 5395c409cb
commit fad0962643
4 changed files with 110 additions and 9 deletions

View File

@@ -14,14 +14,18 @@ class Predibase(LLM):
The `model` parameter is the Predibase "serverless" base_model ID
(see https://docs.predibase.com/user-guide/inference/models for the catalog).
An optional `adapter_id` parameter is the HuggingFace ID of a fine-tuned LLM
adapter, whose base model is the `model` parameter; the fine-tuned adapter
must be compatible with its base model; otherwise, an error is raised.
An optional `adapter_id` parameter is the Predibase ID or HuggingFace ID of a
fine-tuned LLM adapter, whose base model is the `model` parameter; the
fine-tuned adapter must be compatible with its base model;
otherwise, an error is raised. If a Predibase ID references the
fine-tuned adapter, then the `adapter_version` in the adapter repository can
be optionally specified; omitting it defaults to the most recent version.
"""
model: str
predibase_api_key: SecretStr
adapter_id: Optional[str] = None
adapter_version: Optional[int] = None
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
default_options_for_generation: dict = Field(
{
@@ -45,12 +49,16 @@ class Predibase(LLM):
try:
from predibase import PredibaseClient
from predibase.pql import get_session
from predibase.pql.api import Session
from predibase.pql.api import (
ServerResponseError,
Session,
)
from predibase.resource.llm.interface import (
HuggingFaceLLM,
LLMDeployment,
)
from predibase.resource.llm.response import GeneratedResponse
from predibase.resource.model import Model
session: Session = get_session(
token=self.predibase_api_key.get_secret_value(),
@@ -73,7 +81,20 @@ class Predibase(LLM):
)
result: GeneratedResponse
if self.adapter_id:
adapter_model: HuggingFaceLLM = pc.LLM(uri=f"hf://{self.adapter_id}")
"""
Attempt to retrieve the fine-tuned adapter from a Predibase repository.
If absent, then load the fine-tuned adapter from a HuggingFace repository.
"""
adapter_model: Union[Model, HuggingFaceLLM]
try:
adapter_model = pc.get_model(
name=self.adapter_id,
version=self.adapter_version,
model_id=None,
)
except ServerResponseError:
# Predibase does not recognize the adapter ID (query HuggingFace).
adapter_model = pc.LLM(uri=f"hf://{self.adapter_id}")
result = base_llm_deployment.with_adapter(model=adapter_model).generate(
prompt=prompt,
options=options,