Add Runnable.with_listeners() (#12549)

- This binds start/end/error listeners to a runnable, which will be
called with the Run object
This commit is contained in:
Nuno Campos 2023-10-31 11:04:51 +00:00 committed by GitHub
parent bcc62d63be
commit 2f563cee20
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 295 additions and 15 deletions

View File

@ -1496,6 +1496,18 @@ class CallbackManagerForChainGroup(CallbackManager):
self.parent_run_manager = parent_run_manager
self.ended = False
def copy(self) -> CallbackManagerForChainGroup:
return self.__class__(
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
metadata=self.metadata,
inheritable_metadata=self.inheritable_metadata,
parent_run_manager=self.parent_run_manager,
)
def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None:
"""Run when traced chain group ends.

View File

@ -0,0 +1,46 @@
from typing import Callable, Optional
from uuid import UUID
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run
class RootListenersTracer(BaseTracer):
def __init__(
self,
*,
on_start: Optional[Callable[[Run], None]],
on_end: Optional[Callable[[Run], None]],
on_error: Optional[Callable[[Run], None]]
) -> None:
super().__init__()
self._arg_on_start = on_start
self._arg_on_end = on_end
self._arg_on_error = on_error
self.root_id: Optional[UUID] = None
def _persist_run(self, run: Run) -> None:
# This is a legacy method only called once for an entire run tree
# therefore not useful here
pass
def _on_run_create(self, run: Run) -> None:
if self.root_id is not None:
return
self.root_id = run.id
if self._arg_on_start is not None:
self._arg_on_start(run)
def _on_run_update(self, run: Run) -> None:
if run.id != self.root_id:
return
if run.error is None:
if self._arg_on_end is not None:
self._arg_on_end(run)
else:
if self._arg_on_error is not None:
self._arg_on_error(run)

View File

@ -37,6 +37,7 @@ if TYPE_CHECKING:
CallbackManagerForChainRun,
)
from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch
from langchain.callbacks.tracers.schemas import Run
from langchain.schema.runnable.fallbacks import (
RunnableWithFallbacks as RunnableWithFallbacksT,
)
@ -585,6 +586,39 @@ class Runnable(Generic[Input, Output], ABC):
kwargs={},
)
def with_listeners(
self,
*,
on_start: Optional[Callable[[Run], None]] = None,
on_end: Optional[Callable[[Run], None]] = None,
on_error: Optional[Callable[[Run], None]] = None,
) -> Runnable[Input, Output]:
"""
Bind lifecycle listeners to a Runnable, returning a new Runnable.
on_start: Called before the runnable starts running, with the Run object.
on_end: Called after the runnable finishes running, with the Run object.
on_error: Called if the runnable throws an error, with the Run object.
The Run object contains information about the run, including its id,
type, input, output, error, start_time, end_time, and any tags or metadata
added to the run.
"""
from langchain.callbacks.tracers.root_listeners import RootListenersTracer
return RunnableBinding(
bound=self,
config_factories=[
lambda: {
"callbacks": [
RootListenersTracer(
on_start=on_start, on_end=on_end, on_error=on_error
)
],
}
],
)
def with_types(
self,
*,
@ -2323,6 +2357,30 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
) -> RunnableEach[Input, Output]:
return RunnableEach(bound=self.bound.with_config(config, **kwargs))
def with_listeners(
self,
*,
on_start: Optional[Callable[[Run], None]] = None,
on_end: Optional[Callable[[Run], None]] = None,
on_error: Optional[Callable[[Run], None]] = None,
) -> RunnableEach[Input, Output]:
"""
Bind lifecycle listeners to a Runnable, returning a new Runnable.
on_start: Called before the runnable starts running, with the Run object.
on_end: Called after the runnable finishes running, with the Run object.
on_error: Called if the runnable throws an error, with the Run object.
The Run object contains information about the run, including its id,
type, input, output, error, start_time, end_time, and any tags or metadata
added to the run.
"""
return RunnableEach(
bound=self.bound.with_listeners(
on_start=on_start, on_end=on_end, on_error=on_error
)
)
def _invoke(
self,
inputs: List[Input],
@ -2363,10 +2421,12 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
bound: Runnable[Input, Output]
kwargs: Mapping[str, Any]
kwargs: Mapping[str, Any] = Field(default_factory=dict)
config: RunnableConfig = Field(default_factory=dict)
config_factories: List[Callable[[], RunnableConfig]] = Field(default_factory=list)
# Union[Type[Input], BaseModel] + things like List[str]
custom_input_type: Optional[Any] = None
# Union[Type[Output], BaseModel] + things like List[str]
@ -2379,8 +2439,9 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
self,
*,
bound: Runnable[Input, Output],
kwargs: Mapping[str, Any],
kwargs: Optional[Mapping[str, Any]] = None,
config: Optional[RunnableConfig] = None,
config_factories: Optional[List[Callable[[], RunnableConfig]]] = None,
custom_input_type: Optional[Union[Type[Input], BaseModel]] = None,
custom_output_type: Optional[Union[Type[Output], BaseModel]] = None,
**other_kwargs: Any,
@ -2397,8 +2458,9 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
)
super().__init__(
bound=bound,
kwargs=kwargs,
config=config,
kwargs=kwargs or {},
config=config or {},
config_factories=config_factories or [],
custom_input_type=custom_input_type,
custom_output_type=custom_output_type,
**other_kwargs,
@ -2472,6 +2534,43 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
custom_output_type=self.custom_output_type,
)
def with_listeners(
self,
*,
on_start: Optional[Callable[[Run], None]] = None,
on_end: Optional[Callable[[Run], None]] = None,
on_error: Optional[Callable[[Run], None]] = None,
) -> Runnable[Input, Output]:
"""
Bind lifecycle listeners to a Runnable, returning a new Runnable.
on_start: Called before the runnable starts running, with the Run object.
on_end: Called after the runnable finishes running, with the Run object.
on_error: Called if the runnable throws an error, with the Run object.
The Run object contains information about the run, including its id,
type, input, output, error, start_time, end_time, and any tags or metadata
added to the run.
"""
from langchain.callbacks.tracers.root_listeners import RootListenersTracer
return self.__class__(
bound=self.bound,
kwargs=self.kwargs,
config=self.config,
config_factories=[
lambda: {
"callbacks": [
RootListenersTracer(
on_start=on_start, on_end=on_end, on_error=on_error
)
],
}
],
custom_input_type=self.custom_input_type,
custom_output_type=self.custom_output_type,
)
def with_types(
self,
input_type: Optional[Union[Type[Input], BaseModel]] = None,
@ -2496,6 +2595,11 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
config=self.config,
)
def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig:
return merge_configs(
self.config, *(f() for f in self.config_factories), *configs
)
def invoke(
self,
input: Input,
@ -2504,7 +2608,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
) -> Output:
return self.bound.invoke(
input,
merge_configs(self.config, config),
self._merge_configs(config),
**{**self.kwargs, **kwargs},
)
@ -2516,7 +2620,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
) -> Output:
return await self.bound.ainvoke(
input,
merge_configs(self.config, config),
self._merge_configs(config),
**{**self.kwargs, **kwargs},
)
@ -2531,10 +2635,10 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
if isinstance(config, list):
configs = cast(
List[RunnableConfig],
[merge_configs(self.config, conf) for conf in config],
[self._merge_configs(conf) for conf in config],
)
else:
configs = [merge_configs(self.config, config) for _ in range(len(inputs))]
configs = [self._merge_configs(config) for _ in range(len(inputs))]
return self.bound.batch(
inputs,
configs,
@ -2553,10 +2657,10 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
if isinstance(config, list):
configs = cast(
List[RunnableConfig],
[merge_configs(self.config, conf) for conf in config],
[self._merge_configs(conf) for conf in config],
)
else:
configs = [merge_configs(self.config, config) for _ in range(len(inputs))]
configs = [self._merge_configs(config) for _ in range(len(inputs))]
return await self.bound.abatch(
inputs,
configs,
@ -2572,7 +2676,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
) -> Iterator[Output]:
yield from self.bound.stream(
input,
merge_configs(self.config, config),
self._merge_configs(config),
**{**self.kwargs, **kwargs},
)
@ -2584,7 +2688,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
) -> AsyncIterator[Output]:
async for item in self.bound.astream(
input,
merge_configs(self.config, config),
self._merge_configs(config),
**{**self.kwargs, **kwargs},
):
yield item
@ -2597,7 +2701,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
) -> Iterator[Output]:
yield from self.bound.transform(
input,
merge_configs(self.config, config),
self._merge_configs(config),
**{**self.kwargs, **kwargs},
)
@ -2609,7 +2713,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
) -> AsyncIterator[Output]:
async for item in self.bound.atransform(
input,
merge_configs(self.config, config),
self._merge_configs(config),
**{**self.kwargs, **kwargs},
):
yield item

View File

@ -220,6 +220,51 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
**base.get(key, {}), # type: ignore
**(config.get(key) or {}), # type: ignore
}
elif key == "callbacks":
base_callbacks = base.get("callbacks")
these_callbacks = config["callbacks"]
# callbacks can be either None, list[handler] or manager
# so merging two callbacks values has 6 cases
if isinstance(these_callbacks, list):
if base_callbacks is None:
base["callbacks"] = these_callbacks
elif isinstance(base_callbacks, list):
base["callbacks"] = base_callbacks + these_callbacks
else:
# base_callbacks is a manager
mngr = base_callbacks.copy()
for callback in these_callbacks:
mngr.add_handler(callback, inherit=True)
base["callbacks"] = mngr
elif these_callbacks is not None:
# these_callbacks is a manager
if base_callbacks is None:
base["callbacks"] = these_callbacks
elif isinstance(base_callbacks, list):
mngr = these_callbacks.copy()
for callback in base_callbacks:
mngr.add_handler(callback, inherit=True)
base["callbacks"] = mngr
else:
# base_callbacks is also a manager
base["callbacks"] = base_callbacks.__class__(
parent_run_id=base_callbacks.parent_run_id
or these_callbacks.parent_run_id,
handlers=base_callbacks.handlers + these_callbacks.handlers,
inheritable_handlers=base_callbacks.inheritable_handlers
+ these_callbacks.inheritable_handlers,
tags=list(set(base_callbacks.tags + these_callbacks.tags)),
inheritable_tags=list(
set(
base_callbacks.inheritable_tags
+ these_callbacks.inheritable_tags
)
),
metadata={
**base_callbacks.metadata,
**these_callbacks.metadata,
},
)
else:
base[key] = config[key] or base.get(key) # type: ignore
return base

View File

@ -3748,6 +3748,7 @@
]
},
"config": {},
"config_factories": [],
"custom_input_type": null,
"custom_output_type": null
}

View File

@ -0,0 +1,34 @@
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
from langchain.schema.runnable.config import RunnableConfig, merge_configs
def test_merge_config_callbacks() -> None:
manager: RunnableConfig = {
"callbacks": CallbackManager(handlers=[StdOutCallbackHandler()])
}
handlers: RunnableConfig = {"callbacks": [ConsoleCallbackHandler()]}
other_handlers: RunnableConfig = {"callbacks": [StreamingStdOutCallbackHandler()]}
merged = merge_configs(manager, handlers)["callbacks"]
assert isinstance(merged, CallbackManager)
assert len(merged.handlers) == 2
assert isinstance(merged.handlers[0], StdOutCallbackHandler)
assert isinstance(merged.handlers[1], ConsoleCallbackHandler)
merged = merge_configs(handlers, manager)["callbacks"]
assert isinstance(merged, CallbackManager)
assert len(merged.handlers) == 2
assert isinstance(merged.handlers[0], StdOutCallbackHandler)
assert isinstance(merged.handlers[1], ConsoleCallbackHandler)
merged = merge_configs(handlers, other_handlers)["callbacks"]
assert isinstance(merged, list)
assert len(merged) == 2
assert isinstance(merged[0], ConsoleCallbackHandler)
assert isinstance(merged[1], StreamingStdOutCallbackHandler)

View File

@ -19,7 +19,12 @@ from freezegun import freeze_time
from pytest_mock import MockerFixture
from syrupy import SnapshotAssertion
from langchain.callbacks.manager import Callbacks, atrace_as_chain_group, collect_runs
from langchain.callbacks.manager import (
Callbacks,
atrace_as_chain_group,
collect_runs,
trace_as_chain_group,
)
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch
from langchain.callbacks.tracers.schemas import Run
@ -1495,6 +1500,39 @@ def test_prompt_template_params() -> None:
prompt.invoke({})
def test_with_listeners(mocker: MockerFixture) -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
chat = FakeListChatModel(responses=["foo"])
chain = prompt | chat
mock_start = mocker.Mock()
mock_end = mocker.Mock()
chain.with_listeners(on_start=mock_start, on_end=mock_end).invoke(
{"question": "Who are you?"}
)
assert mock_start.call_count == 1
assert mock_start.call_args[0][0].name == "RunnableSequence"
assert mock_end.call_count == 1
mock_start.reset_mock()
mock_end.reset_mock()
with trace_as_chain_group("hello") as manager:
chain.with_listeners(on_start=mock_start, on_end=mock_end).invoke(
{"question": "Who are you?"}, {"callbacks": manager}
)
assert mock_start.call_count == 1
assert mock_start.call_args[0][0].name == "RunnableSequence"
assert mock_end.call_count == 1
@pytest.mark.asyncio
@freeze_time("2023-01-01")
async def test_prompt_with_chat_model(