langchain: Add ruff rules SIM (#31881)

See https://docs.astral.sh/ruff/rules/#flake8-simplify-sim

Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
Christophe Bornet 2025-07-07 16:27:04 +02:00 committed by GitHub
parent f06380516f
commit 2df3fdf40d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 138 additions and 196 deletions

View File

@ -84,28 +84,28 @@ def create_importer(
not is_interactive_env()
and deprecated_lookups
and name in deprecated_lookups
):
# Depth 3:
# internal.py
# module_import.py
# Module in langchain that uses this function
# [calling code] whose frame we want to inspect.
if not internal.is_caller_internal(depth=3):
warn_deprecated(
since="0.1",
pending=False,
removal="1.0",
message=(
f"Importing {name} from {package} is deprecated. "
f"Please replace deprecated imports:\n\n"
f">> from {package} import {name}\n\n"
"with new imports of:\n\n"
f">> from {new_module} import {name}\n"
"You can use the langchain cli to **automatically** "
"upgrade many imports. Please see documentation here "
"<https://python.langchain.com/docs/versions/v0_2/>"
),
)
and not internal.is_caller_internal(depth=3)
):
warn_deprecated(
since="0.1",
pending=False,
removal="1.0",
message=(
f"Importing {name} from {package} is deprecated. "
f"Please replace deprecated imports:\n\n"
f">> from {package} import {name}\n\n"
"with new imports of:\n\n"
f">> from {new_module} import {name}\n"
"You can use the langchain cli to **automatically** "
"upgrade many imports. Please see documentation here "
"<https://python.langchain.com/docs/versions/v0_2/>"
),
)
return result
except Exception as e:
msg = f"module {new_module} has no attribute {name}"
@ -115,28 +115,30 @@ def create_importer(
try:
module = importlib.import_module(fallback_module)
result = getattr(module, name)
if not is_interactive_env():
if (
not is_interactive_env()
# Depth 3:
# internal.py
# module_import.py
# Module in langchain that uses this function
# [calling code] whose frame we want to inspect.
if not internal.is_caller_internal(depth=3):
warn_deprecated(
since="0.1",
pending=False,
removal="1.0",
message=(
f"Importing {name} from {package} is deprecated. "
f"Please replace deprecated imports:\n\n"
f">> from {package} import {name}\n\n"
"with new imports of:\n\n"
f">> from {fallback_module} import {name}\n"
"You can use the langchain cli to **automatically** "
"upgrade many imports. Please see documentation here "
"<https://python.langchain.com/docs/versions/v0_2/>"
),
)
and not internal.is_caller_internal(depth=3)
):
warn_deprecated(
since="0.1",
pending=False,
removal="1.0",
message=(
f"Importing {name} from {package} is deprecated. "
f"Please replace deprecated imports:\n\n"
f">> from {package} import {name}\n\n"
"with new imports of:\n\n"
f">> from {fallback_module} import {name}\n"
"You can use the langchain cli to **automatically** "
"upgrade many imports. Please see documentation here "
"<https://python.langchain.com/docs/versions/v0_2/>"
),
)
return result
except Exception as e:

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio
import builtins
import contextlib
import json
import logging
import time
@ -196,10 +197,7 @@ class BaseSingleActionAgent(BaseModel):
agent.agent.save(file_path="path/agent.yaml")
"""
# Convert file to Path object.
if isinstance(file_path, str):
save_path = Path(file_path)
else:
save_path = file_path
save_path = Path(file_path) if isinstance(file_path, str) else file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
@ -322,10 +320,8 @@ class BaseMultiActionAgent(BaseModel):
def dict(self, **kwargs: Any) -> builtins.dict:
"""Return dictionary representation of agent."""
_dict = super().model_dump()
try:
with contextlib.suppress(NotImplementedError):
_dict["_type"] = str(self._agent_type)
except NotImplementedError:
pass
return _dict
def save(self, file_path: Union[Path, str]) -> None:
@ -345,10 +341,7 @@ class BaseMultiActionAgent(BaseModel):
agent.agent.save(file_path="path/agent.yaml")
"""
# Convert file to Path object.
if isinstance(file_path, str):
save_path = Path(file_path)
else:
save_path = file_path
save_path = Path(file_path) if isinstance(file_path, str) else file_path
# Fetch dictionary to save
agent_dict = self.dict()
@ -1133,13 +1126,14 @@ class AgentExecutor(Chain):
agent = self.agent
tools = self.tools
allowed_tools = agent.get_allowed_tools() # type: ignore[union-attr]
if allowed_tools is not None:
if set(allowed_tools) != set([tool.name for tool in tools]):
msg = (
f"Allowed tools ({allowed_tools}) different than "
f"provided tools ({[tool.name for tool in tools]})"
)
raise ValueError(msg)
if allowed_tools is not None and set(allowed_tools) != set(
[tool.name for tool in tools]
):
msg = (
f"Allowed tools ({allowed_tools}) different than "
f"provided tools ({[tool.name for tool in tools]})"
)
raise ValueError(msg)
return self
@model_validator(mode="before")
@ -1272,13 +1266,7 @@ class AgentExecutor(Chain):
def _should_continue(self, iterations: int, time_elapsed: float) -> bool:
if self.max_iterations is not None and iterations >= self.max_iterations:
return False
if (
self.max_execution_time is not None
and time_elapsed >= self.max_execution_time
):
return False
return True
return self.max_execution_time is None or time_elapsed < self.max_execution_time
def _return(
self,
@ -1410,10 +1398,7 @@ class AgentExecutor(Chain):
return
actions: list[AgentAction]
if isinstance(output, AgentAction):
actions = [output]
else:
actions = output
actions = [output] if isinstance(output, AgentAction) else output
for agent_action in actions:
yield agent_action
for agent_action in actions:
@ -1547,10 +1532,7 @@ class AgentExecutor(Chain):
return
actions: list[AgentAction]
if isinstance(output, AgentAction):
actions = [output]
else:
actions = output
actions = [output] if isinstance(output, AgentAction) else output
for agent_action in actions:
yield agent_action
@ -1728,12 +1710,14 @@ class AgentExecutor(Chain):
if len(self._action_agent.return_values) > 0:
return_value_key = self._action_agent.return_values[0]
# Invalid tools won't be in the map, so we return False.
if agent_action.tool in name_to_tool_map:
if name_to_tool_map[agent_action.tool].return_direct:
return AgentFinish(
{return_value_key: observation},
"",
)
if (
agent_action.tool in name_to_tool_map
and name_to_tool_map[agent_action.tool].return_direct
):
return AgentFinish(
{return_value_key: observation},
"",
)
return None
def _prepare_intermediate_steps(

View File

@ -1,5 +1,6 @@
"""Load agent."""
import contextlib
from collections.abc import Sequence
from typing import Any, Optional
@ -81,11 +82,9 @@ def initialize_agent(
agent_obj = load_agent(
agent_path, llm=llm, tools=tools, callback_manager=callback_manager
)
try:
with contextlib.suppress(NotImplementedError):
# TODO: Add tags from the serialized object directly.
tags_.append(agent_obj._agent_type)
except NotImplementedError:
pass
else:
msg = (
"Somehow both `agent` and `agent_path` are None, this should never happen."

View File

@ -128,10 +128,7 @@ def _load_agent_from_file(
"""Load agent from file."""
valid_suffixes = {"json", "yaml"}
# Convert file to Path object.
if isinstance(file, str):
file_path = Path(file)
else:
file_path = file
file_path = Path(file) if isinstance(file, str) else file
# Load from either json or yaml.
if file_path.suffix[1:] == "json":
with open(file_path) as f:

View File

@ -230,10 +230,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
"""
_prompts = extra_prompt_messages or []
messages: list[Union[BaseMessagePromptTemplate, BaseMessage]]
if system_message:
messages = [system_message]
else:
messages = []
messages = [system_message] if system_message else []
messages.extend(
[

View File

@ -278,10 +278,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
"""
_prompts = extra_prompt_messages or []
messages: list[Union[BaseMessagePromptTemplate, BaseMessage]]
if system_message:
messages = [system_message]
else:
messages = []
messages = [system_message] if system_message else []
messages.extend(
[

View File

@ -60,10 +60,7 @@ def parse_ai_message_to_tool_action(
# Open AI does not support passing in a JSON array as an argument.
function_name = tool_call["name"]
_tool_input = tool_call["args"]
if "__arg1" in _tool_input:
tool_input = _tool_input["__arg1"]
else:
tool_input = _tool_input
tool_input = _tool_input.get("__arg1", _tool_input)
content_msg = f"responded: {message.content}\n" if message.content else "\n"
log = f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n"

View File

@ -1,6 +1,7 @@
"""Base interface that all chains should implement."""
import builtins
import contextlib
import inspect
import json
import logging
@ -727,10 +728,8 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
# -> {"_type": "foo", "verbose": False, ...}
"""
_dict = super().dict(**kwargs)
try:
with contextlib.suppress(NotImplementedError):
_dict["_type"] = self._chain_type
except NotImplementedError:
pass
return _dict
def save(self, file_path: Union[Path, str]) -> None:
@ -758,10 +757,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
raise NotImplementedError(msg)
# Convert file to Path object.
if isinstance(file_path, str):
save_path = Path(file_path)
else:
save_path = file_path
save_path = Path(file_path) if isinstance(file_path, str) else file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)

View File

@ -683,10 +683,7 @@ def load_chain(path: Union[str, Path], **kwargs: Any) -> Chain:
def _load_chain_from_file(file: Union[str, Path], **kwargs: Any) -> Chain:
"""Load chain from file."""
# Convert file to Path object.
if isinstance(file, str):
file_path = Path(file)
else:
file_path = file
file_path = Path(file) if isinstance(file, str) else file
# Load from either json or yaml.
if file_path.suffix == ".json":
with open(file_path) as f:

View File

@ -95,10 +95,7 @@ class OpenAIModerationChain(Chain):
return [self.output_key]
def _moderate(self, text: str, results: Any) -> str:
if self.openai_pre_1_0:
condition = results["flagged"]
else:
condition = results.flagged
condition = results["flagged"] if self.openai_pre_1_0 else results.flagged
if condition:
error_str = "Text was found that violates OpenAI's content policy."
if self.error:

View File

@ -108,22 +108,26 @@ class QueryTransformer(Transformer):
def _match_func_name(self, func_name: str) -> Union[Operator, Comparator]:
if func_name in set(Comparator):
if self.allowed_comparators is not None:
if func_name not in self.allowed_comparators:
msg = (
f"Received disallowed comparator {func_name}. Allowed "
f"comparators are {self.allowed_comparators}"
)
raise ValueError(msg)
if (
self.allowed_comparators is not None
and func_name not in self.allowed_comparators
):
msg = (
f"Received disallowed comparator {func_name}. Allowed "
f"comparators are {self.allowed_comparators}"
)
raise ValueError(msg)
return Comparator(func_name)
elif func_name in set(Operator):
if self.allowed_operators is not None:
if func_name not in self.allowed_operators:
msg = (
f"Received disallowed operator {func_name}. Allowed operators"
f" are {self.allowed_operators}"
)
raise ValueError(msg)
if (
self.allowed_operators is not None
and func_name not in self.allowed_operators
):
msg = (
f"Received disallowed operator {func_name}. Allowed operators"
f" are {self.allowed_operators}"
)
raise ValueError(msg)
return Operator(func_name)
else:
msg = (

View File

@ -244,10 +244,7 @@ The following is the expected answer. Use this to measure correctness:
if not isinstance(llm, BaseChatModel):
msg = "Only chat models supported by the current trajectory eval"
raise NotImplementedError(msg)
if agent_tools:
prompt = EVAL_CHAT_PROMPT
else:
prompt = TOOL_FREE_EVAL_CHAT_PROMPT
prompt = EVAL_CHAT_PROMPT if agent_tools else TOOL_FREE_EVAL_CHAT_PROMPT
eval_chain = LLMChain(llm=llm, prompt=prompt)
return cls(
agent_tools=agent_tools, # type: ignore[arg-type]

View File

@ -36,13 +36,12 @@ class CombinedMemory(BaseMemory):
def check_input_key(cls, value: list[BaseMemory]) -> list[BaseMemory]:
"""Check that if memories are of type BaseChatMemory that input keys exist."""
for val in value:
if isinstance(val, BaseChatMemory):
if val.input_key is None:
warnings.warn(
"When using CombinedMemory, "
"input keys should be so the input is known. "
f" Was not set on {val}"
)
if isinstance(val, BaseChatMemory) and val.input_key is None:
warnings.warn(
"When using CombinedMemory, "
"input keys should be so the input is known. "
f" Was not set on {val}"
)
return value
@property

View File

@ -51,11 +51,9 @@ class ModelLaboratory:
"Currently only support chains with one output variable, "
f"got {chain.output_keys}"
)
raise ValueError(msg)
if names is not None:
if len(names) != len(chains):
msg = "Length of chains does not match length of names."
raise ValueError(msg)
if names is not None and len(names) != len(chains):
msg = "Length of chains does not match length of names."
raise ValueError(msg)
self.chains = chains
chain_range = [str(i) for i in range(len(self.chains))]
self.chain_colors = get_color_mapping(chain_range)
@ -93,10 +91,7 @@ class ModelLaboratory:
"""
print(f"\033[1mInput:\033[0m\n{text}\n") # noqa: T201
for i, chain in enumerate(self.chains):
if self.names is not None:
name = self.names[i]
else:
name = str(chain)
name = self.names[i] if self.names is not None else str(chain)
print_text(name, end="\n")
output = chain.run(text)
print_text(output, color=self.chain_colors[str(i)], end="\n\n")

View File

@ -10,14 +10,13 @@ def load_output_parser(config: dict) -> dict:
Returns:
config dict with output parser loaded
"""
if "output_parsers" in config:
if config["output_parsers"] is not None:
_config = config["output_parsers"]
output_parser_type = _config["_type"]
if output_parser_type == "regex_parser":
output_parser = RegexParser(**_config)
else:
msg = f"Unsupported output parser {output_parser_type}"
raise ValueError(msg)
config["output_parsers"] = output_parser
if "output_parsers" in config and config["output_parsers"] is not None:
_config = config["output_parsers"]
output_parser_type = _config["_type"]
if output_parser_type == "regex_parser":
output_parser = RegexParser(**_config)
else:
msg = f"Unsupported output parser {output_parser_type}"
raise ValueError(msg)
config["output_parsers"] = output_parser
return config

View File

@ -27,12 +27,8 @@ class YamlOutputParser(BaseOutputParser[T]):
try:
# Greedy search for 1st yaml candidate.
match = re.search(self.pattern, text.strip())
yaml_str = ""
if match:
yaml_str = match.group("yaml")
else:
# If no backticks were present, try to parse the entire output as yaml.
yaml_str = text
# If no backticks were present, try to parse the entire output as yaml.
yaml_str = match.group("yaml") if match else text
json_object = yaml.safe_load(yaml_str)
if hasattr(self.pydantic_object, "model_validate"):

View File

@ -123,10 +123,7 @@ class MultiQueryRetriever(BaseRetriever):
response = await self.llm_chain.ainvoke(
{"question": question}, config={"callbacks": run_manager.get_child()}
)
if isinstance(self.llm_chain, LLMChain):
lines = response["text"]
else:
lines = response
lines = response["text"] if isinstance(self.llm_chain, LLMChain) else response
if self.verbose:
logger.info(f"Generated queries: {lines}")
return lines
@ -186,10 +183,7 @@ class MultiQueryRetriever(BaseRetriever):
response = self.llm_chain.invoke(
{"question": question}, config={"callbacks": run_manager.get_child()}
)
if isinstance(self.llm_chain, LLMChain):
lines = response["text"]
else:
lines = response
lines = response["text"] if isinstance(self.llm_chain, LLMChain) else response
if self.verbose:
logger.info(f"Generated queries: {lines}")
return lines

View File

@ -57,9 +57,7 @@ class EvalConfig(BaseModel):
"""
kwargs = {}
for field, val in self:
if field == "evaluator_type":
continue
elif val is None:
if field == "evaluator_type" or val is None:
continue
kwargs[field] = val
return kwargs

View File

@ -143,7 +143,7 @@ ignore-regex = ".*(Stati Uniti|Tense=Pres).*"
ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin"
[tool.ruff.lint]
select = ["A", "E", "F", "I", "EM", "PGH003", "PIE", "T201", "D", "UP", "S", "W"]
select = ["A", "E", "F", "I", "EM", "PGH003", "PIE", "SIM", "T201", "D", "UP", "S", "W"]
pydocstyle.convention = "google"
pyupgrade.keep-runtime-typing = true

View File

@ -2,6 +2,8 @@
import asyncio
import json
import operator
from functools import reduce
from itertools import cycle
from typing import Any, Optional, Union, cast
@ -520,14 +522,7 @@ async def test_runnable_agent() -> None:
assert messages != []
# Aggregate state
run_log = None
for result in results:
if run_log is None:
run_log = result
else:
# `+` is defined for RunLogPatch
run_log = run_log + result
run_log = reduce(operator.add, results)
assert isinstance(run_log, RunLog)

View File

@ -17,7 +17,7 @@ def test_simple_memory() -> None:
output = memory.load_memory_variables({})
assert output == {"baz": "foo"}
assert ["baz"] == memory.memory_variables
assert memory.memory_variables == ["baz"]
@pytest.mark.parametrize(

View File

@ -1,5 +1,6 @@
"""Embeddings tests."""
import contextlib
import hashlib
import importlib
import warnings
@ -75,10 +76,8 @@ def test_embed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
def test_embed_documents_batch(cache_embeddings_batch: CacheBackedEmbeddings) -> None:
# "RAISE_EXCEPTION" forces a failure in batch 2
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
try:
with contextlib.suppress(ValueError):
cache_embeddings_batch.embed_documents(texts)
except ValueError:
pass
keys = list(cache_embeddings_batch.document_embedding_store.yield_keys())
# only the first batch of three embeddings should exist
assert len(keys) == 3
@ -122,10 +121,8 @@ async def test_aembed_documents_batch(
) -> None:
# "RAISE_EXCEPTION" forces a failure in batch 2
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
try:
with contextlib.suppress(ValueError):
await cache_embeddings_batch.aembed_documents(texts)
except ValueError:
pass
keys = [
key
async for key in cache_embeddings_batch.document_embedding_store.ayield_keys()

View File

@ -70,7 +70,7 @@ def test_non_zero_distance(distance: StringDistance, normalize_score: bool) -> N
reference = "I like apples."
result = eval_chain.evaluate_strings(prediction=prediction, reference=reference)
assert "score" in result
assert 0 < result["score"]
assert result["score"] > 0
if normalize_score:
assert result["score"] < 1.0

View File

@ -45,7 +45,7 @@ def test_combining_dict_result() -> None:
]
combining_parser = CombiningOutputParser(parsers=parsers)
result_dict = combining_parser.parse(DEF_README)
assert DEF_EXPECTED_RESULT == result_dict
assert result_dict == DEF_EXPECTED_RESULT
def test_combining_output_parser_output_type() -> None:

View File

@ -47,7 +47,7 @@ def test_pandas_output_parser_col_first_elem() -> None:
def test_pandas_output_parser_col_multi_elem() -> None:
expected_output = {"chicken": pd.Series([1, 2], name="chicken", dtype="int64")}
actual_output = parser.parse("column:chicken[0, 1]")
for key in actual_output.keys():
for key in actual_output:
assert expected_output["chicken"].equals(actual_output[key])

View File

@ -23,7 +23,7 @@ def test_regex_parser_parse() -> None:
output_keys=["confidence", "explanation"],
default_output_key="noConfidence",
)
assert DEF_EXPECTED_RESULT == parser.parse(DEF_README)
assert parser.parse(DEF_README) == DEF_EXPECTED_RESULT
def test_regex_parser_output_type() -> None:

View File

@ -35,7 +35,7 @@ def test_regex_dict_result() -> None:
)
result_dict = regex_dict_parser.parse(DEF_README)
print("parse_result:", result_dict) # noqa: T201
assert DEF_EXPECTED_RESULT == result_dict
assert result_dict == DEF_EXPECTED_RESULT
def test_regex_dict_output_type() -> None:

View File

@ -76,7 +76,7 @@ def test_yaml_output_parser(result: str) -> None:
model = yaml_parser.parse(result)
print("parse_result:", result) # noqa: T201
assert DEF_EXPECTED_RESULT == model
assert model == DEF_EXPECTED_RESULT
def test_yaml_output_parser_fail() -> None:

View File

@ -124,9 +124,12 @@ def extract_deprecated_lookup(file_path: str) -> Optional[dict[str, Any]]:
for node in ast.walk(tree):
if isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(target, ast.Name) and target.id == "DEPRECATED_LOOKUP":
if isinstance(node.value, ast.Dict):
return _dict_from_ast(node.value)
if (
isinstance(target, ast.Name)
and target.id == "DEPRECATED_LOOKUP"
and isinstance(node.value, ast.Dict)
):
return _dict_from_ast(node.value)
return None
@ -156,8 +159,10 @@ def _literal_eval_str(node: ast.AST) -> str:
Returns:
str: The corresponding string value.
"""
if isinstance(node, ast.Constant): # Python 3.8+
if isinstance(node.value, str):
return node.value
if (
isinstance(node, ast.Constant) # Python 3.8+
and isinstance(node.value, str)
):
return node.value
msg = f"Invalid DEPRECATED_LOOKUP format: expected str, got {type(node).__name__}"
raise AssertionError(msg)