mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
anthropic[patch]: handle lists in function calling (#18609)
This commit is contained in:
@@ -69,6 +69,16 @@ TOOL_PARAMETER_FORMAT = """<parameter>
|
||||
</parameter>"""
|
||||
|
||||
|
||||
def _get_type(parameter: Dict[str, Any]) -> str:
|
||||
if "type" in parameter:
|
||||
return parameter["type"]
|
||||
if "anyOf" in parameter:
|
||||
return json.dumps({"anyOf": parameter["anyOf"]})
|
||||
if "allOf" in parameter:
|
||||
return json.dumps({"allOf": parameter["allOf"]})
|
||||
return json.dumps(parameter)
|
||||
|
||||
|
||||
def get_system_message(tools: List[Dict]) -> str:
|
||||
tools_data: List[Dict] = [
|
||||
{
|
||||
@@ -78,7 +88,7 @@ def get_system_message(tools: List[Dict]) -> str:
|
||||
[
|
||||
TOOL_PARAMETER_FORMAT.format(
|
||||
parameter_name=name,
|
||||
parameter_type=parameter["type"],
|
||||
parameter_type=_get_type(parameter),
|
||||
parameter_description=parameter.get("description"),
|
||||
)
|
||||
for name, parameter in tool["parameters"]["properties"].items()
|
||||
@@ -118,21 +128,44 @@ def _xml_to_dict(t: Any) -> Union[str, Dict[str, Any]]:
|
||||
return d
|
||||
|
||||
|
||||
def _xml_to_tool_calls(elem: Any) -> List[Dict[str, Any]]:
|
||||
def _xml_to_function_call(invoke: Any, tools: List[Dict]) -> Dict[str, Any]:
|
||||
name = invoke.find("tool_name").text
|
||||
arguments = _xml_to_dict(invoke.find("parameters"))
|
||||
|
||||
# make list elements in arguments actually lists
|
||||
filtered_tools = [tool for tool in tools if tool["name"] == name]
|
||||
if len(filtered_tools) > 0 and not isinstance(arguments, str):
|
||||
tool = filtered_tools[0]
|
||||
for key, value in arguments.items():
|
||||
if key in tool["parameters"]["properties"]:
|
||||
if "type" in tool["parameters"]["properties"][key]:
|
||||
if tool["parameters"]["properties"][key][
|
||||
"type"
|
||||
] == "array" and not isinstance(value, list):
|
||||
arguments[key] = [value]
|
||||
if (
|
||||
tool["parameters"]["properties"][key]["type"] != "object"
|
||||
and isinstance(value, dict)
|
||||
and len(value.keys()) == 1
|
||||
):
|
||||
arguments[key] = list(value.values())[0]
|
||||
|
||||
return {
|
||||
"function": {
|
||||
"name": name,
|
||||
"arguments": json.dumps(arguments),
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
|
||||
|
||||
def _xml_to_tool_calls(elem: Any, tools: List[Dict]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Convert an XML element and its children into a dictionary of dictionaries.
|
||||
"""
|
||||
invokes = elem.findall("invoke")
|
||||
return [
|
||||
{
|
||||
"function": {
|
||||
"name": invoke.find("tool_name").text,
|
||||
"arguments": json.dumps(_xml_to_dict(invoke.find("parameters"))),
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
for invoke in invokes
|
||||
]
|
||||
|
||||
return [_xml_to_function_call(invoke, tools) for invoke in invokes]
|
||||
|
||||
|
||||
@beta()
|
||||
@@ -262,7 +295,7 @@ class ChatAnthropicTools(ChatAnthropic):
|
||||
xml_text = text[start:end]
|
||||
|
||||
xml = self._xmllib.fromstring(xml_text)
|
||||
additional_kwargs["tool_calls"] = _xml_to_tool_calls(xml)
|
||||
additional_kwargs["tool_calls"] = _xml_to_tool_calls(xml, tools)
|
||||
text = ""
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user