diff --git a/libs/langchain_v1/langchain/agents/factory.py b/libs/langchain_v1/langchain/agents/factory.py index 306f9bde416..276d8132827 100644 --- a/libs/langchain_v1/langchain/agents/factory.py +++ b/libs/langchain_v1/langchain/agents/factory.py @@ -27,7 +27,7 @@ from langgraph.prebuilt import ToolCallTransformer from langgraph.prebuilt.tool_node import ToolNode from langgraph.types import Command, Send from langsmith import traceable -from typing_extensions import NotRequired, Required, TypedDict +from typing_extensions import NotRequired, Required, TypedDict, overload from langchain.agents._subagent_transformer import SubagentTransformer from langchain.agents.middleware.types import ( @@ -35,15 +35,15 @@ from langchain.agents.middleware.types import ( AgentState, ContextT, ExtendedModelResponse, + InputAgentState, JumpTo, ModelRequest, ModelResponse, OmitFromSchema, + OutputAgentState, ResponseT, StateT_co, ToolCallRequest, - _InputAgentState, - _OutputAgentState, ) from langchain.agents.structured_output import ( AutoStrategy, @@ -714,6 +714,76 @@ def _chain_async_tool_call_wrappers( return result +# No `response_format`: there is no structured output, so `ResponseT` resolves to `Any`. +@overload +def create_agent( + model: str | BaseChatModel, + tools: Sequence[BaseTool | Callable[..., Any] | dict[str, Any]] | None = None, + *, + system_prompt: str | SystemMessage | None = None, + middleware: Sequence[AgentMiddleware[StateT_co, ContextT]] = (), + response_format: None = None, + state_schema: None = None, + context_schema: type[ContextT] | None = None, + checkpointer: Checkpointer | None = None, + store: BaseStore | None = None, + interrupt_before: list[str] | None = None, + interrupt_after: list[str] | None = None, + debug: bool = False, + name: str | None = None, + cache: BaseCache[Any] | None = None, + transformers: Sequence[TransformerFactory] | None = None, +) -> CompiledStateGraph[AgentState[Any], ContextT, InputAgentState, OutputAgentState[Any]]: ... + + +# Raw-dict `response_format`: structured output is an untyped `dict[str, Any]`. +@overload +def create_agent( + model: str | BaseChatModel, + tools: Sequence[BaseTool | Callable[..., Any] | dict[str, Any]] | None = None, + *, + system_prompt: str | SystemMessage | None = None, + middleware: Sequence[AgentMiddleware[StateT_co, ContextT]] = (), + response_format: dict[str, Any], + state_schema: type[AgentState[dict[str, Any]]] | None = None, + context_schema: type[ContextT] | None = None, + checkpointer: Checkpointer | None = None, + store: BaseStore | None = None, + interrupt_before: list[str] | None = None, + interrupt_after: list[str] | None = None, + debug: bool = False, + name: str | None = None, + cache: BaseCache[Any] | None = None, + transformers: Sequence[TransformerFactory] | None = None, +) -> CompiledStateGraph[ + AgentState[dict[str, Any]], ContextT, InputAgentState, OutputAgentState[dict[str, Any]] +]: ... + + +# Schema-typed `response_format`: `ResponseT` is inferred from the schema/type. +@overload +def create_agent( + model: str | BaseChatModel, + tools: Sequence[BaseTool | Callable[..., Any] | dict[str, Any]] | None = None, + *, + system_prompt: str | SystemMessage | None = None, + middleware: Sequence[AgentMiddleware[StateT_co, ContextT]] = (), + response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None, + state_schema: type[AgentState[ResponseT]] | None = None, + context_schema: type[ContextT] | None = None, + checkpointer: Checkpointer | None = None, + store: BaseStore | None = None, + interrupt_before: list[str] | None = None, + interrupt_after: list[str] | None = None, + debug: bool = False, + name: str | None = None, + cache: BaseCache[Any] | None = None, + transformers: Sequence[TransformerFactory] | None = None, +) -> CompiledStateGraph[ + AgentState[ResponseT], ContextT, InputAgentState, OutputAgentState[ResponseT] +]: ... + + def create_agent( model: str | BaseChatModel, tools: Sequence[BaseTool | Callable[..., Any] | dict[str, Any]] | None = None, @@ -732,7 +802,7 @@ def create_agent( cache: BaseCache[Any] | None = None, transformers: Sequence[TransformerFactory] | None = None, ) -> CompiledStateGraph[ - AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT] + AgentState[ResponseT], ContextT, InputAgentState, OutputAgentState[ResponseT] ]: """Creates an agent graph that calls tools in a loop until a stopping condition is met. @@ -1066,7 +1136,7 @@ def create_agent( # create graph, add nodes graph: StateGraph[ - AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT] + AgentState[ResponseT], ContextT, InputAgentState, OutputAgentState[ResponseT] ] = StateGraph( state_schema=resolved_state_schema, input_schema=input_schema, @@ -1860,7 +1930,7 @@ def _make_tools_to_model_edge( def _add_middleware_edge( graph: StateGraph[ - AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT] + AgentState[ResponseT], ContextT, InputAgentState, OutputAgentState[ResponseT] ], *, name: str, diff --git a/libs/langchain_v1/langchain/agents/middleware/__init__.py b/libs/langchain_v1/langchain/agents/middleware/__init__.py index b107e8fa043..c49f39d85e8 100644 --- a/libs/langchain_v1/langchain/agents/middleware/__init__.py +++ b/libs/langchain_v1/langchain/agents/middleware/__init__.py @@ -30,9 +30,11 @@ from langchain.agents.middleware.types import ( AgentMiddleware, AgentState, ExtendedModelResponse, + InputAgentState, ModelCallResult, ModelRequest, ModelResponse, + OutputAgentState, ToolCallRequest, after_agent, after_model, @@ -55,6 +57,7 @@ __all__ = [ "FilesystemFileSearchMiddleware", "HostExecutionPolicy", "HumanInTheLoopMiddleware", + "InputAgentState", "InterruptOnConfig", "LLMToolEmulator", "LLMToolSelectorMiddleware", @@ -64,6 +67,7 @@ __all__ = [ "ModelRequest", "ModelResponse", "ModelRetryMiddleware", + "OutputAgentState", "PIIDetectionError", "PIIMiddleware", "ProviderToolSearchMiddleware", diff --git a/libs/langchain_v1/langchain/agents/middleware/types.py b/libs/langchain_v1/langchain/agents/middleware/types.py index 58ebcac29b1..c8d7f2b38fc 100644 --- a/libs/langchain_v1/langchain/agents/middleware/types.py +++ b/libs/langchain_v1/langchain/agents/middleware/types.py @@ -44,10 +44,12 @@ __all__ = [ "AgentState", "ContextT", "ExtendedModelResponse", + "InputAgentState", "ModelCallResult", "ModelRequest", "ModelResponse", "OmitFromSchema", + "OutputAgentState", "ResponseT", "StateT_co", "ToolCallRequest", @@ -350,19 +352,25 @@ class AgentState(TypedDict, Generic[ResponseT]): structured_response: NotRequired[Annotated[ResponseT, OmitFromInput]] -class _InputAgentState(TypedDict): # noqa: PYI049 +class InputAgentState(TypedDict): """Input state schema for the agent.""" messages: Required[Annotated[list[AnyMessage | dict[str, Any]], add_messages]] -class _OutputAgentState(TypedDict, Generic[ResponseT]): # noqa: PYI049 +class OutputAgentState(TypedDict, Generic[ResponseT]): """Output state schema for the agent.""" messages: Required[Annotated[list[AnyMessage], add_messages]] structured_response: NotRequired[ResponseT] +# Deprecated aliases kept for backwards compatibility with external consumers that +# imported the previously private names. Remove in a future release. +_InputAgentState = InputAgentState +_OutputAgentState = OutputAgentState + + StateT = TypeVar("StateT", bound=AgentState[Any], default=AgentState[Any]) StateT_co = TypeVar("StateT_co", bound=AgentState[Any], default=AgentState[Any], covariant=True) StateT_contra = TypeVar("StateT_contra", bound=AgentState[Any], contravariant=True) diff --git a/libs/langchain_v1/tests/integration_tests/agents/middleware/test_human_in_the_loop_integration.py b/libs/langchain_v1/tests/integration_tests/agents/middleware/test_human_in_the_loop_integration.py index a6de8f500ba..2d5ef524311 100644 --- a/libs/langchain_v1/tests/integration_tests/agents/middleware/test_human_in_the_loop_integration.py +++ b/libs/langchain_v1/tests/integration_tests/agents/middleware/test_human_in_the_loop_integration.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: from langchain_core.runnables import RunnableConfig from langgraph.graph.state import CompiledStateGraph - from langchain.agents.middleware.types import _InputAgentState + from langchain.agents.middleware.types import InputAgentState def _get_model(provider: str) -> Any: @@ -53,7 +53,7 @@ def test_hitl_reject_does_not_retry(provider: str) -> None: the interrupt. So a completed run with no new `__interrupt__` is a reliable signal that the model honored the guidance and did not retry. """ - agent: CompiledStateGraph[Any, Any, _InputAgentState, Any] = create_agent( + agent: CompiledStateGraph[Any, Any, InputAgentState, Any] = create_agent( model=_get_model(provider), tools=[get_weather], middleware=[ diff --git a/libs/langchain_v1/tests/integration_tests/agents/middleware/test_shell_tool_integration.py b/libs/langchain_v1/tests/integration_tests/agents/middleware/test_shell_tool_integration.py index cee6b825ce7..2d670f2b116 100644 --- a/libs/langchain_v1/tests/integration_tests/agents/middleware/test_shell_tool_integration.py +++ b/libs/langchain_v1/tests/integration_tests/agents/middleware/test_shell_tool_integration.py @@ -14,10 +14,6 @@ from langchain.agents.middleware.shell_tool import ShellToolMiddleware if TYPE_CHECKING: from pathlib import Path - from langgraph.graph.state import CompiledStateGraph - - from langchain.agents.middleware.types import _InputAgentState - def _get_model(provider: str) -> Any: """Get chat model for the specified provider.""" @@ -35,7 +31,7 @@ def _get_model(provider: str) -> Any: def test_shell_tool_basic_execution(tmp_path: Path, provider: str) -> None: """Test basic shell command execution across different models.""" workspace = tmp_path / "workspace" - agent: CompiledStateGraph[Any, Any, _InputAgentState, Any] = create_agent( + agent = create_agent( model=_get_model(provider), middleware=[ShellToolMiddleware(workspace_root=workspace)], ) @@ -57,7 +53,7 @@ def test_shell_tool_basic_execution(tmp_path: Path, provider: str) -> None: def test_shell_session_persistence(tmp_path: Path) -> None: """Test shell session state persists across multiple tool calls.""" workspace = tmp_path / "workspace" - agent: CompiledStateGraph[Any, Any, _InputAgentState, Any] = create_agent( + agent = create_agent( model=_get_model("anthropic"), middleware=[ShellToolMiddleware(workspace_root=workspace)], ) @@ -84,7 +80,7 @@ def test_shell_session_persistence(tmp_path: Path) -> None: def test_shell_tool_error_handling(tmp_path: Path) -> None: """Test shell tool captures command errors.""" workspace = tmp_path / "workspace" - agent: CompiledStateGraph[Any, Any, _InputAgentState, Any] = create_agent( + agent = create_agent( model=_get_model("anthropic"), middleware=[ShellToolMiddleware(workspace_root=workspace)], ) @@ -121,7 +117,7 @@ def test_shell_tool_with_custom_tools(tmp_path: Path) -> None: """Greet someone by name.""" return f"Hello, {name}!" - agent: CompiledStateGraph[Any, Any, _InputAgentState, Any] = create_agent( + agent = create_agent( model=_get_model("anthropic"), tools=[custom_greeting], middleware=[ShellToolMiddleware(workspace_root=workspace)], diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware_typing/test_middleware_typing.py b/libs/langchain_v1/tests/unit_tests/agents/middleware_typing/test_middleware_typing.py index b5f5cec6729..499059b67ae 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware_typing/test_middleware_typing.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware_typing/test_middleware_typing.py @@ -21,15 +21,17 @@ import pytest from langchain_core.language_models.fake_chat_models import GenericFakeChatModel from langchain_core.messages import AIMessage, HumanMessage from pydantic import BaseModel -from typing_extensions import TypedDict, override +from typing_extensions import TypedDict, assert_type, override from langchain.agents import create_agent from langchain.agents.middleware.types import ( AgentMiddleware, AgentState, ContextT, + InputAgentState, ModelRequest, ModelResponse, + OutputAgentState, ResponseT, before_model, ) @@ -72,6 +74,12 @@ class SummaryResult(BaseModel): key_points: list[str] +class CustomAgentState(AgentState[Any]): + """Custom state schema without a structured response format.""" + + user_id: str + + # ============================================================================= # 1. BACKWARDS COMPATIBLE: Middlewares without type parameters # These work when create_agent has NO context_schema or response_format @@ -254,6 +262,61 @@ def test_create_agent_no_context_schema(fake_model: GenericFakeChatModel) -> Non assert agent is not None +def test_create_agent_custom_state_without_response_format( + fake_model: GenericFakeChatModel, +) -> None: + """Custom state without response_format should not infer dict structured output.""" + agent = create_agent( + model=fake_model, + state_schema=CustomAgentState, + ) + + if TYPE_CHECKING: + assert_type( + agent, + CompiledStateGraph[AgentState[Any], None, InputAgentState, OutputAgentState[Any]], + ) + + assert agent is not None + + +def test_create_agent_dict_response_format(fake_model: GenericFakeChatModel) -> None: + """A raw-dict `response_format` infers an untyped `dict` structured response.""" + json_schema: dict[str, Any] = {"type": "json_schema", "schema": {}} + agent = create_agent(model=fake_model, response_format=json_schema) + + if TYPE_CHECKING: + assert_type( + agent, + CompiledStateGraph[ + AgentState[dict[str, Any]], + None, + InputAgentState, + OutputAgentState[dict[str, Any]], + ], + ) + + assert agent is not None + + +def test_create_agent_typed_response_format(fake_model: GenericFakeChatModel) -> None: + """A schema-typed `response_format` infers a matching `ResponseT`.""" + agent = create_agent(model=fake_model, response_format=AnalysisResult) + + if TYPE_CHECKING: + assert_type( + agent, + CompiledStateGraph[ + AgentState[AnalysisResult], + None, + InputAgentState, + OutputAgentState[AnalysisResult], + ], + ) + + assert agent is not None + + def test_create_agent_with_user_context(fake_model: GenericFakeChatModel) -> None: """Typed: context_schema=UserContext requires matching middleware.""" agent: CompiledStateGraph[Any, UserContext, Any, Any] = create_agent(