mirror of
				https://github.com/hwchase17/langchain.git
				synced 2025-10-30 23:29:54 +00:00 
			
		
		
		
	Thank you for contributing to LangChain!
- [x] **PR title**: "package: description"
- Where "package" is whichever of langchain, community, core,
experimental, etc. is being modified. Use "docs: ..." for purely docs
changes, "templates: ..." for template changes, "infra: ..." for CI
changes.
  - Example: "community: add foobar LLM"
- [ ] **PR message**: ***Delete this entire checklist*** and replace
with
    - **Description:** a description of the change
    - **Issue:** the issue # it fixes, if applicable
    - **Dependencies:** any dependencies required for this change
- **Twitter handle:** if your PR gets announced, and you'd like a
mention, we'll gladly shout you out!
- [x] **Add tests and docs**: If you're adding a new integration, please
include
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.
- [ ] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/
Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.
If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.
		
	
		
			
				
	
	
		
			220 lines
		
	
	
		
			8.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			220 lines
		
	
	
		
			8.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import os
 | |
| from typing import Any, Dict, List, Mapping, Optional, Union
 | |
| 
 | |
| from langchain_core.callbacks import CallbackManagerForLLMRun
 | |
| from langchain_core.language_models.llms import LLM
 | |
| from langchain_core.pydantic_v1 import Field, SecretStr
 | |
| 
 | |
| 
 | |
| class Predibase(LLM):
 | |
|     """Use your Predibase models with Langchain.
 | |
| 
 | |
|     To use, you should have the ``predibase`` python package installed,
 | |
|     and have your Predibase API key.
 | |
| 
 | |
|     The `model` parameter is the Predibase "serverless" base_model ID
 | |
|     (see https://docs.predibase.com/user-guide/inference/models for the catalog).
 | |
| 
 | |
|     An optional `adapter_id` parameter is the Predibase ID or HuggingFace ID of a
 | |
|     fine-tuned LLM adapter, whose base model is the `model` parameter; the
 | |
|     fine-tuned adapter must be compatible with its base model;
 | |
|     otherwise, an error is raised.  If the fine-tuned adapter is hosted at Predibase,
 | |
|     then `adapter_version` in the adapter repository must be specified.
 | |
| 
 | |
|     An optional `predibase_sdk_version` parameter defaults to latest SDK version.
 | |
|     """
 | |
| 
 | |
|     model: str
 | |
|     predibase_api_key: SecretStr
 | |
|     predibase_sdk_version: Optional[str] = None
 | |
|     adapter_id: Optional[str] = None
 | |
|     adapter_version: Optional[int] = None
 | |
|     model_kwargs: Dict[str, Any] = Field(default_factory=dict)
 | |
|     default_options_for_generation: dict = Field(
 | |
|         {
 | |
|             "max_new_tokens": 256,
 | |
|             "temperature": 0.1,
 | |
|         },
 | |
|         const=True,
 | |
|     )
 | |
| 
 | |
|     @property
 | |
|     def _llm_type(self) -> str:
 | |
|         return "predibase"
 | |
| 
 | |
|     def _call(
 | |
|         self,
 | |
|         prompt: str,
 | |
|         stop: Optional[List[str]] = None,
 | |
|         run_manager: Optional[CallbackManagerForLLMRun] = None,
 | |
|         **kwargs: Any,
 | |
|     ) -> str:
 | |
|         options: Dict[str, Union[str, float]] = {
 | |
|             **(self.model_kwargs or {}),
 | |
|             **self.default_options_for_generation,
 | |
|             **(kwargs or {}),
 | |
|         }
 | |
|         if self._is_deprecated_sdk_version():
 | |
|             try:
 | |
|                 from predibase import PredibaseClient
 | |
|                 from predibase.pql import get_session
 | |
|                 from predibase.pql.api import (
 | |
|                     ServerResponseError,
 | |
|                     Session,
 | |
|                 )
 | |
|                 from predibase.resource.llm.interface import (
 | |
|                     HuggingFaceLLM,
 | |
|                     LLMDeployment,
 | |
|                 )
 | |
|                 from predibase.resource.llm.response import GeneratedResponse
 | |
|                 from predibase.resource.model import Model
 | |
| 
 | |
|                 session: Session = get_session(
 | |
|                     token=self.predibase_api_key.get_secret_value(),
 | |
|                     gateway="https://api.app.predibase.com/v1",
 | |
|                     serving_endpoint="serving.app.predibase.com",
 | |
|                 )
 | |
|                 pc: PredibaseClient = PredibaseClient(session=session)
 | |
|             except ImportError as e:
 | |
|                 raise ImportError(
 | |
|                     "Could not import Predibase Python package. "
 | |
|                     "Please install it with `pip install predibase`."
 | |
|                 ) from e
 | |
|             except ValueError as e:
 | |
|                 raise ValueError("Your API key is not correct. Please try again") from e
 | |
| 
 | |
|             base_llm_deployment: LLMDeployment = pc.LLM(
 | |
|                 uri=f"pb://deployments/{self.model}"
 | |
|             )
 | |
|             result: GeneratedResponse
 | |
|             if self.adapter_id:
 | |
|                 """
 | |
|                 Attempt to retrieve the fine-tuned adapter from a Predibase
 | |
|                 repository.  If absent, then load the fine-tuned adapter
 | |
|                 from a HuggingFace repository.
 | |
|                 """
 | |
|                 adapter_model: Union[Model, HuggingFaceLLM]
 | |
|                 try:
 | |
|                     adapter_model = pc.get_model(
 | |
|                         name=self.adapter_id,
 | |
|                         version=self.adapter_version,
 | |
|                         model_id=None,
 | |
|                     )
 | |
|                 except ServerResponseError:
 | |
|                     # Predibase does not recognize the adapter ID (query HuggingFace).
 | |
|                     adapter_model = pc.LLM(uri=f"hf://{self.adapter_id}")
 | |
|                 result = base_llm_deployment.with_adapter(model=adapter_model).generate(
 | |
|                     prompt=prompt,
 | |
|                     options=options,
 | |
|                 )
 | |
|             else:
 | |
|                 result = base_llm_deployment.generate(
 | |
|                     prompt=prompt,
 | |
|                     options=options,
 | |
|                 )
 | |
|             return result.response
 | |
| 
 | |
|         from predibase import Predibase
 | |
| 
 | |
|         os.environ["PREDIBASE_GATEWAY"] = "https://api.app.predibase.com"
 | |
|         predibase: Predibase = Predibase(
 | |
|             api_token=self.predibase_api_key.get_secret_value()
 | |
|         )
 | |
| 
 | |
|         import requests
 | |
|         from lorax.client import Client as LoraxClient
 | |
|         from lorax.errors import GenerationError
 | |
|         from lorax.types import Response
 | |
| 
 | |
|         lorax_client: LoraxClient = predibase.deployments.client(
 | |
|             deployment_ref=self.model
 | |
|         )
 | |
| 
 | |
|         response: Response
 | |
|         if self.adapter_id:
 | |
|             """
 | |
|             Attempt to retrieve the fine-tuned adapter from a Predibase repository.
 | |
|             If absent, then load the fine-tuned adapter from a HuggingFace repository.
 | |
|             """
 | |
|             if self.adapter_version:
 | |
|                 # Since the adapter version is provided, query the Predibase repository.
 | |
|                 pb_adapter_id: str = f"{self.adapter_id}/{self.adapter_version}"
 | |
|                 options.pop(
 | |
|                     "api_token", None
 | |
|                 )  # The "api_token" is not used for Predibase-hosted models.
 | |
|                 try:
 | |
|                     response = lorax_client.generate(
 | |
|                         prompt=prompt,
 | |
|                         adapter_id=pb_adapter_id,
 | |
|                         **options,
 | |
|                     )
 | |
|                 except GenerationError as ge:
 | |
|                     raise ValueError(
 | |
|                         f"""An adapter with the ID "{pb_adapter_id}" cannot be \
 | |
| found in the Predibase repository of fine-tuned adapters."""
 | |
|                     ) from ge
 | |
|             else:
 | |
|                 # The adapter version is omitted,
 | |
|                 # hence look for the adapter ID in the HuggingFace repository.
 | |
|                 try:
 | |
|                     response = lorax_client.generate(
 | |
|                         prompt=prompt,
 | |
|                         adapter_id=self.adapter_id,
 | |
|                         adapter_source="hub",
 | |
|                         **options,
 | |
|                     )
 | |
|                 except GenerationError as ge:
 | |
|                     raise ValueError(
 | |
|                         f"""Either an adapter with the ID "{self.adapter_id}" \
 | |
| cannot be found in a HuggingFace repository, or it is incompatible with the \
 | |
| base model (please make sure that the adapter configuration is consistent).
 | |
| """
 | |
|                     ) from ge
 | |
|         else:
 | |
|             try:
 | |
|                 response = lorax_client.generate(
 | |
|                     prompt=prompt,
 | |
|                     **options,
 | |
|                 )
 | |
|             except requests.JSONDecodeError as jde:
 | |
|                 raise ValueError(
 | |
|                     f"""An LLM with the deployment ID "{self.model}" cannot be found \
 | |
| at Predibase (please refer to \
 | |
| "https://docs.predibase.com/user-guide/inference/models" for the list of \
 | |
| supported models).
 | |
| """
 | |
|                 ) from jde
 | |
|         response_text = response.generated_text
 | |
| 
 | |
|         return response_text
 | |
| 
 | |
|     @property
 | |
|     def _identifying_params(self) -> Mapping[str, Any]:
 | |
|         """Get the identifying parameters."""
 | |
|         return {
 | |
|             **{"model_kwargs": self.model_kwargs},
 | |
|         }
 | |
| 
 | |
|     def _is_deprecated_sdk_version(self) -> bool:
 | |
|         try:
 | |
|             import semantic_version
 | |
|             from predibase.version import __version__ as current_version
 | |
|             from semantic_version.base import Version
 | |
| 
 | |
|             sdk_semver_deprecated: Version = semantic_version.Version(
 | |
|                 version_string="2024.4.8"
 | |
|             )
 | |
|             actual_current_version: str = self.predibase_sdk_version or current_version
 | |
|             sdk_semver_current: Version = semantic_version.Version(
 | |
|                 version_string=actual_current_version
 | |
|             )
 | |
|             return not (
 | |
|                 (sdk_semver_current > sdk_semver_deprecated)
 | |
|                 or ("+dev" in actual_current_version)
 | |
|             )
 | |
|         except ImportError as e:
 | |
|             raise ImportError(
 | |
|                 "Could not import Predibase Python package. "
 | |
|                 "Please install it with `pip install semantic_version predibase`."
 | |
|             ) from e
 |