community[patch]: Add OCI Generative AI tool and structured output support (#24693)

- [x] **PR title**: 
  community: Add OCI Generative AI tool and structured output support


- [x] **PR message**: 
- **Description:** adding tool calling and structured output support for
chat models offered by OCI Generative AI services. This is an update to
our last PR 22880 with changes in
/langchain_community/chat_models/oci_generative_ai.py
    - **Issue:** NA
    - **Dependencies:** NA
    - **Twitter handle:** NA


- [x] **Add tests and docs**: 
  1. we have updated our unit tests
2. we have updated our documentation under
/docs/docs/integrations/chat/oci_generative_ai.ipynb


- [x] **Lint and test**: `make format`, `make lint` and `make test` we
run successfully

---------

Co-authored-by: RHARPAZ <RHARPAZ@RHARPAZ-5750.us.oracle.com>
Co-authored-by: Arthur Cheng <arthur.cheng@oracle.com>
This commit is contained in:
Rave Harpaz 2024-07-25 23:19:00 -07:00 committed by GitHub
parent 2b6a262f84
commit ee399e3ec5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 395 additions and 31 deletions

View File

@ -33,7 +33,7 @@
"### Model features\n", "### Model features\n",
"| [Tool calling](/docs/how_to/tool_calling/) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n", "| [Tool calling](/docs/how_to/tool_calling/) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n",
"| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n", "| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n",
"| ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | \n", "| ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | \n",
"\n", "\n",
"## Setup\n", "## Setup\n",
"\n", "\n",

View File

@ -1,8 +1,22 @@
import json import json
import re
import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Sequence,
Type,
Union,
)
from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import ( from langchain_core.language_models.chat_models import (
BaseChatModel, BaseChatModel,
generate_from_stream, generate_from_stream,
@ -14,15 +28,76 @@ from langchain_core.messages import (
ChatMessage, ChatMessage,
HumanMessage, HumanMessage,
SystemMessage, SystemMessage,
ToolCall,
ToolMessage,
)
from langchain_core.messages.tool import ToolCallChunk
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
) )
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Extra from langchain_core.pydantic_v1 import BaseModel, Extra
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_function
from langchain_community.llms.oci_generative_ai import OCIGenAIBase from langchain_community.llms.oci_generative_ai import OCIGenAIBase
from langchain_community.llms.utils import enforce_stop_tokens from langchain_community.llms.utils import enforce_stop_tokens
CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint" CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint"
JSON_TO_PYTHON_TYPES = {
"string": "str",
"number": "float",
"boolean": "bool",
"integer": "int",
"array": "List",
"object": "Dict",
}
def _remove_signature_from_tool_description(name: str, description: str) -> str:
"""
Removes the `{name}{signature} - ` prefix and Args: section from tool description.
The signature is usually present for tools created with the @tool decorator,
whereas the Args: section may be present in function doc blocks.
"""
description = re.sub(rf"^{name}\(.*?\) -(?:> \w+? -)? ", "", description)
description = re.sub(r"(?s)(?:\n?\n\s*?)?Args:.*$", "", description)
return description
def _format_oci_tool_calls(
tool_calls: Optional[List[Any]] = None,
) -> List[Dict]:
"""
Formats a OCI GenAI API response into the tool call format used in Langchain.
"""
if not tool_calls:
return []
formatted_tool_calls = []
for tool_call in tool_calls:
formatted_tool_calls.append(
{
"id": uuid.uuid4().hex[:],
"function": {
"name": tool_call.name,
"arguments": json.dumps(tool_call.parameters),
},
"type": "function",
}
)
return formatted_tool_calls
def _convert_oci_tool_call_to_langchain(tool_call: Any) -> ToolCall:
"""Convert a OCI GenAI tool call into langchain_core.messages.ToolCall"""
_id = uuid.uuid4().hex[:]
return ToolCall(name=tool_call.name, args=tool_call.parameters, id=_id)
class Provider(ABC): class Provider(ABC):
@property @property
@ -35,14 +110,28 @@ class Provider(ABC):
@abstractmethod @abstractmethod
def chat_stream_to_text(self, event_data: Dict) -> str: ... def chat_stream_to_text(self, event_data: Dict) -> str: ...
@abstractmethod
def is_chat_stream_end(self, event_data: Dict) -> bool: ...
@abstractmethod @abstractmethod
def chat_generation_info(self, response: Any) -> Dict[str, Any]: ... def chat_generation_info(self, response: Any) -> Dict[str, Any]: ...
@abstractmethod
def chat_stream_generation_info(self, event_data: Dict) -> Dict[str, Any]: ...
@abstractmethod @abstractmethod
def get_role(self, message: BaseMessage) -> str: ... def get_role(self, message: BaseMessage) -> str: ...
@abstractmethod @abstractmethod
def messages_to_oci_params(self, messages: Any) -> Dict[str, Any]: ... def messages_to_oci_params(
self, messages: Any, **kwargs: Any
) -> Dict[str, Any]: ...
@abstractmethod
def convert_to_oci_tool(
self,
tool: Union[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
) -> Dict[str, Any]: ...
class CohereProvider(Provider): class CohereProvider(Provider):
@ -52,10 +141,15 @@ class CohereProvider(Provider):
from oci.generative_ai_inference import models from oci.generative_ai_inference import models
self.oci_chat_request = models.CohereChatRequest self.oci_chat_request = models.CohereChatRequest
self.oci_tool = models.CohereTool
self.oci_tool_param = models.CohereParameterDefinition
self.oci_tool_result = models.CohereToolResult
self.oci_tool_call = models.CohereToolCall
self.oci_chat_message = { self.oci_chat_message = {
"USER": models.CohereUserMessage, "USER": models.CohereUserMessage,
"CHATBOT": models.CohereChatBotMessage, "CHATBOT": models.CohereChatBotMessage,
"SYSTEM": models.CohereSystemMessage, "SYSTEM": models.CohereSystemMessage,
"TOOL": models.CohereToolMessage,
} }
self.chat_api_format = models.BaseChatRequest.API_FORMAT_COHERE self.chat_api_format = models.BaseChatRequest.API_FORMAT_COHERE
@ -63,15 +157,54 @@ class CohereProvider(Provider):
return response.data.chat_response.text return response.data.chat_response.text
def chat_stream_to_text(self, event_data: Dict) -> str: def chat_stream_to_text(self, event_data: Dict) -> str:
if "text" in event_data and "finishReason" not in event_data: if "text" in event_data:
return event_data["text"] return event_data["text"]
else: else:
return "" return ""
def is_chat_stream_end(self, event_data: Dict) -> bool:
return "finishReason" in event_data
def chat_generation_info(self, response: Any) -> Dict[str, Any]: def chat_generation_info(self, response: Any) -> Dict[str, Any]:
return { generation_info: Dict[str, Any] = {
"documents": response.data.chat_response.documents,
"citations": response.data.chat_response.citations,
"search_queries": response.data.chat_response.search_queries,
"is_search_required": response.data.chat_response.is_search_required,
"finish_reason": response.data.chat_response.finish_reason, "finish_reason": response.data.chat_response.finish_reason,
} }
if response.data.chat_response.tool_calls:
# Only populate tool_calls when 1) present on the response and
# 2) has one or more calls.
generation_info["tool_calls"] = _format_oci_tool_calls(
response.data.chat_response.tool_calls
)
return generation_info
def chat_stream_generation_info(self, event_data: Dict) -> Dict[str, Any]:
generation_info: Dict[str, Any] = {
"documents": event_data.get("documents"),
"citations": event_data.get("citations"),
"finish_reason": event_data.get("finishReason"),
}
if "toolCalls" in event_data:
generation_info["tool_calls"] = []
for tool_call in event_data["toolCalls"]:
generation_info["tool_calls"].append(
{
"id": uuid.uuid4().hex[:],
"function": {
"name": tool_call["name"],
"arguments": json.dumps(tool_call["parameters"]),
},
"type": "function",
}
)
generation_info = {k: v for k, v in generation_info.items() if v is not None}
return generation_info
def get_role(self, message: BaseMessage) -> str: def get_role(self, message: BaseMessage) -> str:
if isinstance(message, HumanMessage): if isinstance(message, HumanMessage):
@ -80,21 +213,154 @@ class CohereProvider(Provider):
return "CHATBOT" return "CHATBOT"
elif isinstance(message, SystemMessage): elif isinstance(message, SystemMessage):
return "SYSTEM" return "SYSTEM"
elif isinstance(message, ToolMessage):
return "TOOL"
else: else:
raise ValueError(f"Got unknown type {message}") raise ValueError(f"Got unknown type {message}")
def messages_to_oci_params(self, messages: Sequence[ChatMessage]) -> Dict[str, Any]: def messages_to_oci_params(
oci_chat_history = [ self, messages: Sequence[ChatMessage], **kwargs: Any
self.oci_chat_message[self.get_role(msg)](message=msg.content) ) -> Dict[str, Any]:
for msg in messages[:-1] is_force_single_step = kwargs.get("is_force_single_step") or False
]
oci_chat_history = []
for msg in messages[:-1]:
if self.get_role(msg) == "USER" or self.get_role(msg) == "SYSTEM":
oci_chat_history.append(
self.oci_chat_message[self.get_role(msg)](message=msg.content)
)
elif isinstance(msg, AIMessage):
if msg.tool_calls and is_force_single_step:
continue
tool_calls = (
[
self.oci_tool_call(name=tc["name"], parameters=tc["args"])
for tc in msg.tool_calls
]
if msg.tool_calls
else None
)
msg_content = msg.content if msg.content else " "
oci_chat_history.append(
self.oci_chat_message[self.get_role(msg)](
message=msg_content, tool_calls=tool_calls
)
)
# Get the messages for the current chat turn
current_chat_turn_messages = []
for message in messages[::-1]:
current_chat_turn_messages.append(message)
if isinstance(message, HumanMessage):
break
current_chat_turn_messages = current_chat_turn_messages[::-1]
oci_tool_results: Union[List[Any], None] = []
for message in current_chat_turn_messages:
if isinstance(message, ToolMessage):
tool_message = message
previous_ai_msgs = [
message
for message in current_chat_turn_messages
if isinstance(message, AIMessage) and message.tool_calls
]
if previous_ai_msgs:
previous_ai_msg = previous_ai_msgs[-1]
for lc_tool_call in previous_ai_msg.tool_calls:
if lc_tool_call["id"] == tool_message.tool_call_id:
tool_result = self.oci_tool_result()
tool_result.call = self.oci_tool_call(
name=lc_tool_call["name"],
parameters=lc_tool_call["args"],
)
tool_result.outputs = [{"output": tool_message.content}]
oci_tool_results.append(tool_result)
if not oci_tool_results:
oci_tool_results = None
message_str = "" if oci_tool_results else messages[-1].content
oci_params = { oci_params = {
"message": messages[-1].content, "message": message_str,
"chat_history": oci_chat_history, "chat_history": oci_chat_history,
"tool_results": oci_tool_results,
"api_format": self.chat_api_format, "api_format": self.chat_api_format,
} }
return oci_params return {k: v for k, v in oci_params.items() if v is not None}
def convert_to_oci_tool(
self,
tool: Union[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
) -> Dict[str, Any]:
"""
Convert a BaseTool instance, JSON schema dict, or BaseModel type to a OCI tool.
"""
if isinstance(tool, BaseTool):
return self.oci_tool(
name=tool.name,
description=_remove_signature_from_tool_description(
tool.name, tool.description
),
parameter_definitions={
p_name: self.oci_tool_param(
description=p_def.get("description")
if "description" in p_def
else "",
type=JSON_TO_PYTHON_TYPES.get(
p_def.get("type"), p_def.get("type")
),
is_required="default" not in p_def,
)
for p_name, p_def in tool.args.items()
},
)
elif isinstance(tool, dict):
if not all(k in tool for k in ("title", "description", "properties")):
raise ValueError(
"Unsupported dict type. Tool must be passed in as a BaseTool instance, JSON schema dict, or BaseModel type." # noqa: E501
)
return self.oci_tool(
name=tool.get("title"),
description=tool.get("description"),
parameter_definitions={
p_name: self.oci_tool_param(
description=p_def.get("description"),
type=JSON_TO_PYTHON_TYPES.get(
p_def.get("type"), p_def.get("type")
),
is_required="default" not in p_def,
)
for p_name, p_def in tool.get("properties", {}).items()
},
)
elif (isinstance(tool, type) and issubclass(tool, BaseModel)) or callable(tool):
as_json_schema_function = convert_to_openai_function(tool)
parameters = as_json_schema_function.get("parameters", {})
properties = parameters.get("properties", {})
return self.oci_tool(
name=as_json_schema_function.get("name"),
description=as_json_schema_function.get(
"description",
as_json_schema_function.get("name"),
),
parameter_definitions={
p_name: self.oci_tool_param(
description=p_def.get("description"),
type=JSON_TO_PYTHON_TYPES.get(
p_def.get("type"), p_def.get("type")
),
is_required=p_name in parameters.get("required", []),
)
for p_name, p_def in properties.items()
},
)
else:
raise ValueError(
f"Unsupported tool type {type(tool)}. Tool must be passed in as a BaseTool instance, JSON schema dict, or BaseModel type." # noqa: E501
)
class MetaProvider(Provider): class MetaProvider(Provider):
@ -116,10 +382,10 @@ class MetaProvider(Provider):
return response.data.chat_response.choices[0].message.content[0].text return response.data.chat_response.choices[0].message.content[0].text
def chat_stream_to_text(self, event_data: Dict) -> str: def chat_stream_to_text(self, event_data: Dict) -> str:
if "message" in event_data: return event_data["message"]["content"][0]["text"]
return event_data["message"]["content"][0]["text"]
else: def is_chat_stream_end(self, event_data: Dict) -> bool:
return "" return "message" not in event_data
def chat_generation_info(self, response: Any) -> Dict[str, Any]: def chat_generation_info(self, response: Any) -> Dict[str, Any]:
return { return {
@ -127,6 +393,11 @@ class MetaProvider(Provider):
"time_created": str(response.data.chat_response.time_created), "time_created": str(response.data.chat_response.time_created),
} }
def chat_stream_generation_info(self, event_data: Dict) -> Dict[str, Any]:
return {
"finish_reason": event_data["finishReason"],
}
def get_role(self, message: BaseMessage) -> str: def get_role(self, message: BaseMessage) -> str:
# meta only supports alternating user/assistant roles # meta only supports alternating user/assistant roles
if isinstance(message, HumanMessage): if isinstance(message, HumanMessage):
@ -138,7 +409,9 @@ class MetaProvider(Provider):
else: else:
raise ValueError(f"Got unknown type {message}") raise ValueError(f"Got unknown type {message}")
def messages_to_oci_params(self, messages: List[BaseMessage]) -> Dict[str, Any]: def messages_to_oci_params(
self, messages: List[BaseMessage], **kwargs: Any
) -> Dict[str, Any]:
oci_messages = [ oci_messages = [
self.oci_chat_message[self.get_role(msg)]( self.oci_chat_message[self.get_role(msg)](
content=[self.oci_chat_message_content(text=msg.content)] content=[self.oci_chat_message_content(text=msg.content)]
@ -153,6 +426,12 @@ class MetaProvider(Provider):
return oci_params return oci_params
def convert_to_oci_tool(
self,
tool: Union[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
) -> Dict[str, Any]:
raise NotImplementedError("Tools not supported for Meta models")
class ChatOCIGenAI(BaseChatModel, OCIGenAIBase): class ChatOCIGenAI(BaseChatModel, OCIGenAIBase):
"""ChatOCIGenAI chat model integration. """ChatOCIGenAI chat model integration.
@ -247,8 +526,8 @@ class ChatOCIGenAI(BaseChatModel, OCIGenAIBase):
self, self,
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]], stop: Optional[List[str]],
kwargs: Dict[str, Any],
stream: bool, stream: bool,
**kwargs: Any,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
try: try:
from oci.generative_ai_inference import models from oci.generative_ai_inference import models
@ -258,8 +537,10 @@ class ChatOCIGenAI(BaseChatModel, OCIGenAIBase):
"Could not import oci python package. " "Could not import oci python package. "
"Please make sure you have the oci package installed." "Please make sure you have the oci package installed."
) from ex ) from ex
oci_params = self._provider.messages_to_oci_params(messages)
oci_params["is_stream"] = stream # self.is_stream oci_params = self._provider.messages_to_oci_params(messages, **kwargs)
oci_params["is_stream"] = stream
_model_kwargs = self.model_kwargs or {} _model_kwargs = self.model_kwargs or {}
if stop is not None: if stop is not None:
@ -280,6 +561,43 @@ class ChatOCIGenAI(BaseChatModel, OCIGenAIBase):
return request return request
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
formatted_tools = [self._provider.convert_to_oci_tool(tool) for tool in tools]
return super().bind(tools=formatted_tools, **kwargs)
def with_structured_output(
self,
schema: Union[Dict[Any, Any], Type[BaseModel]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema.
Args:
schema: The output schema as a dict or a Pydantic class. If a Pydantic class
then the model output will be an object of that class. If a dict then
the model output will be a dict.
Returns:
A Runnable that takes any ChatModel input and returns either a dict or
Pydantic class as output.
"""
llm = self.bind_tools([schema], **kwargs)
if isinstance(schema, type) and issubclass(schema, BaseModel):
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], first_tool_only=True
)
else:
key_name = getattr(self._provider.convert_to_oci_tool(schema), "name")
output_parser = JsonOutputKeyToolsParser(
key_name=key_name, first_tool_only=True
)
return llm | output_parser
def _generate( def _generate(
self, self,
messages: List[BaseMessage], messages: List[BaseMessage],
@ -313,7 +631,7 @@ class ChatOCIGenAI(BaseChatModel, OCIGenAIBase):
) )
return generate_from_stream(stream_iter) return generate_from_stream(stream_iter)
request = self._prepare_request(messages, stop, kwargs, stream=False) request = self._prepare_request(messages, stop=stop, stream=False, **kwargs)
response = self.client.chat(request) response = self.client.chat(request)
content = self._provider.chat_response_to_text(response) content = self._provider.chat_response_to_text(response)
@ -330,11 +648,22 @@ class ChatOCIGenAI(BaseChatModel, OCIGenAIBase):
"content-length": response.headers["content-length"], "content-length": response.headers["content-length"],
} }
if "tool_calls" in generation_info:
tool_calls = [
_convert_oci_tool_call_to_langchain(tool_call)
for tool_call in response.data.chat_response.tool_calls
]
else:
tool_calls = []
message = AIMessage(
content=content,
additional_kwargs=generation_info,
tool_calls=tool_calls,
)
return ChatResult( return ChatResult(
generations=[ generations=[
ChatGeneration( ChatGeneration(message=message, generation_info=generation_info)
message=AIMessage(content=content), generation_info=generation_info
)
], ],
llm_output=llm_output, llm_output=llm_output,
) )
@ -346,12 +675,42 @@ class ChatOCIGenAI(BaseChatModel, OCIGenAIBase):
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[ChatGenerationChunk]: ) -> Iterator[ChatGenerationChunk]:
request = self._prepare_request(messages, stop, kwargs, stream=True) request = self._prepare_request(messages, stop=stop, stream=True, **kwargs)
response = self.client.chat(request) response = self.client.chat(request)
for event in response.data.events(): for event in response.data.events():
delta = self._provider.chat_stream_to_text(json.loads(event.data)) event_data = json.loads(event.data)
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta)) if not self._provider.is_chat_stream_end(event_data): # still streaming
if run_manager: delta = self._provider.chat_stream_to_text(event_data)
run_manager.on_llm_new_token(delta, chunk=chunk) chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
yield chunk if run_manager:
run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk
else: # stream end
generation_info = self._provider.chat_stream_generation_info(event_data)
tool_call_chunks = []
if tool_calls := generation_info.get("tool_calls"):
content = self._provider.chat_stream_to_text(event_data)
try:
tool_call_chunks = [
ToolCallChunk(
name=tool_call["function"].get("name"),
args=tool_call["function"].get("arguments"),
id=tool_call.get("id"),
index=tool_call.get("index"),
)
for tool_call in tool_calls
]
except KeyError:
pass
else:
content = ""
message = AIMessageChunk(
content=content,
additional_kwargs=generation_info,
tool_call_chunks=tool_call_chunks,
)
yield ChatGenerationChunk(
message=message,
generation_info=generation_info,
)

View File

@ -38,6 +38,11 @@ def test_llm_chat(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
{ {
"text": response_text, "text": response_text,
"finish_reason": "completed", "finish_reason": "completed",
"is_search_required": None,
"search_queries": None,
"citations": None,
"documents": None,
"tool_calls": None,
} }
), ),
"model_id": "cohere.command-r-16k", "model_id": "cohere.command-r-16k",