mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 17:38:36 +00:00
community[minor]: fix failing Predibase integration (#19776)
- [x] **PR title**: "package: description" - Where "package" is whichever of langchain, community, core, experimental, etc. is being modified. Use "docs: ..." for purely docs changes, "templates: ..." for template changes, "infra: ..." for CI changes. - Example: "community: add foobar LLM" - [x] **PR message**: ***Delete this entire checklist*** and replace with - **Description:** Langchain-Predibase integration was failing, because it was not current with the Predibase SDK; in addition, Predibase integration tests were instantiating the Langchain Community `Predibase` class with one required argument (`model`) missing. This change updates the Predibase SDK usage and fixes the integration tests. - **Twitter handle:** `@alexsherstinsky` --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
e9caa22d47
commit
a9bc212bf2
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, List, Mapping, Optional
|
from typing import Any, Dict, List, Mapping, Optional, Union
|
||||||
|
|
||||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
from langchain_core.language_models.llms import LLM
|
from langchain_core.language_models.llms import LLM
|
||||||
@ -15,6 +15,13 @@ class Predibase(LLM):
|
|||||||
model: str
|
model: str
|
||||||
predibase_api_key: SecretStr
|
predibase_api_key: SecretStr
|
||||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
default_options_for_generation: dict = Field(
|
||||||
|
{
|
||||||
|
"max_new_tokens": 256,
|
||||||
|
"temperature": 0.1,
|
||||||
|
},
|
||||||
|
const=True,
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
@ -29,8 +36,17 @@ class Predibase(LLM):
|
|||||||
) -> str:
|
) -> str:
|
||||||
try:
|
try:
|
||||||
from predibase import PredibaseClient
|
from predibase import PredibaseClient
|
||||||
|
from predibase.pql import get_session
|
||||||
|
from predibase.pql.api import Session
|
||||||
|
from predibase.resource.llm.interface import LLMDeployment
|
||||||
|
from predibase.resource.llm.response import GeneratedResponse
|
||||||
|
|
||||||
pc = PredibaseClient(token=self.predibase_api_key.get_secret_value())
|
session: Session = get_session(
|
||||||
|
token=self.predibase_api_key.get_secret_value(),
|
||||||
|
gateway="https://api.app.predibase.com/v1",
|
||||||
|
serving_endpoint="serving.app.predibase.com",
|
||||||
|
)
|
||||||
|
pc: PredibaseClient = PredibaseClient(session=session)
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Could not import Predibase Python package. "
|
"Could not import Predibase Python package. "
|
||||||
@ -38,9 +54,17 @@ class Predibase(LLM):
|
|||||||
) from e
|
) from e
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise ValueError("Your API key is not correct. Please try again") from e
|
raise ValueError("Your API key is not correct. Please try again") from e
|
||||||
# load model and version
|
options: Dict[str, Union[str, float]] = (
|
||||||
results = pc.prompt(prompt, model_name=self.model)
|
kwargs or self.default_options_for_generation
|
||||||
return results[0].response
|
)
|
||||||
|
base_llm_deployment: LLMDeployment = pc.LLM(
|
||||||
|
uri=f"pb://deployments/{self.model}"
|
||||||
|
)
|
||||||
|
result: GeneratedResponse = base_llm_deployment.generate(
|
||||||
|
prompt=prompt,
|
||||||
|
options=options,
|
||||||
|
)
|
||||||
|
return result.response
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _identifying_params(self) -> Mapping[str, Any]:
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
|
@ -5,14 +5,14 @@ from langchain_community.llms.predibase import Predibase
|
|||||||
|
|
||||||
|
|
||||||
def test_api_key_is_string() -> None:
|
def test_api_key_is_string() -> None:
|
||||||
llm = Predibase(predibase_api_key="secret-api-key")
|
llm = Predibase(model="my_llm", predibase_api_key="secret-api-key")
|
||||||
assert isinstance(llm.predibase_api_key, SecretStr)
|
assert isinstance(llm.predibase_api_key, SecretStr)
|
||||||
|
|
||||||
|
|
||||||
def test_api_key_masked_when_passed_via_constructor(
|
def test_api_key_masked_when_passed_via_constructor(
|
||||||
capsys: CaptureFixture,
|
capsys: CaptureFixture,
|
||||||
) -> None:
|
) -> None:
|
||||||
llm = Predibase(predibase_api_key="secret-api-key")
|
llm = Predibase(model="my_llm", predibase_api_key="secret-api-key")
|
||||||
print(llm.predibase_api_key, end="") # noqa: T201
|
print(llm.predibase_api_key, end="") # noqa: T201
|
||||||
captured = capsys.readouterr()
|
captured = capsys.readouterr()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user