rfc: callback changes (#1165)

conceptually, no reason a tool should know what an "agent action" is

unless any objections, can change in all callback handlers
This commit is contained in:
Harrison Chase
2023-02-20 22:54:15 -08:00
committed by GitHub
parent fb83cd4ff4
commit b7708bbec6
14 changed files with 131 additions and 66 deletions

View File

@@ -19,7 +19,7 @@ from langchain.callbacks.tracers.base import (
TracerSession,
)
from langchain.callbacks.tracers.schemas import TracerSessionCreate
from langchain.schema import AgentAction, LLMResult
from langchain.schema import LLMResult
TEST_SESSION_ID = 2023
@@ -47,7 +47,7 @@ def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]:
serialized={},
tool_input="test",
output="test",
action="action",
action="{}",
session_id=TEST_SESSION_ID,
error=None,
child_runs=[
@@ -84,9 +84,7 @@ def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]:
def _perform_nested_run(tracer: BaseTracer) -> None:
"""Perform a nested run."""
tracer.on_chain_start(serialized={}, inputs={})
tracer.on_tool_start(
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
)
tracer.on_tool_start(serialized={}, input_str="test")
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult([[]]))
tracer.on_tool_end("test")
@@ -303,16 +301,14 @@ def test_tracer_tool_run() -> None:
serialized={},
tool_input="test",
output="test",
action="action",
action="{}",
session_id=TEST_SESSION_ID,
error=None,
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_tool_start(
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
)
tracer.on_tool_start(serialized={}, input_str="test")
tracer.on_tool_end("test")
assert tracer.runs == [compare_run]
@@ -390,16 +386,14 @@ def test_tracer_tool_run_on_error() -> None:
serialized={},
tool_input="test",
output=None,
action="action",
action="{}",
session_id=TEST_SESSION_ID,
error=repr(exception),
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_tool_start(
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
)
tracer.on_tool_start(serialized={}, input_str="test")
tracer.on_tool_error(exception)
assert tracer.runs == [compare_run]
@@ -418,9 +412,7 @@ def test_tracer_nested_runs_on_error() -> None:
tracer.on_llm_end(response=LLMResult([[]]))
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult([[]]))
tracer.on_tool_start(
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
)
tracer.on_tool_start(serialized={}, input_str="test")
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_error(exception)
tracer.on_tool_error(exception)
@@ -473,7 +465,7 @@ def test_tracer_nested_runs_on_error() -> None:
error=repr(exception),
tool_input="test",
output=None,
action="action",
action="{}",
child_runs=[
LLMRun(
id=None,