mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-16 06:53:16 +00:00
core[minor]: Add an async root listener and with_alisteners method (#22151)
- [x] **Adding AsyncRootListener**: "langchain_core: Adding AsyncRootListener" - **Description:** Adding an AsyncBaseTracer, AsyncRootListener and `with_alistener` function. This is to enable binding async root listener to runnables. This currently only supported for sync listeners. - **Issue:** None - **Dependencies:** None - [x] **Add tests and docs**: Added units tests and example snippet code within the function description of `with_alistener` - [x] **Lint and test**: Run make format_diff, make lint_diff and make test
This commit is contained in:
@@ -95,6 +95,7 @@ if TYPE_CHECKING:
|
||||
RunLog,
|
||||
RunLogPatch,
|
||||
)
|
||||
from langchain_core.tracers.root_listeners import AsyncListener
|
||||
from langchain_core.tracers.schemas import Run
|
||||
|
||||
|
||||
@@ -1327,6 +1328,86 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
],
|
||||
)
|
||||
|
||||
def with_alisteners(
|
||||
self,
|
||||
*,
|
||||
on_start: Optional[AsyncListener] = None,
|
||||
on_end: Optional[AsyncListener] = None,
|
||||
on_error: Optional[AsyncListener] = None,
|
||||
) -> Runnable[Input, Output]:
|
||||
"""
|
||||
Bind asynchronous lifecycle listeners to a Runnable, returning a new Runnable.
|
||||
|
||||
on_start: Asynchronously called before the runnable starts running.
|
||||
on_end: Asynchronously called after the runnable finishes running.
|
||||
on_error: Asynchronously called if the runnable throws an error.
|
||||
|
||||
The Run object contains information about the run, including its id,
|
||||
type, input, output, error, start_time, end_time, and any tags or metadata
|
||||
added to the run.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
import time
|
||||
|
||||
async def test_runnable(time_to_sleep : int):
|
||||
print(f"Runnable[{time_to_sleep}s]: starts at {format_t(time.time())}")
|
||||
await asyncio.sleep(time_to_sleep)
|
||||
print(f"Runnable[{time_to_sleep}s]: ends at {format_t(time.time())}")
|
||||
|
||||
async def fn_start(run_obj : Runnable):
|
||||
print(f"on start callback starts at {format_t(time.time())}
|
||||
await asyncio.sleep(3)
|
||||
print(f"on start callback ends at {format_t(time.time())}")
|
||||
|
||||
async def fn_end(run_obj : Runnable):
|
||||
print(f"on end callback starts at {format_t(time.time())}
|
||||
await asyncio.sleep(2)
|
||||
print(f"on end callback ends at {format_t(time.time())}")
|
||||
|
||||
runnable = RunnableLambda(test_runnable).with_alisteners(
|
||||
on_start=fn_start,
|
||||
on_end=fn_end
|
||||
)
|
||||
async def concurrent_runs():
|
||||
await asyncio.gather(runnable.ainvoke(2), runnable.ainvoke(3))
|
||||
|
||||
asyncio.run(concurrent_runs())
|
||||
Result:
|
||||
on start callback starts at 2024-05-16T14:20:29.637053+00:00
|
||||
on start callback starts at 2024-05-16T14:20:29.637150+00:00
|
||||
on start callback ends at 2024-05-16T14:20:32.638305+00:00
|
||||
on start callback ends at 2024-05-16T14:20:32.638383+00:00
|
||||
Runnable[3s]: starts at 2024-05-16T14:20:32.638849+00:00
|
||||
Runnable[5s]: starts at 2024-05-16T14:20:32.638999+00:00
|
||||
Runnable[3s]: ends at 2024-05-16T14:20:35.640016+00:00
|
||||
on end callback starts at 2024-05-16T14:20:35.640534+00:00
|
||||
Runnable[5s]: ends at 2024-05-16T14:20:37.640169+00:00
|
||||
on end callback starts at 2024-05-16T14:20:37.640574+00:00
|
||||
on end callback ends at 2024-05-16T14:20:37.640654+00:00
|
||||
on end callback ends at 2024-05-16T14:20:39.641751+00:00
|
||||
|
||||
"""
|
||||
from langchain_core.tracers.root_listeners import AsyncRootListenersTracer
|
||||
|
||||
return RunnableBinding(
|
||||
bound=self,
|
||||
config_factories=[
|
||||
lambda config: {
|
||||
"callbacks": [
|
||||
AsyncRootListenersTracer(
|
||||
config=config,
|
||||
on_start=on_start,
|
||||
on_end=on_end,
|
||||
on_error=on_error,
|
||||
)
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
def with_types(
|
||||
self,
|
||||
*,
|
||||
@@ -4294,6 +4375,33 @@ class RunnableEach(RunnableEachBase[Input, Output]):
|
||||
)
|
||||
)
|
||||
|
||||
def with_alisteners(
|
||||
self,
|
||||
*,
|
||||
on_start: Optional[AsyncListener] = None,
|
||||
on_end: Optional[AsyncListener] = None,
|
||||
on_error: Optional[AsyncListener] = None,
|
||||
) -> RunnableEach[Input, Output]:
|
||||
"""
|
||||
Bind async lifecycle listeners to a Runnable, returning a new Runnable.
|
||||
|
||||
on_start: Called asynchronously before the runnable starts running,
|
||||
with the Run object.
|
||||
on_end: Called asynchronously after the runnable finishes running,
|
||||
with the Run object.
|
||||
on_error: Called asynchronously if the runnable throws an error,
|
||||
with the Run object.
|
||||
|
||||
The Run object contains information about the run, including its id,
|
||||
type, input, output, error, start_time, end_time, and any tags or metadata
|
||||
added to the run.
|
||||
"""
|
||||
return RunnableEach(
|
||||
bound=self.bound.with_alisteners(
|
||||
on_start=on_start, on_end=on_end, on_error=on_error
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
"""Runnable that delegates calls to another Runnable with a set of kwargs.
|
||||
|
@@ -2,38 +2,27 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime, timezone
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from uuid import UUID
|
||||
|
||||
from tenacity import RetryCallState
|
||||
|
||||
from langchain_core.callbacks.base import BaseCallbackHandler
|
||||
from langchain_core.exceptions import TracerException
|
||||
from langchain_core.load import dumpd
|
||||
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
||||
from langchain_core.exceptions import TracerException # noqa
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatGenerationChunk,
|
||||
GenerationChunk,
|
||||
LLMResult,
|
||||
)
|
||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult
|
||||
from langchain_core.tracers.core import _TracerCore
|
||||
from langchain_core.tracers.schemas import Run
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -42,90 +31,16 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseTracer(BaseCallbackHandler, ABC):
|
||||
class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
"""Base interface for tracers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
_schema_format: Literal[
|
||||
"original", "streaming_events", "original+chat"
|
||||
] = "original",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the tracer.
|
||||
|
||||
Args:
|
||||
_schema_format: Primarily changes how the inputs and outputs are
|
||||
handled. For internal use only. This API will change.
|
||||
- 'original' is the format used by all current tracers.
|
||||
This format is slightly inconsistent with respect to inputs
|
||||
and outputs.
|
||||
- 'streaming_events' is used for supporting streaming events,
|
||||
for internal usage. It will likely change in the future, or
|
||||
be deprecated entirely in favor of a dedicated async tracer
|
||||
for streaming events.
|
||||
- 'original+chat' is a format that is the same as 'original'
|
||||
except it does NOT raise an attribute error on_chat_model_start
|
||||
kwargs: Additional keyword arguments that will be passed to
|
||||
the super class.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._schema_format = _schema_format # For internal use only API will change.
|
||||
self.run_map: Dict[str, Run] = {}
|
||||
"""Map of run ID to run. Cleared on run end."""
|
||||
self.order_map: Dict[UUID, Tuple[UUID, str]] = {}
|
||||
"""Map of run ID to (trace_id, dotted_order). Cleared when tracer GCed."""
|
||||
|
||||
@staticmethod
|
||||
def _add_child_run(
|
||||
parent_run: Run,
|
||||
child_run: Run,
|
||||
) -> None:
|
||||
"""Add child run to a chain run or tool run."""
|
||||
parent_run.child_runs.append(child_run)
|
||||
|
||||
@abstractmethod
|
||||
def _persist_run(self, run: Run) -> None:
|
||||
"""Persist a run."""
|
||||
|
||||
@staticmethod
|
||||
def _get_stacktrace(error: BaseException) -> str:
|
||||
"""Get the stacktrace of the parent error."""
|
||||
msg = repr(error)
|
||||
try:
|
||||
if sys.version_info < (3, 10):
|
||||
tb = traceback.format_exception(
|
||||
error.__class__, error, error.__traceback__
|
||||
)
|
||||
else:
|
||||
tb = traceback.format_exception(error)
|
||||
return (msg + "\n\n".join(tb)).strip()
|
||||
except: # noqa: E722
|
||||
return msg
|
||||
|
||||
def _start_trace(self, run: Run) -> None:
|
||||
"""Start a trace for a run."""
|
||||
current_dotted_order = run.start_time.strftime("%Y%m%dT%H%M%S%fZ") + str(run.id)
|
||||
if run.parent_run_id:
|
||||
if parent := self.order_map.get(run.parent_run_id):
|
||||
run.trace_id, run.dotted_order = parent
|
||||
run.dotted_order += "." + current_dotted_order
|
||||
if parent_run := self.run_map.get(str(run.parent_run_id)):
|
||||
self._add_child_run(parent_run, run)
|
||||
else:
|
||||
logger.debug(
|
||||
f"Parent run {run.parent_run_id} not found for run {run.id}."
|
||||
" Treating as a root run."
|
||||
)
|
||||
run.parent_run_id = None
|
||||
run.trace_id = run.id
|
||||
run.dotted_order = current_dotted_order
|
||||
else:
|
||||
run.trace_id = run.id
|
||||
run.dotted_order = current_dotted_order
|
||||
self.order_map[run.id] = (run.trace_id, run.dotted_order)
|
||||
self.run_map[str(run.id)] = run
|
||||
super()._start_trace(run)
|
||||
self._on_run_create(run)
|
||||
|
||||
def _end_trace(self, run: Run) -> None:
|
||||
@@ -135,25 +50,6 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
self.run_map.pop(str(run.id))
|
||||
self._on_run_update(run)
|
||||
|
||||
def _get_run(
|
||||
self, run_id: UUID, run_type: Union[str, Set[str], None] = None
|
||||
) -> Run:
|
||||
try:
|
||||
run = self.run_map[str(run_id)]
|
||||
except KeyError as exc:
|
||||
raise TracerException(f"No indexed run ID {run_id}.") from exc
|
||||
|
||||
if isinstance(run_type, str):
|
||||
run_types: Union[Set[str], None] = {run_type}
|
||||
else:
|
||||
run_types = run_type
|
||||
if run_types is not None and run.run_type not in run_types:
|
||||
raise TracerException(
|
||||
f"Found {run.run_type} run at ID {run_id}, "
|
||||
f"but expected {run_types} run."
|
||||
)
|
||||
return run
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
@@ -167,35 +63,15 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Start a trace for an LLM run."""
|
||||
if self._schema_format not in ("streaming_events", "original+chat"):
|
||||
# Please keep this un-implemented for backwards compatibility.
|
||||
# When it's unimplemented old tracers that use the "original" format
|
||||
# fallback on the on_llm_start method implementation if they
|
||||
# find that the on_chat_model_start method is not implemented.
|
||||
# This can eventually be cleaned up by writing a "modern" tracer
|
||||
# that has all the updated schema changes corresponding to
|
||||
# the "streaming_events" format.
|
||||
raise NotImplementedError(
|
||||
f"Chat model tracing is not supported in "
|
||||
f"for {self._schema_format} format."
|
||||
)
|
||||
start_time = datetime.now(timezone.utc)
|
||||
if metadata:
|
||||
kwargs.update({"metadata": metadata})
|
||||
chat_model_run = Run(
|
||||
id=run_id,
|
||||
parent_run_id=parent_run_id,
|
||||
chat_model_run = self._create_chat_model_run(
|
||||
serialized=serialized,
|
||||
inputs={"messages": [[dumpd(msg) for msg in batch] for batch in messages]},
|
||||
extra=kwargs,
|
||||
events=[{"name": "start", "time": start_time}],
|
||||
start_time=start_time,
|
||||
# WARNING: This is valid ONLY for streaming_events.
|
||||
# run_type="llm" is what's used by virtually all tracers.
|
||||
# Changing this to "chat_model" may break triggering on_llm_start
|
||||
run_type="chat_model",
|
||||
messages=messages,
|
||||
run_id=run_id,
|
||||
parent_run_id=parent_run_id,
|
||||
tags=tags,
|
||||
name=name, # type: ignore[arg-type]
|
||||
metadata=metadata,
|
||||
name=name,
|
||||
**kwargs,
|
||||
)
|
||||
self._start_trace(chat_model_run)
|
||||
self._on_chat_model_start(chat_model_run)
|
||||
@@ -214,21 +90,15 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Start a trace for an LLM run."""
|
||||
start_time = datetime.now(timezone.utc)
|
||||
if metadata:
|
||||
kwargs.update({"metadata": metadata})
|
||||
llm_run = Run(
|
||||
id=run_id,
|
||||
parent_run_id=parent_run_id,
|
||||
llm_run = self._create_llm_run(
|
||||
serialized=serialized,
|
||||
# TODO: Figure out how to expose kwargs here
|
||||
inputs={"prompts": prompts},
|
||||
extra=kwargs,
|
||||
events=[{"name": "start", "time": start_time}],
|
||||
start_time=start_time,
|
||||
run_type="llm",
|
||||
tags=tags or [],
|
||||
name=name, # type: ignore[arg-type]
|
||||
prompts=prompts,
|
||||
run_id=run_id,
|
||||
parent_run_id=parent_run_id,
|
||||
tags=tags,
|
||||
metadata=metadata,
|
||||
name=name,
|
||||
**kwargs,
|
||||
)
|
||||
self._start_trace(llm_run)
|
||||
self._on_llm_start(llm_run)
|
||||
@@ -246,16 +116,12 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
# "chat_model" is only used for the experimental new streaming_events format.
|
||||
# This change should not affect any existing tracers.
|
||||
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
|
||||
event_kwargs: Dict[str, Any] = {"token": token}
|
||||
if chunk:
|
||||
event_kwargs["chunk"] = chunk
|
||||
llm_run.events.append(
|
||||
{
|
||||
"name": "new_token",
|
||||
"time": datetime.now(timezone.utc),
|
||||
"kwargs": event_kwargs,
|
||||
},
|
||||
llm_run = self._llm_run_with_token_event(
|
||||
token=token,
|
||||
run_id=run_id,
|
||||
chunk=chunk,
|
||||
parent_run_id=parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
self._on_llm_new_token(llm_run, token, chunk)
|
||||
return llm_run
|
||||
@@ -267,27 +133,9 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
llm_run = self._get_run(run_id)
|
||||
retry_d: Dict[str, Any] = {
|
||||
"slept": retry_state.idle_for,
|
||||
"attempt": retry_state.attempt_number,
|
||||
}
|
||||
if retry_state.outcome is None:
|
||||
retry_d["outcome"] = "N/A"
|
||||
elif retry_state.outcome.failed:
|
||||
retry_d["outcome"] = "failed"
|
||||
exception = retry_state.outcome.exception()
|
||||
retry_d["exception"] = str(exception)
|
||||
retry_d["exception_type"] = exception.__class__.__name__
|
||||
else:
|
||||
retry_d["outcome"] = "success"
|
||||
retry_d["result"] = str(retry_state.outcome.result())
|
||||
llm_run.events.append(
|
||||
{
|
||||
"name": "retry",
|
||||
"time": datetime.now(timezone.utc),
|
||||
"kwargs": retry_d,
|
||||
},
|
||||
llm_run = self._llm_run_with_retry_event(
|
||||
retry_state=retry_state,
|
||||
run_id=run_id,
|
||||
)
|
||||
return llm_run
|
||||
|
||||
@@ -295,17 +143,10 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
"""End a trace for an LLM run."""
|
||||
# "chat_model" is only used for the experimental new streaming_events format.
|
||||
# This change should not affect any existing tracers.
|
||||
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
|
||||
llm_run.outputs = response.dict()
|
||||
for i, generations in enumerate(response.generations):
|
||||
for j, generation in enumerate(generations):
|
||||
output_generation = llm_run.outputs["generations"][i][j]
|
||||
if "message" in output_generation:
|
||||
output_generation["message"] = dumpd(
|
||||
cast(ChatGeneration, generation).message
|
||||
)
|
||||
llm_run.end_time = datetime.now(timezone.utc)
|
||||
llm_run.events.append({"name": "end", "time": llm_run.end_time})
|
||||
llm_run = self._complete_llm_run(
|
||||
response=response,
|
||||
run_id=run_id,
|
||||
)
|
||||
self._end_trace(llm_run)
|
||||
self._on_llm_end(llm_run)
|
||||
return llm_run
|
||||
@@ -320,10 +161,10 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
"""Handle an error for an LLM run."""
|
||||
# "chat_model" is only used for the experimental new streaming_events format.
|
||||
# This change should not affect any existing tracers.
|
||||
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
|
||||
llm_run.error = self._get_stacktrace(error)
|
||||
llm_run.end_time = datetime.now(timezone.utc)
|
||||
llm_run.events.append({"name": "error", "time": llm_run.end_time})
|
||||
llm_run = self._errored_llm_run(
|
||||
error=error,
|
||||
run_id=run_id,
|
||||
)
|
||||
self._end_trace(llm_run)
|
||||
self._on_llm_error(llm_run)
|
||||
return llm_run
|
||||
@@ -342,48 +183,21 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Start a trace for a chain run."""
|
||||
start_time = datetime.now(timezone.utc)
|
||||
if metadata:
|
||||
kwargs.update({"metadata": metadata})
|
||||
chain_run = Run(
|
||||
id=run_id,
|
||||
parent_run_id=parent_run_id,
|
||||
chain_run = self._create_chain_run(
|
||||
serialized=serialized,
|
||||
inputs=self._get_chain_inputs(inputs),
|
||||
extra=kwargs,
|
||||
events=[{"name": "start", "time": start_time}],
|
||||
start_time=start_time,
|
||||
child_runs=[],
|
||||
run_type=run_type or "chain",
|
||||
name=name, # type: ignore[arg-type]
|
||||
tags=tags or [],
|
||||
inputs=inputs,
|
||||
run_id=run_id,
|
||||
tags=tags,
|
||||
parent_run_id=parent_run_id,
|
||||
metadata=metadata,
|
||||
run_type=run_type,
|
||||
name=name,
|
||||
**kwargs,
|
||||
)
|
||||
self._start_trace(chain_run)
|
||||
self._on_chain_start(chain_run)
|
||||
return chain_run
|
||||
|
||||
def _get_chain_inputs(self, inputs: Any) -> Any:
|
||||
"""Get the inputs for a chain run."""
|
||||
if self._schema_format in ("original", "original+chat"):
|
||||
return inputs if isinstance(inputs, dict) else {"input": inputs}
|
||||
elif self._schema_format == "streaming_events":
|
||||
return {
|
||||
"input": inputs,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Invalid format: {self._schema_format}")
|
||||
|
||||
def _get_chain_outputs(self, outputs: Any) -> Any:
|
||||
"""Get the outputs for a chain run."""
|
||||
if self._schema_format in ("original", "original+chat"):
|
||||
return outputs if isinstance(outputs, dict) else {"output": outputs}
|
||||
elif self._schema_format == "streaming_events":
|
||||
return {
|
||||
"output": outputs,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Invalid format: {self._schema_format}")
|
||||
|
||||
def on_chain_end(
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
@@ -393,12 +207,12 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""End a trace for a chain run."""
|
||||
chain_run = self._get_run(run_id)
|
||||
chain_run.outputs = self._get_chain_outputs(outputs)
|
||||
chain_run.end_time = datetime.now(timezone.utc)
|
||||
chain_run.events.append({"name": "end", "time": chain_run.end_time})
|
||||
if inputs is not None:
|
||||
chain_run.inputs = self._get_chain_inputs(inputs)
|
||||
chain_run = self._complete_chain_run(
|
||||
outputs=outputs,
|
||||
run_id=run_id,
|
||||
inputs=inputs,
|
||||
**kwargs,
|
||||
)
|
||||
self._end_trace(chain_run)
|
||||
self._on_chain_end(chain_run)
|
||||
return chain_run
|
||||
@@ -412,12 +226,12 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Handle an error for a chain run."""
|
||||
chain_run = self._get_run(run_id)
|
||||
chain_run.error = self._get_stacktrace(error)
|
||||
chain_run.end_time = datetime.now(timezone.utc)
|
||||
chain_run.events.append({"name": "error", "time": chain_run.end_time})
|
||||
if inputs is not None:
|
||||
chain_run.inputs = self._get_chain_inputs(inputs)
|
||||
chain_run = self._errored_chain_run(
|
||||
error=error,
|
||||
run_id=run_id,
|
||||
inputs=inputs,
|
||||
**kwargs,
|
||||
)
|
||||
self._end_trace(chain_run)
|
||||
self._on_chain_error(chain_run)
|
||||
return chain_run
|
||||
@@ -436,30 +250,16 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Start a trace for a tool run."""
|
||||
start_time = datetime.now(timezone.utc)
|
||||
if metadata:
|
||||
kwargs.update({"metadata": metadata})
|
||||
|
||||
if self._schema_format in ("original", "original+chat"):
|
||||
inputs = {"input": input_str}
|
||||
elif self._schema_format == "streaming_events":
|
||||
inputs = {"input": inputs}
|
||||
else:
|
||||
raise AssertionError(f"Invalid format: {self._schema_format}")
|
||||
|
||||
tool_run = Run(
|
||||
id=run_id,
|
||||
parent_run_id=parent_run_id,
|
||||
tool_run = self._create_tool_run(
|
||||
serialized=serialized,
|
||||
# Wrapping in dict since Run requires a dict object.
|
||||
input_str=input_str,
|
||||
run_id=run_id,
|
||||
tags=tags,
|
||||
parent_run_id=parent_run_id,
|
||||
metadata=metadata,
|
||||
name=name,
|
||||
inputs=inputs,
|
||||
extra=kwargs,
|
||||
events=[{"name": "start", "time": start_time}],
|
||||
start_time=start_time,
|
||||
child_runs=[],
|
||||
run_type="tool",
|
||||
tags=tags or [],
|
||||
name=name, # type: ignore[arg-type]
|
||||
**kwargs,
|
||||
)
|
||||
self._start_trace(tool_run)
|
||||
self._on_tool_start(tool_run)
|
||||
@@ -467,10 +267,11 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
|
||||
def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> Run:
|
||||
"""End a trace for a tool run."""
|
||||
tool_run = self._get_run(run_id, run_type="tool")
|
||||
tool_run.outputs = {"output": output}
|
||||
tool_run.end_time = datetime.now(timezone.utc)
|
||||
tool_run.events.append({"name": "end", "time": tool_run.end_time})
|
||||
tool_run = self._complete_tool_run(
|
||||
output=output,
|
||||
run_id=run_id,
|
||||
**kwargs,
|
||||
)
|
||||
self._end_trace(tool_run)
|
||||
self._on_tool_end(tool_run)
|
||||
return tool_run
|
||||
@@ -483,10 +284,10 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Handle an error for a tool run."""
|
||||
tool_run = self._get_run(run_id, run_type="tool")
|
||||
tool_run.error = self._get_stacktrace(error)
|
||||
tool_run.end_time = datetime.now(timezone.utc)
|
||||
tool_run.events.append({"name": "error", "time": tool_run.end_time})
|
||||
tool_run = self._errored_tool_run(
|
||||
error=error,
|
||||
run_id=run_id,
|
||||
)
|
||||
self._end_trace(tool_run)
|
||||
self._on_tool_error(tool_run)
|
||||
return tool_run
|
||||
@@ -504,21 +305,15 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Run when Retriever starts running."""
|
||||
start_time = datetime.now(timezone.utc)
|
||||
if metadata:
|
||||
kwargs.update({"metadata": metadata})
|
||||
retrieval_run = Run(
|
||||
id=run_id,
|
||||
name=name or "Retriever",
|
||||
parent_run_id=parent_run_id,
|
||||
retrieval_run = self._create_retrieval_run(
|
||||
serialized=serialized,
|
||||
inputs={"query": query},
|
||||
extra=kwargs,
|
||||
events=[{"name": "start", "time": start_time}],
|
||||
start_time=start_time,
|
||||
query=query,
|
||||
run_id=run_id,
|
||||
parent_run_id=parent_run_id,
|
||||
tags=tags,
|
||||
child_runs=[],
|
||||
run_type="retriever",
|
||||
metadata=metadata,
|
||||
name=name,
|
||||
**kwargs,
|
||||
)
|
||||
self._start_trace(retrieval_run)
|
||||
self._on_retriever_start(retrieval_run)
|
||||
@@ -532,10 +327,11 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Run when Retriever errors."""
|
||||
retrieval_run = self._get_run(run_id, run_type="retriever")
|
||||
retrieval_run.error = self._get_stacktrace(error)
|
||||
retrieval_run.end_time = datetime.now(timezone.utc)
|
||||
retrieval_run.events.append({"name": "error", "time": retrieval_run.end_time})
|
||||
retrieval_run = self._errored_retrieval_run(
|
||||
error=error,
|
||||
run_id=run_id,
|
||||
**kwargs,
|
||||
)
|
||||
self._end_trace(retrieval_run)
|
||||
self._on_retriever_error(retrieval_run)
|
||||
return retrieval_run
|
||||
@@ -544,10 +340,11 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any
|
||||
) -> Run:
|
||||
"""Run when Retriever ends running."""
|
||||
retrieval_run = self._get_run(run_id, run_type="retriever")
|
||||
retrieval_run.outputs = {"documents": documents}
|
||||
retrieval_run.end_time = datetime.now(timezone.utc)
|
||||
retrieval_run.events.append({"name": "end", "time": retrieval_run.end_time})
|
||||
retrieval_run = self._complete_retrieval_run(
|
||||
documents=documents,
|
||||
run_id=run_id,
|
||||
**kwargs,
|
||||
)
|
||||
self._end_trace(retrieval_run)
|
||||
self._on_retriever_end(retrieval_run)
|
||||
return retrieval_run
|
||||
@@ -560,16 +357,349 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
"""Copy the tracer."""
|
||||
return self
|
||||
|
||||
def _on_run_create(self, run: Run) -> None:
|
||||
"""Process a run upon creation."""
|
||||
|
||||
def _on_run_update(self, run: Run) -> None:
|
||||
class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
|
||||
"""Async Base interface for tracers."""
|
||||
|
||||
@abstractmethod
|
||||
async def _persist_run(self, run: Run) -> None:
|
||||
"""Persist a run."""
|
||||
|
||||
async def _start_trace(self, run: Run) -> None:
|
||||
"""
|
||||
Start a trace for a run.
|
||||
|
||||
Starting a trace will run concurrently with each _on_[run_type]_start method.
|
||||
No _on_[run_type]_start callback should depend on operations in _start_trace.
|
||||
"""
|
||||
super()._start_trace(run)
|
||||
await self._on_run_create(run)
|
||||
|
||||
async def _end_trace(self, run: Run) -> None:
|
||||
"""
|
||||
End a trace for a run.
|
||||
|
||||
Ending a trace will run concurrently with each _on_[run_type]_end method.
|
||||
No _on_[run_type]_end callback should depend on operations in _end_trace.
|
||||
"""
|
||||
if not run.parent_run_id:
|
||||
await self._persist_run(run)
|
||||
self.run_map.pop(str(run.id))
|
||||
await self._on_run_update(run)
|
||||
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
chat_model_run = self._create_chat_model_run(
|
||||
serialized=serialized,
|
||||
messages=messages,
|
||||
run_id=run_id,
|
||||
parent_run_id=parent_run_id,
|
||||
tags=tags,
|
||||
metadata=metadata,
|
||||
name=name,
|
||||
**kwargs,
|
||||
)
|
||||
tasks = [
|
||||
self._start_trace(chat_model_run),
|
||||
self._on_chat_model_start(chat_model_run),
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
return chat_model_run
|
||||
|
||||
async def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
llm_run = self._create_llm_run(
|
||||
serialized=serialized,
|
||||
prompts=prompts,
|
||||
run_id=run_id,
|
||||
parent_run_id=parent_run_id,
|
||||
tags=tags,
|
||||
metadata=metadata,
|
||||
**kwargs,
|
||||
)
|
||||
tasks = [self._start_trace(llm_run), self._on_llm_start(llm_run)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
llm_run = self._llm_run_with_token_event(
|
||||
token=token,
|
||||
run_id=run_id,
|
||||
chunk=chunk,
|
||||
parent_run_id=parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
await self._on_llm_new_token(llm_run, token, chunk)
|
||||
|
||||
async def on_retry(
|
||||
self,
|
||||
retry_state: RetryCallState,
|
||||
*,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self._llm_run_with_retry_event(
|
||||
retry_state=retry_state,
|
||||
run_id=run_id,
|
||||
)
|
||||
|
||||
async def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
llm_run = self._complete_llm_run(
|
||||
response=response,
|
||||
run_id=run_id,
|
||||
)
|
||||
tasks = [self._on_llm_end(llm_run), self._end_trace(llm_run)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def on_llm_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
llm_run = self._errored_llm_run(
|
||||
error=error,
|
||||
run_id=run_id,
|
||||
)
|
||||
tasks = [self._on_llm_error(llm_run), self._end_trace(llm_run)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_type: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
chain_run = self._create_chain_run(
|
||||
serialized=serialized,
|
||||
inputs=inputs,
|
||||
run_id=run_id,
|
||||
tags=tags,
|
||||
parent_run_id=parent_run_id,
|
||||
metadata=metadata,
|
||||
run_type=run_type,
|
||||
name=name,
|
||||
**kwargs,
|
||||
)
|
||||
tasks = [self._start_trace(chain_run), self._on_chain_start(chain_run)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def on_chain_end(
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
chain_run = self._complete_chain_run(
|
||||
outputs=outputs,
|
||||
run_id=run_id,
|
||||
inputs=inputs,
|
||||
**kwargs,
|
||||
)
|
||||
tasks = [self._end_trace(chain_run), self._on_chain_end(chain_run)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def on_chain_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
chain_run = self._errored_chain_run(
|
||||
error=error,
|
||||
inputs=inputs,
|
||||
run_id=run_id,
|
||||
**kwargs,
|
||||
)
|
||||
tasks = [self._end_trace(chain_run), self._on_chain_error(chain_run)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
tool_run = self._create_tool_run(
|
||||
serialized=serialized,
|
||||
input_str=input_str,
|
||||
run_id=run_id,
|
||||
tags=tags,
|
||||
parent_run_id=parent_run_id,
|
||||
metadata=metadata,
|
||||
inputs=inputs,
|
||||
**kwargs,
|
||||
)
|
||||
tasks = [self._start_trace(tool_run), self._on_tool_start(tool_run)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def on_tool_end(
|
||||
self,
|
||||
output: Any,
|
||||
*,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
tool_run = self._complete_tool_run(
|
||||
output=output,
|
||||
run_id=run_id,
|
||||
**kwargs,
|
||||
)
|
||||
tasks = [self._end_trace(tool_run), self._on_tool_end(tool_run)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def on_tool_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
tool_run = self._errored_tool_run(
|
||||
error=error,
|
||||
run_id=run_id,
|
||||
)
|
||||
tasks = [self._end_trace(tool_run), self._on_tool_error(tool_run)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def on_retriever_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
query: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
retriever_run = self._create_retrieval_run(
|
||||
serialized=serialized,
|
||||
query=query,
|
||||
run_id=run_id,
|
||||
parent_run_id=parent_run_id,
|
||||
tags=tags,
|
||||
metadata=metadata,
|
||||
name=name,
|
||||
)
|
||||
tasks = [
|
||||
self._start_trace(retriever_run),
|
||||
self._on_retriever_start(retriever_run),
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def on_retriever_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
retrieval_run = self._errored_retrieval_run(
|
||||
error=error,
|
||||
run_id=run_id,
|
||||
**kwargs,
|
||||
)
|
||||
tasks = [
|
||||
self._end_trace(retrieval_run),
|
||||
self._on_retriever_error(retrieval_run),
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def on_retriever_end(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
retrieval_run = self._complete_retrieval_run(
|
||||
documents=documents,
|
||||
run_id=run_id,
|
||||
**kwargs,
|
||||
)
|
||||
tasks = [self._end_trace(retrieval_run), self._on_retriever_end(retrieval_run)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def _on_run_create(self, run: Run) -> None:
|
||||
"""Process a run upon creation."""
|
||||
pass
|
||||
|
||||
async def _on_run_update(self, run: Run) -> None:
|
||||
"""Process a run upon update."""
|
||||
|
||||
def _on_llm_start(self, run: Run) -> None:
|
||||
async def _on_llm_start(self, run: Run) -> None:
|
||||
"""Process the LLM Run upon start."""
|
||||
|
||||
def _on_llm_new_token(
|
||||
async def _on_llm_end(self, run: Run) -> None:
|
||||
"""Process the LLM Run."""
|
||||
|
||||
async def _on_llm_error(self, run: Run) -> None:
|
||||
"""Process the LLM Run upon error."""
|
||||
|
||||
async def _on_llm_new_token(
|
||||
self,
|
||||
run: Run,
|
||||
token: str,
|
||||
@@ -577,38 +707,32 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
) -> None:
|
||||
"""Process new LLM token."""
|
||||
|
||||
def _on_llm_end(self, run: Run) -> None:
|
||||
"""Process the LLM Run."""
|
||||
|
||||
def _on_llm_error(self, run: Run) -> None:
|
||||
"""Process the LLM Run upon error."""
|
||||
|
||||
def _on_chain_start(self, run: Run) -> None:
|
||||
async def _on_chain_start(self, run: Run) -> None:
|
||||
"""Process the Chain Run upon start."""
|
||||
|
||||
def _on_chain_end(self, run: Run) -> None:
|
||||
async def _on_chain_end(self, run: Run) -> None:
|
||||
"""Process the Chain Run."""
|
||||
|
||||
def _on_chain_error(self, run: Run) -> None:
|
||||
async def _on_chain_error(self, run: Run) -> None:
|
||||
"""Process the Chain Run upon error."""
|
||||
|
||||
def _on_tool_start(self, run: Run) -> None:
|
||||
async def _on_tool_start(self, run: Run) -> None:
|
||||
"""Process the Tool Run upon start."""
|
||||
|
||||
def _on_tool_end(self, run: Run) -> None:
|
||||
async def _on_tool_end(self, run: Run) -> None:
|
||||
"""Process the Tool Run."""
|
||||
|
||||
def _on_tool_error(self, run: Run) -> None:
|
||||
async def _on_tool_error(self, run: Run) -> None:
|
||||
"""Process the Tool Run upon error."""
|
||||
|
||||
def _on_chat_model_start(self, run: Run) -> None:
|
||||
async def _on_chat_model_start(self, run: Run) -> None:
|
||||
"""Process the Chat Model Run upon start."""
|
||||
|
||||
def _on_retriever_start(self, run: Run) -> None:
|
||||
async def _on_retriever_start(self, run: Run) -> None:
|
||||
"""Process the Retriever Run upon start."""
|
||||
|
||||
def _on_retriever_end(self, run: Run) -> None:
|
||||
async def _on_retriever_end(self, run: Run) -> None:
|
||||
"""Process the Retriever Run."""
|
||||
|
||||
def _on_retriever_error(self, run: Run) -> None:
|
||||
async def _on_retriever_error(self, run: Run) -> None:
|
||||
"""Process the Retriever Run upon error."""
|
||||
|
566
libs/core/langchain_core/tracers/core.py
Normal file
566
libs/core/langchain_core/tracers/core.py
Normal file
@@ -0,0 +1,566 @@
|
||||
"""Utilities for the root listener."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime, timezone
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Coroutine,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from uuid import UUID
|
||||
|
||||
from tenacity import RetryCallState
|
||||
|
||||
from langchain_core.exceptions import TracerException
|
||||
from langchain_core.load import dumpd
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatGenerationChunk,
|
||||
GenerationChunk,
|
||||
LLMResult,
|
||||
)
|
||||
from langchain_core.tracers.schemas import Run
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.documents import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SCHEMA_FORMAT_TYPE = Literal["original", "streaming_events"]
|
||||
|
||||
|
||||
class _TracerCore(ABC):
|
||||
"""
|
||||
Abstract base class for tracers
|
||||
This class provides common methods, and reusable methods for tracers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
_schema_format: Literal[
|
||||
"original", "streaming_events", "original+chat"
|
||||
] = "original",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the tracer.
|
||||
|
||||
Args:
|
||||
_schema_format: Primarily changes how the inputs and outputs are
|
||||
handled. For internal use only. This API will change.
|
||||
- 'original' is the format used by all current tracers.
|
||||
This format is slightly inconsistent with respect to inputs
|
||||
and outputs.
|
||||
- 'streaming_events' is used for supporting streaming events,
|
||||
for internal usage. It will likely change in the future, or
|
||||
be deprecated entirely in favor of a dedicated async tracer
|
||||
for streaming events.
|
||||
- 'original+chat' is a format that is the same as 'original'
|
||||
except it does NOT raise an attribute error on_chat_model_start
|
||||
kwargs: Additional keyword arguments that will be passed to
|
||||
the super class.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._schema_format = _schema_format # For internal use only API will change.
|
||||
self.run_map: Dict[str, Run] = {}
|
||||
"""Map of run ID to run. Cleared on run end."""
|
||||
self.order_map: Dict[UUID, Tuple[UUID, str]] = {}
|
||||
"""Map of run ID to (trace_id, dotted_order). Cleared when tracer GCed."""
|
||||
|
||||
@abstractmethod
|
||||
def _persist_run(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Persist a run."""
|
||||
|
||||
@staticmethod
|
||||
def _add_child_run(
|
||||
parent_run: Run,
|
||||
child_run: Run,
|
||||
) -> None:
|
||||
"""Add child run to a chain run or tool run."""
|
||||
parent_run.child_runs.append(child_run)
|
||||
|
||||
@staticmethod
|
||||
def _get_stacktrace(error: BaseException) -> str:
|
||||
"""Get the stacktrace of the parent error."""
|
||||
msg = repr(error)
|
||||
try:
|
||||
if sys.version_info < (3, 10):
|
||||
tb = traceback.format_exception(
|
||||
error.__class__, error, error.__traceback__
|
||||
)
|
||||
else:
|
||||
tb = traceback.format_exception(error)
|
||||
return (msg + "\n\n".join(tb)).strip()
|
||||
except: # noqa: E722
|
||||
return msg
|
||||
|
||||
def _start_trace(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]: # type: ignore[return]
|
||||
current_dotted_order = run.start_time.strftime("%Y%m%dT%H%M%S%fZ") + str(run.id)
|
||||
if run.parent_run_id:
|
||||
if parent := self.order_map.get(run.parent_run_id):
|
||||
run.trace_id, run.dotted_order = parent
|
||||
run.dotted_order += "." + current_dotted_order
|
||||
if parent_run := self.run_map.get(str(run.parent_run_id)):
|
||||
self._add_child_run(parent_run, run)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Parent run {run.parent_run_id} not found for run {run.id}."
|
||||
" Treating as a root run."
|
||||
)
|
||||
run.parent_run_id = None
|
||||
run.trace_id = run.id
|
||||
run.dotted_order = current_dotted_order
|
||||
else:
|
||||
run.trace_id = run.id
|
||||
run.dotted_order = current_dotted_order
|
||||
self.order_map[run.id] = (run.trace_id, run.dotted_order)
|
||||
self.run_map[str(run.id)] = run
|
||||
|
||||
def _get_run(
|
||||
self, run_id: UUID, run_type: Union[str, Set[str], None] = None
|
||||
) -> Run:
|
||||
try:
|
||||
run = self.run_map[str(run_id)]
|
||||
except KeyError as exc:
|
||||
raise TracerException(f"No indexed run ID {run_id}.") from exc
|
||||
|
||||
if isinstance(run_type, str):
|
||||
run_types: Union[Set[str], None] = {run_type}
|
||||
else:
|
||||
run_types = run_type
|
||||
if run_types is not None and run.run_type not in run_types:
|
||||
raise TracerException(
|
||||
f"Found {run.run_type} run at ID {run_id}, "
|
||||
f"but expected {run_types} run."
|
||||
)
|
||||
return run
|
||||
|
||||
def _create_chat_model_run(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Create a chat model run."""
|
||||
if self._schema_format not in ("streaming_events", "original+chat"):
|
||||
# Please keep this un-implemented for backwards compatibility.
|
||||
# When it's unimplemented old tracers that use the "original" format
|
||||
# fallback on the on_llm_start method implementation if they
|
||||
# find that the on_chat_model_start method is not implemented.
|
||||
# This can eventually be cleaned up by writing a "modern" tracer
|
||||
# that has all the updated schema changes corresponding to
|
||||
# the "streaming_events" format.
|
||||
raise NotImplementedError(
|
||||
f"Chat model tracing is not supported in "
|
||||
f"for {self._schema_format} format."
|
||||
)
|
||||
start_time = datetime.now(timezone.utc)
|
||||
if metadata:
|
||||
kwargs.update({"metadata": metadata})
|
||||
return Run(
|
||||
id=run_id,
|
||||
parent_run_id=parent_run_id,
|
||||
serialized=serialized,
|
||||
inputs={"messages": [[dumpd(msg) for msg in batch] for batch in messages]},
|
||||
extra=kwargs,
|
||||
events=[{"name": "start", "time": start_time}],
|
||||
start_time=start_time,
|
||||
# WARNING: This is valid ONLY for streaming_events.
|
||||
# run_type="llm" is what's used by virtually all tracers.
|
||||
# Changing this to "chat_model" may break triggering on_llm_start
|
||||
run_type="chat_model",
|
||||
tags=tags,
|
||||
name=name, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
def _create_llm_run(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Create a llm run"""
|
||||
start_time = datetime.now(timezone.utc)
|
||||
if metadata:
|
||||
kwargs.update({"metadata": metadata})
|
||||
return Run(
|
||||
id=run_id,
|
||||
parent_run_id=parent_run_id,
|
||||
serialized=serialized,
|
||||
# TODO: Figure out how to expose kwargs here
|
||||
inputs={"prompts": prompts},
|
||||
extra=kwargs,
|
||||
events=[{"name": "start", "time": start_time}],
|
||||
start_time=start_time,
|
||||
run_type="llm",
|
||||
tags=tags or [],
|
||||
name=name, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
def _llm_run_with_token_event(
|
||||
self,
|
||||
token: str,
|
||||
run_id: UUID,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""
|
||||
Append token event to LLM run and return the run
|
||||
"""
|
||||
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
|
||||
event_kwargs: Dict[str, Any] = {"token": token}
|
||||
if chunk:
|
||||
event_kwargs["chunk"] = chunk
|
||||
llm_run.events.append(
|
||||
{
|
||||
"name": "new_token",
|
||||
"time": datetime.now(timezone.utc),
|
||||
"kwargs": event_kwargs,
|
||||
},
|
||||
)
|
||||
return llm_run
|
||||
|
||||
def _llm_run_with_retry_event(
|
||||
self,
|
||||
retry_state: RetryCallState,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
llm_run = self._get_run(run_id)
|
||||
retry_d: Dict[str, Any] = {
|
||||
"slept": retry_state.idle_for,
|
||||
"attempt": retry_state.attempt_number,
|
||||
}
|
||||
if retry_state.outcome is None:
|
||||
retry_d["outcome"] = "N/A"
|
||||
elif retry_state.outcome.failed:
|
||||
retry_d["outcome"] = "failed"
|
||||
exception = retry_state.outcome.exception()
|
||||
retry_d["exception"] = str(exception)
|
||||
retry_d["exception_type"] = exception.__class__.__name__
|
||||
else:
|
||||
retry_d["outcome"] = "success"
|
||||
retry_d["result"] = str(retry_state.outcome.result())
|
||||
llm_run.events.append(
|
||||
{
|
||||
"name": "retry",
|
||||
"time": datetime.now(timezone.utc),
|
||||
"kwargs": retry_d,
|
||||
},
|
||||
)
|
||||
return llm_run
|
||||
|
||||
def _complete_llm_run(self, response: LLMResult, run_id: UUID) -> Run:
|
||||
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
|
||||
llm_run.outputs = response.dict()
|
||||
for i, generations in enumerate(response.generations):
|
||||
for j, generation in enumerate(generations):
|
||||
output_generation = llm_run.outputs["generations"][i][j]
|
||||
if "message" in output_generation:
|
||||
output_generation["message"] = dumpd(
|
||||
cast(ChatGeneration, generation).message
|
||||
)
|
||||
llm_run.end_time = datetime.now(timezone.utc)
|
||||
llm_run.events.append({"name": "end", "time": llm_run.end_time})
|
||||
|
||||
return llm_run
|
||||
|
||||
def _errored_llm_run(self, error: BaseException, run_id: UUID) -> Run:
|
||||
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
|
||||
llm_run.error = self._get_stacktrace(error)
|
||||
llm_run.end_time = datetime.now(timezone.utc)
|
||||
llm_run.events.append({"name": "error", "time": llm_run.end_time})
|
||||
|
||||
return llm_run
|
||||
|
||||
def _create_chain_run(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_type: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Create a chain Run"""
|
||||
start_time = datetime.now(timezone.utc)
|
||||
if metadata:
|
||||
kwargs.update({"metadata": metadata})
|
||||
return Run(
|
||||
id=run_id,
|
||||
parent_run_id=parent_run_id,
|
||||
serialized=serialized,
|
||||
inputs=self._get_chain_inputs(inputs),
|
||||
extra=kwargs,
|
||||
events=[{"name": "start", "time": start_time}],
|
||||
start_time=start_time,
|
||||
child_runs=[],
|
||||
run_type=run_type or "chain",
|
||||
name=name, # type: ignore[arg-type]
|
||||
tags=tags or [],
|
||||
)
|
||||
|
||||
def _get_chain_inputs(self, inputs: Any) -> Any:
|
||||
"""Get the inputs for a chain run."""
|
||||
if self._schema_format in ("original", "original+chat"):
|
||||
return inputs if isinstance(inputs, dict) else {"input": inputs}
|
||||
elif self._schema_format == "streaming_events":
|
||||
return {
|
||||
"input": inputs,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Invalid format: {self._schema_format}")
|
||||
|
||||
def _get_chain_outputs(self, outputs: Any) -> Any:
|
||||
"""Get the outputs for a chain run."""
|
||||
if self._schema_format in ("original", "original+chat"):
|
||||
return outputs if isinstance(outputs, dict) else {"output": outputs}
|
||||
elif self._schema_format == "streaming_events":
|
||||
return {
|
||||
"output": outputs,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Invalid format: {self._schema_format}")
|
||||
|
||||
def _complete_chain_run(
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
run_id: UUID,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Update a chain run with outputs and end time."""
|
||||
chain_run = self._get_run(run_id)
|
||||
chain_run.outputs = self._get_chain_outputs(outputs)
|
||||
chain_run.end_time = datetime.now(timezone.utc)
|
||||
chain_run.events.append({"name": "end", "time": chain_run.end_time})
|
||||
if inputs is not None:
|
||||
chain_run.inputs = self._get_chain_inputs(inputs)
|
||||
return chain_run
|
||||
|
||||
def _errored_chain_run(
|
||||
self,
|
||||
error: BaseException,
|
||||
inputs: Optional[Dict[str, Any]],
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
chain_run = self._get_run(run_id)
|
||||
chain_run.error = self._get_stacktrace(error)
|
||||
chain_run.end_time = datetime.now(timezone.utc)
|
||||
chain_run.events.append({"name": "error", "time": chain_run.end_time})
|
||||
if inputs is not None:
|
||||
chain_run.inputs = self._get_chain_inputs(inputs)
|
||||
return chain_run
|
||||
|
||||
def _create_tool_run(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Create a tool run."""
|
||||
start_time = datetime.now(timezone.utc)
|
||||
if metadata:
|
||||
kwargs.update({"metadata": metadata})
|
||||
|
||||
if self._schema_format in ("original", "original+chat"):
|
||||
inputs = {"input": input_str}
|
||||
elif self._schema_format == "streaming_events":
|
||||
inputs = {"input": inputs}
|
||||
else:
|
||||
raise AssertionError(f"Invalid format: {self._schema_format}")
|
||||
|
||||
return Run(
|
||||
id=run_id,
|
||||
parent_run_id=parent_run_id,
|
||||
serialized=serialized,
|
||||
# Wrapping in dict since Run requires a dict object.
|
||||
inputs=inputs,
|
||||
extra=kwargs,
|
||||
events=[{"name": "start", "time": start_time}],
|
||||
start_time=start_time,
|
||||
child_runs=[],
|
||||
run_type="tool",
|
||||
tags=tags or [],
|
||||
name=name, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
def _complete_tool_run(
|
||||
self,
|
||||
output: Dict[str, Any],
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Update a tool run with outputs and end time."""
|
||||
tool_run = self._get_run(run_id, run_type="tool")
|
||||
tool_run.outputs = {"output": output}
|
||||
tool_run.end_time = datetime.now(timezone.utc)
|
||||
tool_run.events.append({"name": "end", "time": tool_run.end_time})
|
||||
return tool_run
|
||||
|
||||
def _errored_tool_run(
|
||||
self,
|
||||
error: BaseException,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Update a tool run with error and end time."""
|
||||
tool_run = self._get_run(run_id, run_type="tool")
|
||||
tool_run.error = self._get_stacktrace(error)
|
||||
tool_run.end_time = datetime.now(timezone.utc)
|
||||
tool_run.events.append({"name": "error", "time": tool_run.end_time})
|
||||
return tool_run
|
||||
|
||||
def _create_retrieval_run(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
query: str,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Create a retrieval run."""
|
||||
start_time = datetime.now(timezone.utc)
|
||||
if metadata:
|
||||
kwargs.update({"metadata": metadata})
|
||||
return Run(
|
||||
id=run_id,
|
||||
name=name or "Retriever",
|
||||
parent_run_id=parent_run_id,
|
||||
serialized=serialized,
|
||||
inputs={"query": query},
|
||||
extra=kwargs,
|
||||
events=[{"name": "start", "time": start_time}],
|
||||
start_time=start_time,
|
||||
tags=tags,
|
||||
child_runs=[],
|
||||
run_type="retriever",
|
||||
)
|
||||
|
||||
def _complete_retrieval_run(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Update a retrieval run with outputs and end time."""
|
||||
retrieval_run = self._get_run(run_id, run_type="retriever")
|
||||
retrieval_run.outputs = {"documents": documents}
|
||||
retrieval_run.end_time = datetime.now(timezone.utc)
|
||||
retrieval_run.events.append({"name": "end", "time": retrieval_run.end_time})
|
||||
return retrieval_run
|
||||
|
||||
def _errored_retrieval_run(
|
||||
self,
|
||||
error: BaseException,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
retrieval_run = self._get_run(run_id, run_type="retriever")
|
||||
retrieval_run.error = self._get_stacktrace(error)
|
||||
retrieval_run.end_time = datetime.now(timezone.utc)
|
||||
retrieval_run.events.append({"name": "error", "time": retrieval_run.end_time})
|
||||
return retrieval_run
|
||||
|
||||
def __deepcopy__(self, memo: dict) -> _TracerCore:
|
||||
"""Deepcopy the tracer."""
|
||||
return self
|
||||
|
||||
def __copy__(self) -> _TracerCore:
|
||||
"""Copy the tracer."""
|
||||
return self
|
||||
|
||||
def _end_trace(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""End a trace for a run."""
|
||||
|
||||
def _on_run_create(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process a run upon creation."""
|
||||
|
||||
def _on_run_update(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process a run upon update."""
|
||||
|
||||
def _on_llm_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the LLM Run upon start."""
|
||||
|
||||
def _on_llm_new_token(
|
||||
self,
|
||||
run: Run,
|
||||
token: str,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]],
|
||||
) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process new LLM token."""
|
||||
|
||||
def _on_llm_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the LLM Run."""
|
||||
|
||||
def _on_llm_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the LLM Run upon error."""
|
||||
|
||||
def _on_chain_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the Chain Run upon start."""
|
||||
|
||||
def _on_chain_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the Chain Run."""
|
||||
|
||||
def _on_chain_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the Chain Run upon error."""
|
||||
|
||||
def _on_tool_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the Tool Run upon start."""
|
||||
|
||||
def _on_tool_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the Tool Run."""
|
||||
|
||||
def _on_tool_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the Tool Run upon error."""
|
||||
|
||||
def _on_chat_model_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the Chat Model Run upon start."""
|
||||
|
||||
def _on_retriever_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the Retriever Run upon start."""
|
||||
|
||||
def _on_retriever_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the Retriever Run."""
|
||||
|
||||
def _on_retriever_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the Retriever Run upon error."""
|
@@ -1,14 +1,18 @@
|
||||
from typing import Callable, Optional, Union
|
||||
from typing import Awaitable, Callable, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.runnables.config import (
|
||||
RunnableConfig,
|
||||
acall_func_with_variable_args,
|
||||
call_func_with_variable_args,
|
||||
)
|
||||
from langchain_core.tracers.base import BaseTracer
|
||||
from langchain_core.tracers.base import AsyncBaseTracer, BaseTracer
|
||||
from langchain_core.tracers.schemas import Run
|
||||
|
||||
Listener = Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]]
|
||||
AsyncListener = Union[
|
||||
Callable[[Run], Awaitable[None]], Callable[[Run, RunnableConfig], Awaitable[None]]
|
||||
]
|
||||
|
||||
|
||||
class RootListenersTracer(BaseTracer):
|
||||
@@ -54,3 +58,50 @@ class RootListenersTracer(BaseTracer):
|
||||
else:
|
||||
if self._arg_on_error is not None:
|
||||
call_func_with_variable_args(self._arg_on_error, run, self.config)
|
||||
|
||||
|
||||
class AsyncRootListenersTracer(AsyncBaseTracer):
|
||||
"""Async Tracer that calls listeners on run start, end, and error."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
config: RunnableConfig,
|
||||
on_start: Optional[AsyncListener],
|
||||
on_end: Optional[AsyncListener],
|
||||
on_error: Optional[AsyncListener],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self._arg_on_start = on_start
|
||||
self._arg_on_end = on_end
|
||||
self._arg_on_error = on_error
|
||||
self.root_id: Optional[UUID] = None
|
||||
|
||||
async def _persist_run(self, run: Run) -> None:
|
||||
# This is a legacy method only called once for an entire run tree
|
||||
# therefore not useful here
|
||||
pass
|
||||
|
||||
async def _on_run_create(self, run: Run) -> None:
|
||||
if self.root_id is not None:
|
||||
return
|
||||
|
||||
self.root_id = run.id
|
||||
|
||||
if self._arg_on_start is not None:
|
||||
await acall_func_with_variable_args(self._arg_on_start, run, self.config)
|
||||
|
||||
async def _on_run_update(self, run: Run) -> None:
|
||||
if run.id != self.root_id:
|
||||
return
|
||||
|
||||
if run.error is None:
|
||||
if self._arg_on_end is not None:
|
||||
await acall_func_with_variable_args(self._arg_on_end, run, self.config)
|
||||
else:
|
||||
if self._arg_on_error is not None:
|
||||
await acall_func_with_variable_args(
|
||||
self._arg_on_error, run, self.config
|
||||
)
|
||||
|
598
libs/core/tests/unit_tests/tracers/test_async_base_tracer.py
Normal file
598
libs/core/tests/unit_tests/tracers/test_async_base_tracer.py
Normal file
@@ -0,0 +1,598 @@
|
||||
"""Test Tracer classes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, List
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
|
||||
from langchain_core.callbacks import AsyncCallbackManager
|
||||
from langchain_core.exceptions import TracerException
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.tracers.base import AsyncBaseTracer
|
||||
from langchain_core.tracers.schemas import Run
|
||||
|
||||
SERIALIZED = {"id": ["llm"]}
|
||||
SERIALIZED_CHAT = {"id": ["chat_model"]}
|
||||
|
||||
|
||||
class FakeAsyncTracer(AsyncBaseTracer):
|
||||
"""Fake tracer to test async based tracers."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the tracer."""
|
||||
super().__init__()
|
||||
self.runs: List[Run] = []
|
||||
|
||||
async def _persist_run(self, run: Run) -> None:
|
||||
self.runs.append(run)
|
||||
|
||||
|
||||
def _compare_run_with_error(run: Any, expected_run: Any) -> None:
|
||||
if run.child_runs:
|
||||
assert len(expected_run.child_runs) == len(run.child_runs)
|
||||
for received, expected in zip(run.child_runs, expected_run.child_runs):
|
||||
_compare_run_with_error(received, expected)
|
||||
received = run.dict(exclude={"child_runs"})
|
||||
received_err = received.pop("error")
|
||||
expected = expected_run.dict(exclude={"child_runs"})
|
||||
expected_err = expected.pop("error")
|
||||
|
||||
assert received == expected
|
||||
if expected_err is not None:
|
||||
assert received_err is not None
|
||||
assert expected_err in received_err
|
||||
else:
|
||||
assert received_err is None
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_tracer_llm_run() -> None:
|
||||
"""Test tracer on an LLM run."""
|
||||
uuid = uuid4()
|
||||
compare_run = Run( # type: ignore[call-arg]
|
||||
id=uuid,
|
||||
parent_run_id=None,
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
events=[
|
||||
{"name": "start", "time": datetime.now(timezone.utc)},
|
||||
{"name": "end", "time": datetime.now(timezone.utc)},
|
||||
],
|
||||
extra={},
|
||||
serialized=SERIALIZED,
|
||||
inputs={"prompts": []},
|
||||
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
|
||||
error=None,
|
||||
run_type="llm",
|
||||
trace_id=uuid,
|
||||
dotted_order=f"20230101T000000000000Z{uuid}",
|
||||
)
|
||||
tracer = FakeAsyncTracer()
|
||||
|
||||
await tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
||||
await tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_tracer_chat_model_run() -> None:
|
||||
"""Test tracer on a Chat Model run."""
|
||||
tracer = FakeAsyncTracer()
|
||||
manager = AsyncCallbackManager(handlers=[tracer])
|
||||
run_managers = await manager.on_chat_model_start(
|
||||
serialized=SERIALIZED_CHAT, messages=[[HumanMessage(content="")]]
|
||||
)
|
||||
compare_run = Run(
|
||||
id=str(run_managers[0].run_id), # type: ignore[arg-type]
|
||||
name="chat_model",
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
events=[
|
||||
{"name": "start", "time": datetime.now(timezone.utc)},
|
||||
{"name": "end", "time": datetime.now(timezone.utc)},
|
||||
],
|
||||
extra={},
|
||||
serialized=SERIALIZED_CHAT,
|
||||
inputs=dict(prompts=["Human: "]),
|
||||
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
|
||||
error=None,
|
||||
run_type="llm",
|
||||
trace_id=run_managers[0].run_id,
|
||||
dotted_order=f"20230101T000000000000Z{run_managers[0].run_id}",
|
||||
)
|
||||
for run_manager in run_managers:
|
||||
await run_manager.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_tracer_llm_run_errors_no_start() -> None:
|
||||
"""Test tracer on an LLM run without a start."""
|
||||
tracer = FakeAsyncTracer()
|
||||
|
||||
with pytest.raises(TracerException):
|
||||
await tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid4())
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_tracer_multiple_llm_runs() -> None:
|
||||
"""Test the tracer with multiple runs."""
|
||||
uuid = uuid4()
|
||||
compare_run = Run(
|
||||
id=uuid,
|
||||
name="llm",
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
events=[
|
||||
{"name": "start", "time": datetime.now(timezone.utc)},
|
||||
{"name": "end", "time": datetime.now(timezone.utc)},
|
||||
],
|
||||
extra={},
|
||||
serialized=SERIALIZED,
|
||||
inputs=dict(prompts=[]),
|
||||
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
|
||||
error=None,
|
||||
run_type="llm",
|
||||
trace_id=uuid,
|
||||
dotted_order=f"20230101T000000000000Z{uuid}",
|
||||
)
|
||||
tracer = FakeAsyncTracer()
|
||||
|
||||
num_runs = 10
|
||||
for _ in range(num_runs):
|
||||
await tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
||||
await tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
||||
|
||||
assert tracer.runs == [compare_run] * num_runs
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_tracer_chain_run() -> None:
|
||||
"""Test tracer on a Chain run."""
|
||||
uuid = uuid4()
|
||||
compare_run = Run( # type: ignore[call-arg]
|
||||
id=str(uuid), # type: ignore[arg-type]
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
events=[
|
||||
{"name": "start", "time": datetime.now(timezone.utc)},
|
||||
{"name": "end", "time": datetime.now(timezone.utc)},
|
||||
],
|
||||
extra={},
|
||||
serialized={"name": "chain"},
|
||||
inputs={},
|
||||
outputs={},
|
||||
error=None,
|
||||
run_type="chain",
|
||||
trace_id=uuid,
|
||||
dotted_order=f"20230101T000000000000Z{uuid}",
|
||||
)
|
||||
tracer = FakeAsyncTracer()
|
||||
|
||||
await tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid)
|
||||
await tracer.on_chain_end(outputs={}, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_tracer_tool_run() -> None:
|
||||
"""Test tracer on a Tool run."""
|
||||
uuid = uuid4()
|
||||
compare_run = Run( # type: ignore[call-arg]
|
||||
id=str(uuid), # type: ignore[arg-type]
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
events=[
|
||||
{"name": "start", "time": datetime.now(timezone.utc)},
|
||||
{"name": "end", "time": datetime.now(timezone.utc)},
|
||||
],
|
||||
extra={},
|
||||
serialized={"name": "tool"},
|
||||
inputs={"input": "test"},
|
||||
outputs={"output": "test"},
|
||||
error=None,
|
||||
run_type="tool",
|
||||
trace_id=uuid,
|
||||
dotted_order=f"20230101T000000000000Z{uuid}",
|
||||
)
|
||||
tracer = FakeAsyncTracer()
|
||||
await tracer.on_tool_start(
|
||||
serialized={"name": "tool"}, input_str="test", run_id=uuid
|
||||
)
|
||||
await tracer.on_tool_end("test", run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_tracer_nested_run() -> None:
|
||||
"""Test tracer on a nested run."""
|
||||
tracer = FakeAsyncTracer()
|
||||
|
||||
chain_uuid = uuid4()
|
||||
tool_uuid = uuid4()
|
||||
llm_uuid1 = uuid4()
|
||||
llm_uuid2 = uuid4()
|
||||
for _ in range(10):
|
||||
await tracer.on_chain_start(
|
||||
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
|
||||
)
|
||||
await tracer.on_tool_start(
|
||||
serialized={"name": "tool"},
|
||||
input_str="test",
|
||||
run_id=tool_uuid,
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
await tracer.on_llm_start(
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
run_id=llm_uuid1,
|
||||
parent_run_id=tool_uuid,
|
||||
)
|
||||
await tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
||||
await tracer.on_tool_end("test", run_id=tool_uuid)
|
||||
await tracer.on_llm_start(
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
run_id=llm_uuid2,
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
await tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
|
||||
await tracer.on_chain_end(outputs={}, run_id=chain_uuid)
|
||||
|
||||
compare_run = Run( # type: ignore[call-arg]
|
||||
id=str(chain_uuid), # type: ignore[arg-type]
|
||||
error=None,
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
events=[
|
||||
{"name": "start", "time": datetime.now(timezone.utc)},
|
||||
{"name": "end", "time": datetime.now(timezone.utc)},
|
||||
],
|
||||
extra={},
|
||||
serialized={"name": "chain"},
|
||||
inputs={},
|
||||
outputs={},
|
||||
run_type="chain",
|
||||
trace_id=chain_uuid,
|
||||
dotted_order=f"20230101T000000000000Z{chain_uuid}",
|
||||
child_runs=[
|
||||
Run( # type: ignore[call-arg]
|
||||
id=tool_uuid,
|
||||
parent_run_id=chain_uuid,
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
events=[
|
||||
{"name": "start", "time": datetime.now(timezone.utc)},
|
||||
{"name": "end", "time": datetime.now(timezone.utc)},
|
||||
],
|
||||
extra={},
|
||||
serialized={"name": "tool"},
|
||||
inputs=dict(input="test"),
|
||||
outputs=dict(output="test"),
|
||||
error=None,
|
||||
run_type="tool",
|
||||
trace_id=chain_uuid,
|
||||
dotted_order=f"20230101T000000000000Z{chain_uuid}.20230101T000000000000Z{tool_uuid}",
|
||||
child_runs=[
|
||||
Run( # type: ignore[call-arg]
|
||||
id=str(llm_uuid1), # type: ignore[arg-type]
|
||||
parent_run_id=str(tool_uuid), # type: ignore[arg-type]
|
||||
error=None,
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
events=[
|
||||
{"name": "start", "time": datetime.now(timezone.utc)},
|
||||
{"name": "end", "time": datetime.now(timezone.utc)},
|
||||
],
|
||||
extra={},
|
||||
serialized=SERIALIZED,
|
||||
inputs=dict(prompts=[]),
|
||||
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
|
||||
run_type="llm",
|
||||
trace_id=chain_uuid,
|
||||
dotted_order=f"20230101T000000000000Z{chain_uuid}.20230101T000000000000Z{tool_uuid}.20230101T000000000000Z{llm_uuid1}",
|
||||
)
|
||||
],
|
||||
),
|
||||
Run( # type: ignore[call-arg]
|
||||
id=str(llm_uuid2), # type: ignore[arg-type]
|
||||
parent_run_id=str(chain_uuid), # type: ignore[arg-type]
|
||||
error=None,
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
events=[
|
||||
{"name": "start", "time": datetime.now(timezone.utc)},
|
||||
{"name": "end", "time": datetime.now(timezone.utc)},
|
||||
],
|
||||
extra={},
|
||||
serialized=SERIALIZED,
|
||||
inputs=dict(prompts=[]),
|
||||
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
|
||||
run_type="llm",
|
||||
trace_id=chain_uuid,
|
||||
dotted_order=f"20230101T000000000000Z{chain_uuid}.20230101T000000000000Z{llm_uuid2}",
|
||||
),
|
||||
],
|
||||
)
|
||||
assert tracer.runs[0] == compare_run
|
||||
assert tracer.runs == [compare_run] * 10
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_tracer_llm_run_on_error() -> None:
|
||||
"""Test tracer on an LLM run with an error."""
|
||||
exception = Exception("test")
|
||||
uuid = uuid4()
|
||||
|
||||
compare_run = Run( # type: ignore[call-arg]
|
||||
id=str(uuid), # type: ignore[arg-type]
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
events=[
|
||||
{"name": "start", "time": datetime.now(timezone.utc)},
|
||||
{"name": "error", "time": datetime.now(timezone.utc)},
|
||||
],
|
||||
extra={},
|
||||
serialized=SERIALIZED,
|
||||
inputs=dict(prompts=[]),
|
||||
outputs=None,
|
||||
error=repr(exception),
|
||||
run_type="llm",
|
||||
trace_id=uuid,
|
||||
dotted_order=f"20230101T000000000000Z{uuid}",
|
||||
)
|
||||
tracer = FakeAsyncTracer()
|
||||
|
||||
await tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
||||
await tracer.on_llm_error(exception, run_id=uuid)
|
||||
assert len(tracer.runs) == 1
|
||||
_compare_run_with_error(tracer.runs[0], compare_run)
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_tracer_llm_run_on_error_callback() -> None:
|
||||
"""Test tracer on an LLM run with an error and a callback."""
|
||||
exception = Exception("test")
|
||||
uuid = uuid4()
|
||||
|
||||
compare_run = Run( # type: ignore[call-arg]
|
||||
id=str(uuid), # type: ignore[arg-type]
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
events=[
|
||||
{"name": "start", "time": datetime.now(timezone.utc)},
|
||||
{"name": "error", "time": datetime.now(timezone.utc)},
|
||||
],
|
||||
extra={},
|
||||
serialized=SERIALIZED,
|
||||
inputs=dict(prompts=[]),
|
||||
outputs=None,
|
||||
error=repr(exception),
|
||||
run_type="llm",
|
||||
trace_id=uuid,
|
||||
dotted_order=f"20230101T000000000000Z{uuid}",
|
||||
)
|
||||
|
||||
class FakeTracerWithLlmErrorCallback(FakeAsyncTracer):
|
||||
error_run = None
|
||||
|
||||
async def _on_llm_error(self, run: Run) -> None:
|
||||
self.error_run = run
|
||||
|
||||
tracer = FakeTracerWithLlmErrorCallback()
|
||||
await tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
||||
await tracer.on_llm_error(exception, run_id=uuid)
|
||||
_compare_run_with_error(tracer.error_run, compare_run)
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_tracer_chain_run_on_error() -> None:
|
||||
"""Test tracer on a Chain run with an error."""
|
||||
exception = Exception("test")
|
||||
uuid = uuid4()
|
||||
|
||||
compare_run = Run( # type: ignore[call-arg]
|
||||
id=str(uuid), # type: ignore[arg-type]
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
events=[
|
||||
{"name": "start", "time": datetime.now(timezone.utc)},
|
||||
{"name": "error", "time": datetime.now(timezone.utc)},
|
||||
],
|
||||
extra={},
|
||||
serialized={"name": "chain"},
|
||||
inputs={},
|
||||
outputs=None,
|
||||
error=repr(exception),
|
||||
run_type="chain",
|
||||
trace_id=uuid,
|
||||
dotted_order=f"20230101T000000000000Z{uuid}",
|
||||
)
|
||||
tracer = FakeAsyncTracer()
|
||||
|
||||
await tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid)
|
||||
await tracer.on_chain_error(exception, run_id=uuid)
|
||||
_compare_run_with_error(tracer.runs[0], compare_run)
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_tracer_tool_run_on_error() -> None:
|
||||
"""Test tracer on a Tool run with an error."""
|
||||
exception = Exception("test")
|
||||
uuid = uuid4()
|
||||
|
||||
compare_run = Run( # type: ignore[call-arg]
|
||||
id=str(uuid), # type: ignore[arg-type]
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
events=[
|
||||
{"name": "start", "time": datetime.now(timezone.utc)},
|
||||
{"name": "error", "time": datetime.now(timezone.utc)},
|
||||
],
|
||||
extra={},
|
||||
serialized={"name": "tool"},
|
||||
inputs=dict(input="test"),
|
||||
outputs=None,
|
||||
action="{'name': 'tool'}",
|
||||
error=repr(exception),
|
||||
run_type="tool",
|
||||
trace_id=uuid,
|
||||
dotted_order=f"20230101T000000000000Z{uuid}",
|
||||
)
|
||||
tracer = FakeAsyncTracer()
|
||||
|
||||
await tracer.on_tool_start(
|
||||
serialized={"name": "tool"}, input_str="test", run_id=uuid
|
||||
)
|
||||
await tracer.on_tool_error(exception, run_id=uuid)
|
||||
_compare_run_with_error(tracer.runs[0], compare_run)
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_tracer_nested_runs_on_error() -> None:
|
||||
"""Test tracer on a nested run with an error."""
|
||||
exception = Exception("test")
|
||||
|
||||
tracer = FakeAsyncTracer()
|
||||
chain_uuid = uuid4()
|
||||
tool_uuid = uuid4()
|
||||
llm_uuid1 = uuid4()
|
||||
llm_uuid2 = uuid4()
|
||||
llm_uuid3 = uuid4()
|
||||
|
||||
for _ in range(3):
|
||||
await tracer.on_chain_start(
|
||||
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
|
||||
)
|
||||
await tracer.on_llm_start(
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
run_id=llm_uuid1,
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
await tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
||||
await tracer.on_llm_start(
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
run_id=llm_uuid2,
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
await tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
|
||||
await tracer.on_tool_start(
|
||||
serialized={"name": "tool"},
|
||||
input_str="test",
|
||||
run_id=tool_uuid,
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
await tracer.on_llm_start(
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
run_id=llm_uuid3,
|
||||
parent_run_id=tool_uuid,
|
||||
)
|
||||
await tracer.on_llm_error(exception, run_id=llm_uuid3)
|
||||
await tracer.on_tool_error(exception, run_id=tool_uuid)
|
||||
await tracer.on_chain_error(exception, run_id=chain_uuid)
|
||||
|
||||
compare_run = Run( # type: ignore[call-arg]
|
||||
id=str(chain_uuid), # type: ignore[arg-type]
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
events=[
|
||||
{"name": "start", "time": datetime.now(timezone.utc)},
|
||||
{"name": "error", "time": datetime.now(timezone.utc)},
|
||||
],
|
||||
extra={},
|
||||
serialized={"name": "chain"},
|
||||
error=repr(exception),
|
||||
inputs={},
|
||||
outputs=None,
|
||||
run_type="chain",
|
||||
trace_id=chain_uuid,
|
||||
dotted_order=f"20230101T000000000000Z{chain_uuid}",
|
||||
child_runs=[
|
||||
Run( # type: ignore[call-arg]
|
||||
id=str(llm_uuid1), # type: ignore[arg-type]
|
||||
parent_run_id=str(chain_uuid), # type: ignore[arg-type]
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
events=[
|
||||
{"name": "start", "time": datetime.now(timezone.utc)},
|
||||
{"name": "end", "time": datetime.now(timezone.utc)},
|
||||
],
|
||||
extra={},
|
||||
serialized=SERIALIZED,
|
||||
error=None,
|
||||
inputs=dict(prompts=[]),
|
||||
outputs=LLMResult(generations=[[]], llm_output=None), # type: ignore[arg-type]
|
||||
run_type="llm",
|
||||
trace_id=chain_uuid,
|
||||
dotted_order=f"20230101T000000000000Z{chain_uuid}.20230101T000000000000Z{llm_uuid1}",
|
||||
),
|
||||
Run( # type: ignore[call-arg]
|
||||
id=str(llm_uuid2), # type: ignore[arg-type]
|
||||
parent_run_id=str(chain_uuid), # type: ignore[arg-type]
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
events=[
|
||||
{"name": "start", "time": datetime.now(timezone.utc)},
|
||||
{"name": "end", "time": datetime.now(timezone.utc)},
|
||||
],
|
||||
extra={},
|
||||
serialized=SERIALIZED,
|
||||
error=None,
|
||||
inputs=dict(prompts=[]),
|
||||
outputs=LLMResult(generations=[[]], llm_output=None), # type: ignore[arg-type]
|
||||
run_type="llm",
|
||||
trace_id=chain_uuid,
|
||||
dotted_order=f"20230101T000000000000Z{chain_uuid}.20230101T000000000000Z{llm_uuid2}",
|
||||
),
|
||||
Run( # type: ignore[call-arg]
|
||||
id=str(tool_uuid), # type: ignore[arg-type]
|
||||
parent_run_id=str(chain_uuid), # type: ignore[arg-type]
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
events=[
|
||||
{"name": "start", "time": datetime.now(timezone.utc)},
|
||||
{"name": "error", "time": datetime.now(timezone.utc)},
|
||||
],
|
||||
extra={},
|
||||
serialized={"name": "tool"},
|
||||
error=repr(exception),
|
||||
inputs=dict(input="test"),
|
||||
outputs=None,
|
||||
action="{'name': 'tool'}",
|
||||
trace_id=chain_uuid,
|
||||
dotted_order=f"20230101T000000000000Z{chain_uuid}.20230101T000000000000Z{tool_uuid}",
|
||||
child_runs=[
|
||||
Run( # type: ignore[call-arg]
|
||||
id=str(llm_uuid3), # type: ignore[arg-type]
|
||||
parent_run_id=str(tool_uuid), # type: ignore[arg-type]
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
events=[
|
||||
{"name": "start", "time": datetime.now(timezone.utc)},
|
||||
{"name": "error", "time": datetime.now(timezone.utc)},
|
||||
],
|
||||
extra={},
|
||||
serialized=SERIALIZED,
|
||||
error=repr(exception),
|
||||
inputs=dict(prompts=[]),
|
||||
outputs=None,
|
||||
run_type="llm",
|
||||
trace_id=chain_uuid,
|
||||
dotted_order=f"20230101T000000000000Z{chain_uuid}.20230101T000000000000Z{tool_uuid}.20230101T000000000000Z{llm_uuid3}",
|
||||
)
|
||||
],
|
||||
run_type="tool",
|
||||
),
|
||||
],
|
||||
)
|
||||
assert len(tracer.runs) == 3
|
||||
for run in tracer.runs:
|
||||
_compare_run_with_error(run, compare_run)
|
@@ -13,10 +13,11 @@ from freezegun import freeze_time
|
||||
from langsmith import Client, traceable
|
||||
|
||||
from langchain_core.callbacks import CallbackManager
|
||||
from langchain_core.exceptions import TracerException
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.runnables import chain as as_runnable
|
||||
from langchain_core.tracers.base import BaseTracer, TracerException
|
||||
from langchain_core.tracers.base import BaseTracer
|
||||
from langchain_core.tracers.schemas import Run
|
||||
|
||||
SERIALIZED = {"id": ["llm"]}
|
||||
|
Reference in New Issue
Block a user