mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-16 15:04:13 +00:00
community: add OCI Endpoint (#14250)
- **Description:** - [OCI Data Science](https://docs.oracle.com/en-us/iaas/data-science/using/home.htm) is a fully managed and serverless platform for data science teams to build, train, and manage machine learning models in the Oracle Cloud Infrastructure. This PR add integration for using LangChain with an LLM hosted on a [OCI Data Science Model Deployment](https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-about.htm). To authenticate, [oracle-ads](https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html) has been used to automatically load credentials for invoking endpoint. - **Issue:** None - **Dependencies:** `oracle-ads` - **Tag maintainer:** @baskaryan - **Twitter handle:** None --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
@@ -324,6 +324,22 @@ def _import_nlpcloud() -> Any:
|
||||
return NLPCloud
|
||||
|
||||
|
||||
def _import_oci_md_tgi() -> Any:
|
||||
from langchain_community.llms.oci_data_science_model_deployment_endpoint import (
|
||||
OCIModelDeploymentTGI,
|
||||
)
|
||||
|
||||
return OCIModelDeploymentTGI
|
||||
|
||||
|
||||
def _import_oci_md_vllm() -> Any:
|
||||
from langchain_community.llms.oci_data_science_model_deployment_endpoint import (
|
||||
OCIModelDeploymentVLLM,
|
||||
)
|
||||
|
||||
return OCIModelDeploymentVLLM
|
||||
|
||||
|
||||
def _import_octoai_endpoint() -> Any:
|
||||
from langchain_community.llms.octoai_endpoint import OctoAIEndpoint
|
||||
|
||||
@@ -639,6 +655,10 @@ def __getattr__(name: str) -> Any:
|
||||
return _import_mosaicml()
|
||||
elif name == "NLPCloud":
|
||||
return _import_nlpcloud()
|
||||
elif name == "OCIModelDeploymentTGI":
|
||||
return _import_oci_md_tgi()
|
||||
elif name == "OCIModelDeploymentVLLM":
|
||||
return _import_oci_md_vllm()
|
||||
elif name == "OctoAIEndpoint":
|
||||
return _import_octoai_endpoint()
|
||||
elif name == "Ollama":
|
||||
@@ -770,6 +790,8 @@ __all__ = [
|
||||
"Nebula",
|
||||
"NIBittensorLLM",
|
||||
"NLPCloud",
|
||||
"OCIModelDeploymentTGI",
|
||||
"OCIModelDeploymentVLLM",
|
||||
"Ollama",
|
||||
"OpenAI",
|
||||
"OpenAIChat",
|
||||
@@ -857,6 +879,8 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
|
||||
"nebula": _import_symblai_nebula,
|
||||
"nibittensor": _import_bittensor,
|
||||
"nlpcloud": _import_nlpcloud,
|
||||
"oci_model_deployment_tgi_endpoint": _import_oci_md_tgi,
|
||||
"oci_model_deployment_vllm_endpoint": _import_oci_md_vllm,
|
||||
"ollama": _import_ollama,
|
||||
"openai": _import_openai,
|
||||
"openlm": _import_openlm,
|
||||
|
@@ -0,0 +1,362 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_TIME_OUT = 300
|
||||
DEFAULT_CONTENT_TYPE_JSON = "application/json"
|
||||
|
||||
|
||||
class OCIModelDeploymentLLM(LLM):
|
||||
"""Base class for LLM deployed on OCI Data Science Model Deployment."""
|
||||
|
||||
auth: dict = Field(default_factory=dict, exclude=True)
|
||||
"""ADS auth dictionary for OCI authentication:
|
||||
https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html.
|
||||
This can be generated by calling `ads.common.auth.api_keys()`
|
||||
or `ads.common.auth.resource_principal()`. If this is not
|
||||
provided then the `ads.common.default_signer()` will be used."""
|
||||
|
||||
max_tokens: int = 256
|
||||
"""Denotes the number of tokens to predict per generation."""
|
||||
|
||||
temperature: float = 0.2
|
||||
"""A non-negative float that tunes the degree of randomness in generation."""
|
||||
|
||||
k: int = 0
|
||||
"""Number of most likely tokens to consider at each step."""
|
||||
|
||||
p: float = 0.75
|
||||
"""Total probability mass of tokens to consider at each step."""
|
||||
|
||||
endpoint: str = ""
|
||||
"""The uri of the endpoint from the deployed Model Deployment model."""
|
||||
|
||||
best_of: int = 1
|
||||
"""Generates best_of completions server-side and returns the "best"
|
||||
(the one with the highest log probability per token).
|
||||
"""
|
||||
|
||||
stop: Optional[List[str]] = None
|
||||
"""Stop words to use when generating. Model output is cut off
|
||||
at the first occurrence of any of these substrings."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment( # pylint: disable=no-self-argument
|
||||
cls, values: Dict
|
||||
) -> Dict:
|
||||
"""Validate that python package exists in environment."""
|
||||
try:
|
||||
import ads
|
||||
|
||||
except ImportError as ex:
|
||||
raise ImportError(
|
||||
"Could not import ads python package. "
|
||||
"Please install it with `pip install oracle_ads`."
|
||||
) from ex
|
||||
if not values.get("auth", None):
|
||||
values["auth"] = ads.common.auth.default_signer()
|
||||
values["endpoint"] = get_from_dict_or_env(
|
||||
values,
|
||||
"endpoint",
|
||||
"OCI_LLM_ENDPOINT",
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Default parameters for the model."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
**{"endpoint": self.endpoint},
|
||||
**self._default_params,
|
||||
}
|
||||
|
||||
def _construct_json_body(self, prompt: str, params: dict) -> dict:
|
||||
"""Constructs the request body as a dictionary (JSON)."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _invocation_params(self, stop: Optional[List[str]], **kwargs: Any) -> dict:
|
||||
"""Combines the invocation parameters with default parameters."""
|
||||
params = self._default_params
|
||||
if self.stop is not None and stop is not None:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
elif self.stop is not None:
|
||||
params["stop"] = self.stop
|
||||
elif stop is not None:
|
||||
params["stop"] = stop
|
||||
else:
|
||||
# Don't set "stop" in param as None. It should be a list.
|
||||
params["stop"] = []
|
||||
|
||||
return {**params, **kwargs}
|
||||
|
||||
def _process_response(self, response_json: dict) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to OCI Data Science Model Deployment endpoint.
|
||||
|
||||
Args:
|
||||
prompt (str):
|
||||
The prompt to pass into the model.
|
||||
stop (List[str], Optional):
|
||||
List of stop words to use when generating.
|
||||
kwargs:
|
||||
requests_kwargs:
|
||||
Additional ``**kwargs`` to pass to requests.post
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = oci_md("Tell me a joke.")
|
||||
|
||||
"""
|
||||
requests_kwargs = kwargs.pop("requests_kwargs", {})
|
||||
params = self._invocation_params(stop, **kwargs)
|
||||
body = self._construct_json_body(prompt, params)
|
||||
logger.info(f"LLM API Request:\n{prompt}")
|
||||
response = self._send_request(
|
||||
data=body, endpoint=self.endpoint, **requests_kwargs
|
||||
)
|
||||
completion = self._process_response(response)
|
||||
logger.info(f"LLM API Completion:\n{completion}")
|
||||
return completion
|
||||
|
||||
def _send_request(
|
||||
self,
|
||||
data: Any,
|
||||
endpoint: str,
|
||||
header: Optional[dict] = {},
|
||||
**kwargs: Any,
|
||||
) -> Dict:
|
||||
"""Sends request to the oci data science model deployment endpoint.
|
||||
|
||||
Args:
|
||||
data (Json serializable):
|
||||
data need to be sent to the endpoint.
|
||||
endpoint (str):
|
||||
The model HTTP endpoint.
|
||||
header (dict, optional):
|
||||
A dictionary of HTTP headers to send to the specified url.
|
||||
Defaults to {}.
|
||||
kwargs:
|
||||
Additional ``**kwargs`` to pass to requests.post.
|
||||
Raises:
|
||||
Exception:
|
||||
Raise when invoking fails.
|
||||
|
||||
Returns:
|
||||
A JSON representation of a requests.Response object.
|
||||
"""
|
||||
if not header:
|
||||
header = {}
|
||||
header["Content-Type"] = (
|
||||
header.pop("content_type", DEFAULT_CONTENT_TYPE_JSON)
|
||||
or DEFAULT_CONTENT_TYPE_JSON
|
||||
)
|
||||
request_kwargs = {"json": data}
|
||||
request_kwargs["headers"] = header
|
||||
timeout = kwargs.pop("timeout", DEFAULT_TIME_OUT)
|
||||
|
||||
attempts = 0
|
||||
while attempts < 2:
|
||||
request_kwargs["auth"] = self.auth.get("signer")
|
||||
response = requests.post(
|
||||
endpoint, timeout=timeout, **request_kwargs, **kwargs
|
||||
)
|
||||
if response.status_code == 401:
|
||||
self._refresh_signer()
|
||||
attempts += 1
|
||||
continue
|
||||
break
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
|
||||
except Exception:
|
||||
logger.error(
|
||||
"DEBUG INFO: request_kwargs=%s, status_code=%s, content=%s",
|
||||
request_kwargs,
|
||||
response.status_code,
|
||||
response.content,
|
||||
)
|
||||
raise
|
||||
|
||||
return response_json
|
||||
|
||||
def _refresh_signer(self) -> None:
|
||||
if self.auth.get("signer", None) and hasattr(
|
||||
self.auth["signer"], "refresh_security_token"
|
||||
):
|
||||
self.auth["signer"].refresh_security_token()
|
||||
|
||||
|
||||
class OCIModelDeploymentTGI(OCIModelDeploymentLLM):
|
||||
"""OCI Data Science Model Deployment TGI Endpoint.
|
||||
|
||||
To use, you must provide the model HTTP endpoint from your deployed
|
||||
model, e.g. https://<MD_OCID>/predict.
|
||||
|
||||
To authenticate, `oracle-ads` has been used to automatically load
|
||||
credentials: https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html
|
||||
|
||||
Make sure to have the required policies to access the OCI Data
|
||||
Science Model Deployment endpoint. See:
|
||||
https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm#model_dep_policies_auth__predict-endpoint
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import ModelDeploymentTGI
|
||||
|
||||
oci_md = ModelDeploymentTGI(endpoint="https://<MD_OCID>/predict")
|
||||
|
||||
"""
|
||||
|
||||
do_sample: bool = True
|
||||
"""If set to True, this parameter enables decoding strategies such as
|
||||
multi-nominal sampling, beam-search multi-nominal sampling, Top-K
|
||||
sampling and Top-p sampling.
|
||||
"""
|
||||
|
||||
watermark = True
|
||||
"""Watermarking with `A Watermark for Large Language Models <https://arxiv.org/abs/2301.10226>`_.
|
||||
Defaults to True."""
|
||||
|
||||
return_full_text = False
|
||||
"""Whether to prepend the prompt to the generated text. Defaults to False."""
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "oci_model_deployment_tgi_endpoint"
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for invoking OCI model deployment TGI endpoint."""
|
||||
return {
|
||||
"best_of": self.best_of,
|
||||
"max_new_tokens": self.max_tokens,
|
||||
"temperature": self.temperature,
|
||||
"top_k": self.k
|
||||
if self.k > 0
|
||||
else None, # `top_k` must be strictly positive'
|
||||
"top_p": self.p,
|
||||
"do_sample": self.do_sample,
|
||||
"return_full_text": self.return_full_text,
|
||||
"watermark": self.watermark,
|
||||
}
|
||||
|
||||
def _construct_json_body(self, prompt: str, params: dict) -> dict:
|
||||
return {
|
||||
"inputs": prompt,
|
||||
"parameters": params,
|
||||
}
|
||||
|
||||
def _process_response(self, response_json: dict) -> str:
|
||||
return str(response_json.get("generated_text", response_json)) + "\n"
|
||||
|
||||
|
||||
class OCIModelDeploymentVLLM(OCIModelDeploymentLLM):
|
||||
"""VLLM deployed on OCI Data Science Model Deployment
|
||||
|
||||
To use, you must provide the model HTTP endpoint from your deployed
|
||||
model, e.g. https://<MD_OCID>/predict.
|
||||
|
||||
To authenticate, `oracle-ads` has been used to automatically load
|
||||
credentials: https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html
|
||||
|
||||
Make sure to have the required policies to access the OCI Data
|
||||
Science Model Deployment endpoint. See:
|
||||
https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm#model_dep_policies_auth__predict-endpoint
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import OCIModelDeploymentVLLM
|
||||
|
||||
oci_md = OCIModelDeploymentVLLM(
|
||||
endpoint="https://<MD_OCID>/predict",
|
||||
model="mymodel"
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
model: str
|
||||
"""The name of the model."""
|
||||
|
||||
n: int = 1
|
||||
"""Number of output sequences to return for the given prompt."""
|
||||
|
||||
k: int = -1
|
||||
"""Number of most likely tokens to consider at each step."""
|
||||
|
||||
frequency_penalty: float = 0.0
|
||||
"""Penalizes repeated tokens according to frequency. Between 0 and 1."""
|
||||
|
||||
presence_penalty: float = 0.0
|
||||
"""Penalizes repeated tokens. Between 0 and 1."""
|
||||
|
||||
use_beam_search: bool = False
|
||||
"""Whether to use beam search instead of sampling."""
|
||||
|
||||
ignore_eos: bool = False
|
||||
"""Whether to ignore the EOS token and continue generating tokens after
|
||||
the EOS token is generated."""
|
||||
|
||||
logprobs: Optional[int] = None
|
||||
"""Number of log probabilities to return per output token."""
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "oci_model_deployment_vllm_endpoint"
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling vllm."""
|
||||
return {
|
||||
"best_of": self.best_of,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"ignore_eos": self.ignore_eos,
|
||||
"logprobs": self.logprobs,
|
||||
"max_tokens": self.max_tokens,
|
||||
"model": self.model,
|
||||
"n": self.n,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"stop": self.stop,
|
||||
"temperature": self.temperature,
|
||||
"top_k": self.k,
|
||||
"top_p": self.p,
|
||||
"use_beam_search": self.use_beam_search,
|
||||
}
|
||||
|
||||
def _construct_json_body(self, prompt: str, params: dict) -> dict:
|
||||
return {
|
||||
"prompt": prompt,
|
||||
**params,
|
||||
}
|
||||
|
||||
def _process_response(self, response_json: dict) -> str:
|
||||
return response_json["choices"][0]["text"]
|
Reference in New Issue
Block a user