feat(langchain): add ruff rules PL (#32079)

See https://docs.astral.sh/ruff/rules/#pylint-pl
This commit is contained in:
Christophe Bornet 2025-07-23 05:55:32 +02:00 committed by GitHub
parent 0f39155f62
commit 3496e1739e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 97 additions and 102 deletions

View File

@ -357,7 +357,7 @@ def __getattr__(name: str) -> Any:
return ElasticVectorSearch
# For backwards compatibility
if name == "SerpAPIChain" or name == "SerpAPIWrapper":
if name in {"SerpAPIChain", "SerpAPIWrapper"}:
from langchain_community.utilities import SerpAPIWrapper
_warn_on_import(

View File

@ -40,13 +40,13 @@ def format_xml(
# Escape XML tags in tool names and inputs using custom delimiters
tool = _escape(action.tool)
tool_input = _escape(str(action.tool_input))
observation = _escape(str(observation))
observation_ = _escape(str(observation))
else:
tool = action.tool
tool_input = str(action.tool_input)
observation = str(observation)
observation_ = str(observation)
log += (
f"<tool>{tool}</tool><tool_input>{tool_input}"
f"</tool_input><observation>{observation}</observation>"
f"</tool_input><observation>{observation_}</observation>"
)
return log

View File

@ -582,7 +582,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
major_version = int(openai.version.VERSION.split(".")[0])
minor_version = int(openai.version.VERSION.split(".")[1])
version_gte_1_14 = (major_version > 1) or (
major_version == 1 and minor_version >= 14
major_version == 1 and minor_version >= 14 # noqa: PLR2004
)
messages = self.client.beta.threads.messages.list(
@ -739,7 +739,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
major_version = int(openai.version.VERSION.split(".")[0])
minor_version = int(openai.version.VERSION.split(".")[1])
version_gte_1_14 = (major_version > 1) or (
major_version == 1 and minor_version >= 14
major_version == 1 and minor_version >= 14 # noqa: PLR2004
)
messages = await self.async_client.beta.threads.messages.list(

View File

@ -24,6 +24,9 @@ if TYPE_CHECKING:
from langchain_community.docstore.base import Docstore
_LOOKUP_AND_SEARCH_TOOLS = {"Lookup", "Search"}
@deprecated(
"0.1.0",
message=AGENT_DEPRECATION_WARNING,
@ -52,11 +55,11 @@ class ReActDocstoreAgent(Agent):
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
validate_tools_single_input(cls.__name__, tools)
super()._validate_tools(tools)
if len(tools) != 2:
if len(tools) != len(_LOOKUP_AND_SEARCH_TOOLS):
msg = f"Exactly two tools must be specified, but got {tools}"
raise ValueError(msg)
tool_names = {tool.name for tool in tools}
if tool_names != {"Lookup", "Search"}:
if tool_names != _LOOKUP_AND_SEARCH_TOOLS:
msg = f"Tool names should be Lookup and Search, got {tool_names}"
raise ValueError(msg)

View File

@ -47,8 +47,8 @@ def _check_in_allowed_domain(url: str, limit_to_domains: Sequence[str]) -> bool:
scheme, domain = _extract_scheme_and_domain(url)
for allowed_domain in limit_to_domains:
allowed_scheme, allowed_domain = _extract_scheme_and_domain(allowed_domain)
if scheme == allowed_scheme and domain == allowed_domain:
allowed_scheme, allowed_domain_ = _extract_scheme_and_domain(allowed_domain)
if scheme == allowed_scheme and domain == allowed_domain_:
return True
return False

View File

@ -193,13 +193,12 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
"multiple llm_chain input_variables"
)
raise ValueError(msg)
else:
if values["document_variable_name"] not in llm_chain_variables:
msg = (
f"document_variable_name {values['document_variable_name']} was "
f"not found in llm_chain input_variables: {llm_chain_variables}"
)
raise ValueError(msg)
elif values["document_variable_name"] not in llm_chain_variables:
msg = (
f"document_variable_name {values['document_variable_name']} was "
f"not found in llm_chain input_variables: {llm_chain_variables}"
)
raise ValueError(msg)
return values
@property

View File

@ -161,13 +161,12 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
"multiple llm_chain input_variables"
)
raise ValueError(msg)
else:
if values["document_variable_name"] not in llm_chain_variables:
msg = (
f"document_variable_name {values['document_variable_name']} was "
f"not found in llm_chain input_variables: {llm_chain_variables}"
)
raise ValueError(msg)
elif values["document_variable_name"] not in llm_chain_variables:
msg = (
f"document_variable_name {values['document_variable_name']} was "
f"not found in llm_chain input_variables: {llm_chain_variables}"
)
raise ValueError(msg)
return values
def combine_docs(

View File

@ -325,10 +325,10 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
_token_max,
**kwargs,
)
result_docs = []
for docs in new_result_doc_list:
new_doc = collapse_docs(docs, _collapse_docs_func, **kwargs)
result_docs.append(new_doc)
result_docs = [
collapse_docs(docs_, _collapse_docs_func, **kwargs)
for docs_ in new_result_doc_list
]
num_tokens = length_func(result_docs, **kwargs)
retries += 1
if self.collapse_max_retries and retries == self.collapse_max_retries:
@ -364,10 +364,10 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
_token_max,
**kwargs,
)
result_docs = []
for docs in new_result_doc_list:
new_doc = await acollapse_docs(docs, _collapse_docs_func, **kwargs)
result_docs.append(new_doc)
result_docs = [
await acollapse_docs(docs_, _collapse_docs_func, **kwargs)
for docs_ in new_result_doc_list
]
num_tokens = length_func(result_docs, **kwargs)
retries += 1
if self.collapse_max_retries and retries == self.collapse_max_retries:

View File

@ -140,13 +140,12 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
"multiple llm_chain input_variables"
)
raise ValueError(msg)
else:
if values["document_variable_name"] not in llm_chain_variables:
msg = (
f"document_variable_name {values['document_variable_name']} was "
f"not found in llm_chain input_variables: {llm_chain_variables}"
)
raise ValueError(msg)
elif values["document_variable_name"] not in llm_chain_variables:
msg = (
f"document_variable_name {values['document_variable_name']} was "
f"not found in llm_chain input_variables: {llm_chain_variables}"
)
raise ValueError(msg)
return values
def combine_docs(

View File

@ -180,13 +180,12 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
"multiple llm_chain_variables"
)
raise ValueError(msg)
else:
if values["document_variable_name"] not in llm_chain_variables:
msg = (
f"document_variable_name {values['document_variable_name']} was "
f"not found in llm_chain input_variables: {llm_chain_variables}"
)
raise ValueError(msg)
elif values["document_variable_name"] not in llm_chain_variables:
msg = (
f"document_variable_name {values['document_variable_name']} was "
f"not found in llm_chain input_variables: {llm_chain_variables}"
)
raise ValueError(msg)
return values
@property

View File

@ -322,7 +322,7 @@ class Crawler:
if node_name == "#text" and ancestor_exception and ancestor_node:
text = strings[node_value[index]]
if text == "|" or text == "":
if text in {"|", ""}:
continue
ancestor_node.append({"type": "type", "value": text})
else:
@ -367,7 +367,7 @@ class Crawler:
element_node_value = strings[text_index]
# remove redundant elements
if ancestor_exception and (node_name != "a" and node_name != "button"):
if ancestor_exception and (node_name not in {"a", "button"}):
continue
elements_in_view_port.append(
@ -423,10 +423,7 @@ class Crawler:
# not very elegant, more like a placeholder
if (
(converted_node_name != "button" or meta == "")
and converted_node_name != "link"
and converted_node_name != "input"
and converted_node_name != "img"
and converted_node_name != "textarea"
and converted_node_name not in {"link", "input", "img", "textarea"}
) and inner_text.strip() == "":
continue

View File

@ -51,13 +51,12 @@ def _format_url(url: str, path_params: dict) -> str:
sep = ","
new_val = ""
new_val += sep.join(kv_strs)
elif param[0] == ".":
new_val = f".{val}"
elif param[0] == ";":
new_val = f";{clean_param}={val}"
else:
if param[0] == ".":
new_val = f".{val}"
elif param[0] == ";":
new_val = f";{clean_param}={val}"
else:
new_val = val
new_val = val
new_params[param] = new_val
return url.format(**new_params)
@ -224,7 +223,7 @@ class SimpleRequestChain(Chain):
_text = f"Calling endpoint {_pretty_name} with arguments:\n" + _pretty_args
_run_manager.on_text(_text)
api_response: Response = self.request_method(name, args)
if api_response.status_code != 200:
if api_response.status_code != requests.codes.ok:
response = (
f"{api_response.status_code}: {api_response.reason}"
f"\nFor {name} "

View File

@ -87,7 +87,7 @@ class HypotheticalDocumentEmbedder:
)
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder as H
return H(*args, **kwargs) # type: ignore[return-value]
return H(*args, **kwargs) # type: ignore[return-value] # noqa: PLE0101
@classmethod
def from_llm(cls, *args: Any, **kwargs: Any) -> Any:

View File

@ -89,7 +89,6 @@ def _infer_model_and_provider(
if provider is None and ":" in model:
provider, model_name = _parse_model_string(model)
else:
provider = provider
model_name = model
if not provider:

View File

@ -89,7 +89,7 @@ _warned_about_sha1: bool = False
def _warn_about_sha1_encoder() -> None:
"""Emit a one-time warning about SHA-1 collision weaknesses."""
global _warned_about_sha1
global _warned_about_sha1 # noqa: PLW0603
if not _warned_about_sha1:
warnings.warn(
"Using default key encoder: SHA-1 is *not* collision-resistant. "

View File

@ -36,6 +36,8 @@ from langchain.evaluation.agents.trajectory_eval_prompt import (
)
from langchain.evaluation.schema import AgentTrajectoryEvaluator, LLMEvalChain
_MAX_SCORE = 5
class TrajectoryEval(TypedDict):
"""A named tuple containing the score and reasoning for a trajectory."""
@ -86,10 +88,10 @@ class TrajectoryOutputParser(BaseOutputParser):
raise OutputParserException(msg)
score = int(_score.group(1))
# If the score is not in the range 1-5, raise an exception.
if not 1 <= score <= 5:
if not 1 <= score <= _MAX_SCORE:
msg = f"Score is not a digit in the range 1-5: {text}"
raise OutputParserException(msg)
normalized_score = (score - 1) / 4
normalized_score = (score - 1) / (_MAX_SCORE - 1)
return TrajectoryEval(score=normalized_score, reasoning=reasoning)

View File

@ -39,7 +39,7 @@ def set_verbose(
# have migrated to using `set_verbose()` here.
langchain.verbose = value
global _verbose
global _verbose # noqa: PLW0603
_verbose = value
@ -69,7 +69,6 @@ def get_verbose() -> bool:
# directing them to use `set_verbose()` when they import `langchain.verbose`.
old_verbose = langchain.verbose
global _verbose
return _verbose or old_verbose
@ -94,7 +93,7 @@ def set_debug(
# have migrated to using `set_debug()` here.
langchain.debug = value
global _debug
global _debug # noqa: PLW0603
_debug = value
@ -122,7 +121,6 @@ def get_debug() -> bool:
# directing them to use `set_debug()` when they import `langchain.debug`.
old_debug = langchain.debug
global _debug
return _debug or old_debug
@ -147,7 +145,7 @@ def set_llm_cache(value: Optional["BaseCache"]) -> None:
# once all users have migrated to using `set_llm_cache()` here.
langchain.llm_cache = value
global _llm_cache
global _llm_cache # noqa: PLW0603
_llm_cache = value
@ -179,5 +177,4 @@ def get_llm_cache() -> "BaseCache":
# to use `set_llm_cache()` when they import `langchain.llm_cache`.
old_llm_cache = langchain.llm_cache
global _llm_cache
return _llm_cache or old_llm_cache

View File

@ -5,6 +5,8 @@ from typing import Any
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.utils import pre_init
_MIN_PARSERS = 2
class CombiningOutputParser(BaseOutputParser[dict[str, Any]]):
"""Combine multiple output parsers into one."""
@ -19,7 +21,7 @@ class CombiningOutputParser(BaseOutputParser[dict[str, Any]]):
def validate_parsers(cls, values: dict[str, Any]) -> dict[str, Any]:
"""Validate the parsers."""
parsers = values["parsers"]
if len(parsers) < 2:
if len(parsers) < _MIN_PARSERS:
msg = "Must have at least two parsers"
raise ValueError(msg)
for parser in parsers:

View File

@ -79,7 +79,7 @@ class PandasDataFrameOutputParser(BaseOutputParser[dict[str, Any]]):
def parse(self, request: str) -> dict[str, Any]:
stripped_request_params = None
splitted_request = request.strip().split(":")
if len(splitted_request) != 2:
if len(splitted_request) != 2: # noqa: PLR2004
msg = f"Request '{request}' is not correctly formatted. \
Please refer to the format instructions."
raise OutputParserException(msg)
@ -127,16 +127,15 @@ class PandasDataFrameOutputParser(BaseOutputParser[dict[str, Any]]):
filtered_df[stripped_request_params],
request_type,
)()
elif request_type == "column":
result[request_params] = self.dataframe[request_params]
elif request_type == "row":
result[request_params] = self.dataframe.iloc[int(request_params)]
else:
if request_type == "column":
result[request_params] = self.dataframe[request_params]
elif request_type == "row":
result[request_params] = self.dataframe.iloc[int(request_params)]
else:
result[request_type] = getattr(
self.dataframe[request_params],
request_type,
)()
result[request_type] = getattr(
self.dataframe[request_params],
request_type,
)()
except (AttributeError, IndexError, KeyError) as e:
if request_type not in {"column", "row"}:
msg = f"Unsupported request type '{request_type}'. \

View File

@ -70,9 +70,8 @@ class LLMChainFilter(BaseDocumentCompressor):
output = output_[self.llm_chain.output_key]
if self.llm_chain.prompt.output_parser is not None:
include_doc = self.llm_chain.prompt.output_parser.parse(output)
else:
if isinstance(output_, bool):
include_doc = output_
elif isinstance(output_, bool):
include_doc = output_
if include_doc:
filtered_docs.append(doc)
@ -101,9 +100,8 @@ class LLMChainFilter(BaseDocumentCompressor):
output = output_[self.llm_chain.output_key]
if self.llm_chain.prompt.output_parser is not None:
include_doc = self.llm_chain.prompt.output_parser.parse(output)
else:
if isinstance(output_, bool):
include_doc = output_
elif isinstance(output_, bool):
include_doc = output_
if include_doc:
filtered_docs.append(doc)

View File

@ -147,11 +147,9 @@ select = [
"A", # flake8-builtins
"B", # flake8-bugbear
"ASYNC", # flake8-async
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"COM", # flake8-commas
"D", # pydocstyle
"DOC", # pydoclint
"DTZ", # flake8-datetimez
"E", # pycodestyle error
"EM", # flake8-errmsg
@ -164,9 +162,10 @@ select = [
"ICN", # flake8-import-conventions
"INT", # flake8-gettext
"ISC", # isort-comprehensions
"PERF", # flake8-perf
"PGH", # pygrep-hooks
"PIE", # flake8-pie
"PERF", # flake8-perf
"PL", # pylint
"PT", # flake8-pytest-style
"PTH", # flake8-use-pathlib
"PYI", # flake8-pyi
@ -175,9 +174,9 @@ select = [
"RSE", # flake8-rst-docstrings
"RUF", # ruff
"S", # flake8-bandit
"SLF", # flake8-self
"SLOT", # flake8-slots
"SIM", # flake8-simplify
"SLF", # flake8-self
"T10", # flake8-debugger
"T20", # flake8-print
"TID", # flake8-tidy-imports
@ -198,13 +197,15 @@ ignore = [
"COM812", # Messes with the formatter
"ISC001", # Messes with the formatter
"PERF203", # Rarely useful
"PLR09", # Too many something (args, statements, etc)
"S112", # Rarely useful
"RUF012", # Doesn't play well with Pydantic
"SLF001", # Private member access
"UP007", # pyupgrade: non-pep604-annotation-union
# TODO
"TRY301", # tryceratops: raise-within-try
# TODO rules
"PLC0415", # pylint: import-outside-top-level
"TRY301", # tryceratops: raise-within-try
]
unfixable = ["B028"] # People should intentionally tune the stacklevel
@ -217,9 +218,10 @@ pyupgrade.keep-runtime-typing = true
[tool.ruff.lint.extend-per-file-ignores]
"tests/**/*.py" = [
"S101", # Tests need assertions
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
"SLF001", # Private member access in tests
"S101", # Tests need assertions
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
"SLF001", # Private member access in tests
"PLR2004", # Magic value comparisons
]
"langchain/chains/constitutional_ai/principles.py" = [
"E501", # Line too long

View File

@ -130,8 +130,7 @@ def pytest_collection_modifyitems(
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`"),
)
break
else:
if only_extended:
item.add_marker(
pytest.mark.skip(reason="Skipping not an extended test."),
)
elif only_extended:
item.add_marker(
pytest.mark.skip(reason="Skipping not an extended test."),
)

View File

@ -1,4 +1,4 @@
import pytest as pytest
import pytest
from langchain_core.documents import Document
from langchain.retrievers.multi_query import LineListOutputParser, _unique_documents

View File

@ -10,6 +10,8 @@ class AnyStr(str):
def __eq__(self, other: object) -> bool:
return isinstance(other, str)
__hash__ = str.__hash__
# The code below creates version of pydantic models
# that will work in unit tests with AnyStr as id field