Add unit tests to test openai tools agent (#15843)

This PR adds unit testing to test openai tools agent.
This commit is contained in:
Eugene Yurtsev 2024-01-10 17:06:30 -05:00 committed by GitHub
parent 21a1538949
commit a06db53c37
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 443 additions and 2 deletions

View File

@ -19,7 +19,6 @@ def create_openai_tools_agent(
Examples:
.. code-block:: python
from langchain import hub
@ -56,7 +55,6 @@ def create_openai_tools_agent(
A runnable sequence representing an agent. It takes as input all the same input
variables as the prompt passed in does. It returns as output either an
AgentAction or AgentFinish.
"""
missing_vars = {"agent_scratchpad"}.difference(prompt.input_variables)
if missing_vars:

View File

@ -25,8 +25,10 @@ from langchain.agents import (
AgentExecutor,
AgentType,
create_openai_functions_agent,
create_openai_tools_agent,
initialize_agent,
)
from langchain.agents.output_parsers.openai_tools import OpenAIToolAgentAction
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.prompts import ChatPromptTemplate
from langchain.tools import tool
@ -626,6 +628,140 @@ async def test_runnable_agent_with_function_calls() -> None:
assert messages == ["looking", " ", "for", " ", "pet...", "Found", " ", "Pet"]
async def test_runnable_with_multi_action_per_step() -> None:
"""Test an agent that can make multiple function calls at once."""
# Will alternate between responding with hello and goodbye
infinite_cycle = cycle(
[AIMessage(content="looking for pet..."), AIMessage(content="Found Pet")]
)
model = GenericFakeChatModel(messages=infinite_cycle)
template = ChatPromptTemplate.from_messages(
[("system", "You are Cat Agent 007"), ("human", "{question}")]
)
parser_responses = cycle(
[
[
AgentAction(
tool="find_pet",
tool_input={
"pet": "cat",
},
log="find_pet()",
),
AgentAction(
tool="pet_pet", # A function that allows you to pet the given pet.
tool_input={
"pet": "cat",
},
log="pet_pet()",
),
],
AgentFinish(
return_values={"foo": "meow"},
log="hard-coded-message",
),
],
)
def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]:
"""A parser."""
return cast(Union[AgentFinish, AgentAction], next(parser_responses))
@tool
def find_pet(pet: str) -> str:
"""Find the given pet."""
if pet != "cat":
raise ValueError("Only cats allowed")
return "Spying from under the bed."
@tool
def pet_pet(pet: str) -> str:
"""Pet the given pet."""
if pet != "cat":
raise ValueError("Only cats should be petted.")
return "purrrr"
agent = template | model | fake_parse
executor = AgentExecutor(agent=agent, tools=[find_pet])
# Invoke
result = executor.invoke({"question": "hello"})
assert result == {"foo": "meow", "question": "hello"}
# ainvoke
result = await executor.ainvoke({"question": "hello"})
assert result == {"foo": "meow", "question": "hello"}
# astream
results = [r async for r in executor.astream({"question": "hello"})]
assert results == [
{
"actions": [
AgentAction(
tool="find_pet", tool_input={"pet": "cat"}, log="find_pet()"
)
],
"messages": [AIMessage(content="find_pet()")],
},
{
"actions": [
AgentAction(tool="pet_pet", tool_input={"pet": "cat"}, log="pet_pet()")
],
"messages": [AIMessage(content="pet_pet()")],
},
{
# By-default observation gets converted into human message.
"messages": [HumanMessage(content="Spying from under the bed.")],
"steps": [
AgentStep(
action=AgentAction(
tool="find_pet", tool_input={"pet": "cat"}, log="find_pet()"
),
observation="Spying from under the bed.",
)
],
},
{
"messages": [
HumanMessage(
content="pet_pet is not a valid tool, try one of [find_pet]."
)
],
"steps": [
AgentStep(
action=AgentAction(
tool="pet_pet", tool_input={"pet": "cat"}, log="pet_pet()"
),
observation="pet_pet is not a valid tool, try one of [find_pet].",
)
],
},
{"foo": "meow", "messages": [AIMessage(content="hard-coded-message")]},
]
# astream log
messages = []
async for patch in executor.astream_log({"question": "hello"}):
for op in patch.ops:
if op["op"] != "add":
continue
value = op["value"]
if not isinstance(value, AIMessageChunk):
continue
if value.content == "": # Then it's a function invocation message
continue
messages.append(value.content)
assert messages == ["looking", " ", "for", " ", "pet...", "Found", " ", "Pet"]
def _make_func_invocation(name: str, **kwargs: Any) -> AIMessage:
"""Create an AIMessage that represents a function invocation.
@ -788,3 +924,310 @@ async def test_openai_agent_with_streaming() -> None:
" ",
"bed.",
]
def _make_tools_invocation(name_to_arguments: Dict[str, Dict[str, Any]]) -> AIMessage:
"""Create an AIMessage that represents a tools invocation.
Args:
name_to_arguments: A dictionary mapping tool names to an invocation.
Returns:
AIMessage that represents a request to invoke a tool.
"""
tool_calls = [
{"function": {"name": name, "arguments": json.dumps(arguments)}, "id": idx}
for idx, (name, arguments) in enumerate(name_to_arguments.items())
]
return AIMessage(
content="",
additional_kwargs={
"tool_calls": tool_calls,
},
)
async def test_openai_agent_tools_agent() -> None:
"""Test OpenAI tools agent."""
infinite_cycle = cycle(
[
_make_tools_invocation(
{
"find_pet": {"pet": "cat"},
"check_time": {},
}
),
AIMessage(content="The cat is spying from under the bed."),
]
)
model = GenericFakeChatModel(messages=infinite_cycle)
@tool
def find_pet(pet: str) -> str:
"""Find the given pet."""
if pet != "cat":
raise ValueError("Only cats allowed")
return "Spying from under the bed."
@tool
def check_time() -> str:
"""Find the given pet."""
return "It's time to pet the cat."
template = ChatPromptTemplate.from_messages(
[
("system", "You are a helpful AI bot. Your name is kitty power meow."),
("human", "{question}"),
MessagesPlaceholder(
variable_name="agent_scratchpad",
),
]
)
# type error due to base tool type below -- would need to be adjusted on tool
# decorator.
agent = create_openai_tools_agent(
model,
[find_pet], # type: ignore[list-item]
template,
)
executor = AgentExecutor(agent=agent, tools=[find_pet])
# Invoke
result = executor.invoke({"question": "hello"})
assert result == {
"output": "The cat is spying from under the bed.",
"question": "hello",
}
# astream
chunks = [chunk async for chunk in executor.astream({"question": "hello"})]
assert chunks == [
{
"actions": [
OpenAIToolAgentAction(
tool="find_pet",
tool_input={"pet": "cat"},
log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n",
message_log=[
AIMessageChunk(
content="",
additional_kwargs={
"tool_calls": [
{
"function": {
"name": "find_pet",
"arguments": '{"pet": "cat"}',
},
"id": 0,
},
{
"function": {
"name": "check_time",
"arguments": "{}",
},
"id": 1,
},
]
},
)
],
tool_call_id="0",
)
],
"messages": [
AIMessageChunk(
content="",
additional_kwargs={
"tool_calls": [
{
"function": {
"name": "find_pet",
"arguments": '{"pet": "cat"}',
},
"id": 0,
},
{
"function": {"name": "check_time", "arguments": "{}"},
"id": 1,
},
]
},
)
],
},
{
"actions": [
OpenAIToolAgentAction(
tool="check_time",
tool_input={},
log="\nInvoking: `check_time` with `{}`\n\n\n",
message_log=[
AIMessageChunk(
content="",
additional_kwargs={
"tool_calls": [
{
"function": {
"name": "find_pet",
"arguments": '{"pet": "cat"}',
},
"id": 0,
},
{
"function": {
"name": "check_time",
"arguments": "{}",
},
"id": 1,
},
]
},
)
],
tool_call_id="1",
)
],
"messages": [
AIMessageChunk(
content="",
additional_kwargs={
"tool_calls": [
{
"function": {
"name": "find_pet",
"arguments": '{"pet": "cat"}',
},
"id": 0,
},
{
"function": {"name": "check_time", "arguments": "{}"},
"id": 1,
},
]
},
)
],
},
{
"messages": [
FunctionMessage(content="Spying from under the bed.", name="find_pet")
],
"steps": [
AgentStep(
action=OpenAIToolAgentAction(
tool="find_pet",
tool_input={"pet": "cat"},
log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n",
message_log=[
AIMessageChunk(
content="",
additional_kwargs={
"tool_calls": [
{
"function": {
"name": "find_pet",
"arguments": '{"pet": "cat"}',
},
"id": 0,
},
{
"function": {
"name": "check_time",
"arguments": "{}",
},
"id": 1,
},
]
},
)
],
tool_call_id="0",
),
observation="Spying from under the bed.",
)
],
},
{
"messages": [
FunctionMessage(
content="check_time is not a valid tool, try one of [find_pet].",
name="check_time",
)
],
"steps": [
AgentStep(
action=OpenAIToolAgentAction(
tool="check_time",
tool_input={},
log="\nInvoking: `check_time` with `{}`\n\n\n",
message_log=[
AIMessageChunk(
content="",
additional_kwargs={
"tool_calls": [
{
"function": {
"name": "find_pet",
"arguments": '{"pet": "cat"}',
},
"id": 0,
},
{
"function": {
"name": "check_time",
"arguments": "{}",
},
"id": 1,
},
]
},
)
],
tool_call_id="1",
),
observation="check_time is not a valid tool, "
"try one of [find_pet].",
)
],
},
{
"messages": [AIMessage(content="The cat is spying from under the bed.")],
"output": "The cat is spying from under the bed.",
},
]
# astream_log
log_patches = [
log_patch async for log_patch in executor.astream_log({"question": "hello"})
]
# Get the tokens from the astream log response.
messages = []
for log_patch in log_patches:
for op in log_patch.ops:
if op["op"] == "add" and isinstance(op["value"], AIMessageChunk):
value = op["value"]
if value.content: # Filter out function call messages
messages.append(value.content)
assert messages == [
"The",
" ",
"cat",
" ",
"is",
" ",
"spying",
" ",
"from",
" ",
"under",
" ",
"the",
" ",
"bed.",
]