mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 22:29:51 +00:00
Add Runnable.with_listeners() (#12549)
- This binds start/end/error listeners to a runnable, which will be called with the Run object
This commit is contained in:
parent
bcc62d63be
commit
2f563cee20
@ -1496,6 +1496,18 @@ class CallbackManagerForChainGroup(CallbackManager):
|
||||
self.parent_run_manager = parent_run_manager
|
||||
self.ended = False
|
||||
|
||||
def copy(self) -> CallbackManagerForChainGroup:
|
||||
return self.__class__(
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
parent_run_manager=self.parent_run_manager,
|
||||
)
|
||||
|
||||
def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None:
|
||||
"""Run when traced chain group ends.
|
||||
|
||||
|
46
libs/langchain/langchain/callbacks/tracers/root_listeners.py
Normal file
46
libs/langchain/langchain/callbacks/tracers/root_listeners.py
Normal file
@ -0,0 +1,46 @@
|
||||
from typing import Callable, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.callbacks.tracers.base import BaseTracer
|
||||
from langchain.callbacks.tracers.schemas import Run
|
||||
|
||||
|
||||
class RootListenersTracer(BaseTracer):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
on_start: Optional[Callable[[Run], None]],
|
||||
on_end: Optional[Callable[[Run], None]],
|
||||
on_error: Optional[Callable[[Run], None]]
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._arg_on_start = on_start
|
||||
self._arg_on_end = on_end
|
||||
self._arg_on_error = on_error
|
||||
self.root_id: Optional[UUID] = None
|
||||
|
||||
def _persist_run(self, run: Run) -> None:
|
||||
# This is a legacy method only called once for an entire run tree
|
||||
# therefore not useful here
|
||||
pass
|
||||
|
||||
def _on_run_create(self, run: Run) -> None:
|
||||
if self.root_id is not None:
|
||||
return
|
||||
|
||||
self.root_id = run.id
|
||||
|
||||
if self._arg_on_start is not None:
|
||||
self._arg_on_start(run)
|
||||
|
||||
def _on_run_update(self, run: Run) -> None:
|
||||
if run.id != self.root_id:
|
||||
return
|
||||
|
||||
if run.error is None:
|
||||
if self._arg_on_end is not None:
|
||||
self._arg_on_end(run)
|
||||
else:
|
||||
if self._arg_on_error is not None:
|
||||
self._arg_on_error(run)
|
@ -37,6 +37,7 @@ if TYPE_CHECKING:
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch
|
||||
from langchain.callbacks.tracers.schemas import Run
|
||||
from langchain.schema.runnable.fallbacks import (
|
||||
RunnableWithFallbacks as RunnableWithFallbacksT,
|
||||
)
|
||||
@ -585,6 +586,39 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
kwargs={},
|
||||
)
|
||||
|
||||
def with_listeners(
|
||||
self,
|
||||
*,
|
||||
on_start: Optional[Callable[[Run], None]] = None,
|
||||
on_end: Optional[Callable[[Run], None]] = None,
|
||||
on_error: Optional[Callable[[Run], None]] = None,
|
||||
) -> Runnable[Input, Output]:
|
||||
"""
|
||||
Bind lifecycle listeners to a Runnable, returning a new Runnable.
|
||||
|
||||
on_start: Called before the runnable starts running, with the Run object.
|
||||
on_end: Called after the runnable finishes running, with the Run object.
|
||||
on_error: Called if the runnable throws an error, with the Run object.
|
||||
|
||||
The Run object contains information about the run, including its id,
|
||||
type, input, output, error, start_time, end_time, and any tags or metadata
|
||||
added to the run.
|
||||
"""
|
||||
from langchain.callbacks.tracers.root_listeners import RootListenersTracer
|
||||
|
||||
return RunnableBinding(
|
||||
bound=self,
|
||||
config_factories=[
|
||||
lambda: {
|
||||
"callbacks": [
|
||||
RootListenersTracer(
|
||||
on_start=on_start, on_end=on_end, on_error=on_error
|
||||
)
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
def with_types(
|
||||
self,
|
||||
*,
|
||||
@ -2323,6 +2357,30 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
|
||||
) -> RunnableEach[Input, Output]:
|
||||
return RunnableEach(bound=self.bound.with_config(config, **kwargs))
|
||||
|
||||
def with_listeners(
|
||||
self,
|
||||
*,
|
||||
on_start: Optional[Callable[[Run], None]] = None,
|
||||
on_end: Optional[Callable[[Run], None]] = None,
|
||||
on_error: Optional[Callable[[Run], None]] = None,
|
||||
) -> RunnableEach[Input, Output]:
|
||||
"""
|
||||
Bind lifecycle listeners to a Runnable, returning a new Runnable.
|
||||
|
||||
on_start: Called before the runnable starts running, with the Run object.
|
||||
on_end: Called after the runnable finishes running, with the Run object.
|
||||
on_error: Called if the runnable throws an error, with the Run object.
|
||||
|
||||
The Run object contains information about the run, including its id,
|
||||
type, input, output, error, start_time, end_time, and any tags or metadata
|
||||
added to the run.
|
||||
"""
|
||||
return RunnableEach(
|
||||
bound=self.bound.with_listeners(
|
||||
on_start=on_start, on_end=on_end, on_error=on_error
|
||||
)
|
||||
)
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
@ -2363,10 +2421,12 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
|
||||
bound: Runnable[Input, Output]
|
||||
|
||||
kwargs: Mapping[str, Any]
|
||||
kwargs: Mapping[str, Any] = Field(default_factory=dict)
|
||||
|
||||
config: RunnableConfig = Field(default_factory=dict)
|
||||
|
||||
config_factories: List[Callable[[], RunnableConfig]] = Field(default_factory=list)
|
||||
|
||||
# Union[Type[Input], BaseModel] + things like List[str]
|
||||
custom_input_type: Optional[Any] = None
|
||||
# Union[Type[Output], BaseModel] + things like List[str]
|
||||
@ -2379,8 +2439,9 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
self,
|
||||
*,
|
||||
bound: Runnable[Input, Output],
|
||||
kwargs: Mapping[str, Any],
|
||||
kwargs: Optional[Mapping[str, Any]] = None,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
config_factories: Optional[List[Callable[[], RunnableConfig]]] = None,
|
||||
custom_input_type: Optional[Union[Type[Input], BaseModel]] = None,
|
||||
custom_output_type: Optional[Union[Type[Output], BaseModel]] = None,
|
||||
**other_kwargs: Any,
|
||||
@ -2397,8 +2458,9 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
)
|
||||
super().__init__(
|
||||
bound=bound,
|
||||
kwargs=kwargs,
|
||||
config=config,
|
||||
kwargs=kwargs or {},
|
||||
config=config or {},
|
||||
config_factories=config_factories or [],
|
||||
custom_input_type=custom_input_type,
|
||||
custom_output_type=custom_output_type,
|
||||
**other_kwargs,
|
||||
@ -2472,6 +2534,43 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
custom_output_type=self.custom_output_type,
|
||||
)
|
||||
|
||||
def with_listeners(
|
||||
self,
|
||||
*,
|
||||
on_start: Optional[Callable[[Run], None]] = None,
|
||||
on_end: Optional[Callable[[Run], None]] = None,
|
||||
on_error: Optional[Callable[[Run], None]] = None,
|
||||
) -> Runnable[Input, Output]:
|
||||
"""
|
||||
Bind lifecycle listeners to a Runnable, returning a new Runnable.
|
||||
|
||||
on_start: Called before the runnable starts running, with the Run object.
|
||||
on_end: Called after the runnable finishes running, with the Run object.
|
||||
on_error: Called if the runnable throws an error, with the Run object.
|
||||
|
||||
The Run object contains information about the run, including its id,
|
||||
type, input, output, error, start_time, end_time, and any tags or metadata
|
||||
added to the run.
|
||||
"""
|
||||
from langchain.callbacks.tracers.root_listeners import RootListenersTracer
|
||||
|
||||
return self.__class__(
|
||||
bound=self.bound,
|
||||
kwargs=self.kwargs,
|
||||
config=self.config,
|
||||
config_factories=[
|
||||
lambda: {
|
||||
"callbacks": [
|
||||
RootListenersTracer(
|
||||
on_start=on_start, on_end=on_end, on_error=on_error
|
||||
)
|
||||
],
|
||||
}
|
||||
],
|
||||
custom_input_type=self.custom_input_type,
|
||||
custom_output_type=self.custom_output_type,
|
||||
)
|
||||
|
||||
def with_types(
|
||||
self,
|
||||
input_type: Optional[Union[Type[Input], BaseModel]] = None,
|
||||
@ -2496,6 +2595,11 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
config=self.config,
|
||||
)
|
||||
|
||||
def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig:
|
||||
return merge_configs(
|
||||
self.config, *(f() for f in self.config_factories), *configs
|
||||
)
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: Input,
|
||||
@ -2504,7 +2608,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
) -> Output:
|
||||
return self.bound.invoke(
|
||||
input,
|
||||
merge_configs(self.config, config),
|
||||
self._merge_configs(config),
|
||||
**{**self.kwargs, **kwargs},
|
||||
)
|
||||
|
||||
@ -2516,7 +2620,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
) -> Output:
|
||||
return await self.bound.ainvoke(
|
||||
input,
|
||||
merge_configs(self.config, config),
|
||||
self._merge_configs(config),
|
||||
**{**self.kwargs, **kwargs},
|
||||
)
|
||||
|
||||
@ -2531,10 +2635,10 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
if isinstance(config, list):
|
||||
configs = cast(
|
||||
List[RunnableConfig],
|
||||
[merge_configs(self.config, conf) for conf in config],
|
||||
[self._merge_configs(conf) for conf in config],
|
||||
)
|
||||
else:
|
||||
configs = [merge_configs(self.config, config) for _ in range(len(inputs))]
|
||||
configs = [self._merge_configs(config) for _ in range(len(inputs))]
|
||||
return self.bound.batch(
|
||||
inputs,
|
||||
configs,
|
||||
@ -2553,10 +2657,10 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
if isinstance(config, list):
|
||||
configs = cast(
|
||||
List[RunnableConfig],
|
||||
[merge_configs(self.config, conf) for conf in config],
|
||||
[self._merge_configs(conf) for conf in config],
|
||||
)
|
||||
else:
|
||||
configs = [merge_configs(self.config, config) for _ in range(len(inputs))]
|
||||
configs = [self._merge_configs(config) for _ in range(len(inputs))]
|
||||
return await self.bound.abatch(
|
||||
inputs,
|
||||
configs,
|
||||
@ -2572,7 +2676,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
) -> Iterator[Output]:
|
||||
yield from self.bound.stream(
|
||||
input,
|
||||
merge_configs(self.config, config),
|
||||
self._merge_configs(config),
|
||||
**{**self.kwargs, **kwargs},
|
||||
)
|
||||
|
||||
@ -2584,7 +2688,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
) -> AsyncIterator[Output]:
|
||||
async for item in self.bound.astream(
|
||||
input,
|
||||
merge_configs(self.config, config),
|
||||
self._merge_configs(config),
|
||||
**{**self.kwargs, **kwargs},
|
||||
):
|
||||
yield item
|
||||
@ -2597,7 +2701,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
) -> Iterator[Output]:
|
||||
yield from self.bound.transform(
|
||||
input,
|
||||
merge_configs(self.config, config),
|
||||
self._merge_configs(config),
|
||||
**{**self.kwargs, **kwargs},
|
||||
)
|
||||
|
||||
@ -2609,7 +2713,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
) -> AsyncIterator[Output]:
|
||||
async for item in self.bound.atransform(
|
||||
input,
|
||||
merge_configs(self.config, config),
|
||||
self._merge_configs(config),
|
||||
**{**self.kwargs, **kwargs},
|
||||
):
|
||||
yield item
|
||||
|
@ -220,6 +220,51 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
|
||||
**base.get(key, {}), # type: ignore
|
||||
**(config.get(key) or {}), # type: ignore
|
||||
}
|
||||
elif key == "callbacks":
|
||||
base_callbacks = base.get("callbacks")
|
||||
these_callbacks = config["callbacks"]
|
||||
# callbacks can be either None, list[handler] or manager
|
||||
# so merging two callbacks values has 6 cases
|
||||
if isinstance(these_callbacks, list):
|
||||
if base_callbacks is None:
|
||||
base["callbacks"] = these_callbacks
|
||||
elif isinstance(base_callbacks, list):
|
||||
base["callbacks"] = base_callbacks + these_callbacks
|
||||
else:
|
||||
# base_callbacks is a manager
|
||||
mngr = base_callbacks.copy()
|
||||
for callback in these_callbacks:
|
||||
mngr.add_handler(callback, inherit=True)
|
||||
base["callbacks"] = mngr
|
||||
elif these_callbacks is not None:
|
||||
# these_callbacks is a manager
|
||||
if base_callbacks is None:
|
||||
base["callbacks"] = these_callbacks
|
||||
elif isinstance(base_callbacks, list):
|
||||
mngr = these_callbacks.copy()
|
||||
for callback in base_callbacks:
|
||||
mngr.add_handler(callback, inherit=True)
|
||||
base["callbacks"] = mngr
|
||||
else:
|
||||
# base_callbacks is also a manager
|
||||
base["callbacks"] = base_callbacks.__class__(
|
||||
parent_run_id=base_callbacks.parent_run_id
|
||||
or these_callbacks.parent_run_id,
|
||||
handlers=base_callbacks.handlers + these_callbacks.handlers,
|
||||
inheritable_handlers=base_callbacks.inheritable_handlers
|
||||
+ these_callbacks.inheritable_handlers,
|
||||
tags=list(set(base_callbacks.tags + these_callbacks.tags)),
|
||||
inheritable_tags=list(
|
||||
set(
|
||||
base_callbacks.inheritable_tags
|
||||
+ these_callbacks.inheritable_tags
|
||||
)
|
||||
),
|
||||
metadata={
|
||||
**base_callbacks.metadata,
|
||||
**these_callbacks.metadata,
|
||||
},
|
||||
)
|
||||
else:
|
||||
base[key] = config[key] or base.get(key) # type: ignore
|
||||
return base
|
||||
|
@ -3748,6 +3748,7 @@
|
||||
]
|
||||
},
|
||||
"config": {},
|
||||
"config_factories": [],
|
||||
"custom_input_type": null,
|
||||
"custom_output_type": null
|
||||
}
|
||||
|
@ -0,0 +1,34 @@
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
|
||||
from langchain.schema.runnable.config import RunnableConfig, merge_configs
|
||||
|
||||
|
||||
def test_merge_config_callbacks() -> None:
|
||||
manager: RunnableConfig = {
|
||||
"callbacks": CallbackManager(handlers=[StdOutCallbackHandler()])
|
||||
}
|
||||
handlers: RunnableConfig = {"callbacks": [ConsoleCallbackHandler()]}
|
||||
other_handlers: RunnableConfig = {"callbacks": [StreamingStdOutCallbackHandler()]}
|
||||
|
||||
merged = merge_configs(manager, handlers)["callbacks"]
|
||||
|
||||
assert isinstance(merged, CallbackManager)
|
||||
assert len(merged.handlers) == 2
|
||||
assert isinstance(merged.handlers[0], StdOutCallbackHandler)
|
||||
assert isinstance(merged.handlers[1], ConsoleCallbackHandler)
|
||||
|
||||
merged = merge_configs(handlers, manager)["callbacks"]
|
||||
|
||||
assert isinstance(merged, CallbackManager)
|
||||
assert len(merged.handlers) == 2
|
||||
assert isinstance(merged.handlers[0], StdOutCallbackHandler)
|
||||
assert isinstance(merged.handlers[1], ConsoleCallbackHandler)
|
||||
|
||||
merged = merge_configs(handlers, other_handlers)["callbacks"]
|
||||
|
||||
assert isinstance(merged, list)
|
||||
assert len(merged) == 2
|
||||
assert isinstance(merged[0], ConsoleCallbackHandler)
|
||||
assert isinstance(merged[1], StreamingStdOutCallbackHandler)
|
@ -19,7 +19,12 @@ from freezegun import freeze_time
|
||||
from pytest_mock import MockerFixture
|
||||
from syrupy import SnapshotAssertion
|
||||
|
||||
from langchain.callbacks.manager import Callbacks, atrace_as_chain_group, collect_runs
|
||||
from langchain.callbacks.manager import (
|
||||
Callbacks,
|
||||
atrace_as_chain_group,
|
||||
collect_runs,
|
||||
trace_as_chain_group,
|
||||
)
|
||||
from langchain.callbacks.tracers.base import BaseTracer
|
||||
from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch
|
||||
from langchain.callbacks.tracers.schemas import Run
|
||||
@ -1495,6 +1500,39 @@ def test_prompt_template_params() -> None:
|
||||
prompt.invoke({})
|
||||
|
||||
|
||||
def test_with_listeners(mocker: MockerFixture) -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
+ "{question}"
|
||||
)
|
||||
chat = FakeListChatModel(responses=["foo"])
|
||||
|
||||
chain = prompt | chat
|
||||
|
||||
mock_start = mocker.Mock()
|
||||
mock_end = mocker.Mock()
|
||||
|
||||
chain.with_listeners(on_start=mock_start, on_end=mock_end).invoke(
|
||||
{"question": "Who are you?"}
|
||||
)
|
||||
|
||||
assert mock_start.call_count == 1
|
||||
assert mock_start.call_args[0][0].name == "RunnableSequence"
|
||||
assert mock_end.call_count == 1
|
||||
|
||||
mock_start.reset_mock()
|
||||
mock_end.reset_mock()
|
||||
|
||||
with trace_as_chain_group("hello") as manager:
|
||||
chain.with_listeners(on_start=mock_start, on_end=mock_end).invoke(
|
||||
{"question": "Who are you?"}, {"callbacks": manager}
|
||||
)
|
||||
|
||||
assert mock_start.call_count == 1
|
||||
assert mock_start.call_args[0][0].name == "RunnableSequence"
|
||||
assert mock_end.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_prompt_with_chat_model(
|
||||
|
Loading…
Reference in New Issue
Block a user