Add OCI Generative AI new model and structured output support (#28754)

- [X] **PR title**: 
 community: Add new model and structured output support


- [X] **PR message**: 
- **Description:** add support for meta llama 3.2 image handling, and
JSON mode for structured output
    - **Issue:** NA
    - **Dependencies:** NA
    - **Twitter handle:** NA


- [x] **Add tests and docs**: 
  1. we have updated our unit tests,
  2. no changes required for documentation.


- [x] **Lint and test**: 
make format, make lint and make test we run successfully

---------

Co-authored-by: Arthur Cheng <arthur.cheng@oracle.com>
Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
Rave Harpaz 2024-12-18 06:50:25 -08:00 committed by GitHub
parent ef24220d3f
commit 986b752fc8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 170 additions and 30 deletions

View File

@ -26,14 +26,14 @@
"## Overview\n", "## Overview\n",
"### Integration details\n", "### Integration details\n",
"\n", "\n",
"| Class | Package | Local | Serializable | [JS support](https://js.langchain.com/docs/integrations/chat/oci_generative_ai) | Package downloads | Package latest |\n", "| Class | Package | Local | Serializable | [JS support](https://js.langchain.com/docs/integrations/chat/oci_generative_ai) |\n",
"| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n", "| :--- | :--- | :---: | :---: | :---: |\n",
"| [ChatOCIGenAI](https://python.langchain.com/api_reference/community/chat_models/langchain_community.chat_models.oci_generative_ai.ChatOCIGenAI.html) | [langchain-community](https://python.langchain.com/api_reference/community/index.html) | ❌ | ❌ | ❌ | ![PyPI - Downloads](https://img.shields.io/pypi/dm/langchain-oci-generative-ai?style=flat-square&label=%20) | ![PyPI - Version](https://img.shields.io/pypi/v/langchain-oci-generative-ai?style=flat-square&label=%20) |\n", "| [ChatOCIGenAI](https://python.langchain.com/api_reference/community/chat_models/langchain_community.chat_models.oci_generative_ai.ChatOCIGenAI.html) | [langchain-community](https://python.langchain.com/api_reference/community/index.html) | ❌ | ❌ | ❌ |\n",
"\n", "\n",
"### Model features\n", "### Model features\n",
"| [Tool calling](/docs/how_to/tool_calling/) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n", "| [Tool calling](/docs/how_to/tool_calling/) | [Structured output](/docs/how_to/structured_output/) | [JSON mode](/docs/how_to/structured_output/#advanced-specifying-the-method-for-structuring-outputs) | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n",
"| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n", "| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n",
"| ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | \n", "| ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | \n",
"\n", "\n",
"## Setup\n", "## Setup\n",
"\n", "\n",

View File

@ -2,12 +2,14 @@ import json
import re import re
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from operator import itemgetter
from typing import ( from typing import (
Any, Any,
Callable, Callable,
Dict, Dict,
Iterator, Iterator,
List, List,
Literal,
Mapping, Mapping,
Optional, Optional,
Sequence, Sequence,
@ -32,13 +34,17 @@ from langchain_core.messages import (
ToolMessage, ToolMessage,
) )
from langchain_core.messages.tool import ToolCallChunk from langchain_core.messages.tool import ToolCallChunk
from langchain_core.output_parsers import (
JsonOutputParser,
PydanticOutputParser,
)
from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import ( from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser, JsonOutputKeyToolsParser,
PydanticToolsParser, PydanticToolsParser,
) )
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_function from langchain_core.utils.function_calling import convert_to_openai_function
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
@ -58,6 +64,10 @@ JSON_TO_PYTHON_TYPES = {
} }
def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and issubclass(obj, BaseModel)
def _remove_signature_from_tool_description(name: str, description: str) -> str: def _remove_signature_from_tool_description(name: str, description: str) -> str:
""" """
Removes the `{name}{signature} - ` prefix and Args: section from tool description. Removes the `{name}{signature} - ` prefix and Args: section from tool description.
@ -158,7 +168,7 @@ class CohereProvider(Provider):
def chat_stream_to_text(self, event_data: Dict) -> str: def chat_stream_to_text(self, event_data: Dict) -> str:
if "text" in event_data: if "text" in event_data:
if "finishedReason" in event_data or "toolCalls" in event_data: if "finishReason" in event_data or "toolCalls" in event_data:
return "" return ""
else: else:
return event_data["text"] return event_data["text"]
@ -378,7 +388,10 @@ class MetaProvider(Provider):
"SYSTEM": models.SystemMessage, "SYSTEM": models.SystemMessage,
"ASSISTANT": models.AssistantMessage, "ASSISTANT": models.AssistantMessage,
} }
self.oci_chat_message_content = models.TextContent self.oci_chat_message_content = models.ChatContent
self.oci_chat_message_text_content = models.TextContent
self.oci_chat_message_image_content = models.ImageContent
self.oci_chat_message_image_url = models.ImageUrl
self.chat_api_format = models.BaseChatRequest.API_FORMAT_GENERIC self.chat_api_format = models.BaseChatRequest.API_FORMAT_GENERIC
def chat_response_to_text(self, response: Any) -> str: def chat_response_to_text(self, response: Any) -> str:
@ -415,19 +428,81 @@ class MetaProvider(Provider):
def messages_to_oci_params( def messages_to_oci_params(
self, messages: List[BaseMessage], **kwargs: Any self, messages: List[BaseMessage], **kwargs: Any
) -> Dict[str, Any]: ) -> Dict[str, Any]:
oci_messages = [ """Convert LangChain messages to OCI chat parameters.
self.oci_chat_message[self.get_role(msg)](
content=[self.oci_chat_message_content(text=msg.content)] Args:
) messages: List of LangChain BaseMessage objects
for msg in messages **kwargs: Additional keyword arguments
]
oci_params = { Returns:
Dict containing OCI chat parameters
Raises:
ValueError: If message content is invalid
"""
oci_messages = []
for message in messages:
content = self._process_message_content(message.content)
oci_message = self.oci_chat_message[self.get_role(message)](content=content)
oci_messages.append(oci_message)
return {
"messages": oci_messages, "messages": oci_messages,
"api_format": self.chat_api_format, "api_format": self.chat_api_format,
"top_k": -1, "top_k": -1,
} }
return oci_params def _process_message_content(
self, content: Union[str, List[Union[str, Dict]]]
) -> List[Any]:
"""Process message content into OCI chat content format.
Args:
content: Message content as string or list
Returns:
List of OCI chat content objects
Raises:
ValueError: If content format is invalid
"""
if isinstance(content, str):
return [self.oci_chat_message_text_content(text=content)]
if not isinstance(content, list):
raise ValueError("Message content must be str or list of items")
processed_content = []
for item in content:
if isinstance(item, str):
processed_content.append(self.oci_chat_message_text_content(text=item))
continue
if not isinstance(item, dict):
raise ValueError(
f"Content items must be str or dict, got: {type(item)}"
)
if "type" not in item:
raise ValueError("Dict content item must have a type key")
if item["type"] == "image_url":
processed_content.append(
self.oci_chat_message_image_content(
image_url=self.oci_chat_message_image_url(
url=item["image_url"]["url"]
)
)
)
elif item["type"] == "text":
processed_content.append(
self.oci_chat_message_text_content(text=item["text"])
)
else:
raise ValueError(f"Unsupported content type: {item['type']}")
return processed_content
def convert_to_oci_tool( def convert_to_oci_tool(
self, self,
@ -577,7 +652,10 @@ class ChatOCIGenAI(BaseChatModel, OCIGenAIBase):
def with_structured_output( def with_structured_output(
self, self,
schema: Union[Dict[Any, Any], Type[BaseModel]], schema: Optional[Union[Dict, Type[BaseModel]]] = None,
*,
method: Literal["function_calling", "json_mode"] = "function_calling",
include_raw: bool = False,
**kwargs: Any, **kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema. """Model wrapper that returns outputs formatted to match the given schema.
@ -585,24 +663,86 @@ class ChatOCIGenAI(BaseChatModel, OCIGenAIBase):
Args: Args:
schema: The output schema as a dict or a Pydantic class. If a Pydantic class schema: The output schema as a dict or a Pydantic class. If a Pydantic class
then the model output will be an object of that class. If a dict then then the model output will be an object of that class. If a dict then
the model output will be a dict. the model output will be a dict. With a Pydantic class the returned
attributes will be validated, whereas with a dict they will not be. If
`method` is "function_calling" and `schema` is a dict, then the dict
must match the OCI Generative AI function-calling spec.
method:
The method for steering model generation, either "function_calling"
or "json_mode". If "function_calling" then the schema will be converted
to an OCI function and the returned model will make use of the
function-calling API. If "json_mode" then Cohere's JSON mode will be
used. Note that if using "json_mode" then you must include instructions
for formatting the output into the desired schema into the model call.
include_raw:
If False then only the parsed structured output is returned. If
an error occurs during model output parsing it will be raised. If True
then both the raw model response (a BaseMessage) and the parsed model
response will be returned. If an error occurs during output parsing it
will be caught and returned as well. The final output is always a dict
with keys "raw", "parsed", and "parsing_error".
Returns: Returns:
A Runnable that takes any ChatModel input and returns either a dict or A Runnable that takes any ChatModel input and returns as output:
Pydantic class as output.
""" If include_raw is True then a dict with keys:
llm = self.bind_tools([schema], **kwargs) raw: BaseMessage
if isinstance(schema, type) and issubclass(schema, BaseModel): parsed: Optional[_DictOrPydantic]
output_parser: OutputParserLike = PydanticToolsParser( parsing_error: Optional[BaseException]
tools=[schema], first_tool_only=True
If include_raw is False then just _DictOrPydantic is returned,
where _DictOrPydantic depends on the schema:
If schema is a Pydantic class then _DictOrPydantic is the Pydantic
class.
If schema is a dict then _DictOrPydantic is a dict.
""" # noqa: E501
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = _is_pydantic_class(schema)
if method == "function_calling":
if schema is None:
raise ValueError(
"schema must be specified when method is 'function_calling'. "
"Received None."
)
llm = self.bind_tools([schema], **kwargs)
tool_name = getattr(self._provider.convert_to_oci_tool(schema), "name")
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], # type: ignore[list-item]
first_tool_only=True, # type: ignore[list-item]
)
else:
output_parser = JsonOutputKeyToolsParser(
key_name=tool_name, first_tool_only=True
)
elif method == "json_mode":
llm = self.bind(response_format={"type": "json_object"})
output_parser = (
PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type]
if is_pydantic_schema
else JsonOutputParser()
) )
else: else:
key_name = getattr(self._provider.convert_to_oci_tool(schema), "name") raise ValueError(
output_parser = JsonOutputKeyToolsParser( f"Unrecognized method argument. "
key_name=key_name, first_tool_only=True f"Expected `function_calling` or `json_mode`."
f"Received: `{method}`."
) )
if include_raw:
return llm | output_parser parser_assign = RunnablePassthrough.assign(
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
)
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
parser_with_fallback = parser_assign.with_fallbacks(
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser
def _generate( def _generate(
self, self,