core: Add ruff rules for comprehensions (C4) (#26829)

This commit is contained in:
Christophe Bornet 2024-09-25 15:34:17 +02:00 committed by GitHub
parent 7e5a9c317f
commit 3a1b9259a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 259 additions and 265 deletions

View File

@ -86,9 +86,9 @@ def _config_with_context(
)
}
deps_by_key = {
key: set(
key: {
_key_from_id(dep) for spec in group for dep in (spec[0].dependencies or [])
)
}
for key, group in grouped_by_key.items()
}
@ -198,7 +198,7 @@ class ContextGet(RunnableSerializable):
configurable = config.get("configurable", {})
if isinstance(self.key, list):
values = await asyncio.gather(*(configurable[id_]() for id_ in self.ids))
return {key: value for key, value in zip(self.key, values)}
return dict(zip(self.key, values))
else:
return await configurable[self.ids[0]]()

View File

@ -551,7 +551,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
def _get_llm_string(self, stop: Optional[list[str]] = None, **kwargs: Any) -> str:
if self.is_lc_serializable():
params = {**kwargs, **{"stop": stop}}
param_string = str(sorted([(k, v) for k, v in params.items()]))
param_string = str(sorted(params.items()))
# This code is not super efficient as it goes back and forth between
# json and dict.
serialized_repr = self._serialized
@ -561,7 +561,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
else:
params = self._get_invocation_params(stop=stop, **kwargs)
params = {**params, **kwargs}
return str(sorted([(k, v) for k, v in params.items()]))
return str(sorted(params.items()))
def generate(
self,

View File

@ -166,7 +166,7 @@ def get_prompts(
Raises:
ValueError: If the cache is not set and cache is True.
"""
llm_string = str(sorted([(k, v) for k, v in params.items()]))
llm_string = str(sorted(params.items()))
missing_prompts = []
missing_prompt_idxs = []
existing_prompts = {}
@ -202,7 +202,7 @@ async def aget_prompts(
Raises:
ValueError: If the cache is not set and cache is True.
"""
llm_string = str(sorted([(k, v) for k, v in params.items()]))
llm_string = str(sorted(params.items()))
missing_prompts = []
missing_prompt_idxs = []
existing_prompts = {}

View File

@ -67,14 +67,14 @@ class Reviver:
Defaults to None.
"""
self.secrets_from_env = secrets_from_env
self.secrets_map = secrets_map or dict()
self.secrets_map = secrets_map or {}
# By default, only support langchain, but user can pass in additional namespaces
self.valid_namespaces = (
[*DEFAULT_NAMESPACES, *valid_namespaces]
if valid_namespaces
else DEFAULT_NAMESPACES
)
self.additional_import_mappings = additional_import_mappings or dict()
self.additional_import_mappings = additional_import_mappings or {}
self.import_mappings = (
{
**ALL_SERIALIZABLE_MAPPINGS,
@ -146,7 +146,7 @@ class Reviver:
# We don't need to recurse on kwargs
# as json.loads will do that for us.
kwargs = value.get("kwargs", dict())
kwargs = value.get("kwargs", {})
return cls(**kwargs)
return value

View File

@ -138,7 +138,7 @@ class Serializable(BaseModel, ABC):
For example,
{"openai_api_key": "OPENAI_API_KEY"}
"""
return dict()
return {}
@property
def lc_attributes(self) -> dict:
@ -188,7 +188,7 @@ class Serializable(BaseModel, ABC):
if not self.is_lc_serializable():
return self.to_json_not_implemented()
secrets = dict()
secrets = {}
# Get latest values for kwargs if there is an attribute with same name
lc_kwargs = {}
for k, v in self:

View File

@ -108,7 +108,7 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
return "Return a JSON object."
else:
# Copy schema to avoid altering original Pydantic schema.
schema = {k: v for k, v in self._get_schema(self.pydantic_object).items()}
schema = dict(self._get_schema(self.pydantic_object).items())
# Remove extraneous fields.
reduced_schema = schema

View File

@ -90,7 +90,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
The format instructions for the JSON output.
"""
# Copy schema to avoid altering original Pydantic schema.
schema = {k: v for k, v in self.pydantic_object.model_json_schema().items()}
schema = dict(self.pydantic_object.model_json_schema().items())
# Remove extraneous fields.
reduced_schema = schema

View File

@ -76,7 +76,7 @@ class LLMResult(BaseModel):
else:
if self.llm_output is not None:
llm_output = deepcopy(self.llm_output)
llm_output["token_usage"] = dict()
llm_output["token_usage"] = {}
else:
llm_output = None
llm_results.append(

View File

@ -1007,11 +1007,11 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
input_vars.update(_message.input_variables)
kwargs = {
**dict(
input_variables=sorted(input_vars),
optional_variables=sorted(optional_variables),
partial_variables=partial_vars,
),
**{
"input_variables": sorted(input_vars),
"optional_variables": sorted(optional_variables),
"partial_variables": partial_vars,
},
**kwargs,
}
cast(type[ChatPromptTemplate], super()).__init__(messages=_messages, **kwargs)

View File

@ -18,7 +18,7 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
if "input_variables" not in kwargs:
kwargs["input_variables"] = []
overlap = set(kwargs["input_variables"]) & set(("url", "path", "detail"))
overlap = set(kwargs["input_variables"]) & {"url", "path", "detail"}
if overlap:
raise ValueError(
"input_variables for the image template cannot contain"

View File

@ -144,7 +144,7 @@ class PromptTemplate(StringPromptTemplate):
template = self.template + other.template
# If any do not want to validate, then don't
validate_template = self.validate_template and other.validate_template
partial_variables = {k: v for k, v in self.partial_variables.items()}
partial_variables = dict(self.partial_variables.items())
for k, v in other.partial_variables.items():
if k in partial_variables:
raise ValueError("Cannot have same variable partialed twice.")

View File

@ -3778,7 +3778,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
for key, step in steps.items()
)
)
output = {key: value for key, value in zip(steps, results)}
output = dict(zip(steps, results))
# finish the root run
except BaseException as e:
await run_manager.on_chain_error(e)

View File

@ -294,7 +294,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
]
to_return: dict[int, Any] = {}
run_again = {i: input for i, input in enumerate(inputs)}
run_again = dict(enumerate(inputs))
handled_exceptions: dict[int, BaseException] = {}
first_to_raise = None
for runnable in self.runnables:
@ -388,7 +388,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
)
to_return = {}
run_again = {i: input for i, input in enumerate(inputs)}
run_again = dict(enumerate(inputs))
handled_exceptions: dict[int, BaseException] = {}
first_to_raise = None
for runnable in self.runnables:

View File

@ -117,7 +117,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
@property
def _kwargs_retrying(self) -> dict[str, Any]:
kwargs: dict[str, Any] = dict()
kwargs: dict[str, Any] = {}
if self.max_attempt_number:
kwargs["stop"] = stop_after_attempt(self.max_attempt_number)

View File

@ -10,7 +10,7 @@ def _get_sub_deps(packages: Sequence[str]) -> list[str]:
from importlib import metadata
sub_deps = set()
_underscored_packages = set(pkg.replace("-", "_") for pkg in packages)
_underscored_packages = {pkg.replace("-", "_") for pkg in packages}
for pkg in packages:
try:
@ -33,7 +33,7 @@ def _get_sub_deps(packages: Sequence[str]) -> list[str]:
return sorted(sub_deps, key=lambda x: x.lower())
def print_sys_info(*, additional_pkgs: Sequence[str] = tuple()) -> None:
def print_sys_info(*, additional_pkgs: Sequence[str] = ()) -> None:
"""Print information about the environment for debugging purposes.
Args:

View File

@ -975,7 +975,7 @@ def _get_all_basemodel_annotations(
) and name not in fields:
continue
annotations[name] = param.annotation
orig_bases: tuple = getattr(cls, "__orig_bases__", tuple())
orig_bases: tuple = getattr(cls, "__orig_bases__", ())
# cls has subscript: cls = FooBar[int]
else:
annotations = _get_all_basemodel_annotations(
@ -1007,11 +1007,9 @@ def _get_all_basemodel_annotations(
# parent_origin = Baz,
# generic_type_vars = (type vars in Baz)
# generic_map = {type var in Baz: str}
generic_type_vars: tuple = getattr(parent_origin, "__parameters__", tuple())
generic_map = {
type_var: t for type_var, t in zip(generic_type_vars, get_args(parent))
}
for field in getattr(parent_origin, "__annotations__", dict()):
generic_type_vars: tuple = getattr(parent_origin, "__parameters__", ())
generic_map = dict(zip(generic_type_vars, get_args(parent)))
for field in getattr(parent_origin, "__annotations__", {}):
annotations[field] = _replace_type_vars(
annotations[field], generic_map, default_to_bound
)

View File

@ -233,9 +233,7 @@ def _convert_any_typed_dicts_to_pydantic(
new_arg_type = _convert_any_typed_dicts_to_pydantic(
annotated_args[0], depth=depth + 1, visited=visited
)
field_kwargs = {
k: v for k, v in zip(("default", "description"), annotated_args[1:])
}
field_kwargs = dict(zip(("default", "description"), annotated_args[1:]))
if (field_desc := field_kwargs.get("description")) and not isinstance(
field_desc, str
):

View File

@ -44,7 +44,7 @@ python = ">=3.12.4"
[tool.poetry.extras]
[tool.ruff.lint]
select = [ "B", "E", "F", "I", "N", "T201", "UP",]
select = [ "B", "C4", "E", "F", "I", "N", "T201", "UP",]
ignore = [ "UP007",]
[tool.coverage.run]

View File

@ -54,5 +54,5 @@ def test_lazy_load() -> None:
expected.append(
Document(example.inputs["first"]["second"].upper(), metadata=metadata)
)
actual = [doc for doc in loader.lazy_load()]
actual = list(loader.lazy_load())
assert expected == actual

View File

@ -55,7 +55,7 @@ async def test_generic_fake_chat_model_stream() -> None:
]
assert len({chunk.id for chunk in chunks}) == 1
chunks = [chunk for chunk in model.stream("meow")]
chunks = list(model.stream("meow"))
assert chunks == [
_any_id_ai_message_chunk(content="hello"),
_any_id_ai_message_chunk(content=" "),

View File

@ -185,11 +185,11 @@ def test_index_simple_delete_full(
):
indexing_result = index(loader, record_manager, vector_store, cleanup="full")
doc_texts = set(
doc_texts = {
# Ignoring type since doc should be in the store and not a None
vector_store.get_by_ids([uid])[0].page_content # type: ignore
for uid in vector_store.store
)
}
assert doc_texts == {"mutated document 1", "This is another document."}
assert indexing_result == {
@ -267,11 +267,11 @@ async def test_aindex_simple_delete_full(
"num_updated": 0,
}
doc_texts = set(
doc_texts = {
# Ignoring type since doc should be in the store and not a None
vector_store.get_by_ids([uid])[0].page_content # type: ignore
for uid in vector_store.store
)
}
assert doc_texts == {"mutated document 1", "This is another document."}
# Attempt to index again verify that nothing changes
@ -558,11 +558,11 @@ def test_incremental_delete(
"num_updated": 0,
}
doc_texts = set(
doc_texts = {
# Ignoring type since doc should be in the store and not a None
vector_store.get_by_ids([uid])[0].page_content # type: ignore
for uid in vector_store.store
)
}
assert doc_texts == {"This is another document.", "This is a test document."}
# Attempt to index again verify that nothing changes
@ -617,11 +617,11 @@ def test_incremental_delete(
"num_updated": 0,
}
doc_texts = set(
doc_texts = {
# Ignoring type since doc should be in the store and not a None
vector_store.get_by_ids([uid])[0].page_content # type: ignore
for uid in vector_store.store
)
}
assert doc_texts == {
"mutated document 1",
"mutated document 2",
@ -685,11 +685,11 @@ def test_incremental_indexing_with_batch_size(
"num_updated": 0,
}
doc_texts = set(
doc_texts = {
# Ignoring type since doc should be in the store and not a None
vector_store.get_by_ids([uid])[0].page_content # type: ignore
for uid in vector_store.store
)
}
assert doc_texts == {"1", "2", "3", "4"}
@ -735,11 +735,11 @@ def test_incremental_delete_with_batch_size(
"num_updated": 0,
}
doc_texts = set(
doc_texts = {
# Ignoring type since doc should be in the store and not a None
vector_store.get_by_ids([uid])[0].page_content # type: ignore
for uid in vector_store.store
)
}
assert doc_texts == {"1", "2", "3", "4"}
# Attempt to index again verify that nothing changes
@ -880,11 +880,11 @@ async def test_aincremental_delete(
"num_updated": 0,
}
doc_texts = set(
doc_texts = {
# Ignoring type since doc should be in the store and not a None
vector_store.get_by_ids([uid])[0].page_content # type: ignore
for uid in vector_store.store
)
}
assert doc_texts == {"This is another document.", "This is a test document."}
# Attempt to index again verify that nothing changes
@ -939,11 +939,11 @@ async def test_aincremental_delete(
"num_updated": 0,
}
doc_texts = set(
doc_texts = {
# Ignoring type since doc should be in the store and not a None
vector_store.get_by_ids([uid])[0].page_content # type: ignore
for uid in vector_store.store
)
}
assert doc_texts == {
"mutated document 1",
"mutated document 2",

View File

@ -53,10 +53,10 @@ def test_batch_size(messages: list, messages_2: list) -> None:
with collect_runs() as cb:
llm.batch([messages, messages_2], {"callbacks": [cb]})
assert len(cb.traced_runs) == 2
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
assert all((r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs)
with collect_runs() as cb:
llm.batch([messages], {"callbacks": [cb]})
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
assert all((r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs)
assert len(cb.traced_runs) == 1
with collect_runs() as cb:
@ -76,11 +76,11 @@ async def test_async_batch_size(messages: list, messages_2: list) -> None:
# so we expect batch_size to always be 1
with collect_runs() as cb:
await llm.abatch([messages, messages_2], {"callbacks": [cb]})
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
assert all((r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs)
assert len(cb.traced_runs) == 2
with collect_runs() as cb:
await llm.abatch([messages], {"callbacks": [cb]})
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
assert all((r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs)
assert len(cb.traced_runs) == 1
with collect_runs() as cb:
@ -146,7 +146,7 @@ async def test_astream_fallback_to_ainvoke() -> None:
return "fake-chat-model"
model = ModelWithGenerate()
chunks = [chunk for chunk in model.stream("anything")]
chunks = list(model.stream("anything"))
assert chunks == [_any_id_ai_message(content="hello")]
chunks = [chunk async for chunk in model.astream("anything")]
@ -183,7 +183,7 @@ async def test_astream_implementation_fallback_to_stream() -> None:
return "fake-chat-model"
model = ModelWithSyncStream()
chunks = [chunk for chunk in model.stream("anything")]
chunks = list(model.stream("anything"))
assert chunks == [
_any_id_ai_message_chunk(content="a"),
_any_id_ai_message_chunk(content="b"),

View File

@ -262,7 +262,7 @@ def test_global_cache_stream() -> None:
AIMessage(content="goodbye world"),
]
model = GenericFakeChatModel(messages=iter(messages), cache=True)
chunks = [chunk for chunk in model.stream("some input")]
chunks = list(model.stream("some input"))
assert len(chunks) == 3
# Assert that streaming information gets cached
assert global_cache._cache != {}

View File

@ -40,12 +40,12 @@ def test_batch_size() -> None:
llm = FakeListLLM(responses=["foo"] * 3)
with collect_runs() as cb:
llm.batch(["foo", "bar", "foo"], {"callbacks": [cb]})
assert all([(r.extra or {}).get("batch_size") == 3 for r in cb.traced_runs])
assert all((r.extra or {}).get("batch_size") == 3 for r in cb.traced_runs)
assert len(cb.traced_runs) == 3
llm = FakeListLLM(responses=["foo"])
with collect_runs() as cb:
llm.batch(["foo"], {"callbacks": [cb]})
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
assert all((r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs)
assert len(cb.traced_runs) == 1
llm = FakeListLLM(responses=["foo"])
@ -71,12 +71,12 @@ async def test_async_batch_size() -> None:
llm = FakeListLLM(responses=["foo"] * 3)
with collect_runs() as cb:
await llm.abatch(["foo", "bar", "foo"], {"callbacks": [cb]})
assert all([(r.extra or {}).get("batch_size") == 3 for r in cb.traced_runs])
assert all((r.extra or {}).get("batch_size") == 3 for r in cb.traced_runs)
assert len(cb.traced_runs) == 3
llm = FakeListLLM(responses=["foo"])
with collect_runs() as cb:
await llm.abatch(["foo"], {"callbacks": [cb]})
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
assert all((r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs)
assert len(cb.traced_runs) == 1
llm = FakeListLLM(responses=["foo"])
@ -142,7 +142,7 @@ async def test_astream_fallback_to_ainvoke() -> None:
return "fake-chat-model"
model = ModelWithGenerate()
chunks = [chunk for chunk in model.stream("anything")]
chunks = list(model.stream("anything"))
assert chunks == ["hello"]
chunks = [chunk async for chunk in model.astream("anything")]
@ -179,7 +179,7 @@ async def test_astream_implementation_fallback_to_stream() -> None:
return "fake-chat-model"
model = ModelWithSyncStream()
chunks = [chunk for chunk in model.stream("anything")]
chunks = list(model.stream("anything"))
assert chunks == ["a", "b"]
assert type(model)._astream == BaseLLM._astream
astream_chunks = [chunk async for chunk in model.astream("anything")]

View File

@ -93,5 +93,5 @@ def test_base_transform_output_parser() -> None:
model = GenericFakeChatModel(messages=iter([AIMessage(content="hello world")]))
chain = model | StrInvertCase()
# inputs to models are ignored, response is hard-coded in model definition
chunks = [chunk for chunk in chain.stream("")]
chunks = list(chain.stream(""))
assert chunks == ["HELLO", " ", "WORLD"]

View File

@ -596,10 +596,10 @@ def test_base_model_schema_consistency() -> None:
setup: str
punchline: str
initial_joke_schema = {k: v for k, v in _schema(Joke).items()}
initial_joke_schema = dict(_schema(Joke).items())
SimpleJsonOutputParser(pydantic_object=Joke)
openai_func = convert_to_openai_function(Joke)
retrieved_joke_schema = {k: v for k, v in _schema(Joke).items()}
retrieved_joke_schema = dict(_schema(Joke).items())
assert initial_joke_schema == retrieved_joke_schema
assert openai_func.get("name", None) is not None

View File

@ -391,7 +391,7 @@ async def test_runnable_seq_streaming_chunks() -> None:
}
)
chunks = [c for c in chain.stream({"foo": "foo", "bar": "bar"})]
chunks = list(chain.stream({"foo": "foo", "bar": "bar"}))
achunks = [c async for c in chain.astream({"foo": "foo", "bar": "bar"})]
for c in chunks:
assert c in achunks

View File

@ -264,7 +264,7 @@ def test_fallbacks_stream() -> None:
runnable = RunnableGenerator(_generate_immediate_error).with_fallbacks(
[RunnableGenerator(_generate)]
)
assert list(runnable.stream({})) == [c for c in "foo bar"]
assert list(runnable.stream({})) == list("foo bar")
with pytest.raises(ValueError):
runnable = RunnableGenerator(_generate_delayed_error).with_fallbacks(

View File

@ -1065,7 +1065,7 @@ async def test_passthrough_tap_async(mocker: MockerFixture) -> None:
assert [
part
async for part in seq.astream(
"hello", dict(metadata={"key": "value"}), my_kwarg="value"
"hello", {"metadata": {"key": "value"}}, my_kwarg="value"
)
] == [5]
assert mock.call_args_list == [
@ -1125,12 +1125,9 @@ async def test_passthrough_tap_async(mocker: MockerFixture) -> None:
assert call in mock.call_args_list
mock.reset_mock()
assert [
part
for part in seq.stream(
"hello", dict(metadata={"key": "value"}), my_kwarg="value"
)
] == [5]
assert list(
seq.stream("hello", {"metadata": {"key": "value"}}, my_kwarg="value")
) == [5]
assert mock.call_args_list == [
mocker.call("hello", my_kwarg="value"),
mocker.call(5),
@ -1155,13 +1152,13 @@ async def test_with_config_metadata_passthrough(mocker: MockerFixture) -> None:
)
assert spy.call_args_list[0].args[1:] == (
"hello",
dict(
tags=["a-tag"],
callbacks=None,
recursion_limit=25,
configurable={"hello": "there", "__secret_key": "nahnah"},
metadata={"hello": "there", "bye": "now"},
),
{
"tags": ["a-tag"],
"callbacks": None,
"recursion_limit": 25,
"configurable": {"hello": "there", "__secret_key": "nahnah"},
"metadata": {"hello": "there", "bye": "now"},
},
)
spy.reset_mock()
@ -1174,7 +1171,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
assert spy.call_args_list == [
mocker.call(
"hello",
dict(tags=["a-tag"], metadata={}, configurable={}),
{"tags": ["a-tag"], "metadata": {}, "configurable": {}},
),
]
spy.reset_mock()
@ -1200,19 +1197,19 @@ async def test_with_config(mocker: MockerFixture) -> None:
assert [
*fake.with_config(tags=["a-tag"]).stream(
"hello", dict(metadata={"key": "value"})
"hello", {"metadata": {"key": "value"}}
)
] == [5]
assert spy.call_args_list == [
mocker.call(
"hello",
dict(tags=["a-tag"], metadata={"key": "value"}, configurable={}),
{"tags": ["a-tag"], "metadata": {"key": "value"}, "configurable": {}},
),
]
spy.reset_mock()
assert fake.with_config(recursion_limit=5).batch(
["hello", "wooorld"], [dict(tags=["a-tag"]), dict(metadata={"key": "value"})]
["hello", "wooorld"], [{"tags": ["a-tag"]}, {"metadata": {"key": "value"}}]
) == [5, 7]
assert len(spy.call_args_list) == 2
@ -1235,7 +1232,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
c
for c in fake.with_config(recursion_limit=5).batch_as_completed(
["hello", "wooorld"],
[dict(tags=["a-tag"]), dict(metadata={"key": "value"})],
[{"tags": ["a-tag"]}, {"metadata": {"key": "value"}}],
)
) == [(0, 5), (1, 7)]
@ -1256,7 +1253,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
spy.reset_mock()
assert fake.with_config(metadata={"a": "b"}).batch(
["hello", "wooorld"], dict(tags=["a-tag"])
["hello", "wooorld"], {"tags": ["a-tag"]}
) == [5, 7]
assert len(spy.call_args_list) == 2
for i, call in enumerate(spy.call_args_list):
@ -1266,7 +1263,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
spy.reset_mock()
assert sorted(
c for c in fake.batch_as_completed(["hello", "wooorld"], dict(tags=["a-tag"]))
c for c in fake.batch_as_completed(["hello", "wooorld"], {"tags": ["a-tag"]})
) == [(0, 5), (1, 7)]
assert len(spy.call_args_list) == 2
for i, call in enumerate(spy.call_args_list):
@ -1284,7 +1281,12 @@ async def test_with_config(mocker: MockerFixture) -> None:
assert spy.call_args_list == [
mocker.call(
"hello",
dict(callbacks=[handler], metadata={"a": "b"}, configurable={}, tags=[]),
{
"callbacks": [handler],
"metadata": {"a": "b"},
"configurable": {},
"tags": [],
},
),
]
spy.reset_mock()
@ -1293,12 +1295,12 @@ async def test_with_config(mocker: MockerFixture) -> None:
part async for part in fake.with_config(metadata={"a": "b"}).astream("hello")
] == [5]
assert spy.call_args_list == [
mocker.call("hello", dict(metadata={"a": "b"}, tags=[], configurable={})),
mocker.call("hello", {"metadata": {"a": "b"}, "tags": [], "configurable": {}}),
]
spy.reset_mock()
assert await fake.with_config(recursion_limit=5, tags=["c"]).abatch(
["hello", "wooorld"], dict(metadata={"key": "value"})
["hello", "wooorld"], {"metadata": {"key": "value"}}
) == [
5,
7,
@ -1306,23 +1308,23 @@ async def test_with_config(mocker: MockerFixture) -> None:
assert spy.call_args_list == [
mocker.call(
"hello",
dict(
metadata={"key": "value"},
tags=["c"],
callbacks=None,
recursion_limit=5,
configurable={},
),
{
"metadata": {"key": "value"},
"tags": ["c"],
"callbacks": None,
"recursion_limit": 5,
"configurable": {},
},
),
mocker.call(
"wooorld",
dict(
metadata={"key": "value"},
tags=["c"],
callbacks=None,
recursion_limit=5,
configurable={},
),
{
"metadata": {"key": "value"},
"tags": ["c"],
"callbacks": None,
"recursion_limit": 5,
"configurable": {},
},
),
]
spy.reset_mock()
@ -1332,7 +1334,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
c
async for c in fake.with_config(
recursion_limit=5, tags=["c"]
).abatch_as_completed(["hello", "wooorld"], dict(metadata={"key": "value"}))
).abatch_as_completed(["hello", "wooorld"], {"metadata": {"key": "value"}})
]
) == [
(0, 5),
@ -1342,24 +1344,24 @@ async def test_with_config(mocker: MockerFixture) -> None:
first_call = next(call for call in spy.call_args_list if call.args[0] == "hello")
assert first_call == mocker.call(
"hello",
dict(
metadata={"key": "value"},
tags=["c"],
callbacks=None,
recursion_limit=5,
configurable={},
),
{
"metadata": {"key": "value"},
"tags": ["c"],
"callbacks": None,
"recursion_limit": 5,
"configurable": {},
},
)
second_call = next(call for call in spy.call_args_list if call.args[0] == "wooorld")
assert second_call == mocker.call(
"wooorld",
dict(
metadata={"key": "value"},
tags=["c"],
callbacks=None,
recursion_limit=5,
configurable={},
),
{
"metadata": {"key": "value"},
"tags": ["c"],
"callbacks": None,
"recursion_limit": 5,
"configurable": {},
},
)
@ -1367,20 +1369,20 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
fake = FakeRunnable()
spy = mocker.spy(fake, "invoke")
assert fake.invoke("hello", dict(tags=["a-tag"])) == 5
assert fake.invoke("hello", {"tags": ["a-tag"]}) == 5
assert spy.call_args_list == [
mocker.call("hello", dict(tags=["a-tag"])),
mocker.call("hello", {"tags": ["a-tag"]}),
]
spy.reset_mock()
assert [*fake.stream("hello", dict(metadata={"key": "value"}))] == [5]
assert [*fake.stream("hello", {"metadata": {"key": "value"}})] == [5]
assert spy.call_args_list == [
mocker.call("hello", dict(metadata={"key": "value"})),
mocker.call("hello", {"metadata": {"key": "value"}}),
]
spy.reset_mock()
assert fake.batch(
["hello", "wooorld"], [dict(tags=["a-tag"]), dict(metadata={"key": "value"})]
["hello", "wooorld"], [{"tags": ["a-tag"]}, {"metadata": {"key": "value"}}]
) == [5, 7]
assert len(spy.call_args_list) == 2
@ -1398,9 +1400,9 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
spy.reset_mock()
assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7]
assert fake.batch(["hello", "wooorld"], {"tags": ["a-tag"]}) == [5, 7]
assert len(spy.call_args_list) == 2
assert set(call.args[0] for call in spy.call_args_list) == {"hello", "wooorld"}
assert {call.args[0] for call in spy.call_args_list} == {"hello", "wooorld"}
for call in spy.call_args_list:
assert call.args[1].get("tags") == ["a-tag"]
assert call.args[1].get("metadata") == {}
@ -1408,7 +1410,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
assert await fake.ainvoke("hello", config={"callbacks": []}) == 5
assert spy.call_args_list == [
mocker.call("hello", dict(callbacks=[])),
mocker.call("hello", {"callbacks": []}),
]
spy.reset_mock()
@ -1418,19 +1420,19 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
]
spy.reset_mock()
assert await fake.abatch(["hello", "wooorld"], dict(metadata={"key": "value"})) == [
assert await fake.abatch(["hello", "wooorld"], {"metadata": {"key": "value"}}) == [
5,
7,
]
assert set(call.args[0] for call in spy.call_args_list) == {"hello", "wooorld"}
assert {call.args[0] for call in spy.call_args_list} == {"hello", "wooorld"}
for call in spy.call_args_list:
assert call.args[1] == dict(
metadata={"key": "value"},
tags=[],
callbacks=None,
recursion_limit=25,
configurable={},
)
assert call.args[1] == {
"metadata": {"key": "value"},
"tags": [],
"callbacks": None,
"recursion_limit": 25,
"configurable": {},
}
async def test_prompt() -> None:
@ -1698,7 +1700,7 @@ def test_prompt_with_chat_model(
chat_spy = mocker.spy(chat.__class__, "invoke")
tracer = FakeTracer()
assert chain.invoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
{"question": "What is your name?"}, {"callbacks": [tracer]}
) == _any_id_ai_message(content="foo")
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue(
@ -1722,7 +1724,7 @@ def test_prompt_with_chat_model(
{"question": "What is your name?"},
{"question": "What is your favorite color?"},
],
dict(callbacks=[tracer]),
{"callbacks": [tracer]},
) == [
_any_id_ai_message(content="foo"),
_any_id_ai_message(content="foo"),
@ -1763,7 +1765,7 @@ def test_prompt_with_chat_model(
chat_spy = mocker.spy(chat.__class__, "stream")
tracer = FakeTracer()
assert [
*chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer]))
*chain.stream({"question": "What is your name?"}, {"callbacks": [tracer]})
] == [
_any_id_ai_message_chunk(content="f"),
_any_id_ai_message_chunk(content="o"),
@ -1804,7 +1806,7 @@ async def test_prompt_with_chat_model_async(
chat_spy = mocker.spy(chat.__class__, "ainvoke")
tracer = FakeTracer()
assert await chain.ainvoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
{"question": "What is your name?"}, {"callbacks": [tracer]}
) == _any_id_ai_message(content="foo")
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue(
@ -1828,7 +1830,7 @@ async def test_prompt_with_chat_model_async(
{"question": "What is your name?"},
{"question": "What is your favorite color?"},
],
dict(callbacks=[tracer]),
{"callbacks": [tracer]},
) == [
_any_id_ai_message(content="foo"),
_any_id_ai_message(content="foo"),
@ -1871,7 +1873,7 @@ async def test_prompt_with_chat_model_async(
assert [
a
async for a in chain.astream(
{"question": "What is your name?"}, dict(callbacks=[tracer])
{"question": "What is your name?"}, {"callbacks": [tracer]}
)
] == [
_any_id_ai_message_chunk(content="f"),
@ -1910,9 +1912,7 @@ async def test_prompt_with_llm(
llm_spy = mocker.spy(llm.__class__, "ainvoke")
tracer = FakeTracer()
assert (
await chain.ainvoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
)
await chain.ainvoke({"question": "What is your name?"}, {"callbacks": [tracer]})
== "foo"
)
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
@ -1935,7 +1935,7 @@ async def test_prompt_with_llm(
{"question": "What is your name?"},
{"question": "What is your favorite color?"},
],
dict(callbacks=[tracer]),
{"callbacks": [tracer]},
) == ["bar", "foo"]
assert prompt_spy.call_args.args[1] == [
{"question": "What is your name?"},
@ -1966,7 +1966,7 @@ async def test_prompt_with_llm(
assert [
token
async for token in chain.astream(
{"question": "What is your name?"}, dict(callbacks=[tracer])
{"question": "What is your name?"}, {"callbacks": [tracer]}
)
] == ["bar"]
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
@ -2110,7 +2110,7 @@ async def test_prompt_with_llm_parser(
parser_spy = mocker.spy(parser.__class__, "ainvoke")
tracer = FakeTracer()
assert await chain.ainvoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
{"question": "What is your name?"}, {"callbacks": [tracer]}
) == ["bear", "dog", "cat"]
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert llm_spy.call_args.args[1] == ChatPromptValue(
@ -2135,7 +2135,7 @@ async def test_prompt_with_llm_parser(
{"question": "What is your name?"},
{"question": "What is your favorite color?"},
],
dict(callbacks=[tracer]),
{"callbacks": [tracer]},
) == [["tomato", "lettuce", "onion"], ["bear", "dog", "cat"]]
assert prompt_spy.call_args.args[1] == [
{"question": "What is your name?"},
@ -2171,7 +2171,7 @@ async def test_prompt_with_llm_parser(
assert [
token
async for token in chain.astream(
{"question": "What is your name?"}, dict(callbacks=[tracer])
{"question": "What is your name?"}, {"callbacks": [tracer]}
)
] == [["tomato"], ["lettuce"], ["onion"]]
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
@ -2495,9 +2495,7 @@ async def test_prompt_with_llm_and_async_lambda(
llm_spy = mocker.spy(llm.__class__, "ainvoke")
tracer = FakeTracer()
assert (
await chain.ainvoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
)
await chain.ainvoke({"question": "What is your name?"}, {"callbacks": [tracer]})
== "foo"
)
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
@ -2539,7 +2537,7 @@ def test_prompt_with_chat_model_and_parser(
parser_spy = mocker.spy(parser.__class__, "invoke")
tracer = FakeTracer()
assert chain.invoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
{"question": "What is your name?"}, {"callbacks": [tracer]}
) == ["foo", "bar"]
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue(
@ -2608,7 +2606,7 @@ def test_combining_sequences(
# Test invoke
tracer = FakeTracer()
assert combined_chain.invoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
{"question": "What is your name?"}, {"callbacks": [tracer]}
) == ["baz", "qux"]
assert tracer.runs == snapshot
@ -2658,7 +2656,7 @@ Question:
chat_spy = mocker.spy(chat.__class__, "invoke")
parser_spy = mocker.spy(parser.__class__, "invoke")
tracer = FakeTracer()
assert chain.invoke("What is your name?", dict(callbacks=[tracer])) == [
assert chain.invoke("What is your name?", {"callbacks": [tracer]}) == [
"foo",
"bar",
]
@ -2725,7 +2723,7 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
llm_spy = mocker.spy(llm.__class__, "invoke")
tracer = FakeTracer()
assert chain.invoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
{"question": "What is your name?"}, {"callbacks": [tracer]}
) == {
"chat": _any_id_ai_message(content="i'm a chatbot"),
"llm": "i'm a textbot",
@ -2788,7 +2786,7 @@ async def test_router_runnable(
router_spy = mocker.spy(router.__class__, "invoke")
tracer = FakeTracer()
assert (
chain.invoke({"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer]))
chain.invoke({"key": "math", "question": "2 + 2"}, {"callbacks": [tracer]})
== "4"
)
assert router_spy.call_args.args[1] == {
@ -2849,7 +2847,7 @@ async def test_higher_order_lambda_runnable(
math_spy = mocker.spy(math_chain.__class__, "invoke")
tracer = FakeTracer()
assert (
chain.invoke({"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer]))
chain.invoke({"key": "math", "question": "2 + 2"}, {"callbacks": [tracer]})
== "4"
)
assert math_spy.call_args.args[1] == {
@ -2880,7 +2878,7 @@ async def test_higher_order_lambda_runnable(
tracer = FakeTracer()
assert (
await achain.ainvoke(
{"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer])
{"key": "math", "question": "2 + 2"}, {"callbacks": [tracer]}
)
== "4"
)
@ -2934,7 +2932,7 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N
llm_spy = mocker.spy(llm.__class__, "invoke")
tracer = FakeTracer()
assert chain.invoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
{"question": "What is your name?"}, {"callbacks": [tracer]}
) == {
"chat": _any_id_ai_message(content="i'm a chatbot"),
"llm": "i'm a textbot",
@ -3841,7 +3839,7 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
def test_runnable_lambda_stream() -> None:
"""Test that stream works for both normal functions & those returning Runnable."""
# Normal output should work
output: list[Any] = [chunk for chunk in RunnableLambda(range).stream(5)]
output: list[Any] = list(RunnableLambda(range).stream(5))
assert output == [range(5)]
# Runnable output should also work
@ -4015,7 +4013,7 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
spy = mocker.spy(ControlledExceptionRunnable, "batch")
tracer = FakeTracer()
inputs = ["foo", "bar", "baz", "qux"]
outputs = chain.batch(inputs, dict(callbacks=[tracer]), return_exceptions=True)
outputs = chain.batch(inputs, {"callbacks": [tracer]}, return_exceptions=True)
assert len(outputs) == 4
assert isinstance(outputs[0], ValueError)
assert isinstance(outputs[1], ValueError)
@ -4135,7 +4133,7 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
tracer = FakeTracer()
inputs = ["foo", "bar", "baz", "qux"]
outputs = await chain.abatch(
inputs, dict(callbacks=[tracer]), return_exceptions=True
inputs, {"callbacks": [tracer]}, return_exceptions=True
)
assert len(outputs) == 4
assert isinstance(outputs[0], ValueError)
@ -5080,13 +5078,13 @@ def test_invoke_stream_passthrough_assign_trace() -> None:
chain = RunnablePassthrough.assign(urls=idchain_sync)
tracer = FakeTracer()
chain.invoke({"example": [1, 2, 3]}, dict(callbacks=[tracer]))
chain.invoke({"example": [1, 2, 3]}, {"callbacks": [tracer]})
assert tracer.runs[0].name == "RunnableAssign<urls>"
assert tracer.runs[0].child_runs[0].name == "RunnableParallel<urls>"
tracer = FakeTracer()
for _ in chain.stream({"example": [1, 2, 3]}, dict(callbacks=[tracer])):
for _ in chain.stream({"example": [1, 2, 3]}, {"callbacks": [tracer]}):
pass
assert tracer.runs[0].name == "RunnableAssign<urls>"
@ -5100,13 +5098,13 @@ async def test_ainvoke_astream_passthrough_assign_trace() -> None:
chain = RunnablePassthrough.assign(urls=idchain_sync)
tracer = FakeTracer()
await chain.ainvoke({"example": [1, 2, 3]}, dict(callbacks=[tracer]))
await chain.ainvoke({"example": [1, 2, 3]}, {"callbacks": [tracer]})
assert tracer.runs[0].name == "RunnableAssign<urls>"
assert tracer.runs[0].child_runs[0].name == "RunnableParallel<urls>"
tracer = FakeTracer()
async for _ in chain.astream({"example": [1, 2, 3]}, dict(callbacks=[tracer])):
async for _ in chain.astream({"example": [1, 2, 3]}, {"callbacks": [tracer]}):
pass
assert tracer.runs[0].name == "RunnableAssign<urls>"
@ -5260,7 +5258,7 @@ async def test_default_atransform_with_dicts() -> None:
def test_passthrough_transform_with_dicts() -> None:
"""Test that default transform works with dicts."""
runnable = RunnablePassthrough(lambda x: x)
chunks = [chunk for chunk in runnable.transform(iter([{"foo": "a"}, {"foo": "n"}]))]
chunks = list(runnable.transform(iter([{"foo": "a"}, {"foo": "n"}])))
assert chunks == [{"foo": "a"}, {"foo": "n"}]

View File

@ -2033,7 +2033,7 @@ async def test_sync_in_async_stream_lambdas() -> None:
async def add_one_proxy_(x: int, config: RunnableConfig) -> int:
streaming = add_one.stream(x, config)
results = [result for result in streaming]
results = list(streaming)
return results[0]
add_one_proxy = RunnableLambda(add_one_proxy_) # type: ignore
@ -2078,7 +2078,7 @@ async def test_sync_in_sync_lambdas() -> None:
def add_one_proxy(x: int, config: RunnableConfig) -> int:
# Use sync streaming
streaming = add_one_.stream(x, config)
results = [result for result in streaming]
results = list(streaming)
return results[0]
add_one_proxy_ = RunnableLambda(add_one_proxy)

View File

@ -1995,7 +1995,7 @@ async def test_sync_in_async_stream_lambdas() -> None:
async def add_one_proxy(x: int, config: RunnableConfig) -> int:
streaming = add_one_.stream(x, config)
results = [result for result in streaming]
results = list(streaming)
return results[0]
add_one_proxy_ = RunnableLambda(add_one_proxy) # type: ignore
@ -2035,7 +2035,7 @@ async def test_sync_in_sync_lambdas() -> None:
def add_one_proxy(x: int, config: RunnableConfig) -> int:
# Use sync streaming
streaming = add_one_.stream(x, config)
results = [result for result in streaming]
results = list(streaming)
return results[0]
add_one_proxy_ = RunnableLambda(add_one_proxy)

View File

@ -448,10 +448,10 @@ def test_message_chunk_to_message() -> None:
def test_tool_calls_merge() -> None:
chunks: list[dict] = [
dict(content=""),
dict(
content="",
additional_kwargs={
{"content": ""},
{
"content": "",
"additional_kwargs": {
"tool_calls": [
{
"index": 0,
@ -461,10 +461,10 @@ def test_tool_calls_merge() -> None:
}
]
},
),
dict(
content="",
additional_kwargs={
},
{
"content": "",
"additional_kwargs": {
"tool_calls": [
{
"index": 0,
@ -474,10 +474,10 @@ def test_tool_calls_merge() -> None:
}
]
},
),
dict(
content="",
additional_kwargs={
},
{
"content": "",
"additional_kwargs": {
"tool_calls": [
{
"index": 0,
@ -487,10 +487,10 @@ def test_tool_calls_merge() -> None:
}
]
},
),
dict(
content="",
additional_kwargs={
},
{
"content": "",
"additional_kwargs": {
"tool_calls": [
{
"index": 0,
@ -500,10 +500,10 @@ def test_tool_calls_merge() -> None:
}
]
},
),
dict(
content="",
additional_kwargs={
},
{
"content": "",
"additional_kwargs": {
"tool_calls": [
{
"index": 0,
@ -513,10 +513,10 @@ def test_tool_calls_merge() -> None:
}
]
},
),
dict(
content="",
additional_kwargs={
},
{
"content": "",
"additional_kwargs": {
"tool_calls": [
{
"index": 0,
@ -526,10 +526,10 @@ def test_tool_calls_merge() -> None:
}
]
},
),
dict(
content="",
additional_kwargs={
},
{
"content": "",
"additional_kwargs": {
"tool_calls": [
{
"index": 0,
@ -539,10 +539,10 @@ def test_tool_calls_merge() -> None:
}
]
},
),
dict(
content="",
additional_kwargs={
},
{
"content": "",
"additional_kwargs": {
"tool_calls": [
{
"index": 1,
@ -552,10 +552,10 @@ def test_tool_calls_merge() -> None:
}
]
},
),
dict(
content="",
additional_kwargs={
},
{
"content": "",
"additional_kwargs": {
"tool_calls": [
{
"index": 1,
@ -565,10 +565,10 @@ def test_tool_calls_merge() -> None:
}
]
},
),
dict(
content="",
additional_kwargs={
},
{
"content": "",
"additional_kwargs": {
"tool_calls": [
{
"index": 1,
@ -578,10 +578,10 @@ def test_tool_calls_merge() -> None:
}
]
},
),
dict(
content="",
additional_kwargs={
},
{
"content": "",
"additional_kwargs": {
"tool_calls": [
{
"index": 1,
@ -591,10 +591,10 @@ def test_tool_calls_merge() -> None:
}
]
},
),
dict(
content="",
additional_kwargs={
},
{
"content": "",
"additional_kwargs": {
"tool_calls": [
{
"index": 1,
@ -604,10 +604,10 @@ def test_tool_calls_merge() -> None:
}
]
},
),
dict(
content="",
additional_kwargs={
},
{
"content": "",
"additional_kwargs": {
"tool_calls": [
{
"index": 1,
@ -617,10 +617,10 @@ def test_tool_calls_merge() -> None:
}
]
},
),
dict(
content="",
additional_kwargs={
},
{
"content": "",
"additional_kwargs": {
"tool_calls": [
{
"index": 1,
@ -630,8 +630,8 @@ def test_tool_calls_merge() -> None:
}
]
},
),
dict(content=""),
},
{"content": ""},
]
final = None

View File

@ -98,7 +98,7 @@ async def test_tracer_chat_model_run() -> None:
],
extra={},
serialized=SERIALIZED_CHAT,
inputs=dict(prompts=["Human: "]),
inputs={"prompts": ["Human: "]},
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
error=None,
run_type="llm",
@ -134,7 +134,7 @@ async def test_tracer_multiple_llm_runs() -> None:
],
extra={},
serialized=SERIALIZED,
inputs=dict(prompts=[]),
inputs={"prompts": []},
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
error=None,
run_type="llm",
@ -272,8 +272,8 @@ async def test_tracer_nested_run() -> None:
],
extra={},
serialized={"name": "tool"},
inputs=dict(input="test"),
outputs=dict(output="test"),
inputs={"input": "test"},
outputs={"output": "test"},
error=None,
run_type="tool",
trace_id=chain_uuid,
@ -291,7 +291,7 @@ async def test_tracer_nested_run() -> None:
],
extra={},
serialized=SERIALIZED,
inputs=dict(prompts=[]),
inputs={"prompts": []},
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
run_type="llm",
trace_id=chain_uuid,
@ -311,7 +311,7 @@ async def test_tracer_nested_run() -> None:
],
extra={},
serialized=SERIALIZED,
inputs=dict(prompts=[]),
inputs={"prompts": []},
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
run_type="llm",
trace_id=chain_uuid,
@ -339,7 +339,7 @@ async def test_tracer_llm_run_on_error() -> None:
],
extra={},
serialized=SERIALIZED,
inputs=dict(prompts=[]),
inputs={"prompts": []},
outputs=None,
error=repr(exception),
run_type="llm",
@ -370,7 +370,7 @@ async def test_tracer_llm_run_on_error_callback() -> None:
],
extra={},
serialized=SERIALIZED,
inputs=dict(prompts=[]),
inputs={"prompts": []},
outputs=None,
error=repr(exception),
run_type="llm",
@ -436,7 +436,7 @@ async def test_tracer_tool_run_on_error() -> None:
],
extra={},
serialized={"name": "tool"},
inputs=dict(input="test"),
inputs={"input": "test"},
outputs=None,
error=repr(exception),
run_type="tool",
@ -527,7 +527,7 @@ async def test_tracer_nested_runs_on_error() -> None:
extra={},
serialized=SERIALIZED,
error=None,
inputs=dict(prompts=[]),
inputs={"prompts": []},
outputs=LLMResult(generations=[[]], llm_output=None), # type: ignore[arg-type]
run_type="llm",
trace_id=chain_uuid,
@ -545,7 +545,7 @@ async def test_tracer_nested_runs_on_error() -> None:
extra={},
serialized=SERIALIZED,
error=None,
inputs=dict(prompts=[]),
inputs={"prompts": []},
outputs=LLMResult(generations=[[]], llm_output=None), # type: ignore[arg-type]
run_type="llm",
trace_id=chain_uuid,
@ -563,7 +563,7 @@ async def test_tracer_nested_runs_on_error() -> None:
extra={},
serialized={"name": "tool"},
error=repr(exception),
inputs=dict(input="test"),
inputs={"input": "test"},
outputs=None,
trace_id=chain_uuid,
dotted_order=f"20230101T000000000000Z{chain_uuid}.20230101T000000000000Z{tool_uuid}",
@ -580,7 +580,7 @@ async def test_tracer_nested_runs_on_error() -> None:
extra={},
serialized=SERIALIZED,
error=repr(exception),
inputs=dict(prompts=[]),
inputs={"prompts": []},
outputs=None,
run_type="llm",
trace_id=chain_uuid,

View File

@ -103,7 +103,7 @@ def test_tracer_chat_model_run() -> None:
],
extra={},
serialized=SERIALIZED_CHAT,
inputs=dict(prompts=["Human: "]),
inputs={"prompts": ["Human: "]},
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
error=None,
run_type="llm",
@ -139,7 +139,7 @@ def test_tracer_multiple_llm_runs() -> None:
],
extra={},
serialized=SERIALIZED,
inputs=dict(prompts=[]),
inputs={"prompts": []},
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
error=None,
run_type="llm",
@ -275,8 +275,8 @@ def test_tracer_nested_run() -> None:
],
extra={},
serialized={"name": "tool"},
inputs=dict(input="test"),
outputs=dict(output="test"),
inputs={"input": "test"},
outputs={"output": "test"},
error=None,
run_type="tool",
trace_id=chain_uuid,
@ -294,7 +294,7 @@ def test_tracer_nested_run() -> None:
],
extra={},
serialized=SERIALIZED,
inputs=dict(prompts=[]),
inputs={"prompts": []},
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
run_type="llm",
trace_id=chain_uuid,
@ -314,7 +314,7 @@ def test_tracer_nested_run() -> None:
],
extra={},
serialized=SERIALIZED,
inputs=dict(prompts=[]),
inputs={"prompts": []},
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
run_type="llm",
trace_id=chain_uuid,
@ -342,7 +342,7 @@ def test_tracer_llm_run_on_error() -> None:
],
extra={},
serialized=SERIALIZED,
inputs=dict(prompts=[]),
inputs={"prompts": []},
outputs=None,
error=repr(exception),
run_type="llm",
@ -373,7 +373,7 @@ def test_tracer_llm_run_on_error_callback() -> None:
],
extra={},
serialized=SERIALIZED,
inputs=dict(prompts=[]),
inputs={"prompts": []},
outputs=None,
error=repr(exception),
run_type="llm",
@ -439,7 +439,7 @@ def test_tracer_tool_run_on_error() -> None:
],
extra={},
serialized={"name": "tool"},
inputs=dict(input="test"),
inputs={"input": "test"},
outputs=None,
error=repr(exception),
run_type="tool",
@ -528,7 +528,7 @@ def test_tracer_nested_runs_on_error() -> None:
extra={},
serialized=SERIALIZED,
error=None,
inputs=dict(prompts=[]),
inputs={"prompts": []},
outputs=LLMResult(generations=[[]], llm_output=None), # type: ignore[arg-type]
run_type="llm",
trace_id=chain_uuid,
@ -546,7 +546,7 @@ def test_tracer_nested_runs_on_error() -> None:
extra={},
serialized=SERIALIZED,
error=None,
inputs=dict(prompts=[]),
inputs={"prompts": []},
outputs=LLMResult(generations=[[]], llm_output=None), # type: ignore[arg-type]
run_type="llm",
trace_id=chain_uuid,
@ -564,7 +564,7 @@ def test_tracer_nested_runs_on_error() -> None:
extra={},
serialized={"name": "tool"},
error=repr(exception),
inputs=dict(input="test"),
inputs={"input": "test"},
outputs=None,
trace_id=chain_uuid,
dotted_order=f"20230101T000000000000Z{chain_uuid}.20230101T000000000000Z{tool_uuid}",
@ -581,7 +581,7 @@ def test_tracer_nested_runs_on_error() -> None:
extra={},
serialized=SERIALIZED,
error=repr(exception),
inputs=dict(prompts=[]),
inputs={"prompts": []},
outputs=None,
run_type="llm",
trace_id=chain_uuid,