mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-10 05:20:39 +00:00
parent
7f42811e14
commit
93c7eb4e6b
@ -30,6 +30,24 @@ class FakeTracer(BaseTracer):
|
||||
self.runs.append(run)
|
||||
|
||||
|
||||
def _compare_run_with_error(run: Run, expected_run: Run) -> None:
|
||||
if run.child_runs:
|
||||
assert len(expected_run.child_runs) == len(run.child_runs)
|
||||
for received, expected in zip(run.child_runs, expected_run.child_runs):
|
||||
_compare_run_with_error(received, expected)
|
||||
received = run.dict(exclude={"child_runs"})
|
||||
received_err = received.pop("error")
|
||||
expected = expected_run.dict(exclude={"child_runs"})
|
||||
expected_err = expected.pop("error")
|
||||
|
||||
assert received == expected
|
||||
if expected_err is not None:
|
||||
assert received_err is not None
|
||||
assert expected_err in received_err
|
||||
else:
|
||||
assert received_err is None
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_llm_run() -> None:
|
||||
"""Test tracer on an LLM run."""
|
||||
@ -328,7 +346,8 @@ def test_tracer_llm_run_on_error() -> None:
|
||||
|
||||
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_error(exception, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
assert len(tracer.runs) == 1
|
||||
_compare_run_with_error(tracer.runs[0], compare_run)
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
@ -364,7 +383,7 @@ def test_tracer_llm_run_on_error_callback() -> None:
|
||||
tracer = FakeTracerWithLlmErrorCallback()
|
||||
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_error(exception, run_id=uuid)
|
||||
assert tracer.error_run == compare_run
|
||||
_compare_run_with_error(tracer.error_run, compare_run)
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
@ -394,7 +413,7 @@ def test_tracer_chain_run_on_error() -> None:
|
||||
|
||||
tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid)
|
||||
tracer.on_chain_error(exception, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
_compare_run_with_error(tracer.runs[0], compare_run)
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
@ -425,7 +444,7 @@ def test_tracer_tool_run_on_error() -> None:
|
||||
|
||||
tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid)
|
||||
tracer.on_tool_error(exception, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
_compare_run_with_error(tracer.runs[0], compare_run)
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
@ -568,4 +587,6 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
),
|
||||
],
|
||||
)
|
||||
assert tracer.runs == [compare_run] * 3
|
||||
assert len(tracer.runs) == 3
|
||||
for run in tracer.runs:
|
||||
_compare_run_with_error(run, compare_run)
|
||||
|
@ -98,6 +98,15 @@ class FakeTracer(BaseTracer):
|
||||
return load_default_session()
|
||||
|
||||
|
||||
def _compare_run_with_error(run: Run, expected_run: Run) -> None:
|
||||
received = run.dict()
|
||||
received_err = received.pop("error")
|
||||
expected = expected_run.dict()
|
||||
expected_err = expected.pop("error")
|
||||
assert received == expected
|
||||
assert expected_err in received_err
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_llm_run() -> None:
|
||||
"""Test tracer on an LLM run."""
|
||||
@ -376,7 +385,7 @@ def test_tracer_llm_run_on_error() -> None:
|
||||
tracer.new_session()
|
||||
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_error(exception, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
_compare_run_with_error(tracer.runs[0], compare_run)
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
@ -404,7 +413,7 @@ def test_tracer_chain_run_on_error() -> None:
|
||||
tracer.new_session()
|
||||
tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid)
|
||||
tracer.on_chain_error(exception, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
_compare_run_with_error(tracer.runs[0], compare_run)
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
@ -433,136 +442,7 @@ def test_tracer_tool_run_on_error() -> None:
|
||||
tracer.new_session()
|
||||
tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid)
|
||||
tracer.on_tool_error(exception, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_nested_runs_on_error() -> None:
|
||||
"""Test tracer on a nested run with an error."""
|
||||
exception = Exception("test")
|
||||
|
||||
tracer = FakeTracer()
|
||||
tracer.new_session()
|
||||
chain_uuid = uuid4()
|
||||
tool_uuid = uuid4()
|
||||
llm_uuid1 = uuid4()
|
||||
llm_uuid2 = uuid4()
|
||||
llm_uuid3 = uuid4()
|
||||
|
||||
for _ in range(3):
|
||||
tracer.on_chain_start(
|
||||
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
|
||||
)
|
||||
tracer.on_llm_start(
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
run_id=llm_uuid1,
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
||||
tracer.on_llm_start(
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
run_id=llm_uuid2,
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
|
||||
tracer.on_tool_start(
|
||||
serialized={"name": "tool"},
|
||||
input_str="test",
|
||||
run_id=tool_uuid,
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
tracer.on_llm_start(
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
run_id=llm_uuid3,
|
||||
parent_run_id=tool_uuid,
|
||||
)
|
||||
tracer.on_llm_error(exception, run_id=llm_uuid3)
|
||||
tracer.on_tool_error(exception, run_id=tool_uuid)
|
||||
tracer.on_chain_error(exception, run_id=chain_uuid)
|
||||
|
||||
compare_run = ChainRun(
|
||||
uuid=str(chain_uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=5,
|
||||
serialized={"name": "chain"},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
inputs={},
|
||||
outputs=None,
|
||||
child_llm_runs=[
|
||||
LLMRun(
|
||||
uuid=str(llm_uuid1),
|
||||
parent_uuid=str(chain_uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=2,
|
||||
child_execution_order=2,
|
||||
serialized=SERIALIZED,
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]], llm_output=None),
|
||||
),
|
||||
LLMRun(
|
||||
uuid=str(llm_uuid2),
|
||||
parent_uuid=str(chain_uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=3,
|
||||
child_execution_order=3,
|
||||
serialized=SERIALIZED,
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]], llm_output=None),
|
||||
),
|
||||
],
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[
|
||||
ToolRun(
|
||||
uuid=str(tool_uuid),
|
||||
parent_uuid=str(chain_uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=4,
|
||||
child_execution_order=5,
|
||||
serialized={"name": "tool"},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
tool_input="test",
|
||||
output=None,
|
||||
action="{'name': 'tool'}",
|
||||
child_llm_runs=[
|
||||
LLMRun(
|
||||
uuid=str(llm_uuid3),
|
||||
parent_uuid=str(tool_uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=5,
|
||||
child_execution_order=5,
|
||||
serialized=SERIALIZED,
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
prompts=[],
|
||||
response=None,
|
||||
)
|
||||
],
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[],
|
||||
),
|
||||
],
|
||||
)
|
||||
assert tracer.runs == [compare_run] * 3
|
||||
_compare_run_with_error(tracer.runs[0], compare_run)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -2,9 +2,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union, cast
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from uuid import UUID
|
||||
|
||||
from tenacity import RetryCallState
|
||||
@ -45,6 +56,21 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
def _persist_run(self, run: Run) -> None:
|
||||
"""Persist a run."""
|
||||
|
||||
@staticmethod
|
||||
def _get_stacktrace(error: BaseException) -> str:
|
||||
"""Get the stacktrace of the parent error."""
|
||||
msg = repr(error)
|
||||
try:
|
||||
if sys.version_info < (3, 10):
|
||||
tb = traceback.format_exception(
|
||||
error.__class__, error, error.__traceback__
|
||||
)
|
||||
else:
|
||||
tb = traceback.format_exception(error)
|
||||
return (msg + "\n\n".join(tb)).strip()
|
||||
except: # noqa: E722
|
||||
return msg
|
||||
|
||||
def _start_trace(self, run: Run) -> None:
|
||||
"""Start a trace for a run."""
|
||||
if run.parent_run_id:
|
||||
@ -220,7 +246,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
) -> Run:
|
||||
"""Handle an error for an LLM run."""
|
||||
llm_run = self._get_run(run_id, run_type="llm")
|
||||
llm_run.error = repr(error)
|
||||
llm_run.error = self._get_stacktrace(error)
|
||||
llm_run.end_time = datetime.utcnow()
|
||||
llm_run.events.append({"name": "error", "time": llm_run.end_time})
|
||||
self._end_trace(llm_run)
|
||||
@ -296,7 +322,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
) -> Run:
|
||||
"""Handle an error for a chain run."""
|
||||
chain_run = self._get_run(run_id)
|
||||
chain_run.error = repr(error)
|
||||
chain_run.error = self._get_stacktrace(error)
|
||||
chain_run.end_time = datetime.utcnow()
|
||||
chain_run.events.append({"name": "error", "time": chain_run.end_time})
|
||||
if inputs is not None:
|
||||
@ -361,7 +387,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
) -> Run:
|
||||
"""Handle an error for a tool run."""
|
||||
tool_run = self._get_run(run_id, run_type="tool")
|
||||
tool_run.error = repr(error)
|
||||
tool_run.error = self._get_stacktrace(error)
|
||||
tool_run.end_time = datetime.utcnow()
|
||||
tool_run.events.append({"name": "error", "time": tool_run.end_time})
|
||||
self._end_trace(tool_run)
|
||||
@ -414,7 +440,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
) -> Run:
|
||||
"""Run when Retriever errors."""
|
||||
retrieval_run = self._get_run(run_id, run_type="retriever")
|
||||
retrieval_run.error = repr(error)
|
||||
retrieval_run.error = self._get_stacktrace(error)
|
||||
retrieval_run.end_time = datetime.utcnow()
|
||||
retrieval_run.events.append({"name": "error", "time": retrieval_run.end_time})
|
||||
self._end_trace(retrieval_run)
|
||||
|
@ -630,13 +630,14 @@ def test_lambda_schemas() -> None:
|
||||
}
|
||||
|
||||
second_lambda = lambda x, y: (x["hello"], x["bye"], y["bah"]) # noqa: E731
|
||||
assert RunnableLambda(
|
||||
second_lambda, # type: ignore[arg-type]
|
||||
).input_schema.schema() == {
|
||||
"title": "RunnableLambdaInput",
|
||||
"type": "object",
|
||||
"properties": {"hello": {"title": "Hello"}, "bye": {"title": "Bye"}},
|
||||
}
|
||||
assert (
|
||||
RunnableLambda(second_lambda).input_schema.schema() # type: ignore[arg-type]
|
||||
== {
|
||||
"title": "RunnableLambdaInput",
|
||||
"type": "object",
|
||||
"properties": {"hello": {"title": "Hello"}, "bye": {"title": "Bye"}},
|
||||
}
|
||||
)
|
||||
|
||||
def get_value(input): # type: ignore[no-untyped-def]
|
||||
return input["variable_name"]
|
||||
@ -3624,33 +3625,32 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
|
||||
|
||||
parent_run_foo = parent_runs[0]
|
||||
assert parent_run_foo.inputs["input"] == "foo"
|
||||
assert parent_run_foo.error == repr(ValueError())
|
||||
assert repr(ValueError()) in str(parent_run_foo.error)
|
||||
assert len(parent_run_foo.child_runs) == 4
|
||||
assert [r.error for r in parent_run_foo.child_runs] == [
|
||||
assert [r.error for r in parent_run_foo.child_runs[:-1]] == [
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
repr(ValueError()),
|
||||
]
|
||||
assert repr(ValueError()) in str(parent_run_foo.child_runs[-1].error)
|
||||
|
||||
parent_run_bar = parent_runs[1]
|
||||
assert parent_run_bar.inputs["input"] == "bar"
|
||||
assert parent_run_bar.error == repr(ValueError())
|
||||
assert repr(ValueError()) in str(parent_run_bar.error)
|
||||
assert len(parent_run_bar.child_runs) == 2
|
||||
assert [r.error for r in parent_run_bar.child_runs] == [
|
||||
None,
|
||||
repr(ValueError()),
|
||||
]
|
||||
assert parent_run_bar.child_runs[0].error is None
|
||||
assert repr(ValueError()) in str(parent_run_bar.child_runs[1].error)
|
||||
|
||||
parent_run_baz = parent_runs[2]
|
||||
assert parent_run_baz.inputs["input"] == "baz"
|
||||
assert parent_run_baz.error == repr(ValueError())
|
||||
assert repr(ValueError()) in str(parent_run_baz.error)
|
||||
assert len(parent_run_baz.child_runs) == 3
|
||||
assert [r.error for r in parent_run_baz.child_runs] == [
|
||||
|
||||
assert [r.error for r in parent_run_baz.child_runs[:-1]] == [
|
||||
None,
|
||||
None,
|
||||
repr(ValueError()),
|
||||
]
|
||||
assert repr(ValueError()) in str(parent_run_baz.child_runs[-1].error)
|
||||
|
||||
parent_run_qux = parent_runs[3]
|
||||
assert parent_run_qux.inputs["input"] == "qux"
|
||||
@ -3746,33 +3746,31 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
|
||||
|
||||
parent_run_foo = parent_runs[0]
|
||||
assert parent_run_foo.inputs["input"] == "foo"
|
||||
assert parent_run_foo.error == repr(ValueError())
|
||||
assert repr(ValueError()) in str(parent_run_foo.error)
|
||||
assert len(parent_run_foo.child_runs) == 4
|
||||
assert [r.error for r in parent_run_foo.child_runs] == [
|
||||
assert [r.error for r in parent_run_foo.child_runs[:-1]] == [
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
repr(ValueError()),
|
||||
]
|
||||
assert repr(ValueError()) in str(parent_run_foo.child_runs[-1].error)
|
||||
|
||||
parent_run_bar = parent_runs[1]
|
||||
assert parent_run_bar.inputs["input"] == "bar"
|
||||
assert parent_run_bar.error == repr(ValueError())
|
||||
assert repr(ValueError()) in str(parent_run_bar.error)
|
||||
assert len(parent_run_bar.child_runs) == 2
|
||||
assert [r.error for r in parent_run_bar.child_runs] == [
|
||||
None,
|
||||
repr(ValueError()),
|
||||
]
|
||||
assert parent_run_bar.child_runs[0].error is None
|
||||
assert repr(ValueError()) in str(parent_run_bar.child_runs[1].error)
|
||||
|
||||
parent_run_baz = parent_runs[2]
|
||||
assert parent_run_baz.inputs["input"] == "baz"
|
||||
assert parent_run_baz.error == repr(ValueError())
|
||||
assert repr(ValueError()) in str(parent_run_baz.error)
|
||||
assert len(parent_run_baz.child_runs) == 3
|
||||
assert [r.error for r in parent_run_baz.child_runs] == [
|
||||
assert [r.error for r in parent_run_baz.child_runs[:-1]] == [
|
||||
None,
|
||||
None,
|
||||
repr(ValueError()),
|
||||
]
|
||||
assert repr(ValueError()) in str(parent_run_baz.child_runs[-1].error)
|
||||
|
||||
parent_run_qux = parent_runs[3]
|
||||
assert parent_run_qux.inputs["input"] == "qux"
|
||||
@ -3941,7 +3939,7 @@ def test_runnable_branch_invoke_callbacks() -> None:
|
||||
branch.invoke(1000, config={"callbacks": [tracer]})
|
||||
|
||||
assert len(tracer.runs) == 2
|
||||
assert tracer.runs[1].error == "ValueError('x is too large')"
|
||||
assert "ValueError('x is too large')" in str(tracer.runs[1].error)
|
||||
assert tracer.runs[1].outputs is None
|
||||
|
||||
|
||||
@ -3968,7 +3966,7 @@ async def test_runnable_branch_ainvoke_callbacks() -> None:
|
||||
await branch.ainvoke(1000, config={"callbacks": [tracer]})
|
||||
|
||||
assert len(tracer.runs) == 2
|
||||
assert tracer.runs[1].error == "ValueError('x is too large')"
|
||||
assert "ValueError('x is too large')" in str(tracer.runs[1].error)
|
||||
assert tracer.runs[1].outputs is None
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user