mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-23 20:23:59 +00:00
x
This commit is contained in:
@@ -808,17 +808,28 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
from langchain_core.runnables.utils import (
|
||||
_get_standardized_inputs,
|
||||
_get_standardized_outputs,
|
||||
_RootEventFilter,
|
||||
)
|
||||
from langchain_core.tracers.log_stream import (
|
||||
RunLog,
|
||||
)
|
||||
|
||||
run_log = RunLog(state=None) # type: ignore[arg-type]
|
||||
yielded_start_event = False
|
||||
encountered_start_event = False
|
||||
config = ensure_config(config)
|
||||
|
||||
root_tags = config.get("tags", [])
|
||||
root_metadata = config.get("metadata", {})
|
||||
root_name = config.get("run_name", self.get_name())
|
||||
|
||||
_root_event_filter = _RootEventFilter(
|
||||
include_names=include_names,
|
||||
include_types=include_types,
|
||||
include_tags=include_tags,
|
||||
exclude_names=exclude_names,
|
||||
exclude_types=exclude_types,
|
||||
exclude_tags=exclude_tags,
|
||||
)
|
||||
|
||||
async for log in self.astream_log(
|
||||
input,
|
||||
@@ -833,19 +844,24 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
):
|
||||
run_log = run_log + log
|
||||
|
||||
if not yielded_start_event:
|
||||
if not encountered_start_event:
|
||||
# Yield the start event for the root runnable.
|
||||
encountered_start_event = True
|
||||
state = run_log.state.copy()
|
||||
yield StreamEvent(
|
||||
|
||||
event = StreamEvent(
|
||||
event=f"on_{state['type']}_start",
|
||||
name=state["name"],
|
||||
run_id=state["id"],
|
||||
name=root_name,
|
||||
tags=root_tags,
|
||||
metadata=root_metadata,
|
||||
data={
|
||||
"input": input,
|
||||
},
|
||||
)
|
||||
yielded_start_event = True
|
||||
|
||||
if _root_event_filter.include_event(event, state["type"]):
|
||||
yield event
|
||||
|
||||
paths = {
|
||||
op["path"].split("/")[2]
|
||||
@@ -920,20 +936,23 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
# Clean up the stream, we don't need it anymore.
|
||||
state["streamed_output"] = []
|
||||
|
||||
yield StreamEvent(
|
||||
event = StreamEvent(
|
||||
event=f"on_{state['type']}_stream",
|
||||
name=state["name"],
|
||||
run_id=state["id"],
|
||||
tags=state.get("tags", []),
|
||||
metadata=state.get("metadata", {}),
|
||||
tags=root_tags,
|
||||
metadata=root_metadata,
|
||||
name=root_name,
|
||||
data=data,
|
||||
)
|
||||
if _root_event_filter.include_event(event, state["type"]):
|
||||
yield event
|
||||
|
||||
state = run_log.state
|
||||
|
||||
yield StreamEvent(
|
||||
# Finally yield the end event for the root runnable.
|
||||
event = StreamEvent(
|
||||
event=f"on_{state['type']}_end",
|
||||
name=state["name"],
|
||||
name=root_name,
|
||||
run_id=state["id"],
|
||||
tags=root_tags,
|
||||
metadata=root_metadata,
|
||||
@@ -941,6 +960,8 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
"output": state["final_output"],
|
||||
},
|
||||
)
|
||||
if _root_event_filter.include_event(event, state["type"]):
|
||||
yield event
|
||||
|
||||
def transform(
|
||||
self,
|
||||
@@ -3875,6 +3896,17 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
):
|
||||
yield item
|
||||
|
||||
async def astream_events(
|
||||
self,
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
async for item in self.bound.astream_events(
|
||||
input, self._merge_configs(config), **{**self.kwargs, **kwargs}
|
||||
):
|
||||
yield item
|
||||
|
||||
def transform(
|
||||
self,
|
||||
input: Iterator[Input],
|
||||
|
||||
@@ -446,6 +446,61 @@ class StreamEvent(TypedDict):
|
||||
"""
|
||||
|
||||
|
||||
class _RootEventFilter:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
include_names: Optional[Sequence[str]] = None,
|
||||
include_types: Optional[Sequence[str]] = None,
|
||||
include_tags: Optional[Sequence[str]] = None,
|
||||
exclude_names: Optional[Sequence[str]] = None,
|
||||
exclude_types: Optional[Sequence[str]] = None,
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
) -> None:
|
||||
"""Utility to filter the root event in the astream_events implementation.
|
||||
|
||||
This is simply binding the arguments to the namespace to make save on
|
||||
a bit of typing in the astream_events implementation.
|
||||
"""
|
||||
self.include_names = include_names
|
||||
self.include_types = include_types
|
||||
self.include_tags = include_tags
|
||||
self.exclude_names = exclude_names
|
||||
self.exclude_types = exclude_types
|
||||
self.exclude_tags = exclude_tags
|
||||
|
||||
def include_event(self, event: StreamEvent, root_type: str) -> bool:
|
||||
"""Determine whether to include an event."""
|
||||
if (
|
||||
self.include_names is None
|
||||
and self.include_types is None
|
||||
and self.include_tags is None
|
||||
):
|
||||
include = True
|
||||
else:
|
||||
include = False
|
||||
|
||||
event_tags = event.get("tags") or []
|
||||
|
||||
if self.include_names is not None:
|
||||
include = include or event["name"] in self.include_names
|
||||
if self.include_types is not None:
|
||||
include = include or root_type in self.include_types
|
||||
if self.include_tags is not None:
|
||||
include = include or any(tag in self.include_tags for tag in event_tags)
|
||||
|
||||
if self.exclude_names is not None:
|
||||
include = include and event["name"] not in self.exclude_names
|
||||
if self.exclude_types is not None:
|
||||
include = include and root_type not in self.exclude_types
|
||||
if self.exclude_tags is not None:
|
||||
include = include and all(
|
||||
tag not in self.exclude_tags for tag in event_tags
|
||||
)
|
||||
|
||||
return include
|
||||
|
||||
|
||||
async def _get_standardized_inputs(log: LogEntry) -> Optional[Dict[str, Any]]:
|
||||
"""Extract standardized inputs from a log entry.
|
||||
|
||||
|
||||
@@ -209,19 +209,11 @@ async def test_event_stream_with_triple_lambda_test_filtering() -> None:
|
||||
|
||||
chain = (
|
||||
r.with_config({"run_name": "1"})
|
||||
| r.with_config({"run_name": "2"})
|
||||
| r.with_config({"run_name": "3"})
|
||||
| r.with_config({"run_name": "2", "tags": ["my_tag"]})
|
||||
| r.with_config({"run_name": "3", "tags": ["my_tag"]})
|
||||
)
|
||||
events = await _collect_events(chain.astream_events("hello", include_names=["1"]))
|
||||
assert events == [
|
||||
{
|
||||
"data": {"input": "hello"},
|
||||
"event": "on_chain_start",
|
||||
"metadata": {},
|
||||
"name": "RunnableSequence",
|
||||
"run_id": None,
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {},
|
||||
"event": "on_chain_start",
|
||||
@@ -246,21 +238,35 @@ async def test_event_stream_with_triple_lambda_test_filtering() -> None:
|
||||
"run_id": None,
|
||||
"tags": ["seq:step:1"],
|
||||
},
|
||||
]
|
||||
|
||||
events = await _collect_events(
|
||||
chain.astream_events("hello", include_tags=["my_tag"], exclude_names=["2"])
|
||||
)
|
||||
assert events == [
|
||||
{
|
||||
"data": {},
|
||||
"event": "on_chain_start",
|
||||
"metadata": {},
|
||||
"name": "3",
|
||||
"run_id": None,
|
||||
"tags": ["my_tag", "seq:step:3"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": "olleh"},
|
||||
"event": "on_chain_stream",
|
||||
"metadata": {},
|
||||
"name": "RunnableSequence",
|
||||
"name": "3",
|
||||
"run_id": None,
|
||||
"tags": [],
|
||||
"tags": ["my_tag", "seq:step:3"],
|
||||
},
|
||||
{
|
||||
"data": {"output": "olleh"},
|
||||
"data": {"input": "hello", "output": "olleh"},
|
||||
"event": "on_chain_end",
|
||||
"metadata": {},
|
||||
"name": "RunnableSequence",
|
||||
"name": "3",
|
||||
"run_id": None,
|
||||
"tags": [],
|
||||
"tags": ["my_tag", "seq:step:3"],
|
||||
},
|
||||
]
|
||||
|
||||
@@ -333,10 +339,10 @@ async def test_event_stream_with_simple_chain() -> None:
|
||||
{
|
||||
"data": {"input": {"question": "hello"}},
|
||||
"event": "on_chain_start",
|
||||
"metadata": {},
|
||||
"metadata": {"foo": "bar"},
|
||||
"name": "my_chain",
|
||||
"run_id": None,
|
||||
"tags": [],
|
||||
"tags": ["my_chain"],
|
||||
},
|
||||
{
|
||||
"data": {"input": {"question": "hello"}},
|
||||
@@ -382,10 +388,10 @@ async def test_event_stream_with_simple_chain() -> None:
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="hello")},
|
||||
"event": "on_chain_stream",
|
||||
"metadata": {},
|
||||
"metadata": {"foo": "bar"},
|
||||
"name": "my_chain",
|
||||
"run_id": None,
|
||||
"tags": [],
|
||||
"tags": ["my_chain"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="hello")},
|
||||
@@ -398,10 +404,10 @@ async def test_event_stream_with_simple_chain() -> None:
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content=" ")},
|
||||
"event": "on_chain_stream",
|
||||
"metadata": {},
|
||||
"metadata": {"foo": "bar"},
|
||||
"name": "my_chain",
|
||||
"run_id": None,
|
||||
"tags": [],
|
||||
"tags": ["my_chain"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content=" ")},
|
||||
@@ -414,10 +420,10 @@ async def test_event_stream_with_simple_chain() -> None:
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="world!")},
|
||||
"event": "on_chain_stream",
|
||||
"metadata": {},
|
||||
"metadata": {"foo": "bar"},
|
||||
"name": "my_chain",
|
||||
"run_id": None,
|
||||
"tags": [],
|
||||
"tags": ["my_chain"],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": AIMessageChunk(content="world!")},
|
||||
@@ -459,14 +465,12 @@ async def test_event_stream_with_simple_chain() -> None:
|
||||
"tags": ["my_chain", "my_model", "seq:step:2"],
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"output": AIMessageChunk(content="hello world!"),
|
||||
},
|
||||
"data": {"output": AIMessageChunk(content="hello world!")},
|
||||
"event": "on_chain_end",
|
||||
"metadata": {},
|
||||
"metadata": {"foo": "bar"},
|
||||
"name": "my_chain",
|
||||
"run_id": None,
|
||||
"tags": [],
|
||||
"tags": ["my_chain"],
|
||||
},
|
||||
]
|
||||
|
||||
@@ -543,7 +547,7 @@ async def test_event_stream_with_retriever() -> None:
|
||||
},
|
||||
"event": "on_retriever_start",
|
||||
"metadata": {},
|
||||
"name": "Retriever",
|
||||
"name": "HardCodedRetriever",
|
||||
"run_id": None,
|
||||
"tags": [],
|
||||
},
|
||||
@@ -556,7 +560,7 @@ async def test_event_stream_with_retriever() -> None:
|
||||
},
|
||||
"event": "on_retriever_stream",
|
||||
"metadata": {},
|
||||
"name": "Retriever",
|
||||
"name": "HardCodedRetriever",
|
||||
"run_id": None,
|
||||
"tags": [],
|
||||
},
|
||||
@@ -569,7 +573,7 @@ async def test_event_stream_with_retriever() -> None:
|
||||
},
|
||||
"event": "on_retriever_end",
|
||||
"metadata": {},
|
||||
"name": "Retriever",
|
||||
"name": "HardCodedRetriever",
|
||||
"run_id": None,
|
||||
"tags": [],
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user