From 0da364fbc8493d69a784f9344e5a934606bc9534 Mon Sep 17 00:00:00 2001 From: Sydney Runkle Date: Tue, 13 May 2025 19:22:33 -0700 Subject: [PATCH] removing shielded and returning early if handlers is empty --- libs/core/langchain_core/callbacks/manager.py | 66 +++++++++++++++++-- 1 file changed, 61 insertions(+), 5 deletions(-) diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index 1220ed8103b..f5037168898 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -520,6 +520,8 @@ class RunManager(BaseRunManager): Returns: Any: The result of the callback. """ + if not self.handlers: + return handle_event( self.handlers, "on_text", @@ -542,6 +544,8 @@ class RunManager(BaseRunManager): retry_state (RetryCallState): The retry state. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return handle_event( self.handlers, "on_retry", @@ -601,6 +605,8 @@ class AsyncRunManager(BaseRunManager, ABC): Returns: Any: The result of the callback. """ + if not self.handlers: + return await ahandle_event( self.handlers, "on_text", @@ -623,6 +629,8 @@ class AsyncRunManager(BaseRunManager, ABC): retry_state (RetryCallState): The retry state. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return await ahandle_event( self.handlers, "on_retry", @@ -675,6 +683,8 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): The chunk. Defaults to None. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return handle_event( self.handlers, "on_llm_new_token", @@ -694,6 +704,8 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): response (LLMResult): The LLM result. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return handle_event( self.handlers, "on_llm_end", @@ -718,6 +730,8 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): - response (LLMResult): The response which was generated before the error occurred. """ + if not self.handlers: + return handle_event( self.handlers, "on_llm_error", @@ -750,7 +764,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): inheritable_metadata=self.inheritable_metadata, ) - @shielded + async def on_llm_new_token( self, token: str, @@ -766,6 +780,8 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): The chunk. Defaults to None. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return await ahandle_event( self.handlers, "on_llm_new_token", @@ -786,6 +802,8 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): response (LLMResult): The LLM result. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return await ahandle_event( self.handlers, "on_llm_end", @@ -814,6 +832,8 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): """ + if not self.handlers: + return await ahandle_event( self.handlers, "on_llm_error", @@ -836,6 +856,8 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin): outputs (Union[dict[str, Any], Any]): The outputs of the chain. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return handle_event( self.handlers, "on_chain_end", @@ -858,6 +880,8 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin): error (Exception or KeyboardInterrupt): The error. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return handle_event( self.handlers, "on_chain_error", @@ -879,6 +903,8 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin): Returns: Any: The result of the callback. """ + if not self.handlers: + return handle_event( self.handlers, "on_agent_action", @@ -900,6 +926,8 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin): Returns: Any: The result of the callback. """ + if not self.handlers: + return handle_event( self.handlers, "on_agent_finish", @@ -942,6 +970,8 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin): outputs (Union[dict[str, Any], Any]): The outputs of the chain. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return await ahandle_event( self.handlers, "on_chain_end", @@ -965,6 +995,8 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin): error (Exception or KeyboardInterrupt): The error. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return await ahandle_event( self.handlers, "on_chain_error", @@ -976,7 +1008,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin): **kwargs, ) - @shielded + async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: """Run when agent action is received. @@ -987,6 +1019,8 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin): Returns: Any: The result of the callback. """ + if not self.handlers: + return await ahandle_event( self.handlers, "on_agent_action", @@ -998,7 +1032,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin): **kwargs, ) - @shielded + async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: """Run when agent finish is received. @@ -1009,6 +1043,8 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin): Returns: Any: The result of the callback. """ + if not self.handlers: + return await ahandle_event( self.handlers, "on_agent_finish", @@ -1035,6 +1071,8 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin): output (Any): The output of the tool. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return handle_event( self.handlers, "on_tool_end", @@ -1057,6 +1095,8 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin): error (Exception or KeyboardInterrupt): The error. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return handle_event( self.handlers, "on_tool_error", @@ -1089,7 +1129,7 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin): inheritable_metadata=self.inheritable_metadata, ) - @shielded + async def on_tool_end(self, output: Any, **kwargs: Any) -> None: """Async run when the tool ends running. @@ -1097,6 +1137,8 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin): output (Any): The output of the tool. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return await ahandle_event( self.handlers, "on_tool_end", @@ -1108,7 +1150,7 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin): **kwargs, ) - @shielded + async def on_tool_error( self, error: BaseException, @@ -1120,6 +1162,8 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin): error (Exception or KeyboardInterrupt): The error. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return await ahandle_event( self.handlers, "on_tool_error", @@ -1146,6 +1190,8 @@ class CallbackManagerForRetrieverRun(ParentRunManager, RetrieverManagerMixin): documents (Sequence[Document]): The retrieved documents. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return handle_event( self.handlers, "on_retriever_end", @@ -1168,6 +1214,8 @@ class CallbackManagerForRetrieverRun(ParentRunManager, RetrieverManagerMixin): error (BaseException): The error. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return handle_event( self.handlers, "on_retriever_error", @@ -1213,6 +1261,8 @@ class AsyncCallbackManagerForRetrieverRun( documents (Sequence[Document]): The retrieved documents. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return await ahandle_event( self.handlers, "on_retriever_end", @@ -1236,6 +1286,8 @@ class AsyncCallbackManagerForRetrieverRun( error (BaseException): The error. **kwargs (Any): Additional keyword arguments. """ + if not self.handlers: + return await ahandle_event( self.handlers, "on_retriever_error", @@ -1521,6 +1573,8 @@ class CallbackManager(BaseCallbackManager): .. versionadded:: 0.2.14 """ + if not self.handlers: + return if kwargs: msg = ( "The dispatcher API does not accept additional keyword arguments." @@ -1998,6 +2052,8 @@ class AsyncCallbackManager(BaseCallbackManager): .. versionadded:: 0.2.14 """ + if not self.handlers: + return if run_id is None: run_id = uuid.uuid4()