mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-17 15:35:14 +00:00
community: for Predibase -- enable both Predibase-hosted and HuggingFace-hosted fine-tuned adapter repositories (#20370)
This commit is contained in:
@@ -14,14 +14,18 @@ class Predibase(LLM):
|
||||
The `model` parameter is the Predibase "serverless" base_model ID
|
||||
(see https://docs.predibase.com/user-guide/inference/models for the catalog).
|
||||
|
||||
An optional `adapter_id` parameter is the HuggingFace ID of a fine-tuned LLM
|
||||
adapter, whose base model is the `model` parameter; the fine-tuned adapter
|
||||
must be compatible with its base model; otherwise, an error is raised.
|
||||
An optional `adapter_id` parameter is the Predibase ID or HuggingFace ID of a
|
||||
fine-tuned LLM adapter, whose base model is the `model` parameter; the
|
||||
fine-tuned adapter must be compatible with its base model;
|
||||
otherwise, an error is raised. If a Predibase ID references the
|
||||
fine-tuned adapter, then the `adapter_version` in the adapter repository can
|
||||
be optionally specified; omitting it defaults to the most recent version.
|
||||
"""
|
||||
|
||||
model: str
|
||||
predibase_api_key: SecretStr
|
||||
adapter_id: Optional[str] = None
|
||||
adapter_version: Optional[int] = None
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
default_options_for_generation: dict = Field(
|
||||
{
|
||||
@@ -45,12 +49,16 @@ class Predibase(LLM):
|
||||
try:
|
||||
from predibase import PredibaseClient
|
||||
from predibase.pql import get_session
|
||||
from predibase.pql.api import Session
|
||||
from predibase.pql.api import (
|
||||
ServerResponseError,
|
||||
Session,
|
||||
)
|
||||
from predibase.resource.llm.interface import (
|
||||
HuggingFaceLLM,
|
||||
LLMDeployment,
|
||||
)
|
||||
from predibase.resource.llm.response import GeneratedResponse
|
||||
from predibase.resource.model import Model
|
||||
|
||||
session: Session = get_session(
|
||||
token=self.predibase_api_key.get_secret_value(),
|
||||
@@ -73,7 +81,20 @@ class Predibase(LLM):
|
||||
)
|
||||
result: GeneratedResponse
|
||||
if self.adapter_id:
|
||||
adapter_model: HuggingFaceLLM = pc.LLM(uri=f"hf://{self.adapter_id}")
|
||||
"""
|
||||
Attempt to retrieve the fine-tuned adapter from a Predibase repository.
|
||||
If absent, then load the fine-tuned adapter from a HuggingFace repository.
|
||||
"""
|
||||
adapter_model: Union[Model, HuggingFaceLLM]
|
||||
try:
|
||||
adapter_model = pc.get_model(
|
||||
name=self.adapter_id,
|
||||
version=self.adapter_version,
|
||||
model_id=None,
|
||||
)
|
||||
except ServerResponseError:
|
||||
# Predibase does not recognize the adapter ID (query HuggingFace).
|
||||
adapter_model = pc.LLM(uri=f"hf://{self.adapter_id}")
|
||||
result = base_llm_deployment.with_adapter(model=adapter_model).generate(
|
||||
prompt=prompt,
|
||||
options=options,
|
||||
|
Reference in New Issue
Block a user