diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_state_schema.py b/libs/langchain_v1/tests/unit_tests/agents/test_state_schema.py index 9b396d08ec1..cbd66d4a0a3 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_state_schema.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_state_schema.py @@ -10,8 +10,9 @@ from typing import TYPE_CHECKING, Annotated, Any from langchain_core.messages import HumanMessage from langchain_core.tools import tool +from typing_extensions import NotRequired, Required -from langchain.agents import create_agent +from langchain.agents import create_agent, factory from langchain.agents.middleware.types import ( AgentMiddleware, AgentState, @@ -253,3 +254,51 @@ def test_state_schema_with_private_state_field() -> None: # Verify the agent executed normally assert len(result["messages"]) == 4 # Human, AI (tool call), Tool result, AI (final) + + +def test_get_schema_type_hints_cache_hits_for_reused_schema() -> None: + """Test repeated schema resolution reuses cached type hints for the same schema.""" + + class CachedState(AgentState[Any]): + cached_field: str + required_field: Required[int] + optional_field: NotRequired[Annotated[str, PrivateStateAttr]] + + factory._get_schema_type_hints.cache_clear() + + factory._resolve_schemas({CachedState}) + first_info = factory._get_schema_type_hints.cache_info() + factory._resolve_schemas({CachedState}) + second_info = factory._get_schema_type_hints.cache_info() + + assert first_info.misses == 1 + assert first_info.hits == 0 + assert second_info.misses == 1 + assert second_info.hits == 1 + + +def test_get_schema_type_hints_cache_accepts_distinct_local_schema_types() -> None: + """Test locally defined schema classes remain hashable cache keys.""" + factory._get_schema_type_hints.cache_clear() + + def make_state_schema(name: str) -> type[AgentState[Any]]: + class LocalState(AgentState[Any]): + value: str + required_value: Required[int] + optional_private_value: NotRequired[Annotated[str, PrivateStateAttr]] + + LocalState.__name__ = name + return LocalState + + schema_a = make_state_schema("LocalStateA") + schema_b = make_state_schema("LocalStateB") + + factory._resolve_schemas({schema_a, schema_b}) + first_info = factory._get_schema_type_hints.cache_info() + factory._resolve_schemas({schema_a, schema_b}) + second_info = factory._get_schema_type_hints.cache_info() + + assert first_info.misses == 2 + assert first_info.hits == 0 + assert second_info.misses == 2 + assert second_info.hits == 2