mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-02 01:23:07 +00:00
Add unit tests to test openai tools agent (#15843)
This PR adds unit testing to test openai tools agent.
This commit is contained in:
parent
21a1538949
commit
a06db53c37
@ -19,7 +19,6 @@ def create_openai_tools_agent(
|
||||
|
||||
Examples:
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import hub
|
||||
@ -56,7 +55,6 @@ def create_openai_tools_agent(
|
||||
A runnable sequence representing an agent. It takes as input all the same input
|
||||
variables as the prompt passed in does. It returns as output either an
|
||||
AgentAction or AgentFinish.
|
||||
|
||||
"""
|
||||
missing_vars = {"agent_scratchpad"}.difference(prompt.input_variables)
|
||||
if missing_vars:
|
||||
|
@ -25,8 +25,10 @@ from langchain.agents import (
|
||||
AgentExecutor,
|
||||
AgentType,
|
||||
create_openai_functions_agent,
|
||||
create_openai_tools_agent,
|
||||
initialize_agent,
|
||||
)
|
||||
from langchain.agents.output_parsers.openai_tools import OpenAIToolAgentAction
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.tools import tool
|
||||
@ -626,6 +628,140 @@ async def test_runnable_agent_with_function_calls() -> None:
|
||||
assert messages == ["looking", " ", "for", " ", "pet...", "Found", " ", "Pet"]
|
||||
|
||||
|
||||
async def test_runnable_with_multi_action_per_step() -> None:
|
||||
"""Test an agent that can make multiple function calls at once."""
|
||||
# Will alternate between responding with hello and goodbye
|
||||
infinite_cycle = cycle(
|
||||
[AIMessage(content="looking for pet..."), AIMessage(content="Found Pet")]
|
||||
)
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
|
||||
template = ChatPromptTemplate.from_messages(
|
||||
[("system", "You are Cat Agent 007"), ("human", "{question}")]
|
||||
)
|
||||
|
||||
parser_responses = cycle(
|
||||
[
|
||||
[
|
||||
AgentAction(
|
||||
tool="find_pet",
|
||||
tool_input={
|
||||
"pet": "cat",
|
||||
},
|
||||
log="find_pet()",
|
||||
),
|
||||
AgentAction(
|
||||
tool="pet_pet", # A function that allows you to pet the given pet.
|
||||
tool_input={
|
||||
"pet": "cat",
|
||||
},
|
||||
log="pet_pet()",
|
||||
),
|
||||
],
|
||||
AgentFinish(
|
||||
return_values={"foo": "meow"},
|
||||
log="hard-coded-message",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]:
|
||||
"""A parser."""
|
||||
return cast(Union[AgentFinish, AgentAction], next(parser_responses))
|
||||
|
||||
@tool
|
||||
def find_pet(pet: str) -> str:
|
||||
"""Find the given pet."""
|
||||
if pet != "cat":
|
||||
raise ValueError("Only cats allowed")
|
||||
return "Spying from under the bed."
|
||||
|
||||
@tool
|
||||
def pet_pet(pet: str) -> str:
|
||||
"""Pet the given pet."""
|
||||
if pet != "cat":
|
||||
raise ValueError("Only cats should be petted.")
|
||||
return "purrrr"
|
||||
|
||||
agent = template | model | fake_parse
|
||||
executor = AgentExecutor(agent=agent, tools=[find_pet])
|
||||
|
||||
# Invoke
|
||||
result = executor.invoke({"question": "hello"})
|
||||
assert result == {"foo": "meow", "question": "hello"}
|
||||
|
||||
# ainvoke
|
||||
result = await executor.ainvoke({"question": "hello"})
|
||||
assert result == {"foo": "meow", "question": "hello"}
|
||||
|
||||
# astream
|
||||
results = [r async for r in executor.astream({"question": "hello"})]
|
||||
assert results == [
|
||||
{
|
||||
"actions": [
|
||||
AgentAction(
|
||||
tool="find_pet", tool_input={"pet": "cat"}, log="find_pet()"
|
||||
)
|
||||
],
|
||||
"messages": [AIMessage(content="find_pet()")],
|
||||
},
|
||||
{
|
||||
"actions": [
|
||||
AgentAction(tool="pet_pet", tool_input={"pet": "cat"}, log="pet_pet()")
|
||||
],
|
||||
"messages": [AIMessage(content="pet_pet()")],
|
||||
},
|
||||
{
|
||||
# By-default observation gets converted into human message.
|
||||
"messages": [HumanMessage(content="Spying from under the bed.")],
|
||||
"steps": [
|
||||
AgentStep(
|
||||
action=AgentAction(
|
||||
tool="find_pet", tool_input={"pet": "cat"}, log="find_pet()"
|
||||
),
|
||||
observation="Spying from under the bed.",
|
||||
)
|
||||
],
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
HumanMessage(
|
||||
content="pet_pet is not a valid tool, try one of [find_pet]."
|
||||
)
|
||||
],
|
||||
"steps": [
|
||||
AgentStep(
|
||||
action=AgentAction(
|
||||
tool="pet_pet", tool_input={"pet": "cat"}, log="pet_pet()"
|
||||
),
|
||||
observation="pet_pet is not a valid tool, try one of [find_pet].",
|
||||
)
|
||||
],
|
||||
},
|
||||
{"foo": "meow", "messages": [AIMessage(content="hard-coded-message")]},
|
||||
]
|
||||
|
||||
# astream log
|
||||
|
||||
messages = []
|
||||
async for patch in executor.astream_log({"question": "hello"}):
|
||||
for op in patch.ops:
|
||||
if op["op"] != "add":
|
||||
continue
|
||||
|
||||
value = op["value"]
|
||||
|
||||
if not isinstance(value, AIMessageChunk):
|
||||
continue
|
||||
|
||||
if value.content == "": # Then it's a function invocation message
|
||||
continue
|
||||
|
||||
messages.append(value.content)
|
||||
|
||||
assert messages == ["looking", " ", "for", " ", "pet...", "Found", " ", "Pet"]
|
||||
|
||||
|
||||
def _make_func_invocation(name: str, **kwargs: Any) -> AIMessage:
|
||||
"""Create an AIMessage that represents a function invocation.
|
||||
|
||||
@ -788,3 +924,310 @@ async def test_openai_agent_with_streaming() -> None:
|
||||
" ",
|
||||
"bed.",
|
||||
]
|
||||
|
||||
|
||||
def _make_tools_invocation(name_to_arguments: Dict[str, Dict[str, Any]]) -> AIMessage:
|
||||
"""Create an AIMessage that represents a tools invocation.
|
||||
|
||||
Args:
|
||||
name_to_arguments: A dictionary mapping tool names to an invocation.
|
||||
|
||||
Returns:
|
||||
AIMessage that represents a request to invoke a tool.
|
||||
"""
|
||||
tool_calls = [
|
||||
{"function": {"name": name, "arguments": json.dumps(arguments)}, "id": idx}
|
||||
for idx, (name, arguments) in enumerate(name_to_arguments.items())
|
||||
]
|
||||
|
||||
return AIMessage(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": tool_calls,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def test_openai_agent_tools_agent() -> None:
|
||||
"""Test OpenAI tools agent."""
|
||||
infinite_cycle = cycle(
|
||||
[
|
||||
_make_tools_invocation(
|
||||
{
|
||||
"find_pet": {"pet": "cat"},
|
||||
"check_time": {},
|
||||
}
|
||||
),
|
||||
AIMessage(content="The cat is spying from under the bed."),
|
||||
]
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
|
||||
@tool
|
||||
def find_pet(pet: str) -> str:
|
||||
"""Find the given pet."""
|
||||
if pet != "cat":
|
||||
raise ValueError("Only cats allowed")
|
||||
return "Spying from under the bed."
|
||||
|
||||
@tool
|
||||
def check_time() -> str:
|
||||
"""Find the given pet."""
|
||||
return "It's time to pet the cat."
|
||||
|
||||
template = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", "You are a helpful AI bot. Your name is kitty power meow."),
|
||||
("human", "{question}"),
|
||||
MessagesPlaceholder(
|
||||
variable_name="agent_scratchpad",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# type error due to base tool type below -- would need to be adjusted on tool
|
||||
# decorator.
|
||||
agent = create_openai_tools_agent(
|
||||
model,
|
||||
[find_pet], # type: ignore[list-item]
|
||||
template,
|
||||
)
|
||||
executor = AgentExecutor(agent=agent, tools=[find_pet])
|
||||
|
||||
# Invoke
|
||||
result = executor.invoke({"question": "hello"})
|
||||
assert result == {
|
||||
"output": "The cat is spying from under the bed.",
|
||||
"question": "hello",
|
||||
}
|
||||
|
||||
# astream
|
||||
chunks = [chunk async for chunk in executor.astream({"question": "hello"})]
|
||||
assert chunks == [
|
||||
{
|
||||
"actions": [
|
||||
OpenAIToolAgentAction(
|
||||
tool="find_pet",
|
||||
tool_input={"pet": "cat"},
|
||||
log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n",
|
||||
message_log=[
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "find_pet",
|
||||
"arguments": '{"pet": "cat"}',
|
||||
},
|
||||
"id": 0,
|
||||
},
|
||||
{
|
||||
"function": {
|
||||
"name": "check_time",
|
||||
"arguments": "{}",
|
||||
},
|
||||
"id": 1,
|
||||
},
|
||||
]
|
||||
},
|
||||
)
|
||||
],
|
||||
tool_call_id="0",
|
||||
)
|
||||
],
|
||||
"messages": [
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "find_pet",
|
||||
"arguments": '{"pet": "cat"}',
|
||||
},
|
||||
"id": 0,
|
||||
},
|
||||
{
|
||||
"function": {"name": "check_time", "arguments": "{}"},
|
||||
"id": 1,
|
||||
},
|
||||
]
|
||||
},
|
||||
)
|
||||
],
|
||||
},
|
||||
{
|
||||
"actions": [
|
||||
OpenAIToolAgentAction(
|
||||
tool="check_time",
|
||||
tool_input={},
|
||||
log="\nInvoking: `check_time` with `{}`\n\n\n",
|
||||
message_log=[
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "find_pet",
|
||||
"arguments": '{"pet": "cat"}',
|
||||
},
|
||||
"id": 0,
|
||||
},
|
||||
{
|
||||
"function": {
|
||||
"name": "check_time",
|
||||
"arguments": "{}",
|
||||
},
|
||||
"id": 1,
|
||||
},
|
||||
]
|
||||
},
|
||||
)
|
||||
],
|
||||
tool_call_id="1",
|
||||
)
|
||||
],
|
||||
"messages": [
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "find_pet",
|
||||
"arguments": '{"pet": "cat"}',
|
||||
},
|
||||
"id": 0,
|
||||
},
|
||||
{
|
||||
"function": {"name": "check_time", "arguments": "{}"},
|
||||
"id": 1,
|
||||
},
|
||||
]
|
||||
},
|
||||
)
|
||||
],
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
FunctionMessage(content="Spying from under the bed.", name="find_pet")
|
||||
],
|
||||
"steps": [
|
||||
AgentStep(
|
||||
action=OpenAIToolAgentAction(
|
||||
tool="find_pet",
|
||||
tool_input={"pet": "cat"},
|
||||
log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n",
|
||||
message_log=[
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "find_pet",
|
||||
"arguments": '{"pet": "cat"}',
|
||||
},
|
||||
"id": 0,
|
||||
},
|
||||
{
|
||||
"function": {
|
||||
"name": "check_time",
|
||||
"arguments": "{}",
|
||||
},
|
||||
"id": 1,
|
||||
},
|
||||
]
|
||||
},
|
||||
)
|
||||
],
|
||||
tool_call_id="0",
|
||||
),
|
||||
observation="Spying from under the bed.",
|
||||
)
|
||||
],
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
FunctionMessage(
|
||||
content="check_time is not a valid tool, try one of [find_pet].",
|
||||
name="check_time",
|
||||
)
|
||||
],
|
||||
"steps": [
|
||||
AgentStep(
|
||||
action=OpenAIToolAgentAction(
|
||||
tool="check_time",
|
||||
tool_input={},
|
||||
log="\nInvoking: `check_time` with `{}`\n\n\n",
|
||||
message_log=[
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "find_pet",
|
||||
"arguments": '{"pet": "cat"}',
|
||||
},
|
||||
"id": 0,
|
||||
},
|
||||
{
|
||||
"function": {
|
||||
"name": "check_time",
|
||||
"arguments": "{}",
|
||||
},
|
||||
"id": 1,
|
||||
},
|
||||
]
|
||||
},
|
||||
)
|
||||
],
|
||||
tool_call_id="1",
|
||||
),
|
||||
observation="check_time is not a valid tool, "
|
||||
"try one of [find_pet].",
|
||||
)
|
||||
],
|
||||
},
|
||||
{
|
||||
"messages": [AIMessage(content="The cat is spying from under the bed.")],
|
||||
"output": "The cat is spying from under the bed.",
|
||||
},
|
||||
]
|
||||
|
||||
# astream_log
|
||||
log_patches = [
|
||||
log_patch async for log_patch in executor.astream_log({"question": "hello"})
|
||||
]
|
||||
|
||||
# Get the tokens from the astream log response.
|
||||
messages = []
|
||||
|
||||
for log_patch in log_patches:
|
||||
for op in log_patch.ops:
|
||||
if op["op"] == "add" and isinstance(op["value"], AIMessageChunk):
|
||||
value = op["value"]
|
||||
if value.content: # Filter out function call messages
|
||||
messages.append(value.content)
|
||||
|
||||
assert messages == [
|
||||
"The",
|
||||
" ",
|
||||
"cat",
|
||||
" ",
|
||||
"is",
|
||||
" ",
|
||||
"spying",
|
||||
" ",
|
||||
"from",
|
||||
" ",
|
||||
"under",
|
||||
" ",
|
||||
"the",
|
||||
" ",
|
||||
"bed.",
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user