Changes to root listener (#12174)

- Implement config_specs to include session_id
- Remove Runnable method and update notebook
- Add more details to notebook, eg. show input schema and config schema
before and after adding message history

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Nuno Campos 2023-11-10 09:53:48 +00:00 committed by GitHub
parent b2b94424db
commit 362a446999
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 298 additions and 27 deletions

View File

@ -1,20 +1,28 @@
from typing import Callable, Optional
from typing import Callable, Optional, Union
from uuid import UUID
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run
from langchain.schema.runnable.config import (
RunnableConfig,
call_func_with_variable_args,
)
Listener = Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]]
class RootListenersTracer(BaseTracer):
def __init__(
self,
*,
on_start: Optional[Callable[[Run], None]],
on_end: Optional[Callable[[Run], None]],
on_error: Optional[Callable[[Run], None]],
config: RunnableConfig,
on_start: Optional[Listener],
on_end: Optional[Listener],
on_error: Optional[Listener],
) -> None:
super().__init__()
self.config = config
self._arg_on_start = on_start
self._arg_on_end = on_end
self._arg_on_error = on_error
@ -32,7 +40,7 @@ class RootListenersTracer(BaseTracer):
self.root_id = run.id
if self._arg_on_start is not None:
self._arg_on_start(run)
call_func_with_variable_args(self._arg_on_start, run, self.config)
def _on_run_update(self, run: Run) -> None:
if run.id != self.root_id:
@ -40,7 +48,7 @@ class RootListenersTracer(BaseTracer):
if run.error is None:
if self._arg_on_end is not None:
self._arg_on_end(run)
call_func_with_variable_args(self._arg_on_end, run, self.config)
else:
if self._arg_on_error is not None:
self._arg_on_error(run)
call_func_with_variable_args(self._arg_on_error, run, self.config)

View File

@ -37,7 +37,7 @@ if TYPE_CHECKING:
CallbackManagerForChainRun,
)
from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch
from langchain.callbacks.tracers.schemas import Run
from langchain.callbacks.tracers.root_listeners import Listener
from langchain.schema.runnable.fallbacks import (
RunnableWithFallbacks as RunnableWithFallbacksT,
)
@ -591,9 +591,9 @@ class Runnable(Generic[Input, Output], ABC):
def with_listeners(
self,
*,
on_start: Optional[Callable[[Run], None]] = None,
on_end: Optional[Callable[[Run], None]] = None,
on_error: Optional[Callable[[Run], None]] = None,
on_start: Optional[Listener] = None,
on_end: Optional[Listener] = None,
on_error: Optional[Listener] = None,
) -> Runnable[Input, Output]:
"""
Bind lifecycle listeners to a Runnable, returning a new Runnable.
@ -611,10 +611,13 @@ class Runnable(Generic[Input, Output], ABC):
return RunnableBinding(
bound=self,
config_factories=[
lambda: {
lambda config: {
"callbacks": [
RootListenersTracer(
on_start=on_start, on_end=on_end, on_error=on_error
config=config,
on_start=on_start,
on_end=on_end,
on_error=on_error,
)
],
}
@ -2391,9 +2394,9 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
def with_listeners(
self,
*,
on_start: Optional[Callable[[Run], None]] = None,
on_end: Optional[Callable[[Run], None]] = None,
on_error: Optional[Callable[[Run], None]] = None,
on_start: Optional[Listener] = None,
on_end: Optional[Listener] = None,
on_error: Optional[Listener] = None,
) -> RunnableEach[Input, Output]:
"""
Bind lifecycle listeners to a Runnable, returning a new Runnable.
@ -2456,7 +2459,9 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
config: RunnableConfig = Field(default_factory=dict)
config_factories: List[Callable[[], RunnableConfig]] = Field(default_factory=list)
config_factories: List[Callable[[RunnableConfig], RunnableConfig]] = Field(
default_factory=list
)
# Union[Type[Input], BaseModel] + things like List[str]
custom_input_type: Optional[Any] = None
@ -2472,7 +2477,9 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
bound: Runnable[Input, Output],
kwargs: Optional[Mapping[str, Any]] = None,
config: Optional[RunnableConfig] = None,
config_factories: Optional[List[Callable[[], RunnableConfig]]] = None,
config_factories: Optional[
List[Callable[[RunnableConfig], RunnableConfig]]
] = None,
custom_input_type: Optional[Union[Type[Input], BaseModel]] = None,
custom_output_type: Optional[Union[Type[Output], BaseModel]] = None,
**other_kwargs: Any,
@ -2570,9 +2577,9 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
def with_listeners(
self,
*,
on_start: Optional[Callable[[Run], None]] = None,
on_end: Optional[Callable[[Run], None]] = None,
on_error: Optional[Callable[[Run], None]] = None,
on_start: Optional[Listener] = None,
on_end: Optional[Listener] = None,
on_error: Optional[Listener] = None,
) -> Runnable[Input, Output]:
"""
Bind lifecycle listeners to a Runnable, returning a new Runnable.
@ -2592,10 +2599,13 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
kwargs=self.kwargs,
config=self.config,
config_factories=[
lambda: {
lambda config: {
"callbacks": [
RootListenersTracer(
on_start=on_start, on_end=on_end, on_error=on_error
config=config,
on_start=on_start,
on_end=on_end,
on_error=on_error,
)
],
}
@ -2629,9 +2639,8 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
)
def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig:
return merge_configs(
self.config, *(f() for f in self.config_factories), *configs
)
config = merge_configs(self.config, *configs)
return merge_configs(config, *(f(config) for f in self.config_factories))
def invoke(
self,

File diff suppressed because one or more lines are too long

View File

@ -1712,8 +1712,41 @@ def test_with_listeners(mocker: MockerFixture) -> None:
@pytest.mark.asyncio
async def test_with_listeners_async(mocker: MockerFixture) -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
chat = FakeListChatModel(responses=["foo"])
chain = prompt | chat
mock_start = mocker.Mock()
mock_end = mocker.Mock()
await chain.with_listeners(on_start=mock_start, on_end=mock_end).ainvoke(
{"question": "Who are you?"}
)
assert mock_start.call_count == 1
assert mock_start.call_args[0][0].name == "RunnableSequence"
assert mock_end.call_count == 1
mock_start.reset_mock()
mock_end.reset_mock()
async with atrace_as_chain_group("hello") as manager:
await chain.with_listeners(on_start=mock_start, on_end=mock_end).ainvoke(
{"question": "Who are you?"}, {"callbacks": manager}
)
assert mock_start.call_count == 1
assert mock_start.call_args[0][0].name == "RunnableSequence"
assert mock_end.call_count == 1
@freeze_time("2023-01-01")
async def test_prompt_with_chat_model(
def test_prompt_with_chat_model(
mocker: MockerFixture, snapshot: SnapshotAssertion
) -> None:
prompt = (
@ -1816,6 +1849,114 @@ async def test_prompt_with_chat_model(
)
@pytest.mark.asyncio
@freeze_time("2023-01-01")
async def test_prompt_with_chat_model_async(
mocker: MockerFixture, snapshot: SnapshotAssertion
) -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
chat = FakeListChatModel(responses=["foo"])
chain = prompt | chat
assert repr(chain) == snapshot
assert isinstance(chain, RunnableSequence)
assert chain.first == prompt
assert chain.middle == []
assert chain.last == chat
assert dumps(chain, pretty=True) == snapshot
# Test invoke
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
chat_spy = mocker.spy(chat.__class__, "ainvoke")
tracer = FakeTracer()
assert await chain.ainvoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
) == AIMessage(content="foo")
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
)
assert tracer.runs == snapshot
mocker.stop(prompt_spy)
mocker.stop(chat_spy)
# Test batch
prompt_spy = mocker.spy(prompt.__class__, "abatch")
chat_spy = mocker.spy(chat.__class__, "abatch")
tracer = FakeTracer()
assert await chain.abatch(
[
{"question": "What is your name?"},
{"question": "What is your favorite color?"},
],
dict(callbacks=[tracer]),
) == [
AIMessage(content="foo"),
AIMessage(content="foo"),
]
assert prompt_spy.call_args.args[1] == [
{"question": "What is your name?"},
{"question": "What is your favorite color?"},
]
assert chat_spy.call_args.args[1] == [
ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
),
ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your favorite color?"),
]
),
]
assert (
len(
[
r
for r in tracer.runs
if r.parent_run_id is None and len(r.child_runs) == 2
]
)
== 2
), "Each of 2 outer runs contains exactly two inner runs (1 prompt, 1 chat)"
mocker.stop(prompt_spy)
mocker.stop(chat_spy)
# Test stream
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
chat_spy = mocker.spy(chat.__class__, "astream")
tracer = FakeTracer()
assert [
a
async for a in chain.astream(
{"question": "What is your name?"}, dict(callbacks=[tracer])
)
] == [
AIMessageChunk(content="f"),
AIMessageChunk(content="o"),
AIMessageChunk(content="o"),
]
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
)
@pytest.mark.asyncio
@freeze_time("2023-01-01")
async def test_prompt_with_llm(