diff --git a/cookbook/tool_call_messages.ipynb b/cookbook/tool_call_messages.ipynb new file mode 100644 index 00000000000..c7253e44bdc --- /dev/null +++ b/cookbook/tool_call_messages.ipynb @@ -0,0 +1,423 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "c48812ed-35bd-4fbe-9a2c-6c7335e5645e", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/chestercurme/repos/langchain/libs/core/langchain_core/_api/beta_decorator.py:87: LangChainBetaWarning: The function `bind_tools` is in beta. It is actively being worked on, so the API may change.\n", + " warn_beta(\n" + ] + } + ], + "source": [ + "from langchain_anthropic import ChatAnthropic\n", + "from langchain_core.runnables import ConfigurableField\n", + "from langchain_core.tools import tool\n", + "from langchain_openai import ChatOpenAI\n", + "\n", + "\n", + "@tool\n", + "def multiply(x: float, y: float) -> float:\n", + " \"\"\"Multiply 'x' times 'y'.\"\"\"\n", + " return x * y\n", + "\n", + "\n", + "@tool\n", + "def exponentiate(x: float, y: float) -> float:\n", + " \"\"\"Raise 'x' to the 'y'.\"\"\"\n", + " return x**y\n", + "\n", + "\n", + "@tool\n", + "def add(x: float, y: float) -> float:\n", + " \"\"\"Add 'x' and 'y'.\"\"\"\n", + " return x + y\n", + "\n", + "\n", + "tools = [multiply, exponentiate, add]\n", + "\n", + "gpt35 = ChatOpenAI(model=\"gpt-3.5-turbo-0125\", temperature=0).bind_tools(tools)\n", + "claude3 = ChatAnthropic(model=\"claude-3-sonnet-20240229\").bind_tools(tools)\n", + "llm_with_tools = gpt35.configurable_alternatives(\n", + " ConfigurableField(id=\"llm\"), default_key=\"gpt35\", claude3=claude3\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "4719ebdb-ad50-468e-9b30-fb5fb086e140", + "metadata": {}, + "source": [ + "# AgentExecutor" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "b98feaa5-8c2d-4125-9519-67114a1fef31", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import List, Tuple, Union\n", + "\n", + "from langchain.agents import AgentExecutor\n", + "from langchain.agents.output_parsers.openai_tools import OpenAIToolAgentAction\n", + "from langchain_core.agents import AgentFinish, _convert_agent_action_to_messages\n", + "from langchain_core.messages import (\n", + " AIMessage,\n", + " BaseMessage,\n", + " ToolMessage,\n", + ")\n", + "from langchain_core.prompts import ChatPromptTemplate\n", + "from langchain_core.runnables import RunnablePassthrough\n", + "\n", + "\n", + "def actions_observations_to_messages(\n", + " steps: List[Tuple[OpenAIToolAgentAction, str]],\n", + ") -> List[BaseMessage]:\n", + " messages = []\n", + " for action, observation in steps:\n", + " messages.extend([m for m in action.message_log if m not in messages])\n", + " messages.append(ToolMessage(observation, tool_call_id=action.tool_call_id))\n", + " return messages\n", + "\n", + "\n", + "def messages_to_action(\n", + " msg: AIMessage,\n", + ") -> Union[List[OpenAIToolAgentAction], AgentFinish]:\n", + " if isinstance(msg, AIMessage) and msg.tool_calls is not None:\n", + " actions = []\n", + " for tool_call in msg.tool_calls:\n", + " actions.append(\n", + " OpenAIToolAgentAction(\n", + " tool=tool_call.name,\n", + " tool_input=tool_call.args,\n", + " tool_call_id=tool_call.id,\n", + " message_log=[msg],\n", + " log=\"\",\n", + " )\n", + " )\n", + " return actions\n", + " else:\n", + " return AgentFinish(return_values={\"output\": msg.content}, log=\"\")\n", + "\n", + "\n", + "prompt = ChatPromptTemplate.from_messages(\n", + " [\n", + " (\"system\", \"You're a helpful assistant with access to tools\"),\n", + " (\"human\", \"{input}\"),\n", + " (\"placeholder\", \"{agent_scratchpad}\"),\n", + " ]\n", + ")\n", + "\n", + "agent = (\n", + " RunnablePassthrough.assign(\n", + " agent_scratchpad=lambda x: actions_observations_to_messages(\n", + " x[\"intermediate_steps\"]\n", + " ),\n", + " )\n", + " | prompt\n", + " | llm_with_tools\n", + " | messages_to_action\n", + ")\n", + "\n", + "agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "b4c0fc7a-80bb-4bb8-a87b-7388291ae8b6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m\u001b[0m\u001b[33;1m\u001b[1;3m300.03770462067547\u001b[0m\u001b[32;1m\u001b[1;3m\u001b[0m\u001b[38;5;200m\u001b[1;3m-900.8841\u001b[0m\u001b[32;1m\u001b[1;3m\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "{'input': \"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\",\n", + " 'output': 'The result of \\\\(3 + 5^{2.743}\\\\) is approximately 300.04, and the result of \\\\(17.24 - 918.1241\\\\) is approximately -900.88.'}" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent_executor.invoke(\n", + " {\"input\": \"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\"}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "41a3a3c8-185d-4861-b6f0-7592668feb62", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/chestercurme/repos/langchain/libs/partners/anthropic/langchain_anthropic/chat_models.py:336: UserWarning: stream: Tool use is not yet supported in streaming mode.\n", + " warnings.warn(\"stream: Tool use is not yet supported in streaming mode.\")\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32;1m\u001b[1;3m\u001b[0m\u001b[33;1m\u001b[1;3m82.65606421491815\u001b[0m" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/chestercurme/repos/langchain/libs/partners/anthropic/langchain_anthropic/chat_models.py:336: UserWarning: stream: Tool use is not yet supported in streaming mode.\n", + " warnings.warn(\"stream: Tool use is not yet supported in streaming mode.\")\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32;1m\u001b[1;3m\u001b[0m\u001b[38;5;200m\u001b[1;3m85.65606421491815\u001b[0m" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/chestercurme/repos/langchain/libs/partners/anthropic/langchain_anthropic/chat_models.py:336: UserWarning: stream: Tool use is not yet supported in streaming mode.\n", + " warnings.warn(\"stream: Tool use is not yet supported in streaming mode.\")\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32;1m\u001b[1;3m\u001b[0m\u001b[38;5;200m\u001b[1;3m-900.8841\u001b[0m" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/chestercurme/repos/langchain/libs/partners/anthropic/langchain_anthropic/chat_models.py:336: UserWarning: stream: Tool use is not yet supported in streaming mode.\n", + " warnings.warn(\"stream: Tool use is not yet supported in streaming mode.\")\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32;1m\u001b[1;3m\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "{'input': \"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\",\n", + " 'output': 'Therefore, 17.24 - 918.1241 = -900.8841'}" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent_executor = AgentExecutor(\n", + " agent=agent.with_config(configurable={\"llm\": \"claude3\"}), tools=tools, verbose=True\n", + ")\n", + "agent_executor.invoke(\n", + " {\"input\": \"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\"},\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "9c186263-1b98-4cb2-b6d1-71f65eb0d811", + "metadata": {}, + "source": [ + "# LangGraph" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "28fc2c60-7dbc-428a-8983-1a6a15ea30d2", + "metadata": {}, + "outputs": [], + "source": [ + "import operator\n", + "from typing import Annotated, Sequence, TypedDict\n", + "\n", + "from langchain_core.messages import AIMessage, BaseMessage, HumanMessage\n", + "from langchain_core.runnables import RunnableLambda\n", + "from langgraph.graph import END, StateGraph\n", + "\n", + "\n", + "class AgentState(TypedDict):\n", + " messages: Annotated[Sequence[BaseMessage], operator.add]\n", + "\n", + "\n", + "def should_continue(state):\n", + " return \"continue\" if state[\"messages\"][-1].tool_calls is not None else \"end\"\n", + "\n", + "\n", + "def call_model(state, config):\n", + " return {\"messages\": [llm_with_tools.invoke(state[\"messages\"], config=config)]}\n", + "\n", + "\n", + "def _invoke_tool(tool_call):\n", + " tool = {tool.name: tool for tool in tools}[tool_call.name]\n", + " return ToolMessage(tool.invoke(tool_call.args), tool_call_id=tool_call.id)\n", + "\n", + "\n", + "tool_executor = RunnableLambda(_invoke_tool)\n", + "\n", + "\n", + "def call_tools(state):\n", + " last_message = state[\"messages\"][-1]\n", + " return {\"messages\": tool_executor.batch(last_message.tool_calls)}\n", + "\n", + "\n", + "workflow = StateGraph(AgentState)\n", + "workflow.add_node(\"agent\", call_model)\n", + "workflow.add_node(\"action\", call_tools)\n", + "workflow.set_entry_point(\"agent\")\n", + "workflow.add_conditional_edges(\n", + " \"agent\",\n", + " should_continue,\n", + " {\n", + " \"continue\": \"action\",\n", + " \"end\": END,\n", + " },\n", + ")\n", + "workflow.add_edge(\"action\", \"agent\")\n", + "graph = workflow.compile()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "24463798-74e6-4c39-8092-7a1524d83225", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'messages': [HumanMessage(content=\"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\"),\n", + " AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_kbBUUeqK75fZZqDTvu8aif7Z', 'function': {'arguments': '{\"x\": 8, \"y\": 2.743}', 'name': 'exponentiate'}, 'type': 'function'}, {'id': 'call_pBD8daSyXidXnrIyG4vG5C9O', 'function': {'arguments': '{\"x\": 17.24, \"y\": -918.1241}', 'name': 'add'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 58, 'prompt_tokens': 168, 'total_tokens': 226}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': 'fp_b28b39ffa8', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-8e1d9687-611c-4c8e-9fcd-ef6e48bd22a6-0', tool_calls=[ToolCall(name='exponentiate', args={'x': 8, 'y': 2.743}, id='call_kbBUUeqK75fZZqDTvu8aif7Z'), ToolCall(name='add', args={'x': 17.24, 'y': -918.1241}, id='call_pBD8daSyXidXnrIyG4vG5C9O')]),\n", + " ToolMessage(content='300.03770462067547', tool_call_id='call_kbBUUeqK75fZZqDTvu8aif7Z'),\n", + " ToolMessage(content='-900.8841', tool_call_id='call_pBD8daSyXidXnrIyG4vG5C9O'),\n", + " AIMessage(content='The result of \\\\(3 + 5^{2.743}\\\\) is approximately 300.04, and the result of \\\\(17.24 - 918.1241\\\\) is approximately -900.88.', response_metadata={'token_usage': {'completion_tokens': 44, 'prompt_tokens': 251, 'total_tokens': 295}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': 'fp_b28b39ffa8', 'finish_reason': 'stop', 'logprobs': None}, id='run-47fe5cbc-3f25-44c3-85b2-6540c3054a77-0')]}" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "graph.invoke(\n", + " {\n", + " \"messages\": [\n", + " HumanMessage(\n", + " \"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\"\n", + " )\n", + " ]\n", + " }\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "073c074e-d722-42e0-85ec-c62c079207e4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'messages': [HumanMessage(content=\"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\"),\n", + " AIMessage(content=[{'text': \"Okay, let's break this down into steps:\", 'type': 'text'}, {'id': 'toolu_01DJkSDpB8ztmJx2DLNbc3eW', 'input': {'x': 5, 'y': 2.743}, 'name': 'exponentiate', 'type': 'tool_use'}], response_metadata={'id': 'msg_01KuVNohyJr24cPhJkY3XVtt', 'model': 'claude-3-sonnet-20240229', 'stop_reason': 'tool_use', 'stop_sequence': None, 'usage': {'input_tokens': 450, 'output_tokens': 84}}, id='run-336cdfb6-0fe4-4d7a-9946-9f01c2eb41ae-0', tool_calls=[ToolCall(name='exponentiate', args={'x': 5, 'y': 2.743}, id='toolu_01DJkSDpB8ztmJx2DLNbc3eW', index=1)]),\n", + " ToolMessage(content='82.65606421491815', tool_call_id='toolu_01DJkSDpB8ztmJx2DLNbc3eW'),\n", + " AIMessage(content=[{'text': 'To get 5 raised to the 2.743 power.', 'type': 'text'}, {'id': 'toolu_01MKQqnDw5CtyuKjQP8YG1FX', 'input': {'x': 3, 'y': 82.65606421491815}, 'name': 'add', 'type': 'tool_use'}], response_metadata={'id': 'msg_01UBsKkvA4StUR4NEvoFFFep', 'model': 'claude-3-sonnet-20240229', 'stop_reason': 'tool_use', 'stop_sequence': None, 'usage': {'input_tokens': 552, 'output_tokens': 91}}, id='run-9d25b7bd-58aa-47dd-933f-15459b24b2c2-0', tool_calls=[ToolCall(name='add', args={'x': 3, 'y': 82.65606421491815}, id='toolu_01MKQqnDw5CtyuKjQP8YG1FX', index=1)]),\n", + " ToolMessage(content='85.65606421491815', tool_call_id='toolu_01MKQqnDw5CtyuKjQP8YG1FX'),\n", + " AIMessage(content=[{'text': 'So 3 plus 5 raised to the 2.743 power is 85.656.\\n\\nFor the second part:', 'type': 'text'}, {'id': 'toolu_019Wb2zPouCR3dw2bSKvCRUL', 'input': {'x': 17.24, 'y': -918.1241}, 'name': 'add', 'type': 'tool_use'}], response_metadata={'id': 'msg_01Y2H2L8FWcDtVkCtuosie2P', 'model': 'claude-3-sonnet-20240229', 'stop_reason': 'tool_use', 'stop_sequence': None, 'usage': {'input_tokens': 661, 'output_tokens': 105}}, id='run-e553c1e3-24ba-4e1b-93ba-6f1985932db4-0', tool_calls=[ToolCall(name='add', args={'x': 17.24, 'y': -918.1241}, id='toolu_019Wb2zPouCR3dw2bSKvCRUL', index=1)]),\n", + " ToolMessage(content='-900.8841', tool_call_id='toolu_019Wb2zPouCR3dw2bSKvCRUL'),\n", + " AIMessage(content='Therefore, 17.24 - 918.1241 = -900.8841', response_metadata={'id': 'msg_01Q14dqvaCD2eA4zwrUvxTcF', 'model': 'claude-3-sonnet-20240229', 'stop_reason': 'end_turn', 'stop_sequence': None, 'usage': {'input_tokens': 782, 'output_tokens': 24}}, id='run-f6b6e525-2df6-4617-9bb3-b39d5cc963a9-0')]}" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "graph.invoke(\n", + " {\n", + " \"messages\": [\n", + " HumanMessage(\n", + " \"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\"\n", + " )\n", + " ]\n", + " },\n", + " config={\"configurable\": {\"llm\": \"claude3\"}},\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/core/langchain_core/messages/__init__.py b/libs/core/langchain_core/messages/__init__.py index 2680a052cb9..f823215244f 100644 --- a/libs/core/langchain_core/messages/__init__.py +++ b/libs/core/langchain_core/messages/__init__.py @@ -15,7 +15,13 @@ """ # noqa: E501 -from langchain_core.messages.ai import AIMessage, AIMessageChunk +from langchain_core.messages.ai import ( + AIMessage, + AIMessageChunk, + InvalidToolCall, + ToolCall, + ToolCallChunk, +) from langchain_core.messages.base import ( BaseMessage, BaseMessageChunk, @@ -50,9 +56,12 @@ __all__ = [ "FunctionMessageChunk", "HumanMessage", "HumanMessageChunk", + "InvalidToolCall", "MessageLikeRepresentation", "SystemMessage", "SystemMessageChunk", + "ToolCall", + "ToolCallChunk", "ToolMessage", "ToolMessageChunk", "_message_from_dict", diff --git a/libs/core/langchain_core/messages/ai.py b/libs/core/langchain_core/messages/ai.py index 22740326e8b..e83e31b88f2 100644 --- a/libs/core/langchain_core/messages/ai.py +++ b/libs/core/langchain_core/messages/ai.py @@ -1,3 +1,4 @@ +import warnings from typing import Any, List, Literal from langchain_core.messages.base import ( @@ -5,7 +6,18 @@ from langchain_core.messages.base import ( BaseMessageChunk, merge_content, ) -from langchain_core.utils._merge import merge_dicts +from langchain_core.messages.tool import ( + InvalidToolCall, + ToolCall, + ToolCallChunk, + default_tool_chunk_parser, + default_tool_parser, +) +from langchain_core.pydantic_v1 import root_validator +from langchain_core.utils._merge import merge_dicts, merge_lists +from langchain_core.utils.json import ( + parse_partial_json, +) class AIMessage(BaseMessage): @@ -16,6 +28,11 @@ class AIMessage(BaseMessage): conversation. """ + tool_calls: List[ToolCall] = [] + """If provided, tool calls associated with the message.""" + invalid_tool_calls: List[InvalidToolCall] = [] + """If provided, tool calls with parsing errors associated with the message.""" + type: Literal["ai"] = "ai" @classmethod @@ -23,6 +40,34 @@ class AIMessage(BaseMessage): """Get the namespace of the langchain object.""" return ["langchain", "schema", "messages"] + @root_validator + def _backwards_compat_tool_calls(cls, values: dict) -> dict: + raw_tool_calls = values.get("additional_kwargs", {}).get("tool_calls") + tool_calls = ( + values.get("tool_calls") + or values.get("invalid_tool_calls") + or values.get("tool_call_chunks") + ) + if raw_tool_calls and not tool_calls: + warnings.warn( + "New langchain packages are available that more efficiently handle " + "tool calling. Please upgrade your packages to versions that set " + "message tool calls. e.g., `pip install --upgrade langchain-anthropic" + "`, pip install--upgrade langchain-openai`, etc." + ) + try: + if issubclass(cls, AIMessageChunk): # type: ignore + values["tool_call_chunks"] = default_tool_chunk_parser( + raw_tool_calls + ) + else: + tool_calls, invalid_tool_calls = default_tool_parser(raw_tool_calls) + values["tool_calls"] = tool_calls + values["invalid_tool_calls"] = invalid_tool_calls + except Exception: + pass + return values + AIMessage.update_forward_refs() @@ -35,11 +80,48 @@ class AIMessageChunk(AIMessage, BaseMessageChunk): # non-chunk variant. type: Literal["AIMessageChunk"] = "AIMessageChunk" # type: ignore[assignment] # noqa: E501 + tool_call_chunks: List[ToolCallChunk] = [] + """If provided, tool call chunks associated with the message.""" + @classmethod def get_lc_namespace(cls) -> List[str]: """Get the namespace of the langchain object.""" return ["langchain", "schema", "messages"] + @root_validator() + def init_tool_calls(cls, values: dict) -> dict: + if not values["tool_call_chunks"]: + values["tool_calls"] = [] + values["invalid_tool_calls"] = [] + return values + tool_calls = [] + invalid_tool_calls = [] + for chunk in values["tool_call_chunks"]: + try: + args_ = parse_partial_json(chunk["args"]) + if isinstance(args_, dict): + tool_calls.append( + ToolCall( + name=chunk["name"] or "", + args=args_, + id=chunk["id"], + ) + ) + else: + raise ValueError("Malformed args.") + except Exception: + invalid_tool_calls.append( + InvalidToolCall( + name=chunk["name"], + args=chunk["args"], + id=chunk["id"], + error="Malformed args.", + ) + ) + values["tool_calls"] = tool_calls + values["invalid_tool_calls"] = invalid_tool_calls + return values + def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore if isinstance(other, AIMessageChunk): if self.example != other.example: @@ -47,15 +129,41 @@ class AIMessageChunk(AIMessage, BaseMessageChunk): "Cannot concatenate AIMessageChunks with different example values." ) + content = merge_content(self.content, other.content) + additional_kwargs = merge_dicts( + self.additional_kwargs, other.additional_kwargs + ) + response_metadata = merge_dicts( + self.response_metadata, other.response_metadata + ) + + # Merge tool call chunks + if self.tool_call_chunks or other.tool_call_chunks: + raw_tool_calls = merge_lists( + self.tool_call_chunks, + other.tool_call_chunks, + ) + if raw_tool_calls: + tool_call_chunks = [ + ToolCallChunk( + name=rtc.get("name"), + args=rtc.get("args"), + index=rtc.get("index"), + id=rtc.get("id"), + ) + for rtc in raw_tool_calls + ] + else: + tool_call_chunks = [] + else: + tool_call_chunks = [] + return self.__class__( example=self.example, - content=merge_content(self.content, other.content), - additional_kwargs=merge_dicts( - self.additional_kwargs, other.additional_kwargs - ), - response_metadata=merge_dicts( - self.response_metadata, other.response_metadata - ), + content=content, + additional_kwargs=additional_kwargs, + tool_call_chunks=tool_call_chunks, + response_metadata=response_metadata, id=self.id, ) diff --git a/libs/core/langchain_core/messages/tool.py b/libs/core/langchain_core/messages/tool.py index 03e333e2ede..169e6856ae9 100644 --- a/libs/core/langchain_core/messages/tool.py +++ b/libs/core/langchain_core/messages/tool.py @@ -1,4 +1,7 @@ -from typing import Any, List, Literal +import json +from typing import Any, Dict, List, Literal, Optional, Tuple + +from typing_extensions import TypedDict from langchain_core.messages.base import ( BaseMessage, @@ -61,3 +64,112 @@ class ToolMessageChunk(ToolMessage, BaseMessageChunk): ) return super().__add__(other) + + +class ToolCall(TypedDict): + """A call to a tool. + + Attributes: + name: (str) the name of the tool to be called + args: (dict) the arguments to the tool call + id: (str) if provided, an identifier associated with the tool call + """ + + name: str + args: Dict[str, Any] + id: Optional[str] + + +class ToolCallChunk(TypedDict): + """A chunk of a tool call (e.g., as part of a stream). + + When merging ToolCallChunks (e.g., via AIMessageChunk.__add__), + all string attributes are concatenated. Chunks are only merged if their + values of `index` are equal and not None. + + Example: + + .. code-block:: python + + left_chunks = [ToolCallChunk(name="foo", args='{"a":', index=0)] + right_chunks = [ToolCallChunk(name=None, args='1}', index=0)] + ( + AIMessageChunk(content="", tool_call_chunks=left_chunks) + + AIMessageChunk(content="", tool_call_chunks=right_chunks) + ).tool_call_chunks == [ToolCallChunk(name='foo', args='{"a":1}', index=0)] + + Attributes: + name: (str) if provided, a substring of the name of the tool to be called + args: (str) if provided, a JSON substring of the arguments to the tool call + id: (str) if provided, a substring of an identifier for the tool call + index: (int) if provided, the index of the tool call in a sequence + """ + + name: Optional[str] + args: Optional[str] + id: Optional[str] + index: Optional[int] + + +class InvalidToolCall(TypedDict): + """Allowance for errors made by LLM. + + Here we add an `error` key to surface errors made during generation + (e.g., invalid JSON arguments.) + """ + + name: Optional[str] + args: Optional[str] + id: Optional[str] + error: Optional[str] + + +def default_tool_parser( + raw_tool_calls: List[dict], +) -> Tuple[List[ToolCall], List[InvalidToolCall]]: + """Best-effort parsing of tools.""" + tool_calls = [] + invalid_tool_calls = [] + for tool_call in raw_tool_calls: + if "function" not in tool_call: + continue + else: + function_name = tool_call["function"]["name"] + try: + function_args = json.loads(tool_call["function"]["arguments"]) + parsed = ToolCall( + name=function_name or "", + args=function_args or {}, + id=tool_call.get("id"), + ) + tool_calls.append(parsed) + except json.JSONDecodeError: + invalid_tool_calls.append( + InvalidToolCall( + name=function_name, + args=tool_call["function"]["arguments"], + id=tool_call.get("id"), + error="Malformed args.", + ) + ) + return tool_calls, invalid_tool_calls + + +def default_tool_chunk_parser(raw_tool_calls: List[dict]) -> List[ToolCallChunk]: + """Best-effort parsing of tool chunks.""" + tool_call_chunks = [] + for tool_call in raw_tool_calls: + if "function" not in tool_call: + function_args = None + function_name = None + else: + function_args = tool_call["function"]["arguments"] + function_name = tool_call["function"]["name"] + parsed = ToolCallChunk( + name=function_name, + args=function_args, + id=tool_call.get("id"), + index=tool_call.get("index"), + ) + tool_call_chunks.append(parsed) + return tool_call_chunks diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index 386e75c1ba9..8f8957a9baf 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -1,6 +1,9 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union -from langchain_core.messages.ai import AIMessage, AIMessageChunk +from langchain_core.messages.ai import ( + AIMessage, + AIMessageChunk, +) from langchain_core.messages.base import ( BaseMessage, BaseMessageChunk, @@ -119,8 +122,11 @@ def message_chunk_to_message(chunk: BaseMessageChunk) -> BaseMessage: if not isinstance(chunk, BaseMessageChunk): return chunk # chunk classes always have the equivalent non-chunk class as their first parent + ignore_keys = ["type"] + if isinstance(chunk, AIMessageChunk): + ignore_keys.append("tool_call_chunks") return chunk.__class__.__mro__[1]( - **{k: v for k, v in chunk.__dict__.items() if k != "type"} + **{k: v for k, v in chunk.__dict__.items() if k not in ignore_keys} ) diff --git a/libs/core/langchain_core/output_parsers/json.py b/libs/core/langchain_core/output_parsers/json.py index 5d8298986b3..9652fde424e 100644 --- a/libs/core/langchain_core/output_parsers/json.py +++ b/libs/core/langchain_core/output_parsers/json.py @@ -1,9 +1,8 @@ from __future__ import annotations import json -import re from json import JSONDecodeError -from typing import Any, Callable, List, Optional, Type, TypeVar, Union +from typing import Any, List, Optional, Type, TypeVar, Union import jsonpatch # type: ignore[import] import pydantic # pydantic: ignore @@ -12,6 +11,11 @@ from langchain_core.exceptions import OutputParserException from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser from langchain_core.outputs import Generation +from langchain_core.utils.json import ( + parse_and_check_json_markdown, + parse_json_markdown, + parse_partial_json, +) from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION if PYDANTIC_MAJOR_VERSION < 2: @@ -26,182 +30,6 @@ else: TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel) -def _replace_new_line(match: re.Match[str]) -> str: - value = match.group(2) - value = re.sub(r"\n", r"\\n", value) - value = re.sub(r"\r", r"\\r", value) - value = re.sub(r"\t", r"\\t", value) - value = re.sub(r'(? str: - """ - The LLM response for `action_input` may be a multiline - string containing unescaped newlines, tabs or quotes. This function - replaces those characters with their escaped counterparts. - (newlines in JSON must be double-escaped: `\\n`) - """ - if isinstance(multiline_string, (bytes, bytearray)): - multiline_string = multiline_string.decode() - - multiline_string = re.sub( - r'("action_input"\:\s*")(.*?)(")', - _replace_new_line, - multiline_string, - flags=re.DOTALL, - ) - - return multiline_string - - -# Adapted from https://github.com/KillianLucas/open-interpreter/blob/5b6080fae1f8c68938a1e4fa8667e3744084ee21/interpreter/utils/parse_partial_json.py -# MIT License -def parse_partial_json(s: str, *, strict: bool = False) -> Any: - """Parse a JSON string that may be missing closing braces. - - Args: - s: The JSON string to parse. - strict: Whether to use strict parsing. Defaults to False. - - Returns: - The parsed JSON object as a Python dictionary. - """ - # Attempt to parse the string as-is. - try: - return json.loads(s, strict=strict) - except json.JSONDecodeError: - pass - - # Initialize variables. - new_s = "" - stack = [] - is_inside_string = False - escaped = False - - # Process each character in the string one at a time. - for char in s: - if is_inside_string: - if char == '"' and not escaped: - is_inside_string = False - elif char == "\n" and not escaped: - char = "\\n" # Replace the newline character with the escape sequence. - elif char == "\\": - escaped = not escaped - else: - escaped = False - else: - if char == '"': - is_inside_string = True - escaped = False - elif char == "{": - stack.append("}") - elif char == "[": - stack.append("]") - elif char == "}" or char == "]": - if stack and stack[-1] == char: - stack.pop() - else: - # Mismatched closing character; the input is malformed. - return None - - # Append the processed character to the new string. - new_s += char - - # If we're still inside a string at the end of processing, - # we need to close the string. - if is_inside_string: - new_s += '"' - - # Try to parse mods of string until we succeed or run out of characters. - while new_s: - final_s = new_s - - # Close any remaining open structures in the reverse - # order that they were opened. - for closing_char in reversed(stack): - final_s += closing_char - - # Attempt to parse the modified string as JSON. - try: - return json.loads(final_s, strict=strict) - except json.JSONDecodeError: - # If we still can't parse the string as JSON, - # try removing the last character - new_s = new_s[:-1] - - # If we got here, we ran out of characters to remove - # and still couldn't parse the string as JSON, so return the parse error - # for the original string. - return json.loads(s, strict=strict) - - -def parse_json_markdown( - json_string: str, *, parser: Callable[[str], Any] = parse_partial_json -) -> dict: - """ - Parse a JSON string from a Markdown string. - - Args: - json_string: The Markdown string. - - Returns: - The parsed JSON object as a Python dictionary. - """ - try: - return _parse_json(json_string, parser=parser) - except json.JSONDecodeError: - # Try to find JSON string within triple backticks - match = re.search(r"```(json)?(.*)", json_string, re.DOTALL) - - # If no match found, assume the entire string is a JSON string - if match is None: - json_str = json_string - else: - # If match found, use the content within the backticks - json_str = match.group(2) - return _parse_json(json_str, parser=parser) - - -def _parse_json( - json_str: str, *, parser: Callable[[str], Any] = parse_partial_json -) -> dict: - # Strip whitespace and newlines from the start and end - json_str = json_str.strip().strip("`") - - # handle newlines and other special characters inside the returned value - json_str = _custom_parser(json_str) - - # Parse the JSON string into a Python dictionary - return parser(json_str) - - -def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict: - """ - Parse a JSON string from a Markdown string and check that it - contains the expected keys. - - Args: - text: The Markdown string. - expected_keys: The expected keys in the JSON string. - - Returns: - The parsed JSON object as a Python dictionary. - """ - try: - json_obj = parse_json_markdown(text) - except json.JSONDecodeError as e: - raise OutputParserException(f"Got invalid JSON object. Error: {e}") - for key in expected_keys: - if key not in json_obj: - raise OutputParserException( - f"Got invalid return object. Expected key `{key}` " - f"to be present, but got {json_obj}" - ) - return json_obj - - class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]): """Parse the output of an LLM call to a JSON object. @@ -267,3 +95,5 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]): # For backwards compatibility SimpleJsonOutputParser = JsonOutputParser +parse_partial_json = parse_partial_json +parse_and_check_json_markdown = parse_and_check_json_markdown diff --git a/libs/core/langchain_core/output_parsers/openai_tools.py b/libs/core/langchain_core/output_parsers/openai_tools.py index fb1f88aca24..da1f638588d 100644 --- a/libs/core/langchain_core/output_parsers/openai_tools.py +++ b/libs/core/langchain_core/output_parsers/openai_tools.py @@ -1,13 +1,89 @@ import copy import json from json import JSONDecodeError -from typing import Any, List, Type +from typing import Any, Dict, List, Optional, Type from langchain_core.exceptions import OutputParserException +from langchain_core.messages import AIMessage, InvalidToolCall from langchain_core.output_parsers import BaseCumulativeTransformOutputParser -from langchain_core.output_parsers.json import parse_partial_json from langchain_core.outputs import ChatGeneration, Generation from langchain_core.pydantic_v1 import BaseModel, ValidationError +from langchain_core.utils.json import parse_partial_json + + +def parse_tool_call( + raw_tool_call: Dict[str, Any], + *, + partial: bool = False, + strict: bool = False, + return_id: bool = True, +) -> Optional[Dict[str, Any]]: + """Parse a single tool call.""" + if "function" not in raw_tool_call: + return None + if partial: + try: + function_args = parse_partial_json( + raw_tool_call["function"]["arguments"], strict=strict + ) + except (JSONDecodeError, TypeError): # None args raise TypeError + return None + else: + try: + function_args = json.loads( + raw_tool_call["function"]["arguments"], strict=strict + ) + except JSONDecodeError as e: + raise OutputParserException( + f"Function {raw_tool_call['function']['name']} arguments:\n\n" + f"{raw_tool_call['function']['arguments']}\n\nare not valid JSON. " + f"Received JSONDecodeError {e}" + ) + parsed = { + "name": raw_tool_call["function"]["name"] or "", + "args": function_args or {}, + } + if return_id: + parsed["id"] = raw_tool_call["id"] + return parsed + + +def make_invalid_tool_call( + raw_tool_call: Dict[str, Any], + error_msg: Optional[str], +) -> InvalidToolCall: + """Create an InvalidToolCall from a raw tool call.""" + return InvalidToolCall( + name=raw_tool_call["function"]["name"], + args=raw_tool_call["function"]["arguments"], + id=raw_tool_call.get("id"), + error=error_msg, + ) + + +def parse_tool_calls( + raw_tool_calls: List[dict], + *, + partial: bool = False, + strict: bool = False, + return_id: bool = True, +) -> List[dict]: + """Parse a list of tool calls.""" + final_tools = [] + exceptions = [] + for tool_call in raw_tool_calls: + try: + parsed = parse_tool_call( + tool_call, partial=partial, strict=strict, return_id=return_id + ) + if parsed: + final_tools.append(parsed) + except OutputParserException as e: + exceptions.append(str(e)) + continue + if exceptions: + raise OutputParserException("\n\n".join(exceptions)) + return final_tools class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]): @@ -40,47 +116,29 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]): "This output parser can only be used with a chat generation." ) message = generation.message - try: - tool_calls = copy.deepcopy(message.additional_kwargs["tool_calls"]) - except KeyError: - return [] + if isinstance(message, AIMessage) and message.tool_calls: + tool_calls = [dict(tc) for tc in message.tool_calls] + for tool_call in tool_calls: + if not self.return_id: + _ = tool_call.pop("id") + else: + try: + raw_tool_calls = copy.deepcopy(message.additional_kwargs["tool_calls"]) + except KeyError: + return [] + tool_calls = parse_tool_calls( + raw_tool_calls, + partial=partial, + strict=self.strict, + return_id=self.return_id, + ) + # for backwards compatibility + for tc in tool_calls: + tc["type"] = tc.pop("name") - final_tools = [] - exceptions = [] - for tool_call in tool_calls: - if "function" not in tool_call: - continue - if partial: - try: - function_args = parse_partial_json( - tool_call["function"]["arguments"], strict=self.strict - ) - except JSONDecodeError: - continue - else: - try: - function_args = json.loads( - tool_call["function"]["arguments"], strict=self.strict - ) - except JSONDecodeError as e: - exceptions.append( - f"Function {tool_call['function']['name']} arguments:\n\n" - f"{tool_call['function']['arguments']}\n\nare not valid JSON. " - f"Received JSONDecodeError {e}" - ) - continue - parsed = { - "type": tool_call["function"]["name"], - "args": function_args, - } - if self.return_id: - parsed["id"] = tool_call["id"] - final_tools.append(parsed) - if exceptions: - raise OutputParserException("\n\n".join(exceptions)) if self.first_tool_only: - return final_tools[0] if final_tools else None - return final_tools + return tool_calls[0] if tool_calls else None + return tool_calls def parse(self, text: str) -> Any: raise NotImplementedError() diff --git a/libs/core/langchain_core/utils/_merge.py b/libs/core/langchain_core/utils/_merge.py index 27dbbdd5ac5..b6f3ab25d43 100644 --- a/libs/core/langchain_core/utils/_merge.py +++ b/libs/core/langchain_core/utils/_merge.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Dict +from typing import Any, Dict, List, Optional def merge_dicts(left: Dict[str, Any], right: Dict[str, Any]) -> Dict[str, Any]: @@ -33,22 +33,7 @@ def merge_dicts(left: Dict[str, Any], right: Dict[str, Any]) -> Dict[str, Any]: elif isinstance(merged[right_k], dict): merged[right_k] = merge_dicts(merged[right_k], right_v) elif isinstance(merged[right_k], list): - merged[right_k] = merged[right_k].copy() - for e in right_v: - if isinstance(e, dict) and "index" in e and isinstance(e["index"], int): - to_merge = [ - i - for i, e_left in enumerate(merged[right_k]) - if e_left["index"] == e["index"] - ] - if to_merge: - merged[right_k][to_merge[0]] = merge_dicts( - merged[right_k][to_merge[0]], e - ) - else: - merged[right_k] = merged[right_k] + [e] - else: - merged[right_k] = merged[right_k] + [e] + merged[right_k] = merge_lists(merged[right_k], right_v) elif merged[right_k] == right_v: continue else: @@ -57,3 +42,27 @@ def merge_dicts(left: Dict[str, Any], right: Dict[str, Any]) -> Dict[str, Any]: f"value has unsupported type {type(merged[right_k])}." ) return merged + + +def merge_lists(left: Optional[List], right: Optional[List]) -> Optional[List]: + """Add two lists, handling None.""" + if left is None and right is None: + return None + elif left is None or right is None: + return left or right + else: + merged = left.copy() + for e in right: + if isinstance(e, dict) and "index" in e and isinstance(e["index"], int): + to_merge = [ + i + for i, e_left in enumerate(merged) + if e_left["index"] == e["index"] + ] + if to_merge: + merged[to_merge[0]] = merge_dicts(merged[to_merge[0]], e) + else: + merged = merged + [e] + else: + merged = merged + [e] + return merged diff --git a/libs/core/langchain_core/utils/json.py b/libs/core/langchain_core/utils/json.py new file mode 100644 index 00000000000..e7867a3a828 --- /dev/null +++ b/libs/core/langchain_core/utils/json.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +import json +import re +from typing import Any, Callable, List + +from langchain_core.exceptions import OutputParserException + + +def _replace_new_line(match: re.Match[str]) -> str: + value = match.group(2) + value = re.sub(r"\n", r"\\n", value) + value = re.sub(r"\r", r"\\r", value) + value = re.sub(r"\t", r"\\t", value) + value = re.sub(r'(? str: + """ + The LLM response for `action_input` may be a multiline + string containing unescaped newlines, tabs or quotes. This function + replaces those characters with their escaped counterparts. + (newlines in JSON must be double-escaped: `\\n`) + """ + if isinstance(multiline_string, (bytes, bytearray)): + multiline_string = multiline_string.decode() + + multiline_string = re.sub( + r'("action_input"\:\s*")(.*?)(")', + _replace_new_line, + multiline_string, + flags=re.DOTALL, + ) + + return multiline_string + + +# Adapted from https://github.com/KillianLucas/open-interpreter/blob/5b6080fae1f8c68938a1e4fa8667e3744084ee21/interpreter/utils/parse_partial_json.py +# MIT License + + +def parse_partial_json(s: str, *, strict: bool = False) -> Any: + """Parse a JSON string that may be missing closing braces. + + Args: + s: The JSON string to parse. + strict: Whether to use strict parsing. Defaults to False. + + Returns: + The parsed JSON object as a Python dictionary. + """ + # Attempt to parse the string as-is. + try: + return json.loads(s, strict=strict) + except json.JSONDecodeError: + pass + + # Initialize variables. + new_s = "" + stack = [] + is_inside_string = False + escaped = False + + # Process each character in the string one at a time. + for char in s: + if is_inside_string: + if char == '"' and not escaped: + is_inside_string = False + elif char == "\n" and not escaped: + char = "\\n" # Replace the newline character with the escape sequence. + elif char == "\\": + escaped = not escaped + else: + escaped = False + else: + if char == '"': + is_inside_string = True + escaped = False + elif char == "{": + stack.append("}") + elif char == "[": + stack.append("]") + elif char == "}" or char == "]": + if stack and stack[-1] == char: + stack.pop() + else: + # Mismatched closing character; the input is malformed. + return None + + # Append the processed character to the new string. + new_s += char + + # If we're still inside a string at the end of processing, + # we need to close the string. + if is_inside_string: + new_s += '"' + + # Try to parse mods of string until we succeed or run out of characters. + while new_s: + final_s = new_s + + # Close any remaining open structures in the reverse + # order that they were opened. + for closing_char in reversed(stack): + final_s += closing_char + + # Attempt to parse the modified string as JSON. + try: + return json.loads(final_s, strict=strict) + except json.JSONDecodeError: + # If we still can't parse the string as JSON, + # try removing the last character + new_s = new_s[:-1] + + # If we got here, we ran out of characters to remove + # and still couldn't parse the string as JSON, so return the parse error + # for the original string. + return json.loads(s, strict=strict) + + +def parse_json_markdown( + json_string: str, *, parser: Callable[[str], Any] = parse_partial_json +) -> dict: + """ + Parse a JSON string from a Markdown string. + + Args: + json_string: The Markdown string. + + Returns: + The parsed JSON object as a Python dictionary. + """ + try: + return _parse_json(json_string, parser=parser) + except json.JSONDecodeError: + # Try to find JSON string within triple backticks + match = re.search(r"```(json)?(.*)", json_string, re.DOTALL) + + # If no match found, assume the entire string is a JSON string + if match is None: + json_str = json_string + else: + # If match found, use the content within the backticks + json_str = match.group(2) + return _parse_json(json_str, parser=parser) + + +def _parse_json( + json_str: str, *, parser: Callable[[str], Any] = parse_partial_json +) -> dict: + # Strip whitespace and newlines from the start and end + json_str = json_str.strip().strip("`") + + # handle newlines and other special characters inside the returned value + json_str = _custom_parser(json_str) + + # Parse the JSON string into a Python dictionary + return parser(json_str) + + +def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict: + """ + Parse a JSON string from a Markdown string and check that it + contains the expected keys. + + Args: + text: The Markdown string. + expected_keys: The expected keys in the JSON string. + + Returns: + The parsed JSON object as a Python dictionary. + """ + try: + json_obj = parse_json_markdown(text) + except json.JSONDecodeError as e: + raise OutputParserException(f"Got invalid JSON object. Error: {e}") + for key in expected_keys: + if key not in json_obj: + raise OutputParserException( + f"Got invalid return object. Expected key `{key}` " + f"to be present, but got {json_obj}" + ) + return json_obj diff --git a/libs/core/tests/unit_tests/messages/test_imports.py b/libs/core/tests/unit_tests/messages/test_imports.py index 7e549f5b43b..eb4e141ce82 100644 --- a/libs/core/tests/unit_tests/messages/test_imports.py +++ b/libs/core/tests/unit_tests/messages/test_imports.py @@ -14,8 +14,11 @@ EXPECTED_ALL = [ "FunctionMessageChunk", "HumanMessage", "HumanMessageChunk", + "InvalidToolCall", "SystemMessage", "SystemMessageChunk", + "ToolCall", + "ToolCallChunk", "ToolMessage", "ToolMessageChunk", "convert_to_messages", diff --git a/libs/core/tests/unit_tests/output_parsers/test_json.py b/libs/core/tests/unit_tests/output_parsers/test_json.py index 7f4437f4329..d29b5fc3def 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_json.py +++ b/libs/core/tests/unit_tests/output_parsers/test_json.py @@ -5,11 +5,10 @@ import pytest from langchain_core.output_parsers.json import ( SimpleJsonOutputParser, - parse_json_markdown, - parse_partial_json, ) from langchain_core.pydantic_v1 import BaseModel from langchain_core.utils.function_calling import convert_to_openai_function +from langchain_core.utils.json import parse_json_markdown, parse_partial_json GOOD_JSON = """```json { diff --git a/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py b/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py index 0ba52d4ff01..cd7f9f52dd4 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py +++ b/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py @@ -1,6 +1,6 @@ from typing import Any, AsyncIterator, Iterator, List -from langchain_core.messages import AIMessageChunk, BaseMessage +from langchain_core.messages import AIMessageChunk, BaseMessage, ToolCallChunk from langchain_core.output_parsers.openai_tools import ( JsonOutputKeyToolsParser, JsonOutputToolsParser, @@ -300,6 +300,28 @@ STREAMED_MESSAGES: list = [ ] +STREAMED_MESSAGES_WITH_TOOL_CALLS = [] +for message in STREAMED_MESSAGES: + if message.additional_kwargs: + STREAMED_MESSAGES_WITH_TOOL_CALLS.append( + AIMessageChunk( + content=message.content, + additional_kwargs=message.additional_kwargs, + tool_call_chunks=[ + ToolCallChunk( + name=chunk["function"].get("name"), + args=chunk["function"].get("arguments"), + id=chunk.get("id"), + index=chunk["index"], + ) + for chunk in message.additional_kwargs["tool_calls"] + ], + ) + ) + else: + STREAMED_MESSAGES_WITH_TOOL_CALLS.append(message) + + EXPECTED_STREAMED_JSON = [ {}, {"names": ["suz"]}, @@ -330,101 +352,118 @@ EXPECTED_STREAMED_JSON = [ ] -def test_partial_json_output_parser() -> None: +def _get_iter(use_tool_calls: bool = False) -> Any: + if use_tool_calls: + list_to_iter = STREAMED_MESSAGES_WITH_TOOL_CALLS + else: + list_to_iter = STREAMED_MESSAGES + def input_iter(_: Any) -> Iterator[BaseMessage]: - for msg in STREAMED_MESSAGES: + for msg in list_to_iter: yield msg - chain = input_iter | JsonOutputToolsParser() + return input_iter - actual = list(chain.stream(None)) - expected: list = [[]] + [ - [{"type": "NameCollector", "args": chunk}] for chunk in EXPECTED_STREAMED_JSON - ] - assert actual == expected + +def _get_aiter(use_tool_calls: bool = False) -> Any: + if use_tool_calls: + list_to_iter = STREAMED_MESSAGES_WITH_TOOL_CALLS + else: + list_to_iter = STREAMED_MESSAGES + + async def input_iter(_: Any) -> AsyncIterator[BaseMessage]: + for msg in list_to_iter: + yield msg + + return input_iter + + +def test_partial_json_output_parser() -> None: + for use_tool_calls in [False, True]: + input_iter = _get_iter(use_tool_calls) + chain = input_iter | JsonOutputToolsParser() + + actual = list(chain.stream(None)) + expected: list = [[]] + [ + [{"type": "NameCollector", "args": chunk}] + for chunk in EXPECTED_STREAMED_JSON + ] + assert actual == expected async def test_partial_json_output_parser_async() -> None: - async def input_iter(_: Any) -> AsyncIterator[BaseMessage]: - for token in STREAMED_MESSAGES: - yield token + for use_tool_calls in [False, True]: + input_iter = _get_aiter(use_tool_calls) + chain = input_iter | JsonOutputToolsParser() - chain = input_iter | JsonOutputToolsParser() - - actual = [p async for p in chain.astream(None)] - expected: list = [[]] + [ - [{"type": "NameCollector", "args": chunk}] for chunk in EXPECTED_STREAMED_JSON - ] - assert actual == expected + actual = [p async for p in chain.astream(None)] + expected: list = [[]] + [ + [{"type": "NameCollector", "args": chunk}] + for chunk in EXPECTED_STREAMED_JSON + ] + assert actual == expected def test_partial_json_output_parser_return_id() -> None: - def input_iter(_: Any) -> Iterator[BaseMessage]: - for msg in STREAMED_MESSAGES: - yield msg + for use_tool_calls in [False, True]: + input_iter = _get_iter(use_tool_calls) + chain = input_iter | JsonOutputToolsParser(return_id=True) - chain = input_iter | JsonOutputToolsParser(return_id=True) - - actual = list(chain.stream(None)) - expected: list = [[]] + [ - [ - { - "type": "NameCollector", - "args": chunk, - "id": "call_OwL7f5PEPJTYzw9sQlNJtCZl", - } + actual = list(chain.stream(None)) + expected: list = [[]] + [ + [ + { + "type": "NameCollector", + "args": chunk, + "id": "call_OwL7f5PEPJTYzw9sQlNJtCZl", + } + ] + for chunk in EXPECTED_STREAMED_JSON ] - for chunk in EXPECTED_STREAMED_JSON - ] - assert actual == expected + assert actual == expected def test_partial_json_output_key_parser() -> None: - def input_iter(_: Any) -> Iterator[BaseMessage]: - for msg in STREAMED_MESSAGES: - yield msg + for use_tool_calls in [False, True]: + input_iter = _get_iter(use_tool_calls) + chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector") - chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector") - - actual = list(chain.stream(None)) - expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON] - assert actual == expected + actual = list(chain.stream(None)) + expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON] + assert actual == expected async def test_partial_json_output_parser_key_async() -> None: - async def input_iter(_: Any) -> AsyncIterator[BaseMessage]: - for token in STREAMED_MESSAGES: - yield token + for use_tool_calls in [False, True]: + input_iter = _get_aiter(use_tool_calls) - chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector") + chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector") - actual = [p async for p in chain.astream(None)] - expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON] - assert actual == expected + actual = [p async for p in chain.astream(None)] + expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON] + assert actual == expected def test_partial_json_output_key_parser_first_only() -> None: - def input_iter(_: Any) -> Iterator[BaseMessage]: - for msg in STREAMED_MESSAGES: - yield msg + for use_tool_calls in [False, True]: + input_iter = _get_iter(use_tool_calls) - chain = input_iter | JsonOutputKeyToolsParser( - key_name="NameCollector", first_tool_only=True - ) + chain = input_iter | JsonOutputKeyToolsParser( + key_name="NameCollector", first_tool_only=True + ) - assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON + assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON async def test_partial_json_output_parser_key_async_first_only() -> None: - async def input_iter(_: Any) -> AsyncIterator[BaseMessage]: - for token in STREAMED_MESSAGES: - yield token + for use_tool_calls in [False, True]: + input_iter = _get_aiter(use_tool_calls) - chain = input_iter | JsonOutputKeyToolsParser( - key_name="NameCollector", first_tool_only=True - ) + chain = input_iter | JsonOutputKeyToolsParser( + key_name="NameCollector", first_tool_only=True + ) - assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON + assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON class Person(BaseModel): @@ -458,26 +497,24 @@ EXPECTED_STREAMED_PYDANTIC = [ def test_partial_pydantic_output_parser() -> None: - def input_iter(_: Any) -> Iterator[BaseMessage]: - for msg in STREAMED_MESSAGES: - yield msg + for use_tool_calls in [False, True]: + input_iter = _get_iter(use_tool_calls) - chain = input_iter | PydanticToolsParser( - tools=[NameCollector], first_tool_only=True - ) + chain = input_iter | PydanticToolsParser( + tools=[NameCollector], first_tool_only=True + ) - actual = list(chain.stream(None)) - assert actual == EXPECTED_STREAMED_PYDANTIC + actual = list(chain.stream(None)) + assert actual == EXPECTED_STREAMED_PYDANTIC async def test_partial_pydantic_output_parser_async() -> None: - async def input_iter(_: Any) -> AsyncIterator[BaseMessage]: - for token in STREAMED_MESSAGES: - yield token + for use_tool_calls in [False, True]: + input_iter = _get_aiter(use_tool_calls) - chain = input_iter | PydanticToolsParser( - tools=[NameCollector], first_tool_only=True - ) + chain = input_iter | PydanticToolsParser( + tools=[NameCollector], first_tool_only=True + ) - actual = [p async for p in chain.astream(None)] - assert actual == EXPECTED_STREAMED_PYDANTIC + actual = [p async for p in chain.astream(None)] + assert actual == EXPECTED_STREAMED_PYDANTIC diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr index 120e57319f2..a890456c094 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr @@ -5299,6 +5299,15 @@ 'title': 'Id', 'type': 'string', }), + 'invalid_tool_calls': dict({ + 'default': list([ + ]), + 'items': dict({ + '$ref': '#/definitions/InvalidToolCall', + }), + 'title': 'Invalid Tool Calls', + 'type': 'array', + }), 'name': dict({ 'title': 'Name', 'type': 'string', @@ -5307,6 +5316,15 @@ 'title': 'Response Metadata', 'type': 'object', }), + 'tool_calls': dict({ + 'default': list([ + ]), + 'items': dict({ + '$ref': '#/definitions/ToolCall', + }), + 'title': 'Tool Calls', + 'type': 'array', + }), 'type': dict({ 'default': 'ai', 'enum': list([ @@ -5545,6 +5563,34 @@ 'title': 'HumanMessage', 'type': 'object', }), + 'InvalidToolCall': dict({ + 'properties': dict({ + 'args': dict({ + 'title': 'Args', + 'type': 'string', + }), + 'error': dict({ + 'title': 'Error', + 'type': 'string', + }), + 'id': dict({ + 'title': 'Id', + 'type': 'string', + }), + 'name': dict({ + 'title': 'Name', + 'type': 'string', + }), + }), + 'required': list([ + 'name', + 'args', + 'id', + 'error', + ]), + 'title': 'InvalidToolCall', + 'type': 'object', + }), 'StringPromptValue': dict({ 'description': 'String prompt value.', 'properties': dict({ @@ -5625,6 +5671,29 @@ 'title': 'SystemMessage', 'type': 'object', }), + 'ToolCall': dict({ + 'properties': dict({ + 'args': dict({ + 'title': 'Args', + 'type': 'object', + }), + 'id': dict({ + 'title': 'Id', + 'type': 'string', + }), + 'name': dict({ + 'title': 'Name', + 'type': 'string', + }), + }), + 'required': list([ + 'name', + 'args', + 'id', + ]), + 'title': 'ToolCall', + 'type': 'object', + }), 'ToolMessage': dict({ 'description': 'Message for passing the result of executing a tool back to a model.', 'properties': dict({ @@ -5765,6 +5834,15 @@ 'title': 'Id', 'type': 'string', }), + 'invalid_tool_calls': dict({ + 'default': list([ + ]), + 'items': dict({ + '$ref': '#/definitions/InvalidToolCall', + }), + 'title': 'Invalid Tool Calls', + 'type': 'array', + }), 'name': dict({ 'title': 'Name', 'type': 'string', @@ -5773,6 +5851,15 @@ 'title': 'Response Metadata', 'type': 'object', }), + 'tool_calls': dict({ + 'default': list([ + ]), + 'items': dict({ + '$ref': '#/definitions/ToolCall', + }), + 'title': 'Tool Calls', + 'type': 'array', + }), 'type': dict({ 'default': 'ai', 'enum': list([ @@ -6011,6 +6098,34 @@ 'title': 'HumanMessage', 'type': 'object', }), + 'InvalidToolCall': dict({ + 'properties': dict({ + 'args': dict({ + 'title': 'Args', + 'type': 'string', + }), + 'error': dict({ + 'title': 'Error', + 'type': 'string', + }), + 'id': dict({ + 'title': 'Id', + 'type': 'string', + }), + 'name': dict({ + 'title': 'Name', + 'type': 'string', + }), + }), + 'required': list([ + 'name', + 'args', + 'id', + 'error', + ]), + 'title': 'InvalidToolCall', + 'type': 'object', + }), 'StringPromptValue': dict({ 'description': 'String prompt value.', 'properties': dict({ @@ -6091,6 +6206,29 @@ 'title': 'SystemMessage', 'type': 'object', }), + 'ToolCall': dict({ + 'properties': dict({ + 'args': dict({ + 'title': 'Args', + 'type': 'object', + }), + 'id': dict({ + 'title': 'Id', + 'type': 'string', + }), + 'name': dict({ + 'title': 'Name', + 'type': 'string', + }), + }), + 'required': list([ + 'name', + 'args', + 'id', + ]), + 'title': 'ToolCall', + 'type': 'object', + }), 'ToolMessage': dict({ 'description': 'Message for passing the result of executing a tool back to a model.', 'properties': dict({ @@ -6215,6 +6353,15 @@ 'title': 'Id', 'type': 'string', }), + 'invalid_tool_calls': dict({ + 'default': list([ + ]), + 'items': dict({ + '$ref': '#/definitions/InvalidToolCall', + }), + 'title': 'Invalid Tool Calls', + 'type': 'array', + }), 'name': dict({ 'title': 'Name', 'type': 'string', @@ -6223,6 +6370,15 @@ 'title': 'Response Metadata', 'type': 'object', }), + 'tool_calls': dict({ + 'default': list([ + ]), + 'items': dict({ + '$ref': '#/definitions/ToolCall', + }), + 'title': 'Tool Calls', + 'type': 'array', + }), 'type': dict({ 'default': 'ai', 'enum': list([ @@ -6414,6 +6570,34 @@ 'title': 'HumanMessage', 'type': 'object', }), + 'InvalidToolCall': dict({ + 'properties': dict({ + 'args': dict({ + 'title': 'Args', + 'type': 'string', + }), + 'error': dict({ + 'title': 'Error', + 'type': 'string', + }), + 'id': dict({ + 'title': 'Id', + 'type': 'string', + }), + 'name': dict({ + 'title': 'Name', + 'type': 'string', + }), + }), + 'required': list([ + 'name', + 'args', + 'id', + 'error', + ]), + 'title': 'InvalidToolCall', + 'type': 'object', + }), 'SystemMessage': dict({ 'description': ''' Message for priming AI behavior, usually passed in as the first of a sequence @@ -6472,6 +6656,29 @@ 'title': 'SystemMessage', 'type': 'object', }), + 'ToolCall': dict({ + 'properties': dict({ + 'args': dict({ + 'title': 'Args', + 'type': 'object', + }), + 'id': dict({ + 'title': 'Id', + 'type': 'string', + }), + 'name': dict({ + 'title': 'Name', + 'type': 'string', + }), + }), + 'required': list([ + 'name', + 'args', + 'id', + ]), + 'title': 'ToolCall', + 'type': 'object', + }), 'ToolMessage': dict({ 'description': 'Message for passing the result of executing a tool back to a model.', 'properties': dict({ @@ -6584,6 +6791,15 @@ 'title': 'Id', 'type': 'string', }), + 'invalid_tool_calls': dict({ + 'default': list([ + ]), + 'items': dict({ + '$ref': '#/definitions/InvalidToolCall', + }), + 'title': 'Invalid Tool Calls', + 'type': 'array', + }), 'name': dict({ 'title': 'Name', 'type': 'string', @@ -6592,6 +6808,15 @@ 'title': 'Response Metadata', 'type': 'object', }), + 'tool_calls': dict({ + 'default': list([ + ]), + 'items': dict({ + '$ref': '#/definitions/ToolCall', + }), + 'title': 'Tool Calls', + 'type': 'array', + }), 'type': dict({ 'default': 'ai', 'enum': list([ @@ -6830,6 +7055,34 @@ 'title': 'HumanMessage', 'type': 'object', }), + 'InvalidToolCall': dict({ + 'properties': dict({ + 'args': dict({ + 'title': 'Args', + 'type': 'string', + }), + 'error': dict({ + 'title': 'Error', + 'type': 'string', + }), + 'id': dict({ + 'title': 'Id', + 'type': 'string', + }), + 'name': dict({ + 'title': 'Name', + 'type': 'string', + }), + }), + 'required': list([ + 'name', + 'args', + 'id', + 'error', + ]), + 'title': 'InvalidToolCall', + 'type': 'object', + }), 'StringPromptValue': dict({ 'description': 'String prompt value.', 'properties': dict({ @@ -6910,6 +7163,29 @@ 'title': 'SystemMessage', 'type': 'object', }), + 'ToolCall': dict({ + 'properties': dict({ + 'args': dict({ + 'title': 'Args', + 'type': 'object', + }), + 'id': dict({ + 'title': 'Id', + 'type': 'string', + }), + 'name': dict({ + 'title': 'Name', + 'type': 'string', + }), + }), + 'required': list([ + 'name', + 'args', + 'id', + ]), + 'title': 'ToolCall', + 'type': 'object', + }), 'ToolMessage': dict({ 'description': 'Message for passing the result of executing a tool back to a model.', 'properties': dict({ @@ -7022,6 +7298,15 @@ 'title': 'Id', 'type': 'string', }), + 'invalid_tool_calls': dict({ + 'default': list([ + ]), + 'items': dict({ + '$ref': '#/definitions/InvalidToolCall', + }), + 'title': 'Invalid Tool Calls', + 'type': 'array', + }), 'name': dict({ 'title': 'Name', 'type': 'string', @@ -7030,6 +7315,15 @@ 'title': 'Response Metadata', 'type': 'object', }), + 'tool_calls': dict({ + 'default': list([ + ]), + 'items': dict({ + '$ref': '#/definitions/ToolCall', + }), + 'title': 'Tool Calls', + 'type': 'array', + }), 'type': dict({ 'default': 'ai', 'enum': list([ @@ -7268,6 +7562,34 @@ 'title': 'HumanMessage', 'type': 'object', }), + 'InvalidToolCall': dict({ + 'properties': dict({ + 'args': dict({ + 'title': 'Args', + 'type': 'string', + }), + 'error': dict({ + 'title': 'Error', + 'type': 'string', + }), + 'id': dict({ + 'title': 'Id', + 'type': 'string', + }), + 'name': dict({ + 'title': 'Name', + 'type': 'string', + }), + }), + 'required': list([ + 'name', + 'args', + 'id', + 'error', + ]), + 'title': 'InvalidToolCall', + 'type': 'object', + }), 'StringPromptValue': dict({ 'description': 'String prompt value.', 'properties': dict({ @@ -7348,6 +7670,29 @@ 'title': 'SystemMessage', 'type': 'object', }), + 'ToolCall': dict({ + 'properties': dict({ + 'args': dict({ + 'title': 'Args', + 'type': 'object', + }), + 'id': dict({ + 'title': 'Id', + 'type': 'string', + }), + 'name': dict({ + 'title': 'Name', + 'type': 'string', + }), + }), + 'required': list([ + 'name', + 'args', + 'id', + ]), + 'title': 'ToolCall', + 'type': 'object', + }), 'ToolMessage': dict({ 'description': 'Message for passing the result of executing a tool back to a model.', 'properties': dict({ @@ -7452,6 +7797,15 @@ 'title': 'Id', 'type': 'string', }), + 'invalid_tool_calls': dict({ + 'default': list([ + ]), + 'items': dict({ + '$ref': '#/definitions/InvalidToolCall', + }), + 'title': 'Invalid Tool Calls', + 'type': 'array', + }), 'name': dict({ 'title': 'Name', 'type': 'string', @@ -7460,6 +7814,15 @@ 'title': 'Response Metadata', 'type': 'object', }), + 'tool_calls': dict({ + 'default': list([ + ]), + 'items': dict({ + '$ref': '#/definitions/ToolCall', + }), + 'title': 'Tool Calls', + 'type': 'array', + }), 'type': dict({ 'default': 'ai', 'enum': list([ @@ -7698,6 +8061,34 @@ 'title': 'HumanMessage', 'type': 'object', }), + 'InvalidToolCall': dict({ + 'properties': dict({ + 'args': dict({ + 'title': 'Args', + 'type': 'string', + }), + 'error': dict({ + 'title': 'Error', + 'type': 'string', + }), + 'id': dict({ + 'title': 'Id', + 'type': 'string', + }), + 'name': dict({ + 'title': 'Name', + 'type': 'string', + }), + }), + 'required': list([ + 'name', + 'args', + 'id', + 'error', + ]), + 'title': 'InvalidToolCall', + 'type': 'object', + }), 'PromptTemplateOutput': dict({ 'anyOf': list([ dict({ @@ -7789,6 +8180,29 @@ 'title': 'SystemMessage', 'type': 'object', }), + 'ToolCall': dict({ + 'properties': dict({ + 'args': dict({ + 'title': 'Args', + 'type': 'object', + }), + 'id': dict({ + 'title': 'Id', + 'type': 'string', + }), + 'name': dict({ + 'title': 'Name', + 'type': 'string', + }), + }), + 'required': list([ + 'name', + 'args', + 'id', + ]), + 'title': 'ToolCall', + 'type': 'object', + }), 'ToolMessage': dict({ 'description': 'Message for passing the result of executing a tool back to a model.', 'properties': dict({ @@ -7920,6 +8334,15 @@ 'title': 'Id', 'type': 'string', }), + 'invalid_tool_calls': dict({ + 'default': list([ + ]), + 'items': dict({ + '$ref': '#/definitions/InvalidToolCall', + }), + 'title': 'Invalid Tool Calls', + 'type': 'array', + }), 'name': dict({ 'title': 'Name', 'type': 'string', @@ -7928,6 +8351,15 @@ 'title': 'Response Metadata', 'type': 'object', }), + 'tool_calls': dict({ + 'default': list([ + ]), + 'items': dict({ + '$ref': '#/definitions/ToolCall', + }), + 'title': 'Tool Calls', + 'type': 'array', + }), 'type': dict({ 'default': 'ai', 'enum': list([ @@ -8119,6 +8551,34 @@ 'title': 'HumanMessage', 'type': 'object', }), + 'InvalidToolCall': dict({ + 'properties': dict({ + 'args': dict({ + 'title': 'Args', + 'type': 'string', + }), + 'error': dict({ + 'title': 'Error', + 'type': 'string', + }), + 'id': dict({ + 'title': 'Id', + 'type': 'string', + }), + 'name': dict({ + 'title': 'Name', + 'type': 'string', + }), + }), + 'required': list([ + 'name', + 'args', + 'id', + 'error', + ]), + 'title': 'InvalidToolCall', + 'type': 'object', + }), 'SystemMessage': dict({ 'description': ''' Message for priming AI behavior, usually passed in as the first of a sequence @@ -8177,6 +8637,29 @@ 'title': 'SystemMessage', 'type': 'object', }), + 'ToolCall': dict({ + 'properties': dict({ + 'args': dict({ + 'title': 'Args', + 'type': 'object', + }), + 'id': dict({ + 'title': 'Id', + 'type': 'string', + }), + 'name': dict({ + 'title': 'Name', + 'type': 'string', + }), + }), + 'required': list([ + 'name', + 'args', + 'id', + ]), + 'title': 'ToolCall', + 'type': 'object', + }), 'ToolMessage': dict({ 'description': 'Message for passing the result of executing a tool back to a model.', 'properties': dict({ diff --git a/libs/core/tests/unit_tests/runnables/test_graph.py b/libs/core/tests/unit_tests/runnables/test_graph.py index f4a6a5ee2a0..fe98da71a72 100644 --- a/libs/core/tests/unit_tests/runnables/test_graph.py +++ b/libs/core/tests/unit_tests/runnables/test_graph.py @@ -206,6 +206,27 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None: {"$ref": "#/definitions/ToolMessage"}, ], "definitions": { + "ToolCall": { + "title": "ToolCall", + "type": "object", + "properties": { + "name": {"title": "Name", "type": "string"}, + "args": {"title": "Args", "type": "object"}, + "id": {"title": "Id", "type": "string"}, + }, + "required": ["name", "args", "id"], + }, + "InvalidToolCall": { + "title": "InvalidToolCall", + "type": "object", + "properties": { + "name": {"title": "Name", "type": "string"}, + "args": {"title": "Args", "type": "string"}, + "id": {"title": "Id", "type": "string"}, + "error": {"title": "Error", "type": "string"}, + }, + "required": ["name", "args", "id", "error"], + }, "AIMessage": { "title": "AIMessage", "description": "Message from an AI.", @@ -240,13 +261,25 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None: "enum": ["ai"], "type": "string", }, - "id": {"title": "Id", "type": "string"}, "name": {"title": "Name", "type": "string"}, + "id": {"title": "Id", "type": "string"}, "example": { "title": "Example", "default": False, "type": "boolean", }, + "tool_calls": { + "title": "Tool Calls", + "default": [], + "type": "array", + "items": {"$ref": "#/definitions/ToolCall"}, + }, + "invalid_tool_calls": { + "title": "Invalid Tool Calls", + "default": [], + "type": "array", + "items": {"$ref": "#/definitions/InvalidToolCall"}, + }, }, "required": ["content"], }, @@ -284,8 +317,8 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None: "enum": ["human"], "type": "string", }, - "id": {"title": "Id", "type": "string"}, "name": {"title": "Name", "type": "string"}, + "id": {"title": "Id", "type": "string"}, "example": { "title": "Example", "default": False, @@ -328,8 +361,8 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None: "enum": ["chat"], "type": "string", }, - "id": {"title": "Id", "type": "string"}, "name": {"title": "Name", "type": "string"}, + "id": {"title": "Id", "type": "string"}, "role": {"title": "Role", "type": "string"}, }, "required": ["content", "role"], @@ -368,8 +401,8 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None: "enum": ["system"], "type": "string", }, - "id": {"title": "Id", "type": "string"}, "name": {"title": "Name", "type": "string"}, + "id": {"title": "Id", "type": "string"}, }, "required": ["content"], }, @@ -407,8 +440,8 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None: "enum": ["function"], "type": "string", }, - "id": {"title": "Id", "type": "string"}, "name": {"title": "Name", "type": "string"}, + "id": {"title": "Id", "type": "string"}, }, "required": ["content", "name"], }, @@ -446,8 +479,8 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None: "enum": ["tool"], "type": "string", }, - "id": {"title": "Id", "type": "string"}, "name": {"title": "Name", "type": "string"}, + "id": {"title": "Id", "type": "string"}, "tool_call_id": { "title": "Tool Call Id", "type": "string", diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 3a860642c3a..7a0524a90b8 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -357,6 +357,27 @@ def test_schemas(snapshot: SnapshotAssertion) -> None: } }, "definitions": { + "ToolCall": { + "title": "ToolCall", + "type": "object", + "properties": { + "name": {"title": "Name", "type": "string"}, + "args": {"title": "Args", "type": "object"}, + "id": {"title": "Id", "type": "string"}, + }, + "required": ["name", "args", "id"], + }, + "InvalidToolCall": { + "title": "InvalidToolCall", + "type": "object", + "properties": { + "name": {"title": "Name", "type": "string"}, + "args": {"title": "Args", "type": "string"}, + "id": {"title": "Id", "type": "string"}, + "error": {"title": "Error", "type": "string"}, + }, + "required": ["name", "args", "id", "error"], + }, "AIMessage": { "title": "AIMessage", "description": "Message from an AI.", @@ -388,13 +409,25 @@ def test_schemas(snapshot: SnapshotAssertion) -> None: "enum": ["ai"], "type": "string", }, - "id": {"title": "Id", "type": "string"}, "name": {"title": "Name", "type": "string"}, + "id": {"title": "Id", "type": "string"}, "example": { "title": "Example", "default": False, "type": "boolean", }, + "tool_calls": { + "title": "Tool Calls", + "default": [], + "type": "array", + "items": {"$ref": "#/definitions/ToolCall"}, + }, + "invalid_tool_calls": { + "title": "Invalid Tool Calls", + "default": [], + "type": "array", + "items": {"$ref": "#/definitions/InvalidToolCall"}, + }, }, "required": ["content"], }, @@ -429,8 +462,8 @@ def test_schemas(snapshot: SnapshotAssertion) -> None: "enum": ["human"], "type": "string", }, - "id": {"title": "Id", "type": "string"}, "name": {"title": "Name", "type": "string"}, + "id": {"title": "Id", "type": "string"}, "example": { "title": "Example", "default": False, @@ -470,8 +503,8 @@ def test_schemas(snapshot: SnapshotAssertion) -> None: "enum": ["chat"], "type": "string", }, - "id": {"title": "Id", "type": "string"}, "name": {"title": "Name", "type": "string"}, + "id": {"title": "Id", "type": "string"}, "role": {"title": "Role", "type": "string"}, }, "required": ["content", "role"], @@ -507,8 +540,8 @@ def test_schemas(snapshot: SnapshotAssertion) -> None: "enum": ["system"], "type": "string", }, - "id": {"title": "Id", "type": "string"}, "name": {"title": "Name", "type": "string"}, + "id": {"title": "Id", "type": "string"}, }, "required": ["content"], }, @@ -543,8 +576,8 @@ def test_schemas(snapshot: SnapshotAssertion) -> None: "enum": ["function"], "type": "string", }, - "id": {"title": "Id", "type": "string"}, "name": {"title": "Name", "type": "string"}, + "id": {"title": "Id", "type": "string"}, }, "required": ["content", "name"], }, @@ -579,8 +612,8 @@ def test_schemas(snapshot: SnapshotAssertion) -> None: "enum": ["tool"], "type": "string", }, - "id": {"title": "Id", "type": "string"}, "name": {"title": "Name", "type": "string"}, + "id": {"title": "Id", "type": "string"}, "tool_call_id": {"title": "Tool Call Id", "type": "string"}, }, "required": ["content", "tool_call_id"], diff --git a/libs/core/tests/unit_tests/test_messages.py b/libs/core/tests/unit_tests/test_messages.py index aeb480e2063..beb4cf4b0fb 100644 --- a/libs/core/tests/unit_tests/test_messages.py +++ b/libs/core/tests/unit_tests/test_messages.py @@ -13,6 +13,8 @@ from langchain_core.messages import ( HumanMessage, HumanMessageChunk, SystemMessage, + ToolCall, + ToolCallChunk, ToolMessage, convert_to_messages, get_buffer_string, @@ -20,6 +22,7 @@ from langchain_core.messages import ( messages_from_dict, messages_to_dict, ) +from langchain_core.utils._merge import merge_lists def test_message_chunks() -> None: @@ -68,6 +71,55 @@ def test_message_chunks() -> None: ) ), "MessageChunk + MessageChunk should be a MessageChunk with merged additional_kwargs" # noqa: E501 + # Test tool calls + assert ( + AIMessageChunk( + content="", + tool_call_chunks=[ToolCallChunk(name="tool1", args="", id="1", index=0)], + ) + + AIMessageChunk( + content="", + tool_call_chunks=[ + ToolCallChunk(name=None, args='{"arg1": "val', id=None, index=0) + ], + ) + + AIMessageChunk( + content="", + tool_call_chunks=[ToolCallChunk(name=None, args='ue}"', id=None, index=0)], + ) + ) == AIMessageChunk( + content="", + tool_call_chunks=[ + ToolCallChunk(name="tool1", args='{"arg1": "value}"', id="1", index=0) + ], + ) + + assert ( + AIMessageChunk( + content="", + tool_call_chunks=[ToolCallChunk(name="tool1", args="", id="1", index=0)], + ) + + AIMessageChunk( + content="", + tool_call_chunks=[ToolCallChunk(name="tool1", args="a", id=None, index=1)], + ) + # Don't merge if `index` field does not match. + ) == AIMessageChunk( + content="", + tool_call_chunks=[ + ToolCallChunk(name="tool1", args="", id="1", index=0), + ToolCallChunk(name="tool1", args="a", id=None, index=1), + ], + ) + + ai_msg_chunk = AIMessageChunk(content="") + tool_calls_msg_chunk = AIMessageChunk( + content="", + tool_call_chunks=[ToolCallChunk(name="tool1", args="a", id=None, index=1)], + ) + assert ai_msg_chunk + tool_calls_msg_chunk == tool_calls_msg_chunk + assert tool_calls_msg_chunk + ai_msg_chunk == tool_calls_msg_chunk + def test_chat_message_chunks() -> None: assert ChatMessageChunk(role="User", content="I am", id="ai4") + ChatMessageChunk( @@ -128,6 +180,7 @@ class TestGetBufferString(unittest.TestCase): self.func_msg = FunctionMessage(name="func", content="function") self.tool_msg = ToolMessage(tool_call_id="tool_id", content="tool") self.chat_msg = ChatMessage(role="Chat", content="chat") + self.tool_calls_msg = AIMessage(content="tool") def test_empty_input(self) -> None: self.assertEqual(get_buffer_string([]), "") @@ -163,6 +216,7 @@ class TestGetBufferString(unittest.TestCase): self.func_msg, self.tool_msg, self.chat_msg, + self.tool_calls_msg, ] expected_output = "\n".join( [ @@ -172,6 +226,7 @@ class TestGetBufferString(unittest.TestCase): "Function: function", "Tool: tool", "Chat: chat", + "AI: tool", ] ) self.assertEqual( @@ -192,6 +247,19 @@ def test_multiple_msg() -> None: ] assert messages_from_dict(messages_to_dict(msgs)) == msgs + # Test with tool calls + msgs = [ + AIMessage( + content="", + tool_calls=[ToolCall(name="a", args={"b": 1}, id=None)], + ), + AIMessage( + content="", + tool_calls=[ToolCall(name="c", args={"c": 2}, id=None)], + ), + ] + assert messages_from_dict(messages_to_dict(msgs)) == msgs + def test_multiple_msg_with_name() -> None: human_msg = HumanMessage( @@ -222,6 +290,30 @@ def test_message_chunk_to_message() -> None: FunctionMessageChunk(name="hello", content="I am") ) == FunctionMessage(name="hello", content="I am") + chunk = AIMessageChunk( + content="I am", + tool_call_chunks=[ + ToolCallChunk(name="tool1", args='{"a": 1}', id="1", index=0), + ToolCallChunk(name="tool2", args='{"b": ', id="2", index=0), + ToolCallChunk(name="tool3", args=None, id="3", index=0), + ToolCallChunk(name="tool4", args="abc", id="4", index=0), + ], + ) + expected = AIMessage( + content="I am", + tool_calls=[ + {"name": "tool1", "args": {"a": 1}, "id": "1"}, + {"name": "tool2", "args": {}, "id": "2"}, + ], + invalid_tool_calls=[ + {"name": "tool3", "args": None, "id": "3", "error": "Malformed args."}, + {"name": "tool4", "args": "abc", "id": "4", "error": "Malformed args."}, + ], + ) + assert message_chunk_to_message(chunk) == expected + assert AIMessage(**expected.dict()) == expected + assert AIMessageChunk(**chunk.dict()) == chunk + def test_tool_calls_merge() -> None: chunks: List[dict] = [ @@ -542,3 +634,35 @@ def test_message_name_chat(MessageClass: Type) -> None: msg3 = MessageClass(content="foo", role="user") assert msg3.name is None + + +def test_merge_tool_calls() -> None: + tool_call_1 = ToolCallChunk(name="tool1", args="", id="1", index=0) + tool_call_2 = ToolCallChunk(name=None, args='{"arg1": "val', id=None, index=0) + tool_call_3 = ToolCallChunk(name=None, args='ue}"', id=None, index=0) + merged = merge_lists([tool_call_1], [tool_call_2]) + assert merged is not None + assert merged == [{"name": "tool1", "args": '{"arg1": "val', "id": "1", "index": 0}] + merged = merge_lists(merged, [tool_call_3]) + assert merged is not None + assert merged == [ + {"name": "tool1", "args": '{"arg1": "value}"', "id": "1", "index": 0} + ] + + left = ToolCallChunk(name="tool1", args='{"arg1": "value1"}', id="1", index=None) + right = ToolCallChunk(name="tool2", args='{"arg2": "value2"}', id="1", index=None) + merged = merge_lists([left], [right]) + assert merged is not None + assert len(merged) == 2 + + left = ToolCallChunk(name="tool1", args='{"arg1": "value1"}', id=None, index=None) + right = ToolCallChunk(name="tool1", args='{"arg2": "value2"}', id=None, index=None) + merged = merge_lists([left], [right]) + assert merged is not None + assert len(merged) == 2 + + left = ToolCallChunk(name="tool1", args='{"arg1": "value1"}', id="1", index=0) + right = ToolCallChunk(name="tool2", args='{"arg2": "value2"}', id=None, index=1) + merged = merge_lists([left], [right]) + assert merged is not None + assert len(merged) == 2 diff --git a/libs/langchain/langchain/output_parsers/json.py b/libs/langchain/langchain/output_parsers/json.py index b0263889daa..20b06e3bcaa 100644 --- a/libs/langchain/langchain/output_parsers/json.py +++ b/libs/langchain/langchain/output_parsers/json.py @@ -1,5 +1,7 @@ from langchain_core.output_parsers.json import ( SimpleJsonOutputParser, +) +from langchain_core.utils.json import ( parse_and_check_json_markdown, parse_json_markdown, parse_partial_json, diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index 7825a165189..20394227420 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -1,3 +1,4 @@ +import json import os import re import warnings @@ -54,7 +55,7 @@ from langchain_core.utils import ( ) from langchain_core.utils.function_calling import convert_to_openai_tool -from langchain_anthropic.output_parsers import ToolsOutputParser +from langchain_anthropic.output_parsers import ToolsOutputParser, extract_tool_calls _message_type_lookups = { "human": "user", @@ -347,7 +348,24 @@ class ChatAnthropic(BaseChatModel): result = self._generate( messages, stop=stop, run_manager=run_manager, **kwargs ) - yield cast(ChatGenerationChunk, result.generations[0]) + message = result.generations[0].message + if isinstance(message, AIMessage) and message.tool_calls is not None: + tool_call_chunks = [ + { + "name": tool_call["name"], + "args": json.dumps(tool_call["args"]), + "id": tool_call["id"], + "index": idx, + } + for idx, tool_call in enumerate(message.tool_calls) + ] + message_chunk = AIMessageChunk( + content=message.content, + tool_call_chunks=tool_call_chunks, + ) + yield ChatGenerationChunk(message=message_chunk) + else: + yield cast(ChatGenerationChunk, result.generations[0]) return with self._client.messages.stream(**params) as stream: for text in stream.text_stream: @@ -369,7 +387,24 @@ class ChatAnthropic(BaseChatModel): result = await self._agenerate( messages, stop=stop, run_manager=run_manager, **kwargs ) - yield cast(ChatGenerationChunk, result.generations[0]) + message = result.generations[0].message + if isinstance(message, AIMessage) and message.tool_calls is not None: + tool_call_chunks = [ + { + "name": tool_call["name"], + "args": json.dumps(tool_call["args"]), + "id": tool_call["id"], + "index": idx, + } + for idx, tool_call in enumerate(message.tool_calls) + ] + message_chunk = AIMessageChunk( + content=message.content, + tool_call_chunks=tool_call_chunks, + ) + yield ChatGenerationChunk(message=message_chunk) + else: + yield cast(ChatGenerationChunk, result.generations[0]) return async with self._async_client.messages.stream(**params) as stream: async for text in stream.text_stream: @@ -386,6 +421,12 @@ class ChatAnthropic(BaseChatModel): } if len(content) == 1 and content[0]["type"] == "text": msg = AIMessage(content=content[0]["text"]) + elif any(block["type"] == "tool_use" for block in content): + tool_calls = extract_tool_calls(content) + msg = AIMessage( + content=content, + tool_calls=tool_calls, + ) else: msg = AIMessage(content=content) return ChatResult( diff --git a/libs/partners/anthropic/langchain_anthropic/output_parsers.py b/libs/partners/anthropic/langchain_anthropic/output_parsers.py index 7d3d05f85e7..84840591f33 100644 --- a/libs/partners/anthropic/langchain_anthropic/output_parsers.py +++ b/libs/partners/anthropic/langchain_anthropic/output_parsers.py @@ -1,18 +1,11 @@ -from typing import Any, List, Optional, Type, TypedDict, cast +from typing import Any, List, Optional, Type -from langchain_core.messages import BaseMessage +from langchain_core.messages import ToolCall from langchain_core.output_parsers import BaseGenerationOutputParser from langchain_core.outputs import ChatGeneration, Generation from langchain_core.pydantic_v1 import BaseModel -class _ToolCall(TypedDict): - name: str - args: dict - id: str - index: int - - class ToolsOutputParser(BaseGenerationOutputParser): first_tool_only: bool = False args_only: bool = False @@ -33,7 +26,19 @@ class ToolsOutputParser(BaseGenerationOutputParser): """ if not result or not isinstance(result[0], ChatGeneration): return None if self.first_tool_only else [] - tool_calls: List = _extract_tool_calls(result[0].message) + message = result[0].message + if isinstance(message.content, str): + tool_calls: List = [] + else: + content: List = message.content + _tool_calls = [dict(tc) for tc in extract_tool_calls(content)] + # Map tool call id to index + id_to_index = { + block["id"]: i + for i, block in enumerate(content) + if block["type"] == "tool_use" + } + tool_calls = [{**tc, "index": id_to_index[tc["id"]]} for tc in _tool_calls] if self.pydantic_schemas: tool_calls = [self._pydantic_parse(tc) for tc in tool_calls] elif self.args_only: @@ -44,23 +49,21 @@ class ToolsOutputParser(BaseGenerationOutputParser): if self.first_tool_only: return tool_calls[0] if tool_calls else None else: - return tool_calls + return [tool_call for tool_call in tool_calls] - def _pydantic_parse(self, tool_call: _ToolCall) -> BaseModel: + def _pydantic_parse(self, tool_call: dict) -> BaseModel: cls_ = {schema.__name__: schema for schema in self.pydantic_schemas or []}[ tool_call["name"] ] return cls_(**tool_call["args"]) -def _extract_tool_calls(msg: BaseMessage) -> List[_ToolCall]: - if isinstance(msg.content, str): - return [] +def extract_tool_calls(content: List[dict]) -> List[ToolCall]: tool_calls = [] - for i, block in enumerate(cast(List[dict], msg.content)): + for block in content: if block["type"] != "tool_use": continue tool_calls.append( - _ToolCall(name=block["name"], args=block["input"], id=block["id"], index=i) + ToolCall(name=block["name"], args=block["input"], id=block["id"]) ) return tool_calls diff --git a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py index 8021bdc1954..7737a1df99c 100644 --- a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py @@ -1,9 +1,15 @@ """Test ChatAnthropic chat model.""" +import json from typing import List from langchain_core.callbacks import CallbackManager -from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage, HumanMessage +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + HumanMessage, +) from langchain_core.outputs import ChatGeneration, LLMResult from langchain_core.prompts import ChatPromptTemplate @@ -234,6 +240,28 @@ def test_tool_use() -> None: response = llm_with_tools.invoke("what's the weather in san francisco, ca") assert isinstance(response, AIMessage) assert isinstance(response.content, list) + assert isinstance(response.tool_calls, list) + assert len(response.tool_calls) == 1 + tool_call = response.tool_calls[0] + assert tool_call["name"] == "get_weather" + assert isinstance(tool_call["args"], dict) + assert "location" in tool_call["args"] + + # Test streaming + first = True + for chunk in llm_with_tools.stream("what's the weather in san francisco, ca"): + if first: + gathered = chunk + first = False + else: + gathered = gathered + chunk # type: ignore + assert isinstance(gathered, AIMessageChunk) + assert isinstance(gathered.tool_call_chunks, list) + assert len(gathered.tool_call_chunks) == 1 + tool_call_chunk = gathered.tool_call_chunks[0] + assert tool_call_chunk["name"] == "get_weather" + assert isinstance(tool_call_chunk["args"], str) + assert "location" in json.loads(tool_call_chunk["args"]) def test_with_structured_output() -> None: diff --git a/libs/partners/fireworks/langchain_fireworks/chat_models.py b/libs/partners/fireworks/langchain_fireworks/chat_models.py index cc5a862c4f3..fc5960eea98 100644 --- a/libs/partners/fireworks/langchain_fireworks/chat_models.py +++ b/libs/partners/fireworks/langchain_fireworks/chat_models.py @@ -56,6 +56,8 @@ from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.openai_tools import ( JsonOutputKeyToolsParser, PydanticToolsParser, + make_invalid_tool_call, + parse_tool_call, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator @@ -94,9 +96,23 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: additional_kwargs: Dict = {} if function_call := _dict.get("function_call"): additional_kwargs["function_call"] = dict(function_call) - if tool_calls := _dict.get("tool_calls"): - additional_kwargs["tool_calls"] = tool_calls - return AIMessage(content=content, additional_kwargs=additional_kwargs) + tool_calls = [] + invalid_tool_calls = [] + if raw_tool_calls := _dict.get("tool_calls"): + additional_kwargs["tool_calls"] = raw_tool_calls + for raw_tool_call in raw_tool_calls: + try: + tool_calls.append(parse_tool_call(raw_tool_call, return_id=True)) + except Exception as e: + invalid_tool_calls.append( + dict(make_invalid_tool_call(raw_tool_call, str(e))) + ) + return AIMessage( + content=content, + additional_kwargs=additional_kwargs, + tool_calls=tool_calls, + invalid_tool_calls=invalid_tool_calls, + ) elif role == "system": return SystemMessage(content=_dict.get("content", "")) elif role == "function": @@ -174,13 +190,31 @@ def _convert_delta_to_message_chunk( if "name" in function_call and function_call["name"] is None: function_call["name"] = "" additional_kwargs["function_call"] = function_call - if _dict.get("tool_calls"): - additional_kwargs["tool_calls"] = _dict["tool_calls"] + if raw_tool_calls := _dict.get("tool_calls"): + additional_kwargs["tool_calls"] = raw_tool_calls + try: + tool_call_chunks = [ + { + "name": rtc["function"].get("name"), + "args": rtc["function"].get("arguments"), + "id": rtc.get("id"), + "index": rtc["index"], + } + for rtc in raw_tool_calls + ] + except KeyError: + pass + else: + tool_call_chunks = [] if role == "user" or default_class == HumanMessageChunk: return HumanMessageChunk(content=content) elif role == "assistant" or default_class == AIMessageChunk: - return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) + return AIMessageChunk( + content=content, + additional_kwargs=additional_kwargs, + tool_call_chunks=tool_call_chunks, + ) elif role == "system" or default_class == SystemMessageChunk: return SystemMessageChunk(content=content) elif role == "function" or default_class == FunctionMessageChunk: diff --git a/libs/partners/fireworks/tests/integration_tests/test_chat_models.py b/libs/partners/fireworks/tests/integration_tests/test_chat_models.py index 27c38b29f1e..f485c9ad03d 100644 --- a/libs/partners/fireworks/tests/integration_tests/test_chat_models.py +++ b/libs/partners/fireworks/tests/integration_tests/test_chat_models.py @@ -47,6 +47,11 @@ def test_tool_choice() -> None: "name": "Erick", } assert tool_call["type"] == "function" + assert isinstance(resp.tool_calls, list) + assert len(resp.tool_calls) == 1 + tool_call = resp.tool_calls[0] + assert tool_call["name"] == "MyTool" + assert tool_call["args"] == {"age": 27, "name": "Erick"} def test_tool_choice_bool() -> None: diff --git a/libs/partners/groq/langchain_groq/chat_models.py b/libs/partners/groq/langchain_groq/chat_models.py index e557eb26a56..5b58b36e6ea 100644 --- a/libs/partners/groq/langchain_groq/chat_models.py +++ b/libs/partners/groq/langchain_groq/chat_models.py @@ -58,6 +58,8 @@ from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.openai_tools import ( JsonOutputKeyToolsParser, PydanticToolsParser, + make_invalid_tool_call, + parse_tool_call, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator @@ -278,9 +280,20 @@ class ChatGroq(BaseChatModel): chat_result = self._create_chat_result(response) generation = chat_result.generations[0] message = generation.message + tool_call_chunks = [ + { + "name": rtc["function"].get("name"), + "args": rtc["function"].get("arguments"), + "id": rtc.get("id"), + "index": rtc.get("index"), + } + for rtc in message.additional_kwargs["tool_calls"] + ] chunk_ = ChatGenerationChunk( message=AIMessageChunk( - content=message.content, additional_kwargs=message.additional_kwargs + content=message.content, + additional_kwargs=message.additional_kwargs, + tool_call_chunks=tool_call_chunks, ), generation_info=generation.generation_info, ) @@ -338,9 +351,20 @@ class ChatGroq(BaseChatModel): chat_result = self._create_chat_result(response) generation = chat_result.generations[0] message = generation.message + tool_call_chunks = [ + { + "name": rtc["function"].get("name"), + "args": rtc["function"].get("arguments"), + "id": rtc.get("id"), + "index": rtc.get("index"), + } + for rtc in message.additional_kwargs["tool_calls"] + ] chunk_ = ChatGenerationChunk( message=AIMessageChunk( - content=message.content, additional_kwargs=message.additional_kwargs + content=message.content, + additional_kwargs=message.additional_kwargs, + tool_call_chunks=tool_call_chunks, ), generation_info=generation.generation_info, ) @@ -883,9 +907,24 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: additional_kwargs: Dict = {} if function_call := _dict.get("function_call"): additional_kwargs["function_call"] = dict(function_call) - if tool_calls := _dict.get("tool_calls"): - additional_kwargs["tool_calls"] = tool_calls - return AIMessage(content=content, id=id_, additional_kwargs=additional_kwargs) + tool_calls = [] + invalid_tool_calls = [] + if raw_tool_calls := _dict.get("tool_calls"): + additional_kwargs["tool_calls"] = raw_tool_calls + for raw_tool_call in raw_tool_calls: + try: + tool_calls.append(parse_tool_call(raw_tool_call, return_id=True)) + except Exception as e: + invalid_tool_calls.append( + make_invalid_tool_call(raw_tool_call, str(e)) + ) + return AIMessage( + content=content, + id=id_, + additional_kwargs=additional_kwargs, + tool_calls=tool_calls, + invalid_tool_calls=invalid_tool_calls, + ) elif role == "system": return SystemMessage(content=_dict.get("content", "")) elif role == "function": diff --git a/libs/partners/groq/tests/integration_tests/test_chat_models.py b/libs/partners/groq/tests/integration_tests/test_chat_models.py index a5445b6c07c..047497c5d1d 100644 --- a/libs/partners/groq/tests/integration_tests/test_chat_models.py +++ b/libs/partners/groq/tests/integration_tests/test_chat_models.py @@ -247,6 +247,12 @@ def test_tool_choice() -> None: } assert tool_call["type"] == "function" + assert isinstance(resp.tool_calls, list) + assert len(resp.tool_calls) == 1 + tool_call = resp.tool_calls[0] + assert tool_call["name"] == "MyTool" + assert tool_call["args"] == {"name": "Erick", "age": 27} + @pytest.mark.xfail(reason="Groq tool_choice doesn't currently force a tool call") def test_tool_choice_bool() -> None: @@ -302,6 +308,14 @@ def test_streaming_tool_call() -> None: } assert tool_call["type"] == "function" + assert isinstance(chunk, AIMessageChunk) + assert isinstance(chunk.tool_call_chunks, list) + assert len(chunk.tool_call_chunks) == 1 + tool_call_chunk = chunk.tool_call_chunks[0] + assert tool_call_chunk["name"] == "MyTool" + assert isinstance(tool_call_chunk["args"], str) + assert json.loads(tool_call_chunk["args"]) == {"name": "Erick", "age": 27} + @pytest.mark.xfail(reason="Groq tool_choice doesn't currently force a tool call") async def test_astreaming_tool_call() -> None: @@ -332,6 +346,14 @@ async def test_astreaming_tool_call() -> None: } assert tool_call["type"] == "function" + assert isinstance(chunk, AIMessageChunk) + assert isinstance(chunk.tool_call_chunks, list) + assert len(chunk.tool_call_chunks) == 1 + tool_call_chunk = chunk.tool_call_chunks[0] + assert tool_call_chunk["name"] == "MyTool" + assert isinstance(tool_call_chunk["args"], str) + assert json.loads(tool_call_chunk["args"]) == {"name": "Erick", "age": 27} + @pytest.mark.scheduled def test_json_mode_structured_output() -> None: diff --git a/libs/partners/groq/tests/unit_tests/test_chat_models.py b/libs/partners/groq/tests/unit_tests/test_chat_models.py index 35c50ab9a7b..2764814ad7b 100644 --- a/libs/partners/groq/tests/unit_tests/test_chat_models.py +++ b/libs/partners/groq/tests/unit_tests/test_chat_models.py @@ -11,7 +11,9 @@ from langchain_core.messages import ( AIMessage, FunctionMessage, HumanMessage, + InvalidToolCall, SystemMessage, + ToolCall, ) from langchain_groq.chat_models import ChatGroq, _convert_dict_to_message @@ -56,6 +58,73 @@ def test__convert_dict_to_message_ai() -> None: assert result == expected_output +def test__convert_dict_to_message_tool_call() -> None: + raw_tool_call = { + "id": "call_wm0JY6CdwOMZ4eTxHWUThDNz", + "function": { + "arguments": '{"name":"Sally","hair_color":"green"}', + "name": "GenerateUsername", + }, + "type": "function", + } + message = {"role": "assistant", "content": None, "tool_calls": [raw_tool_call]} + result = _convert_dict_to_message(message) + expected_output = AIMessage( + content="", + additional_kwargs={"tool_calls": [raw_tool_call]}, + tool_calls=[ + ToolCall( + name="GenerateUsername", + args={"name": "Sally", "hair_color": "green"}, + id="call_wm0JY6CdwOMZ4eTxHWUThDNz", + ) + ], + ) + assert result == expected_output + + # Test malformed tool call + raw_tool_calls = [ + { + "id": "call_wm0JY6CdwOMZ4eTxHWUThDNz", + "function": { + "arguments": "oops", + "name": "GenerateUsername", + }, + "type": "function", + }, + { + "id": "call_abc123", + "function": { + "arguments": '{"name":"Sally","hair_color":"green"}', + "name": "GenerateUsername", + }, + "type": "function", + }, + ] + message = {"role": "assistant", "content": None, "tool_calls": raw_tool_calls} + result = _convert_dict_to_message(message) + expected_output = AIMessage( + content="", + additional_kwargs={"tool_calls": raw_tool_calls}, + invalid_tool_calls=[ + InvalidToolCall( + name="GenerateUsername", + args="oops", + id="call_wm0JY6CdwOMZ4eTxHWUThDNz", + error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501 + ), + ], + tool_calls=[ + ToolCall( + name="GenerateUsername", + args={"name": "Sally", "hair_color": "green"}, + id="call_abc123", + ), + ], + ) + assert result == expected_output + + def test__convert_dict_to_message_system() -> None: message = {"role": "system", "content": "foo"} result = _convert_dict_to_message(message) diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index 6e48d67623e..c1d4c642e24 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -49,6 +49,8 @@ from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.openai_tools import ( JsonOutputKeyToolsParser, PydanticToolsParser, + make_invalid_tool_call, + parse_tool_call, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator @@ -82,9 +84,31 @@ def _convert_mistral_chat_message_to_message( content = cast(str, _message["content"]) additional_kwargs: Dict = {} - if tool_calls := _message.get("tool_calls"): - additional_kwargs["tool_calls"] = tool_calls - return AIMessage(content=content, additional_kwargs=additional_kwargs) + tool_calls = [] + invalid_tool_calls = [] + if raw_tool_calls := _message.get("tool_calls"): + additional_kwargs["tool_calls"] = raw_tool_calls + for raw_tool_call in raw_tool_calls: + try: + parsed: dict = cast( + dict, parse_tool_call(raw_tool_call, return_id=False) + ) + tool_calls.append( + { + **parsed, + **{"id": None}, + }, + ) + except Exception as e: + invalid_tool_calls.append( + dict(make_invalid_tool_call(raw_tool_call, str(e))) + ) + return AIMessage( + content=content, + additional_kwargs=additional_kwargs, + tool_calls=tool_calls, + invalid_tool_calls=invalid_tool_calls, + ) async def _aiter_sse( @@ -133,9 +157,27 @@ def _convert_delta_to_message_chunk( return HumanMessageChunk(content=content) elif role == "assistant" or default_class == AIMessageChunk: additional_kwargs: Dict = {} - if tool_calls := _delta.get("tool_calls"): - additional_kwargs["tool_calls"] = tool_calls - return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) + if raw_tool_calls := _delta.get("tool_calls"): + additional_kwargs["tool_calls"] = raw_tool_calls + try: + tool_call_chunks = [ + { + "name": rtc["function"].get("name"), + "args": rtc["function"].get("arguments"), + "id": rtc.get("id"), + "index": rtc.get("index"), + } + for rtc in raw_tool_calls + ] + except KeyError: + pass + else: + tool_call_chunks = [] + return AIMessageChunk( + content=content, + additional_kwargs=additional_kwargs, + tool_call_chunks=tool_call_chunks, + ) elif role == "system" or default_class == SystemMessageChunk: return SystemMessageChunk(content=content) elif role or default_class == ChatMessageChunk: @@ -163,7 +205,7 @@ def _convert_message_to_mistral_chat_message( for tc in message.additional_kwargs["tool_calls"] ] else: - tool_calls = None + tool_calls = [] return { "role": "assistant", "content": message.content, diff --git a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py index 6607531c5cd..e5e78c91086 100644 --- a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py @@ -3,7 +3,13 @@ import json from typing import Any -from langchain_core.messages import AIMessageChunk, HumanMessage +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + HumanMessage, + ToolCall, + ToolCallChunk, +) from langchain_core.pydantic_v1 import BaseModel from langchain_mistralai.chat_models import ChatMistralAI @@ -151,6 +157,22 @@ def test_streaming_structured_output() -> None: chunk_num += 1 +def test_tool_call() -> None: + llm = ChatMistralAI(model="mistral-large", temperature=0) + + class Person(BaseModel): + name: str + age: int + + tool_llm = llm.bind_tools([Person]) + + result = tool_llm.invoke("Erick, 27 years old") + assert isinstance(result, AIMessage) + assert result.tool_calls == [ + ToolCall(name="Person", args={"name": "Erick", "age": 27}, id=None) + ] + + def test_streaming_tool_call() -> None: llm = ChatMistralAI(model="mistral-large", temperature=0) @@ -178,6 +200,13 @@ def test_streaming_tool_call() -> None: "age": 27, } + assert isinstance(chunk, AIMessageChunk) + assert chunk.tool_call_chunks == [ + ToolCallChunk( + name="Person", args='{"name": "Erick", "age": 27}', id=None, index=None + ) + ] + # where it doesn't call the tool strm = tool_llm.stream("What is 2+2?") acc: Any = None diff --git a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py index 2ee2565e546..18fca396bb7 100644 --- a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py @@ -11,13 +11,16 @@ from langchain_core.messages import ( BaseMessage, ChatMessage, HumanMessage, + InvalidToolCall, SystemMessage, + ToolCall, ) from langchain_core.pydantic_v1 import SecretStr from langchain_mistralai.chat_models import ( # type: ignore[import] ChatMistralAI, _convert_message_to_mistral_chat_message, + _convert_mistral_chat_message_to_message, ) os.environ["MISTRAL_API_KEY"] = "foo" @@ -52,7 +55,7 @@ def test_mistralai_initialization() -> None: ), ( AIMessage(content="Hello"), - dict(role="assistant", content="Hello", tool_calls=None), + dict(role="assistant", content="Hello", tool_calls=[]), ), ( ChatMessage(role="assistant", content="Hello"), @@ -121,3 +124,66 @@ async def test_astream_with_callback() -> None: chat = ChatMistralAI(callbacks=[callback]) async for token in chat.astream("Hello"): assert callback.last_token == token.content + + +def test__convert_dict_to_message_tool_call() -> None: + raw_tool_call = { + "function": { + "arguments": '{"name":"Sally","hair_color":"green"}', + "name": "GenerateUsername", + }, + } + message = {"role": "assistant", "content": "", "tool_calls": [raw_tool_call]} + result = _convert_mistral_chat_message_to_message(message) + expected_output = AIMessage( + content="", + additional_kwargs={"tool_calls": [raw_tool_call]}, + tool_calls=[ + ToolCall( + name="GenerateUsername", + args={"name": "Sally", "hair_color": "green"}, + id=None, + ) + ], + ) + assert result == expected_output + assert _convert_message_to_mistral_chat_message(expected_output) == message + + # Test malformed tool call + raw_tool_calls = [ + { + "function": { + "arguments": "oops", + "name": "GenerateUsername", + }, + }, + { + "function": { + "arguments": '{"name":"Sally","hair_color":"green"}', + "name": "GenerateUsername", + }, + }, + ] + message = {"role": "assistant", "content": "", "tool_calls": raw_tool_calls} + result = _convert_mistral_chat_message_to_message(message) + expected_output = AIMessage( + content="", + additional_kwargs={"tool_calls": raw_tool_calls}, + invalid_tool_calls=[ + InvalidToolCall( + name="GenerateUsername", + args="oops", + error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501 + id=None, + ), + ], + tool_calls=[ + ToolCall( + name="GenerateUsername", + args={"name": "Sally", "hair_color": "green"}, + id=None, + ), + ], + ) + assert result == expected_output + assert _convert_message_to_mistral_chat_message(expected_output) == message diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 7a108e7b648..09e6ac1052a 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -63,6 +63,8 @@ from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.openai_tools import ( JsonOutputKeyToolsParser, PydanticToolsParser, + make_invalid_tool_call, + parse_tool_call, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator @@ -103,10 +105,24 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: additional_kwargs: Dict = {} if function_call := _dict.get("function_call"): additional_kwargs["function_call"] = dict(function_call) - if tool_calls := _dict.get("tool_calls"): - additional_kwargs["tool_calls"] = tool_calls + tool_calls = [] + invalid_tool_calls = [] + if raw_tool_calls := _dict.get("tool_calls"): + additional_kwargs["tool_calls"] = raw_tool_calls + for raw_tool_call in raw_tool_calls: + try: + tool_calls.append(parse_tool_call(raw_tool_call, return_id=True)) + except Exception as e: + invalid_tool_calls.append( + make_invalid_tool_call(raw_tool_call, str(e)) + ) return AIMessage( - content=content, additional_kwargs=additional_kwargs, name=name, id=id_ + content=content, + additional_kwargs=additional_kwargs, + name=name, + id=id_, + tool_calls=tool_calls, + invalid_tool_calls=invalid_tool_calls, ) elif role == "system": return SystemMessage(content=_dict.get("content", ""), name=name, id=id_) @@ -188,14 +204,30 @@ def _convert_delta_to_message_chunk( if "name" in function_call and function_call["name"] is None: function_call["name"] = "" additional_kwargs["function_call"] = function_call - if _dict.get("tool_calls"): - additional_kwargs["tool_calls"] = _dict["tool_calls"] + tool_call_chunks = [] + if raw_tool_calls := _dict.get("tool_calls"): + additional_kwargs["tool_calls"] = raw_tool_calls + try: + tool_call_chunks = [ + { + "name": rtc["function"].get("name"), + "args": rtc["function"].get("arguments"), + "id": rtc.get("id"), + "index": rtc["index"], + } + for rtc in raw_tool_calls + ] + except KeyError: + pass if role == "user" or default_class == HumanMessageChunk: return HumanMessageChunk(content=content, id=id_) elif role == "assistant" or default_class == AIMessageChunk: return AIMessageChunk( - content=content, additional_kwargs=additional_kwargs, id=id_ + content=content, + additional_kwargs=additional_kwargs, + id=id_, + tool_call_chunks=tool_call_chunks, ) elif role == "system" or default_class == SystemMessageChunk: return SystemMessageChunk(content=content, id=id_) diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index b65e0f7ce55..c400831e9aa 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -5,6 +5,7 @@ import pytest from langchain_core.callbacks import CallbackManager from langchain_core.messages import ( AIMessage, + AIMessageChunk, BaseMessage, BaseMessageChunk, HumanMessage, @@ -482,6 +483,28 @@ def test_tool_use() -> None: llm_with_tool = llm.bind_tools(tools=[GenerateUsername], tool_choice=True) msgs: List = [HumanMessage("Sally has green hair, what would her username be?")] ai_msg = llm_with_tool.invoke(msgs) + + assert isinstance(ai_msg, AIMessage) + assert isinstance(ai_msg.tool_calls, list) + assert len(ai_msg.tool_calls) == 1 + tool_call = ai_msg.tool_calls[0] + assert "args" in tool_call + + # Test streaming + ai_messages = llm_with_tool.stream(msgs) + first = True + for message in ai_messages: + if first: + gathered = message + first = False + else: + gathered = gathered + message # type: ignore + assert isinstance(gathered, AIMessageChunk) + assert isinstance(gathered.tool_call_chunks, list) + assert len(gathered.tool_call_chunks) == 1 + tool_call_chunk = gathered.tool_call_chunks[0] + assert "args" in tool_call_chunk + tool_msg = ToolMessage( "sally_green_hair", tool_call_id=ai_msg.additional_kwargs["tool_calls"][0]["id"] ) diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index 4a9a6498057..1b8668c9551 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -9,7 +9,9 @@ from langchain_core.messages import ( AIMessage, FunctionMessage, HumanMessage, + InvalidToolCall, SystemMessage, + ToolCall, ToolMessage, ) @@ -98,6 +100,75 @@ def test__convert_dict_to_message_tool() -> None: assert _convert_message_to_dict(expected_output) == message +def test__convert_dict_to_message_tool_call() -> None: + raw_tool_call = { + "id": "call_wm0JY6CdwOMZ4eTxHWUThDNz", + "function": { + "arguments": '{"name":"Sally","hair_color":"green"}', + "name": "GenerateUsername", + }, + "type": "function", + } + message = {"role": "assistant", "content": None, "tool_calls": [raw_tool_call]} + result = _convert_dict_to_message(message) + expected_output = AIMessage( + content="", + additional_kwargs={"tool_calls": [raw_tool_call]}, + tool_calls=[ + ToolCall( + name="GenerateUsername", + args={"name": "Sally", "hair_color": "green"}, + id="call_wm0JY6CdwOMZ4eTxHWUThDNz", + ) + ], + ) + assert result == expected_output + assert _convert_message_to_dict(expected_output) == message + + # Test malformed tool call + raw_tool_calls = [ + { + "id": "call_wm0JY6CdwOMZ4eTxHWUThDNz", + "function": { + "arguments": "oops", + "name": "GenerateUsername", + }, + "type": "function", + }, + { + "id": "call_abc123", + "function": { + "arguments": '{"name":"Sally","hair_color":"green"}', + "name": "GenerateUsername", + }, + "type": "function", + }, + ] + message = {"role": "assistant", "content": None, "tool_calls": raw_tool_calls} + result = _convert_dict_to_message(message) + expected_output = AIMessage( + content="", + additional_kwargs={"tool_calls": raw_tool_calls}, + invalid_tool_calls=[ + InvalidToolCall( + name="GenerateUsername", + args="oops", + id="call_wm0JY6CdwOMZ4eTxHWUThDNz", + error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501 + ), + ], + tool_calls=[ + ToolCall( + name="GenerateUsername", + args={"name": "Sally", "hair_color": "green"}, + id="call_abc123", + ), + ], + ) + assert result == expected_output + assert _convert_message_to_dict(expected_output) == message + + @pytest.fixture def mock_completion() -> dict: return {