mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-21 02:17:12 +00:00
core[major]: On Tool End Observation Casting Fix (#18798)
This PR updates the on_tool_end handlers to return the raw output from the tool instead of casting it to a string. This is technically a breaking change, though it's impact is expected to be somewhat minimal. It will fix behavior in `astream_events` as well. Fixes the following issue #18760 raised by @eyurtsev --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
a96a6e0f2c
commit
43db4cd20e
@ -1068,7 +1068,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
" def on_tool_end(\n",
|
" def on_tool_end(\n",
|
||||||
" self,\n",
|
" self,\n",
|
||||||
" output: str,\n",
|
" output: Any,\n",
|
||||||
" *,\n",
|
" *,\n",
|
||||||
" run_id: UUID,\n",
|
" run_id: UUID,\n",
|
||||||
" parent_run_id: Optional[UUID] = None,\n",
|
" parent_run_id: Optional[UUID] = None,\n",
|
||||||
@ -1076,7 +1076,7 @@
|
|||||||
" ) -> Any:\n",
|
" ) -> Any:\n",
|
||||||
" \"\"\"Run when tool ends running.\"\"\"\n",
|
" \"\"\"Run when tool ends running.\"\"\"\n",
|
||||||
" print(\"Tool end\")\n",
|
" print(\"Tool end\")\n",
|
||||||
" print(output)\n",
|
" print(str(output))\n",
|
||||||
"\n",
|
"\n",
|
||||||
" async def on_llm_end(\n",
|
" async def on_llm_end(\n",
|
||||||
" self,\n",
|
" self,\n",
|
||||||
|
@ -59,7 +59,7 @@ class BaseCallbackHandler:
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
"""Run when tool starts running."""
|
"""Run when tool starts running."""
|
||||||
|
|
||||||
def on_tool_end(self, output: str, **kwargs: Any) -> Any:
|
def on_tool_end(self, output: Any, **kwargs: Any) -> Any:
|
||||||
"""Run when tool ends running."""
|
"""Run when tool ends running."""
|
||||||
|
|
||||||
def on_tool_error(
|
def on_tool_error(
|
||||||
|
@ -314,8 +314,9 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
|
|
||||||
self._run.track(aim.Text(input_str), name="on_tool_start", context=resp)
|
self._run.track(aim.Text(input_str), name="on_tool_start", context=resp)
|
||||||
|
|
||||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
|
||||||
"""Run when tool ends running."""
|
"""Run when tool ends running."""
|
||||||
|
output = str(output)
|
||||||
aim = import_aim()
|
aim = import_aim()
|
||||||
self.step += 1
|
self.step += 1
|
||||||
self.tool_ends += 1
|
self.tool_ends += 1
|
||||||
|
@ -328,7 +328,7 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
|
|||||||
|
|
||||||
def on_tool_end(
|
def on_tool_end(
|
||||||
self,
|
self,
|
||||||
output: str,
|
output: Any,
|
||||||
observation_prefix: Optional[str] = None,
|
observation_prefix: Optional[str] = None,
|
||||||
llm_prefix: Optional[str] = None,
|
llm_prefix: Optional[str] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
|
@ -196,7 +196,7 @@ class ArizeCallbackHandler(BaseCallbackHandler):
|
|||||||
|
|
||||||
def on_tool_end(
|
def on_tool_end(
|
||||||
self,
|
self,
|
||||||
output: str,
|
output: Any,
|
||||||
observation_prefix: Optional[str] = None,
|
observation_prefix: Optional[str] = None,
|
||||||
llm_prefix: Optional[str] = None,
|
llm_prefix: Optional[str] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
|
@ -279,7 +279,7 @@ class ArthurCallbackHandler(BaseCallbackHandler):
|
|||||||
|
|
||||||
def on_tool_end(
|
def on_tool_end(
|
||||||
self,
|
self,
|
||||||
output: str,
|
output: Any,
|
||||||
observation_prefix: Optional[str] = None,
|
observation_prefix: Optional[str] = None,
|
||||||
llm_prefix: Optional[str] = None,
|
llm_prefix: Optional[str] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
|
@ -243,8 +243,9 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
if self.stream_logs:
|
if self.stream_logs:
|
||||||
self.logger.report_text(resp)
|
self.logger.report_text(resp)
|
||||||
|
|
||||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
|
||||||
"""Run when tool ends running."""
|
"""Run when tool ends running."""
|
||||||
|
output = str(output)
|
||||||
self.step += 1
|
self.step += 1
|
||||||
self.tool_ends += 1
|
self.tool_ends += 1
|
||||||
self.ends += 1
|
self.ends += 1
|
||||||
|
@ -303,8 +303,9 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
resp.update({"input_str": input_str})
|
resp.update({"input_str": input_str})
|
||||||
self.action_records.append(resp)
|
self.action_records.append(resp)
|
||||||
|
|
||||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
|
||||||
"""Run when tool ends running."""
|
"""Run when tool ends running."""
|
||||||
|
output = str(output)
|
||||||
self.step += 1
|
self.step += 1
|
||||||
self.tool_ends += 1
|
self.tool_ends += 1
|
||||||
self.ends += 1
|
self.ends += 1
|
||||||
|
@ -162,7 +162,7 @@ class DeepEvalCallbackHandler(BaseCallbackHandler):
|
|||||||
|
|
||||||
def on_tool_end(
|
def on_tool_end(
|
||||||
self,
|
self,
|
||||||
output: str,
|
output: Any,
|
||||||
observation_prefix: Optional[str] = None,
|
observation_prefix: Optional[str] = None,
|
||||||
llm_prefix: Optional[str] = None,
|
llm_prefix: Optional[str] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
|
@ -465,13 +465,14 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
|
|||||||
|
|
||||||
def on_tool_end(
|
def on_tool_end(
|
||||||
self,
|
self,
|
||||||
output: str,
|
output: Any,
|
||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Union[UUID, None] = None,
|
parent_run_id: Union[UUID, None] = None,
|
||||||
tags: Union[List[str], None] = None,
|
tags: Union[List[str], None] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
output = str(output)
|
||||||
if self.__has_valid_config is False:
|
if self.__has_valid_config is False:
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
|
@ -518,8 +518,9 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
self.records["action_records"].append(resp)
|
self.records["action_records"].append(resp)
|
||||||
self.mlflg.jsonf(resp, f"tool_start_{tool_starts}")
|
self.mlflg.jsonf(resp, f"tool_start_{tool_starts}")
|
||||||
|
|
||||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
|
||||||
"""Run when tool ends running."""
|
"""Run when tool ends running."""
|
||||||
|
output = str(output)
|
||||||
self.metrics["step"] += 1
|
self.metrics["step"] += 1
|
||||||
self.metrics["tool_ends"] += 1
|
self.metrics["tool_ends"] += 1
|
||||||
self.metrics["ends"] += 1
|
self.metrics["ends"] += 1
|
||||||
|
@ -186,8 +186,9 @@ class SageMakerCallbackHandler(BaseCallbackHandler):
|
|||||||
|
|
||||||
self.jsonf(resp, self.temp_dir, f"tool_start_{tool_starts}")
|
self.jsonf(resp, self.temp_dir, f"tool_start_{tool_starts}")
|
||||||
|
|
||||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
|
||||||
"""Run when tool ends running."""
|
"""Run when tool ends running."""
|
||||||
|
output = str(output)
|
||||||
self.metrics["step"] += 1
|
self.metrics["step"] += 1
|
||||||
self.metrics["tool_ends"] += 1
|
self.metrics["tool_ends"] += 1
|
||||||
self.metrics["ends"] += 1
|
self.metrics["ends"] += 1
|
||||||
|
@ -183,13 +183,13 @@ class LLMThought:
|
|||||||
|
|
||||||
def on_tool_end(
|
def on_tool_end(
|
||||||
self,
|
self,
|
||||||
output: str,
|
output: Any,
|
||||||
color: Optional[str] = None,
|
color: Optional[str] = None,
|
||||||
observation_prefix: Optional[str] = None,
|
observation_prefix: Optional[str] = None,
|
||||||
llm_prefix: Optional[str] = None,
|
llm_prefix: Optional[str] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._container.markdown(f"**{output}**")
|
self._container.markdown(f"**{str(output)}**")
|
||||||
|
|
||||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self._container.markdown("**Tool encountered an error...**")
|
self._container.markdown("**Tool encountered an error...**")
|
||||||
@ -363,12 +363,13 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
|
|||||||
|
|
||||||
def on_tool_end(
|
def on_tool_end(
|
||||||
self,
|
self,
|
||||||
output: str,
|
output: Any,
|
||||||
color: Optional[str] = None,
|
color: Optional[str] = None,
|
||||||
observation_prefix: Optional[str] = None,
|
observation_prefix: Optional[str] = None,
|
||||||
llm_prefix: Optional[str] = None,
|
llm_prefix: Optional[str] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
output = str(output)
|
||||||
self._require_current_thought().on_tool_end(
|
self._require_current_thought().on_tool_end(
|
||||||
output, color, observation_prefix, llm_prefix, **kwargs
|
output, color, observation_prefix, llm_prefix, **kwargs
|
||||||
)
|
)
|
||||||
|
@ -356,8 +356,9 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
if self.stream_logs:
|
if self.stream_logs:
|
||||||
self.run.log(resp)
|
self.run.log(resp)
|
||||||
|
|
||||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
|
||||||
"""Run when tool ends running."""
|
"""Run when tool ends running."""
|
||||||
|
output = str(output)
|
||||||
self.step += 1
|
self.step += 1
|
||||||
self.tool_ends += 1
|
self.tool_ends += 1
|
||||||
self.ends += 1
|
self.ends += 1
|
||||||
|
@ -133,7 +133,7 @@ class ToolManagerMixin:
|
|||||||
|
|
||||||
def on_tool_end(
|
def on_tool_end(
|
||||||
self,
|
self,
|
||||||
output: str,
|
output: Any,
|
||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
@ -440,7 +440,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
|||||||
|
|
||||||
async def on_tool_end(
|
async def on_tool_end(
|
||||||
self,
|
self,
|
||||||
output: str,
|
output: Any,
|
||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
|
@ -976,14 +976,15 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin):
|
|||||||
|
|
||||||
def on_tool_end(
|
def on_tool_end(
|
||||||
self,
|
self,
|
||||||
output: str,
|
output: Any,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when tool ends running.
|
"""Run when tool ends running.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
output (str): The output of the tool.
|
output (Any): The output of the tool.
|
||||||
"""
|
"""
|
||||||
|
output = str(output)
|
||||||
handle_event(
|
handle_event(
|
||||||
self.handlers,
|
self.handlers,
|
||||||
"on_tool_end",
|
"on_tool_end",
|
||||||
@ -1038,12 +1039,13 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@shielded
|
@shielded
|
||||||
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
async def on_tool_end(self, output: Any, **kwargs: Any) -> None:
|
||||||
"""Run when tool ends running.
|
"""Run when tool ends running.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
output (str): The output of the tool.
|
output (Any): The output of the tool.
|
||||||
"""
|
"""
|
||||||
|
output = str(output)
|
||||||
await ahandle_event(
|
await ahandle_event(
|
||||||
self.handlers,
|
self.handlers,
|
||||||
"on_tool_end",
|
"on_tool_end",
|
||||||
|
@ -37,13 +37,14 @@ class StdOutCallbackHandler(BaseCallbackHandler):
|
|||||||
|
|
||||||
def on_tool_end(
|
def on_tool_end(
|
||||||
self,
|
self,
|
||||||
output: str,
|
output: Any,
|
||||||
color: Optional[str] = None,
|
color: Optional[str] = None,
|
||||||
observation_prefix: Optional[str] = None,
|
observation_prefix: Optional[str] = None,
|
||||||
llm_prefix: Optional[str] = None,
|
llm_prefix: Optional[str] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""If not the final action, print out observation."""
|
"""If not the final action, print out observation."""
|
||||||
|
output = str(output)
|
||||||
if observation_prefix is not None:
|
if observation_prefix is not None:
|
||||||
print_text(f"\n{observation_prefix}")
|
print_text(f"\n{observation_prefix}")
|
||||||
print_text(output, color=color or self.color)
|
print_text(output, color=color or self.color)
|
||||||
|
@ -59,7 +59,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
|||||||
"""Run on agent action."""
|
"""Run on agent action."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
|
||||||
"""Run when tool ends running."""
|
"""Run when tool ends running."""
|
||||||
|
|
||||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
|
@ -410,17 +410,13 @@ class ChildTool(BaseTool):
|
|||||||
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
|
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
|
||||||
f"or callable. Received: {self.handle_tool_error}"
|
f"or callable. Received: {self.handle_tool_error}"
|
||||||
)
|
)
|
||||||
run_manager.on_tool_end(
|
run_manager.on_tool_end(observation, color="red", name=self.name, **kwargs)
|
||||||
str(observation), color="red", name=self.name, **kwargs
|
|
||||||
)
|
|
||||||
return observation
|
return observation
|
||||||
except (Exception, KeyboardInterrupt) as e:
|
except (Exception, KeyboardInterrupt) as e:
|
||||||
run_manager.on_tool_error(e)
|
run_manager.on_tool_error(e)
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
run_manager.on_tool_end(
|
run_manager.on_tool_end(observation, color=color, name=self.name, **kwargs)
|
||||||
str(observation), color=color, name=self.name, **kwargs
|
|
||||||
)
|
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
async def arun(
|
async def arun(
|
||||||
@ -502,7 +498,7 @@ class ChildTool(BaseTool):
|
|||||||
f"or callable. Received: {self.handle_tool_error}"
|
f"or callable. Received: {self.handle_tool_error}"
|
||||||
)
|
)
|
||||||
await run_manager.on_tool_end(
|
await run_manager.on_tool_end(
|
||||||
str(observation), color="red", name=self.name, **kwargs
|
observation, color="red", name=self.name, **kwargs
|
||||||
)
|
)
|
||||||
return observation
|
return observation
|
||||||
except (Exception, KeyboardInterrupt) as e:
|
except (Exception, KeyboardInterrupt) as e:
|
||||||
@ -510,7 +506,7 @@ class ChildTool(BaseTool):
|
|||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
await run_manager.on_tool_end(
|
await run_manager.on_tool_end(
|
||||||
str(observation), color=color, name=self.name, **kwargs
|
observation, color=color, name=self.name, **kwargs
|
||||||
)
|
)
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
|
@ -504,8 +504,9 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
self._on_tool_start(tool_run)
|
self._on_tool_start(tool_run)
|
||||||
return tool_run
|
return tool_run
|
||||||
|
|
||||||
def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> Run:
|
def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> Run:
|
||||||
"""End a trace for a tool run."""
|
"""End a trace for a tool run."""
|
||||||
|
output = str(output)
|
||||||
tool_run = self._get_run(run_id, run_type="tool")
|
tool_run = self._get_run(run_id, run_type="tool")
|
||||||
tool_run.outputs = {"output": output}
|
tool_run.outputs = {"output": output}
|
||||||
tool_run.end_time = datetime.now(timezone.utc)
|
tool_run.end_time = datetime.now(timezone.utc)
|
||||||
|
Loading…
Reference in New Issue
Block a user