mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-15 22:44:36 +00:00
[Enhancement] Add support for directly providing a run_id (#18990)
The root run id (~trace id's) is useful for assigning feedback, but the current recommended approach is to use callbacks to retrieve it, which has some drawbacks: 1. Doesn't work for streaming until after the first event 2. Doesn't let you call other endpoints with the same trace ID in parallel (since you have to wait until the call is completed/started to use This PR lets you provide = "run_id" in the runnable config. Couple considerations: 1. For batch calls, we split the trace up into separate trees (to permit better rendering). We keep the provided run ID for the first one and generate a unique one for other elements of the batch. 2. For nested calls, the provided ID is ONLY used on the top root/trace. ### Example Usage ``` chain.invoke("foo", {"run_id": uuid.uuid4()}) ```
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import sys
|
||||
import uuid
|
||||
from functools import partial
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
@@ -136,6 +137,22 @@ class FakeTracer(BaseTracer):
|
||||
|
||||
self.runs.append(self._copy_run(run))
|
||||
|
||||
def flattened_runs(self) -> List[Run]:
|
||||
q = [] + self.runs
|
||||
result = []
|
||||
while q:
|
||||
parent = q.pop()
|
||||
result.append(parent)
|
||||
if parent.child_runs:
|
||||
q.extend(parent.child_runs)
|
||||
return result
|
||||
|
||||
@property
|
||||
def run_ids(self) -> List[Optional[uuid.UUID]]:
|
||||
runs = self.flattened_runs()
|
||||
uuids_map = {v: k for k, v in self.uuids_map.items()}
|
||||
return [uuids_map.get(r.id) for r in runs]
|
||||
|
||||
|
||||
class FakeRunnable(Runnable[str, int]):
|
||||
def invoke(
|
||||
@@ -1367,6 +1384,7 @@ async def test_with_config_metadata_passthrough(mocker: MockerFixture) -> None:
|
||||
recursion_limit=25,
|
||||
configurable={"hello": "there"},
|
||||
metadata={"hello": "there", "bye": "now"},
|
||||
run_id=None,
|
||||
),
|
||||
)
|
||||
spy.reset_mock()
|
||||
@@ -1508,6 +1526,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
|
||||
tags=["c"],
|
||||
callbacks=None,
|
||||
recursion_limit=5,
|
||||
run_id=None,
|
||||
),
|
||||
),
|
||||
mocker.call(
|
||||
@@ -1517,6 +1536,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
|
||||
tags=["c"],
|
||||
callbacks=None,
|
||||
recursion_limit=5,
|
||||
run_id=None,
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -1542,6 +1562,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
|
||||
tags=["c"],
|
||||
callbacks=None,
|
||||
recursion_limit=5,
|
||||
run_id=None,
|
||||
),
|
||||
)
|
||||
second_call = next(call for call in spy.call_args_list if call.args[0] == "wooorld")
|
||||
@@ -1552,6 +1573,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
|
||||
tags=["c"],
|
||||
callbacks=None,
|
||||
recursion_limit=5,
|
||||
run_id=None,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1620,6 +1642,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
||||
tags=[],
|
||||
callbacks=None,
|
||||
recursion_limit=25,
|
||||
run_id=None,
|
||||
),
|
||||
),
|
||||
mocker.call(
|
||||
@@ -1629,6 +1652,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
||||
tags=[],
|
||||
callbacks=None,
|
||||
recursion_limit=25,
|
||||
run_id=None,
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -4822,27 +4846,45 @@ async def test_runnable_gen_context_config() -> None:
|
||||
}
|
||||
|
||||
tracer = FakeTracer()
|
||||
assert runnable.invoke(None, {"callbacks": [tracer]}) == 6
|
||||
run_id = uuid.uuid4()
|
||||
assert runnable.invoke(None, {"callbacks": [tracer], "run_id": run_id}) == 6
|
||||
assert len(tracer.runs) == 1
|
||||
assert tracer.runs[0].outputs == {"output": 6}
|
||||
assert len(tracer.runs[0].child_runs) == 3
|
||||
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||
run_ids = tracer.run_ids
|
||||
assert run_id in run_ids
|
||||
assert len(run_ids) == len(set(run_ids))
|
||||
tracer.runs.clear()
|
||||
|
||||
assert list(runnable.stream(None)) == [1, 2, 3]
|
||||
assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"
|
||||
|
||||
tracer = FakeTracer()
|
||||
assert list(runnable.stream(None, {"callbacks": [tracer]})) == [1, 2, 3]
|
||||
run_id = uuid.uuid4()
|
||||
assert list(runnable.stream(None, {"callbacks": [tracer], "run_id": run_id})) == [
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
]
|
||||
assert len(tracer.runs) == 1
|
||||
assert tracer.runs[0].outputs == {"output": 6}
|
||||
assert len(tracer.runs[0].child_runs) == 3
|
||||
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||
run_ids = tracer.run_ids
|
||||
assert run_id in run_ids
|
||||
assert len(run_ids) == len(set(run_ids))
|
||||
tracer.runs.clear()
|
||||
|
||||
tracer = FakeTracer()
|
||||
assert runnable.batch([None, None], {"callbacks": [tracer]}) == [6, 6]
|
||||
run_id = uuid.uuid4()
|
||||
|
||||
with pytest.warns(RuntimeWarning):
|
||||
assert runnable.batch(
|
||||
[None, None], {"callbacks": [tracer], "run_id": run_id}
|
||||
) == [6, 6]
|
||||
assert len(tracer.runs) == 2
|
||||
assert tracer.runs[0].outputs == {"output": 6}
|
||||
assert tracer.runs[1].outputs == {"output": 6}
|
||||
@@ -4865,19 +4907,30 @@ async def test_runnable_gen_context_config() -> None:
|
||||
arunnable = RunnableGenerator(agen)
|
||||
|
||||
tracer = FakeTracer()
|
||||
assert await arunnable.ainvoke(None, {"callbacks": [tracer]}) == 6
|
||||
|
||||
run_id = uuid.uuid4()
|
||||
assert await arunnable.ainvoke(None, {"callbacks": [tracer], "run_id": run_id}) == 6
|
||||
assert len(tracer.runs) == 1
|
||||
assert tracer.runs[0].outputs == {"output": 6}
|
||||
assert len(tracer.runs[0].child_runs) == 3
|
||||
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||
run_ids = tracer.run_ids
|
||||
assert run_id in run_ids
|
||||
assert len(run_ids) == len(set(run_ids))
|
||||
tracer.runs.clear()
|
||||
|
||||
assert [p async for p in arunnable.astream(None)] == [1, 2, 3]
|
||||
assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"
|
||||
|
||||
tracer = FakeTracer()
|
||||
assert [p async for p in arunnable.astream(None, {"callbacks": [tracer]})] == [
|
||||
run_id = uuid.uuid4()
|
||||
assert [
|
||||
p
|
||||
async for p in arunnable.astream(
|
||||
None, {"callbacks": [tracer], "run_id": run_id}
|
||||
)
|
||||
] == [
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
@@ -4887,9 +4940,16 @@ async def test_runnable_gen_context_config() -> None:
|
||||
assert len(tracer.runs[0].child_runs) == 3
|
||||
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
|
||||
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
|
||||
run_ids = tracer.run_ids
|
||||
assert run_id in run_ids
|
||||
assert len(run_ids) == len(set(run_ids))
|
||||
|
||||
tracer = FakeTracer()
|
||||
assert await arunnable.abatch([None, None], {"callbacks": [tracer]}) == [6, 6]
|
||||
run_id = uuid.uuid4()
|
||||
with pytest.warns(RuntimeWarning):
|
||||
assert await arunnable.abatch(
|
||||
[None, None], {"callbacks": [tracer], "run_id": run_id}
|
||||
) == [6, 6]
|
||||
assert len(tracer.runs) == 2
|
||||
assert tracer.runs[0].outputs == {"output": 6}
|
||||
assert tracer.runs[1].outputs == {"output": 6}
|
||||
|
Reference in New Issue
Block a user