From 43db4cd20e0e718f368267528706f92bf604bac9 Mon Sep 17 00:00:00 2001 From: Mohammad Mohtashim <45242107+keenborder786@users.noreply.github.com> Date: Mon, 11 Mar 2024 19:59:04 +0500 Subject: [PATCH] core[major]: On Tool End Observation Casting Fix (#18798) This PR updates the on_tool_end handlers to return the raw output from the tool instead of casting it to a string. This is technically a breaking change, though it's impact is expected to be somewhat minimal. It will fix behavior in `astream_events` as well. Fixes the following issue #18760 raised by @eyurtsev --------- Co-authored-by: Eugene Yurtsev --- docs/docs/modules/agents/how_to/streaming.ipynb | 4 ++-- docs/docs/modules/callbacks/index.mdx | 2 +- .../langchain_community/callbacks/aim_callback.py | 3 ++- .../callbacks/argilla_callback.py | 2 +- .../langchain_community/callbacks/arize_callback.py | 2 +- .../langchain_community/callbacks/arthur_callback.py | 2 +- .../callbacks/clearml_callback.py | 3 ++- .../callbacks/comet_ml_callback.py | 3 ++- .../callbacks/confident_callback.py | 2 +- .../callbacks/llmonitor_callback.py | 3 ++- .../langchain_community/callbacks/mlflow_callback.py | 3 ++- .../callbacks/sagemaker_callback.py | 3 ++- .../streamlit/streamlit_callback_handler.py | 7 ++++--- .../langchain_community/callbacks/wandb_callback.py | 3 ++- libs/core/langchain_core/callbacks/base.py | 4 ++-- libs/core/langchain_core/callbacks/manager.py | 10 ++++++---- libs/core/langchain_core/callbacks/stdout.py | 3 ++- .../langchain_core/callbacks/streaming_stdout.py | 2 +- libs/core/langchain_core/tools.py | 12 ++++-------- libs/core/langchain_core/tracers/base.py | 3 ++- 20 files changed, 42 insertions(+), 34 deletions(-) diff --git a/docs/docs/modules/agents/how_to/streaming.ipynb b/docs/docs/modules/agents/how_to/streaming.ipynb index b6095de0a46..abe3d1f5e36 100644 --- a/docs/docs/modules/agents/how_to/streaming.ipynb +++ b/docs/docs/modules/agents/how_to/streaming.ipynb @@ -1068,7 +1068,7 @@ "\n", " def on_tool_end(\n", " self,\n", - " output: str,\n", + " output: Any,\n", " *,\n", " run_id: UUID,\n", " parent_run_id: Optional[UUID] = None,\n", @@ -1076,7 +1076,7 @@ " ) -> Any:\n", " \"\"\"Run when tool ends running.\"\"\"\n", " print(\"Tool end\")\n", - " print(output)\n", + " print(str(output))\n", "\n", " async def on_llm_end(\n", " self,\n", diff --git a/docs/docs/modules/callbacks/index.mdx b/docs/docs/modules/callbacks/index.mdx index 7be32ccd0ec..b41f90c046b 100644 --- a/docs/docs/modules/callbacks/index.mdx +++ b/docs/docs/modules/callbacks/index.mdx @@ -59,7 +59,7 @@ class BaseCallbackHandler: ) -> Any: """Run when tool starts running.""" - def on_tool_end(self, output: str, **kwargs: Any) -> Any: + def on_tool_end(self, output: Any, **kwargs: Any) -> Any: """Run when tool ends running.""" def on_tool_error( diff --git a/libs/community/langchain_community/callbacks/aim_callback.py b/libs/community/langchain_community/callbacks/aim_callback.py index e36cfcceafe..46d7987c54e 100644 --- a/libs/community/langchain_community/callbacks/aim_callback.py +++ b/libs/community/langchain_community/callbacks/aim_callback.py @@ -314,8 +314,9 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): self._run.track(aim.Text(input_str), name="on_tool_start", context=resp) - def on_tool_end(self, output: str, **kwargs: Any) -> None: + def on_tool_end(self, output: Any, **kwargs: Any) -> None: """Run when tool ends running.""" + output = str(output) aim = import_aim() self.step += 1 self.tool_ends += 1 diff --git a/libs/community/langchain_community/callbacks/argilla_callback.py b/libs/community/langchain_community/callbacks/argilla_callback.py index 157075a2832..ad1fa9c9014 100644 --- a/libs/community/langchain_community/callbacks/argilla_callback.py +++ b/libs/community/langchain_community/callbacks/argilla_callback.py @@ -328,7 +328,7 @@ class ArgillaCallbackHandler(BaseCallbackHandler): def on_tool_end( self, - output: str, + output: Any, observation_prefix: Optional[str] = None, llm_prefix: Optional[str] = None, **kwargs: Any, diff --git a/libs/community/langchain_community/callbacks/arize_callback.py b/libs/community/langchain_community/callbacks/arize_callback.py index 83cba048455..d7905d8d8ca 100644 --- a/libs/community/langchain_community/callbacks/arize_callback.py +++ b/libs/community/langchain_community/callbacks/arize_callback.py @@ -196,7 +196,7 @@ class ArizeCallbackHandler(BaseCallbackHandler): def on_tool_end( self, - output: str, + output: Any, observation_prefix: Optional[str] = None, llm_prefix: Optional[str] = None, **kwargs: Any, diff --git a/libs/community/langchain_community/callbacks/arthur_callback.py b/libs/community/langchain_community/callbacks/arthur_callback.py index b0b28328ab2..c9a81650e54 100644 --- a/libs/community/langchain_community/callbacks/arthur_callback.py +++ b/libs/community/langchain_community/callbacks/arthur_callback.py @@ -279,7 +279,7 @@ class ArthurCallbackHandler(BaseCallbackHandler): def on_tool_end( self, - output: str, + output: Any, observation_prefix: Optional[str] = None, llm_prefix: Optional[str] = None, **kwargs: Any, diff --git a/libs/community/langchain_community/callbacks/clearml_callback.py b/libs/community/langchain_community/callbacks/clearml_callback.py index 8b8aa8e98ab..34358973b41 100644 --- a/libs/community/langchain_community/callbacks/clearml_callback.py +++ b/libs/community/langchain_community/callbacks/clearml_callback.py @@ -243,8 +243,9 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): if self.stream_logs: self.logger.report_text(resp) - def on_tool_end(self, output: str, **kwargs: Any) -> None: + def on_tool_end(self, output: Any, **kwargs: Any) -> None: """Run when tool ends running.""" + output = str(output) self.step += 1 self.tool_ends += 1 self.ends += 1 diff --git a/libs/community/langchain_community/callbacks/comet_ml_callback.py b/libs/community/langchain_community/callbacks/comet_ml_callback.py index 499e93a07b0..b7e4e918d0c 100644 --- a/libs/community/langchain_community/callbacks/comet_ml_callback.py +++ b/libs/community/langchain_community/callbacks/comet_ml_callback.py @@ -303,8 +303,9 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): resp.update({"input_str": input_str}) self.action_records.append(resp) - def on_tool_end(self, output: str, **kwargs: Any) -> None: + def on_tool_end(self, output: Any, **kwargs: Any) -> None: """Run when tool ends running.""" + output = str(output) self.step += 1 self.tool_ends += 1 self.ends += 1 diff --git a/libs/community/langchain_community/callbacks/confident_callback.py b/libs/community/langchain_community/callbacks/confident_callback.py index 4162dc12315..b078abf4507 100644 --- a/libs/community/langchain_community/callbacks/confident_callback.py +++ b/libs/community/langchain_community/callbacks/confident_callback.py @@ -162,7 +162,7 @@ class DeepEvalCallbackHandler(BaseCallbackHandler): def on_tool_end( self, - output: str, + output: Any, observation_prefix: Optional[str] = None, llm_prefix: Optional[str] = None, **kwargs: Any, diff --git a/libs/community/langchain_community/callbacks/llmonitor_callback.py b/libs/community/langchain_community/callbacks/llmonitor_callback.py index f4f2882dac2..32e8820dbdb 100644 --- a/libs/community/langchain_community/callbacks/llmonitor_callback.py +++ b/libs/community/langchain_community/callbacks/llmonitor_callback.py @@ -465,13 +465,14 @@ class LLMonitorCallbackHandler(BaseCallbackHandler): def on_tool_end( self, - output: str, + output: Any, *, run_id: UUID, parent_run_id: Union[UUID, None] = None, tags: Union[List[str], None] = None, **kwargs: Any, ) -> None: + output = str(output) if self.__has_valid_config is False: return try: diff --git a/libs/community/langchain_community/callbacks/mlflow_callback.py b/libs/community/langchain_community/callbacks/mlflow_callback.py index b81da0c166a..294c3ec6825 100644 --- a/libs/community/langchain_community/callbacks/mlflow_callback.py +++ b/libs/community/langchain_community/callbacks/mlflow_callback.py @@ -518,8 +518,9 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): self.records["action_records"].append(resp) self.mlflg.jsonf(resp, f"tool_start_{tool_starts}") - def on_tool_end(self, output: str, **kwargs: Any) -> None: + def on_tool_end(self, output: Any, **kwargs: Any) -> None: """Run when tool ends running.""" + output = str(output) self.metrics["step"] += 1 self.metrics["tool_ends"] += 1 self.metrics["ends"] += 1 diff --git a/libs/community/langchain_community/callbacks/sagemaker_callback.py b/libs/community/langchain_community/callbacks/sagemaker_callback.py index b791425ff00..45295a0b5f7 100644 --- a/libs/community/langchain_community/callbacks/sagemaker_callback.py +++ b/libs/community/langchain_community/callbacks/sagemaker_callback.py @@ -186,8 +186,9 @@ class SageMakerCallbackHandler(BaseCallbackHandler): self.jsonf(resp, self.temp_dir, f"tool_start_{tool_starts}") - def on_tool_end(self, output: str, **kwargs: Any) -> None: + def on_tool_end(self, output: Any, **kwargs: Any) -> None: """Run when tool ends running.""" + output = str(output) self.metrics["step"] += 1 self.metrics["tool_ends"] += 1 self.metrics["ends"] += 1 diff --git a/libs/community/langchain_community/callbacks/streamlit/streamlit_callback_handler.py b/libs/community/langchain_community/callbacks/streamlit/streamlit_callback_handler.py index 725862c53b6..89183fd5f46 100644 --- a/libs/community/langchain_community/callbacks/streamlit/streamlit_callback_handler.py +++ b/libs/community/langchain_community/callbacks/streamlit/streamlit_callback_handler.py @@ -183,13 +183,13 @@ class LLMThought: def on_tool_end( self, - output: str, + output: Any, color: Optional[str] = None, observation_prefix: Optional[str] = None, llm_prefix: Optional[str] = None, **kwargs: Any, ) -> None: - self._container.markdown(f"**{output}**") + self._container.markdown(f"**{str(output)}**") def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: self._container.markdown("**Tool encountered an error...**") @@ -363,12 +363,13 @@ class StreamlitCallbackHandler(BaseCallbackHandler): def on_tool_end( self, - output: str, + output: Any, color: Optional[str] = None, observation_prefix: Optional[str] = None, llm_prefix: Optional[str] = None, **kwargs: Any, ) -> None: + output = str(output) self._require_current_thought().on_tool_end( output, color, observation_prefix, llm_prefix, **kwargs ) diff --git a/libs/community/langchain_community/callbacks/wandb_callback.py b/libs/community/langchain_community/callbacks/wandb_callback.py index 035c44640dd..2ce3cf527ff 100644 --- a/libs/community/langchain_community/callbacks/wandb_callback.py +++ b/libs/community/langchain_community/callbacks/wandb_callback.py @@ -356,8 +356,9 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): if self.stream_logs: self.run.log(resp) - def on_tool_end(self, output: str, **kwargs: Any) -> None: + def on_tool_end(self, output: Any, **kwargs: Any) -> None: """Run when tool ends running.""" + output = str(output) self.step += 1 self.tool_ends += 1 self.ends += 1 diff --git a/libs/core/langchain_core/callbacks/base.py b/libs/core/langchain_core/callbacks/base.py index eb4d1de7060..900cb2fcffc 100644 --- a/libs/core/langchain_core/callbacks/base.py +++ b/libs/core/langchain_core/callbacks/base.py @@ -133,7 +133,7 @@ class ToolManagerMixin: def on_tool_end( self, - output: str, + output: Any, *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -440,7 +440,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): async def on_tool_end( self, - output: str, + output: Any, *, run_id: UUID, parent_run_id: Optional[UUID] = None, diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index b1f103871f1..3fceb33f009 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -976,14 +976,15 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin): def on_tool_end( self, - output: str, + output: Any, **kwargs: Any, ) -> None: """Run when tool ends running. Args: - output (str): The output of the tool. + output (Any): The output of the tool. """ + output = str(output) handle_event( self.handlers, "on_tool_end", @@ -1038,12 +1039,13 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin): ) @shielded - async def on_tool_end(self, output: str, **kwargs: Any) -> None: + async def on_tool_end(self, output: Any, **kwargs: Any) -> None: """Run when tool ends running. Args: - output (str): The output of the tool. + output (Any): The output of the tool. """ + output = str(output) await ahandle_event( self.handlers, "on_tool_end", diff --git a/libs/core/langchain_core/callbacks/stdout.py b/libs/core/langchain_core/callbacks/stdout.py index e129792a010..0408eb19845 100644 --- a/libs/core/langchain_core/callbacks/stdout.py +++ b/libs/core/langchain_core/callbacks/stdout.py @@ -37,13 +37,14 @@ class StdOutCallbackHandler(BaseCallbackHandler): def on_tool_end( self, - output: str, + output: Any, color: Optional[str] = None, observation_prefix: Optional[str] = None, llm_prefix: Optional[str] = None, **kwargs: Any, ) -> None: """If not the final action, print out observation.""" + output = str(output) if observation_prefix is not None: print_text(f"\n{observation_prefix}") print_text(output, color=color or self.color) diff --git a/libs/core/langchain_core/callbacks/streaming_stdout.py b/libs/core/langchain_core/callbacks/streaming_stdout.py index aaac043ff93..4966aa2e11c 100644 --- a/libs/core/langchain_core/callbacks/streaming_stdout.py +++ b/libs/core/langchain_core/callbacks/streaming_stdout.py @@ -59,7 +59,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler): """Run on agent action.""" pass - def on_tool_end(self, output: str, **kwargs: Any) -> None: + def on_tool_end(self, output: Any, **kwargs: Any) -> None: """Run when tool ends running.""" def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: diff --git a/libs/core/langchain_core/tools.py b/libs/core/langchain_core/tools.py index 356e326b50f..a67a642329c 100644 --- a/libs/core/langchain_core/tools.py +++ b/libs/core/langchain_core/tools.py @@ -410,17 +410,13 @@ class ChildTool(BaseTool): f"Got unexpected type of `handle_tool_error`. Expected bool, str " f"or callable. Received: {self.handle_tool_error}" ) - run_manager.on_tool_end( - str(observation), color="red", name=self.name, **kwargs - ) + run_manager.on_tool_end(observation, color="red", name=self.name, **kwargs) return observation except (Exception, KeyboardInterrupt) as e: run_manager.on_tool_error(e) raise e else: - run_manager.on_tool_end( - str(observation), color=color, name=self.name, **kwargs - ) + run_manager.on_tool_end(observation, color=color, name=self.name, **kwargs) return observation async def arun( @@ -502,7 +498,7 @@ class ChildTool(BaseTool): f"or callable. Received: {self.handle_tool_error}" ) await run_manager.on_tool_end( - str(observation), color="red", name=self.name, **kwargs + observation, color="red", name=self.name, **kwargs ) return observation except (Exception, KeyboardInterrupt) as e: @@ -510,7 +506,7 @@ class ChildTool(BaseTool): raise e else: await run_manager.on_tool_end( - str(observation), color=color, name=self.name, **kwargs + observation, color=color, name=self.name, **kwargs ) return observation diff --git a/libs/core/langchain_core/tracers/base.py b/libs/core/langchain_core/tracers/base.py index 2481df4e254..a20d85b07e6 100644 --- a/libs/core/langchain_core/tracers/base.py +++ b/libs/core/langchain_core/tracers/base.py @@ -504,8 +504,9 @@ class BaseTracer(BaseCallbackHandler, ABC): self._on_tool_start(tool_run) return tool_run - def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> Run: + def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> Run: """End a trace for a tool run.""" + output = str(output) tool_run = self._get_run(run_id, run_type="tool") tool_run.outputs = {"output": output} tool_run.end_time = datetime.now(timezone.utc)