mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-16 16:11:02 +00:00
anthropic[patch]: handle lists in function calling (#18609)
This commit is contained in:
parent
1831733c2e
commit
e169ee8863
@ -69,6 +69,16 @@ TOOL_PARAMETER_FORMAT = """<parameter>
|
|||||||
</parameter>"""
|
</parameter>"""
|
||||||
|
|
||||||
|
|
||||||
|
def _get_type(parameter: Dict[str, Any]) -> str:
|
||||||
|
if "type" in parameter:
|
||||||
|
return parameter["type"]
|
||||||
|
if "anyOf" in parameter:
|
||||||
|
return json.dumps({"anyOf": parameter["anyOf"]})
|
||||||
|
if "allOf" in parameter:
|
||||||
|
return json.dumps({"allOf": parameter["allOf"]})
|
||||||
|
return json.dumps(parameter)
|
||||||
|
|
||||||
|
|
||||||
def get_system_message(tools: List[Dict]) -> str:
|
def get_system_message(tools: List[Dict]) -> str:
|
||||||
tools_data: List[Dict] = [
|
tools_data: List[Dict] = [
|
||||||
{
|
{
|
||||||
@ -78,7 +88,7 @@ def get_system_message(tools: List[Dict]) -> str:
|
|||||||
[
|
[
|
||||||
TOOL_PARAMETER_FORMAT.format(
|
TOOL_PARAMETER_FORMAT.format(
|
||||||
parameter_name=name,
|
parameter_name=name,
|
||||||
parameter_type=parameter["type"],
|
parameter_type=_get_type(parameter),
|
||||||
parameter_description=parameter.get("description"),
|
parameter_description=parameter.get("description"),
|
||||||
)
|
)
|
||||||
for name, parameter in tool["parameters"]["properties"].items()
|
for name, parameter in tool["parameters"]["properties"].items()
|
||||||
@ -118,21 +128,44 @@ def _xml_to_dict(t: Any) -> Union[str, Dict[str, Any]]:
|
|||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
def _xml_to_tool_calls(elem: Any) -> List[Dict[str, Any]]:
|
def _xml_to_function_call(invoke: Any, tools: List[Dict]) -> Dict[str, Any]:
|
||||||
|
name = invoke.find("tool_name").text
|
||||||
|
arguments = _xml_to_dict(invoke.find("parameters"))
|
||||||
|
|
||||||
|
# make list elements in arguments actually lists
|
||||||
|
filtered_tools = [tool for tool in tools if tool["name"] == name]
|
||||||
|
if len(filtered_tools) > 0 and not isinstance(arguments, str):
|
||||||
|
tool = filtered_tools[0]
|
||||||
|
for key, value in arguments.items():
|
||||||
|
if key in tool["parameters"]["properties"]:
|
||||||
|
if "type" in tool["parameters"]["properties"][key]:
|
||||||
|
if tool["parameters"]["properties"][key][
|
||||||
|
"type"
|
||||||
|
] == "array" and not isinstance(value, list):
|
||||||
|
arguments[key] = [value]
|
||||||
|
if (
|
||||||
|
tool["parameters"]["properties"][key]["type"] != "object"
|
||||||
|
and isinstance(value, dict)
|
||||||
|
and len(value.keys()) == 1
|
||||||
|
):
|
||||||
|
arguments[key] = list(value.values())[0]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"function": {
|
||||||
|
"name": name,
|
||||||
|
"arguments": json.dumps(arguments),
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _xml_to_tool_calls(elem: Any, tools: List[Dict]) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Convert an XML element and its children into a dictionary of dictionaries.
|
Convert an XML element and its children into a dictionary of dictionaries.
|
||||||
"""
|
"""
|
||||||
invokes = elem.findall("invoke")
|
invokes = elem.findall("invoke")
|
||||||
return [
|
|
||||||
{
|
return [_xml_to_function_call(invoke, tools) for invoke in invokes]
|
||||||
"function": {
|
|
||||||
"name": invoke.find("tool_name").text,
|
|
||||||
"arguments": json.dumps(_xml_to_dict(invoke.find("parameters"))),
|
|
||||||
},
|
|
||||||
"type": "function",
|
|
||||||
}
|
|
||||||
for invoke in invokes
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@beta()
|
@beta()
|
||||||
@ -262,7 +295,7 @@ class ChatAnthropicTools(ChatAnthropic):
|
|||||||
xml_text = text[start:end]
|
xml_text = text[start:end]
|
||||||
|
|
||||||
xml = self._xmllib.fromstring(xml_text)
|
xml = self._xmllib.fromstring(xml_text)
|
||||||
additional_kwargs["tool_calls"] = _xml_to_tool_calls(xml)
|
additional_kwargs["tool_calls"] = _xml_to_tool_calls(xml, tools)
|
||||||
text = ""
|
text = ""
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "langchain-anthropic"
|
name = "langchain-anthropic"
|
||||||
version = "0.1.2"
|
version = "0.1.3"
|
||||||
description = "An integration package connecting AnthropicMessages and LangChain"
|
description = "An integration package connecting AnthropicMessages and LangChain"
|
||||||
authors = []
|
authors = []
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
@ -14,7 +14,7 @@ license = "MIT"
|
|||||||
python = ">=3.8.1,<4.0"
|
python = ">=3.8.1,<4.0"
|
||||||
langchain-core = "^0.1"
|
langchain-core = "^0.1"
|
||||||
anthropic = ">=0.17.0,<1"
|
anthropic = ">=0.17.0,<1"
|
||||||
defusedxml = {version = "^0.7.1", optional = true}
|
defusedxml = { version = "^0.7.1", optional = true }
|
||||||
|
|
||||||
[tool.poetry.group.test]
|
[tool.poetry.group.test]
|
||||||
optional = true
|
optional = true
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
"""Test ChatAnthropic chat model."""
|
"""Test ChatAnthropic chat model."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from enum import Enum
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
from langchain_core.prompts import ChatPromptTemplate
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||||
|
|
||||||
from langchain_anthropic.experimental import ChatAnthropicTools
|
from langchain_anthropic.experimental import ChatAnthropicTools
|
||||||
|
|
||||||
@ -129,3 +131,49 @@ def test_with_structured_output() -> None:
|
|||||||
assert isinstance(result, Person)
|
assert isinstance(result, Person)
|
||||||
assert result.name == "Erick"
|
assert result.name == "Erick"
|
||||||
assert result.age == 27
|
assert result.age == 27
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_complex_structured_output() -> None:
|
||||||
|
class ToneEnum(str, Enum):
|
||||||
|
positive = "positive"
|
||||||
|
negative = "negative"
|
||||||
|
|
||||||
|
class Email(BaseModel):
|
||||||
|
"""Relevant information about an email."""
|
||||||
|
|
||||||
|
sender: Optional[str] = Field(
|
||||||
|
None, description="The sender's name, if available"
|
||||||
|
)
|
||||||
|
sender_phone_number: Optional[str] = Field(
|
||||||
|
None, description="The sender's phone number, if available"
|
||||||
|
)
|
||||||
|
sender_address: Optional[str] = Field(
|
||||||
|
None, description="The sender's address, if available"
|
||||||
|
)
|
||||||
|
action_items: List[str] = Field(
|
||||||
|
..., description="A list of action items requested by the email"
|
||||||
|
)
|
||||||
|
topic: str = Field(
|
||||||
|
..., description="High level description of what the email is about"
|
||||||
|
)
|
||||||
|
tone: ToneEnum = Field(..., description="The tone of the email.")
|
||||||
|
|
||||||
|
prompt = ChatPromptTemplate.from_messages(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"human",
|
||||||
|
"What can you tell me about the following email? Make sure to answer in the correct format: {email}", # noqa: E501
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = ChatAnthropicTools(temperature=0, model_name="claude-3-sonnet-20240229")
|
||||||
|
|
||||||
|
extraction_chain = prompt | llm.with_structured_output(Email)
|
||||||
|
|
||||||
|
response = extraction_chain.invoke(
|
||||||
|
{
|
||||||
|
"email": "From: Erick. The email is about the new project. The tone is positive. The action items are to send the report and to schedule a meeting." # noqa: E501
|
||||||
|
}
|
||||||
|
) # noqa: E501
|
||||||
|
assert isinstance(response, Email)
|
||||||
|
Loading…
Reference in New Issue
Block a user