From 986b752fc81f00b066f2d797b3f8b6b2aaceeb33 Mon Sep 17 00:00:00 2001 From: Rave Harpaz Date: Wed, 18 Dec 2024 06:50:25 -0800 Subject: [PATCH] Add OCI Generative AI new model and structured output support (#28754) - [X] **PR title**: community: Add new model and structured output support - [X] **PR message**: - **Description:** add support for meta llama 3.2 image handling, and JSON mode for structured output - **Issue:** NA - **Dependencies:** NA - **Twitter handle:** NA - [x] **Add tests and docs**: 1. we have updated our unit tests, 2. no changes required for documentation. - [x] **Lint and test**: make format, make lint and make test we run successfully --------- Co-authored-by: Arthur Cheng Co-authored-by: ccurme --- .../integrations/chat/oci_generative_ai.ipynb | 10 +- .../chat_models/oci_generative_ai.py | 190 +++++++++++++++--- 2 files changed, 170 insertions(+), 30 deletions(-) diff --git a/docs/docs/integrations/chat/oci_generative_ai.ipynb b/docs/docs/integrations/chat/oci_generative_ai.ipynb index d5dc26cf25e..49efa3a5752 100644 --- a/docs/docs/integrations/chat/oci_generative_ai.ipynb +++ b/docs/docs/integrations/chat/oci_generative_ai.ipynb @@ -26,14 +26,14 @@ "## Overview\n", "### Integration details\n", "\n", - "| Class | Package | Local | Serializable | [JS support](https://js.langchain.com/docs/integrations/chat/oci_generative_ai) | Package downloads | Package latest |\n", - "| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n", - "| [ChatOCIGenAI](https://python.langchain.com/api_reference/community/chat_models/langchain_community.chat_models.oci_generative_ai.ChatOCIGenAI.html) | [langchain-community](https://python.langchain.com/api_reference/community/index.html) | ❌ | ❌ | ❌ | ![PyPI - Downloads](https://img.shields.io/pypi/dm/langchain-oci-generative-ai?style=flat-square&label=%20) | ![PyPI - Version](https://img.shields.io/pypi/v/langchain-oci-generative-ai?style=flat-square&label=%20) |\n", + "| Class | Package | Local | Serializable | [JS support](https://js.langchain.com/docs/integrations/chat/oci_generative_ai) |\n", + "| :--- | :--- | :---: | :---: | :---: |\n", + "| [ChatOCIGenAI](https://python.langchain.com/api_reference/community/chat_models/langchain_community.chat_models.oci_generative_ai.ChatOCIGenAI.html) | [langchain-community](https://python.langchain.com/api_reference/community/index.html) | ❌ | ❌ | ❌ |\n", "\n", "### Model features\n", - "| [Tool calling](/docs/how_to/tool_calling/) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n", + "| [Tool calling](/docs/how_to/tool_calling/) | [Structured output](/docs/how_to/structured_output/) | [JSON mode](/docs/how_to/structured_output/#advanced-specifying-the-method-for-structuring-outputs) | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n", "| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n", - "| ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | \n", + "| ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | \n", "\n", "## Setup\n", "\n", diff --git a/libs/community/langchain_community/chat_models/oci_generative_ai.py b/libs/community/langchain_community/chat_models/oci_generative_ai.py index 8600dd5e923..f66449bc564 100644 --- a/libs/community/langchain_community/chat_models/oci_generative_ai.py +++ b/libs/community/langchain_community/chat_models/oci_generative_ai.py @@ -2,12 +2,14 @@ import json import re import uuid from abc import ABC, abstractmethod +from operator import itemgetter from typing import ( Any, Callable, Dict, Iterator, List, + Literal, Mapping, Optional, Sequence, @@ -32,13 +34,17 @@ from langchain_core.messages import ( ToolMessage, ) from langchain_core.messages.tool import ToolCallChunk +from langchain_core.output_parsers import ( + JsonOutputParser, + PydanticOutputParser, +) from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.openai_tools import ( JsonOutputKeyToolsParser, PydanticToolsParser, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.runnables import Runnable +from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.tools import BaseTool from langchain_core.utils.function_calling import convert_to_openai_function from pydantic import BaseModel, ConfigDict @@ -58,6 +64,10 @@ JSON_TO_PYTHON_TYPES = { } +def _is_pydantic_class(obj: Any) -> bool: + return isinstance(obj, type) and issubclass(obj, BaseModel) + + def _remove_signature_from_tool_description(name: str, description: str) -> str: """ Removes the `{name}{signature} - ` prefix and Args: section from tool description. @@ -158,7 +168,7 @@ class CohereProvider(Provider): def chat_stream_to_text(self, event_data: Dict) -> str: if "text" in event_data: - if "finishedReason" in event_data or "toolCalls" in event_data: + if "finishReason" in event_data or "toolCalls" in event_data: return "" else: return event_data["text"] @@ -378,7 +388,10 @@ class MetaProvider(Provider): "SYSTEM": models.SystemMessage, "ASSISTANT": models.AssistantMessage, } - self.oci_chat_message_content = models.TextContent + self.oci_chat_message_content = models.ChatContent + self.oci_chat_message_text_content = models.TextContent + self.oci_chat_message_image_content = models.ImageContent + self.oci_chat_message_image_url = models.ImageUrl self.chat_api_format = models.BaseChatRequest.API_FORMAT_GENERIC def chat_response_to_text(self, response: Any) -> str: @@ -415,19 +428,81 @@ class MetaProvider(Provider): def messages_to_oci_params( self, messages: List[BaseMessage], **kwargs: Any ) -> Dict[str, Any]: - oci_messages = [ - self.oci_chat_message[self.get_role(msg)]( - content=[self.oci_chat_message_content(text=msg.content)] - ) - for msg in messages - ] - oci_params = { + """Convert LangChain messages to OCI chat parameters. + + Args: + messages: List of LangChain BaseMessage objects + **kwargs: Additional keyword arguments + + Returns: + Dict containing OCI chat parameters + + Raises: + ValueError: If message content is invalid + """ + oci_messages = [] + + for message in messages: + content = self._process_message_content(message.content) + oci_message = self.oci_chat_message[self.get_role(message)](content=content) + oci_messages.append(oci_message) + + return { "messages": oci_messages, "api_format": self.chat_api_format, "top_k": -1, } - return oci_params + def _process_message_content( + self, content: Union[str, List[Union[str, Dict]]] + ) -> List[Any]: + """Process message content into OCI chat content format. + + Args: + content: Message content as string or list + + Returns: + List of OCI chat content objects + + Raises: + ValueError: If content format is invalid + """ + if isinstance(content, str): + return [self.oci_chat_message_text_content(text=content)] + + if not isinstance(content, list): + raise ValueError("Message content must be str or list of items") + + processed_content = [] + for item in content: + if isinstance(item, str): + processed_content.append(self.oci_chat_message_text_content(text=item)) + continue + + if not isinstance(item, dict): + raise ValueError( + f"Content items must be str or dict, got: {type(item)}" + ) + + if "type" not in item: + raise ValueError("Dict content item must have a type key") + + if item["type"] == "image_url": + processed_content.append( + self.oci_chat_message_image_content( + image_url=self.oci_chat_message_image_url( + url=item["image_url"]["url"] + ) + ) + ) + elif item["type"] == "text": + processed_content.append( + self.oci_chat_message_text_content(text=item["text"]) + ) + else: + raise ValueError(f"Unsupported content type: {item['type']}") + + return processed_content def convert_to_oci_tool( self, @@ -577,7 +652,10 @@ class ChatOCIGenAI(BaseChatModel, OCIGenAIBase): def with_structured_output( self, - schema: Union[Dict[Any, Any], Type[BaseModel]], + schema: Optional[Union[Dict, Type[BaseModel]]] = None, + *, + method: Literal["function_calling", "json_mode"] = "function_calling", + include_raw: bool = False, **kwargs: Any, ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: """Model wrapper that returns outputs formatted to match the given schema. @@ -585,24 +663,86 @@ class ChatOCIGenAI(BaseChatModel, OCIGenAIBase): Args: schema: The output schema as a dict or a Pydantic class. If a Pydantic class then the model output will be an object of that class. If a dict then - the model output will be a dict. + the model output will be a dict. With a Pydantic class the returned + attributes will be validated, whereas with a dict they will not be. If + `method` is "function_calling" and `schema` is a dict, then the dict + must match the OCI Generative AI function-calling spec. + method: + The method for steering model generation, either "function_calling" + or "json_mode". If "function_calling" then the schema will be converted + to an OCI function and the returned model will make use of the + function-calling API. If "json_mode" then Cohere's JSON mode will be + used. Note that if using "json_mode" then you must include instructions + for formatting the output into the desired schema into the model call. + include_raw: + If False then only the parsed structured output is returned. If + an error occurs during model output parsing it will be raised. If True + then both the raw model response (a BaseMessage) and the parsed model + response will be returned. If an error occurs during output parsing it + will be caught and returned as well. The final output is always a dict + with keys "raw", "parsed", and "parsing_error". Returns: - A Runnable that takes any ChatModel input and returns either a dict or - Pydantic class as output. - """ - llm = self.bind_tools([schema], **kwargs) - if isinstance(schema, type) and issubclass(schema, BaseModel): - output_parser: OutputParserLike = PydanticToolsParser( - tools=[schema], first_tool_only=True + A Runnable that takes any ChatModel input and returns as output: + + If include_raw is True then a dict with keys: + raw: BaseMessage + parsed: Optional[_DictOrPydantic] + parsing_error: Optional[BaseException] + + If include_raw is False then just _DictOrPydantic is returned, + where _DictOrPydantic depends on the schema: + + If schema is a Pydantic class then _DictOrPydantic is the Pydantic + class. + + If schema is a dict then _DictOrPydantic is a dict. + + """ # noqa: E501 + if kwargs: + raise ValueError(f"Received unsupported arguments {kwargs}") + is_pydantic_schema = _is_pydantic_class(schema) + if method == "function_calling": + if schema is None: + raise ValueError( + "schema must be specified when method is 'function_calling'. " + "Received None." + ) + llm = self.bind_tools([schema], **kwargs) + tool_name = getattr(self._provider.convert_to_oci_tool(schema), "name") + if is_pydantic_schema: + output_parser: OutputParserLike = PydanticToolsParser( + tools=[schema], # type: ignore[list-item] + first_tool_only=True, # type: ignore[list-item] + ) + else: + output_parser = JsonOutputKeyToolsParser( + key_name=tool_name, first_tool_only=True + ) + elif method == "json_mode": + llm = self.bind(response_format={"type": "json_object"}) + output_parser = ( + PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type] + if is_pydantic_schema + else JsonOutputParser() ) else: - key_name = getattr(self._provider.convert_to_oci_tool(schema), "name") - output_parser = JsonOutputKeyToolsParser( - key_name=key_name, first_tool_only=True + raise ValueError( + f"Unrecognized method argument. " + f"Expected `function_calling` or `json_mode`." + f"Received: `{method}`." ) - - return llm | output_parser + if include_raw: + parser_assign = RunnablePassthrough.assign( + parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None + ) + parser_none = RunnablePassthrough.assign(parsed=lambda _: None) + parser_with_fallback = parser_assign.with_fallbacks( + [parser_none], exception_key="parsing_error" + ) + return RunnableMap(raw=llm) | parser_with_fallback + else: + return llm | output_parser def _generate( self,