[Tracing] String Stacktrace (#14131)

Add full stacktrace
This commit is contained in:
William FH 2023-12-14 22:15:07 -08:00 committed by GitHub
parent 7f42811e14
commit 93c7eb4e6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 98 additions and 173 deletions

View File

@ -30,6 +30,24 @@ class FakeTracer(BaseTracer):
self.runs.append(run)
def _compare_run_with_error(run: Run, expected_run: Run) -> None:
if run.child_runs:
assert len(expected_run.child_runs) == len(run.child_runs)
for received, expected in zip(run.child_runs, expected_run.child_runs):
_compare_run_with_error(received, expected)
received = run.dict(exclude={"child_runs"})
received_err = received.pop("error")
expected = expected_run.dict(exclude={"child_runs"})
expected_err = expected.pop("error")
assert received == expected
if expected_err is not None:
assert received_err is not None
assert expected_err in received_err
else:
assert received_err is None
@freeze_time("2023-01-01")
def test_tracer_llm_run() -> None:
"""Test tracer on an LLM run."""
@ -328,7 +346,8 @@ def test_tracer_llm_run_on_error() -> None:
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_error(exception, run_id=uuid)
assert tracer.runs == [compare_run]
assert len(tracer.runs) == 1
_compare_run_with_error(tracer.runs[0], compare_run)
@freeze_time("2023-01-01")
@ -364,7 +383,7 @@ def test_tracer_llm_run_on_error_callback() -> None:
tracer = FakeTracerWithLlmErrorCallback()
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_error(exception, run_id=uuid)
assert tracer.error_run == compare_run
_compare_run_with_error(tracer.error_run, compare_run)
@freeze_time("2023-01-01")
@ -394,7 +413,7 @@ def test_tracer_chain_run_on_error() -> None:
tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid)
tracer.on_chain_error(exception, run_id=uuid)
assert tracer.runs == [compare_run]
_compare_run_with_error(tracer.runs[0], compare_run)
@freeze_time("2023-01-01")
@ -425,7 +444,7 @@ def test_tracer_tool_run_on_error() -> None:
tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid)
tracer.on_tool_error(exception, run_id=uuid)
assert tracer.runs == [compare_run]
_compare_run_with_error(tracer.runs[0], compare_run)
@freeze_time("2023-01-01")
@ -568,4 +587,6 @@ def test_tracer_nested_runs_on_error() -> None:
),
],
)
assert tracer.runs == [compare_run] * 3
assert len(tracer.runs) == 3
for run in tracer.runs:
_compare_run_with_error(run, compare_run)

View File

@ -98,6 +98,15 @@ class FakeTracer(BaseTracer):
return load_default_session()
def _compare_run_with_error(run: Run, expected_run: Run) -> None:
received = run.dict()
received_err = received.pop("error")
expected = expected_run.dict()
expected_err = expected.pop("error")
assert received == expected
assert expected_err in received_err
@freeze_time("2023-01-01")
def test_tracer_llm_run() -> None:
"""Test tracer on an LLM run."""
@ -376,7 +385,7 @@ def test_tracer_llm_run_on_error() -> None:
tracer.new_session()
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_error(exception, run_id=uuid)
assert tracer.runs == [compare_run]
_compare_run_with_error(tracer.runs[0], compare_run)
@freeze_time("2023-01-01")
@ -404,7 +413,7 @@ def test_tracer_chain_run_on_error() -> None:
tracer.new_session()
tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid)
tracer.on_chain_error(exception, run_id=uuid)
assert tracer.runs == [compare_run]
_compare_run_with_error(tracer.runs[0], compare_run)
@freeze_time("2023-01-01")
@ -433,136 +442,7 @@ def test_tracer_tool_run_on_error() -> None:
tracer.new_session()
tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid)
tracer.on_tool_error(exception, run_id=uuid)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_nested_runs_on_error() -> None:
"""Test tracer on a nested run with an error."""
exception = Exception("test")
tracer = FakeTracer()
tracer.new_session()
chain_uuid = uuid4()
tool_uuid = uuid4()
llm_uuid1 = uuid4()
llm_uuid2 = uuid4()
llm_uuid3 = uuid4()
for _ in range(3):
tracer.on_chain_start(
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
)
tracer.on_llm_start(
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid1,
parent_run_id=chain_uuid,
)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
tracer.on_llm_start(
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid2,
parent_run_id=chain_uuid,
)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
tracer.on_tool_start(
serialized={"name": "tool"},
input_str="test",
run_id=tool_uuid,
parent_run_id=chain_uuid,
)
tracer.on_llm_start(
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid3,
parent_run_id=tool_uuid,
)
tracer.on_llm_error(exception, run_id=llm_uuid3)
tracer.on_tool_error(exception, run_id=tool_uuid)
tracer.on_chain_error(exception, run_id=chain_uuid)
compare_run = ChainRun(
uuid=str(chain_uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=5,
serialized={"name": "chain"},
session_id=TEST_SESSION_ID,
error=repr(exception),
inputs={},
outputs=None,
child_llm_runs=[
LLMRun(
uuid=str(llm_uuid1),
parent_uuid=str(chain_uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=2,
child_execution_order=2,
serialized=SERIALIZED,
session_id=TEST_SESSION_ID,
error=None,
prompts=[],
response=LLMResult(generations=[[]], llm_output=None),
),
LLMRun(
uuid=str(llm_uuid2),
parent_uuid=str(chain_uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=3,
child_execution_order=3,
serialized=SERIALIZED,
session_id=TEST_SESSION_ID,
error=None,
prompts=[],
response=LLMResult(generations=[[]], llm_output=None),
),
],
child_chain_runs=[],
child_tool_runs=[
ToolRun(
uuid=str(tool_uuid),
parent_uuid=str(chain_uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=4,
child_execution_order=5,
serialized={"name": "tool"},
session_id=TEST_SESSION_ID,
error=repr(exception),
tool_input="test",
output=None,
action="{'name': 'tool'}",
child_llm_runs=[
LLMRun(
uuid=str(llm_uuid3),
parent_uuid=str(tool_uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=5,
child_execution_order=5,
serialized=SERIALIZED,
session_id=TEST_SESSION_ID,
error=repr(exception),
prompts=[],
response=None,
)
],
child_chain_runs=[],
child_tool_runs=[],
),
],
)
assert tracer.runs == [compare_run] * 3
_compare_run_with_error(tracer.runs[0], compare_run)
@pytest.fixture

View File

@ -2,9 +2,20 @@
from __future__ import annotations
import logging
import sys
import traceback
from abc import ABC, abstractmethod
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union, cast
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Sequence,
Union,
cast,
)
from uuid import UUID
from tenacity import RetryCallState
@ -45,6 +56,21 @@ class BaseTracer(BaseCallbackHandler, ABC):
def _persist_run(self, run: Run) -> None:
"""Persist a run."""
@staticmethod
def _get_stacktrace(error: BaseException) -> str:
"""Get the stacktrace of the parent error."""
msg = repr(error)
try:
if sys.version_info < (3, 10):
tb = traceback.format_exception(
error.__class__, error, error.__traceback__
)
else:
tb = traceback.format_exception(error)
return (msg + "\n\n".join(tb)).strip()
except: # noqa: E722
return msg
def _start_trace(self, run: Run) -> None:
"""Start a trace for a run."""
if run.parent_run_id:
@ -220,7 +246,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
) -> Run:
"""Handle an error for an LLM run."""
llm_run = self._get_run(run_id, run_type="llm")
llm_run.error = repr(error)
llm_run.error = self._get_stacktrace(error)
llm_run.end_time = datetime.utcnow()
llm_run.events.append({"name": "error", "time": llm_run.end_time})
self._end_trace(llm_run)
@ -296,7 +322,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
) -> Run:
"""Handle an error for a chain run."""
chain_run = self._get_run(run_id)
chain_run.error = repr(error)
chain_run.error = self._get_stacktrace(error)
chain_run.end_time = datetime.utcnow()
chain_run.events.append({"name": "error", "time": chain_run.end_time})
if inputs is not None:
@ -361,7 +387,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
) -> Run:
"""Handle an error for a tool run."""
tool_run = self._get_run(run_id, run_type="tool")
tool_run.error = repr(error)
tool_run.error = self._get_stacktrace(error)
tool_run.end_time = datetime.utcnow()
tool_run.events.append({"name": "error", "time": tool_run.end_time})
self._end_trace(tool_run)
@ -414,7 +440,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
) -> Run:
"""Run when Retriever errors."""
retrieval_run = self._get_run(run_id, run_type="retriever")
retrieval_run.error = repr(error)
retrieval_run.error = self._get_stacktrace(error)
retrieval_run.end_time = datetime.utcnow()
retrieval_run.events.append({"name": "error", "time": retrieval_run.end_time})
self._end_trace(retrieval_run)

View File

@ -630,13 +630,14 @@ def test_lambda_schemas() -> None:
}
second_lambda = lambda x, y: (x["hello"], x["bye"], y["bah"]) # noqa: E731
assert RunnableLambda(
second_lambda, # type: ignore[arg-type]
).input_schema.schema() == {
"title": "RunnableLambdaInput",
"type": "object",
"properties": {"hello": {"title": "Hello"}, "bye": {"title": "Bye"}},
}
assert (
RunnableLambda(second_lambda).input_schema.schema() # type: ignore[arg-type]
== {
"title": "RunnableLambdaInput",
"type": "object",
"properties": {"hello": {"title": "Hello"}, "bye": {"title": "Bye"}},
}
)
def get_value(input): # type: ignore[no-untyped-def]
return input["variable_name"]
@ -3624,33 +3625,32 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
parent_run_foo = parent_runs[0]
assert parent_run_foo.inputs["input"] == "foo"
assert parent_run_foo.error == repr(ValueError())
assert repr(ValueError()) in str(parent_run_foo.error)
assert len(parent_run_foo.child_runs) == 4
assert [r.error for r in parent_run_foo.child_runs] == [
assert [r.error for r in parent_run_foo.child_runs[:-1]] == [
None,
None,
None,
repr(ValueError()),
]
assert repr(ValueError()) in str(parent_run_foo.child_runs[-1].error)
parent_run_bar = parent_runs[1]
assert parent_run_bar.inputs["input"] == "bar"
assert parent_run_bar.error == repr(ValueError())
assert repr(ValueError()) in str(parent_run_bar.error)
assert len(parent_run_bar.child_runs) == 2
assert [r.error for r in parent_run_bar.child_runs] == [
None,
repr(ValueError()),
]
assert parent_run_bar.child_runs[0].error is None
assert repr(ValueError()) in str(parent_run_bar.child_runs[1].error)
parent_run_baz = parent_runs[2]
assert parent_run_baz.inputs["input"] == "baz"
assert parent_run_baz.error == repr(ValueError())
assert repr(ValueError()) in str(parent_run_baz.error)
assert len(parent_run_baz.child_runs) == 3
assert [r.error for r in parent_run_baz.child_runs] == [
assert [r.error for r in parent_run_baz.child_runs[:-1]] == [
None,
None,
repr(ValueError()),
]
assert repr(ValueError()) in str(parent_run_baz.child_runs[-1].error)
parent_run_qux = parent_runs[3]
assert parent_run_qux.inputs["input"] == "qux"
@ -3746,33 +3746,31 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
parent_run_foo = parent_runs[0]
assert parent_run_foo.inputs["input"] == "foo"
assert parent_run_foo.error == repr(ValueError())
assert repr(ValueError()) in str(parent_run_foo.error)
assert len(parent_run_foo.child_runs) == 4
assert [r.error for r in parent_run_foo.child_runs] == [
assert [r.error for r in parent_run_foo.child_runs[:-1]] == [
None,
None,
None,
repr(ValueError()),
]
assert repr(ValueError()) in str(parent_run_foo.child_runs[-1].error)
parent_run_bar = parent_runs[1]
assert parent_run_bar.inputs["input"] == "bar"
assert parent_run_bar.error == repr(ValueError())
assert repr(ValueError()) in str(parent_run_bar.error)
assert len(parent_run_bar.child_runs) == 2
assert [r.error for r in parent_run_bar.child_runs] == [
None,
repr(ValueError()),
]
assert parent_run_bar.child_runs[0].error is None
assert repr(ValueError()) in str(parent_run_bar.child_runs[1].error)
parent_run_baz = parent_runs[2]
assert parent_run_baz.inputs["input"] == "baz"
assert parent_run_baz.error == repr(ValueError())
assert repr(ValueError()) in str(parent_run_baz.error)
assert len(parent_run_baz.child_runs) == 3
assert [r.error for r in parent_run_baz.child_runs] == [
assert [r.error for r in parent_run_baz.child_runs[:-1]] == [
None,
None,
repr(ValueError()),
]
assert repr(ValueError()) in str(parent_run_baz.child_runs[-1].error)
parent_run_qux = parent_runs[3]
assert parent_run_qux.inputs["input"] == "qux"
@ -3941,7 +3939,7 @@ def test_runnable_branch_invoke_callbacks() -> None:
branch.invoke(1000, config={"callbacks": [tracer]})
assert len(tracer.runs) == 2
assert tracer.runs[1].error == "ValueError('x is too large')"
assert "ValueError('x is too large')" in str(tracer.runs[1].error)
assert tracer.runs[1].outputs is None
@ -3968,7 +3966,7 @@ async def test_runnable_branch_ainvoke_callbacks() -> None:
await branch.ainvoke(1000, config={"callbacks": [tracer]})
assert len(tracer.runs) == 2
assert tracer.runs[1].error == "ValueError('x is too large')"
assert "ValueError('x is too large')" in str(tracer.runs[1].error)
assert tracer.runs[1].outputs is None