From 17b799860ff4fb00cf68e733694f1c1277f87bff Mon Sep 17 00:00:00 2001
From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com>
Date: Wed, 14 May 2025 07:42:56 -0700
Subject: [PATCH] perf[core]: remove costly async helpers for non-end event
handlers (#31230)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
1. Remove `shielded` decorator from non-end event handlers
2. Exit early with a `self.handlers` check instead of doing unnecessary
asyncio work
Using a benchmark that processes ~200k chunks (a poem about broccoli).
Before: ~15s
Circled in blue is unnecessary event handling time. This is addressed by
point 2 above
After: ~4.2s
The total time is largely reduced by the removal of the `shielded`
decorator, which holds little significance for non-end handlers.
---
libs/core/langchain_core/callbacks/manager.py | 61 +++++++++++++++++--
1 file changed, 56 insertions(+), 5 deletions(-)
diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py
index 1220ed8103b..24e42994b12 100644
--- a/libs/core/langchain_core/callbacks/manager.py
+++ b/libs/core/langchain_core/callbacks/manager.py
@@ -520,6 +520,8 @@ class RunManager(BaseRunManager):
Returns:
Any: The result of the callback.
"""
+ if not self.handlers:
+ return
handle_event(
self.handlers,
"on_text",
@@ -542,6 +544,8 @@ class RunManager(BaseRunManager):
retry_state (RetryCallState): The retry state.
**kwargs (Any): Additional keyword arguments.
"""
+ if not self.handlers:
+ return
handle_event(
self.handlers,
"on_retry",
@@ -601,6 +605,8 @@ class AsyncRunManager(BaseRunManager, ABC):
Returns:
Any: The result of the callback.
"""
+ if not self.handlers:
+ return
await ahandle_event(
self.handlers,
"on_text",
@@ -623,6 +629,8 @@ class AsyncRunManager(BaseRunManager, ABC):
retry_state (RetryCallState): The retry state.
**kwargs (Any): Additional keyword arguments.
"""
+ if not self.handlers:
+ return
await ahandle_event(
self.handlers,
"on_retry",
@@ -675,6 +683,8 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
The chunk. Defaults to None.
**kwargs (Any): Additional keyword arguments.
"""
+ if not self.handlers:
+ return
handle_event(
self.handlers,
"on_llm_new_token",
@@ -694,6 +704,8 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
response (LLMResult): The LLM result.
**kwargs (Any): Additional keyword arguments.
"""
+ if not self.handlers:
+ return
handle_event(
self.handlers,
"on_llm_end",
@@ -718,6 +730,8 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
- response (LLMResult): The response which was generated before
the error occurred.
"""
+ if not self.handlers:
+ return
handle_event(
self.handlers,
"on_llm_error",
@@ -750,7 +764,6 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
inheritable_metadata=self.inheritable_metadata,
)
- @shielded
async def on_llm_new_token(
self,
token: str,
@@ -766,6 +779,8 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
The chunk. Defaults to None.
**kwargs (Any): Additional keyword arguments.
"""
+ if not self.handlers:
+ return
await ahandle_event(
self.handlers,
"on_llm_new_token",
@@ -786,6 +801,8 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
response (LLMResult): The LLM result.
**kwargs (Any): Additional keyword arguments.
"""
+ if not self.handlers:
+ return
await ahandle_event(
self.handlers,
"on_llm_end",
@@ -814,6 +831,8 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
"""
+ if not self.handlers:
+ return
await ahandle_event(
self.handlers,
"on_llm_error",
@@ -836,6 +855,8 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
outputs (Union[dict[str, Any], Any]): The outputs of the chain.
**kwargs (Any): Additional keyword arguments.
"""
+ if not self.handlers:
+ return
handle_event(
self.handlers,
"on_chain_end",
@@ -858,6 +879,8 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
error (Exception or KeyboardInterrupt): The error.
**kwargs (Any): Additional keyword arguments.
"""
+ if not self.handlers:
+ return
handle_event(
self.handlers,
"on_chain_error",
@@ -879,6 +902,8 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
Returns:
Any: The result of the callback.
"""
+ if not self.handlers:
+ return
handle_event(
self.handlers,
"on_agent_action",
@@ -900,6 +925,8 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
Returns:
Any: The result of the callback.
"""
+ if not self.handlers:
+ return
handle_event(
self.handlers,
"on_agent_finish",
@@ -942,6 +969,8 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
outputs (Union[dict[str, Any], Any]): The outputs of the chain.
**kwargs (Any): Additional keyword arguments.
"""
+ if not self.handlers:
+ return
await ahandle_event(
self.handlers,
"on_chain_end",
@@ -965,6 +994,8 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
error (Exception or KeyboardInterrupt): The error.
**kwargs (Any): Additional keyword arguments.
"""
+ if not self.handlers:
+ return
await ahandle_event(
self.handlers,
"on_chain_error",
@@ -976,7 +1007,6 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
**kwargs,
)
- @shielded
async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run when agent action is received.
@@ -987,6 +1017,8 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
Returns:
Any: The result of the callback.
"""
+ if not self.handlers:
+ return
await ahandle_event(
self.handlers,
"on_agent_action",
@@ -998,7 +1030,6 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
**kwargs,
)
- @shielded
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run when agent finish is received.
@@ -1009,6 +1040,8 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
Returns:
Any: The result of the callback.
"""
+ if not self.handlers:
+ return
await ahandle_event(
self.handlers,
"on_agent_finish",
@@ -1035,6 +1068,8 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin):
output (Any): The output of the tool.
**kwargs (Any): Additional keyword arguments.
"""
+ if not self.handlers:
+ return
handle_event(
self.handlers,
"on_tool_end",
@@ -1057,6 +1092,8 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin):
error (Exception or KeyboardInterrupt): The error.
**kwargs (Any): Additional keyword arguments.
"""
+ if not self.handlers:
+ return
handle_event(
self.handlers,
"on_tool_error",
@@ -1089,7 +1126,6 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
inheritable_metadata=self.inheritable_metadata,
)
- @shielded
async def on_tool_end(self, output: Any, **kwargs: Any) -> None:
"""Async run when the tool ends running.
@@ -1097,6 +1133,8 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
output (Any): The output of the tool.
**kwargs (Any): Additional keyword arguments.
"""
+ if not self.handlers:
+ return
await ahandle_event(
self.handlers,
"on_tool_end",
@@ -1108,7 +1146,6 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
**kwargs,
)
- @shielded
async def on_tool_error(
self,
error: BaseException,
@@ -1120,6 +1157,8 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
error (Exception or KeyboardInterrupt): The error.
**kwargs (Any): Additional keyword arguments.
"""
+ if not self.handlers:
+ return
await ahandle_event(
self.handlers,
"on_tool_error",
@@ -1146,6 +1185,8 @@ class CallbackManagerForRetrieverRun(ParentRunManager, RetrieverManagerMixin):
documents (Sequence[Document]): The retrieved documents.
**kwargs (Any): Additional keyword arguments.
"""
+ if not self.handlers:
+ return
handle_event(
self.handlers,
"on_retriever_end",
@@ -1168,6 +1209,8 @@ class CallbackManagerForRetrieverRun(ParentRunManager, RetrieverManagerMixin):
error (BaseException): The error.
**kwargs (Any): Additional keyword arguments.
"""
+ if not self.handlers:
+ return
handle_event(
self.handlers,
"on_retriever_error",
@@ -1213,6 +1256,8 @@ class AsyncCallbackManagerForRetrieverRun(
documents (Sequence[Document]): The retrieved documents.
**kwargs (Any): Additional keyword arguments.
"""
+ if not self.handlers:
+ return
await ahandle_event(
self.handlers,
"on_retriever_end",
@@ -1236,6 +1281,8 @@ class AsyncCallbackManagerForRetrieverRun(
error (BaseException): The error.
**kwargs (Any): Additional keyword arguments.
"""
+ if not self.handlers:
+ return
await ahandle_event(
self.handlers,
"on_retriever_error",
@@ -1521,6 +1568,8 @@ class CallbackManager(BaseCallbackManager):
.. versionadded:: 0.2.14
"""
+ if not self.handlers:
+ return
if kwargs:
msg = (
"The dispatcher API does not accept additional keyword arguments."
@@ -1998,6 +2047,8 @@ class AsyncCallbackManager(BaseCallbackManager):
.. versionadded:: 0.2.14
"""
+ if not self.handlers:
+ return
if run_id is None:
run_id = uuid.uuid4()