This commit is contained in:
Eugene Yurtsev
2024-01-16 22:22:14 -05:00
parent b938cf63d5
commit f99ed4e4fd
3 changed files with 133 additions and 42 deletions

View File

@@ -808,17 +808,28 @@ class Runnable(Generic[Input, Output], ABC):
from langchain_core.runnables.utils import (
_get_standardized_inputs,
_get_standardized_outputs,
_RootEventFilter,
)
from langchain_core.tracers.log_stream import (
RunLog,
)
run_log = RunLog(state=None) # type: ignore[arg-type]
yielded_start_event = False
encountered_start_event = False
config = ensure_config(config)
root_tags = config.get("tags", [])
root_metadata = config.get("metadata", {})
root_name = config.get("run_name", self.get_name())
_root_event_filter = _RootEventFilter(
include_names=include_names,
include_types=include_types,
include_tags=include_tags,
exclude_names=exclude_names,
exclude_types=exclude_types,
exclude_tags=exclude_tags,
)
async for log in self.astream_log(
input,
@@ -833,19 +844,24 @@ class Runnable(Generic[Input, Output], ABC):
):
run_log = run_log + log
if not yielded_start_event:
if not encountered_start_event:
# Yield the start event for the root runnable.
encountered_start_event = True
state = run_log.state.copy()
yield StreamEvent(
event = StreamEvent(
event=f"on_{state['type']}_start",
name=state["name"],
run_id=state["id"],
name=root_name,
tags=root_tags,
metadata=root_metadata,
data={
"input": input,
},
)
yielded_start_event = True
if _root_event_filter.include_event(event, state["type"]):
yield event
paths = {
op["path"].split("/")[2]
@@ -920,20 +936,23 @@ class Runnable(Generic[Input, Output], ABC):
# Clean up the stream, we don't need it anymore.
state["streamed_output"] = []
yield StreamEvent(
event = StreamEvent(
event=f"on_{state['type']}_stream",
name=state["name"],
run_id=state["id"],
tags=state.get("tags", []),
metadata=state.get("metadata", {}),
tags=root_tags,
metadata=root_metadata,
name=root_name,
data=data,
)
if _root_event_filter.include_event(event, state["type"]):
yield event
state = run_log.state
yield StreamEvent(
# Finally yield the end event for the root runnable.
event = StreamEvent(
event=f"on_{state['type']}_end",
name=state["name"],
name=root_name,
run_id=state["id"],
tags=root_tags,
metadata=root_metadata,
@@ -941,6 +960,8 @@ class Runnable(Generic[Input, Output], ABC):
"output": state["final_output"],
},
)
if _root_event_filter.include_event(event, state["type"]):
yield event
def transform(
self,
@@ -3875,6 +3896,17 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
):
yield item
async def astream_events(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[StreamEvent]:
async for item in self.bound.astream_events(
input, self._merge_configs(config), **{**self.kwargs, **kwargs}
):
yield item
def transform(
self,
input: Iterator[Input],

View File

@@ -446,6 +446,61 @@ class StreamEvent(TypedDict):
"""
class _RootEventFilter:
def __init__(
self,
*,
include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None,
exclude_names: Optional[Sequence[str]] = None,
exclude_types: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
) -> None:
"""Utility to filter the root event in the astream_events implementation.
This is simply binding the arguments to the namespace to make save on
a bit of typing in the astream_events implementation.
"""
self.include_names = include_names
self.include_types = include_types
self.include_tags = include_tags
self.exclude_names = exclude_names
self.exclude_types = exclude_types
self.exclude_tags = exclude_tags
def include_event(self, event: StreamEvent, root_type: str) -> bool:
"""Determine whether to include an event."""
if (
self.include_names is None
and self.include_types is None
and self.include_tags is None
):
include = True
else:
include = False
event_tags = event.get("tags") or []
if self.include_names is not None:
include = include or event["name"] in self.include_names
if self.include_types is not None:
include = include or root_type in self.include_types
if self.include_tags is not None:
include = include or any(tag in self.include_tags for tag in event_tags)
if self.exclude_names is not None:
include = include and event["name"] not in self.exclude_names
if self.exclude_types is not None:
include = include and root_type not in self.exclude_types
if self.exclude_tags is not None:
include = include and all(
tag not in self.exclude_tags for tag in event_tags
)
return include
async def _get_standardized_inputs(log: LogEntry) -> Optional[Dict[str, Any]]:
"""Extract standardized inputs from a log entry.

View File

@@ -209,19 +209,11 @@ async def test_event_stream_with_triple_lambda_test_filtering() -> None:
chain = (
r.with_config({"run_name": "1"})
| r.with_config({"run_name": "2"})
| r.with_config({"run_name": "3"})
| r.with_config({"run_name": "2", "tags": ["my_tag"]})
| r.with_config({"run_name": "3", "tags": ["my_tag"]})
)
events = await _collect_events(chain.astream_events("hello", include_names=["1"]))
assert events == [
{
"data": {"input": "hello"},
"event": "on_chain_start",
"metadata": {},
"name": "RunnableSequence",
"run_id": None,
"tags": [],
},
{
"data": {},
"event": "on_chain_start",
@@ -246,21 +238,35 @@ async def test_event_stream_with_triple_lambda_test_filtering() -> None:
"run_id": None,
"tags": ["seq:step:1"],
},
]
events = await _collect_events(
chain.astream_events("hello", include_tags=["my_tag"], exclude_names=["2"])
)
assert events == [
{
"data": {},
"event": "on_chain_start",
"metadata": {},
"name": "3",
"run_id": None,
"tags": ["my_tag", "seq:step:3"],
},
{
"data": {"chunk": "olleh"},
"event": "on_chain_stream",
"metadata": {},
"name": "RunnableSequence",
"name": "3",
"run_id": None,
"tags": [],
"tags": ["my_tag", "seq:step:3"],
},
{
"data": {"output": "olleh"},
"data": {"input": "hello", "output": "olleh"},
"event": "on_chain_end",
"metadata": {},
"name": "RunnableSequence",
"name": "3",
"run_id": None,
"tags": [],
"tags": ["my_tag", "seq:step:3"],
},
]
@@ -333,10 +339,10 @@ async def test_event_stream_with_simple_chain() -> None:
{
"data": {"input": {"question": "hello"}},
"event": "on_chain_start",
"metadata": {},
"metadata": {"foo": "bar"},
"name": "my_chain",
"run_id": None,
"tags": [],
"tags": ["my_chain"],
},
{
"data": {"input": {"question": "hello"}},
@@ -382,10 +388,10 @@ async def test_event_stream_with_simple_chain() -> None:
{
"data": {"chunk": AIMessageChunk(content="hello")},
"event": "on_chain_stream",
"metadata": {},
"metadata": {"foo": "bar"},
"name": "my_chain",
"run_id": None,
"tags": [],
"tags": ["my_chain"],
},
{
"data": {"chunk": AIMessageChunk(content="hello")},
@@ -398,10 +404,10 @@ async def test_event_stream_with_simple_chain() -> None:
{
"data": {"chunk": AIMessageChunk(content=" ")},
"event": "on_chain_stream",
"metadata": {},
"metadata": {"foo": "bar"},
"name": "my_chain",
"run_id": None,
"tags": [],
"tags": ["my_chain"],
},
{
"data": {"chunk": AIMessageChunk(content=" ")},
@@ -414,10 +420,10 @@ async def test_event_stream_with_simple_chain() -> None:
{
"data": {"chunk": AIMessageChunk(content="world!")},
"event": "on_chain_stream",
"metadata": {},
"metadata": {"foo": "bar"},
"name": "my_chain",
"run_id": None,
"tags": [],
"tags": ["my_chain"],
},
{
"data": {"chunk": AIMessageChunk(content="world!")},
@@ -459,14 +465,12 @@ async def test_event_stream_with_simple_chain() -> None:
"tags": ["my_chain", "my_model", "seq:step:2"],
},
{
"data": {
"output": AIMessageChunk(content="hello world!"),
},
"data": {"output": AIMessageChunk(content="hello world!")},
"event": "on_chain_end",
"metadata": {},
"metadata": {"foo": "bar"},
"name": "my_chain",
"run_id": None,
"tags": [],
"tags": ["my_chain"],
},
]
@@ -543,7 +547,7 @@ async def test_event_stream_with_retriever() -> None:
},
"event": "on_retriever_start",
"metadata": {},
"name": "Retriever",
"name": "HardCodedRetriever",
"run_id": None,
"tags": [],
},
@@ -556,7 +560,7 @@ async def test_event_stream_with_retriever() -> None:
},
"event": "on_retriever_stream",
"metadata": {},
"name": "Retriever",
"name": "HardCodedRetriever",
"run_id": None,
"tags": [],
},
@@ -569,7 +573,7 @@ async def test_event_stream_with_retriever() -> None:
},
"event": "on_retriever_end",
"metadata": {},
"name": "Retriever",
"name": "HardCodedRetriever",
"run_id": None,
"tags": [],
},