mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-28 20:05:58 +00:00
Add OCI Generative AI new model support (#22880)
- [x] PR title: community: Add OCI Generative AI new model support - [x] PR message: - Description: adding support for new models offered by OCI Generative AI services. This is a moderate update of our initial integration PR 16548 and includes a new integration for our chat models under /langchain_community/chat_models/oci_generative_ai.py - Issue: NA - Dependencies: No new Dependencies, just latest version of our OCI sdk - Twitter handle: NA - [x] Add tests and docs: 1. we have updated our unit tests 2. we have updated our documentation including a new ipynb for our new chat integration - [x] Lint and test: `make format`, `make lint`, and `make test` run successfully --------- Co-authored-by: RHARPAZ <RHARPAZ@RHARPAZ-5750.us.oracle.com> Co-authored-by: Arthur Cheng <arthur.cheng@oracle.com>
This commit is contained in:
parent
753edf9c80
commit
f5ff7f178b
190
docs/docs/integrations/chat/oci_generative_ai.ipynb
Normal file
190
docs/docs/integrations/chat/oci_generative_ai.ipynb
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "raw",
|
||||||
|
"id": "afaf8039",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"---\n",
|
||||||
|
"sidebar_label: OCIGenAI\n",
|
||||||
|
"---"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "e49f1e0d",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# ChatOCIGenAI\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook provides a quick overview for getting started with OCIGenAI [chat models](/docs/concepts/#chat-models). For detailed documentation of all ChatOCIGenAI features and configurations head to the [API reference](https://api.python.langchain.com/en/latest/chat_models/langchain_community.chat_models.oci_generative_ai.ChatOCIGenAI.html).\n",
|
||||||
|
"\n",
|
||||||
|
"Oracle Cloud Infrastructure (OCI) Generative AI is a fully managed service that provides a set of state-of-the-art, customizable large language models (LLMs) that cover a wide range of use cases, and which is available through a single API.\n",
|
||||||
|
"Using the OCI Generative AI service you can access ready-to-use pretrained models, or create and host your own fine-tuned custom models based on your own data on dedicated AI clusters. Detailed documentation of the service and API is available __[here](https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm)__ and __[here](https://docs.oracle.com/en-us/iaas/api/#/en/generative-ai/20231130/)__.\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"## Overview\n",
|
||||||
|
"### Integration details\n",
|
||||||
|
"\n",
|
||||||
|
"| Class | Package | Local | Serializable | [JS support](https://js.langchain.com/v0.2/docs/integrations/chat/oci_generative_ai) | Package downloads | Package latest |\n",
|
||||||
|
"| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n",
|
||||||
|
"| [ChatOCIGenAI](https://api.python.langchain.com/en/latest/chat_models/langchain_community.chat_models.oci_generative_ai.ChatOCIGenAI.html) | [langchain-community](https://api.python.langchain.com/en/latest/community_api_reference.html) | ❌ | ❌ | ❌ |  |  |\n",
|
||||||
|
"\n",
|
||||||
|
"### Model features\n",
|
||||||
|
"| [Tool calling](/docs/how_to/tool_calling/) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n",
|
||||||
|
"| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n",
|
||||||
|
"| ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | \n",
|
||||||
|
"\n",
|
||||||
|
"## Setup\n",
|
||||||
|
"\n",
|
||||||
|
"To access OCIGenAI models you'll need to install the `oci` and `langchain-community` packages.\n",
|
||||||
|
"\n",
|
||||||
|
"### Credentials\n",
|
||||||
|
"\n",
|
||||||
|
"The credentials and authentication methods supported for this integration are equivalent to those used with other OCI services and follow the __[standard SDK authentication](https://docs.oracle.com/en-us/iaas/Content/API/Concepts/sdk_authentication_methods.htm)__ methods, specifically API Key, session token, instance principal, and resource principal.\n",
|
||||||
|
"\n",
|
||||||
|
"API key is the default authentication method used in the examples above. The following example demonstrates how to use a different authentication method (session token)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "0730d6a1-c893-4840-9817-5e5251676d5d",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Installation\n",
|
||||||
|
"\n",
|
||||||
|
"The LangChain OCIGenAI integration lives in the `langchain-community` package and you will also need to install the `oci` package:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "652d6238-1f87-422a-b135-f5abbb8652fc",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"%pip install -qU langchain-community oci"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "a38cde65-254d-4219-a441-068766c0d4b5",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Instantiation\n",
|
||||||
|
"\n",
|
||||||
|
"Now we can instantiate our model object and generate chat completions:\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "cb09c344-1836-4e0c-acf8-11d13ac1dbae",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain_community.chat_models.oci_generative_ai import ChatOCIGenAI\n",
|
||||||
|
"from langchain_core.messages import AIMessage, HumanMessage, SystemMessage\n",
|
||||||
|
"\n",
|
||||||
|
"chat = ChatOCIGenAI(\n",
|
||||||
|
" model_id=\"cohere.command-r-16k\",\n",
|
||||||
|
" service_endpoint=\"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com\",\n",
|
||||||
|
" compartment_id=\"MY_OCID\",\n",
|
||||||
|
" model_kwargs={\"temperature\": 0.7, \"max_tokens\": 500},\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "2b4f3e15",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Invocation"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "62e0dbc3",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"messages = [\n",
|
||||||
|
" SystemMessage(content=\"your are an AI assistant.\"),\n",
|
||||||
|
" AIMessage(content=\"Hi there human!\"),\n",
|
||||||
|
" HumanMessage(content=\"tell me a joke.\"),\n",
|
||||||
|
"]\n",
|
||||||
|
"response = chat.invoke(messages)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "d86145b3-bfef-46e8-b227-4dda5c9c2705",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"print(response.content)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "18e2bfc0-7e78-4528-a73f-499ac150dca8",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Chaining\n",
|
||||||
|
"\n",
|
||||||
|
"We can [chain](/docs/how_to/sequence/) our model with a prompt template like so:\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "e197d1d7-a070-4c96-9f8a-a0e86d046e0b",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain_core.prompts import ChatPromptTemplate\n",
|
||||||
|
"\n",
|
||||||
|
"prompt = ChatPromptTemplate.from_template(\"Tell me a joke about {topic}\")\n",
|
||||||
|
"chain = prompt | chat\n",
|
||||||
|
"\n",
|
||||||
|
"response = chain.invoke({\"topic\": \"dogs\"})\n",
|
||||||
|
"print(response.content)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "3a5bb5ca-c3ae-4a58-be67-2cd18574b9a3",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## API reference\n",
|
||||||
|
"\n",
|
||||||
|
"For detailed documentation of all ChatOCIGenAI features and configurations head to the API reference: https://api.python.langchain.com/en/latest/chat_models/langchain_community.chat_models.oci_generative_ai.ChatOCIGenAI.html"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -14,15 +14,15 @@
|
|||||||
"Oracle Cloud Infrastructure (OCI) Generative AI is a fully managed service that provides a set of state-of-the-art, customizable large language models (LLMs) that cover a wide range of use cases, and which is available through a single API.\n",
|
"Oracle Cloud Infrastructure (OCI) Generative AI is a fully managed service that provides a set of state-of-the-art, customizable large language models (LLMs) that cover a wide range of use cases, and which is available through a single API.\n",
|
||||||
"Using the OCI Generative AI service you can access ready-to-use pretrained models, or create and host your own fine-tuned custom models based on your own data on dedicated AI clusters. Detailed documentation of the service and API is available __[here](https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm)__ and __[here](https://docs.oracle.com/en-us/iaas/api/#/en/generative-ai/20231130/)__.\n",
|
"Using the OCI Generative AI service you can access ready-to-use pretrained models, or create and host your own fine-tuned custom models based on your own data on dedicated AI clusters. Detailed documentation of the service and API is available __[here](https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm)__ and __[here](https://docs.oracle.com/en-us/iaas/api/#/en/generative-ai/20231130/)__.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"This notebook explains how to use OCI's Genrative AI models with LangChain."
|
"This notebook explains how to use OCI's Generative AI complete models with LangChain."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"### Prerequisite\n",
|
"## Setup\n",
|
||||||
"We will need to install the oci sdk"
|
"Ensure that the oci sdk and the langchain-community package are installed"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -31,31 +31,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"!pip install -U oci"
|
"!pip install -U oci langchain-community"
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"### OCI Generative AI API endpoint \n",
|
|
||||||
"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Authentication\n",
|
|
||||||
"The authentication methods supported for this langchain integration are:\n",
|
|
||||||
"\n",
|
|
||||||
"1. API Key\n",
|
|
||||||
"2. Session token\n",
|
|
||||||
"3. Instance principal\n",
|
|
||||||
"4. Resource principal \n",
|
|
||||||
"\n",
|
|
||||||
"These follows the standard SDK authentication methods detailed __[here](https://docs.oracle.com/en-us/iaas/Content/API/Concepts/sdk_authentication_methods.htm)__.\n",
|
|
||||||
" "
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -71,13 +47,13 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain_community.llms import OCIGenAI\n",
|
"from langchain_community.llms.oci_generative_ai import OCIGenAI\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# use default authN method API-key\n",
|
|
||||||
"llm = OCIGenAI(\n",
|
"llm = OCIGenAI(\n",
|
||||||
" model_id=\"MY_MODEL\",\n",
|
" model_id=\"cohere.command\",\n",
|
||||||
" service_endpoint=\"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com\",\n",
|
" service_endpoint=\"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com\",\n",
|
||||||
" compartment_id=\"MY_OCID\",\n",
|
" compartment_id=\"MY_OCID\",\n",
|
||||||
|
" model_kwargs={\"temperature\": 0, \"max_tokens\": 500},\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"response = llm.invoke(\"Tell me one fact about earth\", temperature=0.7)\n",
|
"response = llm.invoke(\"Tell me one fact about earth\", temperature=0.7)\n",
|
||||||
@ -85,30 +61,10 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "markdown",
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain.chains import LLMChain\n",
|
"#### Chaining with prompt templates"
|
||||||
"from langchain_core.prompts import PromptTemplate\n",
|
|
||||||
"\n",
|
|
||||||
"# Use Session Token to authN\n",
|
|
||||||
"llm = OCIGenAI(\n",
|
|
||||||
" model_id=\"MY_MODEL\",\n",
|
|
||||||
" service_endpoint=\"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com\",\n",
|
|
||||||
" compartment_id=\"MY_OCID\",\n",
|
|
||||||
" auth_type=\"SECURITY_TOKEN\",\n",
|
|
||||||
" auth_profile=\"MY_PROFILE\", # replace with your profile name\n",
|
|
||||||
" model_kwargs={\"temperature\": 0.7, \"top_p\": 0.75, \"max_tokens\": 200},\n",
|
|
||||||
")\n",
|
|
||||||
"\n",
|
|
||||||
"prompt = PromptTemplate(input_variables=[\"query\"], template=\"{query}\")\n",
|
|
||||||
"\n",
|
|
||||||
"llm_chain = LLMChain(llm=llm, prompt=prompt)\n",
|
|
||||||
"\n",
|
|
||||||
"response = llm_chain.invoke(\"what is the capital of france?\")\n",
|
|
||||||
"print(response)"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -117,49 +73,95 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain_community.embeddings import OCIGenAIEmbeddings\n",
|
"from langchain_core.prompts import PromptTemplate\n",
|
||||||
"from langchain_community.vectorstores import FAISS\n",
|
|
||||||
"from langchain_core.output_parsers import StrOutputParser\n",
|
|
||||||
"from langchain_core.runnables import RunnablePassthrough\n",
|
|
||||||
"\n",
|
|
||||||
"embeddings = OCIGenAIEmbeddings(\n",
|
|
||||||
" model_id=\"MY_EMBEDDING_MODEL\",\n",
|
|
||||||
" service_endpoint=\"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com\",\n",
|
|
||||||
" compartment_id=\"MY_OCID\",\n",
|
|
||||||
")\n",
|
|
||||||
"\n",
|
|
||||||
"vectorstore = FAISS.from_texts(\n",
|
|
||||||
" [\n",
|
|
||||||
" \"Larry Ellison co-founded Oracle Corporation in 1977 with Bob Miner and Ed Oates.\",\n",
|
|
||||||
" \"Oracle Corporation is an American multinational computer technology company headquartered in Austin, Texas, United States.\",\n",
|
|
||||||
" ],\n",
|
|
||||||
" embedding=embeddings,\n",
|
|
||||||
")\n",
|
|
||||||
"\n",
|
|
||||||
"retriever = vectorstore.as_retriever()\n",
|
|
||||||
"\n",
|
|
||||||
"template = \"\"\"Answer the question based only on the following context:\n",
|
|
||||||
"{context}\n",
|
|
||||||
" \n",
|
|
||||||
"Question: {question}\n",
|
|
||||||
"\"\"\"\n",
|
|
||||||
"prompt = PromptTemplate.from_template(template)\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"llm = OCIGenAI(\n",
|
"llm = OCIGenAI(\n",
|
||||||
" model_id=\"MY_MODEL\",\n",
|
" model_id=\"cohere.command\",\n",
|
||||||
" service_endpoint=\"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com\",\n",
|
" service_endpoint=\"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com\",\n",
|
||||||
" compartment_id=\"MY_OCID\",\n",
|
" compartment_id=\"MY_OCID\",\n",
|
||||||
|
" model_kwargs={\"temperature\": 0, \"max_tokens\": 500},\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"chain = (\n",
|
"prompt = PromptTemplate(input_variables=[\"query\"], template=\"{query}\")\n",
|
||||||
" {\"context\": retriever, \"question\": RunnablePassthrough()}\n",
|
"llm_chain = prompt | llm\n",
|
||||||
" | prompt\n",
|
"\n",
|
||||||
" | llm\n",
|
"response = llm_chain.invoke(\"what is the capital of france?\")\n",
|
||||||
" | StrOutputParser()\n",
|
"print(response)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"#### Streaming"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"llm = OCIGenAI(\n",
|
||||||
|
" model_id=\"cohere.command\",\n",
|
||||||
|
" service_endpoint=\"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com\",\n",
|
||||||
|
" compartment_id=\"MY_OCID\",\n",
|
||||||
|
" model_kwargs={\"temperature\": 0, \"max_tokens\": 500},\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"print(chain.invoke(\"when was oracle founded?\"))\n",
|
"for chunk in llm.stream(\"Write me a song about sparkling water.\"):\n",
|
||||||
"print(chain.invoke(\"where is oracle headquartered?\"))"
|
" print(chunk, end=\"\", flush=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Authentication\n",
|
||||||
|
"The authentication methods supported for LlamaIndex are equivalent to those used with other OCI services and follow the __[standard SDK authentication](https://docs.oracle.com/en-us/iaas/Content/API/Concepts/sdk_authentication_methods.htm)__ methods, specifically API Key, session token, instance principal, and resource principal.\n",
|
||||||
|
"\n",
|
||||||
|
"API key is the default authentication method used in the examples above. The following example demonstrates how to use a different authentication method (session token)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"llm = OCIGenAI(\n",
|
||||||
|
" model_id=\"cohere.command\",\n",
|
||||||
|
" service_endpoint=\"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com\",\n",
|
||||||
|
" compartment_id=\"MY_OCID\",\n",
|
||||||
|
" auth_type=\"SECURITY_TOKEN\",\n",
|
||||||
|
" auth_profile=\"MY_PROFILE\", # replace with your profile name\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Dedicated AI Cluster\n",
|
||||||
|
"To access models hosted in a dedicated AI cluster __[create an endpoint](https://docs.oracle.com/en-us/iaas/api/#/en/generative-ai-inference/20231130/)__ whose assigned OCID (currently prefixed by ‘ocid1.generativeaiendpoint.oc1.us-chicago-1’) is used as your model ID.\n",
|
||||||
|
"\n",
|
||||||
|
"When accessing models hosted in a dedicated AI cluster you will need to initialize the OCIGenAI interface with two extra required params (\"provider\" and \"context_size\")."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"llm = OCIGenAI(\n",
|
||||||
|
" model_id=\"ocid1.generativeaiendpoint.oc1.us-chicago-1....\",\n",
|
||||||
|
" service_endpoint=\"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com\",\n",
|
||||||
|
" compartment_id=\"DEDICATED_COMPARTMENT_OCID\",\n",
|
||||||
|
" auth_profile=\"MY_PROFILE\", # replace with your profile name,\n",
|
||||||
|
" provider=\"MODEL_PROVIDER\", # e.g., \"cohere\" or \"meta\"\n",
|
||||||
|
" context_size=\"MODEL_CONTEXT_SIZE\", # e.g., 128000\n",
|
||||||
|
")"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -2,27 +2,29 @@
|
|||||||
|
|
||||||
The `LangChain` integrations related to [Oracle Cloud Infrastructure](https://www.oracle.com/artificial-intelligence/).
|
The `LangChain` integrations related to [Oracle Cloud Infrastructure](https://www.oracle.com/artificial-intelligence/).
|
||||||
|
|
||||||
## LLMs
|
## OCI Generative AI
|
||||||
|
|
||||||
### OCI Generative AI
|
|
||||||
> Oracle Cloud Infrastructure (OCI) [Generative AI](https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm) is a fully managed service that provides a set of state-of-the-art,
|
> Oracle Cloud Infrastructure (OCI) [Generative AI](https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm) is a fully managed service that provides a set of state-of-the-art,
|
||||||
> customizable large language models (LLMs) that cover a wide range of use cases, and which are available through a single API.
|
> customizable large language models (LLMs) that cover a wide range of use cases, and which are available through a single API.
|
||||||
> Using the OCI Generative AI service you can access ready-to-use pretrained models, or create and host your own fine-tuned
|
> Using the OCI Generative AI service you can access ready-to-use pretrained models, or create and host your own fine-tuned
|
||||||
> custom models based on your own data on dedicated AI clusters.
|
> custom models based on your own data on dedicated AI clusters.
|
||||||
|
|
||||||
To use, you should have the latest `oci` python SDK installed.
|
To use, you should have the latest `oci` python SDK and the langchain_community package installed.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install -U oci
|
pip install -U oci langchain-community
|
||||||
```
|
```
|
||||||
|
|
||||||
See [usage examples](/docs/integrations/llms/oci_generative_ai).
|
See [chat](/docs/integrations/llms/oci_generative_ai), [complete](/docs/integrations/chat/oci_generative_ai), and [embedding](/docs/integrations/text_embedding/oci_generative_ai) usage examples.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
from langchain_community.chat_models import ChatOCIGenAI
|
||||||
|
|
||||||
from langchain_community.llms import OCIGenAI
|
from langchain_community.llms import OCIGenAI
|
||||||
|
|
||||||
|
from langchain_community.embeddings import OCIGenAIEmbeddings
|
||||||
```
|
```
|
||||||
|
|
||||||
### OCI Data Science Model Deployment Endpoint
|
## OCI Data Science Model Deployment Endpoint
|
||||||
|
|
||||||
> [OCI Data Science](https://docs.oracle.com/en-us/iaas/data-science/using/home.htm) is a
|
> [OCI Data Science](https://docs.oracle.com/en-us/iaas/data-science/using/home.htm) is a
|
||||||
> fully managed and serverless platform for data science teams. Using the OCI Data Science
|
> fully managed and serverless platform for data science teams. Using the OCI Data Science
|
||||||
@ -47,12 +49,3 @@ from langchain_community.llms import OCIModelDeploymentVLLM
|
|||||||
from langchain_community.llms import OCIModelDeploymentTGI
|
from langchain_community.llms import OCIModelDeploymentTGI
|
||||||
```
|
```
|
||||||
|
|
||||||
## Text Embedding Models
|
|
||||||
|
|
||||||
### OCI Generative AI
|
|
||||||
|
|
||||||
See [usage examples](/docs/integrations/text_embedding/oci_generative_ai).
|
|
||||||
|
|
||||||
```python
|
|
||||||
from langchain_community.embeddings import OCIGenAIEmbeddings
|
|
||||||
```
|
|
@ -46,7 +46,7 @@ mwxml>=0.3.3,<0.4
|
|||||||
newspaper3k>=0.2.8,<0.3
|
newspaper3k>=0.2.8,<0.3
|
||||||
numexpr>=2.8.6,<3
|
numexpr>=2.8.6,<3
|
||||||
nvidia-riva-client>=2.14.0,<3
|
nvidia-riva-client>=2.14.0,<3
|
||||||
oci>=2.119.1,<3
|
oci>=2.128.0,<3
|
||||||
openai<2
|
openai<2
|
||||||
openapi-pydantic>=0.3.2,<0.4
|
openapi-pydantic>=0.3.2,<0.4
|
||||||
oracle-ads>=2.9.1,<3
|
oracle-ads>=2.9.1,<3
|
||||||
|
@ -121,6 +121,9 @@ if TYPE_CHECKING:
|
|||||||
from langchain_community.chat_models.mlx import (
|
from langchain_community.chat_models.mlx import (
|
||||||
ChatMLX,
|
ChatMLX,
|
||||||
)
|
)
|
||||||
|
from langchain_community.chat_models.oci_generative_ai import (
|
||||||
|
ChatOCIGenAI, # noqa: F401
|
||||||
|
)
|
||||||
from langchain_community.chat_models.octoai import ChatOctoAI
|
from langchain_community.chat_models.octoai import ChatOctoAI
|
||||||
from langchain_community.chat_models.ollama import (
|
from langchain_community.chat_models.ollama import (
|
||||||
ChatOllama,
|
ChatOllama,
|
||||||
@ -194,6 +197,7 @@ __all__ = [
|
|||||||
"ChatMLflowAIGateway",
|
"ChatMLflowAIGateway",
|
||||||
"ChatMaritalk",
|
"ChatMaritalk",
|
||||||
"ChatMlflow",
|
"ChatMlflow",
|
||||||
|
"ChatOCIGenAI",
|
||||||
"ChatOllama",
|
"ChatOllama",
|
||||||
"ChatOpenAI",
|
"ChatOpenAI",
|
||||||
"ChatPerplexity",
|
"ChatPerplexity",
|
||||||
@ -248,6 +252,7 @@ _module_lookup = {
|
|||||||
"ChatMaritalk": "langchain_community.chat_models.maritalk",
|
"ChatMaritalk": "langchain_community.chat_models.maritalk",
|
||||||
"ChatMlflow": "langchain_community.chat_models.mlflow",
|
"ChatMlflow": "langchain_community.chat_models.mlflow",
|
||||||
"ChatOctoAI": "langchain_community.chat_models.octoai",
|
"ChatOctoAI": "langchain_community.chat_models.octoai",
|
||||||
|
"ChatOCIGenAI": "langchain_community.chat_models.oci_generative_ai",
|
||||||
"ChatOllama": "langchain_community.chat_models.ollama",
|
"ChatOllama": "langchain_community.chat_models.ollama",
|
||||||
"ChatOpenAI": "langchain_community.chat_models.openai",
|
"ChatOpenAI": "langchain_community.chat_models.openai",
|
||||||
"ChatPerplexity": "langchain_community.chat_models.perplexity",
|
"ChatPerplexity": "langchain_community.chat_models.perplexity",
|
||||||
|
@ -0,0 +1,363 @@
|
|||||||
|
import json
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence
|
||||||
|
|
||||||
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
|
from langchain_core.language_models.chat_models import (
|
||||||
|
BaseChatModel,
|
||||||
|
generate_from_stream,
|
||||||
|
)
|
||||||
|
from langchain_core.messages import (
|
||||||
|
AIMessage,
|
||||||
|
AIMessageChunk,
|
||||||
|
BaseMessage,
|
||||||
|
ChatMessage,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
)
|
||||||
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
|
from langchain_core.pydantic_v1 import Extra
|
||||||
|
|
||||||
|
from langchain_community.llms.oci_generative_ai import OCIGenAIBase
|
||||||
|
from langchain_community.llms.utils import enforce_stop_tokens
|
||||||
|
|
||||||
|
CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint"
|
||||||
|
|
||||||
|
|
||||||
|
class Provider(ABC):
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def stop_sequence_key(self) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def chat_response_to_text(self, response: Any) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def chat_stream_to_text(self, event_data: Dict) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def chat_generation_info(self, response: Any) -> Dict[str, Any]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_role(self, message: BaseMessage) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def messages_to_oci_params(self, messages: Any) -> Dict[str, Any]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class CohereProvider(Provider):
|
||||||
|
stop_sequence_key = "stop_sequences"
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
from oci.generative_ai_inference import models
|
||||||
|
|
||||||
|
self.oci_chat_request = models.CohereChatRequest
|
||||||
|
self.oci_chat_message = {
|
||||||
|
"USER": models.CohereUserMessage,
|
||||||
|
"CHATBOT": models.CohereChatBotMessage,
|
||||||
|
"SYSTEM": models.CohereSystemMessage,
|
||||||
|
}
|
||||||
|
self.chat_api_format = models.BaseChatRequest.API_FORMAT_COHERE
|
||||||
|
|
||||||
|
def chat_response_to_text(self, response: Any) -> str:
|
||||||
|
return response.data.chat_response.text
|
||||||
|
|
||||||
|
def chat_stream_to_text(self, event_data: Dict) -> str:
|
||||||
|
if "text" in event_data and "finishReason" not in event_data:
|
||||||
|
return event_data["text"]
|
||||||
|
else:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def chat_generation_info(self, response: Any) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"finish_reason": response.data.chat_response.finish_reason,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_role(self, message: BaseMessage) -> str:
|
||||||
|
if isinstance(message, HumanMessage):
|
||||||
|
return "USER"
|
||||||
|
elif isinstance(message, AIMessage):
|
||||||
|
return "CHATBOT"
|
||||||
|
elif isinstance(message, SystemMessage):
|
||||||
|
return "SYSTEM"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
|
||||||
|
def messages_to_oci_params(self, messages: Sequence[ChatMessage]) -> Dict[str, Any]:
|
||||||
|
oci_chat_history = [
|
||||||
|
self.oci_chat_message[self.get_role(msg)](message=msg.content)
|
||||||
|
for msg in messages[:-1]
|
||||||
|
]
|
||||||
|
oci_params = {
|
||||||
|
"message": messages[-1].content,
|
||||||
|
"chat_history": oci_chat_history,
|
||||||
|
"api_format": self.chat_api_format,
|
||||||
|
}
|
||||||
|
|
||||||
|
return oci_params
|
||||||
|
|
||||||
|
|
||||||
|
class MetaProvider(Provider):
|
||||||
|
stop_sequence_key = "stop"
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
from oci.generative_ai_inference import models
|
||||||
|
|
||||||
|
self.oci_chat_request = models.GenericChatRequest
|
||||||
|
self.oci_chat_message = {
|
||||||
|
"USER": models.UserMessage,
|
||||||
|
"SYSTEM": models.SystemMessage,
|
||||||
|
"ASSISTANT": models.AssistantMessage,
|
||||||
|
}
|
||||||
|
self.oci_chat_message_content = models.TextContent
|
||||||
|
self.chat_api_format = models.BaseChatRequest.API_FORMAT_GENERIC
|
||||||
|
|
||||||
|
def chat_response_to_text(self, response: Any) -> str:
|
||||||
|
return response.data.chat_response.choices[0].message.content[0].text
|
||||||
|
|
||||||
|
def chat_stream_to_text(self, event_data: Dict) -> str:
|
||||||
|
if "message" in event_data:
|
||||||
|
return event_data["message"]["content"][0]["text"]
|
||||||
|
else:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def chat_generation_info(self, response: Any) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"finish_reason": response.data.chat_response.choices[0].finish_reason,
|
||||||
|
"time_created": str(response.data.chat_response.time_created),
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_role(self, message: BaseMessage) -> str:
|
||||||
|
# meta only supports alternating user/assistant roles
|
||||||
|
if isinstance(message, HumanMessage):
|
||||||
|
return "USER"
|
||||||
|
elif isinstance(message, AIMessage):
|
||||||
|
return "ASSISTANT"
|
||||||
|
elif isinstance(message, SystemMessage):
|
||||||
|
return "SYSTEM"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
|
||||||
|
def messages_to_oci_params(self, messages: List[BaseMessage]) -> Dict[str, Any]:
|
||||||
|
oci_messages = [
|
||||||
|
self.oci_chat_message[self.get_role(msg)](
|
||||||
|
content=[self.oci_chat_message_content(text=msg.content)]
|
||||||
|
)
|
||||||
|
for msg in messages
|
||||||
|
]
|
||||||
|
oci_params = {
|
||||||
|
"messages": oci_messages,
|
||||||
|
"api_format": self.chat_api_format,
|
||||||
|
"top_k": -1,
|
||||||
|
}
|
||||||
|
|
||||||
|
return oci_params
|
||||||
|
|
||||||
|
|
||||||
|
class ChatOCIGenAI(BaseChatModel, OCIGenAIBase):
|
||||||
|
"""ChatOCIGenAI chat model integration.
|
||||||
|
|
||||||
|
Setup:
|
||||||
|
Install ``langchain-community`` and the ``oci`` sdk.
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
pip install -U langchain-community oci
|
||||||
|
|
||||||
|
Key init args — completion params:
|
||||||
|
model_id: str
|
||||||
|
Id of the OCIGenAI chat model to use, e.g., cohere.command-r-16k.
|
||||||
|
is_stream: bool
|
||||||
|
Whether to stream back partial progress
|
||||||
|
model_kwargs: Optional[Dict]
|
||||||
|
Keyword arguments to pass to the specific model used, e.g., temperature, max_tokens.
|
||||||
|
|
||||||
|
Key init args — client params:
|
||||||
|
service_endpoint: str
|
||||||
|
The endpoint URL for the OCIGenAI service, e.g., https://inference.generativeai.us-chicago-1.oci.oraclecloud.com.
|
||||||
|
compartment_id: str
|
||||||
|
The compartment OCID.
|
||||||
|
auth_type: str
|
||||||
|
The authentication type to use, e.g., API_KEY (default), SECURITY_TOKEN, INSTANCE_PRINCIPAL, RESOURCE_PRINCIPAL.
|
||||||
|
auth_profile: Optional[str]
|
||||||
|
The name of the profile in ~/.oci/config, if not specified , DEFAULT will be used.
|
||||||
|
provider: str
|
||||||
|
Provider name of the model. Default to None, will try to be derived from the model_id otherwise, requires user input.
|
||||||
|
See full list of supported init args and their descriptions in the params section.
|
||||||
|
|
||||||
|
Instantiate:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_community.chat_models import ChatOCIGenAI
|
||||||
|
|
||||||
|
chat = ChatOCIGenAI(
|
||||||
|
model_id="cohere.command-r-16k",
|
||||||
|
service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com",
|
||||||
|
compartment_id="MY_OCID",
|
||||||
|
model_kwargs={"temperature": 0.7, "max_tokens": 500},
|
||||||
|
)
|
||||||
|
|
||||||
|
Invoke:
|
||||||
|
.. code-block:: python
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content="your are an AI assistant."),
|
||||||
|
AIMessage(content="Hi there human!"),
|
||||||
|
HumanMessage(content="tell me a joke."),
|
||||||
|
]
|
||||||
|
response = chat.invoke(messages)
|
||||||
|
|
||||||
|
Stream:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
for r in chat.stream(messages):
|
||||||
|
print(r.content, end="", flush=True)
|
||||||
|
|
||||||
|
Response metadata
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
response = chat.invoke(messages)
|
||||||
|
print(response.response_metadata)
|
||||||
|
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of llm."""
|
||||||
|
return "oci_generative_ai_chat"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _provider_map(self) -> Mapping[str, Any]:
|
||||||
|
"""Get the provider map"""
|
||||||
|
return {
|
||||||
|
"cohere": CohereProvider(),
|
||||||
|
"meta": MetaProvider(),
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _provider(self) -> Any:
|
||||||
|
"""Get the internal provider object"""
|
||||||
|
return self._get_provider(provider_map=self._provider_map)
|
||||||
|
|
||||||
|
def _prepare_request(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]],
|
||||||
|
kwargs: Dict[str, Any],
|
||||||
|
stream: bool,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
try:
|
||||||
|
from oci.generative_ai_inference import models
|
||||||
|
|
||||||
|
except ImportError as ex:
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
"Could not import oci python package. "
|
||||||
|
"Please make sure you have the oci package installed."
|
||||||
|
) from ex
|
||||||
|
oci_params = self._provider.messages_to_oci_params(messages)
|
||||||
|
oci_params["is_stream"] = stream # self.is_stream
|
||||||
|
_model_kwargs = self.model_kwargs or {}
|
||||||
|
|
||||||
|
if stop is not None:
|
||||||
|
_model_kwargs[self._provider.stop_sequence_key] = stop
|
||||||
|
|
||||||
|
chat_params = {**_model_kwargs, **kwargs, **oci_params}
|
||||||
|
|
||||||
|
if self.model_id.startswith(CUSTOM_ENDPOINT_PREFIX):
|
||||||
|
serving_mode = models.DedicatedServingMode(endpoint_id=self.model_id)
|
||||||
|
else:
|
||||||
|
serving_mode = models.OnDemandServingMode(model_id=self.model_id)
|
||||||
|
|
||||||
|
request = models.ChatDetails(
|
||||||
|
compartment_id=self.compartment_id,
|
||||||
|
serving_mode=serving_mode,
|
||||||
|
chat_request=self._provider.oci_chat_request(**chat_params),
|
||||||
|
)
|
||||||
|
|
||||||
|
return request
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
"""Call out to a OCIGenAI chat model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: list of LangChain messages
|
||||||
|
stop: Optional list of stop words to use.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LangChain ChatResult
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
HumanMessage(content="hello!"),
|
||||||
|
AIMessage(content="Hi there human!"),
|
||||||
|
HumanMessage(content="Meow!")
|
||||||
|
]
|
||||||
|
|
||||||
|
response = llm.invoke(messages)
|
||||||
|
"""
|
||||||
|
if self.is_stream:
|
||||||
|
stream_iter = self._stream(
|
||||||
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
return generate_from_stream(stream_iter)
|
||||||
|
|
||||||
|
request = self._prepare_request(messages, stop, kwargs, stream=False)
|
||||||
|
response = self.client.chat(request)
|
||||||
|
|
||||||
|
content = self._provider.chat_response_to_text(response)
|
||||||
|
|
||||||
|
if stop is not None:
|
||||||
|
content = enforce_stop_tokens(content, stop)
|
||||||
|
|
||||||
|
generation_info = self._provider.chat_generation_info(response)
|
||||||
|
|
||||||
|
llm_output = {
|
||||||
|
"model_id": response.data.model_id,
|
||||||
|
"model_version": response.data.model_version,
|
||||||
|
"request_id": response.request_id,
|
||||||
|
"content-length": response.headers["content-length"],
|
||||||
|
}
|
||||||
|
|
||||||
|
return ChatResult(
|
||||||
|
generations=[
|
||||||
|
ChatGeneration(
|
||||||
|
message=AIMessage(content=content), generation_info=generation_info
|
||||||
|
)
|
||||||
|
],
|
||||||
|
llm_output=llm_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
|
request = self._prepare_request(messages, stop, kwargs, stream=True)
|
||||||
|
response = self.client.chat(request)
|
||||||
|
|
||||||
|
for event in response.data.events():
|
||||||
|
delta = self._provider.chat_stream_to_text(json.loads(event.data))
|
||||||
|
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(delta, chunk=chunk)
|
||||||
|
yield chunk
|
@ -1,17 +1,53 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC
|
import json
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Mapping, Optional
|
from typing import Any, Dict, Iterator, List, Mapping, Optional
|
||||||
|
|
||||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
from langchain_core.language_models.llms import LLM
|
from langchain_core.language_models.llms import LLM
|
||||||
|
from langchain_core.outputs import GenerationChunk
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
from langchain_community.llms.utils import enforce_stop_tokens
|
from langchain_community.llms.utils import enforce_stop_tokens
|
||||||
|
|
||||||
CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint"
|
CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint"
|
||||||
VALID_PROVIDERS = ("cohere", "meta")
|
|
||||||
|
|
||||||
|
class Provider(ABC):
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def stop_sequence_key(self) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def completion_response_to_text(self, response: Any) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class CohereProvider(Provider):
|
||||||
|
stop_sequence_key = "stop_sequences"
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
from oci.generative_ai_inference import models
|
||||||
|
|
||||||
|
self.llm_inference_request = models.CohereLlmInferenceRequest
|
||||||
|
|
||||||
|
def completion_response_to_text(self, response: Any) -> str:
|
||||||
|
return response.data.inference_response.generated_texts[0].text
|
||||||
|
|
||||||
|
|
||||||
|
class MetaProvider(Provider):
|
||||||
|
stop_sequence_key = "stop"
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
from oci.generative_ai_inference import models
|
||||||
|
|
||||||
|
self.llm_inference_request = models.LlamaLlmInferenceRequest
|
||||||
|
|
||||||
|
def completion_response_to_text(self, response: Any) -> str:
|
||||||
|
return response.data.inference_response.choices[0].text
|
||||||
|
|
||||||
|
|
||||||
class OCIAuthType(Enum):
|
class OCIAuthType(Enum):
|
||||||
@ -33,8 +69,8 @@ class OCIGenAIBase(BaseModel, ABC):
|
|||||||
|
|
||||||
API_KEY,
|
API_KEY,
|
||||||
SECURITY_TOKEN,
|
SECURITY_TOKEN,
|
||||||
INSTANCE_PRINCIPLE,
|
INSTANCE_PRINCIPAL,
|
||||||
RESOURCE_PRINCIPLE
|
RESOURCE_PRINCIPAL
|
||||||
|
|
||||||
If not specified, API_KEY will be used
|
If not specified, API_KEY will be used
|
||||||
"""
|
"""
|
||||||
@ -65,11 +101,6 @@ class OCIGenAIBase(BaseModel, ABC):
|
|||||||
is_stream: bool = False
|
is_stream: bool = False
|
||||||
"""Whether to stream back partial progress"""
|
"""Whether to stream back partial progress"""
|
||||||
|
|
||||||
llm_stop_sequence_mapping: Mapping[str, str] = {
|
|
||||||
"cohere": "stop_sequences",
|
|
||||||
"meta": "stop",
|
|
||||||
}
|
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that OCI config and python package exists in environment."""
|
"""Validate that OCI config and python package exists in environment."""
|
||||||
@ -121,24 +152,28 @@ class OCIGenAIBase(BaseModel, ABC):
|
|||||||
"signer"
|
"signer"
|
||||||
] = oci.auth.signers.get_resource_principals_signer()
|
] = oci.auth.signers.get_resource_principals_signer()
|
||||||
else:
|
else:
|
||||||
raise ValueError("Please provide valid value to auth_type")
|
raise ValueError(
|
||||||
|
"Please provide valid value to auth_type, "
|
||||||
|
f"{values['auth_type']} is not valid."
|
||||||
|
)
|
||||||
|
|
||||||
values["client"] = oci.generative_ai_inference.GenerativeAiInferenceClient(
|
values["client"] = oci.generative_ai_inference.GenerativeAiInferenceClient(
|
||||||
**client_kwargs
|
**client_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
except ImportError as ex:
|
except ImportError as ex:
|
||||||
raise ImportError(
|
raise ModuleNotFoundError(
|
||||||
"Could not import oci python package. "
|
"Could not import oci python package. "
|
||||||
"Please make sure you have the oci package installed."
|
"Please make sure you have the oci package installed."
|
||||||
) from ex
|
) from ex
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Could not authenticate with OCI client. "
|
"""Could not authenticate with OCI client.
|
||||||
"Please check if ~/.oci/config exists. "
|
Please check if ~/.oci/config exists.
|
||||||
"If INSTANCE_PRINCIPLE or RESOURCE_PRINCIPLE is used, "
|
If INSTANCE_PRINCIPAL or RESOURCE_PRINCIPAL is used,
|
||||||
"Please check the specified "
|
please check the specified
|
||||||
"auth_profile and auth_type are valid."
|
auth_profile and auth_type are valid.""",
|
||||||
|
e,
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
return values
|
return values
|
||||||
@ -151,19 +186,19 @@ class OCIGenAIBase(BaseModel, ABC):
|
|||||||
**{"model_kwargs": _model_kwargs},
|
**{"model_kwargs": _model_kwargs},
|
||||||
}
|
}
|
||||||
|
|
||||||
def _get_provider(self) -> str:
|
def _get_provider(self, provider_map: Mapping[str, Any]) -> Any:
|
||||||
if self.provider is not None:
|
if self.provider is not None:
|
||||||
provider = self.provider
|
provider = self.provider
|
||||||
else:
|
else:
|
||||||
provider = self.model_id.split(".")[0].lower()
|
provider = self.model_id.split(".")[0].lower()
|
||||||
|
|
||||||
if provider not in VALID_PROVIDERS:
|
if provider not in provider_map:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid provider derived from model_id: {self.model_id} "
|
f"Invalid provider derived from model_id: {self.model_id} "
|
||||||
"Please explicitly pass in the supported provider "
|
"Please explicitly pass in the supported provider "
|
||||||
"when using custom endpoint"
|
"when using custom endpoint"
|
||||||
)
|
)
|
||||||
return provider
|
return provider_map[provider]
|
||||||
|
|
||||||
|
|
||||||
class OCIGenAI(LLM, OCIGenAIBase):
|
class OCIGenAI(LLM, OCIGenAIBase):
|
||||||
@ -173,7 +208,7 @@ class OCIGenAI(LLM, OCIGenAIBase):
|
|||||||
https://docs.oracle.com/en-us/iaas/Content/API/Concepts/sdk_authentication_methods.htm
|
https://docs.oracle.com/en-us/iaas/Content/API/Concepts/sdk_authentication_methods.htm
|
||||||
|
|
||||||
The authentifcation method is passed through auth_type and should be one of:
|
The authentifcation method is passed through auth_type and should be one of:
|
||||||
API_KEY (default), SECURITY_TOKEN, INSTANCE_PRINCIPLE, RESOURCE_PRINCIPLE
|
API_KEY (default), SECURITY_TOKEN, INSTANCE_PRINCIPAL, RESOURCE_PRINCIPAL
|
||||||
|
|
||||||
Make sure you have the required policies (profile/roles) to
|
Make sure you have the required policies (profile/roles) to
|
||||||
access the OCI Generative AI service.
|
access the OCI Generative AI service.
|
||||||
@ -204,21 +239,29 @@ class OCIGenAI(LLM, OCIGenAIBase):
|
|||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
"""Return type of llm."""
|
"""Return type of llm."""
|
||||||
return "oci"
|
return "oci_generative_ai_completion"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _provider_map(self) -> Mapping[str, Any]:
|
||||||
|
"""Get the provider map"""
|
||||||
|
return {
|
||||||
|
"cohere": CohereProvider(),
|
||||||
|
"meta": MetaProvider(),
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _provider(self) -> Any:
|
||||||
|
"""Get the internal provider object"""
|
||||||
|
return self._get_provider(provider_map=self._provider_map)
|
||||||
|
|
||||||
def _prepare_invocation_object(
|
def _prepare_invocation_object(
|
||||||
self, prompt: str, stop: Optional[List[str]], kwargs: Dict[str, Any]
|
self, prompt: str, stop: Optional[List[str]], kwargs: Dict[str, Any]
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
from oci.generative_ai_inference import models
|
from oci.generative_ai_inference import models
|
||||||
|
|
||||||
oci_llm_request_mapping = {
|
|
||||||
"cohere": models.CohereLlmInferenceRequest,
|
|
||||||
"meta": models.LlamaLlmInferenceRequest,
|
|
||||||
}
|
|
||||||
provider = self._get_provider()
|
|
||||||
_model_kwargs = self.model_kwargs or {}
|
_model_kwargs = self.model_kwargs or {}
|
||||||
if stop is not None:
|
if stop is not None:
|
||||||
_model_kwargs[self.llm_stop_sequence_mapping[provider]] = stop
|
_model_kwargs[self._provider.stop_sequence_key] = stop
|
||||||
|
|
||||||
if self.model_id.startswith(CUSTOM_ENDPOINT_PREFIX):
|
if self.model_id.startswith(CUSTOM_ENDPOINT_PREFIX):
|
||||||
serving_mode = models.DedicatedServingMode(endpoint_id=self.model_id)
|
serving_mode = models.DedicatedServingMode(endpoint_id=self.model_id)
|
||||||
@ -232,19 +275,13 @@ class OCIGenAI(LLM, OCIGenAIBase):
|
|||||||
invocation_obj = models.GenerateTextDetails(
|
invocation_obj = models.GenerateTextDetails(
|
||||||
compartment_id=self.compartment_id,
|
compartment_id=self.compartment_id,
|
||||||
serving_mode=serving_mode,
|
serving_mode=serving_mode,
|
||||||
inference_request=oci_llm_request_mapping[provider](**inference_params),
|
inference_request=self._provider.llm_inference_request(**inference_params),
|
||||||
)
|
)
|
||||||
|
|
||||||
return invocation_obj
|
return invocation_obj
|
||||||
|
|
||||||
def _process_response(self, response: Any, stop: Optional[List[str]]) -> str:
|
def _process_response(self, response: Any, stop: Optional[List[str]]) -> str:
|
||||||
provider = self._get_provider()
|
text = self._provider.completion_response_to_text(response)
|
||||||
if provider == "cohere":
|
|
||||||
text = response.data.inference_response.generated_texts[0].text
|
|
||||||
elif provider == "meta":
|
|
||||||
text = response.data.inference_response.choices[0].text
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid provider: {provider}")
|
|
||||||
|
|
||||||
if stop is not None:
|
if stop is not None:
|
||||||
text = enforce_stop_tokens(text, stop)
|
text = enforce_stop_tokens(text, stop)
|
||||||
@ -272,7 +309,51 @@ class OCIGenAI(LLM, OCIGenAIBase):
|
|||||||
|
|
||||||
response = llm.invoke("Tell me a joke.")
|
response = llm.invoke("Tell me a joke.")
|
||||||
"""
|
"""
|
||||||
|
if self.is_stream:
|
||||||
|
text = ""
|
||||||
|
for chunk in self._stream(prompt, stop, run_manager, **kwargs):
|
||||||
|
text += chunk.text
|
||||||
|
if stop is not None:
|
||||||
|
text = enforce_stop_tokens(text, stop)
|
||||||
|
return text
|
||||||
|
|
||||||
invocation_obj = self._prepare_invocation_object(prompt, stop, kwargs)
|
invocation_obj = self._prepare_invocation_object(prompt, stop, kwargs)
|
||||||
response = self.client.generate_text(invocation_obj)
|
response = self.client.generate_text(invocation_obj)
|
||||||
return self._process_response(response, stop)
|
return self._process_response(response, stop)
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[GenerationChunk]:
|
||||||
|
"""Stream OCIGenAI LLM on given prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt to pass into the model.
|
||||||
|
stop: Optional list of stop words to use when generating.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An iterator of GenerationChunks.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
response = llm.stream("Tell me a joke.")
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.is_stream = True
|
||||||
|
invocation_obj = self._prepare_invocation_object(prompt, stop, kwargs)
|
||||||
|
response = self.client.generate_text(invocation_obj)
|
||||||
|
|
||||||
|
for event in response.data.events():
|
||||||
|
json_load = json.loads(event.data)
|
||||||
|
if "text" in json_load:
|
||||||
|
event_data_text = json_load["text"]
|
||||||
|
else:
|
||||||
|
event_data_text = ""
|
||||||
|
chunk = GenerationChunk(text=event_data_text)
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||||
|
yield chunk
|
||||||
|
@ -27,6 +27,7 @@ EXPECTED_ALL = [
|
|||||||
"ChatMlflow",
|
"ChatMlflow",
|
||||||
"ChatMLflowAIGateway",
|
"ChatMLflowAIGateway",
|
||||||
"ChatMLX",
|
"ChatMLX",
|
||||||
|
"ChatOCIGenAI",
|
||||||
"ChatOllama",
|
"ChatOllama",
|
||||||
"ChatOpenAI",
|
"ChatOpenAI",
|
||||||
"ChatPerplexity",
|
"ChatPerplexity",
|
||||||
|
@ -0,0 +1,105 @@
|
|||||||
|
"""Test OCI Generative AI LLM service"""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
from pytest import MonkeyPatch
|
||||||
|
|
||||||
|
from langchain_community.chat_models.oci_generative_ai import ChatOCIGenAI
|
||||||
|
|
||||||
|
|
||||||
|
class MockResponseDict(dict):
|
||||||
|
def __getattr__(self, val): # type: ignore[no-untyped-def]
|
||||||
|
return self[val]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("oci")
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_model_id", ["cohere.command-r-16k", "meta.llama-3-70b-instruct"]
|
||||||
|
)
|
||||||
|
def test_llm_chat(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
|
||||||
|
"""Test valid chat call to OCI Generative AI LLM service."""
|
||||||
|
oci_gen_ai_client = MagicMock()
|
||||||
|
llm = ChatOCIGenAI(model_id=test_model_id, client=oci_gen_ai_client)
|
||||||
|
|
||||||
|
provider = llm.model_id.split(".")[0].lower()
|
||||||
|
|
||||||
|
def mocked_response(*args): # type: ignore[no-untyped-def]
|
||||||
|
response_text = "Assistant chat reply."
|
||||||
|
response = None
|
||||||
|
if provider == "cohere":
|
||||||
|
response = MockResponseDict(
|
||||||
|
{
|
||||||
|
"status": 200,
|
||||||
|
"data": MockResponseDict(
|
||||||
|
{
|
||||||
|
"chat_response": MockResponseDict(
|
||||||
|
{
|
||||||
|
"text": response_text,
|
||||||
|
"finish_reason": "completed",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
"model_id": "cohere.command-r-16k",
|
||||||
|
"model_version": "1.0.0",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
"request_id": "1234567890",
|
||||||
|
"headers": MockResponseDict(
|
||||||
|
{
|
||||||
|
"content-length": "123",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif provider == "meta":
|
||||||
|
response = MockResponseDict(
|
||||||
|
{
|
||||||
|
"status": 200,
|
||||||
|
"data": MockResponseDict(
|
||||||
|
{
|
||||||
|
"chat_response": MockResponseDict(
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
MockResponseDict(
|
||||||
|
{
|
||||||
|
"message": MockResponseDict(
|
||||||
|
{
|
||||||
|
"content": [
|
||||||
|
MockResponseDict(
|
||||||
|
{
|
||||||
|
"text": response_text, # noqa: E501
|
||||||
|
}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
),
|
||||||
|
"finish_reason": "completed",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"time_created": "2024-09-01T00:00:00Z",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
"model_id": "cohere.command-r-16k",
|
||||||
|
"model_version": "1.0.0",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
"request_id": "1234567890",
|
||||||
|
"headers": MockResponseDict(
|
||||||
|
{
|
||||||
|
"content-length": "123",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
monkeypatch.setattr(llm.client, "chat", mocked_response)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
HumanMessage(content="User message"),
|
||||||
|
]
|
||||||
|
|
||||||
|
expected = "Assistant chat reply."
|
||||||
|
actual = llm.invoke(messages, temperature=0.2)
|
||||||
|
assert actual.content == expected
|
@ -4,7 +4,7 @@ from unittest.mock import MagicMock
|
|||||||
import pytest
|
import pytest
|
||||||
from pytest import MonkeyPatch
|
from pytest import MonkeyPatch
|
||||||
|
|
||||||
from langchain_community.llms import OCIGenAI
|
from langchain_community.llms.oci_generative_ai import OCIGenAI
|
||||||
|
|
||||||
|
|
||||||
class MockResponseDict(dict):
|
class MockResponseDict(dict):
|
||||||
@ -16,12 +16,12 @@ class MockResponseDict(dict):
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"test_model_id", ["cohere.command", "cohere.command-light", "meta.llama-2-70b-chat"]
|
"test_model_id", ["cohere.command", "cohere.command-light", "meta.llama-2-70b-chat"]
|
||||||
)
|
)
|
||||||
def test_llm_call(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
|
def test_llm_complete(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
|
||||||
"""Test valid call to OCI Generative AI LLM service."""
|
"""Test valid completion call to OCI Generative AI LLM service."""
|
||||||
oci_gen_ai_client = MagicMock()
|
oci_gen_ai_client = MagicMock()
|
||||||
llm = OCIGenAI(model_id=test_model_id, client=oci_gen_ai_client)
|
llm = OCIGenAI(model_id=test_model_id, client=oci_gen_ai_client)
|
||||||
|
|
||||||
provider = llm._get_provider()
|
provider = llm.model_id.split(".")[0].lower()
|
||||||
|
|
||||||
def mocked_response(*args): # type: ignore[no-untyped-def]
|
def mocked_response(*args): # type: ignore[no-untyped-def]
|
||||||
response_text = "This is the completion."
|
response_text = "This is the completion."
|
||||||
@ -71,6 +71,5 @@ def test_llm_call(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
monkeypatch.setattr(llm.client, "generate_text", mocked_response)
|
monkeypatch.setattr(llm.client, "generate_text", mocked_response)
|
||||||
|
|
||||||
output = llm.invoke("This is a prompt.", temperature=0.2)
|
output = llm.invoke("This is a prompt.", temperature=0.2)
|
||||||
assert output == "This is the completion."
|
assert output == "This is the completion."
|
||||||
|
Loading…
Reference in New Issue
Block a user