Add OCI Generative AI new model and structured output support (#28754)

- [X] **PR title**: 
 community: Add new model and structured output support


- [X] **PR message**: 
- **Description:** add support for meta llama 3.2 image handling, and
JSON mode for structured output
    - **Issue:** NA
    - **Dependencies:** NA
    - **Twitter handle:** NA


- [x] **Add tests and docs**: 
  1. we have updated our unit tests,
  2. no changes required for documentation.


- [x] **Lint and test**: 
make format, make lint and make test we run successfully

---------

Co-authored-by: Arthur Cheng <arthur.cheng@oracle.com>
Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
Rave Harpaz 2024-12-18 06:50:25 -08:00 committed by GitHub
parent ef24220d3f
commit 986b752fc8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 170 additions and 30 deletions

View File

@ -26,14 +26,14 @@
"## Overview\n",
"### Integration details\n",
"\n",
"| Class | Package | Local | Serializable | [JS support](https://js.langchain.com/docs/integrations/chat/oci_generative_ai) | Package downloads | Package latest |\n",
"| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n",
"| [ChatOCIGenAI](https://python.langchain.com/api_reference/community/chat_models/langchain_community.chat_models.oci_generative_ai.ChatOCIGenAI.html) | [langchain-community](https://python.langchain.com/api_reference/community/index.html) | ❌ | ❌ | ❌ | ![PyPI - Downloads](https://img.shields.io/pypi/dm/langchain-oci-generative-ai?style=flat-square&label=%20) | ![PyPI - Version](https://img.shields.io/pypi/v/langchain-oci-generative-ai?style=flat-square&label=%20) |\n",
"| Class | Package | Local | Serializable | [JS support](https://js.langchain.com/docs/integrations/chat/oci_generative_ai) |\n",
"| :--- | :--- | :---: | :---: | :---: |\n",
"| [ChatOCIGenAI](https://python.langchain.com/api_reference/community/chat_models/langchain_community.chat_models.oci_generative_ai.ChatOCIGenAI.html) | [langchain-community](https://python.langchain.com/api_reference/community/index.html) | ❌ | ❌ | ❌ |\n",
"\n",
"### Model features\n",
"| [Tool calling](/docs/how_to/tool_calling/) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n",
"| [Tool calling](/docs/how_to/tool_calling/) | [Structured output](/docs/how_to/structured_output/) | [JSON mode](/docs/how_to/structured_output/#advanced-specifying-the-method-for-structuring-outputs) | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n",
"| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n",
"| ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | \n",
"| ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | \n",
"\n",
"## Setup\n",
"\n",

View File

@ -2,12 +2,14 @@ import json
import re
import uuid
from abc import ABC, abstractmethod
from operator import itemgetter
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Literal,
Mapping,
Optional,
Sequence,
@ -32,13 +34,17 @@ from langchain_core.messages import (
ToolMessage,
)
from langchain_core.messages.tool import ToolCallChunk
from langchain_core.output_parsers import (
JsonOutputParser,
PydanticOutputParser,
)
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_function
from pydantic import BaseModel, ConfigDict
@ -58,6 +64,10 @@ JSON_TO_PYTHON_TYPES = {
}
def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and issubclass(obj, BaseModel)
def _remove_signature_from_tool_description(name: str, description: str) -> str:
"""
Removes the `{name}{signature} - ` prefix and Args: section from tool description.
@ -158,7 +168,7 @@ class CohereProvider(Provider):
def chat_stream_to_text(self, event_data: Dict) -> str:
if "text" in event_data:
if "finishedReason" in event_data or "toolCalls" in event_data:
if "finishReason" in event_data or "toolCalls" in event_data:
return ""
else:
return event_data["text"]
@ -378,7 +388,10 @@ class MetaProvider(Provider):
"SYSTEM": models.SystemMessage,
"ASSISTANT": models.AssistantMessage,
}
self.oci_chat_message_content = models.TextContent
self.oci_chat_message_content = models.ChatContent
self.oci_chat_message_text_content = models.TextContent
self.oci_chat_message_image_content = models.ImageContent
self.oci_chat_message_image_url = models.ImageUrl
self.chat_api_format = models.BaseChatRequest.API_FORMAT_GENERIC
def chat_response_to_text(self, response: Any) -> str:
@ -415,19 +428,81 @@ class MetaProvider(Provider):
def messages_to_oci_params(
self, messages: List[BaseMessage], **kwargs: Any
) -> Dict[str, Any]:
oci_messages = [
self.oci_chat_message[self.get_role(msg)](
content=[self.oci_chat_message_content(text=msg.content)]
)
for msg in messages
]
oci_params = {
"""Convert LangChain messages to OCI chat parameters.
Args:
messages: List of LangChain BaseMessage objects
**kwargs: Additional keyword arguments
Returns:
Dict containing OCI chat parameters
Raises:
ValueError: If message content is invalid
"""
oci_messages = []
for message in messages:
content = self._process_message_content(message.content)
oci_message = self.oci_chat_message[self.get_role(message)](content=content)
oci_messages.append(oci_message)
return {
"messages": oci_messages,
"api_format": self.chat_api_format,
"top_k": -1,
}
return oci_params
def _process_message_content(
self, content: Union[str, List[Union[str, Dict]]]
) -> List[Any]:
"""Process message content into OCI chat content format.
Args:
content: Message content as string or list
Returns:
List of OCI chat content objects
Raises:
ValueError: If content format is invalid
"""
if isinstance(content, str):
return [self.oci_chat_message_text_content(text=content)]
if not isinstance(content, list):
raise ValueError("Message content must be str or list of items")
processed_content = []
for item in content:
if isinstance(item, str):
processed_content.append(self.oci_chat_message_text_content(text=item))
continue
if not isinstance(item, dict):
raise ValueError(
f"Content items must be str or dict, got: {type(item)}"
)
if "type" not in item:
raise ValueError("Dict content item must have a type key")
if item["type"] == "image_url":
processed_content.append(
self.oci_chat_message_image_content(
image_url=self.oci_chat_message_image_url(
url=item["image_url"]["url"]
)
)
)
elif item["type"] == "text":
processed_content.append(
self.oci_chat_message_text_content(text=item["text"])
)
else:
raise ValueError(f"Unsupported content type: {item['type']}")
return processed_content
def convert_to_oci_tool(
self,
@ -577,7 +652,10 @@ class ChatOCIGenAI(BaseChatModel, OCIGenAIBase):
def with_structured_output(
self,
schema: Union[Dict[Any, Any], Type[BaseModel]],
schema: Optional[Union[Dict, Type[BaseModel]]] = None,
*,
method: Literal["function_calling", "json_mode"] = "function_calling",
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema.
@ -585,24 +663,86 @@ class ChatOCIGenAI(BaseChatModel, OCIGenAIBase):
Args:
schema: The output schema as a dict or a Pydantic class. If a Pydantic class
then the model output will be an object of that class. If a dict then
the model output will be a dict.
the model output will be a dict. With a Pydantic class the returned
attributes will be validated, whereas with a dict they will not be. If
`method` is "function_calling" and `schema` is a dict, then the dict
must match the OCI Generative AI function-calling spec.
method:
The method for steering model generation, either "function_calling"
or "json_mode". If "function_calling" then the schema will be converted
to an OCI function and the returned model will make use of the
function-calling API. If "json_mode" then Cohere's JSON mode will be
used. Note that if using "json_mode" then you must include instructions
for formatting the output into the desired schema into the model call.
include_raw:
If False then only the parsed structured output is returned. If
an error occurs during model output parsing it will be raised. If True
then both the raw model response (a BaseMessage) and the parsed model
response will be returned. If an error occurs during output parsing it
will be caught and returned as well. The final output is always a dict
with keys "raw", "parsed", and "parsing_error".
Returns:
A Runnable that takes any ChatModel input and returns either a dict or
Pydantic class as output.
"""
llm = self.bind_tools([schema], **kwargs)
if isinstance(schema, type) and issubclass(schema, BaseModel):
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], first_tool_only=True
A Runnable that takes any ChatModel input and returns as output:
If include_raw is True then a dict with keys:
raw: BaseMessage
parsed: Optional[_DictOrPydantic]
parsing_error: Optional[BaseException]
If include_raw is False then just _DictOrPydantic is returned,
where _DictOrPydantic depends on the schema:
If schema is a Pydantic class then _DictOrPydantic is the Pydantic
class.
If schema is a dict then _DictOrPydantic is a dict.
""" # noqa: E501
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = _is_pydantic_class(schema)
if method == "function_calling":
if schema is None:
raise ValueError(
"schema must be specified when method is 'function_calling'. "
"Received None."
)
llm = self.bind_tools([schema], **kwargs)
tool_name = getattr(self._provider.convert_to_oci_tool(schema), "name")
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], # type: ignore[list-item]
first_tool_only=True, # type: ignore[list-item]
)
else:
output_parser = JsonOutputKeyToolsParser(
key_name=tool_name, first_tool_only=True
)
elif method == "json_mode":
llm = self.bind(response_format={"type": "json_object"})
output_parser = (
PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type]
if is_pydantic_schema
else JsonOutputParser()
)
else:
key_name = getattr(self._provider.convert_to_oci_tool(schema), "name")
output_parser = JsonOutputKeyToolsParser(
key_name=key_name, first_tool_only=True
raise ValueError(
f"Unrecognized method argument. "
f"Expected `function_calling` or `json_mode`."
f"Received: `{method}`."
)
return llm | output_parser
if include_raw:
parser_assign = RunnablePassthrough.assign(
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
)
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
parser_with_fallback = parser_assign.with_fallbacks(
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser
def _generate(
self,