mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-06 05:08:20 +00:00
Set Context in RunnableSequence & RunnableParallel (#25073)
This commit is contained in:
parent
71c0698ee4
commit
267855b3c1
@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import collections
|
import collections
|
||||||
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import threading
|
import threading
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@ -68,8 +69,8 @@ from langchain_core.runnables.utils import (
|
|||||||
Input,
|
Input,
|
||||||
Output,
|
Output,
|
||||||
accepts_config,
|
accepts_config,
|
||||||
accepts_context,
|
|
||||||
accepts_run_manager,
|
accepts_run_manager,
|
||||||
|
asyncio_accepts_context,
|
||||||
create_model,
|
create_model,
|
||||||
gather_with_concurrency,
|
gather_with_concurrency,
|
||||||
get_function_first_arg_dict_keys,
|
get_function_first_arg_dict_keys,
|
||||||
@ -1830,7 +1831,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
coro = acall_func_with_variable_args(
|
coro = acall_func_with_variable_args(
|
||||||
func, input, config, run_manager, **kwargs
|
func, input, config, run_manager, **kwargs
|
||||||
)
|
)
|
||||||
if accepts_context(asyncio.create_task):
|
if asyncio_accepts_context():
|
||||||
output: Output = await asyncio.create_task(coro, context=context) # type: ignore
|
output: Output = await asyncio.create_task(coro, context=context) # type: ignore
|
||||||
else:
|
else:
|
||||||
output = await coro
|
output = await coro
|
||||||
@ -2156,7 +2157,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
iterator = iterator_
|
iterator = iterator_
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
if accepts_context(asyncio.create_task):
|
if asyncio_accepts_context():
|
||||||
chunk: Output = await asyncio.create_task( # type: ignore[call-arg]
|
chunk: Output = await asyncio.create_task( # type: ignore[call-arg]
|
||||||
py_anext(iterator), # type: ignore[arg-type]
|
py_anext(iterator), # type: ignore[arg-type]
|
||||||
context=context,
|
context=context,
|
||||||
@ -2869,10 +2870,12 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
config = patch_config(
|
config = patch_config(
|
||||||
config, callbacks=run_manager.get_child(f"seq:step:{i+1}")
|
config, callbacks=run_manager.get_child(f"seq:step:{i+1}")
|
||||||
)
|
)
|
||||||
|
context = copy_context()
|
||||||
|
context.run(_set_config_context, config)
|
||||||
if i == 0:
|
if i == 0:
|
||||||
input = step.invoke(input, config, **kwargs)
|
input = context.run(step.invoke, input, config, **kwargs)
|
||||||
else:
|
else:
|
||||||
input = step.invoke(input, config)
|
input = context.run(step.invoke, input, config)
|
||||||
# finish the root run
|
# finish the root run
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
run_manager.on_chain_error(e)
|
run_manager.on_chain_error(e)
|
||||||
@ -2907,10 +2910,16 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
config = patch_config(
|
config = patch_config(
|
||||||
config, callbacks=run_manager.get_child(f"seq:step:{i+1}")
|
config, callbacks=run_manager.get_child(f"seq:step:{i+1}")
|
||||||
)
|
)
|
||||||
|
context = copy_context()
|
||||||
|
context.run(_set_config_context, config)
|
||||||
if i == 0:
|
if i == 0:
|
||||||
input = await step.ainvoke(input, config, **kwargs)
|
part = functools.partial(step.ainvoke, input, config, **kwargs)
|
||||||
else:
|
else:
|
||||||
input = await step.ainvoke(input, config)
|
part = functools.partial(step.ainvoke, input, config)
|
||||||
|
if asyncio_accepts_context():
|
||||||
|
input = await asyncio.create_task(part(), context=context) # type: ignore
|
||||||
|
else:
|
||||||
|
input = await asyncio.create_task(part())
|
||||||
# finish the root run
|
# finish the root run
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
await run_manager.on_chain_error(e)
|
await run_manager.on_chain_error(e)
|
||||||
@ -3542,21 +3551,30 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
|||||||
run_id=config.pop("run_id", None),
|
run_id=config.pop("run_id", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _invoke_step(
|
||||||
|
step: Runnable[Input, Any], input: Input, config: RunnableConfig, key: str
|
||||||
|
) -> Any:
|
||||||
|
child_config = patch_config(
|
||||||
|
config,
|
||||||
|
# mark each step as a child run
|
||||||
|
callbacks=run_manager.get_child(f"map:key:{key}"),
|
||||||
|
)
|
||||||
|
context = copy_context()
|
||||||
|
context.run(_set_config_context, child_config)
|
||||||
|
return context.run(
|
||||||
|
step.invoke,
|
||||||
|
input,
|
||||||
|
child_config,
|
||||||
|
)
|
||||||
|
|
||||||
# gather results from all steps
|
# gather results from all steps
|
||||||
try:
|
try:
|
||||||
# copy to avoid issues from the caller mutating the steps during invoke()
|
# copy to avoid issues from the caller mutating the steps during invoke()
|
||||||
steps = dict(self.steps__)
|
steps = dict(self.steps__)
|
||||||
|
|
||||||
with get_executor_for_config(config) as executor:
|
with get_executor_for_config(config) as executor:
|
||||||
futures = [
|
futures = [
|
||||||
executor.submit(
|
executor.submit(_invoke_step, step, input, config, key)
|
||||||
step.invoke,
|
|
||||||
input,
|
|
||||||
# mark each step as a child run
|
|
||||||
patch_config(
|
|
||||||
config,
|
|
||||||
callbacks=run_manager.get_child(f"map:key:{key}"),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
for key, step in steps.items()
|
for key, step in steps.items()
|
||||||
]
|
]
|
||||||
output = {key: future.result() for key, future in zip(steps, futures)}
|
output = {key: future.result() for key, future in zip(steps, futures)}
|
||||||
@ -3585,18 +3603,34 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
|||||||
run_id=config.pop("run_id", None),
|
run_id=config.pop("run_id", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _ainvoke_step(
|
||||||
|
step: Runnable[Input, Any], input: Input, config: RunnableConfig, key: str
|
||||||
|
) -> Any:
|
||||||
|
child_config = patch_config(
|
||||||
|
config,
|
||||||
|
callbacks=run_manager.get_child(f"map:key:{key}"),
|
||||||
|
)
|
||||||
|
context = copy_context()
|
||||||
|
context.run(_set_config_context, child_config)
|
||||||
|
if asyncio_accepts_context():
|
||||||
|
return await asyncio.create_task( # type: ignore
|
||||||
|
step.ainvoke(input, child_config), context=context
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return await asyncio.create_task(step.ainvoke(input, child_config))
|
||||||
|
|
||||||
# gather results from all steps
|
# gather results from all steps
|
||||||
try:
|
try:
|
||||||
# copy to avoid issues from the caller mutating the steps during invoke()
|
# copy to avoid issues from the caller mutating the steps during invoke()
|
||||||
steps = dict(self.steps__)
|
steps = dict(self.steps__)
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
*(
|
*(
|
||||||
step.ainvoke(
|
_ainvoke_step(
|
||||||
|
step,
|
||||||
input,
|
input,
|
||||||
# mark each step as a child run
|
# mark each step as a child run
|
||||||
patch_config(
|
config,
|
||||||
config, callbacks=run_manager.get_child(f"map:key:{key}")
|
key,
|
||||||
),
|
|
||||||
)
|
)
|
||||||
for key, step in steps.items()
|
for key, step in steps.items()
|
||||||
)
|
)
|
||||||
|
@ -348,7 +348,6 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
|
|||||||
base["callbacks"] = mngr
|
base["callbacks"] = mngr
|
||||||
else:
|
else:
|
||||||
# base_callbacks is also a manager
|
# base_callbacks is also a manager
|
||||||
|
|
||||||
manager = base_callbacks.__class__(
|
manager = base_callbacks.__class__(
|
||||||
parent_run_id=base_callbacks.parent_run_id
|
parent_run_id=base_callbacks.parent_run_id
|
||||||
or these_callbacks.parent_run_id,
|
or these_callbacks.parent_run_id,
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import typing
|
import typing
|
||||||
|
from contextvars import copy_context
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
Awaitable,
|
|
||||||
Dict,
|
Dict,
|
||||||
Iterator,
|
Iterator,
|
||||||
List,
|
List,
|
||||||
@ -23,6 +23,7 @@ from langchain_core.pydantic_v1 import BaseModel
|
|||||||
from langchain_core.runnables.base import Runnable, RunnableSerializable
|
from langchain_core.runnables.base import Runnable, RunnableSerializable
|
||||||
from langchain_core.runnables.config import (
|
from langchain_core.runnables.config import (
|
||||||
RunnableConfig,
|
RunnableConfig,
|
||||||
|
_set_config_context,
|
||||||
ensure_config,
|
ensure_config,
|
||||||
get_async_callback_manager_for_config,
|
get_async_callback_manager_for_config,
|
||||||
get_callback_manager_for_config,
|
get_callback_manager_for_config,
|
||||||
@ -33,6 +34,7 @@ from langchain_core.runnables.utils import (
|
|||||||
ConfigurableFieldSpec,
|
ConfigurableFieldSpec,
|
||||||
Input,
|
Input,
|
||||||
Output,
|
Output,
|
||||||
|
asyncio_accepts_context,
|
||||||
get_unique_config_specs,
|
get_unique_config_specs,
|
||||||
)
|
)
|
||||||
from langchain_core.utils.aiter import py_anext
|
from langchain_core.utils.aiter import py_anext
|
||||||
@ -172,9 +174,12 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
try:
|
try:
|
||||||
if self.exception_key and last_error is not None:
|
if self.exception_key and last_error is not None:
|
||||||
input[self.exception_key] = last_error
|
input[self.exception_key] = last_error
|
||||||
output = runnable.invoke(
|
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||||
|
context = copy_context()
|
||||||
|
context.run(_set_config_context, child_config)
|
||||||
|
output = context.run(
|
||||||
|
runnable.invoke,
|
||||||
input,
|
input,
|
||||||
patch_config(config, callbacks=run_manager.get_child()),
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
except self.exceptions_to_handle as e:
|
except self.exceptions_to_handle as e:
|
||||||
@ -220,11 +225,14 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
try:
|
try:
|
||||||
if self.exception_key and last_error is not None:
|
if self.exception_key and last_error is not None:
|
||||||
input[self.exception_key] = last_error
|
input[self.exception_key] = last_error
|
||||||
output = await runnable.ainvoke(
|
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||||
input,
|
context = copy_context()
|
||||||
patch_config(config, callbacks=run_manager.get_child()),
|
context.run(_set_config_context, child_config)
|
||||||
**kwargs,
|
coro = runnable.ainvoke(input, child_config, **kwargs)
|
||||||
)
|
if asyncio_accepts_context():
|
||||||
|
output = await asyncio.create_task(coro, context=context) # type: ignore
|
||||||
|
else:
|
||||||
|
output = await coro
|
||||||
except self.exceptions_to_handle as e:
|
except self.exceptions_to_handle as e:
|
||||||
if first_error is None:
|
if first_error is None:
|
||||||
first_error = e
|
first_error = e
|
||||||
@ -460,12 +468,15 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
try:
|
try:
|
||||||
if self.exception_key and last_error is not None:
|
if self.exception_key and last_error is not None:
|
||||||
input[self.exception_key] = last_error
|
input[self.exception_key] = last_error
|
||||||
stream = runnable.stream(
|
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||||
|
context = copy_context()
|
||||||
|
context.run(_set_config_context, child_config)
|
||||||
|
stream = context.run(
|
||||||
|
runnable.stream,
|
||||||
input,
|
input,
|
||||||
patch_config(config, callbacks=run_manager.get_child()),
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
chunk = next(stream)
|
chunk: Output = context.run(next, stream) # type: ignore
|
||||||
except self.exceptions_to_handle as e:
|
except self.exceptions_to_handle as e:
|
||||||
first_error = e if first_error is None else first_error
|
first_error = e if first_error is None else first_error
|
||||||
last_error = e
|
last_error = e
|
||||||
@ -520,12 +531,21 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
try:
|
try:
|
||||||
if self.exception_key and last_error is not None:
|
if self.exception_key and last_error is not None:
|
||||||
input[self.exception_key] = last_error
|
input[self.exception_key] = last_error
|
||||||
|
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||||
|
context = copy_context()
|
||||||
|
context.run(_set_config_context, child_config)
|
||||||
stream = runnable.astream(
|
stream = runnable.astream(
|
||||||
input,
|
input,
|
||||||
patch_config(config, callbacks=run_manager.get_child()),
|
child_config,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
chunk = await cast(Awaitable[Output], py_anext(stream))
|
if asyncio_accepts_context():
|
||||||
|
chunk: Output = await asyncio.create_task( # type: ignore[call-arg]
|
||||||
|
py_anext(stream), # type: ignore[arg-type]
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chunk = cast(Output, await py_anext(stream))
|
||||||
except self.exceptions_to_handle as e:
|
except self.exceptions_to_handle as e:
|
||||||
first_error = e if first_error is None else first_error
|
first_error = e if first_error is None else first_error
|
||||||
last_error = e
|
last_error = e
|
||||||
|
@ -118,6 +118,11 @@ def accepts_context(callable: Callable[..., Any]) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def asyncio_accepts_context() -> bool:
|
||||||
|
return accepts_context(asyncio.create_task)
|
||||||
|
|
||||||
|
|
||||||
class IsLocalDict(ast.NodeVisitor):
|
class IsLocalDict(ast.NodeVisitor):
|
||||||
"""Check if a name is a local dict."""
|
"""Check if a name is a local dict."""
|
||||||
|
|
||||||
|
@ -87,7 +87,7 @@ from langchain_core.runnables.config import (
|
|||||||
patch_config,
|
patch_config,
|
||||||
run_in_executor,
|
run_in_executor,
|
||||||
)
|
)
|
||||||
from langchain_core.runnables.utils import accepts_context
|
from langchain_core.runnables.utils import asyncio_accepts_context
|
||||||
from langchain_core.utils.function_calling import (
|
from langchain_core.utils.function_calling import (
|
||||||
_parse_google_docstring,
|
_parse_google_docstring,
|
||||||
_py_38_safe_origin,
|
_py_38_safe_origin,
|
||||||
@ -694,7 +694,7 @@ class ChildTool(BaseTool):
|
|||||||
tool_kwargs[config_param] = config
|
tool_kwargs[config_param] = config
|
||||||
|
|
||||||
coro = context.run(self._arun, *tool_args, **tool_kwargs)
|
coro = context.run(self._arun, *tool_args, **tool_kwargs)
|
||||||
if accepts_context(asyncio.create_task):
|
if asyncio_accepts_context():
|
||||||
response = await asyncio.create_task(coro, context=context) # type: ignore
|
response = await asyncio.create_task(coro, context=context) # type: ignore
|
||||||
else:
|
else:
|
||||||
response = await coro
|
response = await coro
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
|
from typing import Any, AsyncGenerator, Generator
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langsmith import Client, traceable
|
from langsmith import Client, traceable
|
||||||
from langsmith.run_helpers import tracing_context
|
from langsmith.run_helpers import tracing_context
|
||||||
|
|
||||||
from langchain_core.runnables.base import RunnableLambda
|
from langchain_core.runnables.base import RunnableLambda, RunnableParallel
|
||||||
from langchain_core.tracers.langchain import LangChainTracer
|
from langchain_core.tracers.langchain import LangChainTracer
|
||||||
|
|
||||||
|
|
||||||
@ -199,3 +200,141 @@ def test_tracing_enable_disable(
|
|||||||
assert len(mock_posts) == 1
|
assert len(mock_posts) == 1
|
||||||
else:
|
else:
|
||||||
assert not mock_posts
|
assert not mock_posts
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"method", ["invoke", "stream", "batch", "ainvoke", "astream", "abatch"]
|
||||||
|
)
|
||||||
|
async def test_runnable_sequence_parallel_trace_nesting(method: str) -> None:
|
||||||
|
if method.startswith("a") and sys.version_info < (3, 11):
|
||||||
|
pytest.skip("Asyncio context vars require Python 3.11+")
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_client_ = Client(
|
||||||
|
session=mock_session, api_key="test", auto_batch_tracing=False
|
||||||
|
)
|
||||||
|
tracer = LangChainTracer(client=mock_client_)
|
||||||
|
|
||||||
|
@RunnableLambda
|
||||||
|
def my_child_function(a: int) -> int:
|
||||||
|
return a + 2
|
||||||
|
|
||||||
|
if method.startswith("a"):
|
||||||
|
|
||||||
|
async def other_thing(a: int) -> AsyncGenerator[int, None]:
|
||||||
|
yield 1
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def other_thing(a: int) -> Generator[int, None, None]: # type: ignore
|
||||||
|
yield 1
|
||||||
|
|
||||||
|
parallel = RunnableParallel(
|
||||||
|
chain_result=my_child_function.with_config(tags=["atag"]),
|
||||||
|
other_thing=other_thing,
|
||||||
|
)
|
||||||
|
|
||||||
|
def before(x: int) -> int:
|
||||||
|
return x
|
||||||
|
|
||||||
|
def after(x: dict) -> int:
|
||||||
|
return x["chain_result"]
|
||||||
|
|
||||||
|
sequence = before | parallel | after
|
||||||
|
if method.startswith("a"):
|
||||||
|
|
||||||
|
@RunnableLambda # type: ignore
|
||||||
|
async def parent(a: int) -> int:
|
||||||
|
return await sequence.ainvoke(a)
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
@RunnableLambda
|
||||||
|
def parent(a: int) -> int:
|
||||||
|
return sequence.invoke(a)
|
||||||
|
|
||||||
|
# Now run the chain and check the resulting posts
|
||||||
|
cb = [tracer]
|
||||||
|
if method == "invoke":
|
||||||
|
res: Any = parent.invoke(1, {"callbacks": cb}) # type: ignore
|
||||||
|
elif method == "ainvoke":
|
||||||
|
res = await parent.ainvoke(1, {"callbacks": cb}) # type: ignore
|
||||||
|
elif method == "stream":
|
||||||
|
results = list(parent.stream(1, {"callbacks": cb})) # type: ignore
|
||||||
|
res = results[-1]
|
||||||
|
elif method == "astream":
|
||||||
|
results = [res async for res in parent.astream(1, {"callbacks": cb})] # type: ignore
|
||||||
|
res = results[-1]
|
||||||
|
elif method == "batch":
|
||||||
|
res = parent.batch([1], {"callbacks": cb})[0] # type: ignore
|
||||||
|
elif method == "abatch":
|
||||||
|
res = (await parent.abatch([1], {"callbacks": cb}))[0] # type: ignore
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown method {method}")
|
||||||
|
assert res == 3
|
||||||
|
posts = _get_posts(mock_client_)
|
||||||
|
name_order = [
|
||||||
|
"parent",
|
||||||
|
"RunnableSequence",
|
||||||
|
"before",
|
||||||
|
"RunnableParallel<chain_result,other_thing>",
|
||||||
|
["my_child_function", "other_thing"],
|
||||||
|
"after",
|
||||||
|
]
|
||||||
|
expected_parents = {
|
||||||
|
"parent": None,
|
||||||
|
"RunnableSequence": "parent",
|
||||||
|
"before": "RunnableSequence",
|
||||||
|
"RunnableParallel<chain_result,other_thing>": "RunnableSequence",
|
||||||
|
"my_child_function": "RunnableParallel<chain_result,other_thing>",
|
||||||
|
"other_thing": "RunnableParallel<chain_result,other_thing>",
|
||||||
|
"after": "RunnableSequence",
|
||||||
|
}
|
||||||
|
assert len(posts) == sum([1 if isinstance(n, str) else len(n) for n in name_order])
|
||||||
|
prev_dotted_order = None
|
||||||
|
dotted_order_map = {}
|
||||||
|
id_map = {}
|
||||||
|
parent_id_map = {}
|
||||||
|
i = 0
|
||||||
|
for name in name_order:
|
||||||
|
if isinstance(name, list):
|
||||||
|
for n in name:
|
||||||
|
matching_post = next(
|
||||||
|
p for p in posts[i : i + len(name)] if p["name"] == n
|
||||||
|
)
|
||||||
|
assert matching_post
|
||||||
|
dotted_order = matching_post["dotted_order"]
|
||||||
|
if prev_dotted_order is not None:
|
||||||
|
assert dotted_order > prev_dotted_order
|
||||||
|
dotted_order_map[n] = dotted_order
|
||||||
|
id_map[n] = matching_post["id"]
|
||||||
|
parent_id_map[n] = matching_post.get("parent_run_id")
|
||||||
|
i += len(name)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
assert posts[i]["name"] == name
|
||||||
|
dotted_order = posts[i]["dotted_order"]
|
||||||
|
if prev_dotted_order is not None and not str(
|
||||||
|
expected_parents[name]
|
||||||
|
).startswith("RunnableParallel"):
|
||||||
|
assert (
|
||||||
|
dotted_order > prev_dotted_order
|
||||||
|
), f"{name} not after {name_order[i-1]}"
|
||||||
|
prev_dotted_order = dotted_order
|
||||||
|
if name in dotted_order_map:
|
||||||
|
raise ValueError(f"Duplicate name {name}")
|
||||||
|
dotted_order_map[name] = dotted_order
|
||||||
|
id_map[name] = posts[i]["id"]
|
||||||
|
parent_id_map[name] = posts[i].get("parent_run_id")
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
# Now check the dotted orders
|
||||||
|
for name, parent_ in expected_parents.items():
|
||||||
|
dotted_order = dotted_order_map[name]
|
||||||
|
if parent_ is not None:
|
||||||
|
parent_dotted_order = dotted_order_map[parent_]
|
||||||
|
assert dotted_order.startswith(
|
||||||
|
parent_dotted_order
|
||||||
|
), f"{name}, {parent_dotted_order} not in {dotted_order}"
|
||||||
|
assert str(parent_id_map[name]) == str(id_map[parent_])
|
||||||
|
else:
|
||||||
|
assert dotted_order.split(".")[0] == dotted_order
|
||||||
|
Loading…
Reference in New Issue
Block a user