mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-08 12:31:49 +00:00
Add unit tests to test openai tools agent (#15843)
This PR adds unit testing to test openai tools agent.
This commit is contained in:
parent
21a1538949
commit
a06db53c37
@ -19,7 +19,6 @@ def create_openai_tools_agent(
|
|||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from langchain import hub
|
from langchain import hub
|
||||||
@ -56,7 +55,6 @@ def create_openai_tools_agent(
|
|||||||
A runnable sequence representing an agent. It takes as input all the same input
|
A runnable sequence representing an agent. It takes as input all the same input
|
||||||
variables as the prompt passed in does. It returns as output either an
|
variables as the prompt passed in does. It returns as output either an
|
||||||
AgentAction or AgentFinish.
|
AgentAction or AgentFinish.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
missing_vars = {"agent_scratchpad"}.difference(prompt.input_variables)
|
missing_vars = {"agent_scratchpad"}.difference(prompt.input_variables)
|
||||||
if missing_vars:
|
if missing_vars:
|
||||||
|
@ -25,8 +25,10 @@ from langchain.agents import (
|
|||||||
AgentExecutor,
|
AgentExecutor,
|
||||||
AgentType,
|
AgentType,
|
||||||
create_openai_functions_agent,
|
create_openai_functions_agent,
|
||||||
|
create_openai_tools_agent,
|
||||||
initialize_agent,
|
initialize_agent,
|
||||||
)
|
)
|
||||||
|
from langchain.agents.output_parsers.openai_tools import OpenAIToolAgentAction
|
||||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
from langchain.prompts import ChatPromptTemplate
|
from langchain.prompts import ChatPromptTemplate
|
||||||
from langchain.tools import tool
|
from langchain.tools import tool
|
||||||
@ -626,6 +628,140 @@ async def test_runnable_agent_with_function_calls() -> None:
|
|||||||
assert messages == ["looking", " ", "for", " ", "pet...", "Found", " ", "Pet"]
|
assert messages == ["looking", " ", "for", " ", "pet...", "Found", " ", "Pet"]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_runnable_with_multi_action_per_step() -> None:
|
||||||
|
"""Test an agent that can make multiple function calls at once."""
|
||||||
|
# Will alternate between responding with hello and goodbye
|
||||||
|
infinite_cycle = cycle(
|
||||||
|
[AIMessage(content="looking for pet..."), AIMessage(content="Found Pet")]
|
||||||
|
)
|
||||||
|
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||||
|
|
||||||
|
template = ChatPromptTemplate.from_messages(
|
||||||
|
[("system", "You are Cat Agent 007"), ("human", "{question}")]
|
||||||
|
)
|
||||||
|
|
||||||
|
parser_responses = cycle(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
AgentAction(
|
||||||
|
tool="find_pet",
|
||||||
|
tool_input={
|
||||||
|
"pet": "cat",
|
||||||
|
},
|
||||||
|
log="find_pet()",
|
||||||
|
),
|
||||||
|
AgentAction(
|
||||||
|
tool="pet_pet", # A function that allows you to pet the given pet.
|
||||||
|
tool_input={
|
||||||
|
"pet": "cat",
|
||||||
|
},
|
||||||
|
log="pet_pet()",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
AgentFinish(
|
||||||
|
return_values={"foo": "meow"},
|
||||||
|
log="hard-coded-message",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]:
|
||||||
|
"""A parser."""
|
||||||
|
return cast(Union[AgentFinish, AgentAction], next(parser_responses))
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def find_pet(pet: str) -> str:
|
||||||
|
"""Find the given pet."""
|
||||||
|
if pet != "cat":
|
||||||
|
raise ValueError("Only cats allowed")
|
||||||
|
return "Spying from under the bed."
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def pet_pet(pet: str) -> str:
|
||||||
|
"""Pet the given pet."""
|
||||||
|
if pet != "cat":
|
||||||
|
raise ValueError("Only cats should be petted.")
|
||||||
|
return "purrrr"
|
||||||
|
|
||||||
|
agent = template | model | fake_parse
|
||||||
|
executor = AgentExecutor(agent=agent, tools=[find_pet])
|
||||||
|
|
||||||
|
# Invoke
|
||||||
|
result = executor.invoke({"question": "hello"})
|
||||||
|
assert result == {"foo": "meow", "question": "hello"}
|
||||||
|
|
||||||
|
# ainvoke
|
||||||
|
result = await executor.ainvoke({"question": "hello"})
|
||||||
|
assert result == {"foo": "meow", "question": "hello"}
|
||||||
|
|
||||||
|
# astream
|
||||||
|
results = [r async for r in executor.astream({"question": "hello"})]
|
||||||
|
assert results == [
|
||||||
|
{
|
||||||
|
"actions": [
|
||||||
|
AgentAction(
|
||||||
|
tool="find_pet", tool_input={"pet": "cat"}, log="find_pet()"
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"messages": [AIMessage(content="find_pet()")],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"actions": [
|
||||||
|
AgentAction(tool="pet_pet", tool_input={"pet": "cat"}, log="pet_pet()")
|
||||||
|
],
|
||||||
|
"messages": [AIMessage(content="pet_pet()")],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# By-default observation gets converted into human message.
|
||||||
|
"messages": [HumanMessage(content="Spying from under the bed.")],
|
||||||
|
"steps": [
|
||||||
|
AgentStep(
|
||||||
|
action=AgentAction(
|
||||||
|
tool="find_pet", tool_input={"pet": "cat"}, log="find_pet()"
|
||||||
|
),
|
||||||
|
observation="Spying from under the bed.",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
HumanMessage(
|
||||||
|
content="pet_pet is not a valid tool, try one of [find_pet]."
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"steps": [
|
||||||
|
AgentStep(
|
||||||
|
action=AgentAction(
|
||||||
|
tool="pet_pet", tool_input={"pet": "cat"}, log="pet_pet()"
|
||||||
|
),
|
||||||
|
observation="pet_pet is not a valid tool, try one of [find_pet].",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"foo": "meow", "messages": [AIMessage(content="hard-coded-message")]},
|
||||||
|
]
|
||||||
|
|
||||||
|
# astream log
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
async for patch in executor.astream_log({"question": "hello"}):
|
||||||
|
for op in patch.ops:
|
||||||
|
if op["op"] != "add":
|
||||||
|
continue
|
||||||
|
|
||||||
|
value = op["value"]
|
||||||
|
|
||||||
|
if not isinstance(value, AIMessageChunk):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if value.content == "": # Then it's a function invocation message
|
||||||
|
continue
|
||||||
|
|
||||||
|
messages.append(value.content)
|
||||||
|
|
||||||
|
assert messages == ["looking", " ", "for", " ", "pet...", "Found", " ", "Pet"]
|
||||||
|
|
||||||
|
|
||||||
def _make_func_invocation(name: str, **kwargs: Any) -> AIMessage:
|
def _make_func_invocation(name: str, **kwargs: Any) -> AIMessage:
|
||||||
"""Create an AIMessage that represents a function invocation.
|
"""Create an AIMessage that represents a function invocation.
|
||||||
|
|
||||||
@ -788,3 +924,310 @@ async def test_openai_agent_with_streaming() -> None:
|
|||||||
" ",
|
" ",
|
||||||
"bed.",
|
"bed.",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _make_tools_invocation(name_to_arguments: Dict[str, Dict[str, Any]]) -> AIMessage:
|
||||||
|
"""Create an AIMessage that represents a tools invocation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name_to_arguments: A dictionary mapping tool names to an invocation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AIMessage that represents a request to invoke a tool.
|
||||||
|
"""
|
||||||
|
tool_calls = [
|
||||||
|
{"function": {"name": name, "arguments": json.dumps(arguments)}, "id": idx}
|
||||||
|
for idx, (name, arguments) in enumerate(name_to_arguments.items())
|
||||||
|
]
|
||||||
|
|
||||||
|
return AIMessage(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"tool_calls": tool_calls,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_openai_agent_tools_agent() -> None:
|
||||||
|
"""Test OpenAI tools agent."""
|
||||||
|
infinite_cycle = cycle(
|
||||||
|
[
|
||||||
|
_make_tools_invocation(
|
||||||
|
{
|
||||||
|
"find_pet": {"pet": "cat"},
|
||||||
|
"check_time": {},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
AIMessage(content="The cat is spying from under the bed."),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def find_pet(pet: str) -> str:
|
||||||
|
"""Find the given pet."""
|
||||||
|
if pet != "cat":
|
||||||
|
raise ValueError("Only cats allowed")
|
||||||
|
return "Spying from under the bed."
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def check_time() -> str:
|
||||||
|
"""Find the given pet."""
|
||||||
|
return "It's time to pet the cat."
|
||||||
|
|
||||||
|
template = ChatPromptTemplate.from_messages(
|
||||||
|
[
|
||||||
|
("system", "You are a helpful AI bot. Your name is kitty power meow."),
|
||||||
|
("human", "{question}"),
|
||||||
|
MessagesPlaceholder(
|
||||||
|
variable_name="agent_scratchpad",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# type error due to base tool type below -- would need to be adjusted on tool
|
||||||
|
# decorator.
|
||||||
|
agent = create_openai_tools_agent(
|
||||||
|
model,
|
||||||
|
[find_pet], # type: ignore[list-item]
|
||||||
|
template,
|
||||||
|
)
|
||||||
|
executor = AgentExecutor(agent=agent, tools=[find_pet])
|
||||||
|
|
||||||
|
# Invoke
|
||||||
|
result = executor.invoke({"question": "hello"})
|
||||||
|
assert result == {
|
||||||
|
"output": "The cat is spying from under the bed.",
|
||||||
|
"question": "hello",
|
||||||
|
}
|
||||||
|
|
||||||
|
# astream
|
||||||
|
chunks = [chunk async for chunk in executor.astream({"question": "hello"})]
|
||||||
|
assert chunks == [
|
||||||
|
{
|
||||||
|
"actions": [
|
||||||
|
OpenAIToolAgentAction(
|
||||||
|
tool="find_pet",
|
||||||
|
tool_input={"pet": "cat"},
|
||||||
|
log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n",
|
||||||
|
message_log=[
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"function": {
|
||||||
|
"name": "find_pet",
|
||||||
|
"arguments": '{"pet": "cat"}',
|
||||||
|
},
|
||||||
|
"id": 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"function": {
|
||||||
|
"name": "check_time",
|
||||||
|
"arguments": "{}",
|
||||||
|
},
|
||||||
|
"id": 1,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
tool_call_id="0",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"messages": [
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"function": {
|
||||||
|
"name": "find_pet",
|
||||||
|
"arguments": '{"pet": "cat"}',
|
||||||
|
},
|
||||||
|
"id": 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"function": {"name": "check_time", "arguments": "{}"},
|
||||||
|
"id": 1,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"actions": [
|
||||||
|
OpenAIToolAgentAction(
|
||||||
|
tool="check_time",
|
||||||
|
tool_input={},
|
||||||
|
log="\nInvoking: `check_time` with `{}`\n\n\n",
|
||||||
|
message_log=[
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"function": {
|
||||||
|
"name": "find_pet",
|
||||||
|
"arguments": '{"pet": "cat"}',
|
||||||
|
},
|
||||||
|
"id": 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"function": {
|
||||||
|
"name": "check_time",
|
||||||
|
"arguments": "{}",
|
||||||
|
},
|
||||||
|
"id": 1,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
tool_call_id="1",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"messages": [
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"function": {
|
||||||
|
"name": "find_pet",
|
||||||
|
"arguments": '{"pet": "cat"}',
|
||||||
|
},
|
||||||
|
"id": 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"function": {"name": "check_time", "arguments": "{}"},
|
||||||
|
"id": 1,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
FunctionMessage(content="Spying from under the bed.", name="find_pet")
|
||||||
|
],
|
||||||
|
"steps": [
|
||||||
|
AgentStep(
|
||||||
|
action=OpenAIToolAgentAction(
|
||||||
|
tool="find_pet",
|
||||||
|
tool_input={"pet": "cat"},
|
||||||
|
log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n",
|
||||||
|
message_log=[
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"function": {
|
||||||
|
"name": "find_pet",
|
||||||
|
"arguments": '{"pet": "cat"}',
|
||||||
|
},
|
||||||
|
"id": 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"function": {
|
||||||
|
"name": "check_time",
|
||||||
|
"arguments": "{}",
|
||||||
|
},
|
||||||
|
"id": 1,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
tool_call_id="0",
|
||||||
|
),
|
||||||
|
observation="Spying from under the bed.",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
FunctionMessage(
|
||||||
|
content="check_time is not a valid tool, try one of [find_pet].",
|
||||||
|
name="check_time",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"steps": [
|
||||||
|
AgentStep(
|
||||||
|
action=OpenAIToolAgentAction(
|
||||||
|
tool="check_time",
|
||||||
|
tool_input={},
|
||||||
|
log="\nInvoking: `check_time` with `{}`\n\n\n",
|
||||||
|
message_log=[
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"function": {
|
||||||
|
"name": "find_pet",
|
||||||
|
"arguments": '{"pet": "cat"}',
|
||||||
|
},
|
||||||
|
"id": 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"function": {
|
||||||
|
"name": "check_time",
|
||||||
|
"arguments": "{}",
|
||||||
|
},
|
||||||
|
"id": 1,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
tool_call_id="1",
|
||||||
|
),
|
||||||
|
observation="check_time is not a valid tool, "
|
||||||
|
"try one of [find_pet].",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"messages": [AIMessage(content="The cat is spying from under the bed.")],
|
||||||
|
"output": "The cat is spying from under the bed.",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# astream_log
|
||||||
|
log_patches = [
|
||||||
|
log_patch async for log_patch in executor.astream_log({"question": "hello"})
|
||||||
|
]
|
||||||
|
|
||||||
|
# Get the tokens from the astream log response.
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
for log_patch in log_patches:
|
||||||
|
for op in log_patch.ops:
|
||||||
|
if op["op"] == "add" and isinstance(op["value"], AIMessageChunk):
|
||||||
|
value = op["value"]
|
||||||
|
if value.content: # Filter out function call messages
|
||||||
|
messages.append(value.content)
|
||||||
|
|
||||||
|
assert messages == [
|
||||||
|
"The",
|
||||||
|
" ",
|
||||||
|
"cat",
|
||||||
|
" ",
|
||||||
|
"is",
|
||||||
|
" ",
|
||||||
|
"spying",
|
||||||
|
" ",
|
||||||
|
"from",
|
||||||
|
" ",
|
||||||
|
"under",
|
||||||
|
" ",
|
||||||
|
"the",
|
||||||
|
" ",
|
||||||
|
"bed.",
|
||||||
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user