mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 15:19:33 +00:00
Add OCI Generative AI new model and structured output support (#28754)
- [X] **PR title**: community: Add new model and structured output support - [X] **PR message**: - **Description:** add support for meta llama 3.2 image handling, and JSON mode for structured output - **Issue:** NA - **Dependencies:** NA - **Twitter handle:** NA - [x] **Add tests and docs**: 1. we have updated our unit tests, 2. no changes required for documentation. - [x] **Lint and test**: make format, make lint and make test we run successfully --------- Co-authored-by: Arthur Cheng <arthur.cheng@oracle.com> Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
parent
ef24220d3f
commit
986b752fc8
@ -26,14 +26,14 @@
|
||||
"## Overview\n",
|
||||
"### Integration details\n",
|
||||
"\n",
|
||||
"| Class | Package | Local | Serializable | [JS support](https://js.langchain.com/docs/integrations/chat/oci_generative_ai) | Package downloads | Package latest |\n",
|
||||
"| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n",
|
||||
"| [ChatOCIGenAI](https://python.langchain.com/api_reference/community/chat_models/langchain_community.chat_models.oci_generative_ai.ChatOCIGenAI.html) | [langchain-community](https://python.langchain.com/api_reference/community/index.html) | ❌ | ❌ | ❌ |  |  |\n",
|
||||
"| Class | Package | Local | Serializable | [JS support](https://js.langchain.com/docs/integrations/chat/oci_generative_ai) |\n",
|
||||
"| :--- | :--- | :---: | :---: | :---: |\n",
|
||||
"| [ChatOCIGenAI](https://python.langchain.com/api_reference/community/chat_models/langchain_community.chat_models.oci_generative_ai.ChatOCIGenAI.html) | [langchain-community](https://python.langchain.com/api_reference/community/index.html) | ❌ | ❌ | ❌ |\n",
|
||||
"\n",
|
||||
"### Model features\n",
|
||||
"| [Tool calling](/docs/how_to/tool_calling/) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n",
|
||||
"| [Tool calling](/docs/how_to/tool_calling/) | [Structured output](/docs/how_to/structured_output/) | [JSON mode](/docs/how_to/structured_output/#advanced-specifying-the-method-for-structuring-outputs) | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n",
|
||||
"| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n",
|
||||
"| ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | \n",
|
||||
"| ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | \n",
|
||||
"\n",
|
||||
"## Setup\n",
|
||||
"\n",
|
||||
|
@ -2,12 +2,14 @@ import json
|
||||
import re
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
@ -32,13 +34,17 @@ from langchain_core.messages import (
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.messages.tool import ToolCallChunk
|
||||
from langchain_core.output_parsers import (
|
||||
JsonOutputParser,
|
||||
PydanticOutputParser,
|
||||
)
|
||||
from langchain_core.output_parsers.base import OutputParserLike
|
||||
from langchain_core.output_parsers.openai_tools import (
|
||||
JsonOutputKeyToolsParser,
|
||||
PydanticToolsParser,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_function
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
@ -58,6 +64,10 @@ JSON_TO_PYTHON_TYPES = {
|
||||
}
|
||||
|
||||
|
||||
def _is_pydantic_class(obj: Any) -> bool:
|
||||
return isinstance(obj, type) and issubclass(obj, BaseModel)
|
||||
|
||||
|
||||
def _remove_signature_from_tool_description(name: str, description: str) -> str:
|
||||
"""
|
||||
Removes the `{name}{signature} - ` prefix and Args: section from tool description.
|
||||
@ -158,7 +168,7 @@ class CohereProvider(Provider):
|
||||
|
||||
def chat_stream_to_text(self, event_data: Dict) -> str:
|
||||
if "text" in event_data:
|
||||
if "finishedReason" in event_data or "toolCalls" in event_data:
|
||||
if "finishReason" in event_data or "toolCalls" in event_data:
|
||||
return ""
|
||||
else:
|
||||
return event_data["text"]
|
||||
@ -378,7 +388,10 @@ class MetaProvider(Provider):
|
||||
"SYSTEM": models.SystemMessage,
|
||||
"ASSISTANT": models.AssistantMessage,
|
||||
}
|
||||
self.oci_chat_message_content = models.TextContent
|
||||
self.oci_chat_message_content = models.ChatContent
|
||||
self.oci_chat_message_text_content = models.TextContent
|
||||
self.oci_chat_message_image_content = models.ImageContent
|
||||
self.oci_chat_message_image_url = models.ImageUrl
|
||||
self.chat_api_format = models.BaseChatRequest.API_FORMAT_GENERIC
|
||||
|
||||
def chat_response_to_text(self, response: Any) -> str:
|
||||
@ -415,19 +428,81 @@ class MetaProvider(Provider):
|
||||
def messages_to_oci_params(
|
||||
self, messages: List[BaseMessage], **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
oci_messages = [
|
||||
self.oci_chat_message[self.get_role(msg)](
|
||||
content=[self.oci_chat_message_content(text=msg.content)]
|
||||
)
|
||||
for msg in messages
|
||||
]
|
||||
oci_params = {
|
||||
"""Convert LangChain messages to OCI chat parameters.
|
||||
|
||||
Args:
|
||||
messages: List of LangChain BaseMessage objects
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
Returns:
|
||||
Dict containing OCI chat parameters
|
||||
|
||||
Raises:
|
||||
ValueError: If message content is invalid
|
||||
"""
|
||||
oci_messages = []
|
||||
|
||||
for message in messages:
|
||||
content = self._process_message_content(message.content)
|
||||
oci_message = self.oci_chat_message[self.get_role(message)](content=content)
|
||||
oci_messages.append(oci_message)
|
||||
|
||||
return {
|
||||
"messages": oci_messages,
|
||||
"api_format": self.chat_api_format,
|
||||
"top_k": -1,
|
||||
}
|
||||
|
||||
return oci_params
|
||||
def _process_message_content(
|
||||
self, content: Union[str, List[Union[str, Dict]]]
|
||||
) -> List[Any]:
|
||||
"""Process message content into OCI chat content format.
|
||||
|
||||
Args:
|
||||
content: Message content as string or list
|
||||
|
||||
Returns:
|
||||
List of OCI chat content objects
|
||||
|
||||
Raises:
|
||||
ValueError: If content format is invalid
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return [self.oci_chat_message_text_content(text=content)]
|
||||
|
||||
if not isinstance(content, list):
|
||||
raise ValueError("Message content must be str or list of items")
|
||||
|
||||
processed_content = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
processed_content.append(self.oci_chat_message_text_content(text=item))
|
||||
continue
|
||||
|
||||
if not isinstance(item, dict):
|
||||
raise ValueError(
|
||||
f"Content items must be str or dict, got: {type(item)}"
|
||||
)
|
||||
|
||||
if "type" not in item:
|
||||
raise ValueError("Dict content item must have a type key")
|
||||
|
||||
if item["type"] == "image_url":
|
||||
processed_content.append(
|
||||
self.oci_chat_message_image_content(
|
||||
image_url=self.oci_chat_message_image_url(
|
||||
url=item["image_url"]["url"]
|
||||
)
|
||||
)
|
||||
)
|
||||
elif item["type"] == "text":
|
||||
processed_content.append(
|
||||
self.oci_chat_message_text_content(text=item["text"])
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {item['type']}")
|
||||
|
||||
return processed_content
|
||||
|
||||
def convert_to_oci_tool(
|
||||
self,
|
||||
@ -577,7 +652,10 @@ class ChatOCIGenAI(BaseChatModel, OCIGenAIBase):
|
||||
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Union[Dict[Any, Any], Type[BaseModel]],
|
||||
schema: Optional[Union[Dict, Type[BaseModel]]] = None,
|
||||
*,
|
||||
method: Literal["function_calling", "json_mode"] = "function_calling",
|
||||
include_raw: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
||||
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||
@ -585,24 +663,86 @@ class ChatOCIGenAI(BaseChatModel, OCIGenAIBase):
|
||||
Args:
|
||||
schema: The output schema as a dict or a Pydantic class. If a Pydantic class
|
||||
then the model output will be an object of that class. If a dict then
|
||||
the model output will be a dict.
|
||||
the model output will be a dict. With a Pydantic class the returned
|
||||
attributes will be validated, whereas with a dict they will not be. If
|
||||
`method` is "function_calling" and `schema` is a dict, then the dict
|
||||
must match the OCI Generative AI function-calling spec.
|
||||
method:
|
||||
The method for steering model generation, either "function_calling"
|
||||
or "json_mode". If "function_calling" then the schema will be converted
|
||||
to an OCI function and the returned model will make use of the
|
||||
function-calling API. If "json_mode" then Cohere's JSON mode will be
|
||||
used. Note that if using "json_mode" then you must include instructions
|
||||
for formatting the output into the desired schema into the model call.
|
||||
include_raw:
|
||||
If False then only the parsed structured output is returned. If
|
||||
an error occurs during model output parsing it will be raised. If True
|
||||
then both the raw model response (a BaseMessage) and the parsed model
|
||||
response will be returned. If an error occurs during output parsing it
|
||||
will be caught and returned as well. The final output is always a dict
|
||||
with keys "raw", "parsed", and "parsing_error".
|
||||
|
||||
Returns:
|
||||
A Runnable that takes any ChatModel input and returns either a dict or
|
||||
Pydantic class as output.
|
||||
"""
|
||||
llm = self.bind_tools([schema], **kwargs)
|
||||
if isinstance(schema, type) and issubclass(schema, BaseModel):
|
||||
output_parser: OutputParserLike = PydanticToolsParser(
|
||||
tools=[schema], first_tool_only=True
|
||||
A Runnable that takes any ChatModel input and returns as output:
|
||||
|
||||
If include_raw is True then a dict with keys:
|
||||
raw: BaseMessage
|
||||
parsed: Optional[_DictOrPydantic]
|
||||
parsing_error: Optional[BaseException]
|
||||
|
||||
If include_raw is False then just _DictOrPydantic is returned,
|
||||
where _DictOrPydantic depends on the schema:
|
||||
|
||||
If schema is a Pydantic class then _DictOrPydantic is the Pydantic
|
||||
class.
|
||||
|
||||
If schema is a dict then _DictOrPydantic is a dict.
|
||||
|
||||
""" # noqa: E501
|
||||
if kwargs:
|
||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||
is_pydantic_schema = _is_pydantic_class(schema)
|
||||
if method == "function_calling":
|
||||
if schema is None:
|
||||
raise ValueError(
|
||||
"schema must be specified when method is 'function_calling'. "
|
||||
"Received None."
|
||||
)
|
||||
llm = self.bind_tools([schema], **kwargs)
|
||||
tool_name = getattr(self._provider.convert_to_oci_tool(schema), "name")
|
||||
if is_pydantic_schema:
|
||||
output_parser: OutputParserLike = PydanticToolsParser(
|
||||
tools=[schema], # type: ignore[list-item]
|
||||
first_tool_only=True, # type: ignore[list-item]
|
||||
)
|
||||
else:
|
||||
output_parser = JsonOutputKeyToolsParser(
|
||||
key_name=tool_name, first_tool_only=True
|
||||
)
|
||||
elif method == "json_mode":
|
||||
llm = self.bind(response_format={"type": "json_object"})
|
||||
output_parser = (
|
||||
PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type]
|
||||
if is_pydantic_schema
|
||||
else JsonOutputParser()
|
||||
)
|
||||
else:
|
||||
key_name = getattr(self._provider.convert_to_oci_tool(schema), "name")
|
||||
output_parser = JsonOutputKeyToolsParser(
|
||||
key_name=key_name, first_tool_only=True
|
||||
raise ValueError(
|
||||
f"Unrecognized method argument. "
|
||||
f"Expected `function_calling` or `json_mode`."
|
||||
f"Received: `{method}`."
|
||||
)
|
||||
|
||||
return llm | output_parser
|
||||
if include_raw:
|
||||
parser_assign = RunnablePassthrough.assign(
|
||||
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
|
||||
)
|
||||
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
|
||||
parser_with_fallback = parser_assign.with_fallbacks(
|
||||
[parser_none], exception_key="parsing_error"
|
||||
)
|
||||
return RunnableMap(raw=llm) | parser_with_fallback
|
||||
else:
|
||||
return llm | output_parser
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
|
Loading…
Reference in New Issue
Block a user