Changes to root listener (#12174)

- Implement config_specs to include session_id
- Remove Runnable method and update notebook
- Add more details to notebook, eg. show input schema and config schema
before and after adding message history

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Nuno Campos 2023-11-10 09:53:48 +00:00 committed by GitHub
parent b2b94424db
commit 362a446999
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 298 additions and 27 deletions

View File

@ -1,20 +1,28 @@
from typing import Callable, Optional from typing import Callable, Optional, Union
from uuid import UUID from uuid import UUID
from langchain.callbacks.tracers.base import BaseTracer from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run from langchain.callbacks.tracers.schemas import Run
from langchain.schema.runnable.config import (
RunnableConfig,
call_func_with_variable_args,
)
Listener = Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]]
class RootListenersTracer(BaseTracer): class RootListenersTracer(BaseTracer):
def __init__( def __init__(
self, self,
*, *,
on_start: Optional[Callable[[Run], None]], config: RunnableConfig,
on_end: Optional[Callable[[Run], None]], on_start: Optional[Listener],
on_error: Optional[Callable[[Run], None]], on_end: Optional[Listener],
on_error: Optional[Listener],
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config
self._arg_on_start = on_start self._arg_on_start = on_start
self._arg_on_end = on_end self._arg_on_end = on_end
self._arg_on_error = on_error self._arg_on_error = on_error
@ -32,7 +40,7 @@ class RootListenersTracer(BaseTracer):
self.root_id = run.id self.root_id = run.id
if self._arg_on_start is not None: if self._arg_on_start is not None:
self._arg_on_start(run) call_func_with_variable_args(self._arg_on_start, run, self.config)
def _on_run_update(self, run: Run) -> None: def _on_run_update(self, run: Run) -> None:
if run.id != self.root_id: if run.id != self.root_id:
@ -40,7 +48,7 @@ class RootListenersTracer(BaseTracer):
if run.error is None: if run.error is None:
if self._arg_on_end is not None: if self._arg_on_end is not None:
self._arg_on_end(run) call_func_with_variable_args(self._arg_on_end, run, self.config)
else: else:
if self._arg_on_error is not None: if self._arg_on_error is not None:
self._arg_on_error(run) call_func_with_variable_args(self._arg_on_error, run, self.config)

View File

@ -37,7 +37,7 @@ if TYPE_CHECKING:
CallbackManagerForChainRun, CallbackManagerForChainRun,
) )
from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch
from langchain.callbacks.tracers.schemas import Run from langchain.callbacks.tracers.root_listeners import Listener
from langchain.schema.runnable.fallbacks import ( from langchain.schema.runnable.fallbacks import (
RunnableWithFallbacks as RunnableWithFallbacksT, RunnableWithFallbacks as RunnableWithFallbacksT,
) )
@ -591,9 +591,9 @@ class Runnable(Generic[Input, Output], ABC):
def with_listeners( def with_listeners(
self, self,
*, *,
on_start: Optional[Callable[[Run], None]] = None, on_start: Optional[Listener] = None,
on_end: Optional[Callable[[Run], None]] = None, on_end: Optional[Listener] = None,
on_error: Optional[Callable[[Run], None]] = None, on_error: Optional[Listener] = None,
) -> Runnable[Input, Output]: ) -> Runnable[Input, Output]:
""" """
Bind lifecycle listeners to a Runnable, returning a new Runnable. Bind lifecycle listeners to a Runnable, returning a new Runnable.
@ -611,10 +611,13 @@ class Runnable(Generic[Input, Output], ABC):
return RunnableBinding( return RunnableBinding(
bound=self, bound=self,
config_factories=[ config_factories=[
lambda: { lambda config: {
"callbacks": [ "callbacks": [
RootListenersTracer( RootListenersTracer(
on_start=on_start, on_end=on_end, on_error=on_error config=config,
on_start=on_start,
on_end=on_end,
on_error=on_error,
) )
], ],
} }
@ -2391,9 +2394,9 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
def with_listeners( def with_listeners(
self, self,
*, *,
on_start: Optional[Callable[[Run], None]] = None, on_start: Optional[Listener] = None,
on_end: Optional[Callable[[Run], None]] = None, on_end: Optional[Listener] = None,
on_error: Optional[Callable[[Run], None]] = None, on_error: Optional[Listener] = None,
) -> RunnableEach[Input, Output]: ) -> RunnableEach[Input, Output]:
""" """
Bind lifecycle listeners to a Runnable, returning a new Runnable. Bind lifecycle listeners to a Runnable, returning a new Runnable.
@ -2456,7 +2459,9 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
config: RunnableConfig = Field(default_factory=dict) config: RunnableConfig = Field(default_factory=dict)
config_factories: List[Callable[[], RunnableConfig]] = Field(default_factory=list) config_factories: List[Callable[[RunnableConfig], RunnableConfig]] = Field(
default_factory=list
)
# Union[Type[Input], BaseModel] + things like List[str] # Union[Type[Input], BaseModel] + things like List[str]
custom_input_type: Optional[Any] = None custom_input_type: Optional[Any] = None
@ -2472,7 +2477,9 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
bound: Runnable[Input, Output], bound: Runnable[Input, Output],
kwargs: Optional[Mapping[str, Any]] = None, kwargs: Optional[Mapping[str, Any]] = None,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
config_factories: Optional[List[Callable[[], RunnableConfig]]] = None, config_factories: Optional[
List[Callable[[RunnableConfig], RunnableConfig]]
] = None,
custom_input_type: Optional[Union[Type[Input], BaseModel]] = None, custom_input_type: Optional[Union[Type[Input], BaseModel]] = None,
custom_output_type: Optional[Union[Type[Output], BaseModel]] = None, custom_output_type: Optional[Union[Type[Output], BaseModel]] = None,
**other_kwargs: Any, **other_kwargs: Any,
@ -2570,9 +2577,9 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
def with_listeners( def with_listeners(
self, self,
*, *,
on_start: Optional[Callable[[Run], None]] = None, on_start: Optional[Listener] = None,
on_end: Optional[Callable[[Run], None]] = None, on_end: Optional[Listener] = None,
on_error: Optional[Callable[[Run], None]] = None, on_error: Optional[Listener] = None,
) -> Runnable[Input, Output]: ) -> Runnable[Input, Output]:
""" """
Bind lifecycle listeners to a Runnable, returning a new Runnable. Bind lifecycle listeners to a Runnable, returning a new Runnable.
@ -2592,10 +2599,13 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
kwargs=self.kwargs, kwargs=self.kwargs,
config=self.config, config=self.config,
config_factories=[ config_factories=[
lambda: { lambda config: {
"callbacks": [ "callbacks": [
RootListenersTracer( RootListenersTracer(
on_start=on_start, on_end=on_end, on_error=on_error config=config,
on_start=on_start,
on_end=on_end,
on_error=on_error,
) )
], ],
} }
@ -2629,9 +2639,8 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
) )
def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig: def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig:
return merge_configs( config = merge_configs(self.config, *configs)
self.config, *(f() for f in self.config_factories), *configs return merge_configs(config, *(f(config) for f in self.config_factories))
)
def invoke( def invoke(
self, self,

File diff suppressed because one or more lines are too long

View File

@ -1712,8 +1712,41 @@ def test_with_listeners(mocker: MockerFixture) -> None:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_with_listeners_async(mocker: MockerFixture) -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
chat = FakeListChatModel(responses=["foo"])
chain = prompt | chat
mock_start = mocker.Mock()
mock_end = mocker.Mock()
await chain.with_listeners(on_start=mock_start, on_end=mock_end).ainvoke(
{"question": "Who are you?"}
)
assert mock_start.call_count == 1
assert mock_start.call_args[0][0].name == "RunnableSequence"
assert mock_end.call_count == 1
mock_start.reset_mock()
mock_end.reset_mock()
async with atrace_as_chain_group("hello") as manager:
await chain.with_listeners(on_start=mock_start, on_end=mock_end).ainvoke(
{"question": "Who are you?"}, {"callbacks": manager}
)
assert mock_start.call_count == 1
assert mock_start.call_args[0][0].name == "RunnableSequence"
assert mock_end.call_count == 1
@freeze_time("2023-01-01") @freeze_time("2023-01-01")
async def test_prompt_with_chat_model( def test_prompt_with_chat_model(
mocker: MockerFixture, snapshot: SnapshotAssertion mocker: MockerFixture, snapshot: SnapshotAssertion
) -> None: ) -> None:
prompt = ( prompt = (
@ -1816,6 +1849,114 @@ async def test_prompt_with_chat_model(
) )
@pytest.mark.asyncio
@freeze_time("2023-01-01")
async def test_prompt_with_chat_model_async(
mocker: MockerFixture, snapshot: SnapshotAssertion
) -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
chat = FakeListChatModel(responses=["foo"])
chain = prompt | chat
assert repr(chain) == snapshot
assert isinstance(chain, RunnableSequence)
assert chain.first == prompt
assert chain.middle == []
assert chain.last == chat
assert dumps(chain, pretty=True) == snapshot
# Test invoke
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
chat_spy = mocker.spy(chat.__class__, "ainvoke")
tracer = FakeTracer()
assert await chain.ainvoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
) == AIMessage(content="foo")
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
)
assert tracer.runs == snapshot
mocker.stop(prompt_spy)
mocker.stop(chat_spy)
# Test batch
prompt_spy = mocker.spy(prompt.__class__, "abatch")
chat_spy = mocker.spy(chat.__class__, "abatch")
tracer = FakeTracer()
assert await chain.abatch(
[
{"question": "What is your name?"},
{"question": "What is your favorite color?"},
],
dict(callbacks=[tracer]),
) == [
AIMessage(content="foo"),
AIMessage(content="foo"),
]
assert prompt_spy.call_args.args[1] == [
{"question": "What is your name?"},
{"question": "What is your favorite color?"},
]
assert chat_spy.call_args.args[1] == [
ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
),
ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your favorite color?"),
]
),
]
assert (
len(
[
r
for r in tracer.runs
if r.parent_run_id is None and len(r.child_runs) == 2
]
)
== 2
), "Each of 2 outer runs contains exactly two inner runs (1 prompt, 1 chat)"
mocker.stop(prompt_spy)
mocker.stop(chat_spy)
# Test stream
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
chat_spy = mocker.spy(chat.__class__, "astream")
tracer = FakeTracer()
assert [
a
async for a in chain.astream(
{"question": "What is your name?"}, dict(callbacks=[tracer])
)
] == [
AIMessageChunk(content="f"),
AIMessageChunk(content="o"),
AIMessageChunk(content="o"),
]
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
)
@pytest.mark.asyncio @pytest.mark.asyncio
@freeze_time("2023-01-01") @freeze_time("2023-01-01")
async def test_prompt_with_llm( async def test_prompt_with_llm(