Add OCI Generative AI new model support (#22880)

- [x] PR title: 
community: Add OCI Generative AI new model support
 
- [x] PR message:
- Description: adding support for new models offered by OCI Generative
AI services. This is a moderate update of our initial integration PR
16548 and includes a new integration for our chat models under
/langchain_community/chat_models/oci_generative_ai.py
    - Issue: NA
- Dependencies: No new Dependencies, just latest version of our OCI sdk
    - Twitter handle: NA


- [x] Add tests and docs: 
  1. we have updated our unit tests
2. we have updated our documentation including a new ipynb for our new
chat integration


- [x] Lint and test: 
 `make format`, `make lint`, and `make test` run successfully

---------

Co-authored-by: RHARPAZ <RHARPAZ@RHARPAZ-5750.us.oracle.com>
Co-authored-by: Arthur Cheng <arthur.cheng@oracle.com>
This commit is contained in:
Rave Harpaz
2024-06-24 11:48:23 -07:00
committed by GitHub
parent 753edf9c80
commit f5ff7f178b
10 changed files with 885 additions and 146 deletions

View File

@@ -27,6 +27,7 @@ EXPECTED_ALL = [
"ChatMlflow",
"ChatMLflowAIGateway",
"ChatMLX",
"ChatOCIGenAI",
"ChatOllama",
"ChatOpenAI",
"ChatPerplexity",

View File

@@ -0,0 +1,105 @@
"""Test OCI Generative AI LLM service"""
from unittest.mock import MagicMock
import pytest
from langchain_core.messages import HumanMessage
from pytest import MonkeyPatch
from langchain_community.chat_models.oci_generative_ai import ChatOCIGenAI
class MockResponseDict(dict):
def __getattr__(self, val): # type: ignore[no-untyped-def]
return self[val]
@pytest.mark.requires("oci")
@pytest.mark.parametrize(
"test_model_id", ["cohere.command-r-16k", "meta.llama-3-70b-instruct"]
)
def test_llm_chat(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
"""Test valid chat call to OCI Generative AI LLM service."""
oci_gen_ai_client = MagicMock()
llm = ChatOCIGenAI(model_id=test_model_id, client=oci_gen_ai_client)
provider = llm.model_id.split(".")[0].lower()
def mocked_response(*args): # type: ignore[no-untyped-def]
response_text = "Assistant chat reply."
response = None
if provider == "cohere":
response = MockResponseDict(
{
"status": 200,
"data": MockResponseDict(
{
"chat_response": MockResponseDict(
{
"text": response_text,
"finish_reason": "completed",
}
),
"model_id": "cohere.command-r-16k",
"model_version": "1.0.0",
}
),
"request_id": "1234567890",
"headers": MockResponseDict(
{
"content-length": "123",
}
),
}
)
elif provider == "meta":
response = MockResponseDict(
{
"status": 200,
"data": MockResponseDict(
{
"chat_response": MockResponseDict(
{
"choices": [
MockResponseDict(
{
"message": MockResponseDict(
{
"content": [
MockResponseDict(
{
"text": response_text, # noqa: E501
}
)
]
}
),
"finish_reason": "completed",
}
)
],
"time_created": "2024-09-01T00:00:00Z",
}
),
"model_id": "cohere.command-r-16k",
"model_version": "1.0.0",
}
),
"request_id": "1234567890",
"headers": MockResponseDict(
{
"content-length": "123",
}
),
}
)
return response
monkeypatch.setattr(llm.client, "chat", mocked_response)
messages = [
HumanMessage(content="User message"),
]
expected = "Assistant chat reply."
actual = llm.invoke(messages, temperature=0.2)
assert actual.content == expected

View File

@@ -4,7 +4,7 @@ from unittest.mock import MagicMock
import pytest
from pytest import MonkeyPatch
from langchain_community.llms import OCIGenAI
from langchain_community.llms.oci_generative_ai import OCIGenAI
class MockResponseDict(dict):
@@ -16,12 +16,12 @@ class MockResponseDict(dict):
@pytest.mark.parametrize(
"test_model_id", ["cohere.command", "cohere.command-light", "meta.llama-2-70b-chat"]
)
def test_llm_call(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
"""Test valid call to OCI Generative AI LLM service."""
def test_llm_complete(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
"""Test valid completion call to OCI Generative AI LLM service."""
oci_gen_ai_client = MagicMock()
llm = OCIGenAI(model_id=test_model_id, client=oci_gen_ai_client)
provider = llm._get_provider()
provider = llm.model_id.split(".")[0].lower()
def mocked_response(*args): # type: ignore[no-untyped-def]
response_text = "This is the completion."
@@ -71,6 +71,5 @@ def test_llm_call(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
)
monkeypatch.setattr(llm.client, "generate_text", mocked_response)
output = llm.invoke("This is a prompt.", temperature=0.2)
assert output == "This is the completion."