mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-08 22:42:05 +00:00
Add OCI Generative AI new model support (#22880)
- [x] PR title: community: Add OCI Generative AI new model support - [x] PR message: - Description: adding support for new models offered by OCI Generative AI services. This is a moderate update of our initial integration PR 16548 and includes a new integration for our chat models under /langchain_community/chat_models/oci_generative_ai.py - Issue: NA - Dependencies: No new Dependencies, just latest version of our OCI sdk - Twitter handle: NA - [x] Add tests and docs: 1. we have updated our unit tests 2. we have updated our documentation including a new ipynb for our new chat integration - [x] Lint and test: `make format`, `make lint`, and `make test` run successfully --------- Co-authored-by: RHARPAZ <RHARPAZ@RHARPAZ-5750.us.oracle.com> Co-authored-by: Arthur Cheng <arthur.cheng@oracle.com>
This commit is contained in:
@@ -27,6 +27,7 @@ EXPECTED_ALL = [
|
||||
"ChatMlflow",
|
||||
"ChatMLflowAIGateway",
|
||||
"ChatMLX",
|
||||
"ChatOCIGenAI",
|
||||
"ChatOllama",
|
||||
"ChatOpenAI",
|
||||
"ChatPerplexity",
|
||||
|
@@ -0,0 +1,105 @@
|
||||
"""Test OCI Generative AI LLM service"""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from langchain_community.chat_models.oci_generative_ai import ChatOCIGenAI
|
||||
|
||||
|
||||
class MockResponseDict(dict):
|
||||
def __getattr__(self, val): # type: ignore[no-untyped-def]
|
||||
return self[val]
|
||||
|
||||
|
||||
@pytest.mark.requires("oci")
|
||||
@pytest.mark.parametrize(
|
||||
"test_model_id", ["cohere.command-r-16k", "meta.llama-3-70b-instruct"]
|
||||
)
|
||||
def test_llm_chat(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
|
||||
"""Test valid chat call to OCI Generative AI LLM service."""
|
||||
oci_gen_ai_client = MagicMock()
|
||||
llm = ChatOCIGenAI(model_id=test_model_id, client=oci_gen_ai_client)
|
||||
|
||||
provider = llm.model_id.split(".")[0].lower()
|
||||
|
||||
def mocked_response(*args): # type: ignore[no-untyped-def]
|
||||
response_text = "Assistant chat reply."
|
||||
response = None
|
||||
if provider == "cohere":
|
||||
response = MockResponseDict(
|
||||
{
|
||||
"status": 200,
|
||||
"data": MockResponseDict(
|
||||
{
|
||||
"chat_response": MockResponseDict(
|
||||
{
|
||||
"text": response_text,
|
||||
"finish_reason": "completed",
|
||||
}
|
||||
),
|
||||
"model_id": "cohere.command-r-16k",
|
||||
"model_version": "1.0.0",
|
||||
}
|
||||
),
|
||||
"request_id": "1234567890",
|
||||
"headers": MockResponseDict(
|
||||
{
|
||||
"content-length": "123",
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
elif provider == "meta":
|
||||
response = MockResponseDict(
|
||||
{
|
||||
"status": 200,
|
||||
"data": MockResponseDict(
|
||||
{
|
||||
"chat_response": MockResponseDict(
|
||||
{
|
||||
"choices": [
|
||||
MockResponseDict(
|
||||
{
|
||||
"message": MockResponseDict(
|
||||
{
|
||||
"content": [
|
||||
MockResponseDict(
|
||||
{
|
||||
"text": response_text, # noqa: E501
|
||||
}
|
||||
)
|
||||
]
|
||||
}
|
||||
),
|
||||
"finish_reason": "completed",
|
||||
}
|
||||
)
|
||||
],
|
||||
"time_created": "2024-09-01T00:00:00Z",
|
||||
}
|
||||
),
|
||||
"model_id": "cohere.command-r-16k",
|
||||
"model_version": "1.0.0",
|
||||
}
|
||||
),
|
||||
"request_id": "1234567890",
|
||||
"headers": MockResponseDict(
|
||||
{
|
||||
"content-length": "123",
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
return response
|
||||
|
||||
monkeypatch.setattr(llm.client, "chat", mocked_response)
|
||||
|
||||
messages = [
|
||||
HumanMessage(content="User message"),
|
||||
]
|
||||
|
||||
expected = "Assistant chat reply."
|
||||
actual = llm.invoke(messages, temperature=0.2)
|
||||
assert actual.content == expected
|
@@ -4,7 +4,7 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from langchain_community.llms import OCIGenAI
|
||||
from langchain_community.llms.oci_generative_ai import OCIGenAI
|
||||
|
||||
|
||||
class MockResponseDict(dict):
|
||||
@@ -16,12 +16,12 @@ class MockResponseDict(dict):
|
||||
@pytest.mark.parametrize(
|
||||
"test_model_id", ["cohere.command", "cohere.command-light", "meta.llama-2-70b-chat"]
|
||||
)
|
||||
def test_llm_call(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
|
||||
"""Test valid call to OCI Generative AI LLM service."""
|
||||
def test_llm_complete(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
|
||||
"""Test valid completion call to OCI Generative AI LLM service."""
|
||||
oci_gen_ai_client = MagicMock()
|
||||
llm = OCIGenAI(model_id=test_model_id, client=oci_gen_ai_client)
|
||||
|
||||
provider = llm._get_provider()
|
||||
provider = llm.model_id.split(".")[0].lower()
|
||||
|
||||
def mocked_response(*args): # type: ignore[no-untyped-def]
|
||||
response_text = "This is the completion."
|
||||
@@ -71,6 +71,5 @@ def test_llm_call(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
|
||||
)
|
||||
|
||||
monkeypatch.setattr(llm.client, "generate_text", mocked_response)
|
||||
|
||||
output = llm.invoke("This is a prompt.", temperature=0.2)
|
||||
assert output == "This is the completion."
|
||||
|
Reference in New Issue
Block a user