mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-25 13:07:58 +00:00
core[major]: On Tool End Observation Casting Fix (#18798)
This PR updates the on_tool_end handlers to return the raw output from the tool instead of casting it to a string. This is technically a breaking change, though it's impact is expected to be somewhat minimal. It will fix behavior in `astream_events` as well. Fixes the following issue #18760 raised by @eyurtsev --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
committed by
GitHub
parent
a96a6e0f2c
commit
43db4cd20e
@@ -314,8 +314,9 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
|
||||
self._run.track(aim.Text(input_str), name="on_tool_start", context=resp)
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
output = str(output)
|
||||
aim = import_aim()
|
||||
self.step += 1
|
||||
self.tool_ends += 1
|
||||
|
@@ -328,7 +328,7 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
output: Any,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
|
@@ -196,7 +196,7 @@ class ArizeCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
output: Any,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
|
@@ -279,7 +279,7 @@ class ArthurCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
output: Any,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
|
@@ -243,8 +243,9 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
if self.stream_logs:
|
||||
self.logger.report_text(resp)
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
output = str(output)
|
||||
self.step += 1
|
||||
self.tool_ends += 1
|
||||
self.ends += 1
|
||||
|
@@ -303,8 +303,9 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
resp.update({"input_str": input_str})
|
||||
self.action_records.append(resp)
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
output = str(output)
|
||||
self.step += 1
|
||||
self.tool_ends += 1
|
||||
self.ends += 1
|
||||
|
@@ -162,7 +162,7 @@ class DeepEvalCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
output: Any,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
|
@@ -465,13 +465,14 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
output: Any,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Union[UUID, None] = None,
|
||||
tags: Union[List[str], None] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
output = str(output)
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
try:
|
||||
|
@@ -518,8 +518,9 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
self.records["action_records"].append(resp)
|
||||
self.mlflg.jsonf(resp, f"tool_start_{tool_starts}")
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
output = str(output)
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["tool_ends"] += 1
|
||||
self.metrics["ends"] += 1
|
||||
|
@@ -186,8 +186,9 @@ class SageMakerCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
self.jsonf(resp, self.temp_dir, f"tool_start_{tool_starts}")
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
output = str(output)
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["tool_ends"] += 1
|
||||
self.metrics["ends"] += 1
|
||||
|
@@ -183,13 +183,13 @@ class LLMThought:
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
output: Any,
|
||||
color: Optional[str] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self._container.markdown(f"**{output}**")
|
||||
self._container.markdown(f"**{str(output)}**")
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
self._container.markdown("**Tool encountered an error...**")
|
||||
@@ -363,12 +363,13 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
output: Any,
|
||||
color: Optional[str] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
output = str(output)
|
||||
self._require_current_thought().on_tool_end(
|
||||
output, color, observation_prefix, llm_prefix, **kwargs
|
||||
)
|
||||
|
@@ -356,8 +356,9 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
if self.stream_logs:
|
||||
self.run.log(resp)
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
output = str(output)
|
||||
self.step += 1
|
||||
self.tool_ends += 1
|
||||
self.ends += 1
|
||||
|
Reference in New Issue
Block a user