Populate streamed_output for all runs handled by atransform_stream_with_config (#15599)

This means that users of astream_log() now get streamed output of
virtually all requested runs, whereas before the only streamed output
would be for the root run and raw llm runs

<!-- Thank you for contributing to LangChain!

Please title your PR "<package>: <description>", where <package> is
whichever of langchain, community, core, experimental, etc. is being
modified.

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes if applicable,
  - **Dependencies:** any dependencies required for this change,
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` from the root
of the package you've modified to check this locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc: https://python.langchain.com/docs/contributing/

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
This commit is contained in:
Nuno Campos 2024-01-07 19:35:43 -08:00 committed by GitHub
parent 7025fa23aa
commit ef22559f1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 446 additions and 13 deletions

View File

@ -1277,6 +1277,8 @@ class Runnable(Generic[Input, Output], ABC):
"""Helper method to transform an Async Iterator of Input values into an Async """Helper method to transform an Async Iterator of Input values into an Async
Iterator of Output values, with callbacks. Iterator of Output values, with callbacks.
Use this to implement `astream()` or `atransform()` in Runnable subclasses.""" Use this to implement `astream()` or `atransform()` in Runnable subclasses."""
from langchain_core.tracers.log_stream import LogStreamCallbackHandler
# tee the input so we can iterate over it twice # tee the input so we can iterate over it twice
input_for_tracing, input_for_transform = atee(input, 2) input_for_tracing, input_for_transform = atee(input, 2)
# Start the input iterator to ensure the input runnable starts before this one # Start the input iterator to ensure the input runnable starts before this one
@ -1302,6 +1304,16 @@ class Runnable(Generic[Input, Output], ABC):
context = copy_context() context = copy_context()
context.run(var_child_runnable_config.set, child_config) context.run(var_child_runnable_config.set, child_config)
iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type] iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
if stream_log := next(
(
h
for h in run_manager.handlers
if isinstance(h, LogStreamCallbackHandler)
),
None,
):
# populates streamed_output in astream_log() output if needed
iterator = stream_log.tap_output_aiter(run_manager.run_id, iterator)
try: try:
while True: while True:
if accepts_context(asyncio.create_task): if accepts_context(asyncio.create_task):
@ -2733,6 +2745,7 @@ class RunnableLambda(Runnable[Input, Output]):
], ],
] ]
] = None, ] = None,
name: Optional[str] = None,
) -> None: ) -> None:
"""Create a RunnableLambda from a callable, and async callable or both. """Create a RunnableLambda from a callable, and async callable or both.
@ -2766,7 +2779,9 @@ class RunnableLambda(Runnable[Input, Output]):
) )
try: try:
if func_for_name.__name__ != "<lambda>": if name is not None:
self.name = name
elif func_for_name.__name__ != "<lambda>":
self.name = func_for_name.__name__ self.name = func_for_name.__name__
except AttributeError: except AttributeError:
pass pass
@ -3046,17 +3061,7 @@ class RunnableLambda(Runnable[Input, Output]):
def _config( def _config(
self, config: Optional[RunnableConfig], callable: Callable[..., Any] self, config: Optional[RunnableConfig], callable: Callable[..., Any]
) -> RunnableConfig: ) -> RunnableConfig:
config = ensure_config(config) return ensure_config(config)
if config.get("run_name") is None:
try:
run_name = callable.__name__
except AttributeError:
run_name = None
if run_name is not None:
return patch_config(config, run_name=run_name)
return config
def invoke( def invoke(
self, self,

View File

@ -12,6 +12,7 @@ from typing import (
Optional, Optional,
Sequence, Sequence,
TypedDict, TypedDict,
TypeVar,
Union, Union,
) )
from uuid import UUID from uuid import UUID
@ -128,6 +129,9 @@ class RunLog(RunLogPatch):
return f"RunLog({pformat(self.state)})" return f"RunLog({pformat(self.state)})"
T = TypeVar("T")
class LogStreamCallbackHandler(BaseTracer): class LogStreamCallbackHandler(BaseTracer):
"""A tracer that streams run logs to a stream.""" """A tracer that streams run logs to a stream."""
@ -165,6 +169,28 @@ class LogStreamCallbackHandler(BaseTracer):
def __aiter__(self) -> AsyncIterator[RunLogPatch]: def __aiter__(self) -> AsyncIterator[RunLogPatch]:
return self.receive_stream.__aiter__() return self.receive_stream.__aiter__()
async def tap_output_aiter(
self, run_id: UUID, output: AsyncIterator[T]
) -> AsyncIterator[T]:
"""Tap an output async iterator to stream its values to the log."""
async for chunk in output:
# root run is handled in .astream_log()
if run_id != self.root_id:
# if we can't find the run silently ignore
# eg. because this run wasn't included in the log
if key := self._key_map_by_run_id.get(run_id):
self.send_stream.send_nowait(
RunLogPatch(
{
"op": "add",
"path": f"/logs/{key}/streamed_output/-",
"value": chunk,
}
)
)
yield chunk
def include_run(self, run: Run) -> bool: def include_run(self, run: Run) -> bool:
if run.id == self.root_id: if run.id == self.root_id:
return False return False

File diff suppressed because one or more lines are too long

View File

@ -2140,6 +2140,272 @@ async def test_prompt_with_llm(
assert stream_log == expected assert stream_log == expected
@freeze_time("2023-01-01")
async def test_prompt_with_llm_parser(
mocker: MockerFixture, snapshot: SnapshotAssertion
) -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
llm = FakeStreamingListLLM(responses=["bear, dog, cat", "tomato, lettuce, onion"])
parser = CommaSeparatedListOutputParser()
chain: Runnable = prompt | llm | parser
assert isinstance(chain, RunnableSequence)
assert chain.first == prompt
assert chain.middle == [llm]
assert chain.last == parser
assert dumps(chain, pretty=True) == snapshot
# Test invoke
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
llm_spy = mocker.spy(llm.__class__, "ainvoke")
parser_spy = mocker.spy(parser.__class__, "ainvoke")
tracer = FakeTracer()
assert await chain.ainvoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
) == ["bear", "dog", "cat"]
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert llm_spy.call_args.args[1] == ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
)
assert parser_spy.call_args.args[1] == "bear, dog, cat"
assert tracer.runs == snapshot
mocker.stop(prompt_spy)
mocker.stop(llm_spy)
mocker.stop(parser_spy)
# Test batch
prompt_spy = mocker.spy(prompt.__class__, "abatch")
llm_spy = mocker.spy(llm.__class__, "abatch")
parser_spy = mocker.spy(parser.__class__, "abatch")
tracer = FakeTracer()
assert await chain.abatch(
[
{"question": "What is your name?"},
{"question": "What is your favorite color?"},
],
dict(callbacks=[tracer]),
) == [["tomato", "lettuce", "onion"], ["bear", "dog", "cat"]]
assert prompt_spy.call_args.args[1] == [
{"question": "What is your name?"},
{"question": "What is your favorite color?"},
]
assert llm_spy.call_args.args[1] == [
ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
),
ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your favorite color?"),
]
),
]
assert parser_spy.call_args.args[1] == [
"tomato, lettuce, onion",
"bear, dog, cat",
]
assert tracer.runs == snapshot
mocker.stop(prompt_spy)
mocker.stop(llm_spy)
mocker.stop(parser_spy)
# Test stream
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
llm_spy = mocker.spy(llm.__class__, "astream")
tracer = FakeTracer()
assert [
token
async for token in chain.astream(
{"question": "What is your name?"}, dict(callbacks=[tracer])
)
] == [["tomato"], ["lettuce"], ["onion"]]
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert llm_spy.call_args.args[1] == ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
)
prompt_spy.reset_mock()
llm_spy.reset_mock()
stream_log = [
part async for part in chain.astream_log({"question": "What is your name?"})
]
# remove ids from logs
for part in stream_log:
for op in part.ops:
if (
isinstance(op["value"], dict)
and "id" in op["value"]
and not isinstance(op["value"]["id"], list) # serialized lc id
):
del op["value"]["id"]
expected = [
RunLogPatch(
{
"op": "replace",
"path": "",
"value": {
"logs": {},
"final_output": None,
"streamed_output": [],
},
}
),
RunLogPatch(
{
"op": "add",
"path": "/logs/ChatPromptTemplate",
"value": {
"end_time": None,
"final_output": None,
"metadata": {},
"name": "ChatPromptTemplate",
"start_time": "2023-01-01T00:00:00.000+00:00",
"streamed_output": [],
"streamed_output_str": [],
"tags": ["seq:step:1"],
"type": "prompt",
},
}
),
RunLogPatch(
{
"op": "add",
"path": "/logs/ChatPromptTemplate/final_output",
"value": ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
),
},
{
"op": "add",
"path": "/logs/ChatPromptTemplate/end_time",
"value": "2023-01-01T00:00:00.000+00:00",
},
),
RunLogPatch(
{
"op": "add",
"path": "/logs/FakeStreamingListLLM",
"value": {
"end_time": None,
"final_output": None,
"metadata": {},
"name": "FakeStreamingListLLM",
"start_time": "2023-01-01T00:00:00.000+00:00",
"streamed_output": [],
"streamed_output_str": [],
"tags": ["seq:step:2"],
"type": "llm",
},
}
),
RunLogPatch(
{
"op": "add",
"path": "/logs/FakeStreamingListLLM/final_output",
"value": {
"generations": [
[
{
"generation_info": None,
"text": "bear, dog, cat",
"type": "Generation",
}
]
],
"llm_output": None,
"run": None,
},
},
{
"op": "add",
"path": "/logs/FakeStreamingListLLM/end_time",
"value": "2023-01-01T00:00:00.000+00:00",
},
),
RunLogPatch(
{
"op": "add",
"path": "/logs/CommaSeparatedListOutputParser",
"value": {
"end_time": None,
"final_output": None,
"metadata": {},
"name": "CommaSeparatedListOutputParser",
"start_time": "2023-01-01T00:00:00.000+00:00",
"streamed_output": [],
"streamed_output_str": [],
"tags": ["seq:step:3"],
"type": "parser",
},
}
),
RunLogPatch(
{
"op": "add",
"path": "/logs/CommaSeparatedListOutputParser/streamed_output/-",
"value": ["bear"],
}
),
RunLogPatch(
{"op": "add", "path": "/streamed_output/-", "value": ["bear"]},
{"op": "replace", "path": "/final_output", "value": ["bear"]},
),
RunLogPatch(
{
"op": "add",
"path": "/logs/CommaSeparatedListOutputParser/streamed_output/-",
"value": ["dog"],
}
),
RunLogPatch(
{"op": "add", "path": "/streamed_output/-", "value": ["dog"]},
{"op": "add", "path": "/final_output/1", "value": "dog"},
),
RunLogPatch(
{
"op": "add",
"path": "/logs/CommaSeparatedListOutputParser/streamed_output/-",
"value": ["cat"],
}
),
RunLogPatch(
{"op": "add", "path": "/streamed_output/-", "value": ["cat"]},
{"op": "add", "path": "/final_output/2", "value": "cat"},
),
RunLogPatch(
{
"op": "add",
"path": "/logs/CommaSeparatedListOutputParser/final_output",
"value": {"output": ["bear", "dog", "cat"]},
},
{
"op": "add",
"path": "/logs/CommaSeparatedListOutputParser/end_time",
"value": "2023-01-01T00:00:00.000+00:00",
},
),
]
assert stream_log == expected
@freeze_time("2023-01-01") @freeze_time("2023-01-01")
async def test_stream_log_retriever() -> None: async def test_stream_log_retriever() -> None:
prompt = ( prompt = (
@ -4606,6 +4872,14 @@ async def test_runnable_iter_context_config() -> None:
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"] assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3] assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
tracer = FakeTracer()
assert [p async for p in agen.astream_log("a", {"callbacks": [tracer]})]
assert len(tracer.runs) == 1
assert tracer.runs[0].outputs == {"output": 6}
assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
tracer = FakeTracer() tracer = FakeTracer()
assert await agen.abatch(["a", "a"], {"callbacks": [tracer]}) == [6, 6] assert await agen.abatch(["a", "a"], {"callbacks": [tracer]}) == [6, 6]
assert len(tracer.runs) == 2 assert len(tracer.runs) == 2