diff --git a/docs/docs/integrations/chat/oci_generative_ai.ipynb b/docs/docs/integrations/chat/oci_generative_ai.ipynb index 4ce58a13fbf..261cf42feb2 100644 --- a/docs/docs/integrations/chat/oci_generative_ai.ipynb +++ b/docs/docs/integrations/chat/oci_generative_ai.ipynb @@ -33,7 +33,7 @@ "### Model features\n", "| [Tool calling](/docs/how_to/tool_calling/) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n", "| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n", - "| ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | \n", + "| ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | \n", "\n", "## Setup\n", "\n", diff --git a/libs/community/langchain_community/chat_models/oci_generative_ai.py b/libs/community/langchain_community/chat_models/oci_generative_ai.py index e1863f120c2..4edbe04456e 100644 --- a/libs/community/langchain_community/chat_models/oci_generative_ai.py +++ b/libs/community/langchain_community/chat_models/oci_generative_ai.py @@ -1,8 +1,22 @@ import json +import re +import uuid from abc import ABC, abstractmethod -from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + Mapping, + Optional, + Sequence, + Type, + Union, +) from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.chat_models import ( BaseChatModel, generate_from_stream, @@ -14,15 +28,76 @@ from langchain_core.messages import ( ChatMessage, HumanMessage, SystemMessage, + ToolCall, + ToolMessage, +) +from langchain_core.messages.tool import ToolCallChunk +from langchain_core.output_parsers.base import OutputParserLike +from langchain_core.output_parsers.openai_tools import ( + JsonOutputKeyToolsParser, + PydanticToolsParser, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.pydantic_v1 import Extra +from langchain_core.pydantic_v1 import BaseModel, Extra +from langchain_core.runnables import Runnable +from langchain_core.tools import BaseTool +from langchain_core.utils.function_calling import convert_to_openai_function from langchain_community.llms.oci_generative_ai import OCIGenAIBase from langchain_community.llms.utils import enforce_stop_tokens CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint" +JSON_TO_PYTHON_TYPES = { + "string": "str", + "number": "float", + "boolean": "bool", + "integer": "int", + "array": "List", + "object": "Dict", +} + + +def _remove_signature_from_tool_description(name: str, description: str) -> str: + """ + Removes the `{name}{signature} - ` prefix and Args: section from tool description. + The signature is usually present for tools created with the @tool decorator, + whereas the Args: section may be present in function doc blocks. + """ + description = re.sub(rf"^{name}\(.*?\) -(?:> \w+? -)? ", "", description) + description = re.sub(r"(?s)(?:\n?\n\s*?)?Args:.*$", "", description) + return description + + +def _format_oci_tool_calls( + tool_calls: Optional[List[Any]] = None, +) -> List[Dict]: + """ + Formats a OCI GenAI API response into the tool call format used in Langchain. + """ + if not tool_calls: + return [] + + formatted_tool_calls = [] + for tool_call in tool_calls: + formatted_tool_calls.append( + { + "id": uuid.uuid4().hex[:], + "function": { + "name": tool_call.name, + "arguments": json.dumps(tool_call.parameters), + }, + "type": "function", + } + ) + return formatted_tool_calls + + +def _convert_oci_tool_call_to_langchain(tool_call: Any) -> ToolCall: + """Convert a OCI GenAI tool call into langchain_core.messages.ToolCall""" + _id = uuid.uuid4().hex[:] + return ToolCall(name=tool_call.name, args=tool_call.parameters, id=_id) + class Provider(ABC): @property @@ -35,14 +110,28 @@ class Provider(ABC): @abstractmethod def chat_stream_to_text(self, event_data: Dict) -> str: ... + @abstractmethod + def is_chat_stream_end(self, event_data: Dict) -> bool: ... + @abstractmethod def chat_generation_info(self, response: Any) -> Dict[str, Any]: ... + @abstractmethod + def chat_stream_generation_info(self, event_data: Dict) -> Dict[str, Any]: ... + @abstractmethod def get_role(self, message: BaseMessage) -> str: ... @abstractmethod - def messages_to_oci_params(self, messages: Any) -> Dict[str, Any]: ... + def messages_to_oci_params( + self, messages: Any, **kwargs: Any + ) -> Dict[str, Any]: ... + + @abstractmethod + def convert_to_oci_tool( + self, + tool: Union[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + ) -> Dict[str, Any]: ... class CohereProvider(Provider): @@ -52,10 +141,15 @@ class CohereProvider(Provider): from oci.generative_ai_inference import models self.oci_chat_request = models.CohereChatRequest + self.oci_tool = models.CohereTool + self.oci_tool_param = models.CohereParameterDefinition + self.oci_tool_result = models.CohereToolResult + self.oci_tool_call = models.CohereToolCall self.oci_chat_message = { "USER": models.CohereUserMessage, "CHATBOT": models.CohereChatBotMessage, "SYSTEM": models.CohereSystemMessage, + "TOOL": models.CohereToolMessage, } self.chat_api_format = models.BaseChatRequest.API_FORMAT_COHERE @@ -63,15 +157,54 @@ class CohereProvider(Provider): return response.data.chat_response.text def chat_stream_to_text(self, event_data: Dict) -> str: - if "text" in event_data and "finishReason" not in event_data: + if "text" in event_data: return event_data["text"] else: return "" + def is_chat_stream_end(self, event_data: Dict) -> bool: + return "finishReason" in event_data + def chat_generation_info(self, response: Any) -> Dict[str, Any]: - return { + generation_info: Dict[str, Any] = { + "documents": response.data.chat_response.documents, + "citations": response.data.chat_response.citations, + "search_queries": response.data.chat_response.search_queries, + "is_search_required": response.data.chat_response.is_search_required, "finish_reason": response.data.chat_response.finish_reason, } + if response.data.chat_response.tool_calls: + # Only populate tool_calls when 1) present on the response and + # 2) has one or more calls. + generation_info["tool_calls"] = _format_oci_tool_calls( + response.data.chat_response.tool_calls + ) + + return generation_info + + def chat_stream_generation_info(self, event_data: Dict) -> Dict[str, Any]: + generation_info: Dict[str, Any] = { + "documents": event_data.get("documents"), + "citations": event_data.get("citations"), + "finish_reason": event_data.get("finishReason"), + } + if "toolCalls" in event_data: + generation_info["tool_calls"] = [] + for tool_call in event_data["toolCalls"]: + generation_info["tool_calls"].append( + { + "id": uuid.uuid4().hex[:], + "function": { + "name": tool_call["name"], + "arguments": json.dumps(tool_call["parameters"]), + }, + "type": "function", + } + ) + + generation_info = {k: v for k, v in generation_info.items() if v is not None} + + return generation_info def get_role(self, message: BaseMessage) -> str: if isinstance(message, HumanMessage): @@ -80,21 +213,154 @@ class CohereProvider(Provider): return "CHATBOT" elif isinstance(message, SystemMessage): return "SYSTEM" + elif isinstance(message, ToolMessage): + return "TOOL" else: raise ValueError(f"Got unknown type {message}") - def messages_to_oci_params(self, messages: Sequence[ChatMessage]) -> Dict[str, Any]: - oci_chat_history = [ - self.oci_chat_message[self.get_role(msg)](message=msg.content) - for msg in messages[:-1] - ] + def messages_to_oci_params( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> Dict[str, Any]: + is_force_single_step = kwargs.get("is_force_single_step") or False + + oci_chat_history = [] + + for msg in messages[:-1]: + if self.get_role(msg) == "USER" or self.get_role(msg) == "SYSTEM": + oci_chat_history.append( + self.oci_chat_message[self.get_role(msg)](message=msg.content) + ) + elif isinstance(msg, AIMessage): + if msg.tool_calls and is_force_single_step: + continue + tool_calls = ( + [ + self.oci_tool_call(name=tc["name"], parameters=tc["args"]) + for tc in msg.tool_calls + ] + if msg.tool_calls + else None + ) + msg_content = msg.content if msg.content else " " + oci_chat_history.append( + self.oci_chat_message[self.get_role(msg)]( + message=msg_content, tool_calls=tool_calls + ) + ) + + # Get the messages for the current chat turn + current_chat_turn_messages = [] + for message in messages[::-1]: + current_chat_turn_messages.append(message) + if isinstance(message, HumanMessage): + break + current_chat_turn_messages = current_chat_turn_messages[::-1] + + oci_tool_results: Union[List[Any], None] = [] + for message in current_chat_turn_messages: + if isinstance(message, ToolMessage): + tool_message = message + previous_ai_msgs = [ + message + for message in current_chat_turn_messages + if isinstance(message, AIMessage) and message.tool_calls + ] + if previous_ai_msgs: + previous_ai_msg = previous_ai_msgs[-1] + for lc_tool_call in previous_ai_msg.tool_calls: + if lc_tool_call["id"] == tool_message.tool_call_id: + tool_result = self.oci_tool_result() + tool_result.call = self.oci_tool_call( + name=lc_tool_call["name"], + parameters=lc_tool_call["args"], + ) + tool_result.outputs = [{"output": tool_message.content}] + oci_tool_results.append(tool_result) + + if not oci_tool_results: + oci_tool_results = None + + message_str = "" if oci_tool_results else messages[-1].content + oci_params = { - "message": messages[-1].content, + "message": message_str, "chat_history": oci_chat_history, + "tool_results": oci_tool_results, "api_format": self.chat_api_format, } - return oci_params + return {k: v for k, v in oci_params.items() if v is not None} + + def convert_to_oci_tool( + self, + tool: Union[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + ) -> Dict[str, Any]: + """ + Convert a BaseTool instance, JSON schema dict, or BaseModel type to a OCI tool. + """ + if isinstance(tool, BaseTool): + return self.oci_tool( + name=tool.name, + description=_remove_signature_from_tool_description( + tool.name, tool.description + ), + parameter_definitions={ + p_name: self.oci_tool_param( + description=p_def.get("description") + if "description" in p_def + else "", + type=JSON_TO_PYTHON_TYPES.get( + p_def.get("type"), p_def.get("type") + ), + is_required="default" not in p_def, + ) + for p_name, p_def in tool.args.items() + }, + ) + elif isinstance(tool, dict): + if not all(k in tool for k in ("title", "description", "properties")): + raise ValueError( + "Unsupported dict type. Tool must be passed in as a BaseTool instance, JSON schema dict, or BaseModel type." # noqa: E501 + ) + return self.oci_tool( + name=tool.get("title"), + description=tool.get("description"), + parameter_definitions={ + p_name: self.oci_tool_param( + description=p_def.get("description"), + type=JSON_TO_PYTHON_TYPES.get( + p_def.get("type"), p_def.get("type") + ), + is_required="default" not in p_def, + ) + for p_name, p_def in tool.get("properties", {}).items() + }, + ) + elif (isinstance(tool, type) and issubclass(tool, BaseModel)) or callable(tool): + as_json_schema_function = convert_to_openai_function(tool) + parameters = as_json_schema_function.get("parameters", {}) + properties = parameters.get("properties", {}) + return self.oci_tool( + name=as_json_schema_function.get("name"), + description=as_json_schema_function.get( + "description", + as_json_schema_function.get("name"), + ), + parameter_definitions={ + p_name: self.oci_tool_param( + description=p_def.get("description"), + type=JSON_TO_PYTHON_TYPES.get( + p_def.get("type"), p_def.get("type") + ), + is_required=p_name in parameters.get("required", []), + ) + for p_name, p_def in properties.items() + }, + ) + else: + raise ValueError( + f"Unsupported tool type {type(tool)}. Tool must be passed in as a BaseTool instance, JSON schema dict, or BaseModel type." # noqa: E501 + ) class MetaProvider(Provider): @@ -116,10 +382,10 @@ class MetaProvider(Provider): return response.data.chat_response.choices[0].message.content[0].text def chat_stream_to_text(self, event_data: Dict) -> str: - if "message" in event_data: - return event_data["message"]["content"][0]["text"] - else: - return "" + return event_data["message"]["content"][0]["text"] + + def is_chat_stream_end(self, event_data: Dict) -> bool: + return "message" not in event_data def chat_generation_info(self, response: Any) -> Dict[str, Any]: return { @@ -127,6 +393,11 @@ class MetaProvider(Provider): "time_created": str(response.data.chat_response.time_created), } + def chat_stream_generation_info(self, event_data: Dict) -> Dict[str, Any]: + return { + "finish_reason": event_data["finishReason"], + } + def get_role(self, message: BaseMessage) -> str: # meta only supports alternating user/assistant roles if isinstance(message, HumanMessage): @@ -138,7 +409,9 @@ class MetaProvider(Provider): else: raise ValueError(f"Got unknown type {message}") - def messages_to_oci_params(self, messages: List[BaseMessage]) -> Dict[str, Any]: + def messages_to_oci_params( + self, messages: List[BaseMessage], **kwargs: Any + ) -> Dict[str, Any]: oci_messages = [ self.oci_chat_message[self.get_role(msg)]( content=[self.oci_chat_message_content(text=msg.content)] @@ -153,6 +426,12 @@ class MetaProvider(Provider): return oci_params + def convert_to_oci_tool( + self, + tool: Union[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + ) -> Dict[str, Any]: + raise NotImplementedError("Tools not supported for Meta models") + class ChatOCIGenAI(BaseChatModel, OCIGenAIBase): """ChatOCIGenAI chat model integration. @@ -247,8 +526,8 @@ class ChatOCIGenAI(BaseChatModel, OCIGenAIBase): self, messages: List[BaseMessage], stop: Optional[List[str]], - kwargs: Dict[str, Any], stream: bool, + **kwargs: Any, ) -> Dict[str, Any]: try: from oci.generative_ai_inference import models @@ -258,8 +537,10 @@ class ChatOCIGenAI(BaseChatModel, OCIGenAIBase): "Could not import oci python package. " "Please make sure you have the oci package installed." ) from ex - oci_params = self._provider.messages_to_oci_params(messages) - oci_params["is_stream"] = stream # self.is_stream + + oci_params = self._provider.messages_to_oci_params(messages, **kwargs) + + oci_params["is_stream"] = stream _model_kwargs = self.model_kwargs or {} if stop is not None: @@ -280,6 +561,43 @@ class ChatOCIGenAI(BaseChatModel, OCIGenAIBase): return request + def bind_tools( + self, + tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + formatted_tools = [self._provider.convert_to_oci_tool(tool) for tool in tools] + return super().bind(tools=formatted_tools, **kwargs) + + def with_structured_output( + self, + schema: Union[Dict[Any, Any], Type[BaseModel]], + **kwargs: Any, + ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: + """Model wrapper that returns outputs formatted to match the given schema. + + Args: + schema: The output schema as a dict or a Pydantic class. If a Pydantic class + then the model output will be an object of that class. If a dict then + the model output will be a dict. + + Returns: + A Runnable that takes any ChatModel input and returns either a dict or + Pydantic class as output. + """ + llm = self.bind_tools([schema], **kwargs) + if isinstance(schema, type) and issubclass(schema, BaseModel): + output_parser: OutputParserLike = PydanticToolsParser( + tools=[schema], first_tool_only=True + ) + else: + key_name = getattr(self._provider.convert_to_oci_tool(schema), "name") + output_parser = JsonOutputKeyToolsParser( + key_name=key_name, first_tool_only=True + ) + + return llm | output_parser + def _generate( self, messages: List[BaseMessage], @@ -313,7 +631,7 @@ class ChatOCIGenAI(BaseChatModel, OCIGenAIBase): ) return generate_from_stream(stream_iter) - request = self._prepare_request(messages, stop, kwargs, stream=False) + request = self._prepare_request(messages, stop=stop, stream=False, **kwargs) response = self.client.chat(request) content = self._provider.chat_response_to_text(response) @@ -330,11 +648,22 @@ class ChatOCIGenAI(BaseChatModel, OCIGenAIBase): "content-length": response.headers["content-length"], } + if "tool_calls" in generation_info: + tool_calls = [ + _convert_oci_tool_call_to_langchain(tool_call) + for tool_call in response.data.chat_response.tool_calls + ] + else: + tool_calls = [] + + message = AIMessage( + content=content, + additional_kwargs=generation_info, + tool_calls=tool_calls, + ) return ChatResult( generations=[ - ChatGeneration( - message=AIMessage(content=content), generation_info=generation_info - ) + ChatGeneration(message=message, generation_info=generation_info) ], llm_output=llm_output, ) @@ -346,12 +675,42 @@ class ChatOCIGenAI(BaseChatModel, OCIGenAIBase): run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: - request = self._prepare_request(messages, stop, kwargs, stream=True) + request = self._prepare_request(messages, stop=stop, stream=True, **kwargs) response = self.client.chat(request) for event in response.data.events(): - delta = self._provider.chat_stream_to_text(json.loads(event.data)) - chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta)) - if run_manager: - run_manager.on_llm_new_token(delta, chunk=chunk) - yield chunk + event_data = json.loads(event.data) + if not self._provider.is_chat_stream_end(event_data): # still streaming + delta = self._provider.chat_stream_to_text(event_data) + chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta)) + if run_manager: + run_manager.on_llm_new_token(delta, chunk=chunk) + yield chunk + else: # stream end + generation_info = self._provider.chat_stream_generation_info(event_data) + tool_call_chunks = [] + if tool_calls := generation_info.get("tool_calls"): + content = self._provider.chat_stream_to_text(event_data) + try: + tool_call_chunks = [ + ToolCallChunk( + name=tool_call["function"].get("name"), + args=tool_call["function"].get("arguments"), + id=tool_call.get("id"), + index=tool_call.get("index"), + ) + for tool_call in tool_calls + ] + except KeyError: + pass + else: + content = "" + message = AIMessageChunk( + content=content, + additional_kwargs=generation_info, + tool_call_chunks=tool_call_chunks, + ) + yield ChatGenerationChunk( + message=message, + generation_info=generation_info, + ) diff --git a/libs/community/tests/unit_tests/chat_models/test_oci_generative_ai.py b/libs/community/tests/unit_tests/chat_models/test_oci_generative_ai.py index c557b15902d..a59893e8b21 100644 --- a/libs/community/tests/unit_tests/chat_models/test_oci_generative_ai.py +++ b/libs/community/tests/unit_tests/chat_models/test_oci_generative_ai.py @@ -38,6 +38,11 @@ def test_llm_chat(monkeypatch: MonkeyPatch, test_model_id: str) -> None: { "text": response_text, "finish_reason": "completed", + "is_search_required": None, + "search_queries": None, + "citations": None, + "documents": None, + "tool_calls": None, } ), "model_id": "cohere.command-r-16k",