mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 14:18:52 +00:00
community[minor]: Add OCI Generative AI integration (#16548)
<!-- Thank you for contributing to LangChain! Please title your PR "<package>: <description>", where <package> is whichever of langchain, community, core, experimental, etc. is being modified. Replace this entire comment with: - **Description:** Adding Oracle Cloud Infrastructure Generative AI integration. Oracle Cloud Infrastructure (OCI) Generative AI is a fully managed service that provides a set of state-of-the-art, customizable large language models (LLMs) that cover a wide range of use cases, and which is available through a single API. Using the OCI Generative AI service you can access ready-to-use pretrained models, or create and host your own fine-tuned custom models based on your own data on dedicated AI clusters. https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm - **Issue:** None, - **Dependencies:** OCI Python SDK, - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` from the root of the package you've modified to check this locally. Passed See contribution guidelines for more information on how to write/run tests, lint, etc: https://python.langchain.com/docs/contributing/ If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. we provide unit tests. However, we cannot provide integration tests due to Oracle policies that prohibit public sharing of api keys. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. --> --------- Co-authored-by: Arthur Cheng <arthur.cheng@oracle.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
b8768bd6e7
commit
c4e9c9ca29
191
docs/docs/integrations/llms/oci_generative_ai.ipynb
Normal file
191
docs/docs/integrations/llms/oci_generative_ai.ipynb
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Oracle Cloud Infrastructure Generative AI"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Oracle Cloud Infrastructure (OCI) Generative AI is a fully managed service that provides a set of state-of-the-art, customizable large language models (LLMs) that cover a wide range of use cases, and which is available through a single API.\n",
|
||||||
|
"Using the OCI Generative AI service you can access ready-to-use pretrained models, or create and host your own fine-tuned custom models based on your own data on dedicated AI clusters. Detailed documentation of the service and API is available __[here](https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm)__ and __[here](https://docs.oracle.com/en-us/iaas/api/#/en/generative-ai/20231130/)__.\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook explains how to use OCI's Genrative AI models with LangChain."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Prerequisite\n",
|
||||||
|
"We will need to install the oci sdk"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"!pip install -U oci"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### OCI Generative AI API endpoint \n",
|
||||||
|
"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Authentication\n",
|
||||||
|
"The authentication methods supported for this langchain integration are:\n",
|
||||||
|
"\n",
|
||||||
|
"1. API Key\n",
|
||||||
|
"2. Session token\n",
|
||||||
|
"3. Instance principal\n",
|
||||||
|
"4. Resource principal \n",
|
||||||
|
"\n",
|
||||||
|
"These follows the standard SDK authentication methods detailed __[here](https://docs.oracle.com/en-us/iaas/Content/API/Concepts/sdk_authentication_methods.htm)__.\n",
|
||||||
|
" "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Usage"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain_community.llms import OCIGenAI\n",
|
||||||
|
"\n",
|
||||||
|
"# use default authN method API-key\n",
|
||||||
|
"llm = OCIGenAI(\n",
|
||||||
|
" model_id=\"MY_MODEL\",\n",
|
||||||
|
" service_endpoint=\"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com\",\n",
|
||||||
|
" compartment_id=\"MY_OCID\",\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"response = llm.invoke(\"Tell me one fact about earth\", temperature=0.7)\n",
|
||||||
|
"print(response)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.chains import LLMChain\n",
|
||||||
|
"from langchain_core.prompts import PromptTemplate\n",
|
||||||
|
"\n",
|
||||||
|
"# Use Session Token to authN\n",
|
||||||
|
"llm = OCIGenAI(\n",
|
||||||
|
" model_id=\"MY_MODEL\",\n",
|
||||||
|
" service_endpoint=\"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com\",\n",
|
||||||
|
" compartment_id=\"MY_OCID\",\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"prompt = PromptTemplate(input_variables=[\"query\"], template=\"{query}\")\n",
|
||||||
|
"\n",
|
||||||
|
"llm_chain = LLMChain(llm=llm, prompt=prompt)\n",
|
||||||
|
"\n",
|
||||||
|
"response = llm_chain.invoke(\"what is the capital of france?\")\n",
|
||||||
|
"print(response)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.schema.output_parser import StrOutputParser\n",
|
||||||
|
"from langchain.schema.runnable import RunnablePassthrough\n",
|
||||||
|
"from langchain_community.embeddings import OCIGenAIEmbeddings\n",
|
||||||
|
"from langchain_community.vectorstores import FAISS\n",
|
||||||
|
"\n",
|
||||||
|
"embeddings = OCIGenAIEmbeddings(\n",
|
||||||
|
" model_id=\"MY_EMBEDDING_MODEL\",\n",
|
||||||
|
" service_endpoint=\"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com\",\n",
|
||||||
|
" compartment_id=\"MY_OCID\",\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"vectorstore = FAISS.from_texts(\n",
|
||||||
|
" [\n",
|
||||||
|
" \"Larry Ellison co-founded Oracle Corporation in 1977 with Bob Miner and Ed Oates.\",\n",
|
||||||
|
" \"Oracle Corporation is an American multinational computer technology company headquartered in Austin, Texas, United States.\",\n",
|
||||||
|
" ],\n",
|
||||||
|
" embedding=embeddings,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"retriever = vectorstore.as_retriever()\n",
|
||||||
|
"\n",
|
||||||
|
"template = \"\"\"Answer the question based only on the following context:\n",
|
||||||
|
"{context}\n",
|
||||||
|
" \n",
|
||||||
|
"Question: {question}\n",
|
||||||
|
"\"\"\"\n",
|
||||||
|
"prompt = PromptTemplate.from_template(template)\n",
|
||||||
|
"\n",
|
||||||
|
"llm = OCIGenAI(\n",
|
||||||
|
" model_id=\"MY_MODEL\",\n",
|
||||||
|
" service_endpoint=\"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com\",\n",
|
||||||
|
" compartment_id=\"MY_OCID\",\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"chain = (\n",
|
||||||
|
" {\"context\": retriever, \"question\": RunnablePassthrough()}\n",
|
||||||
|
" | prompt\n",
|
||||||
|
" | llm\n",
|
||||||
|
" | StrOutputParser()\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"print(chain.invoke(\"when was oracle founded?\"))\n",
|
||||||
|
"print(chain.invoke(\"where is oracle headquartered?\"))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "oci_langchain",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.18"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
@ -65,6 +65,7 @@ from langchain_community.embeddings.mlflow_gateway import MlflowAIGatewayEmbeddi
|
|||||||
from langchain_community.embeddings.modelscope_hub import ModelScopeEmbeddings
|
from langchain_community.embeddings.modelscope_hub import ModelScopeEmbeddings
|
||||||
from langchain_community.embeddings.mosaicml import MosaicMLInstructorEmbeddings
|
from langchain_community.embeddings.mosaicml import MosaicMLInstructorEmbeddings
|
||||||
from langchain_community.embeddings.nlpcloud import NLPCloudEmbeddings
|
from langchain_community.embeddings.nlpcloud import NLPCloudEmbeddings
|
||||||
|
from langchain_community.embeddings.oci_generative_ai import OCIGenAIEmbeddings
|
||||||
from langchain_community.embeddings.octoai_embeddings import OctoAIEmbeddings
|
from langchain_community.embeddings.octoai_embeddings import OctoAIEmbeddings
|
||||||
from langchain_community.embeddings.ollama import OllamaEmbeddings
|
from langchain_community.embeddings.ollama import OllamaEmbeddings
|
||||||
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
||||||
@ -144,6 +145,7 @@ __all__ = [
|
|||||||
"VoyageEmbeddings",
|
"VoyageEmbeddings",
|
||||||
"BookendEmbeddings",
|
"BookendEmbeddings",
|
||||||
"VolcanoEmbeddings",
|
"VolcanoEmbeddings",
|
||||||
|
"OCIGenAIEmbeddings",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -0,0 +1,203 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Mapping, Optional
|
||||||
|
|
||||||
|
from langchain_core.embeddings import Embeddings
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
|
CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint"
|
||||||
|
|
||||||
|
|
||||||
|
class OCIAuthType(Enum):
|
||||||
|
API_KEY = 1
|
||||||
|
SECURITY_TOKEN = 2
|
||||||
|
INSTANCE_PRINCIPAL = 3
|
||||||
|
RESOURCE_PRINCIPAL = 4
|
||||||
|
|
||||||
|
|
||||||
|
class OCIGenAIEmbeddings(BaseModel, Embeddings):
|
||||||
|
"""OCI embedding models.
|
||||||
|
|
||||||
|
To authenticate, the OCI client uses the methods described in
|
||||||
|
https://docs.oracle.com/en-us/iaas/Content/API/Concepts/sdk_authentication_methods.htm
|
||||||
|
|
||||||
|
The authentifcation method is passed through auth_type and should be one of:
|
||||||
|
API_KEY (default), SECURITY_TOKEN, INSTANCE_PRINCIPLE, RESOURCE_PRINCIPLE
|
||||||
|
|
||||||
|
Make sure you have the required policies (profile/roles) to
|
||||||
|
access the OCI Generative AI service. If a specific config profile is used,
|
||||||
|
you must pass the name of the profile (~/.oci/config) through auth_profile.
|
||||||
|
|
||||||
|
To use, you must provide the compartment id
|
||||||
|
along with the endpoint url, and model id
|
||||||
|
as named parameters to the constructor.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.embeddings import OCIGenAIEmbeddings
|
||||||
|
|
||||||
|
embeddings = OCIGenAIEmbeddings(
|
||||||
|
model_id="MY_EMBEDDING_MODEL",
|
||||||
|
service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com",
|
||||||
|
compartment_id="MY_OCID"
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
client: Any #: :meta private:
|
||||||
|
|
||||||
|
service_models: Any #: :meta private:
|
||||||
|
|
||||||
|
auth_type: Optional[str] = "API_KEY"
|
||||||
|
"""Authentication type, could be
|
||||||
|
|
||||||
|
API_KEY,
|
||||||
|
SECURITY_TOKEN,
|
||||||
|
INSTANCE_PRINCIPLE,
|
||||||
|
RESOURCE_PRINCIPLE
|
||||||
|
|
||||||
|
If not specified, API_KEY will be used
|
||||||
|
"""
|
||||||
|
|
||||||
|
auth_profile: Optional[str] = "DEFAULT"
|
||||||
|
"""The name of the profile in ~/.oci/config
|
||||||
|
If not specified , DEFAULT will be used
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_id: str = None
|
||||||
|
"""Id of the model to call, e.g., cohere.embed-english-light-v2.0"""
|
||||||
|
|
||||||
|
model_kwargs: Optional[Dict] = None
|
||||||
|
"""Keyword arguments to pass to the model"""
|
||||||
|
|
||||||
|
service_endpoint: str = None
|
||||||
|
"""service endpoint url"""
|
||||||
|
|
||||||
|
compartment_id: str = None
|
||||||
|
"""OCID of compartment"""
|
||||||
|
|
||||||
|
truncate: Optional[str] = "END"
|
||||||
|
"""Truncate embeddings that are too long from start or end ("NONE"|"START"|"END")"""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict: # pylint: disable=no-self-argument
|
||||||
|
"""Validate that OCI config and python package exists in environment."""
|
||||||
|
|
||||||
|
# Skip creating new client if passed in constructor
|
||||||
|
if values["client"] is not None:
|
||||||
|
return values
|
||||||
|
|
||||||
|
try:
|
||||||
|
import oci
|
||||||
|
|
||||||
|
client_kwargs = {
|
||||||
|
"config": {},
|
||||||
|
"signer": None,
|
||||||
|
"service_endpoint": values["service_endpoint"],
|
||||||
|
"retry_strategy": oci.retry.DEFAULT_RETRY_STRATEGY,
|
||||||
|
"timeout": (10, 240), # default timeout config for OCI Gen AI service
|
||||||
|
}
|
||||||
|
|
||||||
|
if values["auth_type"] == OCIAuthType(1).name:
|
||||||
|
client_kwargs["config"] = oci.config.from_file(
|
||||||
|
profile_name=values["auth_profile"]
|
||||||
|
)
|
||||||
|
client_kwargs.pop("signer", None)
|
||||||
|
elif values["auth_type"] == OCIAuthType(2).name:
|
||||||
|
|
||||||
|
def make_security_token_signer(oci_config):
|
||||||
|
pk = oci.signer.load_private_key_from_file(
|
||||||
|
oci_config.get("key_file"), None
|
||||||
|
)
|
||||||
|
with open(
|
||||||
|
oci_config.get("security_token_file"), encoding="utf-8"
|
||||||
|
) as f:
|
||||||
|
st_string = f.read()
|
||||||
|
return oci.auth.signers.SecurityTokenSigner(st_string, pk)
|
||||||
|
|
||||||
|
client_kwargs["config"] = oci.config.from_file(
|
||||||
|
profile_name=values["auth_profile"]
|
||||||
|
)
|
||||||
|
client_kwargs["signer"] = make_security_token_signer(
|
||||||
|
oci_config=client_kwargs["config"]
|
||||||
|
)
|
||||||
|
elif values["auth_type"] == OCIAuthType(3).name:
|
||||||
|
client_kwargs[
|
||||||
|
"signer"
|
||||||
|
] = oci.auth.signers.InstancePrincipalsSecurityTokenSigner()
|
||||||
|
elif values["auth_type"] == OCIAuthType(4).name:
|
||||||
|
client_kwargs[
|
||||||
|
"signer"
|
||||||
|
] = oci.auth.signers.get_resource_principals_signer()
|
||||||
|
else:
|
||||||
|
raise ValueError("Please provide valid value to auth_type")
|
||||||
|
|
||||||
|
values["client"] = oci.generative_ai_inference.GenerativeAiInferenceClient(
|
||||||
|
**client_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
except ImportError as ex:
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
"Could not import oci python package. "
|
||||||
|
"Please make sure you have the oci package installed."
|
||||||
|
) from ex
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not authenticate with OCI client. "
|
||||||
|
"Please check if ~/.oci/config exists. "
|
||||||
|
"If INSTANCE_PRINCIPLE or RESOURCE_PRINCIPLE is used, "
|
||||||
|
"Please check the specified "
|
||||||
|
"auth_profile and auth_type are valid."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
|
"""Get the identifying parameters."""
|
||||||
|
_model_kwargs = self.model_kwargs or {}
|
||||||
|
return {
|
||||||
|
**{"model_kwargs": _model_kwargs},
|
||||||
|
}
|
||||||
|
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Call out to OCIGenAI's embedding endpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: The list of texts to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of embeddings, one for each text.
|
||||||
|
"""
|
||||||
|
from oci.generative_ai_inference import models
|
||||||
|
|
||||||
|
if self.model_id.startswith(CUSTOM_ENDPOINT_PREFIX):
|
||||||
|
serving_mode = models.DedicatedServingMode(endpoint_id=self.model_id)
|
||||||
|
else:
|
||||||
|
serving_mode = models.OnDemandServingMode(model_id=self.model_id)
|
||||||
|
|
||||||
|
invocation_obj = models.EmbedTextDetails(
|
||||||
|
serving_mode=serving_mode,
|
||||||
|
compartment_id=self.compartment_id,
|
||||||
|
truncate=self.truncate,
|
||||||
|
inputs=texts,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = self.client.embed_text(invocation_obj)
|
||||||
|
|
||||||
|
return response.data.embeddings
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
"""Call out to OCIGenAI's embedding endpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embeddings for the text.
|
||||||
|
"""
|
||||||
|
return self.embed_documents([text])[0]
|
@ -346,6 +346,12 @@ def _import_oci_md_vllm() -> Any:
|
|||||||
return OCIModelDeploymentVLLM
|
return OCIModelDeploymentVLLM
|
||||||
|
|
||||||
|
|
||||||
|
def _import_oci_gen_ai() -> Any:
|
||||||
|
from langchain_community.llms.oci_generative_ai import OCIGenAI
|
||||||
|
|
||||||
|
return OCIGenAI
|
||||||
|
|
||||||
|
|
||||||
def _import_octoai_endpoint() -> Any:
|
def _import_octoai_endpoint() -> Any:
|
||||||
from langchain_community.llms.octoai_endpoint import OctoAIEndpoint
|
from langchain_community.llms.octoai_endpoint import OctoAIEndpoint
|
||||||
|
|
||||||
@ -667,6 +673,8 @@ def __getattr__(name: str) -> Any:
|
|||||||
return _import_oci_md_tgi()
|
return _import_oci_md_tgi()
|
||||||
elif name == "OCIModelDeploymentVLLM":
|
elif name == "OCIModelDeploymentVLLM":
|
||||||
return _import_oci_md_vllm()
|
return _import_oci_md_vllm()
|
||||||
|
elif name == "OCIGenAI":
|
||||||
|
return _import_oci_gen_ai()
|
||||||
elif name == "OctoAIEndpoint":
|
elif name == "OctoAIEndpoint":
|
||||||
return _import_octoai_endpoint()
|
return _import_octoai_endpoint()
|
||||||
elif name == "Ollama":
|
elif name == "Ollama":
|
||||||
@ -801,6 +809,7 @@ __all__ = [
|
|||||||
"NLPCloud",
|
"NLPCloud",
|
||||||
"OCIModelDeploymentTGI",
|
"OCIModelDeploymentTGI",
|
||||||
"OCIModelDeploymentVLLM",
|
"OCIModelDeploymentVLLM",
|
||||||
|
"OCIGenAI",
|
||||||
"Ollama",
|
"Ollama",
|
||||||
"OpenAI",
|
"OpenAI",
|
||||||
"OpenAIChat",
|
"OpenAIChat",
|
||||||
@ -891,6 +900,7 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
|
|||||||
"nlpcloud": _import_nlpcloud,
|
"nlpcloud": _import_nlpcloud,
|
||||||
"oci_model_deployment_tgi_endpoint": _import_oci_md_tgi,
|
"oci_model_deployment_tgi_endpoint": _import_oci_md_tgi,
|
||||||
"oci_model_deployment_vllm_endpoint": _import_oci_md_vllm,
|
"oci_model_deployment_vllm_endpoint": _import_oci_md_vllm,
|
||||||
|
"oci_generative_ai": _import_oci_gen_ai,
|
||||||
"ollama": _import_ollama,
|
"ollama": _import_ollama,
|
||||||
"openai": _import_openai,
|
"openai": _import_openai,
|
||||||
"openlm": _import_openlm,
|
"openlm": _import_openlm,
|
||||||
|
276
libs/community/langchain_community/llms/oci_generative_ai.py
Normal file
276
libs/community/langchain_community/llms/oci_generative_ai.py
Normal file
@ -0,0 +1,276 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Mapping, Optional
|
||||||
|
|
||||||
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
|
from langchain_core.language_models.llms import LLM
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
|
from langchain_community.llms.utils import enforce_stop_tokens
|
||||||
|
|
||||||
|
CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint"
|
||||||
|
VALID_PROVIDERS = ("cohere", "meta")
|
||||||
|
|
||||||
|
|
||||||
|
class OCIAuthType(Enum):
|
||||||
|
API_KEY = 1
|
||||||
|
SECURITY_TOKEN = 2
|
||||||
|
INSTANCE_PRINCIPAL = 3
|
||||||
|
RESOURCE_PRINCIPAL = 4
|
||||||
|
|
||||||
|
|
||||||
|
class OCIGenAIBase(BaseModel, ABC):
|
||||||
|
"""Base class for OCI GenAI models"""
|
||||||
|
|
||||||
|
client: Any #: :meta private:
|
||||||
|
|
||||||
|
auth_type: Optional[str] = "API_KEY"
|
||||||
|
"""Authentication type, could be
|
||||||
|
|
||||||
|
API_KEY,
|
||||||
|
SECURITY_TOKEN,
|
||||||
|
INSTANCE_PRINCIPLE,
|
||||||
|
RESOURCE_PRINCIPLE
|
||||||
|
|
||||||
|
If not specified, API_KEY will be used
|
||||||
|
"""
|
||||||
|
|
||||||
|
auth_profile: Optional[str] = "DEFAULT"
|
||||||
|
"""The name of the profile in ~/.oci/config
|
||||||
|
If not specified , DEFAULT will be used
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_id: str = None
|
||||||
|
"""Id of the model to call, e.g., cohere.command"""
|
||||||
|
|
||||||
|
provider: str = None
|
||||||
|
"""Provider name of the model. Default to None,
|
||||||
|
will try to be derived from the model_id
|
||||||
|
otherwise, requires user input
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_kwargs: Optional[Dict] = None
|
||||||
|
"""Keyword arguments to pass to the model"""
|
||||||
|
|
||||||
|
service_endpoint: str = None
|
||||||
|
"""service endpoint url"""
|
||||||
|
|
||||||
|
compartment_id: str = None
|
||||||
|
"""OCID of compartment"""
|
||||||
|
|
||||||
|
is_stream: bool = False
|
||||||
|
"""Whether to stream back partial progress"""
|
||||||
|
|
||||||
|
llm_stop_sequence_mapping: Mapping[str, str] = {
|
||||||
|
"cohere": "stop_sequences",
|
||||||
|
"meta": "stop",
|
||||||
|
}
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate that OCI config and python package exists in environment."""
|
||||||
|
|
||||||
|
# Skip creating new client if passed in constructor
|
||||||
|
if values["client"] is not None:
|
||||||
|
return values
|
||||||
|
|
||||||
|
try:
|
||||||
|
import oci
|
||||||
|
|
||||||
|
client_kwargs = {
|
||||||
|
"config": {},
|
||||||
|
"signer": None,
|
||||||
|
"service_endpoint": values["service_endpoint"],
|
||||||
|
"retry_strategy": oci.retry.DEFAULT_RETRY_STRATEGY,
|
||||||
|
"timeout": (10, 240), # default timeout config for OCI Gen AI service
|
||||||
|
}
|
||||||
|
|
||||||
|
if values["auth_type"] == OCIAuthType(1).name:
|
||||||
|
client_kwargs["config"] = oci.config.from_file(
|
||||||
|
profile_name=values["auth_profile"]
|
||||||
|
)
|
||||||
|
client_kwargs.pop("signer", None)
|
||||||
|
elif values["auth_type"] == OCIAuthType(2).name:
|
||||||
|
|
||||||
|
def make_security_token_signer(oci_config):
|
||||||
|
pk = oci.signer.load_private_key_from_file(
|
||||||
|
oci_config.get("key_file"), None
|
||||||
|
)
|
||||||
|
with open(
|
||||||
|
oci_config.get("security_token_file"), encoding="utf-8"
|
||||||
|
) as f:
|
||||||
|
st_string = f.read()
|
||||||
|
return oci.auth.signers.SecurityTokenSigner(st_string, pk)
|
||||||
|
|
||||||
|
client_kwargs["config"] = oci.config.from_file(
|
||||||
|
profile_name=values["auth_profile"]
|
||||||
|
)
|
||||||
|
client_kwargs["signer"] = make_security_token_signer(
|
||||||
|
oci_config=client_kwargs["config"]
|
||||||
|
)
|
||||||
|
elif values["auth_type"] == OCIAuthType(3).name:
|
||||||
|
client_kwargs[
|
||||||
|
"signer"
|
||||||
|
] = oci.auth.signers.InstancePrincipalsSecurityTokenSigner()
|
||||||
|
elif values["auth_type"] == OCIAuthType(4).name:
|
||||||
|
client_kwargs[
|
||||||
|
"signer"
|
||||||
|
] = oci.auth.signers.get_resource_principals_signer()
|
||||||
|
else:
|
||||||
|
raise ValueError("Please provide valid value to auth_type")
|
||||||
|
|
||||||
|
values["client"] = oci.generative_ai_inference.GenerativeAiInferenceClient(
|
||||||
|
**client_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
except ImportError as ex:
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
"Could not import oci python package. "
|
||||||
|
"Please make sure you have the oci package installed."
|
||||||
|
) from ex
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not authenticate with OCI client. "
|
||||||
|
"Please check if ~/.oci/config exists. "
|
||||||
|
"If INSTANCE_PRINCIPLE or RESOURCE_PRINCIPLE is used, "
|
||||||
|
"Please check the specified "
|
||||||
|
"auth_profile and auth_type are valid."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
|
"""Get the identifying parameters."""
|
||||||
|
_model_kwargs = self.model_kwargs or {}
|
||||||
|
return {
|
||||||
|
**{"model_kwargs": _model_kwargs},
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_provider(self) -> str:
|
||||||
|
if self.provider is not None:
|
||||||
|
provider = self.provider
|
||||||
|
else:
|
||||||
|
provider = self.model_id.split(".")[0].lower()
|
||||||
|
|
||||||
|
if provider not in VALID_PROVIDERS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid provider derived from model_id: {self.model_id} "
|
||||||
|
"Please explicitly pass in the supported provider "
|
||||||
|
"when using custom endpoint"
|
||||||
|
)
|
||||||
|
return provider
|
||||||
|
|
||||||
|
|
||||||
|
class OCIGenAI(LLM, OCIGenAIBase):
|
||||||
|
"""OCI large language models.
|
||||||
|
|
||||||
|
To authenticate, the OCI client uses the methods described in
|
||||||
|
https://docs.oracle.com/en-us/iaas/Content/API/Concepts/sdk_authentication_methods.htm
|
||||||
|
|
||||||
|
The authentifcation method is passed through auth_type and should be one of:
|
||||||
|
API_KEY (default), SECURITY_TOKEN, INSTANCE_PRINCIPLE, RESOURCE_PRINCIPLE
|
||||||
|
|
||||||
|
Make sure you have the required policies (profile/roles) to
|
||||||
|
access the OCI Generative AI service.
|
||||||
|
If a specific config profile is used, you must pass
|
||||||
|
the name of the profile (from ~/.oci/config) through auth_profile.
|
||||||
|
|
||||||
|
To use, you must provide the compartment id
|
||||||
|
along with the endpoint url, and model id
|
||||||
|
as named parameters to the constructor.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_community.llms import OCIGenAI
|
||||||
|
|
||||||
|
llm = OCIGenAI(
|
||||||
|
model_id="MY_MODEL_ID",
|
||||||
|
service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com",
|
||||||
|
compartment_id="MY_OCID"
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of llm."""
|
||||||
|
return "oci"
|
||||||
|
|
||||||
|
def _prepare_invocation_object(
|
||||||
|
self, prompt: str, stop: Optional[List[str]], kwargs: Dict[str, Any]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
from oci.generative_ai_inference import models
|
||||||
|
|
||||||
|
oci_llm_request_mapping = {
|
||||||
|
"cohere": models.CohereLlmInferenceRequest,
|
||||||
|
"meta": models.LlamaLlmInferenceRequest,
|
||||||
|
}
|
||||||
|
provider = self._get_provider()
|
||||||
|
_model_kwargs = self.model_kwargs or {}
|
||||||
|
if stop is not None:
|
||||||
|
_model_kwargs[self.llm_stop_sequence_mapping[provider]] = stop
|
||||||
|
|
||||||
|
if self.model_id.startswith(CUSTOM_ENDPOINT_PREFIX):
|
||||||
|
serving_mode = models.DedicatedServingMode(endpoint_id=self.model_id)
|
||||||
|
else:
|
||||||
|
serving_mode = models.OnDemandServingMode(model_id=self.model_id)
|
||||||
|
|
||||||
|
inference_params = {**_model_kwargs, **kwargs}
|
||||||
|
inference_params["prompt"] = prompt
|
||||||
|
inference_params["is_stream"] = self.is_stream
|
||||||
|
|
||||||
|
invocation_obj = models.GenerateTextDetails(
|
||||||
|
compartment_id=self.compartment_id,
|
||||||
|
serving_mode=serving_mode,
|
||||||
|
inference_request=oci_llm_request_mapping[provider](**inference_params),
|
||||||
|
)
|
||||||
|
|
||||||
|
return invocation_obj
|
||||||
|
|
||||||
|
def _process_response(self, response: Any, stop: Optional[List[str]]) -> str:
|
||||||
|
provider = self._get_provider()
|
||||||
|
if provider == "cohere":
|
||||||
|
text = response.data.inference_response.generated_texts[0].text
|
||||||
|
elif provider == "meta":
|
||||||
|
text = response.data.inference_response.choices[0].text
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid provider: {provider}")
|
||||||
|
|
||||||
|
if stop is not None:
|
||||||
|
text = enforce_stop_tokens(text, stop)
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""Call out to OCIGenAI generate endpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt to pass into the model.
|
||||||
|
stop: Optional list of stop words to use when generating.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The string generated by the model.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
response = llm.invoke("Tell me a joke.")
|
||||||
|
"""
|
||||||
|
|
||||||
|
invocation_obj = self._prepare_invocation_object(prompt, stop, kwargs)
|
||||||
|
response = self.client.generate_text(invocation_obj)
|
||||||
|
return self._process_response(response, stop)
|
13
libs/community/poetry.lock
generated
13
libs/community/poetry.lock
generated
@ -3433,7 +3433,6 @@ files = [
|
|||||||
{file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:227b178b22a7f91ae88525810441791b1ca1fc71c86f03190911793be15cec3d"},
|
{file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:227b178b22a7f91ae88525810441791b1ca1fc71c86f03190911793be15cec3d"},
|
||||||
{file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:780eb6383fbae12afa819ef676fc93e1548ae4b076c004a393af26a04b460742"},
|
{file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:780eb6383fbae12afa819ef676fc93e1548ae4b076c004a393af26a04b460742"},
|
||||||
{file = "jq-1.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:08ded6467f4ef89fec35b2bf310f210f8cd13fbd9d80e521500889edf8d22441"},
|
{file = "jq-1.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:08ded6467f4ef89fec35b2bf310f210f8cd13fbd9d80e521500889edf8d22441"},
|
||||||
{file = "jq-1.6.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:49e44ed677713f4115bd5bf2dbae23baa4cd503be350e12a1c1f506b0687848f"},
|
|
||||||
{file = "jq-1.6.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:984f33862af285ad3e41e23179ac4795f1701822473e1a26bf87ff023e5a89ea"},
|
{file = "jq-1.6.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:984f33862af285ad3e41e23179ac4795f1701822473e1a26bf87ff023e5a89ea"},
|
||||||
{file = "jq-1.6.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f42264fafc6166efb5611b5d4cb01058887d050a6c19334f6a3f8a13bb369df5"},
|
{file = "jq-1.6.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f42264fafc6166efb5611b5d4cb01058887d050a6c19334f6a3f8a13bb369df5"},
|
||||||
{file = "jq-1.6.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a67154f150aaf76cc1294032ed588436eb002097dd4fd1e283824bf753a05080"},
|
{file = "jq-1.6.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a67154f150aaf76cc1294032ed588436eb002097dd4fd1e283824bf753a05080"},
|
||||||
@ -4999,13 +4998,13 @@ signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "oci"
|
name = "oci"
|
||||||
version = "2.118.0"
|
version = "2.119.1"
|
||||||
description = "Oracle Cloud Infrastructure Python SDK"
|
description = "Oracle Cloud Infrastructure Python SDK"
|
||||||
optional = true
|
optional = true
|
||||||
python-versions = "*"
|
python-versions = "*"
|
||||||
files = [
|
files = [
|
||||||
{file = "oci-2.118.0-py3-none-any.whl", hash = "sha256:766170a9b4c93053ba3fe5ae63c0ab48fdd71b4d17709742a2b45249f0829872"},
|
{file = "oci-2.119.1-py3-none-any.whl", hash = "sha256:64b6012f3c2b70cf7fb5f58a1a4b4458d8f4d41ea1b79a5d9f8ca4beb2dfa225"},
|
||||||
{file = "oci-2.118.0.tar.gz", hash = "sha256:1004726c4dad6c02f967b7bc4e733ff552451a2914cb542c380756c7d46bb938"},
|
{file = "oci-2.119.1.tar.gz", hash = "sha256:992df963382f378b93634826956677f3c13407ca1b828c4eaf1cfd18f19fae33"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@ -6223,6 +6222,7 @@ files = [
|
|||||||
{file = "pymongo-4.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b8729dbf25eb32ad0dc0b9bd5e6a0d0b7e5c2dc8ec06ad171088e1896b522a74"},
|
{file = "pymongo-4.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b8729dbf25eb32ad0dc0b9bd5e6a0d0b7e5c2dc8ec06ad171088e1896b522a74"},
|
||||||
{file = "pymongo-4.6.1-cp312-cp312-win32.whl", hash = "sha256:3177f783ae7e08aaf7b2802e0df4e4b13903520e8380915e6337cdc7a6ff01d8"},
|
{file = "pymongo-4.6.1-cp312-cp312-win32.whl", hash = "sha256:3177f783ae7e08aaf7b2802e0df4e4b13903520e8380915e6337cdc7a6ff01d8"},
|
||||||
{file = "pymongo-4.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:00c199e1c593e2c8b033136d7a08f0c376452bac8a896c923fcd6f419e07bdd2"},
|
{file = "pymongo-4.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:00c199e1c593e2c8b033136d7a08f0c376452bac8a896c923fcd6f419e07bdd2"},
|
||||||
|
{file = "pymongo-4.6.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6dcc95f4bb9ed793714b43f4f23a7b0c57e4ef47414162297d6f650213512c19"},
|
||||||
{file = "pymongo-4.6.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:13552ca505366df74e3e2f0a4f27c363928f3dff0eef9f281eb81af7f29bc3c5"},
|
{file = "pymongo-4.6.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:13552ca505366df74e3e2f0a4f27c363928f3dff0eef9f281eb81af7f29bc3c5"},
|
||||||
{file = "pymongo-4.6.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:77e0df59b1a4994ad30c6d746992ae887f9756a43fc25dec2db515d94cf0222d"},
|
{file = "pymongo-4.6.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:77e0df59b1a4994ad30c6d746992ae887f9756a43fc25dec2db515d94cf0222d"},
|
||||||
{file = "pymongo-4.6.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:3a7f02a58a0c2912734105e05dedbee4f7507e6f1bd132ebad520be0b11d46fd"},
|
{file = "pymongo-4.6.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:3a7f02a58a0c2912734105e05dedbee4f7507e6f1bd132ebad520be0b11d46fd"},
|
||||||
@ -6773,6 +6773,7 @@ files = [
|
|||||||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||||
|
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||||
@ -9226,9 +9227,9 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
|
|||||||
|
|
||||||
[extras]
|
[extras]
|
||||||
cli = ["typer"]
|
cli = ["typer"]
|
||||||
extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cohere", "dashvector", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict", "zhipuai"]
|
extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cohere", "dashvector", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "oci", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict", "zhipuai"]
|
||||||
|
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.8.1,<4.0"
|
||||||
content-hash = "c03bd15da5fd84ec91adec43e62b06623b6ec51003530a762455f74a4ee3715f"
|
content-hash = "18694abbcaec37f026883b07d1c198f9fc3fdb012d7f2be16ce4ad1866913463"
|
||||||
|
@ -89,6 +89,7 @@ oracle-ads = {version = "^2.9.1", optional = true}
|
|||||||
zhipuai = {version = "^1.0.7", optional = true}
|
zhipuai = {version = "^1.0.7", optional = true}
|
||||||
elasticsearch = {version = "^8.12.0", optional = true}
|
elasticsearch = {version = "^8.12.0", optional = true}
|
||||||
hdbcli = {version = "^2.19.21", optional = true}
|
hdbcli = {version = "^2.19.21", optional = true}
|
||||||
|
oci = {version = "^2.119.1", optional = true}
|
||||||
|
|
||||||
[tool.poetry.group.test]
|
[tool.poetry.group.test]
|
||||||
optional = true
|
optional = true
|
||||||
@ -253,6 +254,7 @@ extended_testing = [
|
|||||||
"zhipuai",
|
"zhipuai",
|
||||||
"elasticsearch",
|
"elasticsearch",
|
||||||
"hdbcli",
|
"hdbcli",
|
||||||
|
"oci"
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
|
@ -56,6 +56,7 @@ EXPECTED_ALL = [
|
|||||||
"VoyageEmbeddings",
|
"VoyageEmbeddings",
|
||||||
"BookendEmbeddings",
|
"BookendEmbeddings",
|
||||||
"VolcanoEmbeddings",
|
"VolcanoEmbeddings",
|
||||||
|
"OCIGenAIEmbeddings",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -0,0 +1,50 @@
|
|||||||
|
"""Test OCI Generative AI embedding service."""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pytest import MonkeyPatch
|
||||||
|
|
||||||
|
from langchain_community.embeddings import OCIGenAIEmbeddings
|
||||||
|
|
||||||
|
|
||||||
|
class MockResponseDict(dict):
|
||||||
|
def __getattr__(self, val):
|
||||||
|
return self[val]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("oci")
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_model_id", ["cohere.embed-english-light-v3.0", "cohere.embed-english-v3.0"]
|
||||||
|
)
|
||||||
|
def test_embedding_call(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
|
||||||
|
"""Test valid call to OCI Generative AI embedding service."""
|
||||||
|
oci_gen_ai_client = MagicMock()
|
||||||
|
embeddings = OCIGenAIEmbeddings(
|
||||||
|
model_id=test_model_id,
|
||||||
|
service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com",
|
||||||
|
client=oci_gen_ai_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
def mocked_response(invocation_obj):
|
||||||
|
docs = invocation_obj.inputs
|
||||||
|
|
||||||
|
embeddings = []
|
||||||
|
for d in docs:
|
||||||
|
if "Hello" in d:
|
||||||
|
v = [1.0, 0.0, 0.0]
|
||||||
|
elif "World" in d:
|
||||||
|
v = [0.0, 1.0, 0.0]
|
||||||
|
else:
|
||||||
|
v = [0.0, 0.0, 1.0]
|
||||||
|
embeddings.append(v)
|
||||||
|
|
||||||
|
return MockResponseDict(
|
||||||
|
{"status": 200, "data": MockResponseDict({"embeddings": embeddings})}
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(embeddings.client, "embed_text", mocked_response)
|
||||||
|
|
||||||
|
output = embeddings.embed_documents(["Hello", "World"])
|
||||||
|
correct_output = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]
|
||||||
|
|
||||||
|
assert output == correct_output
|
@ -52,6 +52,7 @@ EXPECT_ALL = [
|
|||||||
"Nebula",
|
"Nebula",
|
||||||
"OCIModelDeploymentTGI",
|
"OCIModelDeploymentTGI",
|
||||||
"OCIModelDeploymentVLLM",
|
"OCIModelDeploymentVLLM",
|
||||||
|
"OCIGenAI",
|
||||||
"NIBittensorLLM",
|
"NIBittensorLLM",
|
||||||
"NLPCloud",
|
"NLPCloud",
|
||||||
"Ollama",
|
"Ollama",
|
||||||
|
@ -0,0 +1,76 @@
|
|||||||
|
"""Test OCI Generative AI LLM service"""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pytest import MonkeyPatch
|
||||||
|
|
||||||
|
from langchain_community.llms import OCIGenAI
|
||||||
|
|
||||||
|
|
||||||
|
class MockResponseDict(dict):
|
||||||
|
def __getattr__(self, val):
|
||||||
|
return self[val]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("oci")
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_model_id", ["cohere.command", "cohere.command-light", "meta.llama-2-70b-chat"]
|
||||||
|
)
|
||||||
|
def test_llm_call(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
|
||||||
|
"""Test valid call to OCI Generative AI LLM service."""
|
||||||
|
oci_gen_ai_client = MagicMock()
|
||||||
|
llm = OCIGenAI(model_id=test_model_id, client=oci_gen_ai_client)
|
||||||
|
|
||||||
|
provider = llm._get_provider()
|
||||||
|
|
||||||
|
def mocked_response(*args):
|
||||||
|
response_text = "This is the completion."
|
||||||
|
|
||||||
|
if provider == "cohere":
|
||||||
|
return MockResponseDict(
|
||||||
|
{
|
||||||
|
"status": 200,
|
||||||
|
"data": MockResponseDict(
|
||||||
|
{
|
||||||
|
"inference_response": MockResponseDict(
|
||||||
|
{
|
||||||
|
"generated_texts": [
|
||||||
|
MockResponseDict(
|
||||||
|
{
|
||||||
|
"text": response_text,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if provider == "meta":
|
||||||
|
return MockResponseDict(
|
||||||
|
{
|
||||||
|
"status": 200,
|
||||||
|
"data": MockResponseDict(
|
||||||
|
{
|
||||||
|
"inference_response": MockResponseDict(
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
MockResponseDict(
|
||||||
|
{
|
||||||
|
"text": response_text,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(llm.client, "generate_text", mocked_response)
|
||||||
|
|
||||||
|
output = llm.invoke("This is a prompt.", temperature=0.2)
|
||||||
|
assert output == "This is the completion."
|
Loading…
Reference in New Issue
Block a user