mirror of
https://github.com/hwchase17/langchain.git
synced 2026-07-01 06:42:37 +00:00
chore(langchain): add overloads to create_agent (#34309)
This way mypy can infer the return type when `ResponseT` is not passed. --------- Co-authored-by: Mason Daugherty <61371264+mdrxy@users.noreply.github.com> Co-authored-by: open-swe[bot] <open-swe@users.noreply.github.com> Co-authored-by: Mason Daugherty <mason@langchain.dev> Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
committed by
GitHub
parent
86428c63ac
commit
d5f7d33f88
@@ -27,7 +27,7 @@ from langgraph.prebuilt import ToolCallTransformer
|
||||
from langgraph.prebuilt.tool_node import ToolNode
|
||||
from langgraph.types import Command, Send
|
||||
from langsmith import traceable
|
||||
from typing_extensions import NotRequired, Required, TypedDict
|
||||
from typing_extensions import NotRequired, Required, TypedDict, overload
|
||||
|
||||
from langchain.agents._subagent_transformer import SubagentTransformer
|
||||
from langchain.agents.middleware.types import (
|
||||
@@ -35,15 +35,15 @@ from langchain.agents.middleware.types import (
|
||||
AgentState,
|
||||
ContextT,
|
||||
ExtendedModelResponse,
|
||||
InputAgentState,
|
||||
JumpTo,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
OmitFromSchema,
|
||||
OutputAgentState,
|
||||
ResponseT,
|
||||
StateT_co,
|
||||
ToolCallRequest,
|
||||
_InputAgentState,
|
||||
_OutputAgentState,
|
||||
)
|
||||
from langchain.agents.structured_output import (
|
||||
AutoStrategy,
|
||||
@@ -714,6 +714,76 @@ def _chain_async_tool_call_wrappers(
|
||||
return result
|
||||
|
||||
|
||||
# No `response_format`: there is no structured output, so `ResponseT` resolves to `Any`.
|
||||
@overload
|
||||
def create_agent(
|
||||
model: str | BaseChatModel,
|
||||
tools: Sequence[BaseTool | Callable[..., Any] | dict[str, Any]] | None = None,
|
||||
*,
|
||||
system_prompt: str | SystemMessage | None = None,
|
||||
middleware: Sequence[AgentMiddleware[StateT_co, ContextT]] = (),
|
||||
response_format: None = None,
|
||||
state_schema: None = None,
|
||||
context_schema: type[ContextT] | None = None,
|
||||
checkpointer: Checkpointer | None = None,
|
||||
store: BaseStore | None = None,
|
||||
interrupt_before: list[str] | None = None,
|
||||
interrupt_after: list[str] | None = None,
|
||||
debug: bool = False,
|
||||
name: str | None = None,
|
||||
cache: BaseCache[Any] | None = None,
|
||||
transformers: Sequence[TransformerFactory] | None = None,
|
||||
) -> CompiledStateGraph[AgentState[Any], ContextT, InputAgentState, OutputAgentState[Any]]: ...
|
||||
|
||||
|
||||
# Raw-dict `response_format`: structured output is an untyped `dict[str, Any]`.
|
||||
@overload
|
||||
def create_agent(
|
||||
model: str | BaseChatModel,
|
||||
tools: Sequence[BaseTool | Callable[..., Any] | dict[str, Any]] | None = None,
|
||||
*,
|
||||
system_prompt: str | SystemMessage | None = None,
|
||||
middleware: Sequence[AgentMiddleware[StateT_co, ContextT]] = (),
|
||||
response_format: dict[str, Any],
|
||||
state_schema: type[AgentState[dict[str, Any]]] | None = None,
|
||||
context_schema: type[ContextT] | None = None,
|
||||
checkpointer: Checkpointer | None = None,
|
||||
store: BaseStore | None = None,
|
||||
interrupt_before: list[str] | None = None,
|
||||
interrupt_after: list[str] | None = None,
|
||||
debug: bool = False,
|
||||
name: str | None = None,
|
||||
cache: BaseCache[Any] | None = None,
|
||||
transformers: Sequence[TransformerFactory] | None = None,
|
||||
) -> CompiledStateGraph[
|
||||
AgentState[dict[str, Any]], ContextT, InputAgentState, OutputAgentState[dict[str, Any]]
|
||||
]: ...
|
||||
|
||||
|
||||
# Schema-typed `response_format`: `ResponseT` is inferred from the schema/type.
|
||||
@overload
|
||||
def create_agent(
|
||||
model: str | BaseChatModel,
|
||||
tools: Sequence[BaseTool | Callable[..., Any] | dict[str, Any]] | None = None,
|
||||
*,
|
||||
system_prompt: str | SystemMessage | None = None,
|
||||
middleware: Sequence[AgentMiddleware[StateT_co, ContextT]] = (),
|
||||
response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
|
||||
state_schema: type[AgentState[ResponseT]] | None = None,
|
||||
context_schema: type[ContextT] | None = None,
|
||||
checkpointer: Checkpointer | None = None,
|
||||
store: BaseStore | None = None,
|
||||
interrupt_before: list[str] | None = None,
|
||||
interrupt_after: list[str] | None = None,
|
||||
debug: bool = False,
|
||||
name: str | None = None,
|
||||
cache: BaseCache[Any] | None = None,
|
||||
transformers: Sequence[TransformerFactory] | None = None,
|
||||
) -> CompiledStateGraph[
|
||||
AgentState[ResponseT], ContextT, InputAgentState, OutputAgentState[ResponseT]
|
||||
]: ...
|
||||
|
||||
|
||||
def create_agent(
|
||||
model: str | BaseChatModel,
|
||||
tools: Sequence[BaseTool | Callable[..., Any] | dict[str, Any]] | None = None,
|
||||
@@ -732,7 +802,7 @@ def create_agent(
|
||||
cache: BaseCache[Any] | None = None,
|
||||
transformers: Sequence[TransformerFactory] | None = None,
|
||||
) -> CompiledStateGraph[
|
||||
AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT]
|
||||
AgentState[ResponseT], ContextT, InputAgentState, OutputAgentState[ResponseT]
|
||||
]:
|
||||
"""Creates an agent graph that calls tools in a loop until a stopping condition is met.
|
||||
|
||||
@@ -1066,7 +1136,7 @@ def create_agent(
|
||||
|
||||
# create graph, add nodes
|
||||
graph: StateGraph[
|
||||
AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT]
|
||||
AgentState[ResponseT], ContextT, InputAgentState, OutputAgentState[ResponseT]
|
||||
] = StateGraph(
|
||||
state_schema=resolved_state_schema,
|
||||
input_schema=input_schema,
|
||||
@@ -1860,7 +1930,7 @@ def _make_tools_to_model_edge(
|
||||
|
||||
def _add_middleware_edge(
|
||||
graph: StateGraph[
|
||||
AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT]
|
||||
AgentState[ResponseT], ContextT, InputAgentState, OutputAgentState[ResponseT]
|
||||
],
|
||||
*,
|
||||
name: str,
|
||||
|
||||
@@ -30,9 +30,11 @@ from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ExtendedModelResponse,
|
||||
InputAgentState,
|
||||
ModelCallResult,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
OutputAgentState,
|
||||
ToolCallRequest,
|
||||
after_agent,
|
||||
after_model,
|
||||
@@ -55,6 +57,7 @@ __all__ = [
|
||||
"FilesystemFileSearchMiddleware",
|
||||
"HostExecutionPolicy",
|
||||
"HumanInTheLoopMiddleware",
|
||||
"InputAgentState",
|
||||
"InterruptOnConfig",
|
||||
"LLMToolEmulator",
|
||||
"LLMToolSelectorMiddleware",
|
||||
@@ -64,6 +67,7 @@ __all__ = [
|
||||
"ModelRequest",
|
||||
"ModelResponse",
|
||||
"ModelRetryMiddleware",
|
||||
"OutputAgentState",
|
||||
"PIIDetectionError",
|
||||
"PIIMiddleware",
|
||||
"ProviderToolSearchMiddleware",
|
||||
|
||||
@@ -44,10 +44,12 @@ __all__ = [
|
||||
"AgentState",
|
||||
"ContextT",
|
||||
"ExtendedModelResponse",
|
||||
"InputAgentState",
|
||||
"ModelCallResult",
|
||||
"ModelRequest",
|
||||
"ModelResponse",
|
||||
"OmitFromSchema",
|
||||
"OutputAgentState",
|
||||
"ResponseT",
|
||||
"StateT_co",
|
||||
"ToolCallRequest",
|
||||
@@ -350,19 +352,25 @@ class AgentState(TypedDict, Generic[ResponseT]):
|
||||
structured_response: NotRequired[Annotated[ResponseT, OmitFromInput]]
|
||||
|
||||
|
||||
class _InputAgentState(TypedDict): # noqa: PYI049
|
||||
class InputAgentState(TypedDict):
|
||||
"""Input state schema for the agent."""
|
||||
|
||||
messages: Required[Annotated[list[AnyMessage | dict[str, Any]], add_messages]]
|
||||
|
||||
|
||||
class _OutputAgentState(TypedDict, Generic[ResponseT]): # noqa: PYI049
|
||||
class OutputAgentState(TypedDict, Generic[ResponseT]):
|
||||
"""Output state schema for the agent."""
|
||||
|
||||
messages: Required[Annotated[list[AnyMessage], add_messages]]
|
||||
structured_response: NotRequired[ResponseT]
|
||||
|
||||
|
||||
# Deprecated aliases kept for backwards compatibility with external consumers that
|
||||
# imported the previously private names. Remove in a future release.
|
||||
_InputAgentState = InputAgentState
|
||||
_OutputAgentState = OutputAgentState
|
||||
|
||||
|
||||
StateT = TypeVar("StateT", bound=AgentState[Any], default=AgentState[Any])
|
||||
StateT_co = TypeVar("StateT_co", bound=AgentState[Any], default=AgentState[Any], covariant=True)
|
||||
StateT_contra = TypeVar("StateT_contra", bound=AgentState[Any], contravariant=True)
|
||||
|
||||
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from langchain.agents.middleware.types import _InputAgentState
|
||||
from langchain.agents.middleware.types import InputAgentState
|
||||
|
||||
|
||||
def _get_model(provider: str) -> Any:
|
||||
@@ -53,7 +53,7 @@ def test_hitl_reject_does_not_retry(provider: str) -> None:
|
||||
the interrupt. So a completed run with no new `__interrupt__` is a reliable signal
|
||||
that the model honored the guidance and did not retry.
|
||||
"""
|
||||
agent: CompiledStateGraph[Any, Any, _InputAgentState, Any] = create_agent(
|
||||
agent: CompiledStateGraph[Any, Any, InputAgentState, Any] = create_agent(
|
||||
model=_get_model(provider),
|
||||
tools=[get_weather],
|
||||
middleware=[
|
||||
|
||||
@@ -14,10 +14,6 @@ from langchain.agents.middleware.shell_tool import ShellToolMiddleware
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from langchain.agents.middleware.types import _InputAgentState
|
||||
|
||||
|
||||
def _get_model(provider: str) -> Any:
|
||||
"""Get chat model for the specified provider."""
|
||||
@@ -35,7 +31,7 @@ def _get_model(provider: str) -> Any:
|
||||
def test_shell_tool_basic_execution(tmp_path: Path, provider: str) -> None:
|
||||
"""Test basic shell command execution across different models."""
|
||||
workspace = tmp_path / "workspace"
|
||||
agent: CompiledStateGraph[Any, Any, _InputAgentState, Any] = create_agent(
|
||||
agent = create_agent(
|
||||
model=_get_model(provider),
|
||||
middleware=[ShellToolMiddleware(workspace_root=workspace)],
|
||||
)
|
||||
@@ -57,7 +53,7 @@ def test_shell_tool_basic_execution(tmp_path: Path, provider: str) -> None:
|
||||
def test_shell_session_persistence(tmp_path: Path) -> None:
|
||||
"""Test shell session state persists across multiple tool calls."""
|
||||
workspace = tmp_path / "workspace"
|
||||
agent: CompiledStateGraph[Any, Any, _InputAgentState, Any] = create_agent(
|
||||
agent = create_agent(
|
||||
model=_get_model("anthropic"),
|
||||
middleware=[ShellToolMiddleware(workspace_root=workspace)],
|
||||
)
|
||||
@@ -84,7 +80,7 @@ def test_shell_session_persistence(tmp_path: Path) -> None:
|
||||
def test_shell_tool_error_handling(tmp_path: Path) -> None:
|
||||
"""Test shell tool captures command errors."""
|
||||
workspace = tmp_path / "workspace"
|
||||
agent: CompiledStateGraph[Any, Any, _InputAgentState, Any] = create_agent(
|
||||
agent = create_agent(
|
||||
model=_get_model("anthropic"),
|
||||
middleware=[ShellToolMiddleware(workspace_root=workspace)],
|
||||
)
|
||||
@@ -121,7 +117,7 @@ def test_shell_tool_with_custom_tools(tmp_path: Path) -> None:
|
||||
"""Greet someone by name."""
|
||||
return f"Hello, {name}!"
|
||||
|
||||
agent: CompiledStateGraph[Any, Any, _InputAgentState, Any] = create_agent(
|
||||
agent = create_agent(
|
||||
model=_get_model("anthropic"),
|
||||
tools=[custom_greeting],
|
||||
middleware=[ShellToolMiddleware(workspace_root=workspace)],
|
||||
|
||||
@@ -21,15 +21,17 @@ import pytest
|
||||
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict, override
|
||||
from typing_extensions import TypedDict, assert_type, override
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
InputAgentState,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
OutputAgentState,
|
||||
ResponseT,
|
||||
before_model,
|
||||
)
|
||||
@@ -72,6 +74,12 @@ class SummaryResult(BaseModel):
|
||||
key_points: list[str]
|
||||
|
||||
|
||||
class CustomAgentState(AgentState[Any]):
|
||||
"""Custom state schema without a structured response format."""
|
||||
|
||||
user_id: str
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 1. BACKWARDS COMPATIBLE: Middlewares without type parameters
|
||||
# These work when create_agent has NO context_schema or response_format
|
||||
@@ -254,6 +262,61 @@ def test_create_agent_no_context_schema(fake_model: GenericFakeChatModel) -> Non
|
||||
assert agent is not None
|
||||
|
||||
|
||||
def test_create_agent_custom_state_without_response_format(
|
||||
fake_model: GenericFakeChatModel,
|
||||
) -> None:
|
||||
"""Custom state without response_format should not infer dict structured output."""
|
||||
agent = create_agent(
|
||||
model=fake_model,
|
||||
state_schema=CustomAgentState,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
assert_type(
|
||||
agent,
|
||||
CompiledStateGraph[AgentState[Any], None, InputAgentState, OutputAgentState[Any]],
|
||||
)
|
||||
|
||||
assert agent is not None
|
||||
|
||||
|
||||
def test_create_agent_dict_response_format(fake_model: GenericFakeChatModel) -> None:
|
||||
"""A raw-dict `response_format` infers an untyped `dict` structured response."""
|
||||
json_schema: dict[str, Any] = {"type": "json_schema", "schema": {}}
|
||||
agent = create_agent(model=fake_model, response_format=json_schema)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
assert_type(
|
||||
agent,
|
||||
CompiledStateGraph[
|
||||
AgentState[dict[str, Any]],
|
||||
None,
|
||||
InputAgentState,
|
||||
OutputAgentState[dict[str, Any]],
|
||||
],
|
||||
)
|
||||
|
||||
assert agent is not None
|
||||
|
||||
|
||||
def test_create_agent_typed_response_format(fake_model: GenericFakeChatModel) -> None:
|
||||
"""A schema-typed `response_format` infers a matching `ResponseT`."""
|
||||
agent = create_agent(model=fake_model, response_format=AnalysisResult)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
assert_type(
|
||||
agent,
|
||||
CompiledStateGraph[
|
||||
AgentState[AnalysisResult],
|
||||
None,
|
||||
InputAgentState,
|
||||
OutputAgentState[AnalysisResult],
|
||||
],
|
||||
)
|
||||
|
||||
assert agent is not None
|
||||
|
||||
|
||||
def test_create_agent_with_user_context(fake_model: GenericFakeChatModel) -> None:
|
||||
"""Typed: context_schema=UserContext requires matching middleware."""
|
||||
agent: CompiledStateGraph[Any, UserContext, Any, Any] = create_agent(
|
||||
|
||||
Reference in New Issue
Block a user