Add OCI Generative AI new model support (#22880)

- [x] PR title: 
community: Add OCI Generative AI new model support
 
- [x] PR message:
- Description: adding support for new models offered by OCI Generative
AI services. This is a moderate update of our initial integration PR
16548 and includes a new integration for our chat models under
/langchain_community/chat_models/oci_generative_ai.py
    - Issue: NA
- Dependencies: No new Dependencies, just latest version of our OCI sdk
    - Twitter handle: NA


- [x] Add tests and docs: 
  1. we have updated our unit tests
2. we have updated our documentation including a new ipynb for our new
chat integration


- [x] Lint and test: 
 `make format`, `make lint`, and `make test` run successfully

---------

Co-authored-by: RHARPAZ <RHARPAZ@RHARPAZ-5750.us.oracle.com>
Co-authored-by: Arthur Cheng <arthur.cheng@oracle.com>
This commit is contained in:
Rave Harpaz
2024-06-24 11:48:23 -07:00
committed by GitHub
parent 753edf9c80
commit f5ff7f178b
10 changed files with 885 additions and 146 deletions

View File

@@ -1,17 +1,53 @@
from __future__ import annotations
from abc import ABC
import json
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, List, Mapping, Optional
from typing import Any, Dict, Iterator, List, Mapping, Optional
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_community.llms.utils import enforce_stop_tokens
CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint"
VALID_PROVIDERS = ("cohere", "meta")
class Provider(ABC):
@property
@abstractmethod
def stop_sequence_key(self) -> str:
...
@abstractmethod
def completion_response_to_text(self, response: Any) -> str:
...
class CohereProvider(Provider):
stop_sequence_key = "stop_sequences"
def __init__(self) -> None:
from oci.generative_ai_inference import models
self.llm_inference_request = models.CohereLlmInferenceRequest
def completion_response_to_text(self, response: Any) -> str:
return response.data.inference_response.generated_texts[0].text
class MetaProvider(Provider):
stop_sequence_key = "stop"
def __init__(self) -> None:
from oci.generative_ai_inference import models
self.llm_inference_request = models.LlamaLlmInferenceRequest
def completion_response_to_text(self, response: Any) -> str:
return response.data.inference_response.choices[0].text
class OCIAuthType(Enum):
@@ -33,8 +69,8 @@ class OCIGenAIBase(BaseModel, ABC):
API_KEY,
SECURITY_TOKEN,
INSTANCE_PRINCIPLE,
RESOURCE_PRINCIPLE
INSTANCE_PRINCIPAL,
RESOURCE_PRINCIPAL
If not specified, API_KEY will be used
"""
@@ -65,11 +101,6 @@ class OCIGenAIBase(BaseModel, ABC):
is_stream: bool = False
"""Whether to stream back partial progress"""
llm_stop_sequence_mapping: Mapping[str, str] = {
"cohere": "stop_sequences",
"meta": "stop",
}
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that OCI config and python package exists in environment."""
@@ -121,24 +152,28 @@ class OCIGenAIBase(BaseModel, ABC):
"signer"
] = oci.auth.signers.get_resource_principals_signer()
else:
raise ValueError("Please provide valid value to auth_type")
raise ValueError(
"Please provide valid value to auth_type, "
f"{values['auth_type']} is not valid."
)
values["client"] = oci.generative_ai_inference.GenerativeAiInferenceClient(
**client_kwargs
)
except ImportError as ex:
raise ImportError(
raise ModuleNotFoundError(
"Could not import oci python package. "
"Please make sure you have the oci package installed."
) from ex
except Exception as e:
raise ValueError(
"Could not authenticate with OCI client. "
"Please check if ~/.oci/config exists. "
"If INSTANCE_PRINCIPLE or RESOURCE_PRINCIPLE is used, "
"Please check the specified "
"auth_profile and auth_type are valid."
"""Could not authenticate with OCI client.
Please check if ~/.oci/config exists.
If INSTANCE_PRINCIPAL or RESOURCE_PRINCIPAL is used,
please check the specified
auth_profile and auth_type are valid.""",
e,
) from e
return values
@@ -151,19 +186,19 @@ class OCIGenAIBase(BaseModel, ABC):
**{"model_kwargs": _model_kwargs},
}
def _get_provider(self) -> str:
def _get_provider(self, provider_map: Mapping[str, Any]) -> Any:
if self.provider is not None:
provider = self.provider
else:
provider = self.model_id.split(".")[0].lower()
if provider not in VALID_PROVIDERS:
if provider not in provider_map:
raise ValueError(
f"Invalid provider derived from model_id: {self.model_id} "
"Please explicitly pass in the supported provider "
"when using custom endpoint"
)
return provider
return provider_map[provider]
class OCIGenAI(LLM, OCIGenAIBase):
@@ -173,7 +208,7 @@ class OCIGenAI(LLM, OCIGenAIBase):
https://docs.oracle.com/en-us/iaas/Content/API/Concepts/sdk_authentication_methods.htm
The authentifcation method is passed through auth_type and should be one of:
API_KEY (default), SECURITY_TOKEN, INSTANCE_PRINCIPLE, RESOURCE_PRINCIPLE
API_KEY (default), SECURITY_TOKEN, INSTANCE_PRINCIPAL, RESOURCE_PRINCIPAL
Make sure you have the required policies (profile/roles) to
access the OCI Generative AI service.
@@ -204,21 +239,29 @@ class OCIGenAI(LLM, OCIGenAIBase):
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "oci"
return "oci_generative_ai_completion"
@property
def _provider_map(self) -> Mapping[str, Any]:
"""Get the provider map"""
return {
"cohere": CohereProvider(),
"meta": MetaProvider(),
}
@property
def _provider(self) -> Any:
"""Get the internal provider object"""
return self._get_provider(provider_map=self._provider_map)
def _prepare_invocation_object(
self, prompt: str, stop: Optional[List[str]], kwargs: Dict[str, Any]
) -> Dict[str, Any]:
from oci.generative_ai_inference import models
oci_llm_request_mapping = {
"cohere": models.CohereLlmInferenceRequest,
"meta": models.LlamaLlmInferenceRequest,
}
provider = self._get_provider()
_model_kwargs = self.model_kwargs or {}
if stop is not None:
_model_kwargs[self.llm_stop_sequence_mapping[provider]] = stop
_model_kwargs[self._provider.stop_sequence_key] = stop
if self.model_id.startswith(CUSTOM_ENDPOINT_PREFIX):
serving_mode = models.DedicatedServingMode(endpoint_id=self.model_id)
@@ -232,19 +275,13 @@ class OCIGenAI(LLM, OCIGenAIBase):
invocation_obj = models.GenerateTextDetails(
compartment_id=self.compartment_id,
serving_mode=serving_mode,
inference_request=oci_llm_request_mapping[provider](**inference_params),
inference_request=self._provider.llm_inference_request(**inference_params),
)
return invocation_obj
def _process_response(self, response: Any, stop: Optional[List[str]]) -> str:
provider = self._get_provider()
if provider == "cohere":
text = response.data.inference_response.generated_texts[0].text
elif provider == "meta":
text = response.data.inference_response.choices[0].text
else:
raise ValueError(f"Invalid provider: {provider}")
text = self._provider.completion_response_to_text(response)
if stop is not None:
text = enforce_stop_tokens(text, stop)
@@ -272,7 +309,51 @@ class OCIGenAI(LLM, OCIGenAIBase):
response = llm.invoke("Tell me a joke.")
"""
if self.is_stream:
text = ""
for chunk in self._stream(prompt, stop, run_manager, **kwargs):
text += chunk.text
if stop is not None:
text = enforce_stop_tokens(text, stop)
return text
invocation_obj = self._prepare_invocation_object(prompt, stop, kwargs)
response = self.client.generate_text(invocation_obj)
return self._process_response(response, stop)
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
"""Stream OCIGenAI LLM on given prompt.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
An iterator of GenerationChunks.
Example:
.. code-block:: python
response = llm.stream("Tell me a joke.")
"""
self.is_stream = True
invocation_obj = self._prepare_invocation_object(prompt, stop, kwargs)
response = self.client.generate_text(invocation_obj)
for event in response.data.events():
json_load = json.loads(event.data)
if "text" in json_load:
event_data_text = json_load["text"]
else:
event_data_text = ""
chunk = GenerationChunk(text=event_data_text)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk