[Enhancement] Add support for directly providing a run_id (#18990)

The root run id (~trace id's) is useful for assigning feedback, but the
current recommended approach is to use callbacks to retrieve it, which
has some drawbacks:
1. Doesn't work for streaming until after the first event
2. Doesn't let you call other endpoints with the same trace ID in
parallel (since you have to wait until the call is completed/started to
use

This PR lets you provide = "run_id" in the runnable config.

Couple considerations:

1. For batch calls, we split the trace up into separate trees (to permit
better rendering). We keep the provided run ID for the first one and
generate a unique one for other elements of the batch.
2. For nested calls, the provided ID is ONLY used on the top root/trace.



### Example Usage


```
chain.invoke("foo", {"run_id": uuid.uuid4()})
```
This commit is contained in:
William FH
2024-03-18 15:03:04 -07:00
committed by GitHub
parent bd329e9aad
commit 780337488e
11 changed files with 221 additions and 30 deletions

View File

@@ -1,4 +1,5 @@
import sys
import uuid
from functools import partial
from operator import itemgetter
from typing import (
@@ -136,6 +137,22 @@ class FakeTracer(BaseTracer):
self.runs.append(self._copy_run(run))
def flattened_runs(self) -> List[Run]:
q = [] + self.runs
result = []
while q:
parent = q.pop()
result.append(parent)
if parent.child_runs:
q.extend(parent.child_runs)
return result
@property
def run_ids(self) -> List[Optional[uuid.UUID]]:
runs = self.flattened_runs()
uuids_map = {v: k for k, v in self.uuids_map.items()}
return [uuids_map.get(r.id) for r in runs]
class FakeRunnable(Runnable[str, int]):
def invoke(
@@ -1367,6 +1384,7 @@ async def test_with_config_metadata_passthrough(mocker: MockerFixture) -> None:
recursion_limit=25,
configurable={"hello": "there"},
metadata={"hello": "there", "bye": "now"},
run_id=None,
),
)
spy.reset_mock()
@@ -1508,6 +1526,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
tags=["c"],
callbacks=None,
recursion_limit=5,
run_id=None,
),
),
mocker.call(
@@ -1517,6 +1536,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
tags=["c"],
callbacks=None,
recursion_limit=5,
run_id=None,
),
),
]
@@ -1542,6 +1562,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
tags=["c"],
callbacks=None,
recursion_limit=5,
run_id=None,
),
)
second_call = next(call for call in spy.call_args_list if call.args[0] == "wooorld")
@@ -1552,6 +1573,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
tags=["c"],
callbacks=None,
recursion_limit=5,
run_id=None,
),
)
@@ -1620,6 +1642,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
tags=[],
callbacks=None,
recursion_limit=25,
run_id=None,
),
),
mocker.call(
@@ -1629,6 +1652,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
tags=[],
callbacks=None,
recursion_limit=25,
run_id=None,
),
),
]
@@ -4822,27 +4846,45 @@ async def test_runnable_gen_context_config() -> None:
}
tracer = FakeTracer()
assert runnable.invoke(None, {"callbacks": [tracer]}) == 6
run_id = uuid.uuid4()
assert runnable.invoke(None, {"callbacks": [tracer], "run_id": run_id}) == 6
assert len(tracer.runs) == 1
assert tracer.runs[0].outputs == {"output": 6}
assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
run_ids = tracer.run_ids
assert run_id in run_ids
assert len(run_ids) == len(set(run_ids))
tracer.runs.clear()
assert list(runnable.stream(None)) == [1, 2, 3]
assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"
tracer = FakeTracer()
assert list(runnable.stream(None, {"callbacks": [tracer]})) == [1, 2, 3]
run_id = uuid.uuid4()
assert list(runnable.stream(None, {"callbacks": [tracer], "run_id": run_id})) == [
1,
2,
3,
]
assert len(tracer.runs) == 1
assert tracer.runs[0].outputs == {"output": 6}
assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
run_ids = tracer.run_ids
assert run_id in run_ids
assert len(run_ids) == len(set(run_ids))
tracer.runs.clear()
tracer = FakeTracer()
assert runnable.batch([None, None], {"callbacks": [tracer]}) == [6, 6]
run_id = uuid.uuid4()
with pytest.warns(RuntimeWarning):
assert runnable.batch(
[None, None], {"callbacks": [tracer], "run_id": run_id}
) == [6, 6]
assert len(tracer.runs) == 2
assert tracer.runs[0].outputs == {"output": 6}
assert tracer.runs[1].outputs == {"output": 6}
@@ -4865,19 +4907,30 @@ async def test_runnable_gen_context_config() -> None:
arunnable = RunnableGenerator(agen)
tracer = FakeTracer()
assert await arunnable.ainvoke(None, {"callbacks": [tracer]}) == 6
run_id = uuid.uuid4()
assert await arunnable.ainvoke(None, {"callbacks": [tracer], "run_id": run_id}) == 6
assert len(tracer.runs) == 1
assert tracer.runs[0].outputs == {"output": 6}
assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
run_ids = tracer.run_ids
assert run_id in run_ids
assert len(run_ids) == len(set(run_ids))
tracer.runs.clear()
assert [p async for p in arunnable.astream(None)] == [1, 2, 3]
assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call"
tracer = FakeTracer()
assert [p async for p in arunnable.astream(None, {"callbacks": [tracer]})] == [
run_id = uuid.uuid4()
assert [
p
async for p in arunnable.astream(
None, {"callbacks": [tracer], "run_id": run_id}
)
] == [
1,
2,
3,
@@ -4887,9 +4940,16 @@ async def test_runnable_gen_context_config() -> None:
assert len(tracer.runs[0].child_runs) == 3
assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"]
assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3]
run_ids = tracer.run_ids
assert run_id in run_ids
assert len(run_ids) == len(set(run_ids))
tracer = FakeTracer()
assert await arunnable.abatch([None, None], {"callbacks": [tracer]}) == [6, 6]
run_id = uuid.uuid4()
with pytest.warns(RuntimeWarning):
assert await arunnable.abatch(
[None, None], {"callbacks": [tracer], "run_id": run_id}
) == [6, 6]
assert len(tracer.runs) == 2
assert tracer.runs[0].outputs == {"output": 6}
assert tracer.runs[1].outputs == {"output": 6}