mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 06:39:52 +00:00
core: fix batch ordering test (#20952)
This commit is contained in:
parent
8ed150b2fe
commit
d4befd0cfb
@ -138,16 +138,16 @@ class FakeTracer(BaseTracer):
|
||||
"child_execution_order": None,
|
||||
"trace_id": self._replace_uuid(run.trace_id) if run.trace_id else None,
|
||||
"dotted_order": new_dotted_order,
|
||||
"inputs": {
|
||||
k: self._replace_message_id(v) for k, v in run.inputs.items()
|
||||
}
|
||||
if isinstance(run.inputs, dict)
|
||||
else run.inputs,
|
||||
"outputs": {
|
||||
k: self._replace_message_id(v) for k, v in run.outputs.items()
|
||||
}
|
||||
if isinstance(run.outputs, dict)
|
||||
else run.outputs,
|
||||
"inputs": (
|
||||
{k: self._replace_message_id(v) for k, v in run.inputs.items()}
|
||||
if isinstance(run.inputs, dict)
|
||||
else run.inputs
|
||||
),
|
||||
"outputs": (
|
||||
{k: self._replace_message_id(v) for k, v in run.outputs.items()}
|
||||
if isinstance(run.outputs, dict)
|
||||
else run.outputs
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
@ -1652,11 +1652,14 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
||||
|
||||
assert len(spy.call_args_list) == 2
|
||||
for i, call in enumerate(spy.call_args_list):
|
||||
assert call.args[0] == ("hello" if i == 0 else "wooorld")
|
||||
if i == 0:
|
||||
call_arg = call.args[0]
|
||||
|
||||
if call_arg == "hello":
|
||||
assert call_arg == "hello"
|
||||
assert call.args[1].get("tags") == ["a-tag"]
|
||||
assert call.args[1].get("metadata") == {}
|
||||
else:
|
||||
assert call_arg == "wooorld"
|
||||
assert call.args[1].get("tags") == []
|
||||
assert call.args[1].get("metadata") == {"key": "value"}
|
||||
|
||||
@ -1664,8 +1667,8 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
||||
|
||||
assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7]
|
||||
assert len(spy.call_args_list) == 2
|
||||
assert set(call.args[0] for call in spy.call_args_list) == {"hello", "wooorld"}
|
||||
for i, call in enumerate(spy.call_args_list):
|
||||
assert call.args[0] == ("hello" if i == 0 else "wooorld")
|
||||
assert call.args[1].get("tags") == ["a-tag"]
|
||||
assert call.args[1].get("metadata") == {}
|
||||
spy.reset_mock()
|
||||
@ -1686,28 +1689,15 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
||||
5,
|
||||
7,
|
||||
]
|
||||
assert spy.call_args_list == [
|
||||
mocker.call(
|
||||
"hello",
|
||||
dict(
|
||||
metadata={"key": "value"},
|
||||
tags=[],
|
||||
callbacks=None,
|
||||
recursion_limit=25,
|
||||
run_id=None,
|
||||
),
|
||||
),
|
||||
mocker.call(
|
||||
"wooorld",
|
||||
dict(
|
||||
metadata={"key": "value"},
|
||||
tags=[],
|
||||
callbacks=None,
|
||||
recursion_limit=25,
|
||||
run_id=None,
|
||||
),
|
||||
),
|
||||
]
|
||||
assert set(call.args[0] for call in spy.call_args_list) == {"hello", "wooorld"}
|
||||
for call in spy.call_args_list:
|
||||
assert call.args[1] == dict(
|
||||
metadata={"key": "value"},
|
||||
tags=[],
|
||||
callbacks=None,
|
||||
recursion_limit=25,
|
||||
run_id=None,
|
||||
)
|
||||
|
||||
|
||||
async def test_prompt() -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user