In Runnable.stream_log build up final_output from adding output chunks (#12781)

Add arg to omit streamed_output list, in cases where final_output is
enough this saves bandwidth

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
This commit is contained in:
Nuno Campos 2023-11-28 21:50:41 +00:00 committed by GitHub
parent 970fe23feb
commit 0f255bb6c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 125 additions and 27 deletions

View File

@ -92,7 +92,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
if acc_gen is None: if acc_gen is None:
acc_gen = chunk_gen acc_gen = chunk_gen
else: else:
acc_gen += chunk_gen acc_gen = acc_gen + chunk_gen
parsed = self.parse_result([acc_gen], partial=True) parsed = self.parse_result([acc_gen], partial=True)
if parsed is not None and parsed != prev_parsed: if parsed is not None and parsed != prev_parsed:
@ -120,9 +120,9 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
if acc_gen is None: if acc_gen is None:
acc_gen = chunk_gen acc_gen = chunk_gen
else: else:
acc_gen += chunk_gen acc_gen = acc_gen + chunk_gen
parsed = self.parse_result([acc_gen], partial=True) parsed = await self.aparse_result([acc_gen], partial=True)
if parsed is not None and parsed != prev_parsed: if parsed is not None and parsed != prev_parsed:
if self.diff: if self.diff:
yield self._diff(prev_parsed, parsed) yield self._diff(prev_parsed, parsed)

View File

@ -34,13 +34,16 @@ from langchain_core.runnables.fallbacks import RunnableWithFallbacks
from langchain_core.runnables.passthrough import RunnablePassthrough from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.runnables.router import RouterInput, RouterRunnable from langchain_core.runnables.router import RouterInput, RouterRunnable
from langchain_core.runnables.utils import ( from langchain_core.runnables.utils import (
AddableDict,
ConfigurableField, ConfigurableField,
ConfigurableFieldMultiOption, ConfigurableFieldMultiOption,
ConfigurableFieldSingleOption, ConfigurableFieldSingleOption,
aadd,
add, add,
) )
__all__ = [ __all__ = [
"AddableDict",
"ConfigurableField", "ConfigurableField",
"ConfigurableFieldSingleOption", "ConfigurableFieldSingleOption",
"ConfigurableFieldMultiOption", "ConfigurableFieldMultiOption",
@ -60,5 +63,6 @@ __all__ = [
"RunnableSequence", "RunnableSequence",
"RunnableWithFallbacks", "RunnableWithFallbacks",
"get_config_list", "get_config_list",
"aadd",
"add", "add",
] ]

View File

@ -5,6 +5,7 @@ import inspect
import threading import threading
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from concurrent.futures import FIRST_COMPLETED, wait from concurrent.futures import FIRST_COMPLETED, wait
from copy import deepcopy
from functools import partial from functools import partial
from itertools import tee from itertools import tee
from operator import itemgetter from operator import itemgetter
@ -31,7 +32,7 @@ from typing import (
from typing_extensions import Literal, get_args from typing_extensions import Literal, get_args
from langchain_core.load.dump import dumpd from langchain_core.load.dump import dumpd, dumps
from langchain_core.load.serializable import Serializable from langchain_core.load.serializable import Serializable
from langchain_core.pydantic_v1 import BaseModel, Field, create_model from langchain_core.pydantic_v1 import BaseModel, Field, create_model
from langchain_core.runnables.config import ( from langchain_core.runnables.config import (
@ -507,6 +508,7 @@ class Runnable(Generic[Input, Output], ABC):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
*, *,
diff: Literal[True] = True, diff: Literal[True] = True,
with_streamed_output_list: bool = True,
include_names: Optional[Sequence[str]] = None, include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None, include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None, include_tags: Optional[Sequence[str]] = None,
@ -524,6 +526,7 @@ class Runnable(Generic[Input, Output], ABC):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
*, *,
diff: Literal[False], diff: Literal[False],
with_streamed_output_list: bool = True,
include_names: Optional[Sequence[str]] = None, include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None, include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None, include_tags: Optional[Sequence[str]] = None,
@ -540,6 +543,7 @@ class Runnable(Generic[Input, Output], ABC):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
*, *,
diff: bool = True, diff: bool = True,
with_streamed_output_list: bool = True,
include_names: Optional[Sequence[str]] = None, include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None, include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None, include_tags: Optional[Sequence[str]] = None,
@ -557,7 +561,20 @@ class Runnable(Generic[Input, Output], ABC):
step, and the final state of the run. step, and the final state of the run.
The jsonpatch ops can be applied in order to construct state. The jsonpatch ops can be applied in order to construct state.
Args:
input: The input to the runnable.
config: The config to use for the runnable.
diff: Whether to yield diffs between each step, or the current state.
with_streamed_output_list: Whether to yield the streamed_output list.
include_names: Only include logs with these names.
include_types: Only include logs with these types.
include_tags: Only include logs with these tags.
exclude_names: Exclude logs with these names.
exclude_types: Exclude logs with these types.
exclude_tags: Exclude logs with these tags.
""" """
import jsonpatch # type: ignore[import]
from langchain_core.callbacks.base import BaseCallbackManager from langchain_core.callbacks.base import BaseCallbackManager
from langchain_core.tracers.log_stream import ( from langchain_core.tracers.log_stream import (
@ -598,16 +615,36 @@ class Runnable(Generic[Input, Output], ABC):
# add each chunk to the output stream # add each chunk to the output stream
async def consume_astream() -> None: async def consume_astream() -> None:
try: try:
prev_final_output: Optional[Output] = None
final_output: Optional[Output] = None
async for chunk in self.astream(input, config, **kwargs): async for chunk in self.astream(input, config, **kwargs):
await stream.send_stream.send( prev_final_output = final_output
RunLogPatch( if final_output is None:
final_output = chunk
else:
try:
final_output = final_output + chunk # type: ignore
except TypeError:
final_output = chunk
patches: List[Dict[str, Any]] = []
if with_streamed_output_list:
patches.append(
{ {
"op": "add", "op": "add",
"path": "/streamed_output/-", "path": "/streamed_output/-",
"value": chunk, # chunk cannot be shared between
# streamed_output and final_output
# otherwise jsonpatch.apply will
# modify both
"value": deepcopy(chunk),
} }
) )
) for op in jsonpatch.JsonPatch.from_diff(
prev_final_output, final_output, dumps=dumps
):
patches.append({**op, "path": f"/final_output{op['path']}"})
await stream.send_stream.send(RunLogPatch(*patches))
finally: finally:
await stream.send_stream.aclose() await stream.send_stream.aclose()

View File

@ -59,7 +59,7 @@ class RunState(TypedDict):
"""List of output chunks streamed by Runnable.stream()""" """List of output chunks streamed by Runnable.stream()"""
final_output: Optional[Any] final_output: Optional[Any]
"""Final output of the run, usually the result of aggregating (`+`) streamed_output. """Final output of the run, usually the result of aggregating (`+`) streamed_output.
Only available after the run has finished successfully.""" Updated throughout the run when supported by the Runnable."""
logs: Dict[str, LogEntry] logs: Dict[str, LogEntry]
"""Map of run names to sub-runs. If filters were supplied, this list will """Map of run names to sub-runs. If filters were supplied, this list will
@ -151,9 +151,7 @@ class LogStreamCallbackHandler(BaseTracer):
send_stream: Any send_stream: Any
receive_stream: Any receive_stream: Any
send_stream, receive_stream = create_memory_object_stream( send_stream, receive_stream = create_memory_object_stream(math.inf)
math.inf, item_type=RunLogPatch
)
self.lock = threading.Lock() self.lock = threading.Lock()
self.send_stream = send_stream self.send_stream = send_stream
self.receive_stream = receive_stream self.receive_stream = receive_stream
@ -278,15 +276,6 @@ class LogStreamCallbackHandler(BaseTracer):
) )
finally: finally:
if run.id == self.root_id: if run.id == self.root_id:
self.send_stream.send_nowait(
RunLogPatch(
{
"op": "replace",
"path": "/final_output",
"value": load(run.outputs),
}
)
)
if self.auto_close: if self.auto_close:
self.send_stream.close() self.send_stream.close()

View File

@ -1,6 +1,7 @@
from langchain_core.runnables import __all__ from langchain_core.runnables import __all__
EXPECTED_ALL = [ EXPECTED_ALL = [
"AddableDict",
"ConfigurableField", "ConfigurableField",
"ConfigurableFieldSingleOption", "ConfigurableFieldSingleOption",
"ConfigurableFieldMultiOption", "ConfigurableFieldMultiOption",
@ -20,6 +21,7 @@ EXPECTED_ALL = [
"RunnableSequence", "RunnableSequence",
"RunnableWithFallbacks", "RunnableWithFallbacks",
"get_config_list", "get_config_list",
"aadd",
"add", "add",
] ]

View File

@ -49,6 +49,7 @@ from langchain_core.prompts import (
from langchain_core.pydantic_v1 import BaseModel from langchain_core.pydantic_v1 import BaseModel
from langchain_core.retrievers import BaseRetriever from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import ( from langchain_core.runnables import (
AddableDict,
ConfigurableField, ConfigurableField,
ConfigurableFieldMultiOption, ConfigurableFieldMultiOption,
ConfigurableFieldSingleOption, ConfigurableFieldSingleOption,
@ -1542,6 +1543,7 @@ async def test_prompt() -> None:
assert stream_log[1:] == [ assert stream_log[1:] == [
RunLogPatch( RunLogPatch(
{"op": "add", "path": "/streamed_output/-", "value": expected},
{ {
"op": "replace", "op": "replace",
"path": "/final_output", "path": "/final_output",
@ -1551,9 +1553,8 @@ async def test_prompt() -> None:
HumanMessage(content="What is your name?"), HumanMessage(content="What is your name?"),
] ]
), ),
} },
), ),
RunLogPatch({"op": "add", "path": "/streamed_output/-", "value": expected}),
] ]
stream_log_state = [ stream_log_state = [
@ -1612,6 +1613,7 @@ async def test_prompt() -> None:
assert stream_log_nested[1:] == [ assert stream_log_nested[1:] == [
RunLogPatch( RunLogPatch(
{"op": "add", "path": "/streamed_output/-", "value": expected},
{ {
"op": "replace", "op": "replace",
"path": "/final_output", "path": "/final_output",
@ -1621,9 +1623,8 @@ async def test_prompt() -> None:
HumanMessage(content="What is your name?"), HumanMessage(content="What is your name?"),
] ]
), ),
} },
), ),
RunLogPatch({"op": "add", "path": "/streamed_output/-", "value": expected}),
] ]
@ -2107,9 +2108,9 @@ async def test_prompt_with_llm(
"value": "2023-01-01T00:00:00.000", "value": "2023-01-01T00:00:00.000",
}, },
), ),
RunLogPatch({"op": "add", "path": "/streamed_output/-", "value": "foo"}),
RunLogPatch( RunLogPatch(
{"op": "replace", "path": "/final_output", "value": {"output": "foo"}} {"op": "add", "path": "/streamed_output/-", "value": "foo"},
{"op": "replace", "path": "/final_output", "value": "foo"},
), ),
] ]
@ -2154,6 +2155,71 @@ async def test_stream_log_retriever() -> None:
] ]
@freeze_time("2023-01-01")
async def test_stream_log_lists() -> None:
async def list_producer(input: AsyncIterator[Any]) -> AsyncIterator[AddableDict]:
for i in range(4):
yield AddableDict(alist=[str(i)])
chain: Runnable = RunnableGenerator(list_producer)
stream_log = [
part async for part in chain.astream_log({"question": "What is your name?"})
]
# remove ids from logs
for part in stream_log:
for op in part.ops:
if (
isinstance(op["value"], dict)
and "id" in op["value"]
and not isinstance(op["value"]["id"], list) # serialized lc id
):
del op["value"]["id"]
assert stream_log == [
RunLogPatch(
{
"op": "replace",
"path": "",
"value": {"final_output": None, "logs": {}, "streamed_output": []},
}
),
RunLogPatch(
{"op": "add", "path": "/streamed_output/-", "value": {"alist": ["0"]}},
{"op": "replace", "path": "/final_output", "value": {"alist": ["0"]}},
),
RunLogPatch(
{"op": "add", "path": "/streamed_output/-", "value": {"alist": ["1"]}},
{"op": "add", "path": "/final_output/alist/1", "value": "1"},
),
RunLogPatch(
{"op": "add", "path": "/streamed_output/-", "value": {"alist": ["2"]}},
{"op": "add", "path": "/final_output/alist/2", "value": "2"},
),
RunLogPatch(
{"op": "add", "path": "/streamed_output/-", "value": {"alist": ["3"]}},
{"op": "add", "path": "/final_output/alist/3", "value": "3"},
),
]
state = add(stream_log)
assert isinstance(state, RunLog)
assert state.state == {
"final_output": {"alist": ["0", "1", "2", "3"]},
"logs": {},
"streamed_output": [
{"alist": ["0"]},
{"alist": ["1"]},
{"alist": ["2"]},
{"alist": ["3"]},
],
}
@pytest.mark.asyncio
@freeze_time("2023-01-01") @freeze_time("2023-01-01")
async def test_prompt_with_llm_and_async_lambda( async def test_prompt_with_llm_and_async_lambda(
mocker: MockerFixture, snapshot: SnapshotAssertion mocker: MockerFixture, snapshot: SnapshotAssertion