mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-01 12:38:45 +00:00
Changes to root listener (#12174)
- Implement config_specs to include session_id - Remove Runnable method and update notebook - Add more details to notebook, eg. show input schema and config schema before and after adding message history --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
b2b94424db
commit
362a446999
@ -1,20 +1,28 @@
|
|||||||
from typing import Callable, Optional
|
from typing import Callable, Optional, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from langchain.callbacks.tracers.base import BaseTracer
|
from langchain.callbacks.tracers.base import BaseTracer
|
||||||
from langchain.callbacks.tracers.schemas import Run
|
from langchain.callbacks.tracers.schemas import Run
|
||||||
|
from langchain.schema.runnable.config import (
|
||||||
|
RunnableConfig,
|
||||||
|
call_func_with_variable_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
Listener = Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]]
|
||||||
|
|
||||||
|
|
||||||
class RootListenersTracer(BaseTracer):
|
class RootListenersTracer(BaseTracer):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
on_start: Optional[Callable[[Run], None]],
|
config: RunnableConfig,
|
||||||
on_end: Optional[Callable[[Run], None]],
|
on_start: Optional[Listener],
|
||||||
on_error: Optional[Callable[[Run], None]],
|
on_end: Optional[Listener],
|
||||||
|
on_error: Optional[Listener],
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
self._arg_on_start = on_start
|
self._arg_on_start = on_start
|
||||||
self._arg_on_end = on_end
|
self._arg_on_end = on_end
|
||||||
self._arg_on_error = on_error
|
self._arg_on_error = on_error
|
||||||
@ -32,7 +40,7 @@ class RootListenersTracer(BaseTracer):
|
|||||||
self.root_id = run.id
|
self.root_id = run.id
|
||||||
|
|
||||||
if self._arg_on_start is not None:
|
if self._arg_on_start is not None:
|
||||||
self._arg_on_start(run)
|
call_func_with_variable_args(self._arg_on_start, run, self.config)
|
||||||
|
|
||||||
def _on_run_update(self, run: Run) -> None:
|
def _on_run_update(self, run: Run) -> None:
|
||||||
if run.id != self.root_id:
|
if run.id != self.root_id:
|
||||||
@ -40,7 +48,7 @@ class RootListenersTracer(BaseTracer):
|
|||||||
|
|
||||||
if run.error is None:
|
if run.error is None:
|
||||||
if self._arg_on_end is not None:
|
if self._arg_on_end is not None:
|
||||||
self._arg_on_end(run)
|
call_func_with_variable_args(self._arg_on_end, run, self.config)
|
||||||
else:
|
else:
|
||||||
if self._arg_on_error is not None:
|
if self._arg_on_error is not None:
|
||||||
self._arg_on_error(run)
|
call_func_with_variable_args(self._arg_on_error, run, self.config)
|
||||||
|
@ -37,7 +37,7 @@ if TYPE_CHECKING:
|
|||||||
CallbackManagerForChainRun,
|
CallbackManagerForChainRun,
|
||||||
)
|
)
|
||||||
from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch
|
from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch
|
||||||
from langchain.callbacks.tracers.schemas import Run
|
from langchain.callbacks.tracers.root_listeners import Listener
|
||||||
from langchain.schema.runnable.fallbacks import (
|
from langchain.schema.runnable.fallbacks import (
|
||||||
RunnableWithFallbacks as RunnableWithFallbacksT,
|
RunnableWithFallbacks as RunnableWithFallbacksT,
|
||||||
)
|
)
|
||||||
@ -591,9 +591,9 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
def with_listeners(
|
def with_listeners(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
on_start: Optional[Callable[[Run], None]] = None,
|
on_start: Optional[Listener] = None,
|
||||||
on_end: Optional[Callable[[Run], None]] = None,
|
on_end: Optional[Listener] = None,
|
||||||
on_error: Optional[Callable[[Run], None]] = None,
|
on_error: Optional[Listener] = None,
|
||||||
) -> Runnable[Input, Output]:
|
) -> Runnable[Input, Output]:
|
||||||
"""
|
"""
|
||||||
Bind lifecycle listeners to a Runnable, returning a new Runnable.
|
Bind lifecycle listeners to a Runnable, returning a new Runnable.
|
||||||
@ -611,10 +611,13 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
return RunnableBinding(
|
return RunnableBinding(
|
||||||
bound=self,
|
bound=self,
|
||||||
config_factories=[
|
config_factories=[
|
||||||
lambda: {
|
lambda config: {
|
||||||
"callbacks": [
|
"callbacks": [
|
||||||
RootListenersTracer(
|
RootListenersTracer(
|
||||||
on_start=on_start, on_end=on_end, on_error=on_error
|
config=config,
|
||||||
|
on_start=on_start,
|
||||||
|
on_end=on_end,
|
||||||
|
on_error=on_error,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
@ -2391,9 +2394,9 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
|
|||||||
def with_listeners(
|
def with_listeners(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
on_start: Optional[Callable[[Run], None]] = None,
|
on_start: Optional[Listener] = None,
|
||||||
on_end: Optional[Callable[[Run], None]] = None,
|
on_end: Optional[Listener] = None,
|
||||||
on_error: Optional[Callable[[Run], None]] = None,
|
on_error: Optional[Listener] = None,
|
||||||
) -> RunnableEach[Input, Output]:
|
) -> RunnableEach[Input, Output]:
|
||||||
"""
|
"""
|
||||||
Bind lifecycle listeners to a Runnable, returning a new Runnable.
|
Bind lifecycle listeners to a Runnable, returning a new Runnable.
|
||||||
@ -2456,7 +2459,9 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
|||||||
|
|
||||||
config: RunnableConfig = Field(default_factory=dict)
|
config: RunnableConfig = Field(default_factory=dict)
|
||||||
|
|
||||||
config_factories: List[Callable[[], RunnableConfig]] = Field(default_factory=list)
|
config_factories: List[Callable[[RunnableConfig], RunnableConfig]] = Field(
|
||||||
|
default_factory=list
|
||||||
|
)
|
||||||
|
|
||||||
# Union[Type[Input], BaseModel] + things like List[str]
|
# Union[Type[Input], BaseModel] + things like List[str]
|
||||||
custom_input_type: Optional[Any] = None
|
custom_input_type: Optional[Any] = None
|
||||||
@ -2472,7 +2477,9 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
|||||||
bound: Runnable[Input, Output],
|
bound: Runnable[Input, Output],
|
||||||
kwargs: Optional[Mapping[str, Any]] = None,
|
kwargs: Optional[Mapping[str, Any]] = None,
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
config_factories: Optional[List[Callable[[], RunnableConfig]]] = None,
|
config_factories: Optional[
|
||||||
|
List[Callable[[RunnableConfig], RunnableConfig]]
|
||||||
|
] = None,
|
||||||
custom_input_type: Optional[Union[Type[Input], BaseModel]] = None,
|
custom_input_type: Optional[Union[Type[Input], BaseModel]] = None,
|
||||||
custom_output_type: Optional[Union[Type[Output], BaseModel]] = None,
|
custom_output_type: Optional[Union[Type[Output], BaseModel]] = None,
|
||||||
**other_kwargs: Any,
|
**other_kwargs: Any,
|
||||||
@ -2570,9 +2577,9 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
|||||||
def with_listeners(
|
def with_listeners(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
on_start: Optional[Callable[[Run], None]] = None,
|
on_start: Optional[Listener] = None,
|
||||||
on_end: Optional[Callable[[Run], None]] = None,
|
on_end: Optional[Listener] = None,
|
||||||
on_error: Optional[Callable[[Run], None]] = None,
|
on_error: Optional[Listener] = None,
|
||||||
) -> Runnable[Input, Output]:
|
) -> Runnable[Input, Output]:
|
||||||
"""
|
"""
|
||||||
Bind lifecycle listeners to a Runnable, returning a new Runnable.
|
Bind lifecycle listeners to a Runnable, returning a new Runnable.
|
||||||
@ -2592,10 +2599,13 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
|||||||
kwargs=self.kwargs,
|
kwargs=self.kwargs,
|
||||||
config=self.config,
|
config=self.config,
|
||||||
config_factories=[
|
config_factories=[
|
||||||
lambda: {
|
lambda config: {
|
||||||
"callbacks": [
|
"callbacks": [
|
||||||
RootListenersTracer(
|
RootListenersTracer(
|
||||||
on_start=on_start, on_end=on_end, on_error=on_error
|
config=config,
|
||||||
|
on_start=on_start,
|
||||||
|
on_end=on_end,
|
||||||
|
on_error=on_error,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
@ -2629,9 +2639,8 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig:
|
def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig:
|
||||||
return merge_configs(
|
config = merge_configs(self.config, *configs)
|
||||||
self.config, *(f() for f in self.config_factories), *configs
|
return merge_configs(config, *(f(config) for f in self.config_factories))
|
||||||
)
|
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self,
|
self,
|
||||||
|
File diff suppressed because one or more lines are too long
@ -1712,8 +1712,41 @@ def test_with_listeners(mocker: MockerFixture) -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
async def test_with_listeners_async(mocker: MockerFixture) -> None:
|
||||||
|
prompt = (
|
||||||
|
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||||
|
+ "{question}"
|
||||||
|
)
|
||||||
|
chat = FakeListChatModel(responses=["foo"])
|
||||||
|
|
||||||
|
chain = prompt | chat
|
||||||
|
|
||||||
|
mock_start = mocker.Mock()
|
||||||
|
mock_end = mocker.Mock()
|
||||||
|
|
||||||
|
await chain.with_listeners(on_start=mock_start, on_end=mock_end).ainvoke(
|
||||||
|
{"question": "Who are you?"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_start.call_count == 1
|
||||||
|
assert mock_start.call_args[0][0].name == "RunnableSequence"
|
||||||
|
assert mock_end.call_count == 1
|
||||||
|
|
||||||
|
mock_start.reset_mock()
|
||||||
|
mock_end.reset_mock()
|
||||||
|
|
||||||
|
async with atrace_as_chain_group("hello") as manager:
|
||||||
|
await chain.with_listeners(on_start=mock_start, on_end=mock_end).ainvoke(
|
||||||
|
{"question": "Who are you?"}, {"callbacks": manager}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_start.call_count == 1
|
||||||
|
assert mock_start.call_args[0][0].name == "RunnableSequence"
|
||||||
|
assert mock_end.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
@freeze_time("2023-01-01")
|
@freeze_time("2023-01-01")
|
||||||
async def test_prompt_with_chat_model(
|
def test_prompt_with_chat_model(
|
||||||
mocker: MockerFixture, snapshot: SnapshotAssertion
|
mocker: MockerFixture, snapshot: SnapshotAssertion
|
||||||
) -> None:
|
) -> None:
|
||||||
prompt = (
|
prompt = (
|
||||||
@ -1816,6 +1849,114 @@ async def test_prompt_with_chat_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
async def test_prompt_with_chat_model_async(
|
||||||
|
mocker: MockerFixture, snapshot: SnapshotAssertion
|
||||||
|
) -> None:
|
||||||
|
prompt = (
|
||||||
|
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||||
|
+ "{question}"
|
||||||
|
)
|
||||||
|
chat = FakeListChatModel(responses=["foo"])
|
||||||
|
|
||||||
|
chain = prompt | chat
|
||||||
|
|
||||||
|
assert repr(chain) == snapshot
|
||||||
|
assert isinstance(chain, RunnableSequence)
|
||||||
|
assert chain.first == prompt
|
||||||
|
assert chain.middle == []
|
||||||
|
assert chain.last == chat
|
||||||
|
assert dumps(chain, pretty=True) == snapshot
|
||||||
|
|
||||||
|
# Test invoke
|
||||||
|
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
|
||||||
|
chat_spy = mocker.spy(chat.__class__, "ainvoke")
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert await chain.ainvoke(
|
||||||
|
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||||
|
) == AIMessage(content="foo")
|
||||||
|
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||||
|
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||||
|
messages=[
|
||||||
|
SystemMessage(content="You are a nice assistant."),
|
||||||
|
HumanMessage(content="What is your name?"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tracer.runs == snapshot
|
||||||
|
|
||||||
|
mocker.stop(prompt_spy)
|
||||||
|
mocker.stop(chat_spy)
|
||||||
|
|
||||||
|
# Test batch
|
||||||
|
prompt_spy = mocker.spy(prompt.__class__, "abatch")
|
||||||
|
chat_spy = mocker.spy(chat.__class__, "abatch")
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert await chain.abatch(
|
||||||
|
[
|
||||||
|
{"question": "What is your name?"},
|
||||||
|
{"question": "What is your favorite color?"},
|
||||||
|
],
|
||||||
|
dict(callbacks=[tracer]),
|
||||||
|
) == [
|
||||||
|
AIMessage(content="foo"),
|
||||||
|
AIMessage(content="foo"),
|
||||||
|
]
|
||||||
|
assert prompt_spy.call_args.args[1] == [
|
||||||
|
{"question": "What is your name?"},
|
||||||
|
{"question": "What is your favorite color?"},
|
||||||
|
]
|
||||||
|
assert chat_spy.call_args.args[1] == [
|
||||||
|
ChatPromptValue(
|
||||||
|
messages=[
|
||||||
|
SystemMessage(content="You are a nice assistant."),
|
||||||
|
HumanMessage(content="What is your name?"),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
ChatPromptValue(
|
||||||
|
messages=[
|
||||||
|
SystemMessage(content="You are a nice assistant."),
|
||||||
|
HumanMessage(content="What is your favorite color?"),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
assert (
|
||||||
|
len(
|
||||||
|
[
|
||||||
|
r
|
||||||
|
for r in tracer.runs
|
||||||
|
if r.parent_run_id is None and len(r.child_runs) == 2
|
||||||
|
]
|
||||||
|
)
|
||||||
|
== 2
|
||||||
|
), "Each of 2 outer runs contains exactly two inner runs (1 prompt, 1 chat)"
|
||||||
|
mocker.stop(prompt_spy)
|
||||||
|
mocker.stop(chat_spy)
|
||||||
|
|
||||||
|
# Test stream
|
||||||
|
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
|
||||||
|
chat_spy = mocker.spy(chat.__class__, "astream")
|
||||||
|
tracer = FakeTracer()
|
||||||
|
assert [
|
||||||
|
a
|
||||||
|
async for a in chain.astream(
|
||||||
|
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||||
|
)
|
||||||
|
] == [
|
||||||
|
AIMessageChunk(content="f"),
|
||||||
|
AIMessageChunk(content="o"),
|
||||||
|
AIMessageChunk(content="o"),
|
||||||
|
]
|
||||||
|
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||||
|
assert chat_spy.call_args.args[1] == ChatPromptValue(
|
||||||
|
messages=[
|
||||||
|
SystemMessage(content="You are a nice assistant."),
|
||||||
|
HumanMessage(content="What is your name?"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@freeze_time("2023-01-01")
|
@freeze_time("2023-01-01")
|
||||||
async def test_prompt_with_llm(
|
async def test_prompt_with_llm(
|
||||||
|
Loading…
Reference in New Issue
Block a user