This commit is contained in:
Eugene Yurtsev
2026-04-10 12:01:46 -04:00
parent 2c9296c423
commit 4beaff0224

View File

@@ -10,8 +10,9 @@ from typing import TYPE_CHECKING, Annotated, Any
from langchain_core.messages import HumanMessage
from langchain_core.tools import tool
from typing_extensions import NotRequired, Required
from langchain.agents import create_agent
from langchain.agents import create_agent, factory
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
@@ -253,3 +254,51 @@ def test_state_schema_with_private_state_field() -> None:
# Verify the agent executed normally
assert len(result["messages"]) == 4 # Human, AI (tool call), Tool result, AI (final)
def test_get_schema_type_hints_cache_hits_for_reused_schema() -> None:
"""Test repeated schema resolution reuses cached type hints for the same schema."""
class CachedState(AgentState[Any]):
cached_field: str
required_field: Required[int]
optional_field: NotRequired[Annotated[str, PrivateStateAttr]]
factory._get_schema_type_hints.cache_clear()
factory._resolve_schemas({CachedState})
first_info = factory._get_schema_type_hints.cache_info()
factory._resolve_schemas({CachedState})
second_info = factory._get_schema_type_hints.cache_info()
assert first_info.misses == 1
assert first_info.hits == 0
assert second_info.misses == 1
assert second_info.hits == 1
def test_get_schema_type_hints_cache_accepts_distinct_local_schema_types() -> None:
"""Test locally defined schema classes remain hashable cache keys."""
factory._get_schema_type_hints.cache_clear()
def make_state_schema(name: str) -> type[AgentState[Any]]:
class LocalState(AgentState[Any]):
value: str
required_value: Required[int]
optional_private_value: NotRequired[Annotated[str, PrivateStateAttr]]
LocalState.__name__ = name
return LocalState
schema_a = make_state_schema("LocalStateA")
schema_b = make_state_schema("LocalStateB")
factory._resolve_schemas({schema_a, schema_b})
first_info = factory._get_schema_type_hints.cache_info()
factory._resolve_schemas({schema_a, schema_b})
second_info = factory._get_schema_type_hints.cache_info()
assert first_info.misses == 2
assert first_info.hits == 0
assert second_info.misses == 2
assert second_info.hits == 2