mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-11 16:01:33 +00:00
chore(langchain): rename create_react_agent
-> create_agent
(#32789)
This commit is contained in:
@@ -1,10 +1,10 @@
|
|||||||
"""langgraph.prebuilt exposes a higher-level API for creating and executing agents and tools."""
|
"""langgraph.prebuilt exposes a higher-level API for creating and executing agents and tools."""
|
||||||
|
|
||||||
from langchain.agents.react_agent import AgentState, create_react_agent
|
from langchain.agents.react_agent import AgentState, create_agent
|
||||||
from langchain.agents.tool_node import ToolNode
|
from langchain.agents.tool_node import ToolNode
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AgentState",
|
"AgentState",
|
||||||
"ToolNode",
|
"ToolNode",
|
||||||
"create_react_agent",
|
"create_agent",
|
||||||
]
|
]
|
||||||
|
@@ -898,7 +898,7 @@ def _supports_native_structured_output(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_react_agent( # noqa: D417
|
def create_agent( # noqa: D417
|
||||||
model: Union[
|
model: Union[
|
||||||
str,
|
str,
|
||||||
BaseChatModel,
|
BaseChatModel,
|
||||||
@@ -928,7 +928,7 @@ def create_react_agent( # noqa: D417
|
|||||||
) -> CompiledStateGraph[StateT, ContextT]:
|
) -> CompiledStateGraph[StateT, ContextT]:
|
||||||
"""Creates an agent graph that calls tools in a loop until a stopping condition is met.
|
"""Creates an agent graph that calls tools in a loop until a stopping condition is met.
|
||||||
|
|
||||||
For more details on using `create_react_agent`, visit [Agents](https://langchain-ai.github.io/langgraph/agents/overview/) documentation.
|
For more details on using `create_agent`, visit [Agents](https://langchain-ai.github.io/langgraph/agents/overview/) documentation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: The language model for the agent. Supports static and dynamic
|
model: The language model for the agent. Supports static and dynamic
|
||||||
@@ -1096,13 +1096,13 @@ def create_react_agent( # noqa: D417
|
|||||||
|
|
||||||
Example:
|
Example:
|
||||||
```python
|
```python
|
||||||
from langchain.agents import create_react_agent
|
from langchain.agents import create_agent
|
||||||
|
|
||||||
def check_weather(location: str) -> str:
|
def check_weather(location: str) -> str:
|
||||||
'''Return the weather forecast for the specified location.'''
|
'''Return the weather forecast for the specified location.'''
|
||||||
return f"It's always sunny in {location}"
|
return f"It's always sunny in {location}"
|
||||||
|
|
||||||
graph = create_react_agent(
|
graph = create_agent(
|
||||||
"anthropic:claude-3-7-sonnet-latest",
|
"anthropic:claude-3-7-sonnet-latest",
|
||||||
tools=[check_weather],
|
tools=[check_weather],
|
||||||
prompt="You are a helpful assistant",
|
prompt="You are a helpful assistant",
|
||||||
@@ -1123,7 +1123,7 @@ def create_react_agent( # noqa: D417
|
|||||||
context_schema = config_schema
|
context_schema = config_schema
|
||||||
|
|
||||||
if len(deprecated_kwargs) > 0:
|
if len(deprecated_kwargs) > 0:
|
||||||
msg = f"create_react_agent() got unexpected keyword arguments: {deprecated_kwargs}"
|
msg = f"create_agent() got unexpected keyword arguments: {deprecated_kwargs}"
|
||||||
raise TypeError(msg)
|
raise TypeError(msg)
|
||||||
|
|
||||||
if response_format and not isinstance(response_format, (ToolStrategy, ProviderStrategy)):
|
if response_format and not isinstance(response_format, (ToolStrategy, ProviderStrategy)):
|
||||||
@@ -1171,5 +1171,5 @@ __all__ = [
|
|||||||
"AgentStatePydantic",
|
"AgentStatePydantic",
|
||||||
"AgentStateWithStructuredResponse",
|
"AgentStateWithStructuredResponse",
|
||||||
"AgentStateWithStructuredResponsePydantic",
|
"AgentStateWithStructuredResponsePydantic",
|
||||||
"create_react_agent",
|
"create_agent",
|
||||||
]
|
]
|
||||||
|
@@ -2,7 +2,7 @@ import pytest
|
|||||||
from langchain_core.messages import HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from langchain.agents import create_react_agent
|
from langchain.agents import create_agent
|
||||||
from langchain.agents.structured_output import ToolStrategy
|
from langchain.agents.structured_output import ToolStrategy
|
||||||
|
|
||||||
|
|
||||||
@@ -24,7 +24,7 @@ def test_inference_to_native_output() -> None:
|
|||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
model = ChatOpenAI(model="gpt-5")
|
model = ChatOpenAI(model="gpt-5")
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model,
|
model,
|
||||||
prompt=(
|
prompt=(
|
||||||
"You are a helpful weather assistant. Please call the get_weather tool, "
|
"You are a helpful weather assistant. Please call the get_weather tool, "
|
||||||
@@ -54,7 +54,7 @@ def test_inference_to_tool_output() -> None:
|
|||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
model = ChatOpenAI(model="gpt-4")
|
model = ChatOpenAI(model="gpt-4")
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model,
|
model,
|
||||||
prompt=(
|
prompt=(
|
||||||
"You are a helpful weather assistant. Please call the get_weather tool, "
|
"You are a helpful weather assistant. Please call the get_weather tool, "
|
||||||
|
@@ -32,7 +32,7 @@ from typing_extensions import TypedDict
|
|||||||
from langchain.agents import (
|
from langchain.agents import (
|
||||||
AgentState,
|
AgentState,
|
||||||
ToolNode,
|
ToolNode,
|
||||||
create_react_agent,
|
create_agent,
|
||||||
)
|
)
|
||||||
from langchain.agents.react_agent import _validate_chat_history
|
from langchain.agents.react_agent import _validate_chat_history
|
||||||
from langchain.agents.tool_node import (
|
from langchain.agents.tool_node import (
|
||||||
@@ -52,7 +52,7 @@ pytestmark = pytest.mark.anyio
|
|||||||
def test_no_prompt(sync_checkpointer: BaseCheckpointSaver) -> None:
|
def test_no_prompt(sync_checkpointer: BaseCheckpointSaver) -> None:
|
||||||
model = FakeToolCallingModel()
|
model = FakeToolCallingModel()
|
||||||
|
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model,
|
model,
|
||||||
[],
|
[],
|
||||||
checkpointer=sync_checkpointer,
|
checkpointer=sync_checkpointer,
|
||||||
@@ -82,7 +82,7 @@ def test_no_prompt(sync_checkpointer: BaseCheckpointSaver) -> None:
|
|||||||
async def test_no_prompt_async(async_checkpointer: BaseCheckpointSaver) -> None:
|
async def test_no_prompt_async(async_checkpointer: BaseCheckpointSaver) -> None:
|
||||||
model = FakeToolCallingModel()
|
model = FakeToolCallingModel()
|
||||||
|
|
||||||
agent = create_react_agent(model, [], checkpointer=async_checkpointer)
|
agent = create_agent(model, [], checkpointer=async_checkpointer)
|
||||||
inputs = [HumanMessage("hi?")]
|
inputs = [HumanMessage("hi?")]
|
||||||
thread = {"configurable": {"thread_id": "123"}}
|
thread = {"configurable": {"thread_id": "123"}}
|
||||||
response = await agent.ainvoke({"messages": inputs}, thread, debug=True)
|
response = await agent.ainvoke({"messages": inputs}, thread, debug=True)
|
||||||
@@ -107,7 +107,7 @@ async def test_no_prompt_async(async_checkpointer: BaseCheckpointSaver) -> None:
|
|||||||
|
|
||||||
def test_system_message_prompt() -> None:
|
def test_system_message_prompt() -> None:
|
||||||
prompt = SystemMessage(content="Foo")
|
prompt = SystemMessage(content="Foo")
|
||||||
agent = create_react_agent(FakeToolCallingModel(), [], prompt=prompt)
|
agent = create_agent(FakeToolCallingModel(), [], prompt=prompt)
|
||||||
inputs = [HumanMessage("hi?")]
|
inputs = [HumanMessage("hi?")]
|
||||||
response = agent.invoke({"messages": inputs})
|
response = agent.invoke({"messages": inputs})
|
||||||
expected_response = {"messages": [*inputs, AIMessage(content="Foo-hi?", id="0", tool_calls=[])]}
|
expected_response = {"messages": [*inputs, AIMessage(content="Foo-hi?", id="0", tool_calls=[])]}
|
||||||
@@ -116,7 +116,7 @@ def test_system_message_prompt() -> None:
|
|||||||
|
|
||||||
def test_string_prompt() -> None:
|
def test_string_prompt() -> None:
|
||||||
prompt = "Foo"
|
prompt = "Foo"
|
||||||
agent = create_react_agent(FakeToolCallingModel(), [], prompt=prompt)
|
agent = create_agent(FakeToolCallingModel(), [], prompt=prompt)
|
||||||
inputs = [HumanMessage("hi?")]
|
inputs = [HumanMessage("hi?")]
|
||||||
response = agent.invoke({"messages": inputs})
|
response = agent.invoke({"messages": inputs})
|
||||||
expected_response = {"messages": [*inputs, AIMessage(content="Foo-hi?", id="0", tool_calls=[])]}
|
expected_response = {"messages": [*inputs, AIMessage(content="Foo-hi?", id="0", tool_calls=[])]}
|
||||||
@@ -128,7 +128,7 @@ def test_callable_prompt() -> None:
|
|||||||
modified_message = f"Bar {state['messages'][-1].content}"
|
modified_message = f"Bar {state['messages'][-1].content}"
|
||||||
return [HumanMessage(content=modified_message)]
|
return [HumanMessage(content=modified_message)]
|
||||||
|
|
||||||
agent = create_react_agent(FakeToolCallingModel(), [], prompt=prompt)
|
agent = create_agent(FakeToolCallingModel(), [], prompt=prompt)
|
||||||
inputs = [HumanMessage("hi?")]
|
inputs = [HumanMessage("hi?")]
|
||||||
response = agent.invoke({"messages": inputs})
|
response = agent.invoke({"messages": inputs})
|
||||||
expected_response = {"messages": [*inputs, AIMessage(content="Bar hi?", id="0")]}
|
expected_response = {"messages": [*inputs, AIMessage(content="Bar hi?", id="0")]}
|
||||||
@@ -140,7 +140,7 @@ async def test_callable_prompt_async() -> None:
|
|||||||
modified_message = f"Bar {state['messages'][-1].content}"
|
modified_message = f"Bar {state['messages'][-1].content}"
|
||||||
return [HumanMessage(content=modified_message)]
|
return [HumanMessage(content=modified_message)]
|
||||||
|
|
||||||
agent = create_react_agent(FakeToolCallingModel(), [], prompt=prompt)
|
agent = create_agent(FakeToolCallingModel(), [], prompt=prompt)
|
||||||
inputs = [HumanMessage("hi?")]
|
inputs = [HumanMessage("hi?")]
|
||||||
response = await agent.ainvoke({"messages": inputs})
|
response = await agent.ainvoke({"messages": inputs})
|
||||||
expected_response = {"messages": [*inputs, AIMessage(content="Bar hi?", id="0")]}
|
expected_response = {"messages": [*inputs, AIMessage(content="Bar hi?", id="0")]}
|
||||||
@@ -152,7 +152,7 @@ def test_runnable_prompt() -> None:
|
|||||||
lambda state: [HumanMessage(content=f"Baz {state['messages'][-1].content}")]
|
lambda state: [HumanMessage(content=f"Baz {state['messages'][-1].content}")]
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = create_react_agent(FakeToolCallingModel(), [], prompt=prompt)
|
agent = create_agent(FakeToolCallingModel(), [], prompt=prompt)
|
||||||
inputs = [HumanMessage("hi?")]
|
inputs = [HumanMessage("hi?")]
|
||||||
response = agent.invoke({"messages": inputs})
|
response = agent.invoke({"messages": inputs})
|
||||||
expected_response = {"messages": [*inputs, AIMessage(content="Baz hi?", id="0")]}
|
expected_response = {"messages": [*inputs, AIMessage(content="Baz hi?", id="0")]}
|
||||||
@@ -179,7 +179,7 @@ def test_prompt_with_store() -> None:
|
|||||||
model = FakeToolCallingModel()
|
model = FakeToolCallingModel()
|
||||||
|
|
||||||
# test state modifier that uses store works
|
# test state modifier that uses store works
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model,
|
model,
|
||||||
[add],
|
[add],
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@@ -189,7 +189,7 @@ def test_prompt_with_store() -> None:
|
|||||||
assert response["messages"][-1].content == "User name is Alice-hi"
|
assert response["messages"][-1].content == "User name is Alice-hi"
|
||||||
|
|
||||||
# test state modifier that doesn't use store works
|
# test state modifier that doesn't use store works
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model,
|
model,
|
||||||
[add],
|
[add],
|
||||||
prompt=prompt_no_store,
|
prompt=prompt_no_store,
|
||||||
@@ -219,14 +219,14 @@ async def test_prompt_with_store_async() -> None:
|
|||||||
model = FakeToolCallingModel()
|
model = FakeToolCallingModel()
|
||||||
|
|
||||||
# test state modifier that uses store works
|
# test state modifier that uses store works
|
||||||
agent = create_react_agent(model, [add], prompt=prompt, store=in_memory_store)
|
agent = create_agent(model, [add], prompt=prompt, store=in_memory_store)
|
||||||
response = await agent.ainvoke(
|
response = await agent.ainvoke(
|
||||||
{"messages": [("user", "hi")]}, {"configurable": {"user_id": "1"}}
|
{"messages": [("user", "hi")]}, {"configurable": {"user_id": "1"}}
|
||||||
)
|
)
|
||||||
assert response["messages"][-1].content == "User name is Alice-hi"
|
assert response["messages"][-1].content == "User name is Alice-hi"
|
||||||
|
|
||||||
# test state modifier that doesn't use store works
|
# test state modifier that doesn't use store works
|
||||||
agent = create_react_agent(model, [add], prompt=prompt_no_store, store=in_memory_store)
|
agent = create_agent(model, [add], prompt=prompt_no_store, store=in_memory_store)
|
||||||
response = await agent.ainvoke(
|
response = await agent.ainvoke(
|
||||||
{"messages": [("user", "hi")]}, {"configurable": {"user_id": "2"}}
|
{"messages": [("user", "hi")]}, {"configurable": {"user_id": "2"}}
|
||||||
)
|
)
|
||||||
@@ -267,7 +267,7 @@ def test_model_with_tools(tool_style: str, include_builtin: bool) -> None:
|
|||||||
)
|
)
|
||||||
# check valid agent constructor
|
# check valid agent constructor
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
create_react_agent(
|
create_agent(
|
||||||
model.bind_tools(tools),
|
model.bind_tools(tools),
|
||||||
tools,
|
tools,
|
||||||
)
|
)
|
||||||
@@ -414,7 +414,7 @@ def test_react_agent_with_structured_response() -> None:
|
|||||||
model = FakeToolCallingModel[WeatherResponse](
|
model = FakeToolCallingModel[WeatherResponse](
|
||||||
tool_calls=tool_calls, structured_response=expected_structured_response
|
tool_calls=tool_calls, structured_response=expected_structured_response
|
||||||
)
|
)
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model,
|
model,
|
||||||
[get_weather],
|
[get_weather],
|
||||||
response_format=WeatherResponse,
|
response_format=WeatherResponse,
|
||||||
@@ -472,7 +472,7 @@ def test_react_agent_update_state(
|
|||||||
|
|
||||||
tool_calls = [[{"args": {}, "id": "1", "name": "get_user_name"}]]
|
tool_calls = [[{"args": {}, "id": "1", "name": "get_user_name"}]]
|
||||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model,
|
model,
|
||||||
[get_user_name],
|
[get_user_name],
|
||||||
state_schema=CustomState,
|
state_schema=CustomState,
|
||||||
@@ -523,7 +523,7 @@ def test_react_agent_parallel_tool_calls(
|
|||||||
[],
|
[],
|
||||||
]
|
]
|
||||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model,
|
model,
|
||||||
[human_assistance, get_weather],
|
[human_assistance, get_weather],
|
||||||
checkpointer=sync_checkpointer,
|
checkpointer=sync_checkpointer,
|
||||||
@@ -561,7 +561,7 @@ class AgentStateExtraKey(AgentState):
|
|||||||
foo: int
|
foo: int
|
||||||
|
|
||||||
|
|
||||||
def test_create_react_agent_inject_vars() -> None:
|
def test_create_agent_inject_vars() -> None:
|
||||||
"""Test that the agent can inject state and store into tool functions."""
|
"""Test that the agent can inject state and store into tool functions."""
|
||||||
store = InMemoryStore()
|
store = InMemoryStore()
|
||||||
namespace = ("test",)
|
namespace = ("test",)
|
||||||
@@ -583,7 +583,7 @@ def test_create_react_agent_inject_vars() -> None:
|
|||||||
"type": "tool_call",
|
"type": "tool_call",
|
||||||
}
|
}
|
||||||
model = FakeToolCallingModel(tool_calls=[[tool_call], []])
|
model = FakeToolCallingModel(tool_calls=[[tool_call], []])
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model,
|
model,
|
||||||
ToolNode([tool1], handle_tool_errors=False),
|
ToolNode([tool1], handle_tool_errors=False),
|
||||||
state_schema=AgentStateExtraKey,
|
state_schema=AgentStateExtraKey,
|
||||||
@@ -623,7 +623,7 @@ async def test_return_direct() -> None:
|
|||||||
tool_calls=first_tool_call,
|
tool_calls=first_tool_call,
|
||||||
)
|
)
|
||||||
model = FakeToolCallingModel(tool_calls=[first_tool_call, []])
|
model = FakeToolCallingModel(tool_calls=[first_tool_call, []])
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model,
|
model,
|
||||||
[tool_return_direct, tool_normal],
|
[tool_return_direct, tool_normal],
|
||||||
)
|
)
|
||||||
@@ -648,7 +648,7 @@ async def test_return_direct() -> None:
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
model = FakeToolCallingModel(tool_calls=[second_tool_call, []])
|
model = FakeToolCallingModel(tool_calls=[second_tool_call, []])
|
||||||
agent = create_react_agent(model, [tool_return_direct, tool_normal])
|
agent = create_agent(model, [tool_return_direct, tool_normal])
|
||||||
result = agent.invoke({"messages": [HumanMessage(content="Test normal", id="hum1")]})
|
result = agent.invoke({"messages": [HumanMessage(content="Test normal", id="hum1")]})
|
||||||
assert result["messages"] == [
|
assert result["messages"] == [
|
||||||
HumanMessage(content="Test normal", id="hum1"),
|
HumanMessage(content="Test normal", id="hum1"),
|
||||||
@@ -675,7 +675,7 @@ async def test_return_direct() -> None:
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
model = FakeToolCallingModel(tool_calls=[both_tool_calls, []])
|
model = FakeToolCallingModel(tool_calls=[both_tool_calls, []])
|
||||||
agent = create_react_agent(model, [tool_return_direct, tool_normal])
|
agent = create_agent(model, [tool_return_direct, tool_normal])
|
||||||
result = agent.invoke({"messages": [HumanMessage(content="Test both", id="hum2")]})
|
result = agent.invoke({"messages": [HumanMessage(content="Test both", id="hum2")]})
|
||||||
assert result["messages"] == [
|
assert result["messages"] == [
|
||||||
HumanMessage(content="Test both", id="hum2"),
|
HumanMessage(content="Test both", id="hum2"),
|
||||||
@@ -712,7 +712,7 @@ def test__get_state_args() -> None:
|
|||||||
|
|
||||||
def test_inspect_react() -> None:
|
def test_inspect_react() -> None:
|
||||||
model = FakeToolCallingModel(tool_calls=[])
|
model = FakeToolCallingModel(tool_calls=[])
|
||||||
agent = create_react_agent(model, [])
|
agent = create_agent(model, [])
|
||||||
inspect.getclosurevars(agent.nodes["agent"].bound.func)
|
inspect.getclosurevars(agent.nodes["agent"].bound.func)
|
||||||
|
|
||||||
|
|
||||||
@@ -766,7 +766,7 @@ def test_react_with_subgraph_tools(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
tool_node = ToolNode([addition, multiplication], handle_tool_errors=False)
|
tool_node = ToolNode([addition, multiplication], handle_tool_errors=False)
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model,
|
model,
|
||||||
tool_node,
|
tool_node,
|
||||||
checkpointer=sync_checkpointer,
|
checkpointer=sync_checkpointer,
|
||||||
@@ -812,7 +812,7 @@ def test_react_agent_subgraph_streaming_sync() -> None:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model,
|
model,
|
||||||
tools=[get_weather],
|
tools=[get_weather],
|
||||||
prompt="You are a helpful travel assistant.",
|
prompt="You are a helpful travel assistant.",
|
||||||
@@ -899,7 +899,7 @@ async def test_react_agent_subgraph_streaming() -> None:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model,
|
model,
|
||||||
tools=[get_weather],
|
tools=[get_weather],
|
||||||
prompt="You are a helpful travel assistant.",
|
prompt="You are a helpful travel assistant.",
|
||||||
@@ -994,7 +994,7 @@ def test_tool_node_node_interrupt(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
config = {"configurable": {"thread_id": "1"}}
|
config = {"configurable": {"thread_id": "1"}}
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model,
|
model,
|
||||||
[tool_interrupt, tool_normal],
|
[tool_interrupt, tool_normal],
|
||||||
checkpointer=sync_checkpointer,
|
checkpointer=sync_checkpointer,
|
||||||
@@ -1045,7 +1045,7 @@ def test_dynamic_model_basic() -> None:
|
|||||||
return FakeToolCallingModel(tool_calls=[])
|
return FakeToolCallingModel(tool_calls=[])
|
||||||
return FakeToolCallingModel(tool_calls=[])
|
return FakeToolCallingModel(tool_calls=[])
|
||||||
|
|
||||||
agent = create_react_agent(dynamic_model, [])
|
agent = create_agent(dynamic_model, [])
|
||||||
|
|
||||||
result = agent.invoke({"messages": [HumanMessage("hello")]})
|
result = agent.invoke({"messages": [HumanMessage("hello")]})
|
||||||
assert len(result["messages"]) == 2
|
assert len(result["messages"]) == 2
|
||||||
@@ -1082,7 +1082,7 @@ def test_dynamic_model_with_tools() -> None:
|
|||||||
tool_calls=[[{"args": {"x": 1}, "id": "1", "name": "basic_tool"}], []]
|
tool_calls=[[{"args": {"x": 1}, "id": "1", "name": "basic_tool"}], []]
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = create_react_agent(dynamic_model, [basic_tool, advanced_tool])
|
agent = create_agent(dynamic_model, [basic_tool, advanced_tool])
|
||||||
|
|
||||||
# Test basic tool usage
|
# Test basic tool usage
|
||||||
result = agent.invoke({"messages": [HumanMessage("basic request")]})
|
result = agent.invoke({"messages": [HumanMessage("basic request")]})
|
||||||
@@ -1114,7 +1114,7 @@ def test_dynamic_model_with_context() -> None:
|
|||||||
return FakeToolCallingModel(tool_calls=[])
|
return FakeToolCallingModel(tool_calls=[])
|
||||||
return FakeToolCallingModel(tool_calls=[])
|
return FakeToolCallingModel(tool_calls=[])
|
||||||
|
|
||||||
agent = create_react_agent(dynamic_model, [], context_schema=Context)
|
agent = create_agent(dynamic_model, [], context_schema=Context)
|
||||||
|
|
||||||
# Test with basic user
|
# Test with basic user
|
||||||
result = agent.invoke(
|
result = agent.invoke(
|
||||||
@@ -1143,7 +1143,7 @@ def test_dynamic_model_with_state_schema() -> None:
|
|||||||
return FakeToolCallingModel(tool_calls=[])
|
return FakeToolCallingModel(tool_calls=[])
|
||||||
return FakeToolCallingModel(tool_calls=[])
|
return FakeToolCallingModel(tool_calls=[])
|
||||||
|
|
||||||
agent = create_react_agent(dynamic_model, [], state_schema=CustomDynamicState)
|
agent = create_agent(dynamic_model, [], state_schema=CustomDynamicState)
|
||||||
|
|
||||||
result = agent.invoke({"messages": [HumanMessage("hello")], "model_preference": "advanced"})
|
result = agent.invoke({"messages": [HumanMessage("hello")], "model_preference": "advanced"})
|
||||||
assert len(result["messages"]) == 2
|
assert len(result["messages"]) == 2
|
||||||
@@ -1157,7 +1157,7 @@ def test_dynamic_model_with_prompt() -> None:
|
|||||||
return FakeToolCallingModel(tool_calls=[])
|
return FakeToolCallingModel(tool_calls=[])
|
||||||
|
|
||||||
# Test with string prompt
|
# Test with string prompt
|
||||||
agent = create_react_agent(dynamic_model, [], prompt="system_msg")
|
agent = create_agent(dynamic_model, [], prompt="system_msg")
|
||||||
result = agent.invoke({"messages": [HumanMessage("human_msg")]})
|
result = agent.invoke({"messages": [HumanMessage("human_msg")]})
|
||||||
assert result["messages"][-1].content == "system_msg-human_msg"
|
assert result["messages"][-1].content == "system_msg-human_msg"
|
||||||
|
|
||||||
@@ -1166,7 +1166,7 @@ def test_dynamic_model_with_prompt() -> None:
|
|||||||
"""Generate a dynamic system message based on state."""
|
"""Generate a dynamic system message based on state."""
|
||||||
return [{"role": "system", "content": "system_msg"}, *list(state["messages"])]
|
return [{"role": "system", "content": "system_msg"}, *list(state["messages"])]
|
||||||
|
|
||||||
agent = create_react_agent(dynamic_model, [], prompt=dynamic_prompt)
|
agent = create_agent(dynamic_model, [], prompt=dynamic_prompt)
|
||||||
result = agent.invoke({"messages": [HumanMessage("human_msg")]})
|
result = agent.invoke({"messages": [HumanMessage("human_msg")]})
|
||||||
assert result["messages"][-1].content == "system_msg-human_msg"
|
assert result["messages"][-1].content == "system_msg-human_msg"
|
||||||
|
|
||||||
@@ -1177,7 +1177,7 @@ async def test_dynamic_model_async() -> None:
|
|||||||
def dynamic_model(state: AgentState, runtime: Runtime) -> BaseChatModel:
|
def dynamic_model(state: AgentState, runtime: Runtime) -> BaseChatModel:
|
||||||
return FakeToolCallingModel(tool_calls=[])
|
return FakeToolCallingModel(tool_calls=[])
|
||||||
|
|
||||||
agent = create_react_agent(dynamic_model, [])
|
agent = create_agent(dynamic_model, [])
|
||||||
|
|
||||||
result = await agent.ainvoke({"messages": [HumanMessage("hello async")]})
|
result = await agent.ainvoke({"messages": [HumanMessage("hello async")]})
|
||||||
assert len(result["messages"]) == 2
|
assert len(result["messages"]) == 2
|
||||||
@@ -1205,7 +1205,7 @@ def test_dynamic_model_with_structured_response() -> None:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = create_react_agent(dynamic_model, [], response_format=TestResponse)
|
agent = create_agent(dynamic_model, [], response_format=TestResponse)
|
||||||
|
|
||||||
result = agent.invoke({"messages": [HumanMessage("hello")]})
|
result = agent.invoke({"messages": [HumanMessage("hello")]})
|
||||||
assert "structured_response" in result
|
assert "structured_response" in result
|
||||||
@@ -1229,7 +1229,7 @@ def test_dynamic_model_with_checkpointer(sync_checkpointer) -> None:
|
|||||||
index=call_count,
|
index=call_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = create_react_agent(dynamic_model, [], checkpointer=sync_checkpointer)
|
agent = create_agent(dynamic_model, [], checkpointer=sync_checkpointer)
|
||||||
config = {"configurable": {"thread_id": "test_dynamic"}}
|
config = {"configurable": {"thread_id": "test_dynamic"}}
|
||||||
|
|
||||||
# First call
|
# First call
|
||||||
@@ -1267,7 +1267,7 @@ def test_dynamic_model_state_dependent_tools() -> None:
|
|||||||
tool_calls=[[{"args": {"x": 1}, "id": "1", "name": "tool_a"}], []]
|
tool_calls=[[{"args": {"x": 1}, "id": "1", "name": "tool_a"}], []]
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = create_react_agent(dynamic_model, [tool_a, tool_b])
|
agent = create_agent(dynamic_model, [tool_a, tool_b])
|
||||||
|
|
||||||
# Ask to use tool B
|
# Ask to use tool B
|
||||||
result = agent.invoke({"messages": [HumanMessage("use_b please")]})
|
result = agent.invoke({"messages": [HumanMessage("use_b please")]})
|
||||||
@@ -1291,7 +1291,7 @@ def test_dynamic_model_error_handling() -> None:
|
|||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return FakeToolCallingModel(tool_calls=[])
|
return FakeToolCallingModel(tool_calls=[])
|
||||||
|
|
||||||
agent = create_react_agent(failing_dynamic_model, [])
|
agent = create_agent(failing_dynamic_model, [])
|
||||||
|
|
||||||
# Normal operation should work
|
# Normal operation should work
|
||||||
result = agent.invoke({"messages": [HumanMessage("hello")]})
|
result = agent.invoke({"messages": [HumanMessage("hello")]})
|
||||||
@@ -1306,13 +1306,13 @@ def test_dynamic_model_vs_static_model_behavior() -> None:
|
|||||||
"""Test that dynamic and static models produce equivalent results when configured the same."""
|
"""Test that dynamic and static models produce equivalent results when configured the same."""
|
||||||
# Static model
|
# Static model
|
||||||
static_model = FakeToolCallingModel(tool_calls=[])
|
static_model = FakeToolCallingModel(tool_calls=[])
|
||||||
static_agent = create_react_agent(static_model, [])
|
static_agent = create_agent(static_model, [])
|
||||||
|
|
||||||
# Dynamic model returning the same model
|
# Dynamic model returning the same model
|
||||||
def dynamic_model(state, runtime: Runtime):
|
def dynamic_model(state, runtime: Runtime):
|
||||||
return FakeToolCallingModel(tool_calls=[])
|
return FakeToolCallingModel(tool_calls=[])
|
||||||
|
|
||||||
dynamic_agent = create_react_agent(dynamic_model, [])
|
dynamic_agent = create_agent(dynamic_model, [])
|
||||||
|
|
||||||
input_msg = {"messages": [HumanMessage("test message")]}
|
input_msg = {"messages": [HumanMessage("test message")]}
|
||||||
|
|
||||||
@@ -1337,7 +1337,7 @@ def test_dynamic_model_receives_correct_state() -> None:
|
|||||||
received_states.append(state)
|
received_states.append(state)
|
||||||
return FakeToolCallingModel(tool_calls=[])
|
return FakeToolCallingModel(tool_calls=[])
|
||||||
|
|
||||||
agent = create_react_agent(dynamic_model, [], state_schema=CustomAgentState)
|
agent = create_agent(dynamic_model, [], state_schema=CustomAgentState)
|
||||||
|
|
||||||
# Test with initial state
|
# Test with initial state
|
||||||
input_state = {"messages": [HumanMessage("hello")], "custom_field": "test_value"}
|
input_state = {"messages": [HumanMessage("hello")], "custom_field": "test_value"}
|
||||||
@@ -1368,7 +1368,7 @@ async def test_dynamic_model_receives_correct_state_async() -> None:
|
|||||||
received_states.append(state)
|
received_states.append(state)
|
||||||
return FakeToolCallingModel(tool_calls=[])
|
return FakeToolCallingModel(tool_calls=[])
|
||||||
|
|
||||||
agent = create_react_agent(dynamic_model, [], state_schema=CustomAgentStateAsync)
|
agent = create_agent(dynamic_model, [], state_schema=CustomAgentStateAsync)
|
||||||
|
|
||||||
# Test with initial state
|
# Test with initial state
|
||||||
input_state = {
|
input_state = {
|
||||||
@@ -1397,7 +1397,7 @@ def test_pre_model_hook() -> None:
|
|||||||
def pre_model_hook(state: AgentState):
|
def pre_model_hook(state: AgentState):
|
||||||
return {"llm_input_messages": [HumanMessage("Hello!")]}
|
return {"llm_input_messages": [HumanMessage("Hello!")]}
|
||||||
|
|
||||||
agent = create_react_agent(model, [], pre_model_hook=pre_model_hook)
|
agent = create_agent(model, [], pre_model_hook=pre_model_hook)
|
||||||
assert "pre_model_hook" in agent.nodes
|
assert "pre_model_hook" in agent.nodes
|
||||||
result = agent.invoke({"messages": [HumanMessage("hi?")]})
|
result = agent.invoke({"messages": [HumanMessage("hi?")]})
|
||||||
assert result == {
|
assert result == {
|
||||||
@@ -1411,7 +1411,7 @@ def test_pre_model_hook() -> None:
|
|||||||
def pre_model_hook(state: AgentState):
|
def pre_model_hook(state: AgentState):
|
||||||
return {"messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES), HumanMessage("Hello!")]}
|
return {"messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES), HumanMessage("Hello!")]}
|
||||||
|
|
||||||
agent = create_react_agent(model, [], pre_model_hook=pre_model_hook)
|
agent = create_agent(model, [], pre_model_hook=pre_model_hook)
|
||||||
result = agent.invoke({"messages": [HumanMessage("hi?")]})
|
result = agent.invoke({"messages": [HumanMessage("hi?")]})
|
||||||
assert result == {
|
assert result == {
|
||||||
"messages": [
|
"messages": [
|
||||||
@@ -1430,9 +1430,7 @@ def test_post_model_hook() -> None:
|
|||||||
def post_model_hook(state: FlagState) -> dict[str, bool]:
|
def post_model_hook(state: FlagState) -> dict[str, bool]:
|
||||||
return {"flag": True}
|
return {"flag": True}
|
||||||
|
|
||||||
pmh_agent = create_react_agent(
|
pmh_agent = create_agent(model, [], post_model_hook=post_model_hook, state_schema=FlagState)
|
||||||
model, [], post_model_hook=post_model_hook, state_schema=FlagState
|
|
||||||
)
|
|
||||||
|
|
||||||
assert "post_model_hook" in pmh_agent.nodes
|
assert "post_model_hook" in pmh_agent.nodes
|
||||||
|
|
||||||
@@ -1480,7 +1478,7 @@ def test_post_model_hook_with_structured_output() -> None:
|
|||||||
return {"flag": True}
|
return {"flag": True}
|
||||||
|
|
||||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model,
|
model,
|
||||||
[get_weather],
|
[get_weather],
|
||||||
response_format=WeatherResponse,
|
response_format=WeatherResponse,
|
||||||
@@ -1496,7 +1494,7 @@ def test_post_model_hook_with_structured_output() -> None:
|
|||||||
|
|
||||||
# Reset the state of the model
|
# Reset the state of the model
|
||||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model,
|
model,
|
||||||
[get_weather],
|
[get_weather],
|
||||||
response_format=WeatherResponse,
|
response_format=WeatherResponse,
|
||||||
@@ -1568,7 +1566,7 @@ def test_post_model_hook_with_structured_output() -> None:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_create_react_agent_inject_vars_with_post_model_hook() -> None:
|
def test_create_agent_inject_vars_with_post_model_hook() -> None:
|
||||||
store = InMemoryStore()
|
store = InMemoryStore()
|
||||||
namespace = ("test",)
|
namespace = ("test",)
|
||||||
store.put(namespace, "test_key", {"bar": 3})
|
store.put(namespace, "test_key", {"bar": 3})
|
||||||
@@ -1594,7 +1592,7 @@ def test_create_react_agent_inject_vars_with_post_model_hook() -> None:
|
|||||||
return {"foo": 2}
|
return {"foo": 2}
|
||||||
|
|
||||||
model = FakeToolCallingModel(tool_calls=[[tool_call], []])
|
model = FakeToolCallingModel(tool_calls=[[tool_call], []])
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model,
|
model,
|
||||||
ToolNode([tool1], handle_tool_errors=False),
|
ToolNode([tool1], handle_tool_errors=False),
|
||||||
state_schema=AgentStateExtraKey,
|
state_schema=AgentStateExtraKey,
|
||||||
@@ -1629,7 +1627,7 @@ def test_response_format_using_tool_choice() -> None:
|
|||||||
|
|
||||||
expected_structured_response = WeatherResponse(temperature=75)
|
expected_structured_response = WeatherResponse(temperature=75)
|
||||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model,
|
model,
|
||||||
[get_weather],
|
[get_weather],
|
||||||
response_format=WeatherResponse,
|
response_format=WeatherResponse,
|
||||||
|
@@ -5,7 +5,7 @@ import pytest
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from syrupy.assertion import SnapshotAssertion
|
from syrupy.assertion import SnapshotAssertion
|
||||||
|
|
||||||
from langchain.agents import create_react_agent
|
from langchain.agents import create_agent
|
||||||
|
|
||||||
from .model import FakeToolCallingModel
|
from .model import FakeToolCallingModel
|
||||||
|
|
||||||
@@ -39,7 +39,7 @@ def test_react_agent_graph_structure(
|
|||||||
pre_model_hook: Union[Callable, None],
|
pre_model_hook: Union[Callable, None],
|
||||||
post_model_hook: Union[Callable, None],
|
post_model_hook: Union[Callable, None],
|
||||||
) -> None:
|
) -> None:
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model,
|
model,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
pre_model_hook=pre_model_hook,
|
pre_model_hook=pre_model_hook,
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
"""Test suite for create_react_agent with structured output response_format permutations."""
|
"""Test suite for create_agent with structured output response_format permutations."""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -6,7 +6,7 @@ from dataclasses import dataclass
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
from langchain.agents import create_react_agent
|
from langchain.agents import create_agent
|
||||||
from langchain.agents.structured_output import (
|
from langchain.agents.structured_output import (
|
||||||
MultipleStructuredOutputsError,
|
MultipleStructuredOutputsError,
|
||||||
ProviderStrategy,
|
ProviderStrategy,
|
||||||
@@ -114,7 +114,7 @@ class TestResponseFormatAsModel:
|
|||||||
|
|
||||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||||
|
|
||||||
agent = create_react_agent(model, [get_weather], response_format=WeatherBaseModel)
|
agent = create_agent(model, [get_weather], response_format=WeatherBaseModel)
|
||||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||||
|
|
||||||
assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
|
assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
|
||||||
@@ -135,7 +135,7 @@ class TestResponseFormatAsModel:
|
|||||||
|
|
||||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||||
|
|
||||||
agent = create_react_agent(model, [get_weather], response_format=WeatherDataclass)
|
agent = create_agent(model, [get_weather], response_format=WeatherDataclass)
|
||||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||||
|
|
||||||
assert response["structured_response"] == EXPECTED_WEATHER_DATACLASS
|
assert response["structured_response"] == EXPECTED_WEATHER_DATACLASS
|
||||||
@@ -156,7 +156,7 @@ class TestResponseFormatAsModel:
|
|||||||
|
|
||||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||||
|
|
||||||
agent = create_react_agent(model, [get_weather], response_format=WeatherTypedDict)
|
agent = create_agent(model, [get_weather], response_format=WeatherTypedDict)
|
||||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||||
|
|
||||||
assert response["structured_response"] == EXPECTED_WEATHER_DICT
|
assert response["structured_response"] == EXPECTED_WEATHER_DICT
|
||||||
@@ -177,7 +177,7 @@ class TestResponseFormatAsModel:
|
|||||||
|
|
||||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||||
|
|
||||||
agent = create_react_agent(model, [get_weather], response_format=weather_json_schema)
|
agent = create_agent(model, [get_weather], response_format=weather_json_schema)
|
||||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||||
|
|
||||||
assert response["structured_response"] == EXPECTED_WEATHER_DICT
|
assert response["structured_response"] == EXPECTED_WEATHER_DICT
|
||||||
@@ -200,9 +200,7 @@ class TestResponseFormatAsToolStrategy:
|
|||||||
|
|
||||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||||
|
|
||||||
agent = create_react_agent(
|
agent = create_agent(model, [get_weather], response_format=ToolStrategy(WeatherBaseModel))
|
||||||
model, [get_weather], response_format=ToolStrategy(WeatherBaseModel)
|
|
||||||
)
|
|
||||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||||
|
|
||||||
assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
|
assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
|
||||||
@@ -223,9 +221,7 @@ class TestResponseFormatAsToolStrategy:
|
|||||||
|
|
||||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||||
|
|
||||||
agent = create_react_agent(
|
agent = create_agent(model, [get_weather], response_format=ToolStrategy(WeatherDataclass))
|
||||||
model, [get_weather], response_format=ToolStrategy(WeatherDataclass)
|
|
||||||
)
|
|
||||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||||
|
|
||||||
assert response["structured_response"] == EXPECTED_WEATHER_DATACLASS
|
assert response["structured_response"] == EXPECTED_WEATHER_DATACLASS
|
||||||
@@ -246,9 +242,7 @@ class TestResponseFormatAsToolStrategy:
|
|||||||
|
|
||||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||||
|
|
||||||
agent = create_react_agent(
|
agent = create_agent(model, [get_weather], response_format=ToolStrategy(WeatherTypedDict))
|
||||||
model, [get_weather], response_format=ToolStrategy(WeatherTypedDict)
|
|
||||||
)
|
|
||||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||||
|
|
||||||
assert response["structured_response"] == EXPECTED_WEATHER_DICT
|
assert response["structured_response"] == EXPECTED_WEATHER_DICT
|
||||||
@@ -269,7 +263,7 @@ class TestResponseFormatAsToolStrategy:
|
|||||||
|
|
||||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||||
|
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model, [get_weather], response_format=ToolStrategy(weather_json_schema)
|
model, [get_weather], response_format=ToolStrategy(weather_json_schema)
|
||||||
)
|
)
|
||||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||||
@@ -292,7 +286,7 @@ class TestResponseFormatAsToolStrategy:
|
|||||||
|
|
||||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||||
|
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model,
|
model,
|
||||||
[get_weather, get_location],
|
[get_weather, get_location],
|
||||||
response_format=ToolStrategy({"oneOf": [weather_json_schema, location_json_schema]}),
|
response_format=ToolStrategy({"oneOf": [weather_json_schema, location_json_schema]}),
|
||||||
@@ -316,7 +310,7 @@ class TestResponseFormatAsToolStrategy:
|
|||||||
|
|
||||||
model_location = FakeToolCallingModel(tool_calls=tool_calls_location)
|
model_location = FakeToolCallingModel(tool_calls=tool_calls_location)
|
||||||
|
|
||||||
agent_location = create_react_agent(
|
agent_location = create_agent(
|
||||||
model_location,
|
model_location,
|
||||||
[get_weather, get_location],
|
[get_weather, get_location],
|
||||||
response_format=ToolStrategy({"oneOf": [weather_json_schema, location_json_schema]}),
|
response_format=ToolStrategy({"oneOf": [weather_json_schema, location_json_schema]}),
|
||||||
@@ -344,7 +338,7 @@ class TestResponseFormatAsToolStrategy:
|
|||||||
tool_calls=tool_calls
|
tool_calls=tool_calls
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model,
|
model,
|
||||||
[get_weather, get_location],
|
[get_weather, get_location],
|
||||||
response_format=ToolStrategy(Union[WeatherBaseModel, LocationResponse]),
|
response_format=ToolStrategy(Union[WeatherBaseModel, LocationResponse]),
|
||||||
@@ -368,7 +362,7 @@ class TestResponseFormatAsToolStrategy:
|
|||||||
|
|
||||||
model_location = FakeToolCallingModel(tool_calls=tool_calls_location)
|
model_location = FakeToolCallingModel(tool_calls=tool_calls_location)
|
||||||
|
|
||||||
agent_location = create_react_agent(
|
agent_location = create_agent(
|
||||||
model_location,
|
model_location,
|
||||||
[get_weather, get_location],
|
[get_weather, get_location],
|
||||||
response_format=ToolStrategy(Union[WeatherBaseModel, LocationResponse]),
|
response_format=ToolStrategy(Union[WeatherBaseModel, LocationResponse]),
|
||||||
@@ -397,7 +391,7 @@ class TestResponseFormatAsToolStrategy:
|
|||||||
|
|
||||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||||
|
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model,
|
model,
|
||||||
[],
|
[],
|
||||||
response_format=ToolStrategy(
|
response_format=ToolStrategy(
|
||||||
@@ -438,7 +432,7 @@ class TestResponseFormatAsToolStrategy:
|
|||||||
|
|
||||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||||
|
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model,
|
model,
|
||||||
[],
|
[],
|
||||||
response_format=ToolStrategy(
|
response_format=ToolStrategy(
|
||||||
@@ -467,7 +461,7 @@ class TestResponseFormatAsToolStrategy:
|
|||||||
|
|
||||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||||
|
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model,
|
model,
|
||||||
[],
|
[],
|
||||||
response_format=ToolStrategy(
|
response_format=ToolStrategy(
|
||||||
@@ -503,7 +497,7 @@ class TestResponseFormatAsToolStrategy:
|
|||||||
|
|
||||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||||
|
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model,
|
model,
|
||||||
[],
|
[],
|
||||||
response_format=ToolStrategy(
|
response_format=ToolStrategy(
|
||||||
@@ -549,7 +543,7 @@ class TestResponseFormatAsToolStrategy:
|
|||||||
return "Custom error: Multiple outputs not allowed"
|
return "Custom error: Multiple outputs not allowed"
|
||||||
return "Custom error"
|
return "Custom error"
|
||||||
|
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model,
|
model,
|
||||||
[],
|
[],
|
||||||
response_format=ToolStrategy(
|
response_format=ToolStrategy(
|
||||||
@@ -587,7 +581,7 @@ class TestResponseFormatAsToolStrategy:
|
|||||||
|
|
||||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||||
|
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model,
|
model,
|
||||||
[],
|
[],
|
||||||
response_format=ToolStrategy(
|
response_format=ToolStrategy(
|
||||||
@@ -617,7 +611,7 @@ class TestResponseFormatAsProviderStrategy:
|
|||||||
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_PYDANTIC
|
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_PYDANTIC
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model, [get_weather], response_format=ProviderStrategy(WeatherBaseModel)
|
model, [get_weather], response_format=ProviderStrategy(WeatherBaseModel)
|
||||||
)
|
)
|
||||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||||
@@ -635,7 +629,7 @@ class TestResponseFormatAsProviderStrategy:
|
|||||||
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DATACLASS
|
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DATACLASS
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model, [get_weather], response_format=ProviderStrategy(WeatherDataclass)
|
model, [get_weather], response_format=ProviderStrategy(WeatherDataclass)
|
||||||
)
|
)
|
||||||
response = agent.invoke(
|
response = agent.invoke(
|
||||||
@@ -655,7 +649,7 @@ class TestResponseFormatAsProviderStrategy:
|
|||||||
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT
|
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model, [get_weather], response_format=ProviderStrategy(WeatherTypedDict)
|
model, [get_weather], response_format=ProviderStrategy(WeatherTypedDict)
|
||||||
)
|
)
|
||||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||||
@@ -673,7 +667,7 @@ class TestResponseFormatAsProviderStrategy:
|
|||||||
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT
|
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model, [get_weather], response_format=ProviderStrategy(weather_json_schema)
|
model, [get_weather], response_format=ProviderStrategy(weather_json_schema)
|
||||||
)
|
)
|
||||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||||
@@ -699,7 +693,7 @@ def test_union_of_types() -> None:
|
|||||||
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_PYDANTIC
|
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_PYDANTIC
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model,
|
model,
|
||||||
[get_weather, get_location],
|
[get_weather, get_location],
|
||||||
response_format=ToolStrategy(Union[WeatherBaseModel, LocationResponse]),
|
response_format=ToolStrategy(Union[WeatherBaseModel, LocationResponse]),
|
||||||
|
@@ -119,7 +119,7 @@ def test_responses_integration_matrix(case: TestCase) -> None:
|
|||||||
http_client=http_client,
|
http_client=http_client,
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model,
|
model,
|
||||||
tools=[role_tool["tool"], dept_tool["tool"]],
|
tools=[role_tool["tool"], dept_tool["tool"]],
|
||||||
prompt=AGENT_PROMPT,
|
prompt=AGENT_PROMPT,
|
||||||
|
@@ -71,14 +71,14 @@ def test_return_direct_integration_matrix(case: TestCase) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if case.response_format:
|
if case.response_format:
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model,
|
model,
|
||||||
tools=[poll_tool["tool"]],
|
tools=[poll_tool["tool"]],
|
||||||
prompt=AGENT_PROMPT,
|
prompt=AGENT_PROMPT,
|
||||||
response_format=ToolStrategy(case.response_format),
|
response_format=ToolStrategy(case.response_format),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
agent = create_react_agent(
|
agent = create_agent(
|
||||||
model,
|
model,
|
||||||
tools=[poll_tool["tool"]],
|
tools=[poll_tool["tool"]],
|
||||||
prompt=AGENT_PROMPT,
|
prompt=AGENT_PROMPT,
|
||||||
|
Reference in New Issue
Block a user