add tracing support to langchain (#741)

* add implementations of `BaseCallbackHandler` to support tracing:
`SharedTracer` which is thread-safe and `Tracer` which is not and is
meant to be used locally.
* Tracers persist runs to locally running `langchain-server`

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Ankush Gola
2023-01-26 17:38:13 -08:00
committed by GitHub
parent 7f76a1189c
commit 57609845df
23 changed files with 1422 additions and 71 deletions

View File

@@ -1,17 +1,43 @@
"""A fake callback handler for testing purposes."""
from typing import Any, Dict, List
from typing import Any, Dict, List, Union
from pydantic import BaseModel
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
class FakeCallbackHandler(BaseCallbackHandler):
class FakeCallbackHandler(BaseModel, BaseCallbackHandler):
"""Fake callback handler for testing."""
starts: int = 0
ends: int = 0
errors: int = 0
text: int = 0
ignore_llm_: bool = False
ignore_chain_: bool = False
ignore_agent_: bool = False
always_verbose_: bool = False
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return self.always_verbose_
@property
def ignore_llm(self) -> bool:
"""Whether to ignore LLM callbacks."""
return self.ignore_llm_
@property
def ignore_chain(self) -> bool:
"""Whether to ignore chain callbacks."""
return self.ignore_chain_
@property
def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks."""
return self.ignore_agent_
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
@@ -23,7 +49,9 @@ class FakeCallbackHandler(BaseCallbackHandler):
"""Run when LLM ends running."""
self.ends += 1
def on_llm_error(self, error: Exception, **kwargs: Any) -> None:
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when LLM errors."""
self.errors += 1
@@ -37,7 +65,9 @@ class FakeCallbackHandler(BaseCallbackHandler):
"""Run when chain ends running."""
self.ends += 1
def on_chain_error(self, error: Exception, **kwargs: Any) -> None:
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when chain errors."""
self.errors += 1
@@ -51,7 +81,9 @@ class FakeCallbackHandler(BaseCallbackHandler):
"""Run when tool ends running."""
self.ends += 1
def on_tool_error(self, error: Exception, **kwargs: Any) -> None:
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when tool errors."""
self.errors += 1

View File

@@ -8,6 +8,31 @@ from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
def _test_callback_manager(
manager: BaseCallbackManager, *handlers: FakeCallbackHandler
) -> None:
"""Test the CallbackManager."""
manager.on_llm_start({}, [])
manager.on_llm_end(LLMResult(generations=[]))
manager.on_llm_error(Exception())
manager.on_chain_start({"name": "foo"}, {})
manager.on_chain_end({})
manager.on_chain_error(Exception())
manager.on_tool_start({}, AgentAction("", "", ""))
manager.on_tool_end("")
manager.on_tool_error(Exception())
manager.on_agent_finish(AgentFinish(log="", return_values={}))
for handler in handlers:
if handler.always_verbose:
assert handler.starts == 3
assert handler.ends == 4
assert handler.errors == 3
else:
assert handler.starts == 0
assert handler.ends == 0
assert handler.errors == 0
def _test_callback_manager_pass_in_verbose(
manager: BaseCallbackManager, *handlers: FakeCallbackHandler
) -> None:
"""Test the CallbackManager."""
manager.on_llm_start({}, [], verbose=True)
@@ -19,7 +44,7 @@ def _test_callback_manager(
manager.on_tool_start({}, AgentAction("", "", ""), verbose=True)
manager.on_tool_end("", verbose=True)
manager.on_tool_error(Exception(), verbose=True)
manager.on_agent_finish(AgentFinish({}, ""), verbose=True)
manager.on_agent_finish(AgentFinish(log="", return_values={}), verbose=True)
for handler in handlers:
assert handler.starts == 3
assert handler.ends == 4
@@ -27,17 +52,25 @@ def _test_callback_manager(
def test_callback_manager() -> None:
"""Test the CallbackManager."""
handler1 = FakeCallbackHandler(always_verbose_=True)
handler2 = FakeCallbackHandler(always_verbose_=False)
manager = CallbackManager([handler1, handler2])
_test_callback_manager(manager, handler1, handler2)
def test_callback_manager_pass_in_verbose() -> None:
"""Test the CallbackManager."""
handler1 = FakeCallbackHandler()
handler2 = FakeCallbackHandler()
manager = CallbackManager(handlers=[handler1, handler2])
_test_callback_manager(manager, handler1, handler2)
manager = CallbackManager([handler1, handler2])
_test_callback_manager_pass_in_verbose(manager, handler1, handler2)
def test_ignore_llm() -> None:
"""Test ignore llm param for callback handlers."""
handler1 = FakeCallbackHandler(ignore_llm=True)
handler2 = FakeCallbackHandler()
handler1 = FakeCallbackHandler(ignore_llm_=True, always_verbose_=True)
handler2 = FakeCallbackHandler(always_verbose_=True)
manager = CallbackManager(handlers=[handler1, handler2])
manager.on_llm_start({}, [], verbose=True)
manager.on_llm_end(LLMResult(generations=[]), verbose=True)
@@ -52,8 +85,8 @@ def test_ignore_llm() -> None:
def test_ignore_chain() -> None:
"""Test ignore chain param for callback handlers."""
handler1 = FakeCallbackHandler(ignore_chain=True)
handler2 = FakeCallbackHandler()
handler1 = FakeCallbackHandler(ignore_chain_=True, always_verbose_=True)
handler2 = FakeCallbackHandler(always_verbose_=True)
manager = CallbackManager(handlers=[handler1, handler2])
manager.on_chain_start({"name": "foo"}, {}, verbose=True)
manager.on_chain_end({}, verbose=True)
@@ -68,8 +101,8 @@ def test_ignore_chain() -> None:
def test_ignore_agent() -> None:
"""Test ignore agent param for callback handlers."""
handler1 = FakeCallbackHandler(ignore_agent=True)
handler2 = FakeCallbackHandler()
handler1 = FakeCallbackHandler(ignore_agent_=True, always_verbose_=True)
handler2 = FakeCallbackHandler(always_verbose_=True)
manager = CallbackManager(handlers=[handler1, handler2])
manager.on_tool_start({}, AgentAction("", "", ""), verbose=True)
manager.on_tool_end("", verbose=True)
@@ -90,7 +123,7 @@ def test_shared_callback_manager() -> None:
assert manager1 is manager2
handler1 = FakeCallbackHandler()
handler1 = FakeCallbackHandler(always_verbose_=True)
handler2 = FakeCallbackHandler()
manager1.add_handler(handler1)
manager2.add_handler(handler2)

View File

@@ -0,0 +1 @@
"""Tests for correct functioning of tracers."""

View File

@@ -0,0 +1,530 @@
"""Test Tracer classes."""
from __future__ import annotations
import threading
from datetime import datetime
from typing import List, Optional, Union
import pytest
from freezegun import freeze_time
from langchain.callbacks.tracers.base import (
BaseTracer,
ChainRun,
LLMRun,
SharedTracer,
ToolRun,
Tracer,
TracerException,
TracerSession,
)
from langchain.callbacks.tracers.schemas import TracerSessionCreate
from langchain.schema import AgentAction, LLMResult
TEST_SESSION_ID = 2023
@freeze_time("2023-01-01")
def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]:
return ChainRun(
id=None,
error=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
serialized={},
inputs={},
outputs={},
session_id=TEST_SESSION_ID,
child_runs=[
ToolRun(
id=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=2,
serialized={},
tool_input="test",
output="test",
action="action",
session_id=TEST_SESSION_ID,
error=None,
child_runs=[
LLMRun(
id=None,
error=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=3,
serialized={},
prompts=[],
response=LLMResult([[]]),
session_id=TEST_SESSION_ID,
)
],
),
LLMRun(
id=None,
error=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=4,
serialized={},
prompts=[],
response=LLMResult([[]]),
session_id=TEST_SESSION_ID,
),
],
)
def _perform_nested_run(tracer: BaseTracer) -> None:
"""Perform a nested run."""
tracer.on_chain_start(serialized={}, inputs={})
tracer.on_tool_start(
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
)
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult([[]]))
tracer.on_tool_end("test")
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult([[]]))
tracer.on_chain_end(outputs={})
def _add_child_run(
parent_run: Union[ChainRun, ToolRun],
child_run: Union[LLMRun, ChainRun, ToolRun],
) -> None:
"""Add child run to a chain run or tool run."""
parent_run.child_runs.append(child_run)
def _generate_id() -> Optional[Union[int, str]]:
"""Generate an id for a run."""
return None
def load_session(session_name: str) -> TracerSession:
"""Load a tracing session."""
return TracerSession(id=1, name=session_name, start_time=datetime.utcnow())
def _persist_session(session: TracerSessionCreate) -> TracerSession:
"""Persist a tracing session."""
return TracerSession(id=TEST_SESSION_ID, **session.dict())
def load_default_session() -> TracerSession:
"""Load a tracing session."""
return TracerSession(id=1, name="default", start_time=datetime.utcnow())
class FakeTracer(Tracer):
"""Fake tracer that records LangChain execution."""
def __init__(self) -> None:
"""Initialize the tracer."""
super().__init__()
self.runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
"""Persist a run."""
self.runs.append(run)
def _add_child_run(
self,
parent_run: Union[ChainRun, ToolRun],
child_run: Union[LLMRun, ChainRun, ToolRun],
) -> None:
"""Add child run to a chain run or tool run."""
_add_child_run(parent_run, child_run)
def _generate_id(self) -> Optional[Union[int, str]]:
"""Generate an id for a run."""
return _generate_id()
def _persist_session(self, session: TracerSessionCreate) -> TracerSession:
"""Persist a tracing session."""
return _persist_session(session)
def load_session(self, session_name: str) -> TracerSession:
"""Load a tracing session."""
return load_session(session_name)
def load_default_session(self) -> TracerSession:
"""Load a tracing session."""
return load_default_session()
class FakeSharedTracer(SharedTracer):
"""Fake shared tracer that records LangChain execution."""
runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
"""Persist a run."""
with self._lock:
self.runs.append(run)
def remove_runs(self) -> None:
"""Remove all runs."""
with self._lock:
self.runs = []
def _add_child_run(
self,
parent_run: Union[ChainRun, ToolRun],
child_run: Union[LLMRun, ChainRun, ToolRun],
) -> None:
"""Add child run to a chain run or tool run."""
_add_child_run(parent_run, child_run)
def _generate_id(self) -> Optional[Union[int, str]]:
"""Generate an id for a run."""
return _generate_id()
def _persist_session(self, session: TracerSessionCreate) -> TracerSession:
"""Persist a tracing session."""
return _persist_session(session)
def load_session(self, session_name: str) -> TracerSession:
"""Load a tracing session."""
return load_session(session_name)
def load_default_session(self) -> TracerSession:
"""Load a tracing session."""
return load_default_session()
@freeze_time("2023-01-01")
def test_tracer_llm_run() -> None:
"""Test tracer on an LLM run."""
compare_run = LLMRun(
id=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
serialized={},
prompts=[],
response=LLMResult([[]]),
session_id=TEST_SESSION_ID,
error=None,
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult([[]]))
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_llm_run_errors_no_session() -> None:
"""Test tracer on an LLM run without a session."""
tracer = FakeTracer()
with pytest.raises(TracerException):
tracer.on_llm_start(serialized={}, prompts=[])
@freeze_time("2023-01-01")
def test_tracer_llm_run_errors_no_start() -> None:
"""Test tracer on an LLM run without a start."""
tracer = FakeTracer()
tracer.new_session()
with pytest.raises(TracerException):
tracer.on_llm_end(response=LLMResult([[]]))
@freeze_time("2023-01-01")
def test_tracer_multiple_llm_runs() -> None:
"""Test the tracer with multiple runs."""
compare_run = LLMRun(
id=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
serialized={},
prompts=[],
response=LLMResult([[]]),
session_id=TEST_SESSION_ID,
error=None,
)
tracer = FakeTracer()
tracer.new_session()
num_runs = 10
for _ in range(num_runs):
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult([[]]))
assert tracer.runs == [compare_run] * num_runs
@freeze_time("2023-01-01")
def test_tracer_chain_run() -> None:
"""Test tracer on a Chain run."""
compare_run = ChainRun(
id=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
serialized={},
inputs={},
outputs={},
session_id=TEST_SESSION_ID,
error=None,
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_chain_start(serialized={}, inputs={})
tracer.on_chain_end(outputs={})
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_tool_run() -> None:
"""Test tracer on a Tool run."""
compare_run = ToolRun(
id=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
serialized={},
tool_input="test",
output="test",
action="action",
session_id=TEST_SESSION_ID,
error=None,
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_tool_start(
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
)
tracer.on_tool_end("test")
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_nested_run() -> None:
"""Test tracer on a nested run."""
tracer = FakeTracer()
tracer.new_session()
_perform_nested_run(tracer)
assert tracer.runs == [_get_compare_run()]
@freeze_time("2023-01-01")
def test_tracer_llm_run_on_error() -> None:
"""Test tracer on an LLM run with an error."""
exception = Exception("test")
compare_run = LLMRun(
id=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
serialized={},
prompts=[],
response=None,
session_id=TEST_SESSION_ID,
error=repr(exception),
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_error(exception)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_chain_run_on_error() -> None:
"""Test tracer on a Chain run with an error."""
exception = Exception("test")
compare_run = ChainRun(
id=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
serialized={},
inputs={},
outputs=None,
session_id=TEST_SESSION_ID,
error=repr(exception),
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_chain_start(serialized={}, inputs={})
tracer.on_chain_error(exception)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_tool_run_on_error() -> None:
"""Test tracer on a Tool run with an error."""
exception = Exception("test")
compare_run = ToolRun(
id=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
serialized={},
tool_input="test",
output=None,
action="action",
session_id=TEST_SESSION_ID,
error=repr(exception),
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_tool_start(
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
)
tracer.on_tool_error(exception)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_nested_runs_on_error() -> None:
"""Test tracer on a nested run with an error."""
exception = Exception("test")
tracer = FakeTracer()
tracer.new_session()
for _ in range(3):
tracer.on_chain_start(serialized={}, inputs={})
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult([[]]))
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult([[]]))
tracer.on_tool_start(
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
)
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_error(exception)
tracer.on_tool_error(exception)
tracer.on_chain_error(exception)
compare_run = ChainRun(
id=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
serialized={},
session_id=TEST_SESSION_ID,
error=repr(exception),
inputs={},
outputs=None,
child_runs=[
LLMRun(
id=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=2,
serialized={},
session_id=TEST_SESSION_ID,
error=None,
prompts=[],
response=LLMResult(generations=[[]], llm_output=None),
),
LLMRun(
id=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=3,
serialized={},
session_id=TEST_SESSION_ID,
error=None,
prompts=[],
response=LLMResult(generations=[[]], llm_output=None),
),
ToolRun(
id=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=4,
serialized={},
session_id=TEST_SESSION_ID,
error=repr(exception),
tool_input="test",
output=None,
action="action",
child_runs=[
LLMRun(
id=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=5,
serialized={},
session_id=TEST_SESSION_ID,
error=repr(exception),
prompts=[],
response=None,
)
],
child_llm_runs=[],
child_chain_runs=[],
child_tool_runs=[],
),
],
child_llm_runs=[],
child_chain_runs=[],
child_tool_runs=[],
)
assert tracer.runs == [compare_run] * 3
@freeze_time("2023-01-01")
def test_shared_tracer_nested_run() -> None:
"""Test shared tracer on a nested run."""
tracer = FakeSharedTracer()
tracer.new_session()
tracer.remove_runs()
_perform_nested_run(tracer)
assert tracer.runs == [_get_compare_run()]
@freeze_time("2023-01-01")
def test_shared_tracer_nested_run_multithreaded() -> None:
"""Test shared tracer on a nested run."""
tracer = FakeSharedTracer()
tracer.remove_runs()
tracer.new_session()
threads = []
num_threads = 10
for _ in range(num_threads):
thread = threading.Thread(target=_perform_nested_run, args=(tracer,))
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
assert tracer.runs == [_get_compare_run()] * num_threads