Add OCI Generative AI new model support (#22880)

- [x] PR title: 
community: Add OCI Generative AI new model support
 
- [x] PR message:
- Description: adding support for new models offered by OCI Generative
AI services. This is a moderate update of our initial integration PR
16548 and includes a new integration for our chat models under
/langchain_community/chat_models/oci_generative_ai.py
    - Issue: NA
- Dependencies: No new Dependencies, just latest version of our OCI sdk
    - Twitter handle: NA


- [x] Add tests and docs: 
  1. we have updated our unit tests
2. we have updated our documentation including a new ipynb for our new
chat integration


- [x] Lint and test: 
 `make format`, `make lint`, and `make test` run successfully

---------

Co-authored-by: RHARPAZ <RHARPAZ@RHARPAZ-5750.us.oracle.com>
Co-authored-by: Arthur Cheng <arthur.cheng@oracle.com>
This commit is contained in:
Rave Harpaz
2024-06-24 11:48:23 -07:00
committed by GitHub
parent 753edf9c80
commit f5ff7f178b
10 changed files with 885 additions and 146 deletions

View File

@@ -46,7 +46,7 @@ mwxml>=0.3.3,<0.4
newspaper3k>=0.2.8,<0.3
numexpr>=2.8.6,<3
nvidia-riva-client>=2.14.0,<3
oci>=2.119.1,<3
oci>=2.128.0,<3
openai<2
openapi-pydantic>=0.3.2,<0.4
oracle-ads>=2.9.1,<3

View File

@@ -121,6 +121,9 @@ if TYPE_CHECKING:
from langchain_community.chat_models.mlx import (
ChatMLX,
)
from langchain_community.chat_models.oci_generative_ai import (
ChatOCIGenAI, # noqa: F401
)
from langchain_community.chat_models.octoai import ChatOctoAI
from langchain_community.chat_models.ollama import (
ChatOllama,
@@ -194,6 +197,7 @@ __all__ = [
"ChatMLflowAIGateway",
"ChatMaritalk",
"ChatMlflow",
"ChatOCIGenAI",
"ChatOllama",
"ChatOpenAI",
"ChatPerplexity",
@@ -248,6 +252,7 @@ _module_lookup = {
"ChatMaritalk": "langchain_community.chat_models.maritalk",
"ChatMlflow": "langchain_community.chat_models.mlflow",
"ChatOctoAI": "langchain_community.chat_models.octoai",
"ChatOCIGenAI": "langchain_community.chat_models.oci_generative_ai",
"ChatOllama": "langchain_community.chat_models.ollama",
"ChatOpenAI": "langchain_community.chat_models.openai",
"ChatPerplexity": "langchain_community.chat_models.perplexity",

View File

@@ -0,0 +1,363 @@
import json
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import (
BaseChatModel,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Extra
from langchain_community.llms.oci_generative_ai import OCIGenAIBase
from langchain_community.llms.utils import enforce_stop_tokens
CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint"
class Provider(ABC):
@property
@abstractmethod
def stop_sequence_key(self) -> str:
...
@abstractmethod
def chat_response_to_text(self, response: Any) -> str:
...
@abstractmethod
def chat_stream_to_text(self, event_data: Dict) -> str:
...
@abstractmethod
def chat_generation_info(self, response: Any) -> Dict[str, Any]:
...
@abstractmethod
def get_role(self, message: BaseMessage) -> str:
...
@abstractmethod
def messages_to_oci_params(self, messages: Any) -> Dict[str, Any]:
...
class CohereProvider(Provider):
stop_sequence_key = "stop_sequences"
def __init__(self) -> None:
from oci.generative_ai_inference import models
self.oci_chat_request = models.CohereChatRequest
self.oci_chat_message = {
"USER": models.CohereUserMessage,
"CHATBOT": models.CohereChatBotMessage,
"SYSTEM": models.CohereSystemMessage,
}
self.chat_api_format = models.BaseChatRequest.API_FORMAT_COHERE
def chat_response_to_text(self, response: Any) -> str:
return response.data.chat_response.text
def chat_stream_to_text(self, event_data: Dict) -> str:
if "text" in event_data and "finishReason" not in event_data:
return event_data["text"]
else:
return ""
def chat_generation_info(self, response: Any) -> Dict[str, Any]:
return {
"finish_reason": response.data.chat_response.finish_reason,
}
def get_role(self, message: BaseMessage) -> str:
if isinstance(message, HumanMessage):
return "USER"
elif isinstance(message, AIMessage):
return "CHATBOT"
elif isinstance(message, SystemMessage):
return "SYSTEM"
else:
raise ValueError(f"Got unknown type {message}")
def messages_to_oci_params(self, messages: Sequence[ChatMessage]) -> Dict[str, Any]:
oci_chat_history = [
self.oci_chat_message[self.get_role(msg)](message=msg.content)
for msg in messages[:-1]
]
oci_params = {
"message": messages[-1].content,
"chat_history": oci_chat_history,
"api_format": self.chat_api_format,
}
return oci_params
class MetaProvider(Provider):
stop_sequence_key = "stop"
def __init__(self) -> None:
from oci.generative_ai_inference import models
self.oci_chat_request = models.GenericChatRequest
self.oci_chat_message = {
"USER": models.UserMessage,
"SYSTEM": models.SystemMessage,
"ASSISTANT": models.AssistantMessage,
}
self.oci_chat_message_content = models.TextContent
self.chat_api_format = models.BaseChatRequest.API_FORMAT_GENERIC
def chat_response_to_text(self, response: Any) -> str:
return response.data.chat_response.choices[0].message.content[0].text
def chat_stream_to_text(self, event_data: Dict) -> str:
if "message" in event_data:
return event_data["message"]["content"][0]["text"]
else:
return ""
def chat_generation_info(self, response: Any) -> Dict[str, Any]:
return {
"finish_reason": response.data.chat_response.choices[0].finish_reason,
"time_created": str(response.data.chat_response.time_created),
}
def get_role(self, message: BaseMessage) -> str:
# meta only supports alternating user/assistant roles
if isinstance(message, HumanMessage):
return "USER"
elif isinstance(message, AIMessage):
return "ASSISTANT"
elif isinstance(message, SystemMessage):
return "SYSTEM"
else:
raise ValueError(f"Got unknown type {message}")
def messages_to_oci_params(self, messages: List[BaseMessage]) -> Dict[str, Any]:
oci_messages = [
self.oci_chat_message[self.get_role(msg)](
content=[self.oci_chat_message_content(text=msg.content)]
)
for msg in messages
]
oci_params = {
"messages": oci_messages,
"api_format": self.chat_api_format,
"top_k": -1,
}
return oci_params
class ChatOCIGenAI(BaseChatModel, OCIGenAIBase):
"""ChatOCIGenAI chat model integration.
Setup:
Install ``langchain-community`` and the ``oci`` sdk.
.. code-block:: bash
pip install -U langchain-community oci
Key init args — completion params:
model_id: str
Id of the OCIGenAI chat model to use, e.g., cohere.command-r-16k.
is_stream: bool
Whether to stream back partial progress
model_kwargs: Optional[Dict]
Keyword arguments to pass to the specific model used, e.g., temperature, max_tokens.
Key init args — client params:
service_endpoint: str
The endpoint URL for the OCIGenAI service, e.g., https://inference.generativeai.us-chicago-1.oci.oraclecloud.com.
compartment_id: str
The compartment OCID.
auth_type: str
The authentication type to use, e.g., API_KEY (default), SECURITY_TOKEN, INSTANCE_PRINCIPAL, RESOURCE_PRINCIPAL.
auth_profile: Optional[str]
The name of the profile in ~/.oci/config, if not specified , DEFAULT will be used.
provider: str
Provider name of the model. Default to None, will try to be derived from the model_id otherwise, requires user input.
See full list of supported init args and their descriptions in the params section.
Instantiate:
.. code-block:: python
from langchain_community.chat_models import ChatOCIGenAI
chat = ChatOCIGenAI(
model_id="cohere.command-r-16k",
service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com",
compartment_id="MY_OCID",
model_kwargs={"temperature": 0.7, "max_tokens": 500},
)
Invoke:
.. code-block:: python
messages = [
SystemMessage(content="your are an AI assistant."),
AIMessage(content="Hi there human!"),
HumanMessage(content="tell me a joke."),
]
response = chat.invoke(messages)
Stream:
.. code-block:: python
for r in chat.stream(messages):
print(r.content, end="", flush=True)
Response metadata
.. code-block:: python
response = chat.invoke(messages)
print(response.response_metadata)
""" # noqa: E501
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "oci_generative_ai_chat"
@property
def _provider_map(self) -> Mapping[str, Any]:
"""Get the provider map"""
return {
"cohere": CohereProvider(),
"meta": MetaProvider(),
}
@property
def _provider(self) -> Any:
"""Get the internal provider object"""
return self._get_provider(provider_map=self._provider_map)
def _prepare_request(
self,
messages: List[BaseMessage],
stop: Optional[List[str]],
kwargs: Dict[str, Any],
stream: bool,
) -> Dict[str, Any]:
try:
from oci.generative_ai_inference import models
except ImportError as ex:
raise ModuleNotFoundError(
"Could not import oci python package. "
"Please make sure you have the oci package installed."
) from ex
oci_params = self._provider.messages_to_oci_params(messages)
oci_params["is_stream"] = stream # self.is_stream
_model_kwargs = self.model_kwargs or {}
if stop is not None:
_model_kwargs[self._provider.stop_sequence_key] = stop
chat_params = {**_model_kwargs, **kwargs, **oci_params}
if self.model_id.startswith(CUSTOM_ENDPOINT_PREFIX):
serving_mode = models.DedicatedServingMode(endpoint_id=self.model_id)
else:
serving_mode = models.OnDemandServingMode(model_id=self.model_id)
request = models.ChatDetails(
compartment_id=self.compartment_id,
serving_mode=serving_mode,
chat_request=self._provider.oci_chat_request(**chat_params),
)
return request
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Call out to a OCIGenAI chat model.
Args:
messages: list of LangChain messages
stop: Optional list of stop words to use.
Returns:
LangChain ChatResult
Example:
.. code-block:: python
messages = [
HumanMessage(content="hello!"),
AIMessage(content="Hi there human!"),
HumanMessage(content="Meow!")
]
response = llm.invoke(messages)
"""
if self.is_stream:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
request = self._prepare_request(messages, stop, kwargs, stream=False)
response = self.client.chat(request)
content = self._provider.chat_response_to_text(response)
if stop is not None:
content = enforce_stop_tokens(content, stop)
generation_info = self._provider.chat_generation_info(response)
llm_output = {
"model_id": response.data.model_id,
"model_version": response.data.model_version,
"request_id": response.request_id,
"content-length": response.headers["content-length"],
}
return ChatResult(
generations=[
ChatGeneration(
message=AIMessage(content=content), generation_info=generation_info
)
],
llm_output=llm_output,
)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
request = self._prepare_request(messages, stop, kwargs, stream=True)
response = self.client.chat(request)
for event in response.data.events():
delta = self._provider.chat_stream_to_text(json.loads(event.data))
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
if run_manager:
run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk

View File

@@ -1,17 +1,53 @@
from __future__ import annotations
from abc import ABC
import json
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, List, Mapping, Optional
from typing import Any, Dict, Iterator, List, Mapping, Optional
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_community.llms.utils import enforce_stop_tokens
CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint"
VALID_PROVIDERS = ("cohere", "meta")
class Provider(ABC):
@property
@abstractmethod
def stop_sequence_key(self) -> str:
...
@abstractmethod
def completion_response_to_text(self, response: Any) -> str:
...
class CohereProvider(Provider):
stop_sequence_key = "stop_sequences"
def __init__(self) -> None:
from oci.generative_ai_inference import models
self.llm_inference_request = models.CohereLlmInferenceRequest
def completion_response_to_text(self, response: Any) -> str:
return response.data.inference_response.generated_texts[0].text
class MetaProvider(Provider):
stop_sequence_key = "stop"
def __init__(self) -> None:
from oci.generative_ai_inference import models
self.llm_inference_request = models.LlamaLlmInferenceRequest
def completion_response_to_text(self, response: Any) -> str:
return response.data.inference_response.choices[0].text
class OCIAuthType(Enum):
@@ -33,8 +69,8 @@ class OCIGenAIBase(BaseModel, ABC):
API_KEY,
SECURITY_TOKEN,
INSTANCE_PRINCIPLE,
RESOURCE_PRINCIPLE
INSTANCE_PRINCIPAL,
RESOURCE_PRINCIPAL
If not specified, API_KEY will be used
"""
@@ -65,11 +101,6 @@ class OCIGenAIBase(BaseModel, ABC):
is_stream: bool = False
"""Whether to stream back partial progress"""
llm_stop_sequence_mapping: Mapping[str, str] = {
"cohere": "stop_sequences",
"meta": "stop",
}
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that OCI config and python package exists in environment."""
@@ -121,24 +152,28 @@ class OCIGenAIBase(BaseModel, ABC):
"signer"
] = oci.auth.signers.get_resource_principals_signer()
else:
raise ValueError("Please provide valid value to auth_type")
raise ValueError(
"Please provide valid value to auth_type, "
f"{values['auth_type']} is not valid."
)
values["client"] = oci.generative_ai_inference.GenerativeAiInferenceClient(
**client_kwargs
)
except ImportError as ex:
raise ImportError(
raise ModuleNotFoundError(
"Could not import oci python package. "
"Please make sure you have the oci package installed."
) from ex
except Exception as e:
raise ValueError(
"Could not authenticate with OCI client. "
"Please check if ~/.oci/config exists. "
"If INSTANCE_PRINCIPLE or RESOURCE_PRINCIPLE is used, "
"Please check the specified "
"auth_profile and auth_type are valid."
"""Could not authenticate with OCI client.
Please check if ~/.oci/config exists.
If INSTANCE_PRINCIPAL or RESOURCE_PRINCIPAL is used,
please check the specified
auth_profile and auth_type are valid.""",
e,
) from e
return values
@@ -151,19 +186,19 @@ class OCIGenAIBase(BaseModel, ABC):
**{"model_kwargs": _model_kwargs},
}
def _get_provider(self) -> str:
def _get_provider(self, provider_map: Mapping[str, Any]) -> Any:
if self.provider is not None:
provider = self.provider
else:
provider = self.model_id.split(".")[0].lower()
if provider not in VALID_PROVIDERS:
if provider not in provider_map:
raise ValueError(
f"Invalid provider derived from model_id: {self.model_id} "
"Please explicitly pass in the supported provider "
"when using custom endpoint"
)
return provider
return provider_map[provider]
class OCIGenAI(LLM, OCIGenAIBase):
@@ -173,7 +208,7 @@ class OCIGenAI(LLM, OCIGenAIBase):
https://docs.oracle.com/en-us/iaas/Content/API/Concepts/sdk_authentication_methods.htm
The authentifcation method is passed through auth_type and should be one of:
API_KEY (default), SECURITY_TOKEN, INSTANCE_PRINCIPLE, RESOURCE_PRINCIPLE
API_KEY (default), SECURITY_TOKEN, INSTANCE_PRINCIPAL, RESOURCE_PRINCIPAL
Make sure you have the required policies (profile/roles) to
access the OCI Generative AI service.
@@ -204,21 +239,29 @@ class OCIGenAI(LLM, OCIGenAIBase):
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "oci"
return "oci_generative_ai_completion"
@property
def _provider_map(self) -> Mapping[str, Any]:
"""Get the provider map"""
return {
"cohere": CohereProvider(),
"meta": MetaProvider(),
}
@property
def _provider(self) -> Any:
"""Get the internal provider object"""
return self._get_provider(provider_map=self._provider_map)
def _prepare_invocation_object(
self, prompt: str, stop: Optional[List[str]], kwargs: Dict[str, Any]
) -> Dict[str, Any]:
from oci.generative_ai_inference import models
oci_llm_request_mapping = {
"cohere": models.CohereLlmInferenceRequest,
"meta": models.LlamaLlmInferenceRequest,
}
provider = self._get_provider()
_model_kwargs = self.model_kwargs or {}
if stop is not None:
_model_kwargs[self.llm_stop_sequence_mapping[provider]] = stop
_model_kwargs[self._provider.stop_sequence_key] = stop
if self.model_id.startswith(CUSTOM_ENDPOINT_PREFIX):
serving_mode = models.DedicatedServingMode(endpoint_id=self.model_id)
@@ -232,19 +275,13 @@ class OCIGenAI(LLM, OCIGenAIBase):
invocation_obj = models.GenerateTextDetails(
compartment_id=self.compartment_id,
serving_mode=serving_mode,
inference_request=oci_llm_request_mapping[provider](**inference_params),
inference_request=self._provider.llm_inference_request(**inference_params),
)
return invocation_obj
def _process_response(self, response: Any, stop: Optional[List[str]]) -> str:
provider = self._get_provider()
if provider == "cohere":
text = response.data.inference_response.generated_texts[0].text
elif provider == "meta":
text = response.data.inference_response.choices[0].text
else:
raise ValueError(f"Invalid provider: {provider}")
text = self._provider.completion_response_to_text(response)
if stop is not None:
text = enforce_stop_tokens(text, stop)
@@ -272,7 +309,51 @@ class OCIGenAI(LLM, OCIGenAIBase):
response = llm.invoke("Tell me a joke.")
"""
if self.is_stream:
text = ""
for chunk in self._stream(prompt, stop, run_manager, **kwargs):
text += chunk.text
if stop is not None:
text = enforce_stop_tokens(text, stop)
return text
invocation_obj = self._prepare_invocation_object(prompt, stop, kwargs)
response = self.client.generate_text(invocation_obj)
return self._process_response(response, stop)
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
"""Stream OCIGenAI LLM on given prompt.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
An iterator of GenerationChunks.
Example:
.. code-block:: python
response = llm.stream("Tell me a joke.")
"""
self.is_stream = True
invocation_obj = self._prepare_invocation_object(prompt, stop, kwargs)
response = self.client.generate_text(invocation_obj)
for event in response.data.events():
json_load = json.loads(event.data)
if "text" in json_load:
event_data_text = json_load["text"]
else:
event_data_text = ""
chunk = GenerationChunk(text=event_data_text)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk

View File

@@ -27,6 +27,7 @@ EXPECTED_ALL = [
"ChatMlflow",
"ChatMLflowAIGateway",
"ChatMLX",
"ChatOCIGenAI",
"ChatOllama",
"ChatOpenAI",
"ChatPerplexity",

View File

@@ -0,0 +1,105 @@
"""Test OCI Generative AI LLM service"""
from unittest.mock import MagicMock
import pytest
from langchain_core.messages import HumanMessage
from pytest import MonkeyPatch
from langchain_community.chat_models.oci_generative_ai import ChatOCIGenAI
class MockResponseDict(dict):
def __getattr__(self, val): # type: ignore[no-untyped-def]
return self[val]
@pytest.mark.requires("oci")
@pytest.mark.parametrize(
"test_model_id", ["cohere.command-r-16k", "meta.llama-3-70b-instruct"]
)
def test_llm_chat(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
"""Test valid chat call to OCI Generative AI LLM service."""
oci_gen_ai_client = MagicMock()
llm = ChatOCIGenAI(model_id=test_model_id, client=oci_gen_ai_client)
provider = llm.model_id.split(".")[0].lower()
def mocked_response(*args): # type: ignore[no-untyped-def]
response_text = "Assistant chat reply."
response = None
if provider == "cohere":
response = MockResponseDict(
{
"status": 200,
"data": MockResponseDict(
{
"chat_response": MockResponseDict(
{
"text": response_text,
"finish_reason": "completed",
}
),
"model_id": "cohere.command-r-16k",
"model_version": "1.0.0",
}
),
"request_id": "1234567890",
"headers": MockResponseDict(
{
"content-length": "123",
}
),
}
)
elif provider == "meta":
response = MockResponseDict(
{
"status": 200,
"data": MockResponseDict(
{
"chat_response": MockResponseDict(
{
"choices": [
MockResponseDict(
{
"message": MockResponseDict(
{
"content": [
MockResponseDict(
{
"text": response_text, # noqa: E501
}
)
]
}
),
"finish_reason": "completed",
}
)
],
"time_created": "2024-09-01T00:00:00Z",
}
),
"model_id": "cohere.command-r-16k",
"model_version": "1.0.0",
}
),
"request_id": "1234567890",
"headers": MockResponseDict(
{
"content-length": "123",
}
),
}
)
return response
monkeypatch.setattr(llm.client, "chat", mocked_response)
messages = [
HumanMessage(content="User message"),
]
expected = "Assistant chat reply."
actual = llm.invoke(messages, temperature=0.2)
assert actual.content == expected

View File

@@ -4,7 +4,7 @@ from unittest.mock import MagicMock
import pytest
from pytest import MonkeyPatch
from langchain_community.llms import OCIGenAI
from langchain_community.llms.oci_generative_ai import OCIGenAI
class MockResponseDict(dict):
@@ -16,12 +16,12 @@ class MockResponseDict(dict):
@pytest.mark.parametrize(
"test_model_id", ["cohere.command", "cohere.command-light", "meta.llama-2-70b-chat"]
)
def test_llm_call(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
"""Test valid call to OCI Generative AI LLM service."""
def test_llm_complete(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
"""Test valid completion call to OCI Generative AI LLM service."""
oci_gen_ai_client = MagicMock()
llm = OCIGenAI(model_id=test_model_id, client=oci_gen_ai_client)
provider = llm._get_provider()
provider = llm.model_id.split(".")[0].lower()
def mocked_response(*args): # type: ignore[no-untyped-def]
response_text = "This is the completion."
@@ -71,6 +71,5 @@ def test_llm_call(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
)
monkeypatch.setattr(llm.client, "generate_text", mocked_response)
output = llm.invoke("This is a prompt.", temperature=0.2)
assert output == "This is the completion."