core: Add ruff rules for comprehensions (C4) (#26829)

This commit is contained in:
Christophe Bornet 2024-09-25 15:34:17 +02:00 committed by GitHub
parent 7e5a9c317f
commit 3a1b9259a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 259 additions and 265 deletions

View File

@ -86,9 +86,9 @@ def _config_with_context(
) )
} }
deps_by_key = { deps_by_key = {
key: set( key: {
_key_from_id(dep) for spec in group for dep in (spec[0].dependencies or []) _key_from_id(dep) for spec in group for dep in (spec[0].dependencies or [])
) }
for key, group in grouped_by_key.items() for key, group in grouped_by_key.items()
} }
@ -198,7 +198,7 @@ class ContextGet(RunnableSerializable):
configurable = config.get("configurable", {}) configurable = config.get("configurable", {})
if isinstance(self.key, list): if isinstance(self.key, list):
values = await asyncio.gather(*(configurable[id_]() for id_ in self.ids)) values = await asyncio.gather(*(configurable[id_]() for id_ in self.ids))
return {key: value for key, value in zip(self.key, values)} return dict(zip(self.key, values))
else: else:
return await configurable[self.ids[0]]() return await configurable[self.ids[0]]()

View File

@ -551,7 +551,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
def _get_llm_string(self, stop: Optional[list[str]] = None, **kwargs: Any) -> str: def _get_llm_string(self, stop: Optional[list[str]] = None, **kwargs: Any) -> str:
if self.is_lc_serializable(): if self.is_lc_serializable():
params = {**kwargs, **{"stop": stop}} params = {**kwargs, **{"stop": stop}}
param_string = str(sorted([(k, v) for k, v in params.items()])) param_string = str(sorted(params.items()))
# This code is not super efficient as it goes back and forth between # This code is not super efficient as it goes back and forth between
# json and dict. # json and dict.
serialized_repr = self._serialized serialized_repr = self._serialized
@ -561,7 +561,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
else: else:
params = self._get_invocation_params(stop=stop, **kwargs) params = self._get_invocation_params(stop=stop, **kwargs)
params = {**params, **kwargs} params = {**params, **kwargs}
return str(sorted([(k, v) for k, v in params.items()])) return str(sorted(params.items()))
def generate( def generate(
self, self,

View File

@ -166,7 +166,7 @@ def get_prompts(
Raises: Raises:
ValueError: If the cache is not set and cache is True. ValueError: If the cache is not set and cache is True.
""" """
llm_string = str(sorted([(k, v) for k, v in params.items()])) llm_string = str(sorted(params.items()))
missing_prompts = [] missing_prompts = []
missing_prompt_idxs = [] missing_prompt_idxs = []
existing_prompts = {} existing_prompts = {}
@ -202,7 +202,7 @@ async def aget_prompts(
Raises: Raises:
ValueError: If the cache is not set and cache is True. ValueError: If the cache is not set and cache is True.
""" """
llm_string = str(sorted([(k, v) for k, v in params.items()])) llm_string = str(sorted(params.items()))
missing_prompts = [] missing_prompts = []
missing_prompt_idxs = [] missing_prompt_idxs = []
existing_prompts = {} existing_prompts = {}

View File

@ -67,14 +67,14 @@ class Reviver:
Defaults to None. Defaults to None.
""" """
self.secrets_from_env = secrets_from_env self.secrets_from_env = secrets_from_env
self.secrets_map = secrets_map or dict() self.secrets_map = secrets_map or {}
# By default, only support langchain, but user can pass in additional namespaces # By default, only support langchain, but user can pass in additional namespaces
self.valid_namespaces = ( self.valid_namespaces = (
[*DEFAULT_NAMESPACES, *valid_namespaces] [*DEFAULT_NAMESPACES, *valid_namespaces]
if valid_namespaces if valid_namespaces
else DEFAULT_NAMESPACES else DEFAULT_NAMESPACES
) )
self.additional_import_mappings = additional_import_mappings or dict() self.additional_import_mappings = additional_import_mappings or {}
self.import_mappings = ( self.import_mappings = (
{ {
**ALL_SERIALIZABLE_MAPPINGS, **ALL_SERIALIZABLE_MAPPINGS,
@ -146,7 +146,7 @@ class Reviver:
# We don't need to recurse on kwargs # We don't need to recurse on kwargs
# as json.loads will do that for us. # as json.loads will do that for us.
kwargs = value.get("kwargs", dict()) kwargs = value.get("kwargs", {})
return cls(**kwargs) return cls(**kwargs)
return value return value

View File

@ -138,7 +138,7 @@ class Serializable(BaseModel, ABC):
For example, For example,
{"openai_api_key": "OPENAI_API_KEY"} {"openai_api_key": "OPENAI_API_KEY"}
""" """
return dict() return {}
@property @property
def lc_attributes(self) -> dict: def lc_attributes(self) -> dict:
@ -188,7 +188,7 @@ class Serializable(BaseModel, ABC):
if not self.is_lc_serializable(): if not self.is_lc_serializable():
return self.to_json_not_implemented() return self.to_json_not_implemented()
secrets = dict() secrets = {}
# Get latest values for kwargs if there is an attribute with same name # Get latest values for kwargs if there is an attribute with same name
lc_kwargs = {} lc_kwargs = {}
for k, v in self: for k, v in self:

View File

@ -108,7 +108,7 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
return "Return a JSON object." return "Return a JSON object."
else: else:
# Copy schema to avoid altering original Pydantic schema. # Copy schema to avoid altering original Pydantic schema.
schema = {k: v for k, v in self._get_schema(self.pydantic_object).items()} schema = dict(self._get_schema(self.pydantic_object).items())
# Remove extraneous fields. # Remove extraneous fields.
reduced_schema = schema reduced_schema = schema

View File

@ -90,7 +90,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
The format instructions for the JSON output. The format instructions for the JSON output.
""" """
# Copy schema to avoid altering original Pydantic schema. # Copy schema to avoid altering original Pydantic schema.
schema = {k: v for k, v in self.pydantic_object.model_json_schema().items()} schema = dict(self.pydantic_object.model_json_schema().items())
# Remove extraneous fields. # Remove extraneous fields.
reduced_schema = schema reduced_schema = schema

View File

@ -76,7 +76,7 @@ class LLMResult(BaseModel):
else: else:
if self.llm_output is not None: if self.llm_output is not None:
llm_output = deepcopy(self.llm_output) llm_output = deepcopy(self.llm_output)
llm_output["token_usage"] = dict() llm_output["token_usage"] = {}
else: else:
llm_output = None llm_output = None
llm_results.append( llm_results.append(

View File

@ -1007,11 +1007,11 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
input_vars.update(_message.input_variables) input_vars.update(_message.input_variables)
kwargs = { kwargs = {
**dict( **{
input_variables=sorted(input_vars), "input_variables": sorted(input_vars),
optional_variables=sorted(optional_variables), "optional_variables": sorted(optional_variables),
partial_variables=partial_vars, "partial_variables": partial_vars,
), },
**kwargs, **kwargs,
} }
cast(type[ChatPromptTemplate], super()).__init__(messages=_messages, **kwargs) cast(type[ChatPromptTemplate], super()).__init__(messages=_messages, **kwargs)

View File

@ -18,7 +18,7 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
if "input_variables" not in kwargs: if "input_variables" not in kwargs:
kwargs["input_variables"] = [] kwargs["input_variables"] = []
overlap = set(kwargs["input_variables"]) & set(("url", "path", "detail")) overlap = set(kwargs["input_variables"]) & {"url", "path", "detail"}
if overlap: if overlap:
raise ValueError( raise ValueError(
"input_variables for the image template cannot contain" "input_variables for the image template cannot contain"

View File

@ -144,7 +144,7 @@ class PromptTemplate(StringPromptTemplate):
template = self.template + other.template template = self.template + other.template
# If any do not want to validate, then don't # If any do not want to validate, then don't
validate_template = self.validate_template and other.validate_template validate_template = self.validate_template and other.validate_template
partial_variables = {k: v for k, v in self.partial_variables.items()} partial_variables = dict(self.partial_variables.items())
for k, v in other.partial_variables.items(): for k, v in other.partial_variables.items():
if k in partial_variables: if k in partial_variables:
raise ValueError("Cannot have same variable partialed twice.") raise ValueError("Cannot have same variable partialed twice.")

View File

@ -3778,7 +3778,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
for key, step in steps.items() for key, step in steps.items()
) )
) )
output = {key: value for key, value in zip(steps, results)} output = dict(zip(steps, results))
# finish the root run # finish the root run
except BaseException as e: except BaseException as e:
await run_manager.on_chain_error(e) await run_manager.on_chain_error(e)

View File

@ -294,7 +294,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
] ]
to_return: dict[int, Any] = {} to_return: dict[int, Any] = {}
run_again = {i: input for i, input in enumerate(inputs)} run_again = dict(enumerate(inputs))
handled_exceptions: dict[int, BaseException] = {} handled_exceptions: dict[int, BaseException] = {}
first_to_raise = None first_to_raise = None
for runnable in self.runnables: for runnable in self.runnables:
@ -388,7 +388,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
) )
to_return = {} to_return = {}
run_again = {i: input for i, input in enumerate(inputs)} run_again = dict(enumerate(inputs))
handled_exceptions: dict[int, BaseException] = {} handled_exceptions: dict[int, BaseException] = {}
first_to_raise = None first_to_raise = None
for runnable in self.runnables: for runnable in self.runnables:

View File

@ -117,7 +117,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
@property @property
def _kwargs_retrying(self) -> dict[str, Any]: def _kwargs_retrying(self) -> dict[str, Any]:
kwargs: dict[str, Any] = dict() kwargs: dict[str, Any] = {}
if self.max_attempt_number: if self.max_attempt_number:
kwargs["stop"] = stop_after_attempt(self.max_attempt_number) kwargs["stop"] = stop_after_attempt(self.max_attempt_number)

View File

@ -10,7 +10,7 @@ def _get_sub_deps(packages: Sequence[str]) -> list[str]:
from importlib import metadata from importlib import metadata
sub_deps = set() sub_deps = set()
_underscored_packages = set(pkg.replace("-", "_") for pkg in packages) _underscored_packages = {pkg.replace("-", "_") for pkg in packages}
for pkg in packages: for pkg in packages:
try: try:
@ -33,7 +33,7 @@ def _get_sub_deps(packages: Sequence[str]) -> list[str]:
return sorted(sub_deps, key=lambda x: x.lower()) return sorted(sub_deps, key=lambda x: x.lower())
def print_sys_info(*, additional_pkgs: Sequence[str] = tuple()) -> None: def print_sys_info(*, additional_pkgs: Sequence[str] = ()) -> None:
"""Print information about the environment for debugging purposes. """Print information about the environment for debugging purposes.
Args: Args:

View File

@ -975,7 +975,7 @@ def _get_all_basemodel_annotations(
) and name not in fields: ) and name not in fields:
continue continue
annotations[name] = param.annotation annotations[name] = param.annotation
orig_bases: tuple = getattr(cls, "__orig_bases__", tuple()) orig_bases: tuple = getattr(cls, "__orig_bases__", ())
# cls has subscript: cls = FooBar[int] # cls has subscript: cls = FooBar[int]
else: else:
annotations = _get_all_basemodel_annotations( annotations = _get_all_basemodel_annotations(
@ -1007,11 +1007,9 @@ def _get_all_basemodel_annotations(
# parent_origin = Baz, # parent_origin = Baz,
# generic_type_vars = (type vars in Baz) # generic_type_vars = (type vars in Baz)
# generic_map = {type var in Baz: str} # generic_map = {type var in Baz: str}
generic_type_vars: tuple = getattr(parent_origin, "__parameters__", tuple()) generic_type_vars: tuple = getattr(parent_origin, "__parameters__", ())
generic_map = { generic_map = dict(zip(generic_type_vars, get_args(parent)))
type_var: t for type_var, t in zip(generic_type_vars, get_args(parent)) for field in getattr(parent_origin, "__annotations__", {}):
}
for field in getattr(parent_origin, "__annotations__", dict()):
annotations[field] = _replace_type_vars( annotations[field] = _replace_type_vars(
annotations[field], generic_map, default_to_bound annotations[field], generic_map, default_to_bound
) )

View File

@ -233,9 +233,7 @@ def _convert_any_typed_dicts_to_pydantic(
new_arg_type = _convert_any_typed_dicts_to_pydantic( new_arg_type = _convert_any_typed_dicts_to_pydantic(
annotated_args[0], depth=depth + 1, visited=visited annotated_args[0], depth=depth + 1, visited=visited
) )
field_kwargs = { field_kwargs = dict(zip(("default", "description"), annotated_args[1:]))
k: v for k, v in zip(("default", "description"), annotated_args[1:])
}
if (field_desc := field_kwargs.get("description")) and not isinstance( if (field_desc := field_kwargs.get("description")) and not isinstance(
field_desc, str field_desc, str
): ):

View File

@ -44,7 +44,7 @@ python = ">=3.12.4"
[tool.poetry.extras] [tool.poetry.extras]
[tool.ruff.lint] [tool.ruff.lint]
select = [ "B", "E", "F", "I", "N", "T201", "UP",] select = [ "B", "C4", "E", "F", "I", "N", "T201", "UP",]
ignore = [ "UP007",] ignore = [ "UP007",]
[tool.coverage.run] [tool.coverage.run]

View File

@ -54,5 +54,5 @@ def test_lazy_load() -> None:
expected.append( expected.append(
Document(example.inputs["first"]["second"].upper(), metadata=metadata) Document(example.inputs["first"]["second"].upper(), metadata=metadata)
) )
actual = [doc for doc in loader.lazy_load()] actual = list(loader.lazy_load())
assert expected == actual assert expected == actual

View File

@ -55,7 +55,7 @@ async def test_generic_fake_chat_model_stream() -> None:
] ]
assert len({chunk.id for chunk in chunks}) == 1 assert len({chunk.id for chunk in chunks}) == 1
chunks = [chunk for chunk in model.stream("meow")] chunks = list(model.stream("meow"))
assert chunks == [ assert chunks == [
_any_id_ai_message_chunk(content="hello"), _any_id_ai_message_chunk(content="hello"),
_any_id_ai_message_chunk(content=" "), _any_id_ai_message_chunk(content=" "),

View File

@ -185,11 +185,11 @@ def test_index_simple_delete_full(
): ):
indexing_result = index(loader, record_manager, vector_store, cleanup="full") indexing_result = index(loader, record_manager, vector_store, cleanup="full")
doc_texts = set( doc_texts = {
# Ignoring type since doc should be in the store and not a None # Ignoring type since doc should be in the store and not a None
vector_store.get_by_ids([uid])[0].page_content # type: ignore vector_store.get_by_ids([uid])[0].page_content # type: ignore
for uid in vector_store.store for uid in vector_store.store
) }
assert doc_texts == {"mutated document 1", "This is another document."} assert doc_texts == {"mutated document 1", "This is another document."}
assert indexing_result == { assert indexing_result == {
@ -267,11 +267,11 @@ async def test_aindex_simple_delete_full(
"num_updated": 0, "num_updated": 0,
} }
doc_texts = set( doc_texts = {
# Ignoring type since doc should be in the store and not a None # Ignoring type since doc should be in the store and not a None
vector_store.get_by_ids([uid])[0].page_content # type: ignore vector_store.get_by_ids([uid])[0].page_content # type: ignore
for uid in vector_store.store for uid in vector_store.store
) }
assert doc_texts == {"mutated document 1", "This is another document."} assert doc_texts == {"mutated document 1", "This is another document."}
# Attempt to index again verify that nothing changes # Attempt to index again verify that nothing changes
@ -558,11 +558,11 @@ def test_incremental_delete(
"num_updated": 0, "num_updated": 0,
} }
doc_texts = set( doc_texts = {
# Ignoring type since doc should be in the store and not a None # Ignoring type since doc should be in the store and not a None
vector_store.get_by_ids([uid])[0].page_content # type: ignore vector_store.get_by_ids([uid])[0].page_content # type: ignore
for uid in vector_store.store for uid in vector_store.store
) }
assert doc_texts == {"This is another document.", "This is a test document."} assert doc_texts == {"This is another document.", "This is a test document."}
# Attempt to index again verify that nothing changes # Attempt to index again verify that nothing changes
@ -617,11 +617,11 @@ def test_incremental_delete(
"num_updated": 0, "num_updated": 0,
} }
doc_texts = set( doc_texts = {
# Ignoring type since doc should be in the store and not a None # Ignoring type since doc should be in the store and not a None
vector_store.get_by_ids([uid])[0].page_content # type: ignore vector_store.get_by_ids([uid])[0].page_content # type: ignore
for uid in vector_store.store for uid in vector_store.store
) }
assert doc_texts == { assert doc_texts == {
"mutated document 1", "mutated document 1",
"mutated document 2", "mutated document 2",
@ -685,11 +685,11 @@ def test_incremental_indexing_with_batch_size(
"num_updated": 0, "num_updated": 0,
} }
doc_texts = set( doc_texts = {
# Ignoring type since doc should be in the store and not a None # Ignoring type since doc should be in the store and not a None
vector_store.get_by_ids([uid])[0].page_content # type: ignore vector_store.get_by_ids([uid])[0].page_content # type: ignore
for uid in vector_store.store for uid in vector_store.store
) }
assert doc_texts == {"1", "2", "3", "4"} assert doc_texts == {"1", "2", "3", "4"}
@ -735,11 +735,11 @@ def test_incremental_delete_with_batch_size(
"num_updated": 0, "num_updated": 0,
} }
doc_texts = set( doc_texts = {
# Ignoring type since doc should be in the store and not a None # Ignoring type since doc should be in the store and not a None
vector_store.get_by_ids([uid])[0].page_content # type: ignore vector_store.get_by_ids([uid])[0].page_content # type: ignore
for uid in vector_store.store for uid in vector_store.store
) }
assert doc_texts == {"1", "2", "3", "4"} assert doc_texts == {"1", "2", "3", "4"}
# Attempt to index again verify that nothing changes # Attempt to index again verify that nothing changes
@ -880,11 +880,11 @@ async def test_aincremental_delete(
"num_updated": 0, "num_updated": 0,
} }
doc_texts = set( doc_texts = {
# Ignoring type since doc should be in the store and not a None # Ignoring type since doc should be in the store and not a None
vector_store.get_by_ids([uid])[0].page_content # type: ignore vector_store.get_by_ids([uid])[0].page_content # type: ignore
for uid in vector_store.store for uid in vector_store.store
) }
assert doc_texts == {"This is another document.", "This is a test document."} assert doc_texts == {"This is another document.", "This is a test document."}
# Attempt to index again verify that nothing changes # Attempt to index again verify that nothing changes
@ -939,11 +939,11 @@ async def test_aincremental_delete(
"num_updated": 0, "num_updated": 0,
} }
doc_texts = set( doc_texts = {
# Ignoring type since doc should be in the store and not a None # Ignoring type since doc should be in the store and not a None
vector_store.get_by_ids([uid])[0].page_content # type: ignore vector_store.get_by_ids([uid])[0].page_content # type: ignore
for uid in vector_store.store for uid in vector_store.store
) }
assert doc_texts == { assert doc_texts == {
"mutated document 1", "mutated document 1",
"mutated document 2", "mutated document 2",

View File

@ -53,10 +53,10 @@ def test_batch_size(messages: list, messages_2: list) -> None:
with collect_runs() as cb: with collect_runs() as cb:
llm.batch([messages, messages_2], {"callbacks": [cb]}) llm.batch([messages, messages_2], {"callbacks": [cb]})
assert len(cb.traced_runs) == 2 assert len(cb.traced_runs) == 2
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs]) assert all((r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs)
with collect_runs() as cb: with collect_runs() as cb:
llm.batch([messages], {"callbacks": [cb]}) llm.batch([messages], {"callbacks": [cb]})
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs]) assert all((r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs)
assert len(cb.traced_runs) == 1 assert len(cb.traced_runs) == 1
with collect_runs() as cb: with collect_runs() as cb:
@ -76,11 +76,11 @@ async def test_async_batch_size(messages: list, messages_2: list) -> None:
# so we expect batch_size to always be 1 # so we expect batch_size to always be 1
with collect_runs() as cb: with collect_runs() as cb:
await llm.abatch([messages, messages_2], {"callbacks": [cb]}) await llm.abatch([messages, messages_2], {"callbacks": [cb]})
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs]) assert all((r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs)
assert len(cb.traced_runs) == 2 assert len(cb.traced_runs) == 2
with collect_runs() as cb: with collect_runs() as cb:
await llm.abatch([messages], {"callbacks": [cb]}) await llm.abatch([messages], {"callbacks": [cb]})
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs]) assert all((r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs)
assert len(cb.traced_runs) == 1 assert len(cb.traced_runs) == 1
with collect_runs() as cb: with collect_runs() as cb:
@ -146,7 +146,7 @@ async def test_astream_fallback_to_ainvoke() -> None:
return "fake-chat-model" return "fake-chat-model"
model = ModelWithGenerate() model = ModelWithGenerate()
chunks = [chunk for chunk in model.stream("anything")] chunks = list(model.stream("anything"))
assert chunks == [_any_id_ai_message(content="hello")] assert chunks == [_any_id_ai_message(content="hello")]
chunks = [chunk async for chunk in model.astream("anything")] chunks = [chunk async for chunk in model.astream("anything")]
@ -183,7 +183,7 @@ async def test_astream_implementation_fallback_to_stream() -> None:
return "fake-chat-model" return "fake-chat-model"
model = ModelWithSyncStream() model = ModelWithSyncStream()
chunks = [chunk for chunk in model.stream("anything")] chunks = list(model.stream("anything"))
assert chunks == [ assert chunks == [
_any_id_ai_message_chunk(content="a"), _any_id_ai_message_chunk(content="a"),
_any_id_ai_message_chunk(content="b"), _any_id_ai_message_chunk(content="b"),

View File

@ -262,7 +262,7 @@ def test_global_cache_stream() -> None:
AIMessage(content="goodbye world"), AIMessage(content="goodbye world"),
] ]
model = GenericFakeChatModel(messages=iter(messages), cache=True) model = GenericFakeChatModel(messages=iter(messages), cache=True)
chunks = [chunk for chunk in model.stream("some input")] chunks = list(model.stream("some input"))
assert len(chunks) == 3 assert len(chunks) == 3
# Assert that streaming information gets cached # Assert that streaming information gets cached
assert global_cache._cache != {} assert global_cache._cache != {}

View File

@ -40,12 +40,12 @@ def test_batch_size() -> None:
llm = FakeListLLM(responses=["foo"] * 3) llm = FakeListLLM(responses=["foo"] * 3)
with collect_runs() as cb: with collect_runs() as cb:
llm.batch(["foo", "bar", "foo"], {"callbacks": [cb]}) llm.batch(["foo", "bar", "foo"], {"callbacks": [cb]})
assert all([(r.extra or {}).get("batch_size") == 3 for r in cb.traced_runs]) assert all((r.extra or {}).get("batch_size") == 3 for r in cb.traced_runs)
assert len(cb.traced_runs) == 3 assert len(cb.traced_runs) == 3
llm = FakeListLLM(responses=["foo"]) llm = FakeListLLM(responses=["foo"])
with collect_runs() as cb: with collect_runs() as cb:
llm.batch(["foo"], {"callbacks": [cb]}) llm.batch(["foo"], {"callbacks": [cb]})
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs]) assert all((r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs)
assert len(cb.traced_runs) == 1 assert len(cb.traced_runs) == 1
llm = FakeListLLM(responses=["foo"]) llm = FakeListLLM(responses=["foo"])
@ -71,12 +71,12 @@ async def test_async_batch_size() -> None:
llm = FakeListLLM(responses=["foo"] * 3) llm = FakeListLLM(responses=["foo"] * 3)
with collect_runs() as cb: with collect_runs() as cb:
await llm.abatch(["foo", "bar", "foo"], {"callbacks": [cb]}) await llm.abatch(["foo", "bar", "foo"], {"callbacks": [cb]})
assert all([(r.extra or {}).get("batch_size") == 3 for r in cb.traced_runs]) assert all((r.extra or {}).get("batch_size") == 3 for r in cb.traced_runs)
assert len(cb.traced_runs) == 3 assert len(cb.traced_runs) == 3
llm = FakeListLLM(responses=["foo"]) llm = FakeListLLM(responses=["foo"])
with collect_runs() as cb: with collect_runs() as cb:
await llm.abatch(["foo"], {"callbacks": [cb]}) await llm.abatch(["foo"], {"callbacks": [cb]})
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs]) assert all((r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs)
assert len(cb.traced_runs) == 1 assert len(cb.traced_runs) == 1
llm = FakeListLLM(responses=["foo"]) llm = FakeListLLM(responses=["foo"])
@ -142,7 +142,7 @@ async def test_astream_fallback_to_ainvoke() -> None:
return "fake-chat-model" return "fake-chat-model"
model = ModelWithGenerate() model = ModelWithGenerate()
chunks = [chunk for chunk in model.stream("anything")] chunks = list(model.stream("anything"))
assert chunks == ["hello"] assert chunks == ["hello"]
chunks = [chunk async for chunk in model.astream("anything")] chunks = [chunk async for chunk in model.astream("anything")]
@ -179,7 +179,7 @@ async def test_astream_implementation_fallback_to_stream() -> None:
return "fake-chat-model" return "fake-chat-model"
model = ModelWithSyncStream() model = ModelWithSyncStream()
chunks = [chunk for chunk in model.stream("anything")] chunks = list(model.stream("anything"))
assert chunks == ["a", "b"] assert chunks == ["a", "b"]
assert type(model)._astream == BaseLLM._astream assert type(model)._astream == BaseLLM._astream
astream_chunks = [chunk async for chunk in model.astream("anything")] astream_chunks = [chunk async for chunk in model.astream("anything")]

View File

@ -93,5 +93,5 @@ def test_base_transform_output_parser() -> None:
model = GenericFakeChatModel(messages=iter([AIMessage(content="hello world")])) model = GenericFakeChatModel(messages=iter([AIMessage(content="hello world")]))
chain = model | StrInvertCase() chain = model | StrInvertCase()
# inputs to models are ignored, response is hard-coded in model definition # inputs to models are ignored, response is hard-coded in model definition
chunks = [chunk for chunk in chain.stream("")] chunks = list(chain.stream(""))
assert chunks == ["HELLO", " ", "WORLD"] assert chunks == ["HELLO", " ", "WORLD"]

View File

@ -596,10 +596,10 @@ def test_base_model_schema_consistency() -> None:
setup: str setup: str
punchline: str punchline: str
initial_joke_schema = {k: v for k, v in _schema(Joke).items()} initial_joke_schema = dict(_schema(Joke).items())
SimpleJsonOutputParser(pydantic_object=Joke) SimpleJsonOutputParser(pydantic_object=Joke)
openai_func = convert_to_openai_function(Joke) openai_func = convert_to_openai_function(Joke)
retrieved_joke_schema = {k: v for k, v in _schema(Joke).items()} retrieved_joke_schema = dict(_schema(Joke).items())
assert initial_joke_schema == retrieved_joke_schema assert initial_joke_schema == retrieved_joke_schema
assert openai_func.get("name", None) is not None assert openai_func.get("name", None) is not None

View File

@ -391,7 +391,7 @@ async def test_runnable_seq_streaming_chunks() -> None:
} }
) )
chunks = [c for c in chain.stream({"foo": "foo", "bar": "bar"})] chunks = list(chain.stream({"foo": "foo", "bar": "bar"}))
achunks = [c async for c in chain.astream({"foo": "foo", "bar": "bar"})] achunks = [c async for c in chain.astream({"foo": "foo", "bar": "bar"})]
for c in chunks: for c in chunks:
assert c in achunks assert c in achunks

View File

@ -264,7 +264,7 @@ def test_fallbacks_stream() -> None:
runnable = RunnableGenerator(_generate_immediate_error).with_fallbacks( runnable = RunnableGenerator(_generate_immediate_error).with_fallbacks(
[RunnableGenerator(_generate)] [RunnableGenerator(_generate)]
) )
assert list(runnable.stream({})) == [c for c in "foo bar"] assert list(runnable.stream({})) == list("foo bar")
with pytest.raises(ValueError): with pytest.raises(ValueError):
runnable = RunnableGenerator(_generate_delayed_error).with_fallbacks( runnable = RunnableGenerator(_generate_delayed_error).with_fallbacks(

View File

@ -1065,7 +1065,7 @@ async def test_passthrough_tap_async(mocker: MockerFixture) -> None:
assert [ assert [
part part
async for part in seq.astream( async for part in seq.astream(
"hello", dict(metadata={"key": "value"}), my_kwarg="value" "hello", {"metadata": {"key": "value"}}, my_kwarg="value"
) )
] == [5] ] == [5]
assert mock.call_args_list == [ assert mock.call_args_list == [
@ -1125,12 +1125,9 @@ async def test_passthrough_tap_async(mocker: MockerFixture) -> None:
assert call in mock.call_args_list assert call in mock.call_args_list
mock.reset_mock() mock.reset_mock()
assert [ assert list(
part seq.stream("hello", {"metadata": {"key": "value"}}, my_kwarg="value")
for part in seq.stream( ) == [5]
"hello", dict(metadata={"key": "value"}), my_kwarg="value"
)
] == [5]
assert mock.call_args_list == [ assert mock.call_args_list == [
mocker.call("hello", my_kwarg="value"), mocker.call("hello", my_kwarg="value"),
mocker.call(5), mocker.call(5),
@ -1155,13 +1152,13 @@ async def test_with_config_metadata_passthrough(mocker: MockerFixture) -> None:
) )
assert spy.call_args_list[0].args[1:] == ( assert spy.call_args_list[0].args[1:] == (
"hello", "hello",
dict( {
tags=["a-tag"], "tags": ["a-tag"],
callbacks=None, "callbacks": None,
recursion_limit=25, "recursion_limit": 25,
configurable={"hello": "there", "__secret_key": "nahnah"}, "configurable": {"hello": "there", "__secret_key": "nahnah"},
metadata={"hello": "there", "bye": "now"}, "metadata": {"hello": "there", "bye": "now"},
), },
) )
spy.reset_mock() spy.reset_mock()
@ -1174,7 +1171,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
assert spy.call_args_list == [ assert spy.call_args_list == [
mocker.call( mocker.call(
"hello", "hello",
dict(tags=["a-tag"], metadata={}, configurable={}), {"tags": ["a-tag"], "metadata": {}, "configurable": {}},
), ),
] ]
spy.reset_mock() spy.reset_mock()
@ -1200,19 +1197,19 @@ async def test_with_config(mocker: MockerFixture) -> None:
assert [ assert [
*fake.with_config(tags=["a-tag"]).stream( *fake.with_config(tags=["a-tag"]).stream(
"hello", dict(metadata={"key": "value"}) "hello", {"metadata": {"key": "value"}}
) )
] == [5] ] == [5]
assert spy.call_args_list == [ assert spy.call_args_list == [
mocker.call( mocker.call(
"hello", "hello",
dict(tags=["a-tag"], metadata={"key": "value"}, configurable={}), {"tags": ["a-tag"], "metadata": {"key": "value"}, "configurable": {}},
), ),
] ]
spy.reset_mock() spy.reset_mock()
assert fake.with_config(recursion_limit=5).batch( assert fake.with_config(recursion_limit=5).batch(
["hello", "wooorld"], [dict(tags=["a-tag"]), dict(metadata={"key": "value"})] ["hello", "wooorld"], [{"tags": ["a-tag"]}, {"metadata": {"key": "value"}}]
) == [5, 7] ) == [5, 7]
assert len(spy.call_args_list) == 2 assert len(spy.call_args_list) == 2
@ -1235,7 +1232,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
c c
for c in fake.with_config(recursion_limit=5).batch_as_completed( for c in fake.with_config(recursion_limit=5).batch_as_completed(
["hello", "wooorld"], ["hello", "wooorld"],
[dict(tags=["a-tag"]), dict(metadata={"key": "value"})], [{"tags": ["a-tag"]}, {"metadata": {"key": "value"}}],
) )
) == [(0, 5), (1, 7)] ) == [(0, 5), (1, 7)]
@ -1256,7 +1253,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
spy.reset_mock() spy.reset_mock()
assert fake.with_config(metadata={"a": "b"}).batch( assert fake.with_config(metadata={"a": "b"}).batch(
["hello", "wooorld"], dict(tags=["a-tag"]) ["hello", "wooorld"], {"tags": ["a-tag"]}
) == [5, 7] ) == [5, 7]
assert len(spy.call_args_list) == 2 assert len(spy.call_args_list) == 2
for i, call in enumerate(spy.call_args_list): for i, call in enumerate(spy.call_args_list):
@ -1266,7 +1263,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
spy.reset_mock() spy.reset_mock()
assert sorted( assert sorted(
c for c in fake.batch_as_completed(["hello", "wooorld"], dict(tags=["a-tag"])) c for c in fake.batch_as_completed(["hello", "wooorld"], {"tags": ["a-tag"]})
) == [(0, 5), (1, 7)] ) == [(0, 5), (1, 7)]
assert len(spy.call_args_list) == 2 assert len(spy.call_args_list) == 2
for i, call in enumerate(spy.call_args_list): for i, call in enumerate(spy.call_args_list):
@ -1284,7 +1281,12 @@ async def test_with_config(mocker: MockerFixture) -> None:
assert spy.call_args_list == [ assert spy.call_args_list == [
mocker.call( mocker.call(
"hello", "hello",
dict(callbacks=[handler], metadata={"a": "b"}, configurable={}, tags=[]), {
"callbacks": [handler],
"metadata": {"a": "b"},
"configurable": {},
"tags": [],
},
), ),
] ]
spy.reset_mock() spy.reset_mock()
@ -1293,12 +1295,12 @@ async def test_with_config(mocker: MockerFixture) -> None:
part async for part in fake.with_config(metadata={"a": "b"}).astream("hello") part async for part in fake.with_config(metadata={"a": "b"}).astream("hello")
] == [5] ] == [5]
assert spy.call_args_list == [ assert spy.call_args_list == [
mocker.call("hello", dict(metadata={"a": "b"}, tags=[], configurable={})), mocker.call("hello", {"metadata": {"a": "b"}, "tags": [], "configurable": {}}),
] ]
spy.reset_mock() spy.reset_mock()
assert await fake.with_config(recursion_limit=5, tags=["c"]).abatch( assert await fake.with_config(recursion_limit=5, tags=["c"]).abatch(
["hello", "wooorld"], dict(metadata={"key": "value"}) ["hello", "wooorld"], {"metadata": {"key": "value"}}
) == [ ) == [
5, 5,
7, 7,
@ -1306,23 +1308,23 @@ async def test_with_config(mocker: MockerFixture) -> None:
assert spy.call_args_list == [ assert spy.call_args_list == [
mocker.call( mocker.call(
"hello", "hello",
dict( {
metadata={"key": "value"}, "metadata": {"key": "value"},
tags=["c"], "tags": ["c"],
callbacks=None, "callbacks": None,
recursion_limit=5, "recursion_limit": 5,
configurable={}, "configurable": {},
), },
), ),
mocker.call( mocker.call(
"wooorld", "wooorld",
dict( {
metadata={"key": "value"}, "metadata": {"key": "value"},
tags=["c"], "tags": ["c"],
callbacks=None, "callbacks": None,
recursion_limit=5, "recursion_limit": 5,
configurable={}, "configurable": {},
), },
), ),
] ]
spy.reset_mock() spy.reset_mock()
@ -1332,7 +1334,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
c c
async for c in fake.with_config( async for c in fake.with_config(
recursion_limit=5, tags=["c"] recursion_limit=5, tags=["c"]
).abatch_as_completed(["hello", "wooorld"], dict(metadata={"key": "value"})) ).abatch_as_completed(["hello", "wooorld"], {"metadata": {"key": "value"}})
] ]
) == [ ) == [
(0, 5), (0, 5),
@ -1342,24 +1344,24 @@ async def test_with_config(mocker: MockerFixture) -> None:
first_call = next(call for call in spy.call_args_list if call.args[0] == "hello") first_call = next(call for call in spy.call_args_list if call.args[0] == "hello")
assert first_call == mocker.call( assert first_call == mocker.call(
"hello", "hello",
dict( {
metadata={"key": "value"}, "metadata": {"key": "value"},
tags=["c"], "tags": ["c"],
callbacks=None, "callbacks": None,
recursion_limit=5, "recursion_limit": 5,
configurable={}, "configurable": {},
), },
) )
second_call = next(call for call in spy.call_args_list if call.args[0] == "wooorld") second_call = next(call for call in spy.call_args_list if call.args[0] == "wooorld")
assert second_call == mocker.call( assert second_call == mocker.call(
"wooorld", "wooorld",
dict( {
metadata={"key": "value"}, "metadata": {"key": "value"},
tags=["c"], "tags": ["c"],
callbacks=None, "callbacks": None,
recursion_limit=5, "recursion_limit": 5,
configurable={}, "configurable": {},
), },
) )
@ -1367,20 +1369,20 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
fake = FakeRunnable() fake = FakeRunnable()
spy = mocker.spy(fake, "invoke") spy = mocker.spy(fake, "invoke")
assert fake.invoke("hello", dict(tags=["a-tag"])) == 5 assert fake.invoke("hello", {"tags": ["a-tag"]}) == 5
assert spy.call_args_list == [ assert spy.call_args_list == [
mocker.call("hello", dict(tags=["a-tag"])), mocker.call("hello", {"tags": ["a-tag"]}),
] ]
spy.reset_mock() spy.reset_mock()
assert [*fake.stream("hello", dict(metadata={"key": "value"}))] == [5] assert [*fake.stream("hello", {"metadata": {"key": "value"}})] == [5]
assert spy.call_args_list == [ assert spy.call_args_list == [
mocker.call("hello", dict(metadata={"key": "value"})), mocker.call("hello", {"metadata": {"key": "value"}}),
] ]
spy.reset_mock() spy.reset_mock()
assert fake.batch( assert fake.batch(
["hello", "wooorld"], [dict(tags=["a-tag"]), dict(metadata={"key": "value"})] ["hello", "wooorld"], [{"tags": ["a-tag"]}, {"metadata": {"key": "value"}}]
) == [5, 7] ) == [5, 7]
assert len(spy.call_args_list) == 2 assert len(spy.call_args_list) == 2
@ -1398,9 +1400,9 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
spy.reset_mock() spy.reset_mock()
assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7] assert fake.batch(["hello", "wooorld"], {"tags": ["a-tag"]}) == [5, 7]
assert len(spy.call_args_list) == 2 assert len(spy.call_args_list) == 2
assert set(call.args[0] for call in spy.call_args_list) == {"hello", "wooorld"} assert {call.args[0] for call in spy.call_args_list} == {"hello", "wooorld"}
for call in spy.call_args_list: for call in spy.call_args_list:
assert call.args[1].get("tags") == ["a-tag"] assert call.args[1].get("tags") == ["a-tag"]
assert call.args[1].get("metadata") == {} assert call.args[1].get("metadata") == {}
@ -1408,7 +1410,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
assert await fake.ainvoke("hello", config={"callbacks": []}) == 5 assert await fake.ainvoke("hello", config={"callbacks": []}) == 5
assert spy.call_args_list == [ assert spy.call_args_list == [
mocker.call("hello", dict(callbacks=[])), mocker.call("hello", {"callbacks": []}),
] ]
spy.reset_mock() spy.reset_mock()
@ -1418,19 +1420,19 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
] ]
spy.reset_mock() spy.reset_mock()
assert await fake.abatch(["hello", "wooorld"], dict(metadata={"key": "value"})) == [ assert await fake.abatch(["hello", "wooorld"], {"metadata": {"key": "value"}}) == [
5, 5,
7, 7,
] ]
assert set(call.args[0] for call in spy.call_args_list) == {"hello", "wooorld"} assert {call.args[0] for call in spy.call_args_list} == {"hello", "wooorld"}
for call in spy.call_args_list: for call in spy.call_args_list:
assert call.args[1] == dict( assert call.args[1] == {
metadata={"key": "value"}, "metadata": {"key": "value"},
tags=[], "tags": [],
callbacks=None, "callbacks": None,
recursion_limit=25, "recursion_limit": 25,
configurable={}, "configurable": {},
) }
async def test_prompt() -> None: async def test_prompt() -> None:
@ -1698,7 +1700,7 @@ def test_prompt_with_chat_model(
chat_spy = mocker.spy(chat.__class__, "invoke") chat_spy = mocker.spy(chat.__class__, "invoke")
tracer = FakeTracer() tracer = FakeTracer()
assert chain.invoke( assert chain.invoke(
{"question": "What is your name?"}, dict(callbacks=[tracer]) {"question": "What is your name?"}, {"callbacks": [tracer]}
) == _any_id_ai_message(content="foo") ) == _any_id_ai_message(content="foo")
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue( assert chat_spy.call_args.args[1] == ChatPromptValue(
@ -1722,7 +1724,7 @@ def test_prompt_with_chat_model(
{"question": "What is your name?"}, {"question": "What is your name?"},
{"question": "What is your favorite color?"}, {"question": "What is your favorite color?"},
], ],
dict(callbacks=[tracer]), {"callbacks": [tracer]},
) == [ ) == [
_any_id_ai_message(content="foo"), _any_id_ai_message(content="foo"),
_any_id_ai_message(content="foo"), _any_id_ai_message(content="foo"),
@ -1763,7 +1765,7 @@ def test_prompt_with_chat_model(
chat_spy = mocker.spy(chat.__class__, "stream") chat_spy = mocker.spy(chat.__class__, "stream")
tracer = FakeTracer() tracer = FakeTracer()
assert [ assert [
*chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer])) *chain.stream({"question": "What is your name?"}, {"callbacks": [tracer]})
] == [ ] == [
_any_id_ai_message_chunk(content="f"), _any_id_ai_message_chunk(content="f"),
_any_id_ai_message_chunk(content="o"), _any_id_ai_message_chunk(content="o"),
@ -1804,7 +1806,7 @@ async def test_prompt_with_chat_model_async(
chat_spy = mocker.spy(chat.__class__, "ainvoke") chat_spy = mocker.spy(chat.__class__, "ainvoke")
tracer = FakeTracer() tracer = FakeTracer()
assert await chain.ainvoke( assert await chain.ainvoke(
{"question": "What is your name?"}, dict(callbacks=[tracer]) {"question": "What is your name?"}, {"callbacks": [tracer]}
) == _any_id_ai_message(content="foo") ) == _any_id_ai_message(content="foo")
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue( assert chat_spy.call_args.args[1] == ChatPromptValue(
@ -1828,7 +1830,7 @@ async def test_prompt_with_chat_model_async(
{"question": "What is your name?"}, {"question": "What is your name?"},
{"question": "What is your favorite color?"}, {"question": "What is your favorite color?"},
], ],
dict(callbacks=[tracer]), {"callbacks": [tracer]},
) == [ ) == [
_any_id_ai_message(content="foo"), _any_id_ai_message(content="foo"),
_any_id_ai_message(content="foo"), _any_id_ai_message(content="foo"),
@ -1871,7 +1873,7 @@ async def test_prompt_with_chat_model_async(
assert [ assert [
a a
async for a in chain.astream( async for a in chain.astream(
{"question": "What is your name?"}, dict(callbacks=[tracer]) {"question": "What is your name?"}, {"callbacks": [tracer]}
) )
] == [ ] == [
_any_id_ai_message_chunk(content="f"), _any_id_ai_message_chunk(content="f"),
@ -1910,9 +1912,7 @@ async def test_prompt_with_llm(
llm_spy = mocker.spy(llm.__class__, "ainvoke") llm_spy = mocker.spy(llm.__class__, "ainvoke")
tracer = FakeTracer() tracer = FakeTracer()
assert ( assert (
await chain.ainvoke( await chain.ainvoke({"question": "What is your name?"}, {"callbacks": [tracer]})
{"question": "What is your name?"}, dict(callbacks=[tracer])
)
== "foo" == "foo"
) )
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
@ -1935,7 +1935,7 @@ async def test_prompt_with_llm(
{"question": "What is your name?"}, {"question": "What is your name?"},
{"question": "What is your favorite color?"}, {"question": "What is your favorite color?"},
], ],
dict(callbacks=[tracer]), {"callbacks": [tracer]},
) == ["bar", "foo"] ) == ["bar", "foo"]
assert prompt_spy.call_args.args[1] == [ assert prompt_spy.call_args.args[1] == [
{"question": "What is your name?"}, {"question": "What is your name?"},
@ -1966,7 +1966,7 @@ async def test_prompt_with_llm(
assert [ assert [
token token
async for token in chain.astream( async for token in chain.astream(
{"question": "What is your name?"}, dict(callbacks=[tracer]) {"question": "What is your name?"}, {"callbacks": [tracer]}
) )
] == ["bar"] ] == ["bar"]
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
@ -2110,7 +2110,7 @@ async def test_prompt_with_llm_parser(
parser_spy = mocker.spy(parser.__class__, "ainvoke") parser_spy = mocker.spy(parser.__class__, "ainvoke")
tracer = FakeTracer() tracer = FakeTracer()
assert await chain.ainvoke( assert await chain.ainvoke(
{"question": "What is your name?"}, dict(callbacks=[tracer]) {"question": "What is your name?"}, {"callbacks": [tracer]}
) == ["bear", "dog", "cat"] ) == ["bear", "dog", "cat"]
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert llm_spy.call_args.args[1] == ChatPromptValue( assert llm_spy.call_args.args[1] == ChatPromptValue(
@ -2135,7 +2135,7 @@ async def test_prompt_with_llm_parser(
{"question": "What is your name?"}, {"question": "What is your name?"},
{"question": "What is your favorite color?"}, {"question": "What is your favorite color?"},
], ],
dict(callbacks=[tracer]), {"callbacks": [tracer]},
) == [["tomato", "lettuce", "onion"], ["bear", "dog", "cat"]] ) == [["tomato", "lettuce", "onion"], ["bear", "dog", "cat"]]
assert prompt_spy.call_args.args[1] == [ assert prompt_spy.call_args.args[1] == [
{"question": "What is your name?"}, {"question": "What is your name?"},
@ -2171,7 +2171,7 @@ async def test_prompt_with_llm_parser(
assert [ assert [
token token
async for token in chain.astream( async for token in chain.astream(
{"question": "What is your name?"}, dict(callbacks=[tracer]) {"question": "What is your name?"}, {"callbacks": [tracer]}
) )
] == [["tomato"], ["lettuce"], ["onion"]] ] == [["tomato"], ["lettuce"], ["onion"]]
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
@ -2495,9 +2495,7 @@ async def test_prompt_with_llm_and_async_lambda(
llm_spy = mocker.spy(llm.__class__, "ainvoke") llm_spy = mocker.spy(llm.__class__, "ainvoke")
tracer = FakeTracer() tracer = FakeTracer()
assert ( assert (
await chain.ainvoke( await chain.ainvoke({"question": "What is your name?"}, {"callbacks": [tracer]})
{"question": "What is your name?"}, dict(callbacks=[tracer])
)
== "foo" == "foo"
) )
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
@ -2539,7 +2537,7 @@ def test_prompt_with_chat_model_and_parser(
parser_spy = mocker.spy(parser.__class__, "invoke") parser_spy = mocker.spy(parser.__class__, "invoke")
tracer = FakeTracer() tracer = FakeTracer()
assert chain.invoke( assert chain.invoke(
{"question": "What is your name?"}, dict(callbacks=[tracer]) {"question": "What is your name?"}, {"callbacks": [tracer]}
) == ["foo", "bar"] ) == ["foo", "bar"]
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue( assert chat_spy.call_args.args[1] == ChatPromptValue(
@ -2608,7 +2606,7 @@ def test_combining_sequences(
# Test invoke # Test invoke
tracer = FakeTracer() tracer = FakeTracer()
assert combined_chain.invoke( assert combined_chain.invoke(
{"question": "What is your name?"}, dict(callbacks=[tracer]) {"question": "What is your name?"}, {"callbacks": [tracer]}
) == ["baz", "qux"] ) == ["baz", "qux"]
assert tracer.runs == snapshot assert tracer.runs == snapshot
@ -2658,7 +2656,7 @@ Question:
chat_spy = mocker.spy(chat.__class__, "invoke") chat_spy = mocker.spy(chat.__class__, "invoke")
parser_spy = mocker.spy(parser.__class__, "invoke") parser_spy = mocker.spy(parser.__class__, "invoke")
tracer = FakeTracer() tracer = FakeTracer()
assert chain.invoke("What is your name?", dict(callbacks=[tracer])) == [ assert chain.invoke("What is your name?", {"callbacks": [tracer]}) == [
"foo", "foo",
"bar", "bar",
] ]
@ -2725,7 +2723,7 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
llm_spy = mocker.spy(llm.__class__, "invoke") llm_spy = mocker.spy(llm.__class__, "invoke")
tracer = FakeTracer() tracer = FakeTracer()
assert chain.invoke( assert chain.invoke(
{"question": "What is your name?"}, dict(callbacks=[tracer]) {"question": "What is your name?"}, {"callbacks": [tracer]}
) == { ) == {
"chat": _any_id_ai_message(content="i'm a chatbot"), "chat": _any_id_ai_message(content="i'm a chatbot"),
"llm": "i'm a textbot", "llm": "i'm a textbot",
@ -2788,7 +2786,7 @@ async def test_router_runnable(
router_spy = mocker.spy(router.__class__, "invoke") router_spy = mocker.spy(router.__class__, "invoke")
tracer = FakeTracer() tracer = FakeTracer()
assert ( assert (
chain.invoke({"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer])) chain.invoke({"key": "math", "question": "2 + 2"}, {"callbacks": [tracer]})
== "4" == "4"
) )
assert router_spy.call_args.args[1] == { assert router_spy.call_args.args[1] == {
@ -2849,7 +2847,7 @@ async def test_higher_order_lambda_runnable(
math_spy = mocker.spy(math_chain.__class__, "invoke") math_spy = mocker.spy(math_chain.__class__, "invoke")
tracer = FakeTracer() tracer = FakeTracer()
assert ( assert (
chain.invoke({"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer])) chain.invoke({"key": "math", "question": "2 + 2"}, {"callbacks": [tracer]})
== "4" == "4"
) )
assert math_spy.call_args.args[1] == { assert math_spy.call_args.args[1] == {
@ -2880,7 +2878,7 @@ async def test_higher_order_lambda_runnable(
tracer = FakeTracer() tracer = FakeTracer()
assert ( assert (
await achain.ainvoke( await achain.ainvoke(
{"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer]) {"key": "math", "question": "2 + 2"}, {"callbacks": [tracer]}
) )
== "4" == "4"
) )
@ -2934,7 +2932,7 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N
llm_spy = mocker.spy(llm.__class__, "invoke") llm_spy = mocker.spy(llm.__class__, "invoke")
tracer = FakeTracer() tracer = FakeTracer()
assert chain.invoke( assert chain.invoke(
{"question": "What is your name?"}, dict(callbacks=[tracer]) {"question": "What is your name?"}, {"callbacks": [tracer]}
) == { ) == {
"chat": _any_id_ai_message(content="i'm a chatbot"), "chat": _any_id_ai_message(content="i'm a chatbot"),
"llm": "i'm a textbot", "llm": "i'm a textbot",
@ -3841,7 +3839,7 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
def test_runnable_lambda_stream() -> None: def test_runnable_lambda_stream() -> None:
"""Test that stream works for both normal functions & those returning Runnable.""" """Test that stream works for both normal functions & those returning Runnable."""
# Normal output should work # Normal output should work
output: list[Any] = [chunk for chunk in RunnableLambda(range).stream(5)] output: list[Any] = list(RunnableLambda(range).stream(5))
assert output == [range(5)] assert output == [range(5)]
# Runnable output should also work # Runnable output should also work
@ -4015,7 +4013,7 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
spy = mocker.spy(ControlledExceptionRunnable, "batch") spy = mocker.spy(ControlledExceptionRunnable, "batch")
tracer = FakeTracer() tracer = FakeTracer()
inputs = ["foo", "bar", "baz", "qux"] inputs = ["foo", "bar", "baz", "qux"]
outputs = chain.batch(inputs, dict(callbacks=[tracer]), return_exceptions=True) outputs = chain.batch(inputs, {"callbacks": [tracer]}, return_exceptions=True)
assert len(outputs) == 4 assert len(outputs) == 4
assert isinstance(outputs[0], ValueError) assert isinstance(outputs[0], ValueError)
assert isinstance(outputs[1], ValueError) assert isinstance(outputs[1], ValueError)
@ -4135,7 +4133,7 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
tracer = FakeTracer() tracer = FakeTracer()
inputs = ["foo", "bar", "baz", "qux"] inputs = ["foo", "bar", "baz", "qux"]
outputs = await chain.abatch( outputs = await chain.abatch(
inputs, dict(callbacks=[tracer]), return_exceptions=True inputs, {"callbacks": [tracer]}, return_exceptions=True
) )
assert len(outputs) == 4 assert len(outputs) == 4
assert isinstance(outputs[0], ValueError) assert isinstance(outputs[0], ValueError)
@ -5080,13 +5078,13 @@ def test_invoke_stream_passthrough_assign_trace() -> None:
chain = RunnablePassthrough.assign(urls=idchain_sync) chain = RunnablePassthrough.assign(urls=idchain_sync)
tracer = FakeTracer() tracer = FakeTracer()
chain.invoke({"example": [1, 2, 3]}, dict(callbacks=[tracer])) chain.invoke({"example": [1, 2, 3]}, {"callbacks": [tracer]})
assert tracer.runs[0].name == "RunnableAssign<urls>" assert tracer.runs[0].name == "RunnableAssign<urls>"
assert tracer.runs[0].child_runs[0].name == "RunnableParallel<urls>" assert tracer.runs[0].child_runs[0].name == "RunnableParallel<urls>"
tracer = FakeTracer() tracer = FakeTracer()
for _ in chain.stream({"example": [1, 2, 3]}, dict(callbacks=[tracer])): for _ in chain.stream({"example": [1, 2, 3]}, {"callbacks": [tracer]}):
pass pass
assert tracer.runs[0].name == "RunnableAssign<urls>" assert tracer.runs[0].name == "RunnableAssign<urls>"
@ -5100,13 +5098,13 @@ async def test_ainvoke_astream_passthrough_assign_trace() -> None:
chain = RunnablePassthrough.assign(urls=idchain_sync) chain = RunnablePassthrough.assign(urls=idchain_sync)
tracer = FakeTracer() tracer = FakeTracer()
await chain.ainvoke({"example": [1, 2, 3]}, dict(callbacks=[tracer])) await chain.ainvoke({"example": [1, 2, 3]}, {"callbacks": [tracer]})
assert tracer.runs[0].name == "RunnableAssign<urls>" assert tracer.runs[0].name == "RunnableAssign<urls>"
assert tracer.runs[0].child_runs[0].name == "RunnableParallel<urls>" assert tracer.runs[0].child_runs[0].name == "RunnableParallel<urls>"
tracer = FakeTracer() tracer = FakeTracer()
async for _ in chain.astream({"example": [1, 2, 3]}, dict(callbacks=[tracer])): async for _ in chain.astream({"example": [1, 2, 3]}, {"callbacks": [tracer]}):
pass pass
assert tracer.runs[0].name == "RunnableAssign<urls>" assert tracer.runs[0].name == "RunnableAssign<urls>"
@ -5260,7 +5258,7 @@ async def test_default_atransform_with_dicts() -> None:
def test_passthrough_transform_with_dicts() -> None: def test_passthrough_transform_with_dicts() -> None:
"""Test that default transform works with dicts.""" """Test that default transform works with dicts."""
runnable = RunnablePassthrough(lambda x: x) runnable = RunnablePassthrough(lambda x: x)
chunks = [chunk for chunk in runnable.transform(iter([{"foo": "a"}, {"foo": "n"}]))] chunks = list(runnable.transform(iter([{"foo": "a"}, {"foo": "n"}])))
assert chunks == [{"foo": "a"}, {"foo": "n"}] assert chunks == [{"foo": "a"}, {"foo": "n"}]

View File

@ -2033,7 +2033,7 @@ async def test_sync_in_async_stream_lambdas() -> None:
async def add_one_proxy_(x: int, config: RunnableConfig) -> int: async def add_one_proxy_(x: int, config: RunnableConfig) -> int:
streaming = add_one.stream(x, config) streaming = add_one.stream(x, config)
results = [result for result in streaming] results = list(streaming)
return results[0] return results[0]
add_one_proxy = RunnableLambda(add_one_proxy_) # type: ignore add_one_proxy = RunnableLambda(add_one_proxy_) # type: ignore
@ -2078,7 +2078,7 @@ async def test_sync_in_sync_lambdas() -> None:
def add_one_proxy(x: int, config: RunnableConfig) -> int: def add_one_proxy(x: int, config: RunnableConfig) -> int:
# Use sync streaming # Use sync streaming
streaming = add_one_.stream(x, config) streaming = add_one_.stream(x, config)
results = [result for result in streaming] results = list(streaming)
return results[0] return results[0]
add_one_proxy_ = RunnableLambda(add_one_proxy) add_one_proxy_ = RunnableLambda(add_one_proxy)

View File

@ -1995,7 +1995,7 @@ async def test_sync_in_async_stream_lambdas() -> None:
async def add_one_proxy(x: int, config: RunnableConfig) -> int: async def add_one_proxy(x: int, config: RunnableConfig) -> int:
streaming = add_one_.stream(x, config) streaming = add_one_.stream(x, config)
results = [result for result in streaming] results = list(streaming)
return results[0] return results[0]
add_one_proxy_ = RunnableLambda(add_one_proxy) # type: ignore add_one_proxy_ = RunnableLambda(add_one_proxy) # type: ignore
@ -2035,7 +2035,7 @@ async def test_sync_in_sync_lambdas() -> None:
def add_one_proxy(x: int, config: RunnableConfig) -> int: def add_one_proxy(x: int, config: RunnableConfig) -> int:
# Use sync streaming # Use sync streaming
streaming = add_one_.stream(x, config) streaming = add_one_.stream(x, config)
results = [result for result in streaming] results = list(streaming)
return results[0] return results[0]
add_one_proxy_ = RunnableLambda(add_one_proxy) add_one_proxy_ = RunnableLambda(add_one_proxy)

View File

@ -448,10 +448,10 @@ def test_message_chunk_to_message() -> None:
def test_tool_calls_merge() -> None: def test_tool_calls_merge() -> None:
chunks: list[dict] = [ chunks: list[dict] = [
dict(content=""), {"content": ""},
dict( {
content="", "content": "",
additional_kwargs={ "additional_kwargs": {
"tool_calls": [ "tool_calls": [
{ {
"index": 0, "index": 0,
@ -461,10 +461,10 @@ def test_tool_calls_merge() -> None:
} }
] ]
}, },
), },
dict( {
content="", "content": "",
additional_kwargs={ "additional_kwargs": {
"tool_calls": [ "tool_calls": [
{ {
"index": 0, "index": 0,
@ -474,10 +474,10 @@ def test_tool_calls_merge() -> None:
} }
] ]
}, },
), },
dict( {
content="", "content": "",
additional_kwargs={ "additional_kwargs": {
"tool_calls": [ "tool_calls": [
{ {
"index": 0, "index": 0,
@ -487,10 +487,10 @@ def test_tool_calls_merge() -> None:
} }
] ]
}, },
), },
dict( {
content="", "content": "",
additional_kwargs={ "additional_kwargs": {
"tool_calls": [ "tool_calls": [
{ {
"index": 0, "index": 0,
@ -500,10 +500,10 @@ def test_tool_calls_merge() -> None:
} }
] ]
}, },
), },
dict( {
content="", "content": "",
additional_kwargs={ "additional_kwargs": {
"tool_calls": [ "tool_calls": [
{ {
"index": 0, "index": 0,
@ -513,10 +513,10 @@ def test_tool_calls_merge() -> None:
} }
] ]
}, },
), },
dict( {
content="", "content": "",
additional_kwargs={ "additional_kwargs": {
"tool_calls": [ "tool_calls": [
{ {
"index": 0, "index": 0,
@ -526,10 +526,10 @@ def test_tool_calls_merge() -> None:
} }
] ]
}, },
), },
dict( {
content="", "content": "",
additional_kwargs={ "additional_kwargs": {
"tool_calls": [ "tool_calls": [
{ {
"index": 0, "index": 0,
@ -539,10 +539,10 @@ def test_tool_calls_merge() -> None:
} }
] ]
}, },
), },
dict( {
content="", "content": "",
additional_kwargs={ "additional_kwargs": {
"tool_calls": [ "tool_calls": [
{ {
"index": 1, "index": 1,
@ -552,10 +552,10 @@ def test_tool_calls_merge() -> None:
} }
] ]
}, },
), },
dict( {
content="", "content": "",
additional_kwargs={ "additional_kwargs": {
"tool_calls": [ "tool_calls": [
{ {
"index": 1, "index": 1,
@ -565,10 +565,10 @@ def test_tool_calls_merge() -> None:
} }
] ]
}, },
), },
dict( {
content="", "content": "",
additional_kwargs={ "additional_kwargs": {
"tool_calls": [ "tool_calls": [
{ {
"index": 1, "index": 1,
@ -578,10 +578,10 @@ def test_tool_calls_merge() -> None:
} }
] ]
}, },
), },
dict( {
content="", "content": "",
additional_kwargs={ "additional_kwargs": {
"tool_calls": [ "tool_calls": [
{ {
"index": 1, "index": 1,
@ -591,10 +591,10 @@ def test_tool_calls_merge() -> None:
} }
] ]
}, },
), },
dict( {
content="", "content": "",
additional_kwargs={ "additional_kwargs": {
"tool_calls": [ "tool_calls": [
{ {
"index": 1, "index": 1,
@ -604,10 +604,10 @@ def test_tool_calls_merge() -> None:
} }
] ]
}, },
), },
dict( {
content="", "content": "",
additional_kwargs={ "additional_kwargs": {
"tool_calls": [ "tool_calls": [
{ {
"index": 1, "index": 1,
@ -617,10 +617,10 @@ def test_tool_calls_merge() -> None:
} }
] ]
}, },
), },
dict( {
content="", "content": "",
additional_kwargs={ "additional_kwargs": {
"tool_calls": [ "tool_calls": [
{ {
"index": 1, "index": 1,
@ -630,8 +630,8 @@ def test_tool_calls_merge() -> None:
} }
] ]
}, },
), },
dict(content=""), {"content": ""},
] ]
final = None final = None

View File

@ -98,7 +98,7 @@ async def test_tracer_chat_model_run() -> None:
], ],
extra={}, extra={},
serialized=SERIALIZED_CHAT, serialized=SERIALIZED_CHAT,
inputs=dict(prompts=["Human: "]), inputs={"prompts": ["Human: "]},
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type] outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
error=None, error=None,
run_type="llm", run_type="llm",
@ -134,7 +134,7 @@ async def test_tracer_multiple_llm_runs() -> None:
], ],
extra={}, extra={},
serialized=SERIALIZED, serialized=SERIALIZED,
inputs=dict(prompts=[]), inputs={"prompts": []},
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type] outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
error=None, error=None,
run_type="llm", run_type="llm",
@ -272,8 +272,8 @@ async def test_tracer_nested_run() -> None:
], ],
extra={}, extra={},
serialized={"name": "tool"}, serialized={"name": "tool"},
inputs=dict(input="test"), inputs={"input": "test"},
outputs=dict(output="test"), outputs={"output": "test"},
error=None, error=None,
run_type="tool", run_type="tool",
trace_id=chain_uuid, trace_id=chain_uuid,
@ -291,7 +291,7 @@ async def test_tracer_nested_run() -> None:
], ],
extra={}, extra={},
serialized=SERIALIZED, serialized=SERIALIZED,
inputs=dict(prompts=[]), inputs={"prompts": []},
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type] outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
run_type="llm", run_type="llm",
trace_id=chain_uuid, trace_id=chain_uuid,
@ -311,7 +311,7 @@ async def test_tracer_nested_run() -> None:
], ],
extra={}, extra={},
serialized=SERIALIZED, serialized=SERIALIZED,
inputs=dict(prompts=[]), inputs={"prompts": []},
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type] outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
run_type="llm", run_type="llm",
trace_id=chain_uuid, trace_id=chain_uuid,
@ -339,7 +339,7 @@ async def test_tracer_llm_run_on_error() -> None:
], ],
extra={}, extra={},
serialized=SERIALIZED, serialized=SERIALIZED,
inputs=dict(prompts=[]), inputs={"prompts": []},
outputs=None, outputs=None,
error=repr(exception), error=repr(exception),
run_type="llm", run_type="llm",
@ -370,7 +370,7 @@ async def test_tracer_llm_run_on_error_callback() -> None:
], ],
extra={}, extra={},
serialized=SERIALIZED, serialized=SERIALIZED,
inputs=dict(prompts=[]), inputs={"prompts": []},
outputs=None, outputs=None,
error=repr(exception), error=repr(exception),
run_type="llm", run_type="llm",
@ -436,7 +436,7 @@ async def test_tracer_tool_run_on_error() -> None:
], ],
extra={}, extra={},
serialized={"name": "tool"}, serialized={"name": "tool"},
inputs=dict(input="test"), inputs={"input": "test"},
outputs=None, outputs=None,
error=repr(exception), error=repr(exception),
run_type="tool", run_type="tool",
@ -527,7 +527,7 @@ async def test_tracer_nested_runs_on_error() -> None:
extra={}, extra={},
serialized=SERIALIZED, serialized=SERIALIZED,
error=None, error=None,
inputs=dict(prompts=[]), inputs={"prompts": []},
outputs=LLMResult(generations=[[]], llm_output=None), # type: ignore[arg-type] outputs=LLMResult(generations=[[]], llm_output=None), # type: ignore[arg-type]
run_type="llm", run_type="llm",
trace_id=chain_uuid, trace_id=chain_uuid,
@ -545,7 +545,7 @@ async def test_tracer_nested_runs_on_error() -> None:
extra={}, extra={},
serialized=SERIALIZED, serialized=SERIALIZED,
error=None, error=None,
inputs=dict(prompts=[]), inputs={"prompts": []},
outputs=LLMResult(generations=[[]], llm_output=None), # type: ignore[arg-type] outputs=LLMResult(generations=[[]], llm_output=None), # type: ignore[arg-type]
run_type="llm", run_type="llm",
trace_id=chain_uuid, trace_id=chain_uuid,
@ -563,7 +563,7 @@ async def test_tracer_nested_runs_on_error() -> None:
extra={}, extra={},
serialized={"name": "tool"}, serialized={"name": "tool"},
error=repr(exception), error=repr(exception),
inputs=dict(input="test"), inputs={"input": "test"},
outputs=None, outputs=None,
trace_id=chain_uuid, trace_id=chain_uuid,
dotted_order=f"20230101T000000000000Z{chain_uuid}.20230101T000000000000Z{tool_uuid}", dotted_order=f"20230101T000000000000Z{chain_uuid}.20230101T000000000000Z{tool_uuid}",
@ -580,7 +580,7 @@ async def test_tracer_nested_runs_on_error() -> None:
extra={}, extra={},
serialized=SERIALIZED, serialized=SERIALIZED,
error=repr(exception), error=repr(exception),
inputs=dict(prompts=[]), inputs={"prompts": []},
outputs=None, outputs=None,
run_type="llm", run_type="llm",
trace_id=chain_uuid, trace_id=chain_uuid,

View File

@ -103,7 +103,7 @@ def test_tracer_chat_model_run() -> None:
], ],
extra={}, extra={},
serialized=SERIALIZED_CHAT, serialized=SERIALIZED_CHAT,
inputs=dict(prompts=["Human: "]), inputs={"prompts": ["Human: "]},
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type] outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
error=None, error=None,
run_type="llm", run_type="llm",
@ -139,7 +139,7 @@ def test_tracer_multiple_llm_runs() -> None:
], ],
extra={}, extra={},
serialized=SERIALIZED, serialized=SERIALIZED,
inputs=dict(prompts=[]), inputs={"prompts": []},
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type] outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
error=None, error=None,
run_type="llm", run_type="llm",
@ -275,8 +275,8 @@ def test_tracer_nested_run() -> None:
], ],
extra={}, extra={},
serialized={"name": "tool"}, serialized={"name": "tool"},
inputs=dict(input="test"), inputs={"input": "test"},
outputs=dict(output="test"), outputs={"output": "test"},
error=None, error=None,
run_type="tool", run_type="tool",
trace_id=chain_uuid, trace_id=chain_uuid,
@ -294,7 +294,7 @@ def test_tracer_nested_run() -> None:
], ],
extra={}, extra={},
serialized=SERIALIZED, serialized=SERIALIZED,
inputs=dict(prompts=[]), inputs={"prompts": []},
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type] outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
run_type="llm", run_type="llm",
trace_id=chain_uuid, trace_id=chain_uuid,
@ -314,7 +314,7 @@ def test_tracer_nested_run() -> None:
], ],
extra={}, extra={},
serialized=SERIALIZED, serialized=SERIALIZED,
inputs=dict(prompts=[]), inputs={"prompts": []},
outputs=LLMResult(generations=[[]]), # type: ignore[arg-type] outputs=LLMResult(generations=[[]]), # type: ignore[arg-type]
run_type="llm", run_type="llm",
trace_id=chain_uuid, trace_id=chain_uuid,
@ -342,7 +342,7 @@ def test_tracer_llm_run_on_error() -> None:
], ],
extra={}, extra={},
serialized=SERIALIZED, serialized=SERIALIZED,
inputs=dict(prompts=[]), inputs={"prompts": []},
outputs=None, outputs=None,
error=repr(exception), error=repr(exception),
run_type="llm", run_type="llm",
@ -373,7 +373,7 @@ def test_tracer_llm_run_on_error_callback() -> None:
], ],
extra={}, extra={},
serialized=SERIALIZED, serialized=SERIALIZED,
inputs=dict(prompts=[]), inputs={"prompts": []},
outputs=None, outputs=None,
error=repr(exception), error=repr(exception),
run_type="llm", run_type="llm",
@ -439,7 +439,7 @@ def test_tracer_tool_run_on_error() -> None:
], ],
extra={}, extra={},
serialized={"name": "tool"}, serialized={"name": "tool"},
inputs=dict(input="test"), inputs={"input": "test"},
outputs=None, outputs=None,
error=repr(exception), error=repr(exception),
run_type="tool", run_type="tool",
@ -528,7 +528,7 @@ def test_tracer_nested_runs_on_error() -> None:
extra={}, extra={},
serialized=SERIALIZED, serialized=SERIALIZED,
error=None, error=None,
inputs=dict(prompts=[]), inputs={"prompts": []},
outputs=LLMResult(generations=[[]], llm_output=None), # type: ignore[arg-type] outputs=LLMResult(generations=[[]], llm_output=None), # type: ignore[arg-type]
run_type="llm", run_type="llm",
trace_id=chain_uuid, trace_id=chain_uuid,
@ -546,7 +546,7 @@ def test_tracer_nested_runs_on_error() -> None:
extra={}, extra={},
serialized=SERIALIZED, serialized=SERIALIZED,
error=None, error=None,
inputs=dict(prompts=[]), inputs={"prompts": []},
outputs=LLMResult(generations=[[]], llm_output=None), # type: ignore[arg-type] outputs=LLMResult(generations=[[]], llm_output=None), # type: ignore[arg-type]
run_type="llm", run_type="llm",
trace_id=chain_uuid, trace_id=chain_uuid,
@ -564,7 +564,7 @@ def test_tracer_nested_runs_on_error() -> None:
extra={}, extra={},
serialized={"name": "tool"}, serialized={"name": "tool"},
error=repr(exception), error=repr(exception),
inputs=dict(input="test"), inputs={"input": "test"},
outputs=None, outputs=None,
trace_id=chain_uuid, trace_id=chain_uuid,
dotted_order=f"20230101T000000000000Z{chain_uuid}.20230101T000000000000Z{tool_uuid}", dotted_order=f"20230101T000000000000Z{chain_uuid}.20230101T000000000000Z{tool_uuid}",
@ -581,7 +581,7 @@ def test_tracer_nested_runs_on_error() -> None:
extra={}, extra={},
serialized=SERIALIZED, serialized=SERIALIZED,
error=repr(exception), error=repr(exception),
inputs=dict(prompts=[]), inputs={"prompts": []},
outputs=None, outputs=None,
run_type="llm", run_type="llm",
trace_id=chain_uuid, trace_id=chain_uuid,