[Tracing] String Stacktrace (#14131)

Add full stacktrace
This commit is contained in:
William FH 2023-12-14 22:15:07 -08:00 committed by GitHub
parent 7f42811e14
commit 93c7eb4e6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 98 additions and 173 deletions

View File

@ -30,6 +30,24 @@ class FakeTracer(BaseTracer):
self.runs.append(run) self.runs.append(run)
def _compare_run_with_error(run: Run, expected_run: Run) -> None:
if run.child_runs:
assert len(expected_run.child_runs) == len(run.child_runs)
for received, expected in zip(run.child_runs, expected_run.child_runs):
_compare_run_with_error(received, expected)
received = run.dict(exclude={"child_runs"})
received_err = received.pop("error")
expected = expected_run.dict(exclude={"child_runs"})
expected_err = expected.pop("error")
assert received == expected
if expected_err is not None:
assert received_err is not None
assert expected_err in received_err
else:
assert received_err is None
@freeze_time("2023-01-01") @freeze_time("2023-01-01")
def test_tracer_llm_run() -> None: def test_tracer_llm_run() -> None:
"""Test tracer on an LLM run.""" """Test tracer on an LLM run."""
@ -328,7 +346,8 @@ def test_tracer_llm_run_on_error() -> None:
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid) tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_error(exception, run_id=uuid) tracer.on_llm_error(exception, run_id=uuid)
assert tracer.runs == [compare_run] assert len(tracer.runs) == 1
_compare_run_with_error(tracer.runs[0], compare_run)
@freeze_time("2023-01-01") @freeze_time("2023-01-01")
@ -364,7 +383,7 @@ def test_tracer_llm_run_on_error_callback() -> None:
tracer = FakeTracerWithLlmErrorCallback() tracer = FakeTracerWithLlmErrorCallback()
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid) tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_error(exception, run_id=uuid) tracer.on_llm_error(exception, run_id=uuid)
assert tracer.error_run == compare_run _compare_run_with_error(tracer.error_run, compare_run)
@freeze_time("2023-01-01") @freeze_time("2023-01-01")
@ -394,7 +413,7 @@ def test_tracer_chain_run_on_error() -> None:
tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid) tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid)
tracer.on_chain_error(exception, run_id=uuid) tracer.on_chain_error(exception, run_id=uuid)
assert tracer.runs == [compare_run] _compare_run_with_error(tracer.runs[0], compare_run)
@freeze_time("2023-01-01") @freeze_time("2023-01-01")
@ -425,7 +444,7 @@ def test_tracer_tool_run_on_error() -> None:
tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid) tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid)
tracer.on_tool_error(exception, run_id=uuid) tracer.on_tool_error(exception, run_id=uuid)
assert tracer.runs == [compare_run] _compare_run_with_error(tracer.runs[0], compare_run)
@freeze_time("2023-01-01") @freeze_time("2023-01-01")
@ -568,4 +587,6 @@ def test_tracer_nested_runs_on_error() -> None:
), ),
], ],
) )
assert tracer.runs == [compare_run] * 3 assert len(tracer.runs) == 3
for run in tracer.runs:
_compare_run_with_error(run, compare_run)

View File

@ -98,6 +98,15 @@ class FakeTracer(BaseTracer):
return load_default_session() return load_default_session()
def _compare_run_with_error(run: Run, expected_run: Run) -> None:
received = run.dict()
received_err = received.pop("error")
expected = expected_run.dict()
expected_err = expected.pop("error")
assert received == expected
assert expected_err in received_err
@freeze_time("2023-01-01") @freeze_time("2023-01-01")
def test_tracer_llm_run() -> None: def test_tracer_llm_run() -> None:
"""Test tracer on an LLM run.""" """Test tracer on an LLM run."""
@ -376,7 +385,7 @@ def test_tracer_llm_run_on_error() -> None:
tracer.new_session() tracer.new_session()
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid) tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_error(exception, run_id=uuid) tracer.on_llm_error(exception, run_id=uuid)
assert tracer.runs == [compare_run] _compare_run_with_error(tracer.runs[0], compare_run)
@freeze_time("2023-01-01") @freeze_time("2023-01-01")
@ -404,7 +413,7 @@ def test_tracer_chain_run_on_error() -> None:
tracer.new_session() tracer.new_session()
tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid) tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid)
tracer.on_chain_error(exception, run_id=uuid) tracer.on_chain_error(exception, run_id=uuid)
assert tracer.runs == [compare_run] _compare_run_with_error(tracer.runs[0], compare_run)
@freeze_time("2023-01-01") @freeze_time("2023-01-01")
@ -433,136 +442,7 @@ def test_tracer_tool_run_on_error() -> None:
tracer.new_session() tracer.new_session()
tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid) tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid)
tracer.on_tool_error(exception, run_id=uuid) tracer.on_tool_error(exception, run_id=uuid)
assert tracer.runs == [compare_run] _compare_run_with_error(tracer.runs[0], compare_run)
@freeze_time("2023-01-01")
def test_tracer_nested_runs_on_error() -> None:
"""Test tracer on a nested run with an error."""
exception = Exception("test")
tracer = FakeTracer()
tracer.new_session()
chain_uuid = uuid4()
tool_uuid = uuid4()
llm_uuid1 = uuid4()
llm_uuid2 = uuid4()
llm_uuid3 = uuid4()
for _ in range(3):
tracer.on_chain_start(
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
)
tracer.on_llm_start(
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid1,
parent_run_id=chain_uuid,
)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
tracer.on_llm_start(
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid2,
parent_run_id=chain_uuid,
)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
tracer.on_tool_start(
serialized={"name": "tool"},
input_str="test",
run_id=tool_uuid,
parent_run_id=chain_uuid,
)
tracer.on_llm_start(
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid3,
parent_run_id=tool_uuid,
)
tracer.on_llm_error(exception, run_id=llm_uuid3)
tracer.on_tool_error(exception, run_id=tool_uuid)
tracer.on_chain_error(exception, run_id=chain_uuid)
compare_run = ChainRun(
uuid=str(chain_uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=5,
serialized={"name": "chain"},
session_id=TEST_SESSION_ID,
error=repr(exception),
inputs={},
outputs=None,
child_llm_runs=[
LLMRun(
uuid=str(llm_uuid1),
parent_uuid=str(chain_uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=2,
child_execution_order=2,
serialized=SERIALIZED,
session_id=TEST_SESSION_ID,
error=None,
prompts=[],
response=LLMResult(generations=[[]], llm_output=None),
),
LLMRun(
uuid=str(llm_uuid2),
parent_uuid=str(chain_uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=3,
child_execution_order=3,
serialized=SERIALIZED,
session_id=TEST_SESSION_ID,
error=None,
prompts=[],
response=LLMResult(generations=[[]], llm_output=None),
),
],
child_chain_runs=[],
child_tool_runs=[
ToolRun(
uuid=str(tool_uuid),
parent_uuid=str(chain_uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=4,
child_execution_order=5,
serialized={"name": "tool"},
session_id=TEST_SESSION_ID,
error=repr(exception),
tool_input="test",
output=None,
action="{'name': 'tool'}",
child_llm_runs=[
LLMRun(
uuid=str(llm_uuid3),
parent_uuid=str(tool_uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=5,
child_execution_order=5,
serialized=SERIALIZED,
session_id=TEST_SESSION_ID,
error=repr(exception),
prompts=[],
response=None,
)
],
child_chain_runs=[],
child_tool_runs=[],
),
],
)
assert tracer.runs == [compare_run] * 3
@pytest.fixture @pytest.fixture

View File

@ -2,9 +2,20 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import sys
import traceback
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union, cast from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Sequence,
Union,
cast,
)
from uuid import UUID from uuid import UUID
from tenacity import RetryCallState from tenacity import RetryCallState
@ -45,6 +56,21 @@ class BaseTracer(BaseCallbackHandler, ABC):
def _persist_run(self, run: Run) -> None: def _persist_run(self, run: Run) -> None:
"""Persist a run.""" """Persist a run."""
@staticmethod
def _get_stacktrace(error: BaseException) -> str:
"""Get the stacktrace of the parent error."""
msg = repr(error)
try:
if sys.version_info < (3, 10):
tb = traceback.format_exception(
error.__class__, error, error.__traceback__
)
else:
tb = traceback.format_exception(error)
return (msg + "\n\n".join(tb)).strip()
except: # noqa: E722
return msg
def _start_trace(self, run: Run) -> None: def _start_trace(self, run: Run) -> None:
"""Start a trace for a run.""" """Start a trace for a run."""
if run.parent_run_id: if run.parent_run_id:
@ -220,7 +246,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
) -> Run: ) -> Run:
"""Handle an error for an LLM run.""" """Handle an error for an LLM run."""
llm_run = self._get_run(run_id, run_type="llm") llm_run = self._get_run(run_id, run_type="llm")
llm_run.error = repr(error) llm_run.error = self._get_stacktrace(error)
llm_run.end_time = datetime.utcnow() llm_run.end_time = datetime.utcnow()
llm_run.events.append({"name": "error", "time": llm_run.end_time}) llm_run.events.append({"name": "error", "time": llm_run.end_time})
self._end_trace(llm_run) self._end_trace(llm_run)
@ -296,7 +322,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
) -> Run: ) -> Run:
"""Handle an error for a chain run.""" """Handle an error for a chain run."""
chain_run = self._get_run(run_id) chain_run = self._get_run(run_id)
chain_run.error = repr(error) chain_run.error = self._get_stacktrace(error)
chain_run.end_time = datetime.utcnow() chain_run.end_time = datetime.utcnow()
chain_run.events.append({"name": "error", "time": chain_run.end_time}) chain_run.events.append({"name": "error", "time": chain_run.end_time})
if inputs is not None: if inputs is not None:
@ -361,7 +387,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
) -> Run: ) -> Run:
"""Handle an error for a tool run.""" """Handle an error for a tool run."""
tool_run = self._get_run(run_id, run_type="tool") tool_run = self._get_run(run_id, run_type="tool")
tool_run.error = repr(error) tool_run.error = self._get_stacktrace(error)
tool_run.end_time = datetime.utcnow() tool_run.end_time = datetime.utcnow()
tool_run.events.append({"name": "error", "time": tool_run.end_time}) tool_run.events.append({"name": "error", "time": tool_run.end_time})
self._end_trace(tool_run) self._end_trace(tool_run)
@ -414,7 +440,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
) -> Run: ) -> Run:
"""Run when Retriever errors.""" """Run when Retriever errors."""
retrieval_run = self._get_run(run_id, run_type="retriever") retrieval_run = self._get_run(run_id, run_type="retriever")
retrieval_run.error = repr(error) retrieval_run.error = self._get_stacktrace(error)
retrieval_run.end_time = datetime.utcnow() retrieval_run.end_time = datetime.utcnow()
retrieval_run.events.append({"name": "error", "time": retrieval_run.end_time}) retrieval_run.events.append({"name": "error", "time": retrieval_run.end_time})
self._end_trace(retrieval_run) self._end_trace(retrieval_run)

View File

@ -630,13 +630,14 @@ def test_lambda_schemas() -> None:
} }
second_lambda = lambda x, y: (x["hello"], x["bye"], y["bah"]) # noqa: E731 second_lambda = lambda x, y: (x["hello"], x["bye"], y["bah"]) # noqa: E731
assert RunnableLambda( assert (
second_lambda, # type: ignore[arg-type] RunnableLambda(second_lambda).input_schema.schema() # type: ignore[arg-type]
).input_schema.schema() == { == {
"title": "RunnableLambdaInput", "title": "RunnableLambdaInput",
"type": "object", "type": "object",
"properties": {"hello": {"title": "Hello"}, "bye": {"title": "Bye"}}, "properties": {"hello": {"title": "Hello"}, "bye": {"title": "Bye"}},
} }
)
def get_value(input): # type: ignore[no-untyped-def] def get_value(input): # type: ignore[no-untyped-def]
return input["variable_name"] return input["variable_name"]
@ -3624,33 +3625,32 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
parent_run_foo = parent_runs[0] parent_run_foo = parent_runs[0]
assert parent_run_foo.inputs["input"] == "foo" assert parent_run_foo.inputs["input"] == "foo"
assert parent_run_foo.error == repr(ValueError()) assert repr(ValueError()) in str(parent_run_foo.error)
assert len(parent_run_foo.child_runs) == 4 assert len(parent_run_foo.child_runs) == 4
assert [r.error for r in parent_run_foo.child_runs] == [ assert [r.error for r in parent_run_foo.child_runs[:-1]] == [
None, None,
None, None,
None, None,
repr(ValueError()),
] ]
assert repr(ValueError()) in str(parent_run_foo.child_runs[-1].error)
parent_run_bar = parent_runs[1] parent_run_bar = parent_runs[1]
assert parent_run_bar.inputs["input"] == "bar" assert parent_run_bar.inputs["input"] == "bar"
assert parent_run_bar.error == repr(ValueError()) assert repr(ValueError()) in str(parent_run_bar.error)
assert len(parent_run_bar.child_runs) == 2 assert len(parent_run_bar.child_runs) == 2
assert [r.error for r in parent_run_bar.child_runs] == [ assert parent_run_bar.child_runs[0].error is None
None, assert repr(ValueError()) in str(parent_run_bar.child_runs[1].error)
repr(ValueError()),
]
parent_run_baz = parent_runs[2] parent_run_baz = parent_runs[2]
assert parent_run_baz.inputs["input"] == "baz" assert parent_run_baz.inputs["input"] == "baz"
assert parent_run_baz.error == repr(ValueError()) assert repr(ValueError()) in str(parent_run_baz.error)
assert len(parent_run_baz.child_runs) == 3 assert len(parent_run_baz.child_runs) == 3
assert [r.error for r in parent_run_baz.child_runs] == [
assert [r.error for r in parent_run_baz.child_runs[:-1]] == [
None, None,
None, None,
repr(ValueError()),
] ]
assert repr(ValueError()) in str(parent_run_baz.child_runs[-1].error)
parent_run_qux = parent_runs[3] parent_run_qux = parent_runs[3]
assert parent_run_qux.inputs["input"] == "qux" assert parent_run_qux.inputs["input"] == "qux"
@ -3746,33 +3746,31 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
parent_run_foo = parent_runs[0] parent_run_foo = parent_runs[0]
assert parent_run_foo.inputs["input"] == "foo" assert parent_run_foo.inputs["input"] == "foo"
assert parent_run_foo.error == repr(ValueError()) assert repr(ValueError()) in str(parent_run_foo.error)
assert len(parent_run_foo.child_runs) == 4 assert len(parent_run_foo.child_runs) == 4
assert [r.error for r in parent_run_foo.child_runs] == [ assert [r.error for r in parent_run_foo.child_runs[:-1]] == [
None, None,
None, None,
None, None,
repr(ValueError()),
] ]
assert repr(ValueError()) in str(parent_run_foo.child_runs[-1].error)
parent_run_bar = parent_runs[1] parent_run_bar = parent_runs[1]
assert parent_run_bar.inputs["input"] == "bar" assert parent_run_bar.inputs["input"] == "bar"
assert parent_run_bar.error == repr(ValueError()) assert repr(ValueError()) in str(parent_run_bar.error)
assert len(parent_run_bar.child_runs) == 2 assert len(parent_run_bar.child_runs) == 2
assert [r.error for r in parent_run_bar.child_runs] == [ assert parent_run_bar.child_runs[0].error is None
None, assert repr(ValueError()) in str(parent_run_bar.child_runs[1].error)
repr(ValueError()),
]
parent_run_baz = parent_runs[2] parent_run_baz = parent_runs[2]
assert parent_run_baz.inputs["input"] == "baz" assert parent_run_baz.inputs["input"] == "baz"
assert parent_run_baz.error == repr(ValueError()) assert repr(ValueError()) in str(parent_run_baz.error)
assert len(parent_run_baz.child_runs) == 3 assert len(parent_run_baz.child_runs) == 3
assert [r.error for r in parent_run_baz.child_runs] == [ assert [r.error for r in parent_run_baz.child_runs[:-1]] == [
None, None,
None, None,
repr(ValueError()),
] ]
assert repr(ValueError()) in str(parent_run_baz.child_runs[-1].error)
parent_run_qux = parent_runs[3] parent_run_qux = parent_runs[3]
assert parent_run_qux.inputs["input"] == "qux" assert parent_run_qux.inputs["input"] == "qux"
@ -3941,7 +3939,7 @@ def test_runnable_branch_invoke_callbacks() -> None:
branch.invoke(1000, config={"callbacks": [tracer]}) branch.invoke(1000, config={"callbacks": [tracer]})
assert len(tracer.runs) == 2 assert len(tracer.runs) == 2
assert tracer.runs[1].error == "ValueError('x is too large')" assert "ValueError('x is too large')" in str(tracer.runs[1].error)
assert tracer.runs[1].outputs is None assert tracer.runs[1].outputs is None
@ -3968,7 +3966,7 @@ async def test_runnable_branch_ainvoke_callbacks() -> None:
await branch.ainvoke(1000, config={"callbacks": [tracer]}) await branch.ainvoke(1000, config={"callbacks": [tracer]})
assert len(tracer.runs) == 2 assert len(tracer.runs) == 2
assert tracer.runs[1].error == "ValueError('x is too large')" assert "ValueError('x is too large')" in str(tracer.runs[1].error)
assert tracer.runs[1].outputs is None assert tracer.runs[1].outputs is None