core[major]: On Tool End Observation Casting Fix (#18798)

This PR updates the on_tool_end handlers to return the raw output from the tool instead of casting it to a string. 

This is technically a breaking change, though it's impact is expected to be somewhat minimal. It will fix behavior in `astream_events` as well.

Fixes the following issue #18760 raised by @eyurtsev

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Mohammad Mohtashim
2024-03-11 19:59:04 +05:00
committed by GitHub
parent a96a6e0f2c
commit 43db4cd20e
20 changed files with 42 additions and 34 deletions

View File

@@ -133,7 +133,7 @@ class ToolManagerMixin:
def on_tool_end(
self,
output: str,
output: Any,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
@@ -440,7 +440,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
async def on_tool_end(
self,
output: str,
output: Any,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,

View File

@@ -976,14 +976,15 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin):
def on_tool_end(
self,
output: str,
output: Any,
**kwargs: Any,
) -> None:
"""Run when tool ends running.
Args:
output (str): The output of the tool.
output (Any): The output of the tool.
"""
output = str(output)
handle_event(
self.handlers,
"on_tool_end",
@@ -1038,12 +1039,13 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
)
@shielded
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
async def on_tool_end(self, output: Any, **kwargs: Any) -> None:
"""Run when tool ends running.
Args:
output (str): The output of the tool.
output (Any): The output of the tool.
"""
output = str(output)
await ahandle_event(
self.handlers,
"on_tool_end",

View File

@@ -37,13 +37,14 @@ class StdOutCallbackHandler(BaseCallbackHandler):
def on_tool_end(
self,
output: str,
output: Any,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""If not the final action, print out observation."""
output = str(output)
if observation_prefix is not None:
print_text(f"\n{observation_prefix}")
print_text(output, color=color or self.color)

View File

@@ -59,7 +59,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
"""Run on agent action."""
pass
def on_tool_end(self, output: str, **kwargs: Any) -> None:
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
"""Run when tool ends running."""
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:

View File

@@ -410,17 +410,13 @@ class ChildTool(BaseTool):
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
f"or callable. Received: {self.handle_tool_error}"
)
run_manager.on_tool_end(
str(observation), color="red", name=self.name, **kwargs
)
run_manager.on_tool_end(observation, color="red", name=self.name, **kwargs)
return observation
except (Exception, KeyboardInterrupt) as e:
run_manager.on_tool_error(e)
raise e
else:
run_manager.on_tool_end(
str(observation), color=color, name=self.name, **kwargs
)
run_manager.on_tool_end(observation, color=color, name=self.name, **kwargs)
return observation
async def arun(
@@ -502,7 +498,7 @@ class ChildTool(BaseTool):
f"or callable. Received: {self.handle_tool_error}"
)
await run_manager.on_tool_end(
str(observation), color="red", name=self.name, **kwargs
observation, color="red", name=self.name, **kwargs
)
return observation
except (Exception, KeyboardInterrupt) as e:
@@ -510,7 +506,7 @@ class ChildTool(BaseTool):
raise e
else:
await run_manager.on_tool_end(
str(observation), color=color, name=self.name, **kwargs
observation, color=color, name=self.name, **kwargs
)
return observation

View File

@@ -504,8 +504,9 @@ class BaseTracer(BaseCallbackHandler, ABC):
self._on_tool_start(tool_run)
return tool_run
def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> Run:
def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> Run:
"""End a trace for a tool run."""
output = str(output)
tool_run = self._get_run(run_id, run_type="tool")
tool_run.outputs = {"output": output}
tool_run.end_time = datetime.now(timezone.utc)