mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-01 19:12:42 +00:00
Add OCI Generative AI new model support (#22880)
- [x] PR title: community: Add OCI Generative AI new model support - [x] PR message: - Description: adding support for new models offered by OCI Generative AI services. This is a moderate update of our initial integration PR 16548 and includes a new integration for our chat models under /langchain_community/chat_models/oci_generative_ai.py - Issue: NA - Dependencies: No new Dependencies, just latest version of our OCI sdk - Twitter handle: NA - [x] Add tests and docs: 1. we have updated our unit tests 2. we have updated our documentation including a new ipynb for our new chat integration - [x] Lint and test: `make format`, `make lint`, and `make test` run successfully --------- Co-authored-by: RHARPAZ <RHARPAZ@RHARPAZ-5750.us.oracle.com> Co-authored-by: Arthur Cheng <arthur.cheng@oracle.com>
This commit is contained in:
@@ -1,17 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
from typing import Any, Dict, Iterator, List, Mapping, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.outputs import GenerationChunk
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain_community.llms.utils import enforce_stop_tokens
|
||||
|
||||
CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint"
|
||||
VALID_PROVIDERS = ("cohere", "meta")
|
||||
|
||||
|
||||
class Provider(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def stop_sequence_key(self) -> str:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def completion_response_to_text(self, response: Any) -> str:
|
||||
...
|
||||
|
||||
|
||||
class CohereProvider(Provider):
|
||||
stop_sequence_key = "stop_sequences"
|
||||
|
||||
def __init__(self) -> None:
|
||||
from oci.generative_ai_inference import models
|
||||
|
||||
self.llm_inference_request = models.CohereLlmInferenceRequest
|
||||
|
||||
def completion_response_to_text(self, response: Any) -> str:
|
||||
return response.data.inference_response.generated_texts[0].text
|
||||
|
||||
|
||||
class MetaProvider(Provider):
|
||||
stop_sequence_key = "stop"
|
||||
|
||||
def __init__(self) -> None:
|
||||
from oci.generative_ai_inference import models
|
||||
|
||||
self.llm_inference_request = models.LlamaLlmInferenceRequest
|
||||
|
||||
def completion_response_to_text(self, response: Any) -> str:
|
||||
return response.data.inference_response.choices[0].text
|
||||
|
||||
|
||||
class OCIAuthType(Enum):
|
||||
@@ -33,8 +69,8 @@ class OCIGenAIBase(BaseModel, ABC):
|
||||
|
||||
API_KEY,
|
||||
SECURITY_TOKEN,
|
||||
INSTANCE_PRINCIPLE,
|
||||
RESOURCE_PRINCIPLE
|
||||
INSTANCE_PRINCIPAL,
|
||||
RESOURCE_PRINCIPAL
|
||||
|
||||
If not specified, API_KEY will be used
|
||||
"""
|
||||
@@ -65,11 +101,6 @@ class OCIGenAIBase(BaseModel, ABC):
|
||||
is_stream: bool = False
|
||||
"""Whether to stream back partial progress"""
|
||||
|
||||
llm_stop_sequence_mapping: Mapping[str, str] = {
|
||||
"cohere": "stop_sequences",
|
||||
"meta": "stop",
|
||||
}
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that OCI config and python package exists in environment."""
|
||||
@@ -121,24 +152,28 @@ class OCIGenAIBase(BaseModel, ABC):
|
||||
"signer"
|
||||
] = oci.auth.signers.get_resource_principals_signer()
|
||||
else:
|
||||
raise ValueError("Please provide valid value to auth_type")
|
||||
raise ValueError(
|
||||
"Please provide valid value to auth_type, "
|
||||
f"{values['auth_type']} is not valid."
|
||||
)
|
||||
|
||||
values["client"] = oci.generative_ai_inference.GenerativeAiInferenceClient(
|
||||
**client_kwargs
|
||||
)
|
||||
|
||||
except ImportError as ex:
|
||||
raise ImportError(
|
||||
raise ModuleNotFoundError(
|
||||
"Could not import oci python package. "
|
||||
"Please make sure you have the oci package installed."
|
||||
) from ex
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"Could not authenticate with OCI client. "
|
||||
"Please check if ~/.oci/config exists. "
|
||||
"If INSTANCE_PRINCIPLE or RESOURCE_PRINCIPLE is used, "
|
||||
"Please check the specified "
|
||||
"auth_profile and auth_type are valid."
|
||||
"""Could not authenticate with OCI client.
|
||||
Please check if ~/.oci/config exists.
|
||||
If INSTANCE_PRINCIPAL or RESOURCE_PRINCIPAL is used,
|
||||
please check the specified
|
||||
auth_profile and auth_type are valid.""",
|
||||
e,
|
||||
) from e
|
||||
|
||||
return values
|
||||
@@ -151,19 +186,19 @@ class OCIGenAIBase(BaseModel, ABC):
|
||||
**{"model_kwargs": _model_kwargs},
|
||||
}
|
||||
|
||||
def _get_provider(self) -> str:
|
||||
def _get_provider(self, provider_map: Mapping[str, Any]) -> Any:
|
||||
if self.provider is not None:
|
||||
provider = self.provider
|
||||
else:
|
||||
provider = self.model_id.split(".")[0].lower()
|
||||
|
||||
if provider not in VALID_PROVIDERS:
|
||||
if provider not in provider_map:
|
||||
raise ValueError(
|
||||
f"Invalid provider derived from model_id: {self.model_id} "
|
||||
"Please explicitly pass in the supported provider "
|
||||
"when using custom endpoint"
|
||||
)
|
||||
return provider
|
||||
return provider_map[provider]
|
||||
|
||||
|
||||
class OCIGenAI(LLM, OCIGenAIBase):
|
||||
@@ -173,7 +208,7 @@ class OCIGenAI(LLM, OCIGenAIBase):
|
||||
https://docs.oracle.com/en-us/iaas/Content/API/Concepts/sdk_authentication_methods.htm
|
||||
|
||||
The authentifcation method is passed through auth_type and should be one of:
|
||||
API_KEY (default), SECURITY_TOKEN, INSTANCE_PRINCIPLE, RESOURCE_PRINCIPLE
|
||||
API_KEY (default), SECURITY_TOKEN, INSTANCE_PRINCIPAL, RESOURCE_PRINCIPAL
|
||||
|
||||
Make sure you have the required policies (profile/roles) to
|
||||
access the OCI Generative AI service.
|
||||
@@ -204,21 +239,29 @@ class OCIGenAI(LLM, OCIGenAIBase):
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "oci"
|
||||
return "oci_generative_ai_completion"
|
||||
|
||||
@property
|
||||
def _provider_map(self) -> Mapping[str, Any]:
|
||||
"""Get the provider map"""
|
||||
return {
|
||||
"cohere": CohereProvider(),
|
||||
"meta": MetaProvider(),
|
||||
}
|
||||
|
||||
@property
|
||||
def _provider(self) -> Any:
|
||||
"""Get the internal provider object"""
|
||||
return self._get_provider(provider_map=self._provider_map)
|
||||
|
||||
def _prepare_invocation_object(
|
||||
self, prompt: str, stop: Optional[List[str]], kwargs: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
from oci.generative_ai_inference import models
|
||||
|
||||
oci_llm_request_mapping = {
|
||||
"cohere": models.CohereLlmInferenceRequest,
|
||||
"meta": models.LlamaLlmInferenceRequest,
|
||||
}
|
||||
provider = self._get_provider()
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
if stop is not None:
|
||||
_model_kwargs[self.llm_stop_sequence_mapping[provider]] = stop
|
||||
_model_kwargs[self._provider.stop_sequence_key] = stop
|
||||
|
||||
if self.model_id.startswith(CUSTOM_ENDPOINT_PREFIX):
|
||||
serving_mode = models.DedicatedServingMode(endpoint_id=self.model_id)
|
||||
@@ -232,19 +275,13 @@ class OCIGenAI(LLM, OCIGenAIBase):
|
||||
invocation_obj = models.GenerateTextDetails(
|
||||
compartment_id=self.compartment_id,
|
||||
serving_mode=serving_mode,
|
||||
inference_request=oci_llm_request_mapping[provider](**inference_params),
|
||||
inference_request=self._provider.llm_inference_request(**inference_params),
|
||||
)
|
||||
|
||||
return invocation_obj
|
||||
|
||||
def _process_response(self, response: Any, stop: Optional[List[str]]) -> str:
|
||||
provider = self._get_provider()
|
||||
if provider == "cohere":
|
||||
text = response.data.inference_response.generated_texts[0].text
|
||||
elif provider == "meta":
|
||||
text = response.data.inference_response.choices[0].text
|
||||
else:
|
||||
raise ValueError(f"Invalid provider: {provider}")
|
||||
text = self._provider.completion_response_to_text(response)
|
||||
|
||||
if stop is not None:
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
@@ -272,7 +309,51 @@ class OCIGenAI(LLM, OCIGenAIBase):
|
||||
|
||||
response = llm.invoke("Tell me a joke.")
|
||||
"""
|
||||
if self.is_stream:
|
||||
text = ""
|
||||
for chunk in self._stream(prompt, stop, run_manager, **kwargs):
|
||||
text += chunk.text
|
||||
if stop is not None:
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
return text
|
||||
|
||||
invocation_obj = self._prepare_invocation_object(prompt, stop, kwargs)
|
||||
response = self.client.generate_text(invocation_obj)
|
||||
return self._process_response(response, stop)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
"""Stream OCIGenAI LLM on given prompt.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
An iterator of GenerationChunks.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = llm.stream("Tell me a joke.")
|
||||
"""
|
||||
|
||||
self.is_stream = True
|
||||
invocation_obj = self._prepare_invocation_object(prompt, stop, kwargs)
|
||||
response = self.client.generate_text(invocation_obj)
|
||||
|
||||
for event in response.data.events():
|
||||
json_load = json.loads(event.data)
|
||||
if "text" in json_load:
|
||||
event_data_text = json_load["text"]
|
||||
else:
|
||||
event_data_text = ""
|
||||
chunk = GenerationChunk(text=event_data_text)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||
yield chunk
|
||||
|
Reference in New Issue
Block a user