mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-08 06:00:41 +00:00
Improve output of Runnable.astream_log() (#11391)
- Make logs a dictionary keyed by run name (and counter for repeats) - Ensure no output shows up in lc_serializable format - Fix up repr for RunLog and RunLogPatch <!-- Thank you for contributing to LangChain! Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes (if applicable), - **Dependencies:** any dependencies required for this change, - **Tag maintainer:** for a quicker response, tag the relevant maintainer (see below), - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/extras` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
parent
a30f98f534
commit
4d66756d93
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import math
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
@ -19,6 +20,7 @@ from anyio import create_memory_object_stream
|
||||
|
||||
from langchain.callbacks.tracers.base import BaseTracer
|
||||
from langchain.callbacks.tracers.schemas import Run
|
||||
from langchain.load.load import load
|
||||
from langchain.schema.output import ChatGenerationChunk, GenerationChunk
|
||||
|
||||
|
||||
@ -55,7 +57,7 @@ class RunState(TypedDict):
|
||||
"""Final output of the run, usually the result of aggregating streamed_output.
|
||||
Only available after the run has finished successfully."""
|
||||
|
||||
logs: list[LogEntry]
|
||||
logs: Dict[str, LogEntry]
|
||||
"""List of sub-runs contained in this run, if any, in the order they were started.
|
||||
If filters were supplied, this list will contain only the runs that matched the
|
||||
filters."""
|
||||
@ -85,7 +87,8 @@ class RunLogPatch:
|
||||
def __repr__(self) -> str:
|
||||
from pprint import pformat
|
||||
|
||||
return f"RunLogPatch(ops={pformat(self.ops)})"
|
||||
# 1:-1 to get rid of the [] around the list
|
||||
return f"RunLogPatch({pformat(self.ops)[1:-1]})"
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, RunLogPatch) and self.ops == other.ops
|
||||
@ -112,7 +115,7 @@ class RunLog(RunLogPatch):
|
||||
def __repr__(self) -> str:
|
||||
from pprint import pformat
|
||||
|
||||
return f"RunLog(state={pformat(self.state)})"
|
||||
return f"RunLog({pformat(self.state)})"
|
||||
|
||||
|
||||
class LogStreamCallbackHandler(BaseTracer):
|
||||
@ -143,7 +146,8 @@ class LogStreamCallbackHandler(BaseTracer):
|
||||
self.lock = threading.Lock()
|
||||
self.send_stream = send_stream
|
||||
self.receive_stream = receive_stream
|
||||
self._index_map: Dict[UUID, int] = {}
|
||||
self._key_map_by_run_id: Dict[UUID, str] = {}
|
||||
self._counter_map_by_name: Dict[str, int] = defaultdict(int)
|
||||
|
||||
def __aiter__(self) -> AsyncIterator[RunLogPatch]:
|
||||
return self.receive_stream.__aiter__()
|
||||
@ -196,7 +200,7 @@ class LogStreamCallbackHandler(BaseTracer):
|
||||
id=str(run.id),
|
||||
streamed_output=[],
|
||||
final_output=None,
|
||||
logs=[],
|
||||
logs={},
|
||||
),
|
||||
}
|
||||
)
|
||||
@ -207,14 +211,18 @@ class LogStreamCallbackHandler(BaseTracer):
|
||||
|
||||
# Determine previous index, increment by 1
|
||||
with self.lock:
|
||||
self._index_map[run.id] = max(self._index_map.values(), default=-1) + 1
|
||||
self._counter_map_by_name[run.name] += 1
|
||||
count = self._counter_map_by_name[run.name]
|
||||
self._key_map_by_run_id[run.id] = (
|
||||
run.name if count == 1 else f"{run.name}:{count}"
|
||||
)
|
||||
|
||||
# Add the run to the stream
|
||||
self.send_stream.send_nowait(
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": f"/logs/{self._index_map[run.id]}",
|
||||
"path": f"/logs/{self._key_map_by_run_id[run.id]}",
|
||||
"value": LogEntry(
|
||||
id=str(run.id),
|
||||
name=run.name,
|
||||
@ -233,7 +241,7 @@ class LogStreamCallbackHandler(BaseTracer):
|
||||
def _on_run_update(self, run: Run) -> None:
|
||||
"""Finish a run."""
|
||||
try:
|
||||
index = self._index_map.get(run.id)
|
||||
index = self._key_map_by_run_id.get(run.id)
|
||||
|
||||
if index is None:
|
||||
return
|
||||
@ -243,7 +251,8 @@ class LogStreamCallbackHandler(BaseTracer):
|
||||
{
|
||||
"op": "add",
|
||||
"path": f"/logs/{index}/final_output",
|
||||
"value": run.outputs,
|
||||
# to undo the dumpd done by some runnables / tracer / etc
|
||||
"value": load(run.outputs),
|
||||
},
|
||||
{
|
||||
"op": "add",
|
||||
@ -259,7 +268,7 @@ class LogStreamCallbackHandler(BaseTracer):
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "/final_output",
|
||||
"value": run.outputs,
|
||||
"value": load(run.outputs),
|
||||
}
|
||||
)
|
||||
)
|
||||
@ -273,7 +282,7 @@ class LogStreamCallbackHandler(BaseTracer):
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]],
|
||||
) -> None:
|
||||
"""Process new LLM token."""
|
||||
index = self._index_map.get(run.id)
|
||||
index = self._key_map_by_run_id.get(run.id)
|
||||
|
||||
if index is None:
|
||||
return
|
||||
|
@ -1239,7 +1239,7 @@ async def test_prompt() -> None:
|
||||
assert len(stream_log[0].ops) == 1
|
||||
assert stream_log[0].ops[0]["op"] == "replace"
|
||||
assert stream_log[0].ops[0]["path"] == ""
|
||||
assert stream_log[0].ops[0]["value"]["logs"] == []
|
||||
assert stream_log[0].ops[0]["value"]["logs"] == {}
|
||||
assert stream_log[0].ops[0]["value"]["final_output"] is None
|
||||
assert stream_log[0].ops[0]["value"]["streamed_output"] == []
|
||||
assert isinstance(stream_log[0].ops[0]["value"]["id"], str)
|
||||
@ -1249,40 +1249,12 @@ async def test_prompt() -> None:
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "/final_output",
|
||||
"value": {
|
||||
"id": ["langchain", "prompts", "chat", "ChatPromptValue"],
|
||||
"kwargs": {
|
||||
"messages": [
|
||||
{
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"messages",
|
||||
"SystemMessage",
|
||||
],
|
||||
"kwargs": {"content": "You are a nice " "assistant."},
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
},
|
||||
{
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"messages",
|
||||
"HumanMessage",
|
||||
],
|
||||
"kwargs": {
|
||||
"additional_kwargs": {},
|
||||
"content": "What is your " "name?",
|
||||
},
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
},
|
||||
]
|
||||
},
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
},
|
||||
"value": ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
),
|
||||
}
|
||||
),
|
||||
RunLogPatch({"op": "add", "path": "/streamed_output/-", "value": expected}),
|
||||
@ -1525,7 +1497,7 @@ async def test_prompt_with_llm(
|
||||
"op": "replace",
|
||||
"path": "",
|
||||
"value": {
|
||||
"logs": [],
|
||||
"logs": {},
|
||||
"final_output": None,
|
||||
"streamed_output": [],
|
||||
},
|
||||
@ -1534,7 +1506,7 @@ async def test_prompt_with_llm(
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/0",
|
||||
"path": "/logs/ChatPromptTemplate",
|
||||
"value": {
|
||||
"end_time": None,
|
||||
"final_output": None,
|
||||
@ -1550,55 +1522,24 @@ async def test_prompt_with_llm(
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/0/final_output",
|
||||
"value": {
|
||||
"id": ["langchain", "prompts", "chat", "ChatPromptValue"],
|
||||
"kwargs": {
|
||||
"messages": [
|
||||
{
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"messages",
|
||||
"SystemMessage",
|
||||
],
|
||||
"kwargs": {
|
||||
"additional_kwargs": {},
|
||||
"content": "You are a nice " "assistant.",
|
||||
},
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
},
|
||||
{
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"messages",
|
||||
"HumanMessage",
|
||||
],
|
||||
"kwargs": {
|
||||
"additional_kwargs": {},
|
||||
"content": "What is your " "name?",
|
||||
},
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
},
|
||||
]
|
||||
},
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
},
|
||||
"path": "/logs/ChatPromptTemplate/final_output",
|
||||
"value": ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
),
|
||||
},
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/0/end_time",
|
||||
"path": "/logs/ChatPromptTemplate/end_time",
|
||||
"value": "2023-01-01T00:00:00.000",
|
||||
},
|
||||
),
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/1",
|
||||
"path": "/logs/FakeListLLM",
|
||||
"value": {
|
||||
"end_time": None,
|
||||
"final_output": None,
|
||||
@ -1614,7 +1555,7 @@ async def test_prompt_with_llm(
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/1/final_output",
|
||||
"path": "/logs/FakeListLLM/final_output",
|
||||
"value": {
|
||||
"generations": [[{"generation_info": None, "text": "foo"}]],
|
||||
"llm_output": None,
|
||||
@ -1623,7 +1564,7 @@ async def test_prompt_with_llm(
|
||||
},
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/1/end_time",
|
||||
"path": "/logs/FakeListLLM/end_time",
|
||||
"value": "2023-01-01T00:00:00.000",
|
||||
},
|
||||
),
|
||||
@ -1634,6 +1575,192 @@ async def test_prompt_with_llm(
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_stream_log_retriever() -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
+ "{documents}"
|
||||
+ "{question}"
|
||||
)
|
||||
llm = FakeListLLM(responses=["foo", "bar"])
|
||||
|
||||
chain: Runnable = (
|
||||
{"documents": FakeRetriever(), "question": itemgetter("question")}
|
||||
| prompt
|
||||
| {"one": llm, "two": llm}
|
||||
)
|
||||
|
||||
stream_log = [
|
||||
part async for part in chain.astream_log({"question": "What is your name?"})
|
||||
]
|
||||
|
||||
# remove ids from logs
|
||||
for part in stream_log:
|
||||
for op in part.ops:
|
||||
if (
|
||||
isinstance(op["value"], dict)
|
||||
and "id" in op["value"]
|
||||
and not isinstance(op["value"]["id"], list) # serialized lc id
|
||||
):
|
||||
del op["value"]["id"]
|
||||
|
||||
assert stream_log[:-9] == [
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "",
|
||||
"value": {
|
||||
"logs": {},
|
||||
"final_output": None,
|
||||
"streamed_output": [],
|
||||
},
|
||||
}
|
||||
),
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/RunnableMap",
|
||||
"value": {
|
||||
"end_time": None,
|
||||
"final_output": None,
|
||||
"metadata": {},
|
||||
"name": "RunnableMap",
|
||||
"start_time": "2023-01-01T00:00:00.000",
|
||||
"streamed_output_str": [],
|
||||
"tags": ["seq:step:1"],
|
||||
"type": "chain",
|
||||
},
|
||||
}
|
||||
),
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/RunnableLambda",
|
||||
"value": {
|
||||
"end_time": None,
|
||||
"final_output": None,
|
||||
"metadata": {},
|
||||
"name": "RunnableLambda",
|
||||
"start_time": "2023-01-01T00:00:00.000",
|
||||
"streamed_output_str": [],
|
||||
"tags": ["map:key:question"],
|
||||
"type": "chain",
|
||||
},
|
||||
}
|
||||
),
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/RunnableLambda/final_output",
|
||||
"value": {"output": "What is your name?"},
|
||||
},
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/RunnableLambda/end_time",
|
||||
"value": "2023-01-01T00:00:00.000",
|
||||
},
|
||||
),
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/Retriever",
|
||||
"value": {
|
||||
"end_time": None,
|
||||
"final_output": None,
|
||||
"metadata": {},
|
||||
"name": "Retriever",
|
||||
"start_time": "2023-01-01T00:00:00.000",
|
||||
"streamed_output_str": [],
|
||||
"tags": ["map:key:documents"],
|
||||
"type": "retriever",
|
||||
},
|
||||
}
|
||||
),
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/Retriever/final_output",
|
||||
"value": {
|
||||
"documents": [
|
||||
Document(page_content="foo"),
|
||||
Document(page_content="bar"),
|
||||
]
|
||||
},
|
||||
},
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/Retriever/end_time",
|
||||
"value": "2023-01-01T00:00:00.000",
|
||||
},
|
||||
),
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/RunnableMap/final_output",
|
||||
"value": {
|
||||
"documents": [
|
||||
Document(page_content="foo"),
|
||||
Document(page_content="bar"),
|
||||
],
|
||||
"question": "What is your name?",
|
||||
},
|
||||
},
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/RunnableMap/end_time",
|
||||
"value": "2023-01-01T00:00:00.000",
|
||||
},
|
||||
),
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/ChatPromptTemplate",
|
||||
"value": {
|
||||
"end_time": None,
|
||||
"final_output": None,
|
||||
"metadata": {},
|
||||
"name": "ChatPromptTemplate",
|
||||
"start_time": "2023-01-01T00:00:00.000",
|
||||
"streamed_output_str": [],
|
||||
"tags": ["seq:step:2"],
|
||||
"type": "prompt",
|
||||
},
|
||||
}
|
||||
),
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/ChatPromptTemplate/final_output",
|
||||
"value": ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(
|
||||
content="[Document(page_content='foo'), Document(page_content='bar')]" # noqa: E501
|
||||
),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
),
|
||||
},
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/ChatPromptTemplate/end_time",
|
||||
"value": "2023-01-01T00:00:00.000",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
assert sorted(cast(RunLog, add(stream_log)).state["logs"]) == [
|
||||
"ChatPromptTemplate",
|
||||
"FakeListLLM",
|
||||
"FakeListLLM:2",
|
||||
"Retriever",
|
||||
"RunnableLambda",
|
||||
"RunnableMap",
|
||||
"RunnableMap:2",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_prompt_with_llm_and_async_lambda(
|
||||
@ -2291,14 +2418,18 @@ async def test_map_astream() -> None:
|
||||
assert isinstance(final_state.state["id"], str)
|
||||
assert len(final_state.ops) == len(streamed_ops)
|
||||
assert len(final_state.state["logs"]) == 5
|
||||
assert final_state.state["logs"][0]["name"] == "ChatPromptTemplate"
|
||||
assert final_state.state["logs"][0]["final_output"] == dumpd(
|
||||
prompt.invoke({"question": "What is your name?"})
|
||||
assert (
|
||||
final_state.state["logs"]["ChatPromptTemplate"]["name"] == "ChatPromptTemplate"
|
||||
)
|
||||
assert final_state.state["logs"][1]["name"] == "RunnableMap"
|
||||
assert sorted(log["name"] for log in final_state.state["logs"][2:]) == [
|
||||
assert final_state.state["logs"]["ChatPromptTemplate"][
|
||||
"final_output"
|
||||
] == prompt.invoke({"question": "What is your name?"})
|
||||
assert final_state.state["logs"]["RunnableMap"]["name"] == "RunnableMap"
|
||||
assert sorted(final_state.state["logs"]) == [
|
||||
"ChatPromptTemplate",
|
||||
"FakeListChatModel",
|
||||
"FakeStreamingListLLM",
|
||||
"RunnableMap",
|
||||
"RunnablePassthrough",
|
||||
]
|
||||
|
||||
@ -2316,7 +2447,7 @@ async def test_map_astream() -> None:
|
||||
assert final_state.state["final_output"] == final_value
|
||||
assert len(final_state.state["streamed_output"]) == len(streamed_chunks)
|
||||
assert len(final_state.state["logs"]) == 1
|
||||
assert final_state.state["logs"][0]["name"] == "FakeListChatModel"
|
||||
assert final_state.state["logs"]["FakeListChatModel"]["name"] == "FakeListChatModel"
|
||||
|
||||
# Test astream_log with exclude filters
|
||||
final_state = None
|
||||
@ -2332,13 +2463,17 @@ async def test_map_astream() -> None:
|
||||
assert final_state.state["final_output"] == final_value
|
||||
assert len(final_state.state["streamed_output"]) == len(streamed_chunks)
|
||||
assert len(final_state.state["logs"]) == 4
|
||||
assert final_state.state["logs"][0]["name"] == "ChatPromptTemplate"
|
||||
assert final_state.state["logs"][0]["final_output"] == dumpd(
|
||||
assert (
|
||||
final_state.state["logs"]["ChatPromptTemplate"]["name"] == "ChatPromptTemplate"
|
||||
)
|
||||
assert final_state.state["logs"]["ChatPromptTemplate"]["final_output"] == (
|
||||
prompt.invoke({"question": "What is your name?"})
|
||||
)
|
||||
assert final_state.state["logs"][1]["name"] == "RunnableMap"
|
||||
assert sorted(log["name"] for log in final_state.state["logs"][2:]) == [
|
||||
assert final_state.state["logs"]["RunnableMap"]["name"] == "RunnableMap"
|
||||
assert sorted(final_state.state["logs"]) == [
|
||||
"ChatPromptTemplate",
|
||||
"FakeStreamingListLLM",
|
||||
"RunnableMap",
|
||||
"RunnablePassthrough",
|
||||
]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user