mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-05 21:12:48 +00:00
community: Support both Predibase SDK-v1 and SDK-v2 in Predibase-LangChain integration (#20859)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Mapping, Optional, Union
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
@@ -17,13 +18,15 @@ class Predibase(LLM):
|
||||
An optional `adapter_id` parameter is the Predibase ID or HuggingFace ID of a
|
||||
fine-tuned LLM adapter, whose base model is the `model` parameter; the
|
||||
fine-tuned adapter must be compatible with its base model;
|
||||
otherwise, an error is raised. If a Predibase ID references the
|
||||
fine-tuned adapter, then the `adapter_version` in the adapter repository can
|
||||
be optionally specified; omitting it defaults to the most recent version.
|
||||
otherwise, an error is raised. If the fine-tuned adapter is hosted at Predibase,
|
||||
then `adapter_version` in the adapter repository must be specified.
|
||||
|
||||
An optional `predibase_sdk_version` parameter defaults to latest SDK version.
|
||||
"""
|
||||
|
||||
model: str
|
||||
predibase_api_key: SecretStr
|
||||
predibase_sdk_version: Optional[str] = None
|
||||
adapter_id: Optional[str] = None
|
||||
adapter_version: Optional[int] = None
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
@@ -46,65 +49,139 @@ class Predibase(LLM):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
try:
|
||||
from predibase import PredibaseClient
|
||||
from predibase.pql import get_session
|
||||
from predibase.pql.api import (
|
||||
ServerResponseError,
|
||||
Session,
|
||||
)
|
||||
from predibase.resource.llm.interface import (
|
||||
HuggingFaceLLM,
|
||||
LLMDeployment,
|
||||
)
|
||||
from predibase.resource.llm.response import GeneratedResponse
|
||||
from predibase.resource.model import Model
|
||||
|
||||
session: Session = get_session(
|
||||
token=self.predibase_api_key.get_secret_value(),
|
||||
gateway="https://api.app.predibase.com/v1",
|
||||
serving_endpoint="serving.app.predibase.com",
|
||||
)
|
||||
pc: PredibaseClient = PredibaseClient(session=session)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import Predibase Python package. "
|
||||
"Please install it with `pip install predibase`."
|
||||
) from e
|
||||
except ValueError as e:
|
||||
raise ValueError("Your API key is not correct. Please try again") from e
|
||||
options: Dict[str, Union[str, float]] = (
|
||||
self.model_kwargs or self.default_options_for_generation
|
||||
)
|
||||
base_llm_deployment: LLMDeployment = pc.LLM(
|
||||
uri=f"pb://deployments/{self.model}"
|
||||
if self._is_deprecated_sdk_version():
|
||||
try:
|
||||
from predibase import PredibaseClient
|
||||
from predibase.pql import get_session
|
||||
from predibase.pql.api import (
|
||||
ServerResponseError,
|
||||
Session,
|
||||
)
|
||||
from predibase.resource.llm.interface import (
|
||||
HuggingFaceLLM,
|
||||
LLMDeployment,
|
||||
)
|
||||
from predibase.resource.llm.response import GeneratedResponse
|
||||
from predibase.resource.model import Model
|
||||
|
||||
session: Session = get_session(
|
||||
token=self.predibase_api_key.get_secret_value(),
|
||||
gateway="https://api.app.predibase.com/v1",
|
||||
serving_endpoint="serving.app.predibase.com",
|
||||
)
|
||||
pc: PredibaseClient = PredibaseClient(session=session)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import Predibase Python package. "
|
||||
"Please install it with `pip install predibase`."
|
||||
) from e
|
||||
except ValueError as e:
|
||||
raise ValueError("Your API key is not correct. Please try again") from e
|
||||
|
||||
base_llm_deployment: LLMDeployment = pc.LLM(
|
||||
uri=f"pb://deployments/{self.model}"
|
||||
)
|
||||
result: GeneratedResponse
|
||||
if self.adapter_id:
|
||||
"""
|
||||
Attempt to retrieve the fine-tuned adapter from a Predibase
|
||||
repository. If absent, then load the fine-tuned adapter
|
||||
from a HuggingFace repository.
|
||||
"""
|
||||
adapter_model: Union[Model, HuggingFaceLLM]
|
||||
try:
|
||||
adapter_model = pc.get_model(
|
||||
name=self.adapter_id,
|
||||
version=self.adapter_version,
|
||||
model_id=None,
|
||||
)
|
||||
except ServerResponseError:
|
||||
# Predibase does not recognize the adapter ID (query HuggingFace).
|
||||
adapter_model = pc.LLM(uri=f"hf://{self.adapter_id}")
|
||||
result = base_llm_deployment.with_adapter(model=adapter_model).generate(
|
||||
prompt=prompt,
|
||||
options=options,
|
||||
)
|
||||
else:
|
||||
result = base_llm_deployment.generate(
|
||||
prompt=prompt,
|
||||
options=options,
|
||||
)
|
||||
return result.response
|
||||
|
||||
from predibase import Predibase
|
||||
|
||||
os.environ["PREDIBASE_GATEWAY"] = "https://api.app.predibase.com"
|
||||
predibase: Predibase = Predibase(
|
||||
api_token=self.predibase_api_key.get_secret_value()
|
||||
)
|
||||
result: GeneratedResponse
|
||||
|
||||
import requests
|
||||
from lorax.client import Client as LoraxClient
|
||||
from lorax.errors import GenerationError
|
||||
from lorax.types import Response
|
||||
|
||||
lorax_client: LoraxClient = predibase.deployments.client(
|
||||
deployment_ref=self.model
|
||||
)
|
||||
|
||||
response: Response
|
||||
if self.adapter_id:
|
||||
"""
|
||||
Attempt to retrieve the fine-tuned adapter from a Predibase repository.
|
||||
If absent, then load the fine-tuned adapter from a HuggingFace repository.
|
||||
"""
|
||||
adapter_model: Union[Model, HuggingFaceLLM]
|
||||
try:
|
||||
adapter_model = pc.get_model(
|
||||
name=self.adapter_id,
|
||||
version=self.adapter_version,
|
||||
model_id=None,
|
||||
)
|
||||
except ServerResponseError:
|
||||
# Predibase does not recognize the adapter ID (query HuggingFace).
|
||||
adapter_model = pc.LLM(uri=f"hf://{self.adapter_id}")
|
||||
result = base_llm_deployment.with_adapter(model=adapter_model).generate(
|
||||
prompt=prompt,
|
||||
options=options,
|
||||
)
|
||||
if self.adapter_version:
|
||||
# Since the adapter version is provided, query the Predibase repository.
|
||||
pb_adapter_id: str = f"{self.adapter_id}/{self.adapter_version}"
|
||||
try:
|
||||
response = lorax_client.generate(
|
||||
prompt=prompt,
|
||||
adapter_id=pb_adapter_id,
|
||||
**options,
|
||||
)
|
||||
except GenerationError as ge:
|
||||
raise ValueError(
|
||||
f"""An adapter with the ID "{pb_adapter_id}" cannot be \
|
||||
found in the Predibase repository of fine-tuned adapters."""
|
||||
) from ge
|
||||
else:
|
||||
# The adapter version is omitted,
|
||||
# hence look for the adapter ID in the HuggingFace repository.
|
||||
try:
|
||||
response = lorax_client.generate(
|
||||
prompt=prompt,
|
||||
adapter_id=self.adapter_id,
|
||||
adapter_source="hub",
|
||||
**options,
|
||||
)
|
||||
except GenerationError as ge:
|
||||
raise ValueError(
|
||||
f"""Either an adapter with the ID "{self.adapter_id}" \
|
||||
cannot be found in a HuggingFace repository, or it is incompatible with the \
|
||||
base model (please make sure that the adapter configuration is consistent).
|
||||
"""
|
||||
) from ge
|
||||
else:
|
||||
result = base_llm_deployment.generate(
|
||||
prompt=prompt,
|
||||
options=options,
|
||||
)
|
||||
return result.response
|
||||
try:
|
||||
response = lorax_client.generate(
|
||||
prompt=prompt,
|
||||
**options,
|
||||
)
|
||||
except requests.JSONDecodeError as jde:
|
||||
raise ValueError(
|
||||
f"""An LLM with the deployment ID "{self.model}" cannot be found \
|
||||
at Predibase (please refer to \
|
||||
"https://docs.predibase.com/user-guide/inference/models" for the list of \
|
||||
supported models).
|
||||
"""
|
||||
) from jde
|
||||
response_text = response.generated_text
|
||||
|
||||
return response_text
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
@@ -112,3 +189,26 @@ class Predibase(LLM):
|
||||
return {
|
||||
**{"model_kwargs": self.model_kwargs},
|
||||
}
|
||||
|
||||
def _is_deprecated_sdk_version(self) -> bool:
|
||||
try:
|
||||
import semantic_version
|
||||
from predibase.version import __version__ as current_version
|
||||
from semantic_version.base import Version
|
||||
|
||||
sdk_semver_deprecated: Version = semantic_version.Version(
|
||||
version_string="2024.4.8"
|
||||
)
|
||||
actual_current_version: str = self.predibase_sdk_version or current_version
|
||||
sdk_semver_current: Version = semantic_version.Version(
|
||||
version_string=actual_current_version
|
||||
)
|
||||
return not (
|
||||
(sdk_semver_current > sdk_semver_deprecated)
|
||||
or ("+dev" in actual_current_version)
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import Predibase Python package. "
|
||||
"Please install it with `pip install semantic_version predibase`."
|
||||
) from e
|
||||
|
@@ -19,6 +19,22 @@ def test_api_key_masked_when_passed_via_constructor(
|
||||
assert captured.out == "**********"
|
||||
|
||||
|
||||
def test_specifying_predibase_sdk_version_argument() -> None:
|
||||
llm = Predibase(
|
||||
model="my_llm",
|
||||
predibase_api_key="secret-api-key",
|
||||
)
|
||||
assert not llm.predibase_sdk_version
|
||||
|
||||
legacy_predibase_sdk_version = "2024.4.8"
|
||||
llm = Predibase(
|
||||
model="my_llm",
|
||||
predibase_api_key="secret-api-key",
|
||||
predibase_sdk_version=legacy_predibase_sdk_version,
|
||||
)
|
||||
assert llm.predibase_sdk_version == legacy_predibase_sdk_version
|
||||
|
||||
|
||||
def test_specifying_adapter_id_argument() -> None:
|
||||
llm = Predibase(model="my_llm", predibase_api_key="secret-api-key")
|
||||
assert not llm.adapter_id
|
||||
@@ -33,8 +49,8 @@ def test_specifying_adapter_id_argument() -> None:
|
||||
|
||||
llm = Predibase(
|
||||
model="my_llm",
|
||||
adapter_id="my-other-hf-adapter",
|
||||
predibase_api_key="secret-api-key",
|
||||
adapter_id="my-other-hf-adapter",
|
||||
)
|
||||
assert llm.adapter_id == "my-other-hf-adapter"
|
||||
assert llm.adapter_version is None
|
||||
@@ -55,9 +71,9 @@ def test_specifying_adapter_id_and_adapter_version_arguments() -> None:
|
||||
|
||||
llm = Predibase(
|
||||
model="my_llm",
|
||||
predibase_api_key="secret-api-key",
|
||||
adapter_id="my-other-hf-adapter",
|
||||
adapter_version=3,
|
||||
predibase_api_key="secret-api-key",
|
||||
)
|
||||
assert llm.adapter_id == "my-other-hf-adapter"
|
||||
assert llm.adapter_version == 3
|
||||
|
Reference in New Issue
Block a user