mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-12 12:59:07 +00:00
core[major]: On Tool End Observation Casting Fix (#18798)
This PR updates the on_tool_end handlers to return the raw output from the tool instead of casting it to a string. This is technically a breaking change, though it's impact is expected to be somewhat minimal. It will fix behavior in `astream_events` as well. Fixes the following issue #18760 raised by @eyurtsev --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
committed by
GitHub
parent
a96a6e0f2c
commit
43db4cd20e
@@ -133,7 +133,7 @@ class ToolManagerMixin:
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
output: Any,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
@@ -440,7 +440,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
async def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
output: Any,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
|
@@ -976,14 +976,15 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin):
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
output: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool ends running.
|
||||
|
||||
Args:
|
||||
output (str): The output of the tool.
|
||||
output (Any): The output of the tool.
|
||||
"""
|
||||
output = str(output)
|
||||
handle_event(
|
||||
self.handlers,
|
||||
"on_tool_end",
|
||||
@@ -1038,12 +1039,13 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
|
||||
)
|
||||
|
||||
@shielded
|
||||
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
async def on_tool_end(self, output: Any, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running.
|
||||
|
||||
Args:
|
||||
output (str): The output of the tool.
|
||||
output (Any): The output of the tool.
|
||||
"""
|
||||
output = str(output)
|
||||
await ahandle_event(
|
||||
self.handlers,
|
||||
"on_tool_end",
|
||||
|
@@ -37,13 +37,14 @@ class StdOutCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
output: Any,
|
||||
color: Optional[str] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""If not the final action, print out observation."""
|
||||
output = str(output)
|
||||
if observation_prefix is not None:
|
||||
print_text(f"\n{observation_prefix}")
|
||||
print_text(output, color=color or self.color)
|
||||
|
@@ -59,7 +59,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
||||
"""Run on agent action."""
|
||||
pass
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
|
@@ -410,17 +410,13 @@ class ChildTool(BaseTool):
|
||||
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
|
||||
f"or callable. Received: {self.handle_tool_error}"
|
||||
)
|
||||
run_manager.on_tool_end(
|
||||
str(observation), color="red", name=self.name, **kwargs
|
||||
)
|
||||
run_manager.on_tool_end(observation, color="red", name=self.name, **kwargs)
|
||||
return observation
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
run_manager.on_tool_error(e)
|
||||
raise e
|
||||
else:
|
||||
run_manager.on_tool_end(
|
||||
str(observation), color=color, name=self.name, **kwargs
|
||||
)
|
||||
run_manager.on_tool_end(observation, color=color, name=self.name, **kwargs)
|
||||
return observation
|
||||
|
||||
async def arun(
|
||||
@@ -502,7 +498,7 @@ class ChildTool(BaseTool):
|
||||
f"or callable. Received: {self.handle_tool_error}"
|
||||
)
|
||||
await run_manager.on_tool_end(
|
||||
str(observation), color="red", name=self.name, **kwargs
|
||||
observation, color="red", name=self.name, **kwargs
|
||||
)
|
||||
return observation
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
@@ -510,7 +506,7 @@ class ChildTool(BaseTool):
|
||||
raise e
|
||||
else:
|
||||
await run_manager.on_tool_end(
|
||||
str(observation), color=color, name=self.name, **kwargs
|
||||
observation, color=color, name=self.name, **kwargs
|
||||
)
|
||||
return observation
|
||||
|
||||
|
@@ -504,8 +504,9 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
self._on_tool_start(tool_run)
|
||||
return tool_run
|
||||
|
||||
def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> Run:
|
||||
def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> Run:
|
||||
"""End a trace for a tool run."""
|
||||
output = str(output)
|
||||
tool_run = self._get_run(run_id, run_type="tool")
|
||||
tool_run.outputs = {"output": output}
|
||||
tool_run.end_time = datetime.now(timezone.utc)
|
||||
|
Reference in New Issue
Block a user