mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-18 12:58:59 +00:00
community: Support both Predibase SDK-v1 and SDK-v2 in Predibase-LangChain integration (#20859)
This commit is contained in:
parent
8c95ac3145
commit
12e5ec6de3
@ -63,12 +63,13 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"from langchain_community.llms import Predibase\n",
|
"from langchain_community.llms import Predibase\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# With a fine-tuned adapter hosted at Predibase (adapter_version can be specified; omitting it is equivalent to the most recent version).\n",
|
"# With a fine-tuned adapter hosted at Predibase (adapter_version must be specified).\n",
|
||||||
"model = Predibase(\n",
|
"model = Predibase(\n",
|
||||||
" model=\"mistral-7b\",\n",
|
" model=\"mistral-7b\",\n",
|
||||||
|
" predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n",
|
||||||
|
" predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n",
|
||||||
" adapter_id=\"e2e_nlg\",\n",
|
" adapter_id=\"e2e_nlg\",\n",
|
||||||
" adapter_version=1,\n",
|
" adapter_version=1,\n",
|
||||||
" predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n",
|
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -83,8 +84,9 @@
|
|||||||
"# With a fine-tuned adapter hosted at HuggingFace (adapter_version does not apply and will be ignored).\n",
|
"# With a fine-tuned adapter hosted at HuggingFace (adapter_version does not apply and will be ignored).\n",
|
||||||
"model = Predibase(\n",
|
"model = Predibase(\n",
|
||||||
" model=\"mistral-7b\",\n",
|
" model=\"mistral-7b\",\n",
|
||||||
" adapter_id=\"predibase/e2e_nlg\",\n",
|
|
||||||
" predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n",
|
" predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n",
|
||||||
|
" predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n",
|
||||||
|
" adapter_id=\"predibase/e2e_nlg\",\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -122,7 +124,9 @@
|
|||||||
"from langchain_community.llms import Predibase\n",
|
"from langchain_community.llms import Predibase\n",
|
||||||
"\n",
|
"\n",
|
||||||
"model = Predibase(\n",
|
"model = Predibase(\n",
|
||||||
" model=\"mistral-7b\", predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\")\n",
|
" model=\"mistral-7b\",\n",
|
||||||
|
" predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n",
|
||||||
|
" predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -136,12 +140,13 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# With a fine-tuned adapter hosted at Predibase (adapter_version can be specified; omitting it is equivalent to the most recent version).\n",
|
"# With a fine-tuned adapter hosted at Predibase (adapter_version must be specified).\n",
|
||||||
"model = Predibase(\n",
|
"model = Predibase(\n",
|
||||||
" model=\"mistral-7b\",\n",
|
" model=\"mistral-7b\",\n",
|
||||||
|
" predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n",
|
||||||
|
" predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n",
|
||||||
" adapter_id=\"e2e_nlg\",\n",
|
" adapter_id=\"e2e_nlg\",\n",
|
||||||
" adapter_version=1,\n",
|
" adapter_version=1,\n",
|
||||||
" predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n",
|
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -154,8 +159,9 @@
|
|||||||
"# With a fine-tuned adapter hosted at HuggingFace (adapter_version does not apply and will be ignored).\n",
|
"# With a fine-tuned adapter hosted at HuggingFace (adapter_version does not apply and will be ignored).\n",
|
||||||
"llm = Predibase(\n",
|
"llm = Predibase(\n",
|
||||||
" model=\"mistral-7b\",\n",
|
" model=\"mistral-7b\",\n",
|
||||||
" adapter_id=\"predibase/e2e_nlg\",\n",
|
|
||||||
" predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n",
|
" predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n",
|
||||||
|
" predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n",
|
||||||
|
" adapter_id=\"predibase/e2e_nlg\",\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -247,13 +253,14 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"model = Predibase(\n",
|
"model = Predibase(\n",
|
||||||
" model=\"my-base-LLM\",\n",
|
" model=\"my-base-LLM\",\n",
|
||||||
" adapter_id=\"my-finetuned-adapter-id\", # Supports both, Predibase-hosted and HuggingFace-hosted model repositories.\n",
|
|
||||||
" # adapter_version=1, # optional (returns the latest, if omitted)\n",
|
|
||||||
" predibase_api_key=os.environ.get(\n",
|
" predibase_api_key=os.environ.get(\n",
|
||||||
" \"PREDIBASE_API_TOKEN\"\n",
|
" \"PREDIBASE_API_TOKEN\"\n",
|
||||||
" ), # Adapter argument is optional.\n",
|
" ), # Adapter argument is optional.\n",
|
||||||
|
" predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n",
|
||||||
|
" adapter_id=\"my-finetuned-adapter-id\", # Supports both, Predibase-hosted and HuggingFace-hosted adapter repositories.\n",
|
||||||
|
" adapter_version=1, # required for Predibase-hosted adapters (ignored for HuggingFace-hosted adapters)\n",
|
||||||
")\n",
|
")\n",
|
||||||
"# replace my-finetuned-LLM with the name of your model in Predibase"
|
"# replace my-base-LLM with the name of your choice of a serverless base model in Predibase"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -17,7 +17,11 @@ os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}"
|
|||||||
|
|
||||||
from langchain_community.llms import Predibase
|
from langchain_community.llms import Predibase
|
||||||
|
|
||||||
model = Predibase(model="mistral-7b"", predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN"))
|
model = Predibase(
|
||||||
|
model="mistral-7b",
|
||||||
|
predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN"),
|
||||||
|
predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)
|
||||||
|
)
|
||||||
|
|
||||||
response = model("Can you recommend me a nice dry wine?")
|
response = model("Can you recommend me a nice dry wine?")
|
||||||
print(response)
|
print(response)
|
||||||
@ -31,8 +35,14 @@ os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}"
|
|||||||
|
|
||||||
from langchain_community.llms import Predibase
|
from langchain_community.llms import Predibase
|
||||||
|
|
||||||
# The fine-tuned adapter is hosted at Predibase (adapter_version can be specified; omitting it is equivalent to the most recent version).
|
# The fine-tuned adapter is hosted at Predibase (adapter_version must be specified).
|
||||||
model = Predibase(model="mistral-7b"", adapter_id="e2e_nlg", adapter_version=1, predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN"))
|
model = Predibase(
|
||||||
|
model="mistral-7b",
|
||||||
|
predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN"),
|
||||||
|
predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)
|
||||||
|
adapter_id="e2e_nlg",
|
||||||
|
adapter_version=1,
|
||||||
|
)
|
||||||
|
|
||||||
response = model("Can you recommend me a nice dry wine?")
|
response = model("Can you recommend me a nice dry wine?")
|
||||||
print(response)
|
print(response)
|
||||||
@ -47,7 +57,12 @@ os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}"
|
|||||||
from langchain_community.llms import Predibase
|
from langchain_community.llms import Predibase
|
||||||
|
|
||||||
# The fine-tuned adapter is hosted at HuggingFace (adapter_version does not apply and will be ignored).
|
# The fine-tuned adapter is hosted at HuggingFace (adapter_version does not apply and will be ignored).
|
||||||
model = Predibase(model="mistral-7b"", adapter_id="predibase/e2e_nlg", predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN"))
|
model = Predibase(
|
||||||
|
model="mistral-7b",
|
||||||
|
predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN"),
|
||||||
|
predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)
|
||||||
|
adapter_id="predibase/e2e_nlg",
|
||||||
|
)
|
||||||
|
|
||||||
response = model("Can you recommend me a nice dry wine?")
|
response = model("Can you recommend me a nice dry wine?")
|
||||||
print(response)
|
print(response)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
from typing import Any, Dict, List, Mapping, Optional, Union
|
from typing import Any, Dict, List, Mapping, Optional, Union
|
||||||
|
|
||||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
@ -17,13 +18,15 @@ class Predibase(LLM):
|
|||||||
An optional `adapter_id` parameter is the Predibase ID or HuggingFace ID of a
|
An optional `adapter_id` parameter is the Predibase ID or HuggingFace ID of a
|
||||||
fine-tuned LLM adapter, whose base model is the `model` parameter; the
|
fine-tuned LLM adapter, whose base model is the `model` parameter; the
|
||||||
fine-tuned adapter must be compatible with its base model;
|
fine-tuned adapter must be compatible with its base model;
|
||||||
otherwise, an error is raised. If a Predibase ID references the
|
otherwise, an error is raised. If the fine-tuned adapter is hosted at Predibase,
|
||||||
fine-tuned adapter, then the `adapter_version` in the adapter repository can
|
then `adapter_version` in the adapter repository must be specified.
|
||||||
be optionally specified; omitting it defaults to the most recent version.
|
|
||||||
|
An optional `predibase_sdk_version` parameter defaults to latest SDK version.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model: str
|
model: str
|
||||||
predibase_api_key: SecretStr
|
predibase_api_key: SecretStr
|
||||||
|
predibase_sdk_version: Optional[str] = None
|
||||||
adapter_id: Optional[str] = None
|
adapter_id: Optional[str] = None
|
||||||
adapter_version: Optional[int] = None
|
adapter_version: Optional[int] = None
|
||||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
@ -46,6 +49,10 @@ class Predibase(LLM):
|
|||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
options: Dict[str, Union[str, float]] = (
|
||||||
|
self.model_kwargs or self.default_options_for_generation
|
||||||
|
)
|
||||||
|
if self._is_deprecated_sdk_version():
|
||||||
try:
|
try:
|
||||||
from predibase import PredibaseClient
|
from predibase import PredibaseClient
|
||||||
from predibase.pql import get_session
|
from predibase.pql import get_session
|
||||||
@ -73,17 +80,16 @@ class Predibase(LLM):
|
|||||||
) from e
|
) from e
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise ValueError("Your API key is not correct. Please try again") from e
|
raise ValueError("Your API key is not correct. Please try again") from e
|
||||||
options: Dict[str, Union[str, float]] = (
|
|
||||||
self.model_kwargs or self.default_options_for_generation
|
|
||||||
)
|
|
||||||
base_llm_deployment: LLMDeployment = pc.LLM(
|
base_llm_deployment: LLMDeployment = pc.LLM(
|
||||||
uri=f"pb://deployments/{self.model}"
|
uri=f"pb://deployments/{self.model}"
|
||||||
)
|
)
|
||||||
result: GeneratedResponse
|
result: GeneratedResponse
|
||||||
if self.adapter_id:
|
if self.adapter_id:
|
||||||
"""
|
"""
|
||||||
Attempt to retrieve the fine-tuned adapter from a Predibase repository.
|
Attempt to retrieve the fine-tuned adapter from a Predibase
|
||||||
If absent, then load the fine-tuned adapter from a HuggingFace repository.
|
repository. If absent, then load the fine-tuned adapter
|
||||||
|
from a HuggingFace repository.
|
||||||
"""
|
"""
|
||||||
adapter_model: Union[Model, HuggingFaceLLM]
|
adapter_model: Union[Model, HuggingFaceLLM]
|
||||||
try:
|
try:
|
||||||
@ -106,9 +112,103 @@ class Predibase(LLM):
|
|||||||
)
|
)
|
||||||
return result.response
|
return result.response
|
||||||
|
|
||||||
|
from predibase import Predibase
|
||||||
|
|
||||||
|
os.environ["PREDIBASE_GATEWAY"] = "https://api.app.predibase.com"
|
||||||
|
predibase: Predibase = Predibase(
|
||||||
|
api_token=self.predibase_api_key.get_secret_value()
|
||||||
|
)
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from lorax.client import Client as LoraxClient
|
||||||
|
from lorax.errors import GenerationError
|
||||||
|
from lorax.types import Response
|
||||||
|
|
||||||
|
lorax_client: LoraxClient = predibase.deployments.client(
|
||||||
|
deployment_ref=self.model
|
||||||
|
)
|
||||||
|
|
||||||
|
response: Response
|
||||||
|
if self.adapter_id:
|
||||||
|
"""
|
||||||
|
Attempt to retrieve the fine-tuned adapter from a Predibase repository.
|
||||||
|
If absent, then load the fine-tuned adapter from a HuggingFace repository.
|
||||||
|
"""
|
||||||
|
if self.adapter_version:
|
||||||
|
# Since the adapter version is provided, query the Predibase repository.
|
||||||
|
pb_adapter_id: str = f"{self.adapter_id}/{self.adapter_version}"
|
||||||
|
try:
|
||||||
|
response = lorax_client.generate(
|
||||||
|
prompt=prompt,
|
||||||
|
adapter_id=pb_adapter_id,
|
||||||
|
**options,
|
||||||
|
)
|
||||||
|
except GenerationError as ge:
|
||||||
|
raise ValueError(
|
||||||
|
f"""An adapter with the ID "{pb_adapter_id}" cannot be \
|
||||||
|
found in the Predibase repository of fine-tuned adapters."""
|
||||||
|
) from ge
|
||||||
|
else:
|
||||||
|
# The adapter version is omitted,
|
||||||
|
# hence look for the adapter ID in the HuggingFace repository.
|
||||||
|
try:
|
||||||
|
response = lorax_client.generate(
|
||||||
|
prompt=prompt,
|
||||||
|
adapter_id=self.adapter_id,
|
||||||
|
adapter_source="hub",
|
||||||
|
**options,
|
||||||
|
)
|
||||||
|
except GenerationError as ge:
|
||||||
|
raise ValueError(
|
||||||
|
f"""Either an adapter with the ID "{self.adapter_id}" \
|
||||||
|
cannot be found in a HuggingFace repository, or it is incompatible with the \
|
||||||
|
base model (please make sure that the adapter configuration is consistent).
|
||||||
|
"""
|
||||||
|
) from ge
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
response = lorax_client.generate(
|
||||||
|
prompt=prompt,
|
||||||
|
**options,
|
||||||
|
)
|
||||||
|
except requests.JSONDecodeError as jde:
|
||||||
|
raise ValueError(
|
||||||
|
f"""An LLM with the deployment ID "{self.model}" cannot be found \
|
||||||
|
at Predibase (please refer to \
|
||||||
|
"https://docs.predibase.com/user-guide/inference/models" for the list of \
|
||||||
|
supported models).
|
||||||
|
"""
|
||||||
|
) from jde
|
||||||
|
response_text = response.generated_text
|
||||||
|
|
||||||
|
return response_text
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _identifying_params(self) -> Mapping[str, Any]:
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
"""Get the identifying parameters."""
|
"""Get the identifying parameters."""
|
||||||
return {
|
return {
|
||||||
**{"model_kwargs": self.model_kwargs},
|
**{"model_kwargs": self.model_kwargs},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _is_deprecated_sdk_version(self) -> bool:
|
||||||
|
try:
|
||||||
|
import semantic_version
|
||||||
|
from predibase.version import __version__ as current_version
|
||||||
|
from semantic_version.base import Version
|
||||||
|
|
||||||
|
sdk_semver_deprecated: Version = semantic_version.Version(
|
||||||
|
version_string="2024.4.8"
|
||||||
|
)
|
||||||
|
actual_current_version: str = self.predibase_sdk_version or current_version
|
||||||
|
sdk_semver_current: Version = semantic_version.Version(
|
||||||
|
version_string=actual_current_version
|
||||||
|
)
|
||||||
|
return not (
|
||||||
|
(sdk_semver_current > sdk_semver_deprecated)
|
||||||
|
or ("+dev" in actual_current_version)
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import Predibase Python package. "
|
||||||
|
"Please install it with `pip install semantic_version predibase`."
|
||||||
|
) from e
|
||||||
|
@ -19,6 +19,22 @@ def test_api_key_masked_when_passed_via_constructor(
|
|||||||
assert captured.out == "**********"
|
assert captured.out == "**********"
|
||||||
|
|
||||||
|
|
||||||
|
def test_specifying_predibase_sdk_version_argument() -> None:
|
||||||
|
llm = Predibase(
|
||||||
|
model="my_llm",
|
||||||
|
predibase_api_key="secret-api-key",
|
||||||
|
)
|
||||||
|
assert not llm.predibase_sdk_version
|
||||||
|
|
||||||
|
legacy_predibase_sdk_version = "2024.4.8"
|
||||||
|
llm = Predibase(
|
||||||
|
model="my_llm",
|
||||||
|
predibase_api_key="secret-api-key",
|
||||||
|
predibase_sdk_version=legacy_predibase_sdk_version,
|
||||||
|
)
|
||||||
|
assert llm.predibase_sdk_version == legacy_predibase_sdk_version
|
||||||
|
|
||||||
|
|
||||||
def test_specifying_adapter_id_argument() -> None:
|
def test_specifying_adapter_id_argument() -> None:
|
||||||
llm = Predibase(model="my_llm", predibase_api_key="secret-api-key")
|
llm = Predibase(model="my_llm", predibase_api_key="secret-api-key")
|
||||||
assert not llm.adapter_id
|
assert not llm.adapter_id
|
||||||
@ -33,8 +49,8 @@ def test_specifying_adapter_id_argument() -> None:
|
|||||||
|
|
||||||
llm = Predibase(
|
llm = Predibase(
|
||||||
model="my_llm",
|
model="my_llm",
|
||||||
adapter_id="my-other-hf-adapter",
|
|
||||||
predibase_api_key="secret-api-key",
|
predibase_api_key="secret-api-key",
|
||||||
|
adapter_id="my-other-hf-adapter",
|
||||||
)
|
)
|
||||||
assert llm.adapter_id == "my-other-hf-adapter"
|
assert llm.adapter_id == "my-other-hf-adapter"
|
||||||
assert llm.adapter_version is None
|
assert llm.adapter_version is None
|
||||||
@ -55,9 +71,9 @@ def test_specifying_adapter_id_and_adapter_version_arguments() -> None:
|
|||||||
|
|
||||||
llm = Predibase(
|
llm = Predibase(
|
||||||
model="my_llm",
|
model="my_llm",
|
||||||
|
predibase_api_key="secret-api-key",
|
||||||
adapter_id="my-other-hf-adapter",
|
adapter_id="my-other-hf-adapter",
|
||||||
adapter_version=3,
|
adapter_version=3,
|
||||||
predibase_api_key="secret-api-key",
|
|
||||||
)
|
)
|
||||||
assert llm.adapter_id == "my-other-hf-adapter"
|
assert llm.adapter_id == "my-other-hf-adapter"
|
||||||
assert llm.adapter_version == 3
|
assert llm.adapter_version == 3
|
||||||
|
Loading…
Reference in New Issue
Block a user