mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-10 05:20:39 +00:00
Populate streamed_output for all runs handled by atransform_stream_with_config (#15599)
This means that users of astream_log() now get streamed output of virtually all requested runs, whereas before the only streamed output would be for the root run and raw llm runs <!-- Thank you for contributing to LangChain! Please title your PR "<package>: <description>", where <package> is whichever of langchain, community, core, experimental, etc. is being modified. Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes if applicable, - **Dependencies:** any dependencies required for this change, - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` from the root of the package you've modified to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://python.langchain.com/docs/contributing/ If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
parent
7025fa23aa
commit
ef22559f1f
@ -1277,6 +1277,8 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
"""Helper method to transform an Async Iterator of Input values into an Async
|
||||
Iterator of Output values, with callbacks.
|
||||
Use this to implement `astream()` or `atransform()` in Runnable subclasses."""
|
||||
from langchain_core.tracers.log_stream import LogStreamCallbackHandler
|
||||
|
||||
# tee the input so we can iterate over it twice
|
||||
input_for_tracing, input_for_transform = atee(input, 2)
|
||||
# Start the input iterator to ensure the input runnable starts before this one
|
||||
@ -1302,6 +1304,16 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
context = copy_context()
|
||||
context.run(var_child_runnable_config.set, child_config)
|
||||
iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
|
||||
if stream_log := next(
|
||||
(
|
||||
h
|
||||
for h in run_manager.handlers
|
||||
if isinstance(h, LogStreamCallbackHandler)
|
||||
),
|
||||
None,
|
||||
):
|
||||
# populates streamed_output in astream_log() output if needed
|
||||
iterator = stream_log.tap_output_aiter(run_manager.run_id, iterator)
|
||||
try:
|
||||
while True:
|
||||
if accepts_context(asyncio.create_task):
|
||||
@ -2733,6 +2745,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
],
|
||||
]
|
||||
] = None,
|
||||
name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Create a RunnableLambda from a callable, and async callable or both.
|
||||
|
||||
@ -2766,7 +2779,9 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
)
|
||||
|
||||
try:
|
||||
if func_for_name.__name__ != "<lambda>":
|
||||
if name is not None:
|
||||
self.name = name
|
||||
elif func_for_name.__name__ != "<lambda>":
|
||||
self.name = func_for_name.__name__
|
||||
except AttributeError:
|
||||
pass
|
||||
@ -3046,17 +3061,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
def _config(
|
||||
self, config: Optional[RunnableConfig], callable: Callable[..., Any]
|
||||
) -> RunnableConfig:
|
||||
config = ensure_config(config)
|
||||
|
||||
if config.get("run_name") is None:
|
||||
try:
|
||||
run_name = callable.__name__
|
||||
except AttributeError:
|
||||
run_name = None
|
||||
if run_name is not None:
|
||||
return patch_config(config, run_name=run_name)
|
||||
|
||||
return config
|
||||
return ensure_config(config)
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
|
@ -12,6 +12,7 @@ from typing import (
|
||||
Optional,
|
||||
Sequence,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from uuid import UUID
|
||||
@ -128,6 +129,9 @@ class RunLog(RunLogPatch):
|
||||
return f"RunLog({pformat(self.state)})"
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class LogStreamCallbackHandler(BaseTracer):
|
||||
"""A tracer that streams run logs to a stream."""
|
||||
|
||||
@ -165,6 +169,28 @@ class LogStreamCallbackHandler(BaseTracer):
|
||||
def __aiter__(self) -> AsyncIterator[RunLogPatch]:
|
||||
return self.receive_stream.__aiter__()
|
||||
|
||||
async def tap_output_aiter(
|
||||
self, run_id: UUID, output: AsyncIterator[T]
|
||||
) -> AsyncIterator[T]:
|
||||
"""Tap an output async iterator to stream its values to the log."""
|
||||
async for chunk in output:
|
||||
# root run is handled in .astream_log()
|
||||
if run_id != self.root_id:
|
||||
# if we can't find the run silently ignore
|
||||
# eg. because this run wasn't included in the log
|
||||
if key := self._key_map_by_run_id.get(run_id):
|
||||
self.send_stream.send_nowait(
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": f"/logs/{key}/streamed_output/-",
|
||||
"value": chunk,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
yield chunk
|
||||
|
||||
def include_run(self, run: Run) -> bool:
|
||||
if run.id == self.root_id:
|
||||
return False
|
||||
|
File diff suppressed because one or more lines are too long
@ -2140,6 +2140,272 @@ async def test_prompt_with_llm(
|
||||
assert stream_log == expected
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_prompt_with_llm_parser(
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion
|
||||
) -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
+ "{question}"
|
||||
)
|
||||
llm = FakeStreamingListLLM(responses=["bear, dog, cat", "tomato, lettuce, onion"])
|
||||
parser = CommaSeparatedListOutputParser()
|
||||
|
||||
chain: Runnable = prompt | llm | parser
|
||||
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
assert chain.first == prompt
|
||||
assert chain.middle == [llm]
|
||||
assert chain.last == parser
|
||||
assert dumps(chain, pretty=True) == snapshot
|
||||
|
||||
# Test invoke
|
||||
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
|
||||
llm_spy = mocker.spy(llm.__class__, "ainvoke")
|
||||
parser_spy = mocker.spy(parser.__class__, "ainvoke")
|
||||
tracer = FakeTracer()
|
||||
assert await chain.ainvoke(
|
||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||
) == ["bear", "dog", "cat"]
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
assert llm_spy.call_args.args[1] == ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
)
|
||||
assert parser_spy.call_args.args[1] == "bear, dog, cat"
|
||||
assert tracer.runs == snapshot
|
||||
mocker.stop(prompt_spy)
|
||||
mocker.stop(llm_spy)
|
||||
mocker.stop(parser_spy)
|
||||
|
||||
# Test batch
|
||||
prompt_spy = mocker.spy(prompt.__class__, "abatch")
|
||||
llm_spy = mocker.spy(llm.__class__, "abatch")
|
||||
parser_spy = mocker.spy(parser.__class__, "abatch")
|
||||
tracer = FakeTracer()
|
||||
assert await chain.abatch(
|
||||
[
|
||||
{"question": "What is your name?"},
|
||||
{"question": "What is your favorite color?"},
|
||||
],
|
||||
dict(callbacks=[tracer]),
|
||||
) == [["tomato", "lettuce", "onion"], ["bear", "dog", "cat"]]
|
||||
assert prompt_spy.call_args.args[1] == [
|
||||
{"question": "What is your name?"},
|
||||
{"question": "What is your favorite color?"},
|
||||
]
|
||||
assert llm_spy.call_args.args[1] == [
|
||||
ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
),
|
||||
ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your favorite color?"),
|
||||
]
|
||||
),
|
||||
]
|
||||
assert parser_spy.call_args.args[1] == [
|
||||
"tomato, lettuce, onion",
|
||||
"bear, dog, cat",
|
||||
]
|
||||
assert tracer.runs == snapshot
|
||||
mocker.stop(prompt_spy)
|
||||
mocker.stop(llm_spy)
|
||||
mocker.stop(parser_spy)
|
||||
|
||||
# Test stream
|
||||
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
|
||||
llm_spy = mocker.spy(llm.__class__, "astream")
|
||||
tracer = FakeTracer()
|
||||
assert [
|
||||
token
|
||||
async for token in chain.astream(
|
||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||
)
|
||||
] == [["tomato"], ["lettuce"], ["onion"]]
|
||||
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
|
||||
assert llm_spy.call_args.args[1] == ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
)
|
||||
|
||||
prompt_spy.reset_mock()
|
||||
llm_spy.reset_mock()
|
||||
stream_log = [
|
||||
part async for part in chain.astream_log({"question": "What is your name?"})
|
||||
]
|
||||
|
||||
# remove ids from logs
|
||||
for part in stream_log:
|
||||
for op in part.ops:
|
||||
if (
|
||||
isinstance(op["value"], dict)
|
||||
and "id" in op["value"]
|
||||
and not isinstance(op["value"]["id"], list) # serialized lc id
|
||||
):
|
||||
del op["value"]["id"]
|
||||
|
||||
expected = [
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "",
|
||||
"value": {
|
||||
"logs": {},
|
||||
"final_output": None,
|
||||
"streamed_output": [],
|
||||
},
|
||||
}
|
||||
),
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/ChatPromptTemplate",
|
||||
"value": {
|
||||
"end_time": None,
|
||||
"final_output": None,
|
||||
"metadata": {},
|
||||
"name": "ChatPromptTemplate",
|
||||
"start_time": "2023-01-01T00:00:00.000+00:00",
|
||||
"streamed_output": [],
|
||||
"streamed_output_str": [],
|
||||
"tags": ["seq:step:1"],
|
||||
"type": "prompt",
|
||||
},
|
||||
}
|
||||
),
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/ChatPromptTemplate/final_output",
|
||||
"value": ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
),
|
||||
},
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/ChatPromptTemplate/end_time",
|
||||
"value": "2023-01-01T00:00:00.000+00:00",
|
||||
},
|
||||
),
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/FakeStreamingListLLM",
|
||||
"value": {
|
||||
"end_time": None,
|
||||
"final_output": None,
|
||||
"metadata": {},
|
||||
"name": "FakeStreamingListLLM",
|
||||
"start_time": "2023-01-01T00:00:00.000+00:00",
|
||||
"streamed_output": [],
|
||||
"streamed_output_str": [],
|
||||
"tags": ["seq:step:2"],
|
||||
"type": "llm",
|
||||
},
|
||||
}
|
||||
),
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/FakeStreamingListLLM/final_output",
|
||||
"value": {
|
||||
"generations": [
|
||||
[
|
||||
{
|
||||
"generation_info": None,
|
||||
"text": "bear, dog, cat",
|
||||
"type": "Generation",
|
||||
}
|
||||
]
|
||||
],
|
||||
"llm_output": None,
|
||||
"run": None,
|
||||
},
|
||||
},
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/FakeStreamingListLLM/end_time",
|
||||
"value": "2023-01-01T00:00:00.000+00:00",
|
||||
},
|
||||
),
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/CommaSeparatedListOutputParser",
|
||||
"value": {
|
||||
"end_time": None,
|
||||
"final_output": None,
|
||||
"metadata": {},
|
||||
"name": "CommaSeparatedListOutputParser",
|
||||
"start_time": "2023-01-01T00:00:00.000+00:00",
|
||||
"streamed_output": [],
|
||||
"streamed_output_str": [],
|
||||
"tags": ["seq:step:3"],
|
||||
"type": "parser",
|
||||
},
|
||||
}
|
||||
),
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/CommaSeparatedListOutputParser/streamed_output/-",
|
||||
"value": ["bear"],
|
||||
}
|
||||
),
|
||||
RunLogPatch(
|
||||
{"op": "add", "path": "/streamed_output/-", "value": ["bear"]},
|
||||
{"op": "replace", "path": "/final_output", "value": ["bear"]},
|
||||
),
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/CommaSeparatedListOutputParser/streamed_output/-",
|
||||
"value": ["dog"],
|
||||
}
|
||||
),
|
||||
RunLogPatch(
|
||||
{"op": "add", "path": "/streamed_output/-", "value": ["dog"]},
|
||||
{"op": "add", "path": "/final_output/1", "value": "dog"},
|
||||
),
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/CommaSeparatedListOutputParser/streamed_output/-",
|
||||
"value": ["cat"],
|
||||
}
|
||||
),
|
||||
RunLogPatch(
|
||||
{"op": "add", "path": "/streamed_output/-", "value": ["cat"]},
|
||||
{"op": "add", "path": "/final_output/2", "value": "cat"},
|
||||
),
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/CommaSeparatedListOutputParser/final_output",
|
||||
"value": {"output": ["bear", "dog", "cat"]},
|
||||
},
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/CommaSeparatedListOutputParser/end_time",
|
||||
"value": "2023-01-01T00:00:00.000+00:00",
|
||||
},
|
||||
),
|
||||
]
|
||||
assert stream_log == expected
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_stream_log_retriever() -> None:
|
||||
prompt = (
|
||||
@ -4606,6 +4872,14 @@ async def test_runnable_iter_context_config() -> None:
|
||||
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||
|
||||
tracer = FakeTracer()
|
||||
assert [p async for p in agen.astream_log("a", {"callbacks": [tracer]})]
|
||||
assert len(tracer.runs) == 1
|
||||
assert tracer.runs[0].outputs == {"output": 6}
|
||||
assert len(tracer.runs[0].child_runs) == 3
|
||||
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||
|
||||
tracer = FakeTracer()
|
||||
assert await agen.abatch(["a", "a"], {"callbacks": [tracer]}) == [6, 6]
|
||||
assert len(tracer.runs) == 2
|
||||
|
Loading…
Reference in New Issue
Block a user