mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-08 14:31:55 +00:00
add tracing support to langchain (#741)
* add implementations of `BaseCallbackHandler` to support tracing: `SharedTracer` which is thread-safe and `Tracer` which is not and is meant to be used locally. * Tracers persist runs to locally running `langchain-server` Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
@@ -1,17 +1,43 @@
|
||||
"""A fake callback handler for testing purposes."""
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
class FakeCallbackHandler(BaseCallbackHandler):
|
||||
class FakeCallbackHandler(BaseModel, BaseCallbackHandler):
|
||||
"""Fake callback handler for testing."""
|
||||
|
||||
starts: int = 0
|
||||
ends: int = 0
|
||||
errors: int = 0
|
||||
text: int = 0
|
||||
ignore_llm_: bool = False
|
||||
ignore_chain_: bool = False
|
||||
ignore_agent_: bool = False
|
||||
always_verbose_: bool = False
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return self.always_verbose_
|
||||
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
"""Whether to ignore LLM callbacks."""
|
||||
return self.ignore_llm_
|
||||
|
||||
@property
|
||||
def ignore_chain(self) -> bool:
|
||||
"""Whether to ignore chain callbacks."""
|
||||
return self.ignore_chain_
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return self.ignore_agent_
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
@@ -23,7 +49,9 @@ class FakeCallbackHandler(BaseCallbackHandler):
|
||||
"""Run when LLM ends running."""
|
||||
self.ends += 1
|
||||
|
||||
def on_llm_error(self, error: Exception, **kwargs: Any) -> None:
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.errors += 1
|
||||
|
||||
@@ -37,7 +65,9 @@ class FakeCallbackHandler(BaseCallbackHandler):
|
||||
"""Run when chain ends running."""
|
||||
self.ends += 1
|
||||
|
||||
def on_chain_error(self, error: Exception, **kwargs: Any) -> None:
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
self.errors += 1
|
||||
|
||||
@@ -51,7 +81,9 @@ class FakeCallbackHandler(BaseCallbackHandler):
|
||||
"""Run when tool ends running."""
|
||||
self.ends += 1
|
||||
|
||||
def on_tool_error(self, error: Exception, **kwargs: Any) -> None:
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
self.errors += 1
|
||||
|
||||
|
@@ -8,6 +8,31 @@ from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
def _test_callback_manager(
|
||||
manager: BaseCallbackManager, *handlers: FakeCallbackHandler
|
||||
) -> None:
|
||||
"""Test the CallbackManager."""
|
||||
manager.on_llm_start({}, [])
|
||||
manager.on_llm_end(LLMResult(generations=[]))
|
||||
manager.on_llm_error(Exception())
|
||||
manager.on_chain_start({"name": "foo"}, {})
|
||||
manager.on_chain_end({})
|
||||
manager.on_chain_error(Exception())
|
||||
manager.on_tool_start({}, AgentAction("", "", ""))
|
||||
manager.on_tool_end("")
|
||||
manager.on_tool_error(Exception())
|
||||
manager.on_agent_finish(AgentFinish(log="", return_values={}))
|
||||
for handler in handlers:
|
||||
if handler.always_verbose:
|
||||
assert handler.starts == 3
|
||||
assert handler.ends == 4
|
||||
assert handler.errors == 3
|
||||
else:
|
||||
assert handler.starts == 0
|
||||
assert handler.ends == 0
|
||||
assert handler.errors == 0
|
||||
|
||||
|
||||
def _test_callback_manager_pass_in_verbose(
|
||||
manager: BaseCallbackManager, *handlers: FakeCallbackHandler
|
||||
) -> None:
|
||||
"""Test the CallbackManager."""
|
||||
manager.on_llm_start({}, [], verbose=True)
|
||||
@@ -19,7 +44,7 @@ def _test_callback_manager(
|
||||
manager.on_tool_start({}, AgentAction("", "", ""), verbose=True)
|
||||
manager.on_tool_end("", verbose=True)
|
||||
manager.on_tool_error(Exception(), verbose=True)
|
||||
manager.on_agent_finish(AgentFinish({}, ""), verbose=True)
|
||||
manager.on_agent_finish(AgentFinish(log="", return_values={}), verbose=True)
|
||||
for handler in handlers:
|
||||
assert handler.starts == 3
|
||||
assert handler.ends == 4
|
||||
@@ -27,17 +52,25 @@ def _test_callback_manager(
|
||||
|
||||
|
||||
def test_callback_manager() -> None:
|
||||
"""Test the CallbackManager."""
|
||||
handler1 = FakeCallbackHandler(always_verbose_=True)
|
||||
handler2 = FakeCallbackHandler(always_verbose_=False)
|
||||
manager = CallbackManager([handler1, handler2])
|
||||
_test_callback_manager(manager, handler1, handler2)
|
||||
|
||||
|
||||
def test_callback_manager_pass_in_verbose() -> None:
|
||||
"""Test the CallbackManager."""
|
||||
handler1 = FakeCallbackHandler()
|
||||
handler2 = FakeCallbackHandler()
|
||||
manager = CallbackManager(handlers=[handler1, handler2])
|
||||
_test_callback_manager(manager, handler1, handler2)
|
||||
manager = CallbackManager([handler1, handler2])
|
||||
_test_callback_manager_pass_in_verbose(manager, handler1, handler2)
|
||||
|
||||
|
||||
def test_ignore_llm() -> None:
|
||||
"""Test ignore llm param for callback handlers."""
|
||||
handler1 = FakeCallbackHandler(ignore_llm=True)
|
||||
handler2 = FakeCallbackHandler()
|
||||
handler1 = FakeCallbackHandler(ignore_llm_=True, always_verbose_=True)
|
||||
handler2 = FakeCallbackHandler(always_verbose_=True)
|
||||
manager = CallbackManager(handlers=[handler1, handler2])
|
||||
manager.on_llm_start({}, [], verbose=True)
|
||||
manager.on_llm_end(LLMResult(generations=[]), verbose=True)
|
||||
@@ -52,8 +85,8 @@ def test_ignore_llm() -> None:
|
||||
|
||||
def test_ignore_chain() -> None:
|
||||
"""Test ignore chain param for callback handlers."""
|
||||
handler1 = FakeCallbackHandler(ignore_chain=True)
|
||||
handler2 = FakeCallbackHandler()
|
||||
handler1 = FakeCallbackHandler(ignore_chain_=True, always_verbose_=True)
|
||||
handler2 = FakeCallbackHandler(always_verbose_=True)
|
||||
manager = CallbackManager(handlers=[handler1, handler2])
|
||||
manager.on_chain_start({"name": "foo"}, {}, verbose=True)
|
||||
manager.on_chain_end({}, verbose=True)
|
||||
@@ -68,8 +101,8 @@ def test_ignore_chain() -> None:
|
||||
|
||||
def test_ignore_agent() -> None:
|
||||
"""Test ignore agent param for callback handlers."""
|
||||
handler1 = FakeCallbackHandler(ignore_agent=True)
|
||||
handler2 = FakeCallbackHandler()
|
||||
handler1 = FakeCallbackHandler(ignore_agent_=True, always_verbose_=True)
|
||||
handler2 = FakeCallbackHandler(always_verbose_=True)
|
||||
manager = CallbackManager(handlers=[handler1, handler2])
|
||||
manager.on_tool_start({}, AgentAction("", "", ""), verbose=True)
|
||||
manager.on_tool_end("", verbose=True)
|
||||
@@ -90,7 +123,7 @@ def test_shared_callback_manager() -> None:
|
||||
|
||||
assert manager1 is manager2
|
||||
|
||||
handler1 = FakeCallbackHandler()
|
||||
handler1 = FakeCallbackHandler(always_verbose_=True)
|
||||
handler2 = FakeCallbackHandler()
|
||||
manager1.add_handler(handler1)
|
||||
manager2.add_handler(handler2)
|
||||
|
1
tests/unit_tests/callbacks/tracers/__init__.py
Normal file
1
tests/unit_tests/callbacks/tracers/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for correct functioning of tracers."""
|
530
tests/unit_tests/callbacks/tracers/test_tracer.py
Normal file
530
tests/unit_tests/callbacks/tracers/test_tracer.py
Normal file
@@ -0,0 +1,530 @@
|
||||
"""Test Tracer classes."""
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
|
||||
from langchain.callbacks.tracers.base import (
|
||||
BaseTracer,
|
||||
ChainRun,
|
||||
LLMRun,
|
||||
SharedTracer,
|
||||
ToolRun,
|
||||
Tracer,
|
||||
TracerException,
|
||||
TracerSession,
|
||||
)
|
||||
from langchain.callbacks.tracers.schemas import TracerSessionCreate
|
||||
from langchain.schema import AgentAction, LLMResult
|
||||
|
||||
TEST_SESSION_ID = 2023
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]:
|
||||
return ChainRun(
|
||||
id=None,
|
||||
error=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
serialized={},
|
||||
inputs={},
|
||||
outputs={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
child_runs=[
|
||||
ToolRun(
|
||||
id=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=2,
|
||||
serialized={},
|
||||
tool_input="test",
|
||||
output="test",
|
||||
action="action",
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
child_runs=[
|
||||
LLMRun(
|
||||
id=None,
|
||||
error=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=3,
|
||||
serialized={},
|
||||
prompts=[],
|
||||
response=LLMResult([[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
)
|
||||
],
|
||||
),
|
||||
LLMRun(
|
||||
id=None,
|
||||
error=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=4,
|
||||
serialized={},
|
||||
prompts=[],
|
||||
response=LLMResult([[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _perform_nested_run(tracer: BaseTracer) -> None:
|
||||
"""Perform a nested run."""
|
||||
tracer.on_chain_start(serialized={}, inputs={})
|
||||
tracer.on_tool_start(
|
||||
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
|
||||
)
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult([[]]))
|
||||
tracer.on_tool_end("test")
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult([[]]))
|
||||
tracer.on_chain_end(outputs={})
|
||||
|
||||
|
||||
def _add_child_run(
|
||||
parent_run: Union[ChainRun, ToolRun],
|
||||
child_run: Union[LLMRun, ChainRun, ToolRun],
|
||||
) -> None:
|
||||
"""Add child run to a chain run or tool run."""
|
||||
parent_run.child_runs.append(child_run)
|
||||
|
||||
|
||||
def _generate_id() -> Optional[Union[int, str]]:
|
||||
"""Generate an id for a run."""
|
||||
return None
|
||||
|
||||
|
||||
def load_session(session_name: str) -> TracerSession:
|
||||
"""Load a tracing session."""
|
||||
return TracerSession(id=1, name=session_name, start_time=datetime.utcnow())
|
||||
|
||||
|
||||
def _persist_session(session: TracerSessionCreate) -> TracerSession:
|
||||
"""Persist a tracing session."""
|
||||
return TracerSession(id=TEST_SESSION_ID, **session.dict())
|
||||
|
||||
|
||||
def load_default_session() -> TracerSession:
|
||||
"""Load a tracing session."""
|
||||
return TracerSession(id=1, name="default", start_time=datetime.utcnow())
|
||||
|
||||
|
||||
class FakeTracer(Tracer):
|
||||
"""Fake tracer that records LangChain execution."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the tracer."""
|
||||
super().__init__()
|
||||
self.runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
|
||||
|
||||
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
||||
"""Persist a run."""
|
||||
self.runs.append(run)
|
||||
|
||||
def _add_child_run(
|
||||
self,
|
||||
parent_run: Union[ChainRun, ToolRun],
|
||||
child_run: Union[LLMRun, ChainRun, ToolRun],
|
||||
) -> None:
|
||||
"""Add child run to a chain run or tool run."""
|
||||
_add_child_run(parent_run, child_run)
|
||||
|
||||
def _generate_id(self) -> Optional[Union[int, str]]:
|
||||
"""Generate an id for a run."""
|
||||
return _generate_id()
|
||||
|
||||
def _persist_session(self, session: TracerSessionCreate) -> TracerSession:
|
||||
"""Persist a tracing session."""
|
||||
return _persist_session(session)
|
||||
|
||||
def load_session(self, session_name: str) -> TracerSession:
|
||||
"""Load a tracing session."""
|
||||
return load_session(session_name)
|
||||
|
||||
def load_default_session(self) -> TracerSession:
|
||||
"""Load a tracing session."""
|
||||
return load_default_session()
|
||||
|
||||
|
||||
class FakeSharedTracer(SharedTracer):
|
||||
"""Fake shared tracer that records LangChain execution."""
|
||||
|
||||
runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
|
||||
|
||||
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
||||
"""Persist a run."""
|
||||
with self._lock:
|
||||
self.runs.append(run)
|
||||
|
||||
def remove_runs(self) -> None:
|
||||
"""Remove all runs."""
|
||||
with self._lock:
|
||||
self.runs = []
|
||||
|
||||
def _add_child_run(
|
||||
self,
|
||||
parent_run: Union[ChainRun, ToolRun],
|
||||
child_run: Union[LLMRun, ChainRun, ToolRun],
|
||||
) -> None:
|
||||
"""Add child run to a chain run or tool run."""
|
||||
_add_child_run(parent_run, child_run)
|
||||
|
||||
def _generate_id(self) -> Optional[Union[int, str]]:
|
||||
"""Generate an id for a run."""
|
||||
return _generate_id()
|
||||
|
||||
def _persist_session(self, session: TracerSessionCreate) -> TracerSession:
|
||||
"""Persist a tracing session."""
|
||||
return _persist_session(session)
|
||||
|
||||
def load_session(self, session_name: str) -> TracerSession:
|
||||
"""Load a tracing session."""
|
||||
return load_session(session_name)
|
||||
|
||||
def load_default_session(self) -> TracerSession:
|
||||
"""Load a tracing session."""
|
||||
return load_default_session()
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_llm_run() -> None:
|
||||
"""Test tracer on an LLM run."""
|
||||
compare_run = LLMRun(
|
||||
id=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
serialized={},
|
||||
prompts=[],
|
||||
response=LLMResult([[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult([[]]))
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_llm_run_errors_no_session() -> None:
|
||||
"""Test tracer on an LLM run without a session."""
|
||||
tracer = FakeTracer()
|
||||
|
||||
with pytest.raises(TracerException):
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_llm_run_errors_no_start() -> None:
|
||||
"""Test tracer on an LLM run without a start."""
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
with pytest.raises(TracerException):
|
||||
tracer.on_llm_end(response=LLMResult([[]]))
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_multiple_llm_runs() -> None:
|
||||
"""Test the tracer with multiple runs."""
|
||||
compare_run = LLMRun(
|
||||
id=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
serialized={},
|
||||
prompts=[],
|
||||
response=LLMResult([[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
num_runs = 10
|
||||
for _ in range(num_runs):
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult([[]]))
|
||||
|
||||
assert tracer.runs == [compare_run] * num_runs
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_chain_run() -> None:
|
||||
"""Test tracer on a Chain run."""
|
||||
compare_run = ChainRun(
|
||||
id=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
serialized={},
|
||||
inputs={},
|
||||
outputs={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_chain_start(serialized={}, inputs={})
|
||||
tracer.on_chain_end(outputs={})
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_tool_run() -> None:
|
||||
"""Test tracer on a Tool run."""
|
||||
compare_run = ToolRun(
|
||||
id=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
serialized={},
|
||||
tool_input="test",
|
||||
output="test",
|
||||
action="action",
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_tool_start(
|
||||
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
|
||||
)
|
||||
tracer.on_tool_end("test")
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_nested_run() -> None:
|
||||
"""Test tracer on a nested run."""
|
||||
tracer = FakeTracer()
|
||||
tracer.new_session()
|
||||
_perform_nested_run(tracer)
|
||||
assert tracer.runs == [_get_compare_run()]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_llm_run_on_error() -> None:
|
||||
"""Test tracer on an LLM run with an error."""
|
||||
exception = Exception("test")
|
||||
|
||||
compare_run = LLMRun(
|
||||
id=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
serialized={},
|
||||
prompts=[],
|
||||
response=None,
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_error(exception)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_chain_run_on_error() -> None:
|
||||
"""Test tracer on a Chain run with an error."""
|
||||
exception = Exception("test")
|
||||
|
||||
compare_run = ChainRun(
|
||||
id=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
serialized={},
|
||||
inputs={},
|
||||
outputs=None,
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_chain_start(serialized={}, inputs={})
|
||||
tracer.on_chain_error(exception)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_tool_run_on_error() -> None:
|
||||
"""Test tracer on a Tool run with an error."""
|
||||
exception = Exception("test")
|
||||
|
||||
compare_run = ToolRun(
|
||||
id=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
serialized={},
|
||||
tool_input="test",
|
||||
output=None,
|
||||
action="action",
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_tool_start(
|
||||
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
|
||||
)
|
||||
tracer.on_tool_error(exception)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_nested_runs_on_error() -> None:
|
||||
"""Test tracer on a nested run with an error."""
|
||||
exception = Exception("test")
|
||||
|
||||
tracer = FakeTracer()
|
||||
tracer.new_session()
|
||||
|
||||
for _ in range(3):
|
||||
tracer.on_chain_start(serialized={}, inputs={})
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult([[]]))
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult([[]]))
|
||||
tracer.on_tool_start(
|
||||
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
|
||||
)
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_error(exception)
|
||||
tracer.on_tool_error(exception)
|
||||
tracer.on_chain_error(exception)
|
||||
|
||||
compare_run = ChainRun(
|
||||
id=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
serialized={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
inputs={},
|
||||
outputs=None,
|
||||
child_runs=[
|
||||
LLMRun(
|
||||
id=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=2,
|
||||
serialized={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]], llm_output=None),
|
||||
),
|
||||
LLMRun(
|
||||
id=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=3,
|
||||
serialized={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]], llm_output=None),
|
||||
),
|
||||
ToolRun(
|
||||
id=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=4,
|
||||
serialized={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
tool_input="test",
|
||||
output=None,
|
||||
action="action",
|
||||
child_runs=[
|
||||
LLMRun(
|
||||
id=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=5,
|
||||
serialized={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
prompts=[],
|
||||
response=None,
|
||||
)
|
||||
],
|
||||
child_llm_runs=[],
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[],
|
||||
),
|
||||
],
|
||||
child_llm_runs=[],
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[],
|
||||
)
|
||||
|
||||
assert tracer.runs == [compare_run] * 3
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_shared_tracer_nested_run() -> None:
|
||||
"""Test shared tracer on a nested run."""
|
||||
tracer = FakeSharedTracer()
|
||||
tracer.new_session()
|
||||
tracer.remove_runs()
|
||||
_perform_nested_run(tracer)
|
||||
assert tracer.runs == [_get_compare_run()]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_shared_tracer_nested_run_multithreaded() -> None:
|
||||
"""Test shared tracer on a nested run."""
|
||||
tracer = FakeSharedTracer()
|
||||
tracer.remove_runs()
|
||||
tracer.new_session()
|
||||
threads = []
|
||||
num_threads = 10
|
||||
for _ in range(num_threads):
|
||||
thread = threading.Thread(target=_perform_nested_run, args=(tracer,))
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
assert tracer.runs == [_get_compare_run()] * num_threads
|
Reference in New Issue
Block a user