chore(langchain): add overloads to create_agent (#34309)

This way mypy can infer the return type when `ResponseT` is not passed.

---------

Co-authored-by: Mason Daugherty <61371264+mdrxy@users.noreply.github.com>
Co-authored-by: open-swe[bot] <open-swe@users.noreply.github.com>
Co-authored-by: Mason Daugherty <mason@langchain.dev>
Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
Christophe Bornet
2026-06-11 07:12:20 +02:00
committed by GitHub
parent 86428c63ac
commit d5f7d33f88
6 changed files with 160 additions and 19 deletions

View File

@@ -27,7 +27,7 @@ from langgraph.prebuilt import ToolCallTransformer
from langgraph.prebuilt.tool_node import ToolNode
from langgraph.types import Command, Send
from langsmith import traceable
from typing_extensions import NotRequired, Required, TypedDict
from typing_extensions import NotRequired, Required, TypedDict, overload
from langchain.agents._subagent_transformer import SubagentTransformer
from langchain.agents.middleware.types import (
@@ -35,15 +35,15 @@ from langchain.agents.middleware.types import (
AgentState,
ContextT,
ExtendedModelResponse,
InputAgentState,
JumpTo,
ModelRequest,
ModelResponse,
OmitFromSchema,
OutputAgentState,
ResponseT,
StateT_co,
ToolCallRequest,
_InputAgentState,
_OutputAgentState,
)
from langchain.agents.structured_output import (
AutoStrategy,
@@ -714,6 +714,76 @@ def _chain_async_tool_call_wrappers(
return result
# No `response_format`: there is no structured output, so `ResponseT` resolves to `Any`.
@overload
def create_agent(
model: str | BaseChatModel,
tools: Sequence[BaseTool | Callable[..., Any] | dict[str, Any]] | None = None,
*,
system_prompt: str | SystemMessage | None = None,
middleware: Sequence[AgentMiddleware[StateT_co, ContextT]] = (),
response_format: None = None,
state_schema: None = None,
context_schema: type[ContextT] | None = None,
checkpointer: Checkpointer | None = None,
store: BaseStore | None = None,
interrupt_before: list[str] | None = None,
interrupt_after: list[str] | None = None,
debug: bool = False,
name: str | None = None,
cache: BaseCache[Any] | None = None,
transformers: Sequence[TransformerFactory] | None = None,
) -> CompiledStateGraph[AgentState[Any], ContextT, InputAgentState, OutputAgentState[Any]]: ...
# Raw-dict `response_format`: structured output is an untyped `dict[str, Any]`.
@overload
def create_agent(
model: str | BaseChatModel,
tools: Sequence[BaseTool | Callable[..., Any] | dict[str, Any]] | None = None,
*,
system_prompt: str | SystemMessage | None = None,
middleware: Sequence[AgentMiddleware[StateT_co, ContextT]] = (),
response_format: dict[str, Any],
state_schema: type[AgentState[dict[str, Any]]] | None = None,
context_schema: type[ContextT] | None = None,
checkpointer: Checkpointer | None = None,
store: BaseStore | None = None,
interrupt_before: list[str] | None = None,
interrupt_after: list[str] | None = None,
debug: bool = False,
name: str | None = None,
cache: BaseCache[Any] | None = None,
transformers: Sequence[TransformerFactory] | None = None,
) -> CompiledStateGraph[
AgentState[dict[str, Any]], ContextT, InputAgentState, OutputAgentState[dict[str, Any]]
]: ...
# Schema-typed `response_format`: `ResponseT` is inferred from the schema/type.
@overload
def create_agent(
model: str | BaseChatModel,
tools: Sequence[BaseTool | Callable[..., Any] | dict[str, Any]] | None = None,
*,
system_prompt: str | SystemMessage | None = None,
middleware: Sequence[AgentMiddleware[StateT_co, ContextT]] = (),
response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
state_schema: type[AgentState[ResponseT]] | None = None,
context_schema: type[ContextT] | None = None,
checkpointer: Checkpointer | None = None,
store: BaseStore | None = None,
interrupt_before: list[str] | None = None,
interrupt_after: list[str] | None = None,
debug: bool = False,
name: str | None = None,
cache: BaseCache[Any] | None = None,
transformers: Sequence[TransformerFactory] | None = None,
) -> CompiledStateGraph[
AgentState[ResponseT], ContextT, InputAgentState, OutputAgentState[ResponseT]
]: ...
def create_agent(
model: str | BaseChatModel,
tools: Sequence[BaseTool | Callable[..., Any] | dict[str, Any]] | None = None,
@@ -732,7 +802,7 @@ def create_agent(
cache: BaseCache[Any] | None = None,
transformers: Sequence[TransformerFactory] | None = None,
) -> CompiledStateGraph[
AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT]
AgentState[ResponseT], ContextT, InputAgentState, OutputAgentState[ResponseT]
]:
"""Creates an agent graph that calls tools in a loop until a stopping condition is met.
@@ -1066,7 +1136,7 @@ def create_agent(
# create graph, add nodes
graph: StateGraph[
AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT]
AgentState[ResponseT], ContextT, InputAgentState, OutputAgentState[ResponseT]
] = StateGraph(
state_schema=resolved_state_schema,
input_schema=input_schema,
@@ -1860,7 +1930,7 @@ def _make_tools_to_model_edge(
def _add_middleware_edge(
graph: StateGraph[
AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT]
AgentState[ResponseT], ContextT, InputAgentState, OutputAgentState[ResponseT]
],
*,
name: str,

View File

@@ -30,9 +30,11 @@ from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ExtendedModelResponse,
InputAgentState,
ModelCallResult,
ModelRequest,
ModelResponse,
OutputAgentState,
ToolCallRequest,
after_agent,
after_model,
@@ -55,6 +57,7 @@ __all__ = [
"FilesystemFileSearchMiddleware",
"HostExecutionPolicy",
"HumanInTheLoopMiddleware",
"InputAgentState",
"InterruptOnConfig",
"LLMToolEmulator",
"LLMToolSelectorMiddleware",
@@ -64,6 +67,7 @@ __all__ = [
"ModelRequest",
"ModelResponse",
"ModelRetryMiddleware",
"OutputAgentState",
"PIIDetectionError",
"PIIMiddleware",
"ProviderToolSearchMiddleware",

View File

@@ -44,10 +44,12 @@ __all__ = [
"AgentState",
"ContextT",
"ExtendedModelResponse",
"InputAgentState",
"ModelCallResult",
"ModelRequest",
"ModelResponse",
"OmitFromSchema",
"OutputAgentState",
"ResponseT",
"StateT_co",
"ToolCallRequest",
@@ -350,19 +352,25 @@ class AgentState(TypedDict, Generic[ResponseT]):
structured_response: NotRequired[Annotated[ResponseT, OmitFromInput]]
class _InputAgentState(TypedDict): # noqa: PYI049
class InputAgentState(TypedDict):
"""Input state schema for the agent."""
messages: Required[Annotated[list[AnyMessage | dict[str, Any]], add_messages]]
class _OutputAgentState(TypedDict, Generic[ResponseT]): # noqa: PYI049
class OutputAgentState(TypedDict, Generic[ResponseT]):
"""Output state schema for the agent."""
messages: Required[Annotated[list[AnyMessage], add_messages]]
structured_response: NotRequired[ResponseT]
# Deprecated aliases kept for backwards compatibility with external consumers that
# imported the previously private names. Remove in a future release.
_InputAgentState = InputAgentState
_OutputAgentState = OutputAgentState
StateT = TypeVar("StateT", bound=AgentState[Any], default=AgentState[Any])
StateT_co = TypeVar("StateT_co", bound=AgentState[Any], default=AgentState[Any], covariant=True)
StateT_contra = TypeVar("StateT_contra", bound=AgentState[Any], contravariant=True)

View File

@@ -23,7 +23,7 @@ if TYPE_CHECKING:
from langchain_core.runnables import RunnableConfig
from langgraph.graph.state import CompiledStateGraph
from langchain.agents.middleware.types import _InputAgentState
from langchain.agents.middleware.types import InputAgentState
def _get_model(provider: str) -> Any:
@@ -53,7 +53,7 @@ def test_hitl_reject_does_not_retry(provider: str) -> None:
the interrupt. So a completed run with no new `__interrupt__` is a reliable signal
that the model honored the guidance and did not retry.
"""
agent: CompiledStateGraph[Any, Any, _InputAgentState, Any] = create_agent(
agent: CompiledStateGraph[Any, Any, InputAgentState, Any] = create_agent(
model=_get_model(provider),
tools=[get_weather],
middleware=[

View File

@@ -14,10 +14,6 @@ from langchain.agents.middleware.shell_tool import ShellToolMiddleware
if TYPE_CHECKING:
from pathlib import Path
from langgraph.graph.state import CompiledStateGraph
from langchain.agents.middleware.types import _InputAgentState
def _get_model(provider: str) -> Any:
"""Get chat model for the specified provider."""
@@ -35,7 +31,7 @@ def _get_model(provider: str) -> Any:
def test_shell_tool_basic_execution(tmp_path: Path, provider: str) -> None:
"""Test basic shell command execution across different models."""
workspace = tmp_path / "workspace"
agent: CompiledStateGraph[Any, Any, _InputAgentState, Any] = create_agent(
agent = create_agent(
model=_get_model(provider),
middleware=[ShellToolMiddleware(workspace_root=workspace)],
)
@@ -57,7 +53,7 @@ def test_shell_tool_basic_execution(tmp_path: Path, provider: str) -> None:
def test_shell_session_persistence(tmp_path: Path) -> None:
"""Test shell session state persists across multiple tool calls."""
workspace = tmp_path / "workspace"
agent: CompiledStateGraph[Any, Any, _InputAgentState, Any] = create_agent(
agent = create_agent(
model=_get_model("anthropic"),
middleware=[ShellToolMiddleware(workspace_root=workspace)],
)
@@ -84,7 +80,7 @@ def test_shell_session_persistence(tmp_path: Path) -> None:
def test_shell_tool_error_handling(tmp_path: Path) -> None:
"""Test shell tool captures command errors."""
workspace = tmp_path / "workspace"
agent: CompiledStateGraph[Any, Any, _InputAgentState, Any] = create_agent(
agent = create_agent(
model=_get_model("anthropic"),
middleware=[ShellToolMiddleware(workspace_root=workspace)],
)
@@ -121,7 +117,7 @@ def test_shell_tool_with_custom_tools(tmp_path: Path) -> None:
"""Greet someone by name."""
return f"Hello, {name}!"
agent: CompiledStateGraph[Any, Any, _InputAgentState, Any] = create_agent(
agent = create_agent(
model=_get_model("anthropic"),
tools=[custom_greeting],
middleware=[ShellToolMiddleware(workspace_root=workspace)],

View File

@@ -21,15 +21,17 @@ import pytest
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
from langchain_core.messages import AIMessage, HumanMessage
from pydantic import BaseModel
from typing_extensions import TypedDict, override
from typing_extensions import TypedDict, assert_type, override
from langchain.agents import create_agent
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextT,
InputAgentState,
ModelRequest,
ModelResponse,
OutputAgentState,
ResponseT,
before_model,
)
@@ -72,6 +74,12 @@ class SummaryResult(BaseModel):
key_points: list[str]
class CustomAgentState(AgentState[Any]):
"""Custom state schema without a structured response format."""
user_id: str
# =============================================================================
# 1. BACKWARDS COMPATIBLE: Middlewares without type parameters
# These work when create_agent has NO context_schema or response_format
@@ -254,6 +262,61 @@ def test_create_agent_no_context_schema(fake_model: GenericFakeChatModel) -> Non
assert agent is not None
def test_create_agent_custom_state_without_response_format(
fake_model: GenericFakeChatModel,
) -> None:
"""Custom state without response_format should not infer dict structured output."""
agent = create_agent(
model=fake_model,
state_schema=CustomAgentState,
)
if TYPE_CHECKING:
assert_type(
agent,
CompiledStateGraph[AgentState[Any], None, InputAgentState, OutputAgentState[Any]],
)
assert agent is not None
def test_create_agent_dict_response_format(fake_model: GenericFakeChatModel) -> None:
"""A raw-dict `response_format` infers an untyped `dict` structured response."""
json_schema: dict[str, Any] = {"type": "json_schema", "schema": {}}
agent = create_agent(model=fake_model, response_format=json_schema)
if TYPE_CHECKING:
assert_type(
agent,
CompiledStateGraph[
AgentState[dict[str, Any]],
None,
InputAgentState,
OutputAgentState[dict[str, Any]],
],
)
assert agent is not None
def test_create_agent_typed_response_format(fake_model: GenericFakeChatModel) -> None:
"""A schema-typed `response_format` infers a matching `ResponseT`."""
agent = create_agent(model=fake_model, response_format=AnalysisResult)
if TYPE_CHECKING:
assert_type(
agent,
CompiledStateGraph[
AgentState[AnalysisResult],
None,
InputAgentState,
OutputAgentState[AnalysisResult],
],
)
assert agent is not None
def test_create_agent_with_user_context(fake_model: GenericFakeChatModel) -> None:
"""Typed: context_schema=UserContext requires matching middleware."""
agent: CompiledStateGraph[Any, UserContext, Any, Any] = create_agent(