mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-24 05:50:18 +00:00
some linting
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
from langchain.agents.types import AgentJump, AgentMiddleware, AgentState, AgentUpdate
|
||||
from langgraph.prebuilt.interrupt import (
|
||||
ActionRequest,
|
||||
HumanInterrupt,
|
||||
@@ -7,6 +6,8 @@ from langgraph.prebuilt.interrupt import (
|
||||
)
|
||||
from langgraph.types import interrupt
|
||||
|
||||
from langchain.agents.types import AgentJump, AgentMiddleware, AgentState, AgentUpdate
|
||||
|
||||
ToolInterruptConfig = dict[str, HumanInterruptConfig]
|
||||
|
||||
|
||||
@@ -23,12 +24,12 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
def after_model(self, state: AgentState) -> AgentUpdate | AgentJump | None:
|
||||
messages = state["messages"]
|
||||
if not messages:
|
||||
return
|
||||
return None
|
||||
|
||||
last_message = messages[-1]
|
||||
|
||||
if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
|
||||
return
|
||||
return None
|
||||
|
||||
# Separate tool calls that need interrupts from those that don't
|
||||
interrupt_tool_calls = []
|
||||
@@ -43,7 +44,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
|
||||
# If no interrupts needed, return early
|
||||
if not interrupt_tool_calls:
|
||||
return
|
||||
return None
|
||||
|
||||
approved_tool_calls = auto_approved_tool_calls.copy()
|
||||
|
||||
@@ -53,9 +54,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
for tool_call in interrupt_tool_calls:
|
||||
tool_name = tool_call["name"]
|
||||
tool_args = tool_call["args"]
|
||||
description = (
|
||||
f"{self.message_prefix}\n\nTool: {tool_name}\nArgs: {tool_args}"
|
||||
)
|
||||
description = f"{self.message_prefix}\n\nTool: {tool_name}\nArgs: {tool_args}"
|
||||
tool_config = self.tool_configs[tool_name]
|
||||
|
||||
request: HumanInterrupt = {
|
||||
|
||||
@@ -1,18 +1,13 @@
|
||||
from langchain.agents.new_agent import create_agent
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.tools import tool
|
||||
import operator
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated
|
||||
from pydantic import BaseModel
|
||||
from langchain.agents.structured_output import ToolStrategy
|
||||
|
||||
from langchain.agents.types import AgentJump, AgentMiddleware, AgentState, AgentUpdate
|
||||
|
||||
|
||||
class State(AgentState):
|
||||
model_request_count: Annotated[int, operator.add]
|
||||
|
||||
|
||||
class ModelRequestLimitMiddleware(AgentMiddleware):
|
||||
"""Terminates after N model requests"""
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from typing import Callable, Iterable
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
|
||||
from langchain_core.language_models import LanguageModelLike
|
||||
from langchain_core.messages import (
|
||||
@@ -85,10 +84,7 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
),
|
||||
ToolMessage(tool_call_id=fake_tool_call_id, content=summary),
|
||||
]
|
||||
return {
|
||||
"messages": [RemoveMessage(id=m.id) for m in messages_to_summarize]
|
||||
+ fake_messages
|
||||
}
|
||||
return {"messages": [RemoveMessage(id=m.id) for m in messages_to_summarize] + fake_messages}
|
||||
|
||||
def _summarize_messages(self, messages_to_summarize: Sequence[AnyMessage]) -> str:
|
||||
system_message = self.summary_system_prompt
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Annotated, Any, Dict, List, cast
|
||||
from dataclasses import field
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from langgraph.agents.types import AgentJump, AgentMiddleware, AgentState, AgentUpdate
|
||||
|
||||
|
||||
class State(AgentState):
|
||||
tool_call_count: dict[str, int] = field(default_factory=dict)
|
||||
|
||||
@@ -19,14 +18,14 @@ class ToolCallLimitMiddleware(AgentMiddleware):
|
||||
self.tool_limits = tool_limits
|
||||
|
||||
def after_model(self, state: State) -> AgentUpdate | AgentJump | None:
|
||||
ai_msg: AIMessage = cast(AIMessage, state["messages"][-1])
|
||||
ai_msg: AIMessage = cast("AIMessage", state["messages"][-1])
|
||||
|
||||
tool_calls = {}
|
||||
for call in ai_msg.tool_calls or []:
|
||||
tool_calls[call["name"]] = tool_calls.get(call["name"], 0) + 1
|
||||
|
||||
aggregate_calls = state["tool_call_count"].copy()
|
||||
for tool_name in tool_calls.keys():
|
||||
for tool_name in tool_calls:
|
||||
aggregate_calls[tool_name] = aggregate_calls.get(tool_name, 0) + 1
|
||||
|
||||
for tool_name, max_calls in self.tool_limits.items():
|
||||
|
||||
@@ -422,7 +422,9 @@ def create_agent(
|
||||
)
|
||||
|
||||
if middleware_w_modify_model_request:
|
||||
for m1, m2 in zip(middleware_w_modify_model_request, middleware_w_modify_model_request[1:], strict=False):
|
||||
for m1, m2 in zip(
|
||||
middleware_w_modify_model_request, middleware_w_modify_model_request[1:], strict=False
|
||||
):
|
||||
_add_middleware_edge(
|
||||
graph,
|
||||
m1.modify_model_request,
|
||||
|
||||
@@ -8,25 +8,7 @@ from typing import Annotated
|
||||
from pydantic import BaseModel
|
||||
from langchain.agents.structured_output import ToolStrategy
|
||||
|
||||
from langchain.agents.types import AgentJump, AgentMiddleware, AgentState, AgentUpdate
|
||||
|
||||
class State(AgentState):
|
||||
model_request_count: Annotated[int, operator.add]
|
||||
|
||||
class ModelRequestLimitMiddleware(AgentMiddleware):
|
||||
"""Terminates after N model requests"""
|
||||
|
||||
state_schema = State
|
||||
|
||||
def __init__(self, max_requests: int = 10):
|
||||
self.max_requests = max_requests
|
||||
|
||||
def before_model(self, state: State) -> AgentUpdate | AgentJump | None:
|
||||
# TODO: want to be able to configure end behavior here
|
||||
if state.get("model_request_count", 0) == self.max_requests:
|
||||
return {"jump_to": "__end__"}
|
||||
|
||||
return {"model_request_count": 1}
|
||||
from langchain.agents.middleware.model_call_limits import ModelRequestLimitMiddleware
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user