From 17b799860ff4fb00cf68e733694f1c1277f87bff Mon Sep 17 00:00:00 2001 From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Wed, 14 May 2025 07:42:56 -0700 Subject: [PATCH] perf[core]: remove costly async helpers for non-end event handlers (#31230) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. Remove `shielded` decorator from non-end event handlers 2. Exit early with a `self.handlers` check instead of doing unnecessary asyncio work Using a benchmark that processes ~200k chunks (a poem about broccoli). Before: ~15s Circled in blue is unnecessary event handling time. This is addressed by point 2 above Screenshot 2025-05-14 at 7 37 53 AM After: ~4.2s The total time is largely reduced by the removal of the `shielded` decorator, which holds little significance for non-end handlers. Screenshot 2025-05-14 at 7 37 22 AM --- libs/core/langchain_core/callbacks/manager.py | 61 +++++++++++++++++-- 1 file changed, 56 insertions(+), 5 deletions(-) diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index 1220ed8103b..24e42994b12 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -520,6 +520,8 @@ class RunManager(BaseRunManager): Returns: Any: The result of the callback. """ + if not self.handlers: + return handle_event( self.handlers, "on_text", @@ -542,6 +544,8 @@ class RunManager(BaseRunManager): retry_state (RetryCallState): The retry state. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return handle_event( self.handlers, "on_retry", @@ -601,6 +605,8 @@ class AsyncRunManager(BaseRunManager, ABC): Returns: Any: The result of the callback. """ + if not self.handlers: + return await ahandle_event( self.handlers, "on_text", @@ -623,6 +629,8 @@ class AsyncRunManager(BaseRunManager, ABC): retry_state (RetryCallState): The retry state. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return await ahandle_event( self.handlers, "on_retry", @@ -675,6 +683,8 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): The chunk. Defaults to None. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return handle_event( self.handlers, "on_llm_new_token", @@ -694,6 +704,8 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): response (LLMResult): The LLM result. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return handle_event( self.handlers, "on_llm_end", @@ -718,6 +730,8 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): - response (LLMResult): The response which was generated before the error occurred. """ + if not self.handlers: + return handle_event( self.handlers, "on_llm_error", @@ -750,7 +764,6 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): inheritable_metadata=self.inheritable_metadata, ) - @shielded async def on_llm_new_token( self, token: str, @@ -766,6 +779,8 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): The chunk. Defaults to None. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return await ahandle_event( self.handlers, "on_llm_new_token", @@ -786,6 +801,8 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): response (LLMResult): The LLM result. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return await ahandle_event( self.handlers, "on_llm_end", @@ -814,6 +831,8 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): """ + if not self.handlers: + return await ahandle_event( self.handlers, "on_llm_error", @@ -836,6 +855,8 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin): outputs (Union[dict[str, Any], Any]): The outputs of the chain. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return handle_event( self.handlers, "on_chain_end", @@ -858,6 +879,8 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin): error (Exception or KeyboardInterrupt): The error. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return handle_event( self.handlers, "on_chain_error", @@ -879,6 +902,8 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin): Returns: Any: The result of the callback. """ + if not self.handlers: + return handle_event( self.handlers, "on_agent_action", @@ -900,6 +925,8 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin): Returns: Any: The result of the callback. """ + if not self.handlers: + return handle_event( self.handlers, "on_agent_finish", @@ -942,6 +969,8 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin): outputs (Union[dict[str, Any], Any]): The outputs of the chain. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return await ahandle_event( self.handlers, "on_chain_end", @@ -965,6 +994,8 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin): error (Exception or KeyboardInterrupt): The error. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return await ahandle_event( self.handlers, "on_chain_error", @@ -976,7 +1007,6 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin): **kwargs, ) - @shielded async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: """Run when agent action is received. @@ -987,6 +1017,8 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin): Returns: Any: The result of the callback. """ + if not self.handlers: + return await ahandle_event( self.handlers, "on_agent_action", @@ -998,7 +1030,6 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin): **kwargs, ) - @shielded async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: """Run when agent finish is received. @@ -1009,6 +1040,8 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin): Returns: Any: The result of the callback. """ + if not self.handlers: + return await ahandle_event( self.handlers, "on_agent_finish", @@ -1035,6 +1068,8 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin): output (Any): The output of the tool. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return handle_event( self.handlers, "on_tool_end", @@ -1057,6 +1092,8 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin): error (Exception or KeyboardInterrupt): The error. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return handle_event( self.handlers, "on_tool_error", @@ -1089,7 +1126,6 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin): inheritable_metadata=self.inheritable_metadata, ) - @shielded async def on_tool_end(self, output: Any, **kwargs: Any) -> None: """Async run when the tool ends running. @@ -1097,6 +1133,8 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin): output (Any): The output of the tool. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return await ahandle_event( self.handlers, "on_tool_end", @@ -1108,7 +1146,6 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin): **kwargs, ) - @shielded async def on_tool_error( self, error: BaseException, @@ -1120,6 +1157,8 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin): error (Exception or KeyboardInterrupt): The error. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return await ahandle_event( self.handlers, "on_tool_error", @@ -1146,6 +1185,8 @@ class CallbackManagerForRetrieverRun(ParentRunManager, RetrieverManagerMixin): documents (Sequence[Document]): The retrieved documents. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return handle_event( self.handlers, "on_retriever_end", @@ -1168,6 +1209,8 @@ class CallbackManagerForRetrieverRun(ParentRunManager, RetrieverManagerMixin): error (BaseException): The error. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return handle_event( self.handlers, "on_retriever_error", @@ -1213,6 +1256,8 @@ class AsyncCallbackManagerForRetrieverRun( documents (Sequence[Document]): The retrieved documents. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return await ahandle_event( self.handlers, "on_retriever_end", @@ -1236,6 +1281,8 @@ class AsyncCallbackManagerForRetrieverRun( error (BaseException): The error. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return await ahandle_event( self.handlers, "on_retriever_error", @@ -1521,6 +1568,8 @@ class CallbackManager(BaseCallbackManager): .. versionadded:: 0.2.14 """ + if not self.handlers: + return if kwargs: msg = ( "The dispatcher API does not accept additional keyword arguments." @@ -1998,6 +2047,8 @@ class AsyncCallbackManager(BaseCallbackManager): .. versionadded:: 0.2.14 """ + if not self.handlers: + return if run_id is None: run_id = uuid.uuid4()