mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-01 11:02:37 +00:00
rfc: callback changes (#1165)
conceptually, no reason a tool should know what an "agent action" is unless any objections, can change in all callback handlers
This commit is contained in:
@@ -109,8 +109,10 @@ def test_agent_with_callbacks_global() -> None:
|
||||
# 1 top level chain run runs, 2 LLMChain runs, 2 LLM runs, 1 tool run
|
||||
assert handler.chain_starts == handler.chain_ends == 3
|
||||
assert handler.llm_starts == handler.llm_ends == 2
|
||||
assert handler.tool_starts == handler.tool_ends == 1
|
||||
assert handler.starts == 6
|
||||
assert handler.tool_starts == 2
|
||||
assert handler.tool_ends == 1
|
||||
# 1 extra agent action
|
||||
assert handler.starts == 7
|
||||
# 1 extra agent end
|
||||
assert handler.ends == 7
|
||||
assert handler.errors == 0
|
||||
@@ -155,8 +157,10 @@ def test_agent_with_callbacks_local() -> None:
|
||||
# 1 top level chain run, 2 LLMChain starts, 2 LLM runs, 1 tool run
|
||||
assert handler.chain_starts == handler.chain_ends == 3
|
||||
assert handler.llm_starts == handler.llm_ends == 2
|
||||
assert handler.tool_starts == handler.tool_ends == 1
|
||||
assert handler.starts == 6
|
||||
assert handler.tool_starts == 2
|
||||
assert handler.tool_ends == 1
|
||||
# 1 extra agent action
|
||||
assert handler.starts == 7
|
||||
# 1 extra agent end
|
||||
assert handler.ends == 7
|
||||
assert handler.errors == 0
|
||||
|
@@ -2,7 +2,6 @@
|
||||
import pytest
|
||||
|
||||
from langchain.agents.tools import Tool, tool
|
||||
from langchain.schema import AgentAction
|
||||
|
||||
|
||||
def test_unnamed_decorator() -> None:
|
||||
@@ -101,7 +100,4 @@ async def test_create_async_tool() -> None:
|
||||
assert test_tool.name == "test_name"
|
||||
assert test_tool.description == "test_description"
|
||||
assert test_tool.coroutine is not None
|
||||
assert (
|
||||
await test_tool.arun(AgentAction(tool_input="foo", tool="test_name", log=""))
|
||||
== "foo"
|
||||
)
|
||||
assert await test_tool.arun("foo") == "foo"
|
||||
|
@@ -94,7 +94,7 @@ class FakeCallbackHandler(BaseFakeCallbackHandler, BaseCallbackHandler):
|
||||
self.errors += 1
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
self.tool_starts += 1
|
||||
@@ -120,6 +120,11 @@ class FakeCallbackHandler(BaseFakeCallbackHandler, BaseCallbackHandler):
|
||||
self.agent_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
|
||||
class FakeAsyncCallbackHandler(BaseFakeCallbackHandler, AsyncCallbackHandler):
|
||||
"""Fake async callback handler for testing."""
|
||||
@@ -165,7 +170,7 @@ class FakeAsyncCallbackHandler(BaseFakeCallbackHandler, AsyncCallbackHandler):
|
||||
self.errors += 1
|
||||
|
||||
async def on_tool_start(
|
||||
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
self.tool_starts += 1
|
||||
@@ -190,3 +195,8 @@ class FakeAsyncCallbackHandler(BaseFakeCallbackHandler, AsyncCallbackHandler):
|
||||
"""Run when agent ends running."""
|
||||
self.agent_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> None:
|
||||
"""Run on agent action."""
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
|
@@ -9,7 +9,7 @@ from langchain.callbacks.base import (
|
||||
CallbackManager,
|
||||
)
|
||||
from langchain.callbacks.shared import SharedCallbackManager
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema import AgentFinish, LLMResult
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import (
|
||||
BaseFakeCallbackHandler,
|
||||
FakeAsyncCallbackHandler,
|
||||
@@ -27,7 +27,7 @@ def _test_callback_manager(
|
||||
manager.on_chain_start({"name": "foo"}, {})
|
||||
manager.on_chain_end({})
|
||||
manager.on_chain_error(Exception())
|
||||
manager.on_tool_start({}, AgentAction("", "", ""))
|
||||
manager.on_tool_start({}, "")
|
||||
manager.on_tool_end("")
|
||||
manager.on_tool_error(Exception())
|
||||
manager.on_agent_finish(AgentFinish(log="", return_values={}))
|
||||
@@ -44,7 +44,7 @@ async def _test_callback_manager_async(
|
||||
await manager.on_chain_start({"name": "foo"}, {})
|
||||
await manager.on_chain_end({})
|
||||
await manager.on_chain_error(Exception())
|
||||
await manager.on_tool_start({}, AgentAction("", "", ""))
|
||||
await manager.on_tool_start({}, "")
|
||||
await manager.on_tool_end("")
|
||||
await manager.on_tool_error(Exception())
|
||||
await manager.on_agent_finish(AgentFinish(log="", return_values={}))
|
||||
@@ -73,7 +73,7 @@ def _test_callback_manager_pass_in_verbose(
|
||||
manager.on_chain_start({"name": "foo"}, {}, verbose=True)
|
||||
manager.on_chain_end({}, verbose=True)
|
||||
manager.on_chain_error(Exception(), verbose=True)
|
||||
manager.on_tool_start({}, AgentAction("", "", ""), verbose=True)
|
||||
manager.on_tool_start({}, "", verbose=True)
|
||||
manager.on_tool_end("", verbose=True)
|
||||
manager.on_tool_error(Exception(), verbose=True)
|
||||
manager.on_agent_finish(AgentFinish(log="", return_values={}), verbose=True)
|
||||
@@ -136,7 +136,7 @@ def test_ignore_agent() -> None:
|
||||
handler1 = FakeCallbackHandler(ignore_agent_=True, always_verbose_=True)
|
||||
handler2 = FakeCallbackHandler(always_verbose_=True)
|
||||
manager = CallbackManager(handlers=[handler1, handler2])
|
||||
manager.on_tool_start({}, AgentAction("", "", ""), verbose=True)
|
||||
manager.on_tool_start({}, "", verbose=True)
|
||||
manager.on_tool_end("", verbose=True)
|
||||
manager.on_tool_error(Exception(), verbose=True)
|
||||
manager.on_agent_finish(AgentFinish({}, ""), verbose=True)
|
||||
|
@@ -19,7 +19,7 @@ from langchain.callbacks.tracers.base import (
|
||||
TracerSession,
|
||||
)
|
||||
from langchain.callbacks.tracers.schemas import TracerSessionCreate
|
||||
from langchain.schema import AgentAction, LLMResult
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
TEST_SESSION_ID = 2023
|
||||
|
||||
@@ -47,7 +47,7 @@ def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]:
|
||||
serialized={},
|
||||
tool_input="test",
|
||||
output="test",
|
||||
action="action",
|
||||
action="{}",
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
child_runs=[
|
||||
@@ -84,9 +84,7 @@ def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]:
|
||||
def _perform_nested_run(tracer: BaseTracer) -> None:
|
||||
"""Perform a nested run."""
|
||||
tracer.on_chain_start(serialized={}, inputs={})
|
||||
tracer.on_tool_start(
|
||||
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
|
||||
)
|
||||
tracer.on_tool_start(serialized={}, input_str="test")
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult([[]]))
|
||||
tracer.on_tool_end("test")
|
||||
@@ -303,16 +301,14 @@ def test_tracer_tool_run() -> None:
|
||||
serialized={},
|
||||
tool_input="test",
|
||||
output="test",
|
||||
action="action",
|
||||
action="{}",
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_tool_start(
|
||||
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
|
||||
)
|
||||
tracer.on_tool_start(serialized={}, input_str="test")
|
||||
tracer.on_tool_end("test")
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
@@ -390,16 +386,14 @@ def test_tracer_tool_run_on_error() -> None:
|
||||
serialized={},
|
||||
tool_input="test",
|
||||
output=None,
|
||||
action="action",
|
||||
action="{}",
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_tool_start(
|
||||
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
|
||||
)
|
||||
tracer.on_tool_start(serialized={}, input_str="test")
|
||||
tracer.on_tool_error(exception)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
@@ -418,9 +412,7 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
tracer.on_llm_end(response=LLMResult([[]]))
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult([[]]))
|
||||
tracer.on_tool_start(
|
||||
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
|
||||
)
|
||||
tracer.on_tool_start(serialized={}, input_str="test")
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_error(exception)
|
||||
tracer.on_tool_error(exception)
|
||||
@@ -473,7 +465,7 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
error=repr(exception),
|
||||
tool_input="test",
|
||||
output=None,
|
||||
action="action",
|
||||
action="{}",
|
||||
child_runs=[
|
||||
LLMRun(
|
||||
id=None,
|
||||
|
Reference in New Issue
Block a user