mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-28 02:29:17 +00:00
Changes to root listener (#12174)
- Implement config_specs to include session_id - Remove Runnable method and update notebook - Add more details to notebook, eg. show input schema and config schema before and after adding message history --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
b2b94424db
commit
362a446999
@ -1,20 +1,28 @@
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.callbacks.tracers.base import BaseTracer
|
||||
from langchain.callbacks.tracers.schemas import Run
|
||||
from langchain.schema.runnable.config import (
|
||||
RunnableConfig,
|
||||
call_func_with_variable_args,
|
||||
)
|
||||
|
||||
Listener = Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]]
|
||||
|
||||
|
||||
class RootListenersTracer(BaseTracer):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
on_start: Optional[Callable[[Run], None]],
|
||||
on_end: Optional[Callable[[Run], None]],
|
||||
on_error: Optional[Callable[[Run], None]],
|
||||
config: RunnableConfig,
|
||||
on_start: Optional[Listener],
|
||||
on_end: Optional[Listener],
|
||||
on_error: Optional[Listener],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self._arg_on_start = on_start
|
||||
self._arg_on_end = on_end
|
||||
self._arg_on_error = on_error
|
||||
@ -32,7 +40,7 @@ class RootListenersTracer(BaseTracer):
|
||||
self.root_id = run.id
|
||||
|
||||
if self._arg_on_start is not None:
|
||||
self._arg_on_start(run)
|
||||
call_func_with_variable_args(self._arg_on_start, run, self.config)
|
||||
|
||||
def _on_run_update(self, run: Run) -> None:
|
||||
if run.id != self.root_id:
|
||||
@ -40,7 +48,7 @@ class RootListenersTracer(BaseTracer):
|
||||
|
||||
if run.error is None:
|
||||
if self._arg_on_end is not None:
|
||||
self._arg_on_end(run)
|
||||
call_func_with_variable_args(self._arg_on_end, run, self.config)
|
||||
else:
|
||||
if self._arg_on_error is not None:
|
||||
self._arg_on_error(run)
|
||||
call_func_with_variable_args(self._arg_on_error, run, self.config)
|
||||
|
@ -37,7 +37,7 @@ if TYPE_CHECKING:
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch
|
||||
from langchain.callbacks.tracers.schemas import Run
|
||||
from langchain.callbacks.tracers.root_listeners import Listener
|
||||
from langchain.schema.runnable.fallbacks import (
|
||||
RunnableWithFallbacks as RunnableWithFallbacksT,
|
||||
)
|
||||
@ -591,9 +591,9 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
def with_listeners(
|
||||
self,
|
||||
*,
|
||||
on_start: Optional[Callable[[Run], None]] = None,
|
||||
on_end: Optional[Callable[[Run], None]] = None,
|
||||
on_error: Optional[Callable[[Run], None]] = None,
|
||||
on_start: Optional[Listener] = None,
|
||||
on_end: Optional[Listener] = None,
|
||||
on_error: Optional[Listener] = None,
|
||||
) -> Runnable[Input, Output]:
|
||||
"""
|
||||
Bind lifecycle listeners to a Runnable, returning a new Runnable.
|
||||
@ -611,10 +611,13 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
return RunnableBinding(
|
||||
bound=self,
|
||||
config_factories=[
|
||||
lambda: {
|
||||
lambda config: {
|
||||
"callbacks": [
|
||||
RootListenersTracer(
|
||||
on_start=on_start, on_end=on_end, on_error=on_error
|
||||
config=config,
|
||||
on_start=on_start,
|
||||
on_end=on_end,
|
||||
on_error=on_error,
|
||||
)
|
||||
],
|
||||
}
|
||||
@ -2391,9 +2394,9 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
|
||||
def with_listeners(
|
||||
self,
|
||||
*,
|
||||
on_start: Optional[Callable[[Run], None]] = None,
|
||||
on_end: Optional[Callable[[Run], None]] = None,
|
||||
on_error: Optional[Callable[[Run], None]] = None,
|
||||
on_start: Optional[Listener] = None,
|
||||
on_end: Optional[Listener] = None,
|
||||
on_error: Optional[Listener] = None,
|
||||
) -> RunnableEach[Input, Output]:
|
||||
"""
|
||||
Bind lifecycle listeners to a Runnable, returning a new Runnable.
|
||||
@ -2456,7 +2459,9 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
|
||||
config: RunnableConfig = Field(default_factory=dict)
|
||||
|
||||
config_factories: List[Callable[[], RunnableConfig]] = Field(default_factory=list)
|
||||
config_factories: List[Callable[[RunnableConfig], RunnableConfig]] = Field(
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
# Union[Type[Input], BaseModel] + things like List[str]
|
||||
custom_input_type: Optional[Any] = None
|
||||
@ -2472,7 +2477,9 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
bound: Runnable[Input, Output],
|
||||
kwargs: Optional[Mapping[str, Any]] = None,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
config_factories: Optional[List[Callable[[], RunnableConfig]]] = None,
|
||||
config_factories: Optional[
|
||||
List[Callable[[RunnableConfig], RunnableConfig]]
|
||||
] = None,
|
||||
custom_input_type: Optional[Union[Type[Input], BaseModel]] = None,
|
||||
custom_output_type: Optional[Union[Type[Output], BaseModel]] = None,
|
||||
**other_kwargs: Any,
|
||||
@ -2570,9 +2577,9 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
def with_listeners(
|
||||
self,
|
||||
*,
|
||||
on_start: Optional[Callable[[Run], None]] = None,
|
||||
on_end: Optional[Callable[[Run], None]] = None,
|
||||
on_error: Optional[Callable[[Run], None]] = None,
|
||||
on_start: Optional[Listener] = None,
|
||||
on_end: Optional[Listener] = None,
|
||||
on_error: Optional[Listener] = None,
|
||||
) -> Runnable[Input, Output]:
|
||||
"""
|
||||
Bind lifecycle listeners to a Runnable, returning a new Runnable.
|
||||
@ -2592,10 +2599,13 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
kwargs=self.kwargs,
|
||||
config=self.config,
|
||||
config_factories=[
|
||||
lambda: {
|
||||
lambda config: {
|
||||
"callbacks": [
|
||||
RootListenersTracer(
|
||||
on_start=on_start, on_end=on_end, on_error=on_error
|
||||
config=config,
|
||||
on_start=on_start,
|
||||
on_end=on_end,
|
||||
on_error=on_error,
|
||||
)
|
||||
],
|
||||
}
|
||||
@ -2629,9 +2639,8 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
)
|
||||
|
||||
def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig:
|
||||
return merge_configs(
|
||||
self.config, *(f() for f in self.config_factories), *configs
|
||||
)
|
||||
config = merge_configs(self.config, *configs)
|
||||
return merge_configs(config, *(f(config) for f in self.config_factories))
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
|
File diff suppressed because one or more lines are too long
@ -1712,8 +1712,41 @@ def test_with_listeners(mocker: MockerFixture) -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_with_listeners_async(mocker: MockerFixture) -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
+ "{question}"
|
||||
)
|
||||
chat = FakeListChatModel(responses=["foo"])
|
||||
|
||||
chain = prompt | chat
|
||||
|
||||
mock_start = mocker.Mock()
|
||||
mock_end = mocker.Mock()
|
||||
|
||||
await chain.with_listeners(on_start=mock_start, on_end=mock_end).ainvoke(
|
||||
{"question": "Who are you?"}
|
||||
)
|
||||
|
||||
assert mock_start.call_count == 1
|
||||
assert mock_start.call_args[0][0].name == "RunnableSequence"
|
||||
assert mock_end.call_count == 1
|
||||
|
||||
mock_start.reset_mock()
|
||||
mock_end.reset_mock()
|
||||
|
||||
async with atrace_as_chain_group("hello") as manager:
|
||||
await chain.with_listeners(on_start=mock_start, on_end=mock_end).ainvoke(
|
||||
{"question": "Who are you?"}, {"callbacks": manager}
|
||||
)
|
||||
|
||||
assert mock_start.call_count == 1
|
||||
assert mock_start.call_args[0][0].name == "RunnableSequence"
|
||||
assert mock_end.call_count == 1
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_prompt_with_chat_model(
|
||||
def test_prompt_with_chat_model(
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion
|
||||
) -> None:
|
||||
prompt = (
|
||||
@ -1816,6 +1849,114 @@ async def test_prompt_with_chat_model(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_prompt_with_chat_model_async(
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion
|
||||
) -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
+ "{question}"
|
||||
)
|
||||
chat = FakeListChatModel(responses=["foo"])
|
||||
|
||||
chain = prompt | chat
|
||||
|
||||
assert repr(chain) == snapshot
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
assert chain.first == prompt
|
||||
assert chain.middle == []
|
||||
assert chain.last == chat
|
||||
assert dumps(chain, pretty=True) == snapshot
|
||||
|
||||
# Test invoke
|
||||
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
|
||||
chat_spy = mocker.spy(chat.__class__, "ainvoke")
|
||||
tracer = FakeTracer()
|
||||
assert await chain.ainvoke(
|
||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||
) == AIMessage(content="foo")
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
)
|
||||
|
||||
assert tracer.runs == snapshot
|
||||
|
||||
mocker.stop(prompt_spy)
|
||||
mocker.stop(chat_spy)
|
||||
|
||||
# Test batch
|
||||
prompt_spy = mocker.spy(prompt.__class__, "abatch")
|
||||
chat_spy = mocker.spy(chat.__class__, "abatch")
|
||||
tracer = FakeTracer()
|
||||
assert await chain.abatch(
|
||||
[
|
||||
{"question": "What is your name?"},
|
||||
{"question": "What is your favorite color?"},
|
||||
],
|
||||
dict(callbacks=[tracer]),
|
||||
) == [
|
||||
AIMessage(content="foo"),
|
||||
AIMessage(content="foo"),
|
||||
]
|
||||
assert prompt_spy.call_args.args[1] == [
|
||||
{"question": "What is your name?"},
|
||||
{"question": "What is your favorite color?"},
|
||||
]
|
||||
assert chat_spy.call_args.args[1] == [
|
||||
ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
),
|
||||
ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your favorite color?"),
|
||||
]
|
||||
),
|
||||
]
|
||||
assert (
|
||||
len(
|
||||
[
|
||||
r
|
||||
for r in tracer.runs
|
||||
if r.parent_run_id is None and len(r.child_runs) == 2
|
||||
]
|
||||
)
|
||||
== 2
|
||||
), "Each of 2 outer runs contains exactly two inner runs (1 prompt, 1 chat)"
|
||||
mocker.stop(prompt_spy)
|
||||
mocker.stop(chat_spy)
|
||||
|
||||
# Test stream
|
||||
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
|
||||
chat_spy = mocker.spy(chat.__class__, "astream")
|
||||
tracer = FakeTracer()
|
||||
assert [
|
||||
a
|
||||
async for a in chain.astream(
|
||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||
)
|
||||
] == [
|
||||
AIMessageChunk(content="f"),
|
||||
AIMessageChunk(content="o"),
|
||||
AIMessageChunk(content="o"),
|
||||
]
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_prompt_with_llm(
|
||||
|
Loading…
Reference in New Issue
Block a user