diff --git a/libs/partners/anthropic/langchain_anthropic/experimental.py b/libs/partners/anthropic/langchain_anthropic/experimental.py index 1cec829dd5b..d39e8c27bf8 100644 --- a/libs/partners/anthropic/langchain_anthropic/experimental.py +++ b/libs/partners/anthropic/langchain_anthropic/experimental.py @@ -69,6 +69,16 @@ TOOL_PARAMETER_FORMAT = """ """ +def _get_type(parameter: Dict[str, Any]) -> str: + if "type" in parameter: + return parameter["type"] + if "anyOf" in parameter: + return json.dumps({"anyOf": parameter["anyOf"]}) + if "allOf" in parameter: + return json.dumps({"allOf": parameter["allOf"]}) + return json.dumps(parameter) + + def get_system_message(tools: List[Dict]) -> str: tools_data: List[Dict] = [ { @@ -78,7 +88,7 @@ def get_system_message(tools: List[Dict]) -> str: [ TOOL_PARAMETER_FORMAT.format( parameter_name=name, - parameter_type=parameter["type"], + parameter_type=_get_type(parameter), parameter_description=parameter.get("description"), ) for name, parameter in tool["parameters"]["properties"].items() @@ -118,21 +128,44 @@ def _xml_to_dict(t: Any) -> Union[str, Dict[str, Any]]: return d -def _xml_to_tool_calls(elem: Any) -> List[Dict[str, Any]]: +def _xml_to_function_call(invoke: Any, tools: List[Dict]) -> Dict[str, Any]: + name = invoke.find("tool_name").text + arguments = _xml_to_dict(invoke.find("parameters")) + + # make list elements in arguments actually lists + filtered_tools = [tool for tool in tools if tool["name"] == name] + if len(filtered_tools) > 0 and not isinstance(arguments, str): + tool = filtered_tools[0] + for key, value in arguments.items(): + if key in tool["parameters"]["properties"]: + if "type" in tool["parameters"]["properties"][key]: + if tool["parameters"]["properties"][key][ + "type" + ] == "array" and not isinstance(value, list): + arguments[key] = [value] + if ( + tool["parameters"]["properties"][key]["type"] != "object" + and isinstance(value, dict) + and len(value.keys()) == 1 + ): + arguments[key] = list(value.values())[0] + + return { + "function": { + "name": name, + "arguments": json.dumps(arguments), + }, + "type": "function", + } + + +def _xml_to_tool_calls(elem: Any, tools: List[Dict]) -> List[Dict[str, Any]]: """ Convert an XML element and its children into a dictionary of dictionaries. """ invokes = elem.findall("invoke") - return [ - { - "function": { - "name": invoke.find("tool_name").text, - "arguments": json.dumps(_xml_to_dict(invoke.find("parameters"))), - }, - "type": "function", - } - for invoke in invokes - ] + + return [_xml_to_function_call(invoke, tools) for invoke in invokes] @beta() @@ -262,7 +295,7 @@ class ChatAnthropicTools(ChatAnthropic): xml_text = text[start:end] xml = self._xmllib.fromstring(xml_text) - additional_kwargs["tool_calls"] = _xml_to_tool_calls(xml) + additional_kwargs["tool_calls"] = _xml_to_tool_calls(xml, tools) text = "" except Exception: pass diff --git a/libs/partners/anthropic/pyproject.toml b/libs/partners/anthropic/pyproject.toml index b15646479fc..ee8f855da11 100644 --- a/libs/partners/anthropic/pyproject.toml +++ b/libs/partners/anthropic/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain-anthropic" -version = "0.1.2" +version = "0.1.3" description = "An integration package connecting AnthropicMessages and LangChain" authors = [] readme = "README.md" @@ -14,7 +14,7 @@ license = "MIT" python = ">=3.8.1,<4.0" langchain-core = "^0.1" anthropic = ">=0.17.0,<1" -defusedxml = {version = "^0.7.1", optional = true} +defusedxml = { version = "^0.7.1", optional = true } [tool.poetry.group.test] optional = true diff --git a/libs/partners/anthropic/tests/integration_tests/test_experimental.py b/libs/partners/anthropic/tests/integration_tests/test_experimental.py index 59a9288f293..d6116d0fa82 100644 --- a/libs/partners/anthropic/tests/integration_tests/test_experimental.py +++ b/libs/partners/anthropic/tests/integration_tests/test_experimental.py @@ -1,9 +1,11 @@ """Test ChatAnthropic chat model.""" import json +from enum import Enum +from typing import List, Optional from langchain_core.prompts import ChatPromptTemplate -from langchain_core.pydantic_v1 import BaseModel +from langchain_core.pydantic_v1 import BaseModel, Field from langchain_anthropic.experimental import ChatAnthropicTools @@ -129,3 +131,49 @@ def test_with_structured_output() -> None: assert isinstance(result, Person) assert result.name == "Erick" assert result.age == 27 + + +def test_anthropic_complex_structured_output() -> None: + class ToneEnum(str, Enum): + positive = "positive" + negative = "negative" + + class Email(BaseModel): + """Relevant information about an email.""" + + sender: Optional[str] = Field( + None, description="The sender's name, if available" + ) + sender_phone_number: Optional[str] = Field( + None, description="The sender's phone number, if available" + ) + sender_address: Optional[str] = Field( + None, description="The sender's address, if available" + ) + action_items: List[str] = Field( + ..., description="A list of action items requested by the email" + ) + topic: str = Field( + ..., description="High level description of what the email is about" + ) + tone: ToneEnum = Field(..., description="The tone of the email.") + + prompt = ChatPromptTemplate.from_messages( + [ + ( + "human", + "What can you tell me about the following email? Make sure to answer in the correct format: {email}", # noqa: E501 + ), + ] + ) + + llm = ChatAnthropicTools(temperature=0, model_name="claude-3-sonnet-20240229") + + extraction_chain = prompt | llm.with_structured_output(Email) + + response = extraction_chain.invoke( + { + "email": "From: Erick. The email is about the new project. The tone is positive. The action items are to send the report and to schedule a meeting." # noqa: E501 + } + ) # noqa: E501 + assert isinstance(response, Email)