diff --git a/libs/langchain/langchain/agents/openai_tools/base.py b/libs/langchain/langchain/agents/openai_tools/base.py index 56868d903ad..c1206ea4efc 100644 --- a/libs/langchain/langchain/agents/openai_tools/base.py +++ b/libs/langchain/langchain/agents/openai_tools/base.py @@ -19,7 +19,6 @@ def create_openai_tools_agent( Examples: - .. code-block:: python from langchain import hub @@ -56,7 +55,6 @@ def create_openai_tools_agent( A runnable sequence representing an agent. It takes as input all the same input variables as the prompt passed in does. It returns as output either an AgentAction or AgentFinish. - """ missing_vars = {"agent_scratchpad"}.difference(prompt.input_variables) if missing_vars: diff --git a/libs/langchain/tests/unit_tests/agents/test_agent.py b/libs/langchain/tests/unit_tests/agents/test_agent.py index 7baee92e072..6ba908b6ad0 100644 --- a/libs/langchain/tests/unit_tests/agents/test_agent.py +++ b/libs/langchain/tests/unit_tests/agents/test_agent.py @@ -25,8 +25,10 @@ from langchain.agents import ( AgentExecutor, AgentType, create_openai_functions_agent, + create_openai_tools_agent, initialize_agent, ) +from langchain.agents.output_parsers.openai_tools import OpenAIToolAgentAction from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.prompts import ChatPromptTemplate from langchain.tools import tool @@ -626,6 +628,140 @@ async def test_runnable_agent_with_function_calls() -> None: assert messages == ["looking", " ", "for", " ", "pet...", "Found", " ", "Pet"] +async def test_runnable_with_multi_action_per_step() -> None: + """Test an agent that can make multiple function calls at once.""" + # Will alternate between responding with hello and goodbye + infinite_cycle = cycle( + [AIMessage(content="looking for pet..."), AIMessage(content="Found Pet")] + ) + model = GenericFakeChatModel(messages=infinite_cycle) + + template = ChatPromptTemplate.from_messages( + [("system", "You are Cat Agent 007"), ("human", "{question}")] + ) + + parser_responses = cycle( + [ + [ + AgentAction( + tool="find_pet", + tool_input={ + "pet": "cat", + }, + log="find_pet()", + ), + AgentAction( + tool="pet_pet", # A function that allows you to pet the given pet. + tool_input={ + "pet": "cat", + }, + log="pet_pet()", + ), + ], + AgentFinish( + return_values={"foo": "meow"}, + log="hard-coded-message", + ), + ], + ) + + def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]: + """A parser.""" + return cast(Union[AgentFinish, AgentAction], next(parser_responses)) + + @tool + def find_pet(pet: str) -> str: + """Find the given pet.""" + if pet != "cat": + raise ValueError("Only cats allowed") + return "Spying from under the bed." + + @tool + def pet_pet(pet: str) -> str: + """Pet the given pet.""" + if pet != "cat": + raise ValueError("Only cats should be petted.") + return "purrrr" + + agent = template | model | fake_parse + executor = AgentExecutor(agent=agent, tools=[find_pet]) + + # Invoke + result = executor.invoke({"question": "hello"}) + assert result == {"foo": "meow", "question": "hello"} + + # ainvoke + result = await executor.ainvoke({"question": "hello"}) + assert result == {"foo": "meow", "question": "hello"} + + # astream + results = [r async for r in executor.astream({"question": "hello"})] + assert results == [ + { + "actions": [ + AgentAction( + tool="find_pet", tool_input={"pet": "cat"}, log="find_pet()" + ) + ], + "messages": [AIMessage(content="find_pet()")], + }, + { + "actions": [ + AgentAction(tool="pet_pet", tool_input={"pet": "cat"}, log="pet_pet()") + ], + "messages": [AIMessage(content="pet_pet()")], + }, + { + # By-default observation gets converted into human message. + "messages": [HumanMessage(content="Spying from under the bed.")], + "steps": [ + AgentStep( + action=AgentAction( + tool="find_pet", tool_input={"pet": "cat"}, log="find_pet()" + ), + observation="Spying from under the bed.", + ) + ], + }, + { + "messages": [ + HumanMessage( + content="pet_pet is not a valid tool, try one of [find_pet]." + ) + ], + "steps": [ + AgentStep( + action=AgentAction( + tool="pet_pet", tool_input={"pet": "cat"}, log="pet_pet()" + ), + observation="pet_pet is not a valid tool, try one of [find_pet].", + ) + ], + }, + {"foo": "meow", "messages": [AIMessage(content="hard-coded-message")]}, + ] + + # astream log + + messages = [] + async for patch in executor.astream_log({"question": "hello"}): + for op in patch.ops: + if op["op"] != "add": + continue + + value = op["value"] + + if not isinstance(value, AIMessageChunk): + continue + + if value.content == "": # Then it's a function invocation message + continue + + messages.append(value.content) + + assert messages == ["looking", " ", "for", " ", "pet...", "Found", " ", "Pet"] + + def _make_func_invocation(name: str, **kwargs: Any) -> AIMessage: """Create an AIMessage that represents a function invocation. @@ -788,3 +924,310 @@ async def test_openai_agent_with_streaming() -> None: " ", "bed.", ] + + +def _make_tools_invocation(name_to_arguments: Dict[str, Dict[str, Any]]) -> AIMessage: + """Create an AIMessage that represents a tools invocation. + + Args: + name_to_arguments: A dictionary mapping tool names to an invocation. + + Returns: + AIMessage that represents a request to invoke a tool. + """ + tool_calls = [ + {"function": {"name": name, "arguments": json.dumps(arguments)}, "id": idx} + for idx, (name, arguments) in enumerate(name_to_arguments.items()) + ] + + return AIMessage( + content="", + additional_kwargs={ + "tool_calls": tool_calls, + }, + ) + + +async def test_openai_agent_tools_agent() -> None: + """Test OpenAI tools agent.""" + infinite_cycle = cycle( + [ + _make_tools_invocation( + { + "find_pet": {"pet": "cat"}, + "check_time": {}, + } + ), + AIMessage(content="The cat is spying from under the bed."), + ] + ) + + model = GenericFakeChatModel(messages=infinite_cycle) + + @tool + def find_pet(pet: str) -> str: + """Find the given pet.""" + if pet != "cat": + raise ValueError("Only cats allowed") + return "Spying from under the bed." + + @tool + def check_time() -> str: + """Find the given pet.""" + return "It's time to pet the cat." + + template = ChatPromptTemplate.from_messages( + [ + ("system", "You are a helpful AI bot. Your name is kitty power meow."), + ("human", "{question}"), + MessagesPlaceholder( + variable_name="agent_scratchpad", + ), + ] + ) + + # type error due to base tool type below -- would need to be adjusted on tool + # decorator. + agent = create_openai_tools_agent( + model, + [find_pet], # type: ignore[list-item] + template, + ) + executor = AgentExecutor(agent=agent, tools=[find_pet]) + + # Invoke + result = executor.invoke({"question": "hello"}) + assert result == { + "output": "The cat is spying from under the bed.", + "question": "hello", + } + + # astream + chunks = [chunk async for chunk in executor.astream({"question": "hello"})] + assert chunks == [ + { + "actions": [ + OpenAIToolAgentAction( + tool="find_pet", + tool_input={"pet": "cat"}, + log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n", + message_log=[ + AIMessageChunk( + content="", + additional_kwargs={ + "tool_calls": [ + { + "function": { + "name": "find_pet", + "arguments": '{"pet": "cat"}', + }, + "id": 0, + }, + { + "function": { + "name": "check_time", + "arguments": "{}", + }, + "id": 1, + }, + ] + }, + ) + ], + tool_call_id="0", + ) + ], + "messages": [ + AIMessageChunk( + content="", + additional_kwargs={ + "tool_calls": [ + { + "function": { + "name": "find_pet", + "arguments": '{"pet": "cat"}', + }, + "id": 0, + }, + { + "function": {"name": "check_time", "arguments": "{}"}, + "id": 1, + }, + ] + }, + ) + ], + }, + { + "actions": [ + OpenAIToolAgentAction( + tool="check_time", + tool_input={}, + log="\nInvoking: `check_time` with `{}`\n\n\n", + message_log=[ + AIMessageChunk( + content="", + additional_kwargs={ + "tool_calls": [ + { + "function": { + "name": "find_pet", + "arguments": '{"pet": "cat"}', + }, + "id": 0, + }, + { + "function": { + "name": "check_time", + "arguments": "{}", + }, + "id": 1, + }, + ] + }, + ) + ], + tool_call_id="1", + ) + ], + "messages": [ + AIMessageChunk( + content="", + additional_kwargs={ + "tool_calls": [ + { + "function": { + "name": "find_pet", + "arguments": '{"pet": "cat"}', + }, + "id": 0, + }, + { + "function": {"name": "check_time", "arguments": "{}"}, + "id": 1, + }, + ] + }, + ) + ], + }, + { + "messages": [ + FunctionMessage(content="Spying from under the bed.", name="find_pet") + ], + "steps": [ + AgentStep( + action=OpenAIToolAgentAction( + tool="find_pet", + tool_input={"pet": "cat"}, + log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n", + message_log=[ + AIMessageChunk( + content="", + additional_kwargs={ + "tool_calls": [ + { + "function": { + "name": "find_pet", + "arguments": '{"pet": "cat"}', + }, + "id": 0, + }, + { + "function": { + "name": "check_time", + "arguments": "{}", + }, + "id": 1, + }, + ] + }, + ) + ], + tool_call_id="0", + ), + observation="Spying from under the bed.", + ) + ], + }, + { + "messages": [ + FunctionMessage( + content="check_time is not a valid tool, try one of [find_pet].", + name="check_time", + ) + ], + "steps": [ + AgentStep( + action=OpenAIToolAgentAction( + tool="check_time", + tool_input={}, + log="\nInvoking: `check_time` with `{}`\n\n\n", + message_log=[ + AIMessageChunk( + content="", + additional_kwargs={ + "tool_calls": [ + { + "function": { + "name": "find_pet", + "arguments": '{"pet": "cat"}', + }, + "id": 0, + }, + { + "function": { + "name": "check_time", + "arguments": "{}", + }, + "id": 1, + }, + ] + }, + ) + ], + tool_call_id="1", + ), + observation="check_time is not a valid tool, " + "try one of [find_pet].", + ) + ], + }, + { + "messages": [AIMessage(content="The cat is spying from under the bed.")], + "output": "The cat is spying from under the bed.", + }, + ] + + # astream_log + log_patches = [ + log_patch async for log_patch in executor.astream_log({"question": "hello"}) + ] + + # Get the tokens from the astream log response. + messages = [] + + for log_patch in log_patches: + for op in log_patch.ops: + if op["op"] == "add" and isinstance(op["value"], AIMessageChunk): + value = op["value"] + if value.content: # Filter out function call messages + messages.append(value.content) + + assert messages == [ + "The", + " ", + "cat", + " ", + "is", + " ", + "spying", + " ", + "from", + " ", + "under", + " ", + "the", + " ", + "bed.", + ]