some linting

This commit is contained in:
Sydney Runkle
2025-09-05 10:16:21 -04:00
parent 522f99da34
commit a6a4b0d58f
6 changed files with 19 additions and 46 deletions

View File

@@ -1,4 +1,3 @@
from langchain.agents.types import AgentJump, AgentMiddleware, AgentState, AgentUpdate
from langgraph.prebuilt.interrupt import (
ActionRequest,
HumanInterrupt,
@@ -7,6 +6,8 @@ from langgraph.prebuilt.interrupt import (
)
from langgraph.types import interrupt
from langchain.agents.types import AgentJump, AgentMiddleware, AgentState, AgentUpdate
ToolInterruptConfig = dict[str, HumanInterruptConfig]
@@ -23,12 +24,12 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
def after_model(self, state: AgentState) -> AgentUpdate | AgentJump | None:
messages = state["messages"]
if not messages:
return
return None
last_message = messages[-1]
if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
return
return None
# Separate tool calls that need interrupts from those that don't
interrupt_tool_calls = []
@@ -43,7 +44,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
# If no interrupts needed, return early
if not interrupt_tool_calls:
return
return None
approved_tool_calls = auto_approved_tool_calls.copy()
@@ -53,9 +54,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
for tool_call in interrupt_tool_calls:
tool_name = tool_call["name"]
tool_args = tool_call["args"]
description = (
f"{self.message_prefix}\n\nTool: {tool_name}\nArgs: {tool_args}"
)
description = f"{self.message_prefix}\n\nTool: {tool_name}\nArgs: {tool_args}"
tool_config = self.tool_configs[tool_name]
request: HumanInterrupt = {

View File

@@ -1,18 +1,13 @@
from langchain.agents.new_agent import create_agent
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage
from langchain_core.tools import tool
import operator
from dataclasses import dataclass
from typing import Annotated
from pydantic import BaseModel
from langchain.agents.structured_output import ToolStrategy
from langchain.agents.types import AgentJump, AgentMiddleware, AgentState, AgentUpdate
class State(AgentState):
model_request_count: Annotated[int, operator.add]
class ModelRequestLimitMiddleware(AgentMiddleware):
"""Terminates after N model requests"""

View File

@@ -1,6 +1,5 @@
import uuid
from collections.abc import Sequence
from typing import Callable, Iterable
from collections.abc import Callable, Iterable, Sequence
from langchain_core.language_models import LanguageModelLike
from langchain_core.messages import (
@@ -85,10 +84,7 @@ class SummarizationMiddleware(AgentMiddleware):
),
ToolMessage(tool_call_id=fake_tool_call_id, content=summary),
]
return {
"messages": [RemoveMessage(id=m.id) for m in messages_to_summarize]
+ fake_messages
}
return {"messages": [RemoveMessage(id=m.id) for m in messages_to_summarize] + fake_messages}
def _summarize_messages(self, messages_to_summarize: Sequence[AnyMessage]) -> str:
system_message = self.summary_system_prompt

View File

@@ -1,11 +1,10 @@
from dataclasses import dataclass, field
from typing import Annotated, Any, Dict, List, cast
from dataclasses import field
from typing import cast
from langchain_core.messages import AIMessage
from typing_extensions import Annotated
from langgraph.agents.types import AgentJump, AgentMiddleware, AgentState, AgentUpdate
class State(AgentState):
tool_call_count: dict[str, int] = field(default_factory=dict)
@@ -19,14 +18,14 @@ class ToolCallLimitMiddleware(AgentMiddleware):
self.tool_limits = tool_limits
def after_model(self, state: State) -> AgentUpdate | AgentJump | None:
ai_msg: AIMessage = cast(AIMessage, state["messages"][-1])
ai_msg: AIMessage = cast("AIMessage", state["messages"][-1])
tool_calls = {}
for call in ai_msg.tool_calls or []:
tool_calls[call["name"]] = tool_calls.get(call["name"], 0) + 1
aggregate_calls = state["tool_call_count"].copy()
for tool_name in tool_calls.keys():
for tool_name in tool_calls:
aggregate_calls[tool_name] = aggregate_calls.get(tool_name, 0) + 1
for tool_name, max_calls in self.tool_limits.items():

View File

@@ -422,7 +422,9 @@ def create_agent(
)
if middleware_w_modify_model_request:
for m1, m2 in zip(middleware_w_modify_model_request, middleware_w_modify_model_request[1:], strict=False):
for m1, m2 in zip(
middleware_w_modify_model_request, middleware_w_modify_model_request[1:], strict=False
):
_add_middleware_edge(
graph,
m1.modify_model_request,

View File

@@ -8,25 +8,7 @@ from typing import Annotated
from pydantic import BaseModel
from langchain.agents.structured_output import ToolStrategy
from langchain.agents.types import AgentJump, AgentMiddleware, AgentState, AgentUpdate
class State(AgentState):
model_request_count: Annotated[int, operator.add]
class ModelRequestLimitMiddleware(AgentMiddleware):
"""Terminates after N model requests"""
state_schema = State
def __init__(self, max_requests: int = 10):
self.max_requests = max_requests
def before_model(self, state: State) -> AgentUpdate | AgentJump | None:
# TODO: want to be able to configure end behavior here
if state.get("model_request_count", 0) == self.max_requests:
return {"jump_to": "__end__"}
return {"model_request_count": 1}
from langchain.agents.middleware.model_call_limits import ModelRequestLimitMiddleware