mirror of
https://github.com/hwchase17/langchain.git
synced 2026-03-18 11:07:36 +00:00
chore(langchain): fix types in test_injected_runtime_create_agent, test_create_agent_tool_validation (#34568)
This commit is contained in:
committed by
GitHub
parent
f1ab8c5c80
commit
5dc8ba3c99
@@ -1,5 +1,5 @@
|
||||
import sys
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Any
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage
|
||||
@@ -28,7 +28,7 @@ def test_tool_invocation_error_excludes_injected_state() -> None:
|
||||
"""
|
||||
|
||||
# Define a custom state schema with injected data
|
||||
class TestState(AgentState):
|
||||
class TestState(AgentState[Any]):
|
||||
secret_data: str # Example of state data not controlled by LLM
|
||||
|
||||
@dec_tool
|
||||
@@ -94,7 +94,7 @@ async def test_tool_invocation_error_excludes_injected_state_async() -> None:
|
||||
"""
|
||||
|
||||
# Define a custom state schema
|
||||
class TestState(AgentState):
|
||||
class TestState(AgentState[Any]):
|
||||
internal_data: str
|
||||
|
||||
@dec_tool
|
||||
@@ -193,10 +193,10 @@ def test_create_agent_error_content_with_multiple_params() -> None:
|
||||
This ensures the LLM receives focused, actionable feedback.
|
||||
"""
|
||||
|
||||
class TestState(AgentState):
|
||||
class TestState(AgentState[Any]):
|
||||
user_id: str
|
||||
api_key: str
|
||||
session_data: dict
|
||||
session_data: dict[str, Any]
|
||||
|
||||
@dec_tool
|
||||
def complex_tool(
|
||||
@@ -309,7 +309,7 @@ def test_create_agent_error_only_model_controllable_params() -> None:
|
||||
absent from error messages. This provides focused feedback to the LLM.
|
||||
"""
|
||||
|
||||
class StateWithSecrets(AgentState):
|
||||
class StateWithSecrets(AgentState[Any]):
|
||||
password: str # Example of data not controlled by LLM
|
||||
|
||||
@dec_tool
|
||||
|
||||
@@ -16,7 +16,7 @@ configurations.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated, Any
|
||||
from typing import TYPE_CHECKING, Annotated, Any
|
||||
|
||||
from langchain_core.messages import HumanMessage, ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
@@ -28,11 +28,14 @@ from langchain.agents.middleware.types import AgentMiddleware, AgentState
|
||||
from langchain.tools import InjectedState, ToolRuntime
|
||||
from tests.unit_tests.agents.model import FakeToolCallingModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
|
||||
def test_tool_runtime_basic_injection() -> None:
|
||||
"""Test basic ToolRuntime injection in tools with create_agent."""
|
||||
# Track what was injected
|
||||
injected_data = {}
|
||||
injected_data: dict[str, Any] = {}
|
||||
|
||||
@tool
|
||||
def runtime_tool(x: int, runtime: ToolRuntime) -> str:
|
||||
@@ -80,7 +83,7 @@ def test_tool_runtime_basic_injection() -> None:
|
||||
|
||||
async def test_tool_runtime_async_injection() -> None:
|
||||
"""Test ToolRuntime injection works with async tools."""
|
||||
injected_data = {}
|
||||
injected_data: dict[str, Any] = {}
|
||||
|
||||
@tool
|
||||
async def async_runtime_tool(x: int, runtime: ToolRuntime) -> str:
|
||||
@@ -195,7 +198,7 @@ def test_tool_runtime_with_store() -> None:
|
||||
|
||||
def test_tool_runtime_with_multiple_tools() -> None:
|
||||
"""Test multiple tools can all access ToolRuntime."""
|
||||
call_log = []
|
||||
call_log: list[tuple[str, str | None, int | str]] = []
|
||||
|
||||
@tool
|
||||
def tool_a(x: int, runtime: ToolRuntime) -> str:
|
||||
@@ -242,7 +245,7 @@ def test_tool_runtime_with_multiple_tools() -> None:
|
||||
|
||||
def test_tool_runtime_config_access() -> None:
|
||||
"""Test tools can access config through ToolRuntime."""
|
||||
config_data = {}
|
||||
config_data: dict[str, Any] = {}
|
||||
|
||||
@tool
|
||||
def config_tool(x: int, runtime: ToolRuntime) -> str:
|
||||
@@ -282,7 +285,7 @@ def test_tool_runtime_config_access() -> None:
|
||||
def test_tool_runtime_with_custom_state() -> None:
|
||||
"""Test ToolRuntime works with custom state schemas."""
|
||||
|
||||
class CustomState(AgentState):
|
||||
class CustomState(AgentState[Any]):
|
||||
custom_field: str
|
||||
|
||||
runtime_state = {}
|
||||
@@ -464,11 +467,11 @@ def test_tool_runtime_with_middleware() -> None:
|
||||
runtime_calls = []
|
||||
|
||||
class TestMiddleware(AgentMiddleware):
|
||||
def before_model(self, state, runtime) -> dict[str, Any]:
|
||||
def before_model(self, state: AgentState[Any], runtime: Runtime) -> dict[str, Any]:
|
||||
middleware_calls.append("before_model")
|
||||
return {}
|
||||
|
||||
def after_model(self, state, runtime) -> dict[str, Any]:
|
||||
def after_model(self, state: AgentState[Any], runtime: Runtime) -> dict[str, Any]:
|
||||
middleware_calls.append("after_model")
|
||||
return {}
|
||||
|
||||
@@ -515,11 +518,7 @@ def test_tool_runtime_type_hints() -> None:
|
||||
def typed_runtime_tool(x: int, runtime: ToolRuntime) -> str:
|
||||
"""Tool with runtime access."""
|
||||
# Access state dict - verify we can access standard state fields
|
||||
if isinstance(runtime.state, dict):
|
||||
# Count messages in state
|
||||
typed_runtime["message_count"] = len(runtime.state.get("messages", []))
|
||||
else:
|
||||
typed_runtime["message_count"] = len(getattr(runtime.state, "messages", []))
|
||||
typed_runtime["message_count"] = len(runtime.state.get("messages", []))
|
||||
return f"Typed: {x}"
|
||||
|
||||
agent = create_agent(
|
||||
@@ -546,7 +545,7 @@ def test_tool_runtime_type_hints() -> None:
|
||||
|
||||
def test_tool_runtime_name_based_injection() -> None:
|
||||
"""Test that parameter named 'runtime' gets injected without type annotation."""
|
||||
injected_data = {}
|
||||
injected_data: dict[str, Any] = {}
|
||||
|
||||
@tool
|
||||
def name_based_tool(x: int, runtime: Any) -> str:
|
||||
@@ -601,7 +600,7 @@ def test_combined_injected_state_runtime_store() -> None:
|
||||
injected_data = {}
|
||||
|
||||
# Custom state schema with additional fields
|
||||
class CustomState(AgentState):
|
||||
class CustomState(AgentState[Any]):
|
||||
user_id: str
|
||||
session_id: str
|
||||
|
||||
@@ -667,6 +666,7 @@ def test_combined_injected_state_runtime_store() -> None:
|
||||
|
||||
# Verify the tool's args schema only includes LLM-controlled parameters
|
||||
tool_args_schema = multi_injection_tool.args_schema
|
||||
assert isinstance(tool_args_schema, dict)
|
||||
assert "location" in tool_args_schema["properties"]
|
||||
assert "state" not in tool_args_schema["properties"]
|
||||
assert "runtime" not in tool_args_schema["properties"]
|
||||
@@ -718,7 +718,7 @@ async def test_combined_injected_state_runtime_store_async() -> None:
|
||||
injected_data = {}
|
||||
|
||||
# Custom state schema
|
||||
class CustomState(AgentState):
|
||||
class CustomState(AgentState[Any]):
|
||||
api_key: str
|
||||
request_id: str
|
||||
|
||||
@@ -792,6 +792,7 @@ async def test_combined_injected_state_runtime_store_async() -> None:
|
||||
|
||||
# Verify the tool's args schema only includes LLM-controlled parameters
|
||||
tool_args_schema = async_multi_injection_tool.args_schema
|
||||
assert isinstance(tool_args_schema, dict)
|
||||
assert "query" in tool_args_schema["properties"]
|
||||
assert "max_results" in tool_args_schema["properties"]
|
||||
assert "state" not in tool_args_schema["properties"]
|
||||
|
||||
Reference in New Issue
Block a user