"""Experimental tool-calling support for Anthropic chat models.""" from __future__ import annotations import json from typing import ( Any, ) SYSTEM_PROMPT_FORMAT = """In this environment you have access to a set of tools you can use to answer the user's question. You may call them like this: $TOOL_NAME <$PARAMETER_NAME>$PARAMETER_VALUE ... Here are the tools available: {formatted_tools} """ # noqa: E501 TOOL_FORMAT = """ {tool_name} {tool_description} {formatted_parameters} """ TOOL_PARAMETER_FORMAT = """ {parameter_name} {parameter_type} {parameter_description} """ def _get_type(parameter: dict[str, Any]) -> str: if "type" in parameter: return parameter["type"] if "anyOf" in parameter: return json.dumps({"anyOf": parameter["anyOf"]}) if "allOf" in parameter: return json.dumps({"allOf": parameter["allOf"]}) return json.dumps(parameter) def get_system_message(tools: list[dict]) -> str: """Generate a system message that describes the available tools.""" tools_data: list[dict] = [ { "tool_name": tool["name"], "tool_description": tool["description"], "formatted_parameters": "\n".join( [ TOOL_PARAMETER_FORMAT.format( parameter_name=name, parameter_type=_get_type(parameter), parameter_description=parameter.get("description"), ) for name, parameter in tool["parameters"]["properties"].items() ], ), } for tool in tools ] tools_formatted = "\n".join( [ TOOL_FORMAT.format( tool_name=tool["tool_name"], tool_description=tool["tool_description"], formatted_parameters=tool["formatted_parameters"], ) for tool in tools_data ], ) return SYSTEM_PROMPT_FORMAT.format(formatted_tools=tools_formatted) def _xml_to_dict(t: Any) -> str | dict[str, Any]: # Base case: If the element has no children, return its text or an empty string. if len(t) == 0: return t.text or "" # Recursive case: The element has children. Convert them into a dictionary. d: dict[str, Any] = {} for child in t: if child.tag not in d: d[child.tag] = _xml_to_dict(child) else: # Handle multiple children with the same tag if not isinstance(d[child.tag], list): d[child.tag] = [d[child.tag]] # Convert existing entry into a list d[child.tag].append(_xml_to_dict(child)) return d def _xml_to_function_call(invoke: Any, tools: list[dict]) -> dict[str, Any]: name = invoke.find("tool_name").text arguments = _xml_to_dict(invoke.find("parameters")) # make list elements in arguments actually lists filtered_tools = [tool for tool in tools if tool["name"] == name] if len(filtered_tools) > 0 and not isinstance(arguments, str): tool = filtered_tools[0] for key, value in arguments.items(): if ( key in tool["parameters"]["properties"] and "type" in tool["parameters"]["properties"][key] ): if tool["parameters"]["properties"][key][ "type" ] == "array" and not isinstance(value, list): arguments[key] = [value] if ( tool["parameters"]["properties"][key]["type"] != "object" and isinstance(value, dict) and len(value.keys()) == 1 ): arguments[key] = next(iter(value.values())) return { "function": { "name": name, "arguments": json.dumps(arguments), }, "type": "function", } def _xml_to_tool_calls(elem: Any, tools: list[dict]) -> list[dict[str, Any]]: """Convert an XML element and its children into a dictionary of dictionaries.""" invokes = elem.findall("invoke") return [_xml_to_function_call(invoke, tools) for invoke in invokes]