langchain: Add ruff rules C4 (#31879)

All auto-fixes
See https://docs.astral.sh/ruff/rules/#flake8-comprehensions-c4

---------

Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
Christophe Bornet 2025-07-07 16:55:52 +02:00 committed by GitHub
parent 4134b36db8
commit fceebbb387
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 97 additions and 96 deletions

View File

@ -1126,9 +1126,9 @@ class AgentExecutor(Chain):
agent = self.agent
tools = self.tools
allowed_tools = agent.get_allowed_tools() # type: ignore[union-attr]
if allowed_tools is not None and set(allowed_tools) != set(
[tool.name for tool in tools]
):
if allowed_tools is not None and set(allowed_tools) != {
tool.name for tool in tools
}:
msg = (
f"Allowed tools ({allowed_tools}) different than "
f"provided tools ({[tool.name for tool in tools]})"
@ -1318,16 +1318,15 @@ class AgentExecutor(Chain):
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Union[AgentFinish, list[tuple[AgentAction, str]]]:
return self._consume_next_step(
[
a
for a in self._iter_next_step(
list(
self._iter_next_step(
name_to_tool_map,
color_mapping,
inputs,
intermediate_steps,
run_manager,
)
]
)
)
def _iter_next_step(

View File

@ -37,7 +37,7 @@ class SelfAskOutputParser(AgentOutputParser):
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
last_line = text.split("\n")[-1]
if not any([follow in last_line for follow in self.followups]):
if not any(follow in last_line for follow in self.followups):
if self.finish_string not in last_line:
msg = f"Could not parse output: {text}"
raise OutputParserException(msg)

View File

@ -24,7 +24,7 @@ class InvalidTool(BaseTool):
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Use the tool."""
available_tool_names_str = ", ".join([tool for tool in available_tool_names])
available_tool_names_str = ", ".join(list(available_tool_names))
return (
f"{requested_tool_name} is not a valid tool, "
f"try one of [{available_tool_names_str}]."
@ -37,7 +37,7 @@ class InvalidTool(BaseTool):
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
"""Use the tool asynchronously."""
available_tool_names_str = ", ".join([tool for tool in available_tool_names])
available_tool_names_str = ", ".join(list(available_tool_names))
return (
f"{requested_tool_name} is not a valid tool, "
f"try one of [{available_tool_names_str}]."

View File

@ -352,7 +352,7 @@ try:
headers: Optional[dict] = None,
api_url_prompt: BasePromptTemplate = API_URL_PROMPT,
api_response_prompt: BasePromptTemplate = API_RESPONSE_PROMPT,
limit_to_domains: Optional[Sequence[str]] = tuple(),
limit_to_domains: Optional[Sequence[str]] = (),
**kwargs: Any,
) -> APIChain:
"""Load chain from just an LLM and the api docs."""

View File

@ -112,13 +112,15 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
self, config: Optional[RunnableConfig] = None
) -> type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model("ChainInput", **{k: (Any, None) for k in self.input_keys})
return create_model("ChainInput", **dict.fromkeys(self.input_keys, (Any, None)))
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model("ChainOutput", **{k: (Any, None) for k in self.output_keys})
return create_model(
"ChainOutput", **dict.fromkeys(self.output_keys, (Any, None))
)
@override
def invoke(

View File

@ -100,7 +100,7 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
if self.return_intermediate_steps:
schema["intermediate_steps"] = (list[str], None)
if self.metadata_keys:
schema.update({key: (Any, None) for key in self.metadata_keys})
schema.update(dict.fromkeys(self.metadata_keys, (Any, None)))
return create_model("MapRerankOutput", **schema)

View File

@ -48,7 +48,7 @@ class SequentialChain(Chain):
"""Validate that the correct inputs exist for all chains."""
chains = values["chains"]
input_variables = values["input_variables"]
memory_keys = list()
memory_keys = []
if "memory" in values and values["memory"] is not None:
"""Validate that prompt input variables are consistent."""
memory_keys = values["memory"].memory_variables

View File

@ -69,7 +69,7 @@ def load_dataset(uri: str) -> list[dict]:
raise ImportError(msg)
dataset = load_dataset(f"LangChainDatasets/{uri}")
return [d for d in dataset["train"]]
return list(dataset["train"])
_EVALUATOR_MAP: dict[

View File

@ -311,10 +311,10 @@ class SQLRecordManager(RecordManager):
).values(records_to_upsert)
stmt = sqlite_insert_stmt.on_conflict_do_update(
[UpsertionRecord.key, UpsertionRecord.namespace],
set_=dict(
updated_at=sqlite_insert_stmt.excluded.updated_at,
group_id=sqlite_insert_stmt.excluded.group_id,
),
set_={
"updated_at": sqlite_insert_stmt.excluded.updated_at,
"group_id": sqlite_insert_stmt.excluded.group_id,
},
)
elif self.dialect == "postgresql":
from sqlalchemy.dialects.postgresql import Insert as PgInsertType
@ -327,10 +327,10 @@ class SQLRecordManager(RecordManager):
)
stmt = pg_insert_stmt.on_conflict_do_update( # type: ignore[assignment]
"uix_key_namespace", # Name of constraint
set_=dict(
updated_at=pg_insert_stmt.excluded.updated_at,
group_id=pg_insert_stmt.excluded.group_id,
),
set_={
"updated_at": pg_insert_stmt.excluded.updated_at,
"group_id": pg_insert_stmt.excluded.group_id,
},
)
else:
msg = f"Unsupported dialect {self.dialect}"
@ -393,10 +393,10 @@ class SQLRecordManager(RecordManager):
).values(records_to_upsert)
stmt = sqlite_insert_stmt.on_conflict_do_update(
[UpsertionRecord.key, UpsertionRecord.namespace],
set_=dict(
updated_at=sqlite_insert_stmt.excluded.updated_at,
group_id=sqlite_insert_stmt.excluded.group_id,
),
set_={
"updated_at": sqlite_insert_stmt.excluded.updated_at,
"group_id": sqlite_insert_stmt.excluded.group_id,
},
)
elif self.dialect == "postgresql":
from sqlalchemy.dialects.postgresql import Insert as PgInsertType
@ -409,10 +409,10 @@ class SQLRecordManager(RecordManager):
)
stmt = pg_insert_stmt.on_conflict_do_update( # type: ignore[assignment]
"uix_key_namespace", # Name of constraint
set_=dict(
updated_at=pg_insert_stmt.excluded.updated_at,
group_id=pg_insert_stmt.excluded.group_id,
),
set_={
"updated_at": pg_insert_stmt.excluded.updated_at,
"group_id": pg_insert_stmt.excluded.group_id,
},
)
else:
msg = f"Unsupported dialect {self.dialect}"
@ -432,7 +432,7 @@ class SQLRecordManager(RecordManager):
)
)
records = filtered_query.all()
found_keys = set(r.key for r in records)
found_keys = {r.key for r in records}
return [k in found_keys for k in keys]
async def aexists(self, keys: Sequence[str]) -> list[bool]:

View File

@ -8,7 +8,7 @@ class SimpleMemory(BaseMemory):
ever change between prompts.
"""
memories: dict[str, Any] = dict()
memories: dict[str, Any] = {}
@property
def memory_variables(self) -> list[str]:

View File

@ -49,7 +49,7 @@ class CombiningOutputParser(BaseOutputParser[dict[str, Any]]):
def parse(self, text: str) -> dict[str, Any]:
"""Parse the output of an LLM call."""
texts = text.split("\n\n")
output = dict()
output = {}
for txt, parser in zip(texts, self.parsers):
output.update(parser.parse(txt.strip()))
return output

View File

@ -82,19 +82,19 @@ class OutputFixingParser(BaseOutputParser[T]):
else:
try:
completion = self.retry_chain.invoke(
dict(
instructions=self.parser.get_format_instructions(),
completion=completion,
error=repr(e),
)
{
"instructions": self.parser.get_format_instructions(), # noqa: E501
"completion": completion,
"error": repr(e),
}
)
except (NotImplementedError, AttributeError):
# Case: self.parser does not have get_format_instructions
completion = self.retry_chain.invoke(
dict(
completion=completion,
error=repr(e),
)
{
"completion": completion,
"error": repr(e),
}
)
msg = "Failed to parse"
@ -120,19 +120,19 @@ class OutputFixingParser(BaseOutputParser[T]):
else:
try:
completion = await self.retry_chain.ainvoke(
dict(
instructions=self.parser.get_format_instructions(),
completion=completion,
error=repr(e),
)
{
"instructions": self.parser.get_format_instructions(), # noqa: E501
"completion": completion,
"error": repr(e),
}
)
except (NotImplementedError, AttributeError):
# Case: self.parser does not have get_format_instructions
completion = await self.retry_chain.ainvoke(
dict(
completion=completion,
error=repr(e),
)
{
"completion": completion,
"error": repr(e),
}
)
msg = "Failed to parse"

View File

@ -116,10 +116,10 @@ class RetryOutputParser(BaseOutputParser[T]):
)
else:
completion = self.retry_chain.invoke(
dict(
prompt=prompt_value.to_string(),
completion=completion,
)
{
"prompt": prompt_value.to_string(),
"completion": completion,
}
)
msg = "Failed to parse"
@ -153,10 +153,10 @@ class RetryOutputParser(BaseOutputParser[T]):
)
else:
completion = await self.retry_chain.ainvoke(
dict(
prompt=prompt_value.to_string(),
completion=completion,
)
{
"prompt": prompt_value.to_string(),
"completion": completion,
}
)
msg = "Failed to parse"
@ -244,11 +244,11 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
)
else:
completion = self.retry_chain.invoke(
dict(
completion=completion,
prompt=prompt_value.to_string(),
error=repr(e),
)
{
"completion": completion,
"prompt": prompt_value.to_string(),
"error": repr(e),
}
)
msg = "Failed to parse"
@ -273,11 +273,11 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
)
else:
completion = await self.retry_chain.ainvoke(
dict(
prompt=prompt_value.to_string(),
completion=completion,
error=repr(e),
)
{
"prompt": prompt_value.to_string(),
"completion": completion,
"error": repr(e),
}
)
msg = "Failed to parse"

View File

@ -43,7 +43,7 @@ class YamlOutputParser(BaseOutputParser[T]):
def get_format_instructions(self) -> str:
# Copy schema to avoid altering original Pydantic schema.
schema = {k: v for k, v in self.pydantic_object.schema().items()}
schema = dict(self.pydantic_object.schema().items())
# Remove extraneous fields.
reduced_schema = schema

View File

@ -24,7 +24,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever):
vectorstore: VectorStore
"""The vectorstore to store documents and determine salience."""
search_kwargs: dict = Field(default_factory=lambda: dict(k=100))
search_kwargs: dict = Field(default_factory=lambda: {"k": 100})
"""Keyword arguments to pass to the vectorstore similarity search."""
# TODO: abstract as a queue

View File

@ -623,7 +623,7 @@ def _load_run_evaluators(
input_key, prediction_key, reference_key = None, None, None
if config.evaluators or (
config.custom_evaluators
and any([isinstance(e, StringEvaluator) for e in config.custom_evaluators])
and any(isinstance(e, StringEvaluator) for e in config.custom_evaluators)
):
input_key, prediction_key, reference_key = _get_keys(
config, run_inputs, run_outputs, example_outputs

View File

@ -143,7 +143,7 @@ ignore-regex = ".*(Stati Uniti|Tense=Pres).*"
ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin"
[tool.ruff.lint]
select = ["A", "E", "F", "I", "EM", "PGH003", "PIE", "SIM", "T201", "D", "UP", "S", "W"]
select = ["A", "C4", "E", "F", "I", "EM", "PGH003", "PIE", "SIM", "T201", "D", "UP", "S", "W"]
pydocstyle.convention = "google"
pyupgrade.keep-runtime-typing = true

View File

@ -205,7 +205,7 @@ def test_agent_stream() -> None:
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
)
output = [a for a in agent.stream("when was langchain made")]
output = list(agent.stream("when was langchain made"))
assert output == [
{
"actions": [

View File

@ -24,7 +24,7 @@ def test_resolve_criteria_enum(criterion: Criteria) -> None:
def test_resolve_criteria_list_enum() -> None:
val = resolve_pairwise_criteria(list(Criteria))
assert isinstance(val, dict)
assert set(val.keys()) == set(c.value for c in list(Criteria))
assert set(val.keys()) == {c.value for c in list(Criteria)}
def test_PairwiseStringResultOutputParser_parse() -> None:

View File

@ -288,11 +288,11 @@ def test_index_simple_delete_full(
"num_updated": 0,
}
doc_texts = set(
doc_texts = {
# Ignoring type since doc should be in the store and not a None
vector_store.store.get(uid).page_content # type: ignore[union-attr]
for uid in vector_store.store
)
}
assert doc_texts == {"mutated document 1", "This is another document."}
# Attempt to index again verify that nothing changes
@ -364,11 +364,11 @@ async def test_aindex_simple_delete_full(
"num_updated": 0,
}
doc_texts = set(
doc_texts = {
# Ignoring type since doc should be in the store and not a None
vector_store.store.get(uid).page_content # type: ignore[union-attr]
for uid in vector_store.store
)
}
assert doc_texts == {"mutated document 1", "This is another document."}
# Attempt to index again verify that nothing changes
@ -657,11 +657,11 @@ def test_incremental_delete(
"num_updated": 0,
}
doc_texts = set(
doc_texts = {
# Ignoring type since doc should be in the store and not a None
vector_store.store.get(uid).page_content # type: ignore[union-attr]
for uid in vector_store.store
)
}
assert doc_texts == {"This is another document.", "This is a test document."}
# Attempt to index again verify that nothing changes
@ -716,11 +716,11 @@ def test_incremental_delete(
"num_updated": 0,
}
doc_texts = set(
doc_texts = {
# Ignoring type since doc should be in the store and not a None
vector_store.store.get(uid).page_content # type: ignore[union-attr]
for uid in vector_store.store
)
}
assert doc_texts == {
"mutated document 1",
"mutated document 2",
@ -784,11 +784,11 @@ def test_incremental_indexing_with_batch_size(
"num_updated": 0,
}
doc_texts = set(
doc_texts = {
# Ignoring type since doc should be in the store and not a None
vector_store.store.get(uid).page_content # type: ignore[union-attr]
for uid in vector_store.store
)
}
assert doc_texts == {"1", "2", "3", "4"}
@ -834,11 +834,11 @@ def test_incremental_delete_with_batch_size(
"num_updated": 0,
}
doc_texts = set(
doc_texts = {
# Ignoring type since doc should be in the store and not a None
vector_store.store.get(uid).page_content # type: ignore[union-attr]
for uid in vector_store.store
)
}
assert doc_texts == {"1", "2", "3", "4"}
# Attempt to index again verify that nothing changes
@ -980,11 +980,11 @@ async def test_aincremental_delete(
"num_updated": 0,
}
doc_texts = set(
doc_texts = {
# Ignoring type since doc should be in the store and not a None
vector_store.store.get(uid).page_content # type: ignore[union-attr]
for uid in vector_store.store
)
}
assert doc_texts == {"This is another document.", "This is a test document."}
# Attempt to index again verify that nothing changes
@ -1039,11 +1039,11 @@ async def test_aincremental_delete(
"num_updated": 0,
}
doc_texts = set(
doc_texts = {
# Ignoring type since doc should be in the store and not a None
vector_store.store.get(uid).page_content # type: ignore[union-attr]
for uid in vector_store.store
)
}
assert doc_texts == {
"mutated document 1",
"mutated document 2",

View File

@ -51,7 +51,7 @@ async def test_generic_fake_chat_model_stream() -> None:
_AnyIdAIMessageChunk(content="goodbye"),
]
chunks = [chunk for chunk in model.stream("meow")]
chunks = list(model.stream("meow"))
assert chunks == [
_AnyIdAIMessageChunk(content="hello"),
_AnyIdAIMessageChunk(content=" "),

View File

@ -28,7 +28,7 @@ def test_required_dependencies(uv_conf: Mapping[str, Any]) -> None:
"""
# Get the dependencies from the [tool.poetry.dependencies] section
dependencies = uv_conf["project"]["dependencies"]
required_dependencies = set(Requirement(dep).name for dep in dependencies)
required_dependencies = {Requirement(dep).name for dep in dependencies}
assert sorted(required_dependencies) == sorted(
[
@ -54,7 +54,7 @@ def test_test_group_dependencies(uv_conf: Mapping[str, Any]) -> None:
"""
dependencies = uv_conf["dependency-groups"]["test"]
test_group_deps = set(Requirement(dep).name for dep in dependencies)
test_group_deps = {Requirement(dep).name for dep in dependencies}
assert sorted(test_group_deps) == sorted(
[