Compare commits

...

1 Commits

Author SHA1 Message Date
Harrison Chase
d5d860964a update root listeners 2023-11-08 10:31:17 -08:00

View File

@@ -1,20 +1,28 @@
from typing import Callable, Optional
from typing import Callable, Optional, Union
from uuid import UUID
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run
from langchain.schema.runnable.config import (
RunnableConfig,
call_func_with_variable_args,
)
Listener = Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]]
class RootListenersTracer(BaseTracer):
def __init__(
self,
*,
on_start: Optional[Callable[[Run], None]],
on_end: Optional[Callable[[Run], None]],
on_error: Optional[Callable[[Run], None]],
config: RunnableConfig,
on_start: Optional[Listener],
on_end: Optional[Listener],
on_error: Optional[Listener],
) -> None:
super().__init__()
self.config = config
self._arg_on_start = on_start
self._arg_on_end = on_end
self._arg_on_error = on_error
@@ -32,7 +40,7 @@ class RootListenersTracer(BaseTracer):
self.root_id = run.id
if self._arg_on_start is not None:
self._arg_on_start(run)
call_func_with_variable_args(self._arg_on_start, run, self.config)
def _on_run_update(self, run: Run) -> None:
if run.id != self.root_id:
@@ -40,7 +48,7 @@ class RootListenersTracer(BaseTracer):
if run.error is None:
if self._arg_on_end is not None:
self._arg_on_end(run)
call_func_with_variable_args(self._arg_on_end, run, self.config)
else:
if self._arg_on_error is not None:
self._arg_on_error(run)
call_func_with_variable_args(self._arg_on_error, run, self.config)