rfc: callback changes (#1165)

conceptually, no reason a tool should know what an "agent action" is

unless any objections, can change in all callback handlers
This commit is contained in:
Harrison Chase
2023-02-20 22:54:15 -08:00
committed by GitHub
parent fb83cd4ff4
commit b7708bbec6
14 changed files with 131 additions and 66 deletions

View File

@@ -109,8 +109,10 @@ def test_agent_with_callbacks_global() -> None:
# 1 top level chain run runs, 2 LLMChain runs, 2 LLM runs, 1 tool run
assert handler.chain_starts == handler.chain_ends == 3
assert handler.llm_starts == handler.llm_ends == 2
assert handler.tool_starts == handler.tool_ends == 1
assert handler.starts == 6
assert handler.tool_starts == 2
assert handler.tool_ends == 1
# 1 extra agent action
assert handler.starts == 7
# 1 extra agent end
assert handler.ends == 7
assert handler.errors == 0
@@ -155,8 +157,10 @@ def test_agent_with_callbacks_local() -> None:
# 1 top level chain run, 2 LLMChain starts, 2 LLM runs, 1 tool run
assert handler.chain_starts == handler.chain_ends == 3
assert handler.llm_starts == handler.llm_ends == 2
assert handler.tool_starts == handler.tool_ends == 1
assert handler.starts == 6
assert handler.tool_starts == 2
assert handler.tool_ends == 1
# 1 extra agent action
assert handler.starts == 7
# 1 extra agent end
assert handler.ends == 7
assert handler.errors == 0

View File

@@ -2,7 +2,6 @@
import pytest
from langchain.agents.tools import Tool, tool
from langchain.schema import AgentAction
def test_unnamed_decorator() -> None:
@@ -101,7 +100,4 @@ async def test_create_async_tool() -> None:
assert test_tool.name == "test_name"
assert test_tool.description == "test_description"
assert test_tool.coroutine is not None
assert (
await test_tool.arun(AgentAction(tool_input="foo", tool="test_name", log=""))
== "foo"
)
assert await test_tool.arun("foo") == "foo"

View File

@@ -94,7 +94,7 @@ class FakeCallbackHandler(BaseFakeCallbackHandler, BaseCallbackHandler):
self.errors += 1
def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
self.tool_starts += 1
@@ -120,6 +120,11 @@ class FakeCallbackHandler(BaseFakeCallbackHandler, BaseCallbackHandler):
self.agent_ends += 1
self.ends += 1
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
self.tool_starts += 1
self.starts += 1
class FakeAsyncCallbackHandler(BaseFakeCallbackHandler, AsyncCallbackHandler):
"""Fake async callback handler for testing."""
@@ -165,7 +170,7 @@ class FakeAsyncCallbackHandler(BaseFakeCallbackHandler, AsyncCallbackHandler):
self.errors += 1
async def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
self.tool_starts += 1
@@ -190,3 +195,8 @@ class FakeAsyncCallbackHandler(BaseFakeCallbackHandler, AsyncCallbackHandler):
"""Run when agent ends running."""
self.agent_ends += 1
self.ends += 1
async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> None:
"""Run on agent action."""
self.tool_starts += 1
self.starts += 1

View File

@@ -9,7 +9,7 @@ from langchain.callbacks.base import (
CallbackManager,
)
from langchain.callbacks.shared import SharedCallbackManager
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema import AgentFinish, LLMResult
from tests.unit_tests.callbacks.fake_callback_handler import (
BaseFakeCallbackHandler,
FakeAsyncCallbackHandler,
@@ -27,7 +27,7 @@ def _test_callback_manager(
manager.on_chain_start({"name": "foo"}, {})
manager.on_chain_end({})
manager.on_chain_error(Exception())
manager.on_tool_start({}, AgentAction("", "", ""))
manager.on_tool_start({}, "")
manager.on_tool_end("")
manager.on_tool_error(Exception())
manager.on_agent_finish(AgentFinish(log="", return_values={}))
@@ -44,7 +44,7 @@ async def _test_callback_manager_async(
await manager.on_chain_start({"name": "foo"}, {})
await manager.on_chain_end({})
await manager.on_chain_error(Exception())
await manager.on_tool_start({}, AgentAction("", "", ""))
await manager.on_tool_start({}, "")
await manager.on_tool_end("")
await manager.on_tool_error(Exception())
await manager.on_agent_finish(AgentFinish(log="", return_values={}))
@@ -73,7 +73,7 @@ def _test_callback_manager_pass_in_verbose(
manager.on_chain_start({"name": "foo"}, {}, verbose=True)
manager.on_chain_end({}, verbose=True)
manager.on_chain_error(Exception(), verbose=True)
manager.on_tool_start({}, AgentAction("", "", ""), verbose=True)
manager.on_tool_start({}, "", verbose=True)
manager.on_tool_end("", verbose=True)
manager.on_tool_error(Exception(), verbose=True)
manager.on_agent_finish(AgentFinish(log="", return_values={}), verbose=True)
@@ -136,7 +136,7 @@ def test_ignore_agent() -> None:
handler1 = FakeCallbackHandler(ignore_agent_=True, always_verbose_=True)
handler2 = FakeCallbackHandler(always_verbose_=True)
manager = CallbackManager(handlers=[handler1, handler2])
manager.on_tool_start({}, AgentAction("", "", ""), verbose=True)
manager.on_tool_start({}, "", verbose=True)
manager.on_tool_end("", verbose=True)
manager.on_tool_error(Exception(), verbose=True)
manager.on_agent_finish(AgentFinish({}, ""), verbose=True)

View File

@@ -19,7 +19,7 @@ from langchain.callbacks.tracers.base import (
TracerSession,
)
from langchain.callbacks.tracers.schemas import TracerSessionCreate
from langchain.schema import AgentAction, LLMResult
from langchain.schema import LLMResult
TEST_SESSION_ID = 2023
@@ -47,7 +47,7 @@ def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]:
serialized={},
tool_input="test",
output="test",
action="action",
action="{}",
session_id=TEST_SESSION_ID,
error=None,
child_runs=[
@@ -84,9 +84,7 @@ def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]:
def _perform_nested_run(tracer: BaseTracer) -> None:
"""Perform a nested run."""
tracer.on_chain_start(serialized={}, inputs={})
tracer.on_tool_start(
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
)
tracer.on_tool_start(serialized={}, input_str="test")
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult([[]]))
tracer.on_tool_end("test")
@@ -303,16 +301,14 @@ def test_tracer_tool_run() -> None:
serialized={},
tool_input="test",
output="test",
action="action",
action="{}",
session_id=TEST_SESSION_ID,
error=None,
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_tool_start(
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
)
tracer.on_tool_start(serialized={}, input_str="test")
tracer.on_tool_end("test")
assert tracer.runs == [compare_run]
@@ -390,16 +386,14 @@ def test_tracer_tool_run_on_error() -> None:
serialized={},
tool_input="test",
output=None,
action="action",
action="{}",
session_id=TEST_SESSION_ID,
error=repr(exception),
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_tool_start(
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
)
tracer.on_tool_start(serialized={}, input_str="test")
tracer.on_tool_error(exception)
assert tracer.runs == [compare_run]
@@ -418,9 +412,7 @@ def test_tracer_nested_runs_on_error() -> None:
tracer.on_llm_end(response=LLMResult([[]]))
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult([[]]))
tracer.on_tool_start(
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
)
tracer.on_tool_start(serialized={}, input_str="test")
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_error(exception)
tracer.on_tool_error(exception)
@@ -473,7 +465,7 @@ def test_tracer_nested_runs_on_error() -> None:
error=repr(exception),
tool_input="test",
output=None,
action="action",
action="{}",
child_runs=[
LLMRun(
id=None,