From c4e9c9ca2959e2cf756fc499b7c6f35d8786e143 Mon Sep 17 00:00:00 2001 From: Rave Harpaz Date: Wed, 24 Jan 2024 18:23:50 -0800 Subject: [PATCH] community[minor]: Add OCI Generative AI integration (#16548) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --------- Co-authored-by: Arthur Cheng Co-authored-by: Bagatur --- .../integrations/llms/oci_generative_ai.ipynb | 191 ++++++++++++ .../embeddings/__init__.py | 2 + .../embeddings/oci_generative_ai.py | 203 +++++++++++++ .../langchain_community/llms/__init__.py | 10 + .../llms/oci_generative_ai.py | 276 ++++++++++++++++++ libs/community/poetry.lock | 13 +- libs/community/pyproject.toml | 2 + .../unit_tests/embeddings/test_imports.py | 1 + .../embeddings/test_oci_gen_ai_embedding.py | 50 ++++ .../tests/unit_tests/llms/test_imports.py | 1 + .../unit_tests/llms/test_oci_generative_ai.py | 76 +++++ 11 files changed, 819 insertions(+), 6 deletions(-) create mode 100644 docs/docs/integrations/llms/oci_generative_ai.ipynb create mode 100644 libs/community/langchain_community/embeddings/oci_generative_ai.py create mode 100644 libs/community/langchain_community/llms/oci_generative_ai.py create mode 100644 libs/community/tests/unit_tests/embeddings/test_oci_gen_ai_embedding.py create mode 100644 libs/community/tests/unit_tests/llms/test_oci_generative_ai.py diff --git a/docs/docs/integrations/llms/oci_generative_ai.ipynb b/docs/docs/integrations/llms/oci_generative_ai.ipynb new file mode 100644 index 00000000000..200f1038f3e --- /dev/null +++ b/docs/docs/integrations/llms/oci_generative_ai.ipynb @@ -0,0 +1,191 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Oracle Cloud Infrastructure Generative AI" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Oracle Cloud Infrastructure (OCI) Generative AI is a fully managed service that provides a set of state-of-the-art, customizable large language models (LLMs) that cover a wide range of use cases, and which is available through a single API.\n", + "Using the OCI Generative AI service you can access ready-to-use pretrained models, or create and host your own fine-tuned custom models based on your own data on dedicated AI clusters. Detailed documentation of the service and API is available __[here](https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm)__ and __[here](https://docs.oracle.com/en-us/iaas/api/#/en/generative-ai/20231130/)__.\n", + "\n", + "This notebook explains how to use OCI's Genrative AI models with LangChain." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Prerequisite\n", + "We will need to install the oci sdk" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -U oci" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### OCI Generative AI API endpoint \n", + "https://inference.generativeai.us-chicago-1.oci.oraclecloud.com" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Authentication\n", + "The authentication methods supported for this langchain integration are:\n", + "\n", + "1. API Key\n", + "2. Session token\n", + "3. Instance principal\n", + "4. Resource principal \n", + "\n", + "These follows the standard SDK authentication methods detailed __[here](https://docs.oracle.com/en-us/iaas/Content/API/Concepts/sdk_authentication_methods.htm)__.\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Usage" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.llms import OCIGenAI\n", + "\n", + "# use default authN method API-key\n", + "llm = OCIGenAI(\n", + " model_id=\"MY_MODEL\",\n", + " service_endpoint=\"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com\",\n", + " compartment_id=\"MY_OCID\",\n", + ")\n", + "\n", + "response = llm.invoke(\"Tell me one fact about earth\", temperature=0.7)\n", + "print(response)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.chains import LLMChain\n", + "from langchain_core.prompts import PromptTemplate\n", + "\n", + "# Use Session Token to authN\n", + "llm = OCIGenAI(\n", + " model_id=\"MY_MODEL\",\n", + " service_endpoint=\"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com\",\n", + " compartment_id=\"MY_OCID\",\n", + ")\n", + "\n", + "prompt = PromptTemplate(input_variables=[\"query\"], template=\"{query}\")\n", + "\n", + "llm_chain = LLMChain(llm=llm, prompt=prompt)\n", + "\n", + "response = llm_chain.invoke(\"what is the capital of france?\")\n", + "print(response)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.schema.output_parser import StrOutputParser\n", + "from langchain.schema.runnable import RunnablePassthrough\n", + "from langchain_community.embeddings import OCIGenAIEmbeddings\n", + "from langchain_community.vectorstores import FAISS\n", + "\n", + "embeddings = OCIGenAIEmbeddings(\n", + " model_id=\"MY_EMBEDDING_MODEL\",\n", + " service_endpoint=\"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com\",\n", + " compartment_id=\"MY_OCID\",\n", + ")\n", + "\n", + "vectorstore = FAISS.from_texts(\n", + " [\n", + " \"Larry Ellison co-founded Oracle Corporation in 1977 with Bob Miner and Ed Oates.\",\n", + " \"Oracle Corporation is an American multinational computer technology company headquartered in Austin, Texas, United States.\",\n", + " ],\n", + " embedding=embeddings,\n", + ")\n", + "\n", + "retriever = vectorstore.as_retriever()\n", + "\n", + "template = \"\"\"Answer the question based only on the following context:\n", + "{context}\n", + " \n", + "Question: {question}\n", + "\"\"\"\n", + "prompt = PromptTemplate.from_template(template)\n", + "\n", + "llm = OCIGenAI(\n", + " model_id=\"MY_MODEL\",\n", + " service_endpoint=\"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com\",\n", + " compartment_id=\"MY_OCID\",\n", + ")\n", + "\n", + "chain = (\n", + " {\"context\": retriever, \"question\": RunnablePassthrough()}\n", + " | prompt\n", + " | llm\n", + " | StrOutputParser()\n", + ")\n", + "\n", + "print(chain.invoke(\"when was oracle founded?\"))\n", + "print(chain.invoke(\"where is oracle headquartered?\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "oci_langchain", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/libs/community/langchain_community/embeddings/__init__.py b/libs/community/langchain_community/embeddings/__init__.py index f39787b3556..aaf893ba51b 100644 --- a/libs/community/langchain_community/embeddings/__init__.py +++ b/libs/community/langchain_community/embeddings/__init__.py @@ -65,6 +65,7 @@ from langchain_community.embeddings.mlflow_gateway import MlflowAIGatewayEmbeddi from langchain_community.embeddings.modelscope_hub import ModelScopeEmbeddings from langchain_community.embeddings.mosaicml import MosaicMLInstructorEmbeddings from langchain_community.embeddings.nlpcloud import NLPCloudEmbeddings +from langchain_community.embeddings.oci_generative_ai import OCIGenAIEmbeddings from langchain_community.embeddings.octoai_embeddings import OctoAIEmbeddings from langchain_community.embeddings.ollama import OllamaEmbeddings from langchain_community.embeddings.openai import OpenAIEmbeddings @@ -144,6 +145,7 @@ __all__ = [ "VoyageEmbeddings", "BookendEmbeddings", "VolcanoEmbeddings", + "OCIGenAIEmbeddings", ] diff --git a/libs/community/langchain_community/embeddings/oci_generative_ai.py b/libs/community/langchain_community/embeddings/oci_generative_ai.py new file mode 100644 index 00000000000..6d47fec6f32 --- /dev/null +++ b/libs/community/langchain_community/embeddings/oci_generative_ai.py @@ -0,0 +1,203 @@ +from enum import Enum +from typing import Any, Dict, List, Mapping, Optional + +from langchain_core.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator + +CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint" + + +class OCIAuthType(Enum): + API_KEY = 1 + SECURITY_TOKEN = 2 + INSTANCE_PRINCIPAL = 3 + RESOURCE_PRINCIPAL = 4 + + +class OCIGenAIEmbeddings(BaseModel, Embeddings): + """OCI embedding models. + + To authenticate, the OCI client uses the methods described in + https://docs.oracle.com/en-us/iaas/Content/API/Concepts/sdk_authentication_methods.htm + + The authentifcation method is passed through auth_type and should be one of: + API_KEY (default), SECURITY_TOKEN, INSTANCE_PRINCIPLE, RESOURCE_PRINCIPLE + + Make sure you have the required policies (profile/roles) to + access the OCI Generative AI service. If a specific config profile is used, + you must pass the name of the profile (~/.oci/config) through auth_profile. + + To use, you must provide the compartment id + along with the endpoint url, and model id + as named parameters to the constructor. + + Example: + .. code-block:: python + + from langchain.embeddings import OCIGenAIEmbeddings + + embeddings = OCIGenAIEmbeddings( + model_id="MY_EMBEDDING_MODEL", + service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", + compartment_id="MY_OCID" + ) + """ + + client: Any #: :meta private: + + service_models: Any #: :meta private: + + auth_type: Optional[str] = "API_KEY" + """Authentication type, could be + + API_KEY, + SECURITY_TOKEN, + INSTANCE_PRINCIPLE, + RESOURCE_PRINCIPLE + + If not specified, API_KEY will be used + """ + + auth_profile: Optional[str] = "DEFAULT" + """The name of the profile in ~/.oci/config + If not specified , DEFAULT will be used + """ + + model_id: str = None + """Id of the model to call, e.g., cohere.embed-english-light-v2.0""" + + model_kwargs: Optional[Dict] = None + """Keyword arguments to pass to the model""" + + service_endpoint: str = None + """service endpoint url""" + + compartment_id: str = None + """OCID of compartment""" + + truncate: Optional[str] = "END" + """Truncate embeddings that are too long from start or end ("NONE"|"START"|"END")""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: # pylint: disable=no-self-argument + """Validate that OCI config and python package exists in environment.""" + + # Skip creating new client if passed in constructor + if values["client"] is not None: + return values + + try: + import oci + + client_kwargs = { + "config": {}, + "signer": None, + "service_endpoint": values["service_endpoint"], + "retry_strategy": oci.retry.DEFAULT_RETRY_STRATEGY, + "timeout": (10, 240), # default timeout config for OCI Gen AI service + } + + if values["auth_type"] == OCIAuthType(1).name: + client_kwargs["config"] = oci.config.from_file( + profile_name=values["auth_profile"] + ) + client_kwargs.pop("signer", None) + elif values["auth_type"] == OCIAuthType(2).name: + + def make_security_token_signer(oci_config): + pk = oci.signer.load_private_key_from_file( + oci_config.get("key_file"), None + ) + with open( + oci_config.get("security_token_file"), encoding="utf-8" + ) as f: + st_string = f.read() + return oci.auth.signers.SecurityTokenSigner(st_string, pk) + + client_kwargs["config"] = oci.config.from_file( + profile_name=values["auth_profile"] + ) + client_kwargs["signer"] = make_security_token_signer( + oci_config=client_kwargs["config"] + ) + elif values["auth_type"] == OCIAuthType(3).name: + client_kwargs[ + "signer" + ] = oci.auth.signers.InstancePrincipalsSecurityTokenSigner() + elif values["auth_type"] == OCIAuthType(4).name: + client_kwargs[ + "signer" + ] = oci.auth.signers.get_resource_principals_signer() + else: + raise ValueError("Please provide valid value to auth_type") + + values["client"] = oci.generative_ai_inference.GenerativeAiInferenceClient( + **client_kwargs + ) + + except ImportError as ex: + raise ModuleNotFoundError( + "Could not import oci python package. " + "Please make sure you have the oci package installed." + ) from ex + except Exception as e: + raise ValueError( + "Could not authenticate with OCI client. " + "Please check if ~/.oci/config exists. " + "If INSTANCE_PRINCIPLE or RESOURCE_PRINCIPLE is used, " + "Please check the specified " + "auth_profile and auth_type are valid." + ) from e + + return values + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + _model_kwargs = self.model_kwargs or {} + return { + **{"model_kwargs": _model_kwargs}, + } + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Call out to OCIGenAI's embedding endpoint. + + Args: + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text. + """ + from oci.generative_ai_inference import models + + if self.model_id.startswith(CUSTOM_ENDPOINT_PREFIX): + serving_mode = models.DedicatedServingMode(endpoint_id=self.model_id) + else: + serving_mode = models.OnDemandServingMode(model_id=self.model_id) + + invocation_obj = models.EmbedTextDetails( + serving_mode=serving_mode, + compartment_id=self.compartment_id, + truncate=self.truncate, + inputs=texts, + ) + + response = self.client.embed_text(invocation_obj) + + return response.data.embeddings + + def embed_query(self, text: str) -> List[float]: + """Call out to OCIGenAI's embedding endpoint. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + return self.embed_documents([text])[0] diff --git a/libs/community/langchain_community/llms/__init__.py b/libs/community/langchain_community/llms/__init__.py index 13d3607918e..5a08670333e 100644 --- a/libs/community/langchain_community/llms/__init__.py +++ b/libs/community/langchain_community/llms/__init__.py @@ -346,6 +346,12 @@ def _import_oci_md_vllm() -> Any: return OCIModelDeploymentVLLM +def _import_oci_gen_ai() -> Any: + from langchain_community.llms.oci_generative_ai import OCIGenAI + + return OCIGenAI + + def _import_octoai_endpoint() -> Any: from langchain_community.llms.octoai_endpoint import OctoAIEndpoint @@ -667,6 +673,8 @@ def __getattr__(name: str) -> Any: return _import_oci_md_tgi() elif name == "OCIModelDeploymentVLLM": return _import_oci_md_vllm() + elif name == "OCIGenAI": + return _import_oci_gen_ai() elif name == "OctoAIEndpoint": return _import_octoai_endpoint() elif name == "Ollama": @@ -801,6 +809,7 @@ __all__ = [ "NLPCloud", "OCIModelDeploymentTGI", "OCIModelDeploymentVLLM", + "OCIGenAI", "Ollama", "OpenAI", "OpenAIChat", @@ -891,6 +900,7 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]: "nlpcloud": _import_nlpcloud, "oci_model_deployment_tgi_endpoint": _import_oci_md_tgi, "oci_model_deployment_vllm_endpoint": _import_oci_md_vllm, + "oci_generative_ai": _import_oci_gen_ai, "ollama": _import_ollama, "openai": _import_openai, "openlm": _import_openlm, diff --git a/libs/community/langchain_community/llms/oci_generative_ai.py b/libs/community/langchain_community/llms/oci_generative_ai.py new file mode 100644 index 00000000000..092cbda5548 --- /dev/null +++ b/libs/community/langchain_community/llms/oci_generative_ai.py @@ -0,0 +1,276 @@ +from __future__ import annotations + +from abc import ABC +from enum import Enum +from typing import Any, Dict, List, Mapping, Optional + +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models.llms import LLM +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator + +from langchain_community.llms.utils import enforce_stop_tokens + +CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint" +VALID_PROVIDERS = ("cohere", "meta") + + +class OCIAuthType(Enum): + API_KEY = 1 + SECURITY_TOKEN = 2 + INSTANCE_PRINCIPAL = 3 + RESOURCE_PRINCIPAL = 4 + + +class OCIGenAIBase(BaseModel, ABC): + """Base class for OCI GenAI models""" + + client: Any #: :meta private: + + auth_type: Optional[str] = "API_KEY" + """Authentication type, could be + + API_KEY, + SECURITY_TOKEN, + INSTANCE_PRINCIPLE, + RESOURCE_PRINCIPLE + + If not specified, API_KEY will be used + """ + + auth_profile: Optional[str] = "DEFAULT" + """The name of the profile in ~/.oci/config + If not specified , DEFAULT will be used + """ + + model_id: str = None + """Id of the model to call, e.g., cohere.command""" + + provider: str = None + """Provider name of the model. Default to None, + will try to be derived from the model_id + otherwise, requires user input + """ + + model_kwargs: Optional[Dict] = None + """Keyword arguments to pass to the model""" + + service_endpoint: str = None + """service endpoint url""" + + compartment_id: str = None + """OCID of compartment""" + + is_stream: bool = False + """Whether to stream back partial progress""" + + llm_stop_sequence_mapping: Mapping[str, str] = { + "cohere": "stop_sequences", + "meta": "stop", + } + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that OCI config and python package exists in environment.""" + + # Skip creating new client if passed in constructor + if values["client"] is not None: + return values + + try: + import oci + + client_kwargs = { + "config": {}, + "signer": None, + "service_endpoint": values["service_endpoint"], + "retry_strategy": oci.retry.DEFAULT_RETRY_STRATEGY, + "timeout": (10, 240), # default timeout config for OCI Gen AI service + } + + if values["auth_type"] == OCIAuthType(1).name: + client_kwargs["config"] = oci.config.from_file( + profile_name=values["auth_profile"] + ) + client_kwargs.pop("signer", None) + elif values["auth_type"] == OCIAuthType(2).name: + + def make_security_token_signer(oci_config): + pk = oci.signer.load_private_key_from_file( + oci_config.get("key_file"), None + ) + with open( + oci_config.get("security_token_file"), encoding="utf-8" + ) as f: + st_string = f.read() + return oci.auth.signers.SecurityTokenSigner(st_string, pk) + + client_kwargs["config"] = oci.config.from_file( + profile_name=values["auth_profile"] + ) + client_kwargs["signer"] = make_security_token_signer( + oci_config=client_kwargs["config"] + ) + elif values["auth_type"] == OCIAuthType(3).name: + client_kwargs[ + "signer" + ] = oci.auth.signers.InstancePrincipalsSecurityTokenSigner() + elif values["auth_type"] == OCIAuthType(4).name: + client_kwargs[ + "signer" + ] = oci.auth.signers.get_resource_principals_signer() + else: + raise ValueError("Please provide valid value to auth_type") + + values["client"] = oci.generative_ai_inference.GenerativeAiInferenceClient( + **client_kwargs + ) + + except ImportError as ex: + raise ModuleNotFoundError( + "Could not import oci python package. " + "Please make sure you have the oci package installed." + ) from ex + except Exception as e: + raise ValueError( + "Could not authenticate with OCI client. " + "Please check if ~/.oci/config exists. " + "If INSTANCE_PRINCIPLE or RESOURCE_PRINCIPLE is used, " + "Please check the specified " + "auth_profile and auth_type are valid." + ) from e + + return values + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + _model_kwargs = self.model_kwargs or {} + return { + **{"model_kwargs": _model_kwargs}, + } + + def _get_provider(self) -> str: + if self.provider is not None: + provider = self.provider + else: + provider = self.model_id.split(".")[0].lower() + + if provider not in VALID_PROVIDERS: + raise ValueError( + f"Invalid provider derived from model_id: {self.model_id} " + "Please explicitly pass in the supported provider " + "when using custom endpoint" + ) + return provider + + +class OCIGenAI(LLM, OCIGenAIBase): + """OCI large language models. + + To authenticate, the OCI client uses the methods described in + https://docs.oracle.com/en-us/iaas/Content/API/Concepts/sdk_authentication_methods.htm + + The authentifcation method is passed through auth_type and should be one of: + API_KEY (default), SECURITY_TOKEN, INSTANCE_PRINCIPLE, RESOURCE_PRINCIPLE + + Make sure you have the required policies (profile/roles) to + access the OCI Generative AI service. + If a specific config profile is used, you must pass + the name of the profile (from ~/.oci/config) through auth_profile. + + To use, you must provide the compartment id + along with the endpoint url, and model id + as named parameters to the constructor. + + Example: + .. code-block:: python + + from langchain_community.llms import OCIGenAI + + llm = OCIGenAI( + model_id="MY_MODEL_ID", + service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", + compartment_id="MY_OCID" + ) + """ + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "oci" + + def _prepare_invocation_object( + self, prompt: str, stop: Optional[List[str]], kwargs: Dict[str, Any] + ) -> Dict[str, Any]: + from oci.generative_ai_inference import models + + oci_llm_request_mapping = { + "cohere": models.CohereLlmInferenceRequest, + "meta": models.LlamaLlmInferenceRequest, + } + provider = self._get_provider() + _model_kwargs = self.model_kwargs or {} + if stop is not None: + _model_kwargs[self.llm_stop_sequence_mapping[provider]] = stop + + if self.model_id.startswith(CUSTOM_ENDPOINT_PREFIX): + serving_mode = models.DedicatedServingMode(endpoint_id=self.model_id) + else: + serving_mode = models.OnDemandServingMode(model_id=self.model_id) + + inference_params = {**_model_kwargs, **kwargs} + inference_params["prompt"] = prompt + inference_params["is_stream"] = self.is_stream + + invocation_obj = models.GenerateTextDetails( + compartment_id=self.compartment_id, + serving_mode=serving_mode, + inference_request=oci_llm_request_mapping[provider](**inference_params), + ) + + return invocation_obj + + def _process_response(self, response: Any, stop: Optional[List[str]]) -> str: + provider = self._get_provider() + if provider == "cohere": + text = response.data.inference_response.generated_texts[0].text + elif provider == "meta": + text = response.data.inference_response.choices[0].text + else: + raise ValueError(f"Invalid provider: {provider}") + + if stop is not None: + text = enforce_stop_tokens(text, stop) + + return text + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Call out to OCIGenAI generate endpoint. + + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + + Returns: + The string generated by the model. + + Example: + .. code-block:: python + + response = llm.invoke("Tell me a joke.") + """ + + invocation_obj = self._prepare_invocation_object(prompt, stop, kwargs) + response = self.client.generate_text(invocation_obj) + return self._process_response(response, stop) diff --git a/libs/community/poetry.lock b/libs/community/poetry.lock index 59c6e780fb0..210fd33046e 100644 --- a/libs/community/poetry.lock +++ b/libs/community/poetry.lock @@ -3433,7 +3433,6 @@ files = [ {file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:227b178b22a7f91ae88525810441791b1ca1fc71c86f03190911793be15cec3d"}, {file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:780eb6383fbae12afa819ef676fc93e1548ae4b076c004a393af26a04b460742"}, {file = "jq-1.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:08ded6467f4ef89fec35b2bf310f210f8cd13fbd9d80e521500889edf8d22441"}, - {file = "jq-1.6.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:49e44ed677713f4115bd5bf2dbae23baa4cd503be350e12a1c1f506b0687848f"}, {file = "jq-1.6.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:984f33862af285ad3e41e23179ac4795f1701822473e1a26bf87ff023e5a89ea"}, {file = "jq-1.6.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f42264fafc6166efb5611b5d4cb01058887d050a6c19334f6a3f8a13bb369df5"}, {file = "jq-1.6.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a67154f150aaf76cc1294032ed588436eb002097dd4fd1e283824bf753a05080"}, @@ -4999,13 +4998,13 @@ signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"] [[package]] name = "oci" -version = "2.118.0" +version = "2.119.1" description = "Oracle Cloud Infrastructure Python SDK" optional = true python-versions = "*" files = [ - {file = "oci-2.118.0-py3-none-any.whl", hash = "sha256:766170a9b4c93053ba3fe5ae63c0ab48fdd71b4d17709742a2b45249f0829872"}, - {file = "oci-2.118.0.tar.gz", hash = "sha256:1004726c4dad6c02f967b7bc4e733ff552451a2914cb542c380756c7d46bb938"}, + {file = "oci-2.119.1-py3-none-any.whl", hash = "sha256:64b6012f3c2b70cf7fb5f58a1a4b4458d8f4d41ea1b79a5d9f8ca4beb2dfa225"}, + {file = "oci-2.119.1.tar.gz", hash = "sha256:992df963382f378b93634826956677f3c13407ca1b828c4eaf1cfd18f19fae33"}, ] [package.dependencies] @@ -6223,6 +6222,7 @@ files = [ {file = "pymongo-4.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b8729dbf25eb32ad0dc0b9bd5e6a0d0b7e5c2dc8ec06ad171088e1896b522a74"}, {file = "pymongo-4.6.1-cp312-cp312-win32.whl", hash = "sha256:3177f783ae7e08aaf7b2802e0df4e4b13903520e8380915e6337cdc7a6ff01d8"}, {file = "pymongo-4.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:00c199e1c593e2c8b033136d7a08f0c376452bac8a896c923fcd6f419e07bdd2"}, + {file = "pymongo-4.6.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6dcc95f4bb9ed793714b43f4f23a7b0c57e4ef47414162297d6f650213512c19"}, {file = "pymongo-4.6.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:13552ca505366df74e3e2f0a4f27c363928f3dff0eef9f281eb81af7f29bc3c5"}, {file = "pymongo-4.6.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:77e0df59b1a4994ad30c6d746992ae887f9756a43fc25dec2db515d94cf0222d"}, {file = "pymongo-4.6.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:3a7f02a58a0c2912734105e05dedbee4f7507e6f1bd132ebad520be0b11d46fd"}, @@ -6773,6 +6773,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -9226,9 +9227,9 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [extras] cli = ["typer"] -extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cohere", "dashvector", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict", "zhipuai"] +extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cohere", "dashvector", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "oci", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict", "zhipuai"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "c03bd15da5fd84ec91adec43e62b06623b6ec51003530a762455f74a4ee3715f" +content-hash = "18694abbcaec37f026883b07d1c198f9fc3fdb012d7f2be16ce4ad1866913463" diff --git a/libs/community/pyproject.toml b/libs/community/pyproject.toml index 14896a667e0..8d8f30aad0e 100644 --- a/libs/community/pyproject.toml +++ b/libs/community/pyproject.toml @@ -89,6 +89,7 @@ oracle-ads = {version = "^2.9.1", optional = true} zhipuai = {version = "^1.0.7", optional = true} elasticsearch = {version = "^8.12.0", optional = true} hdbcli = {version = "^2.19.21", optional = true} +oci = {version = "^2.119.1", optional = true} [tool.poetry.group.test] optional = true @@ -253,6 +254,7 @@ extended_testing = [ "zhipuai", "elasticsearch", "hdbcli", + "oci" ] [tool.ruff] diff --git a/libs/community/tests/unit_tests/embeddings/test_imports.py b/libs/community/tests/unit_tests/embeddings/test_imports.py index dee9b1ba836..1bb872607d8 100644 --- a/libs/community/tests/unit_tests/embeddings/test_imports.py +++ b/libs/community/tests/unit_tests/embeddings/test_imports.py @@ -56,6 +56,7 @@ EXPECTED_ALL = [ "VoyageEmbeddings", "BookendEmbeddings", "VolcanoEmbeddings", + "OCIGenAIEmbeddings", ] diff --git a/libs/community/tests/unit_tests/embeddings/test_oci_gen_ai_embedding.py b/libs/community/tests/unit_tests/embeddings/test_oci_gen_ai_embedding.py new file mode 100644 index 00000000000..12d9c447d5d --- /dev/null +++ b/libs/community/tests/unit_tests/embeddings/test_oci_gen_ai_embedding.py @@ -0,0 +1,50 @@ +"""Test OCI Generative AI embedding service.""" +from unittest.mock import MagicMock + +import pytest +from pytest import MonkeyPatch + +from langchain_community.embeddings import OCIGenAIEmbeddings + + +class MockResponseDict(dict): + def __getattr__(self, val): + return self[val] + + +@pytest.mark.requires("oci") +@pytest.mark.parametrize( + "test_model_id", ["cohere.embed-english-light-v3.0", "cohere.embed-english-v3.0"] +) +def test_embedding_call(monkeypatch: MonkeyPatch, test_model_id: str) -> None: + """Test valid call to OCI Generative AI embedding service.""" + oci_gen_ai_client = MagicMock() + embeddings = OCIGenAIEmbeddings( + model_id=test_model_id, + service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", + client=oci_gen_ai_client, + ) + + def mocked_response(invocation_obj): + docs = invocation_obj.inputs + + embeddings = [] + for d in docs: + if "Hello" in d: + v = [1.0, 0.0, 0.0] + elif "World" in d: + v = [0.0, 1.0, 0.0] + else: + v = [0.0, 0.0, 1.0] + embeddings.append(v) + + return MockResponseDict( + {"status": 200, "data": MockResponseDict({"embeddings": embeddings})} + ) + + monkeypatch.setattr(embeddings.client, "embed_text", mocked_response) + + output = embeddings.embed_documents(["Hello", "World"]) + correct_output = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]] + + assert output == correct_output diff --git a/libs/community/tests/unit_tests/llms/test_imports.py b/libs/community/tests/unit_tests/llms/test_imports.py index 2e5fed3a70c..9c7abb11f83 100644 --- a/libs/community/tests/unit_tests/llms/test_imports.py +++ b/libs/community/tests/unit_tests/llms/test_imports.py @@ -52,6 +52,7 @@ EXPECT_ALL = [ "Nebula", "OCIModelDeploymentTGI", "OCIModelDeploymentVLLM", + "OCIGenAI", "NIBittensorLLM", "NLPCloud", "Ollama", diff --git a/libs/community/tests/unit_tests/llms/test_oci_generative_ai.py b/libs/community/tests/unit_tests/llms/test_oci_generative_ai.py new file mode 100644 index 00000000000..694d88f0c24 --- /dev/null +++ b/libs/community/tests/unit_tests/llms/test_oci_generative_ai.py @@ -0,0 +1,76 @@ +"""Test OCI Generative AI LLM service""" +from unittest.mock import MagicMock + +import pytest +from pytest import MonkeyPatch + +from langchain_community.llms import OCIGenAI + + +class MockResponseDict(dict): + def __getattr__(self, val): + return self[val] + + +@pytest.mark.requires("oci") +@pytest.mark.parametrize( + "test_model_id", ["cohere.command", "cohere.command-light", "meta.llama-2-70b-chat"] +) +def test_llm_call(monkeypatch: MonkeyPatch, test_model_id: str) -> None: + """Test valid call to OCI Generative AI LLM service.""" + oci_gen_ai_client = MagicMock() + llm = OCIGenAI(model_id=test_model_id, client=oci_gen_ai_client) + + provider = llm._get_provider() + + def mocked_response(*args): + response_text = "This is the completion." + + if provider == "cohere": + return MockResponseDict( + { + "status": 200, + "data": MockResponseDict( + { + "inference_response": MockResponseDict( + { + "generated_texts": [ + MockResponseDict( + { + "text": response_text, + } + ) + ] + } + ) + } + ), + } + ) + + if provider == "meta": + return MockResponseDict( + { + "status": 200, + "data": MockResponseDict( + { + "inference_response": MockResponseDict( + { + "choices": [ + MockResponseDict( + { + "text": response_text, + } + ) + ] + } + ) + } + ), + } + ) + + monkeypatch.setattr(llm.client, "generate_text", mocked_response) + + output = llm.invoke("This is a prompt.", temperature=0.2) + assert output == "This is the completion."