feat(langchain): add ruff rules PL (#32079)

See https://docs.astral.sh/ruff/rules/#pylint-pl
This commit is contained in:
Christophe Bornet 2025-07-23 05:55:32 +02:00 committed by GitHub
parent 0f39155f62
commit 3496e1739e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 97 additions and 102 deletions

View File

@ -357,7 +357,7 @@ def __getattr__(name: str) -> Any:
return ElasticVectorSearch return ElasticVectorSearch
# For backwards compatibility # For backwards compatibility
if name == "SerpAPIChain" or name == "SerpAPIWrapper": if name in {"SerpAPIChain", "SerpAPIWrapper"}:
from langchain_community.utilities import SerpAPIWrapper from langchain_community.utilities import SerpAPIWrapper
_warn_on_import( _warn_on_import(

View File

@ -40,13 +40,13 @@ def format_xml(
# Escape XML tags in tool names and inputs using custom delimiters # Escape XML tags in tool names and inputs using custom delimiters
tool = _escape(action.tool) tool = _escape(action.tool)
tool_input = _escape(str(action.tool_input)) tool_input = _escape(str(action.tool_input))
observation = _escape(str(observation)) observation_ = _escape(str(observation))
else: else:
tool = action.tool tool = action.tool
tool_input = str(action.tool_input) tool_input = str(action.tool_input)
observation = str(observation) observation_ = str(observation)
log += ( log += (
f"<tool>{tool}</tool><tool_input>{tool_input}" f"<tool>{tool}</tool><tool_input>{tool_input}"
f"</tool_input><observation>{observation}</observation>" f"</tool_input><observation>{observation_}</observation>"
) )
return log return log

View File

@ -582,7 +582,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
major_version = int(openai.version.VERSION.split(".")[0]) major_version = int(openai.version.VERSION.split(".")[0])
minor_version = int(openai.version.VERSION.split(".")[1]) minor_version = int(openai.version.VERSION.split(".")[1])
version_gte_1_14 = (major_version > 1) or ( version_gte_1_14 = (major_version > 1) or (
major_version == 1 and minor_version >= 14 major_version == 1 and minor_version >= 14 # noqa: PLR2004
) )
messages = self.client.beta.threads.messages.list( messages = self.client.beta.threads.messages.list(
@ -739,7 +739,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
major_version = int(openai.version.VERSION.split(".")[0]) major_version = int(openai.version.VERSION.split(".")[0])
minor_version = int(openai.version.VERSION.split(".")[1]) minor_version = int(openai.version.VERSION.split(".")[1])
version_gte_1_14 = (major_version > 1) or ( version_gte_1_14 = (major_version > 1) or (
major_version == 1 and minor_version >= 14 major_version == 1 and minor_version >= 14 # noqa: PLR2004
) )
messages = await self.async_client.beta.threads.messages.list( messages = await self.async_client.beta.threads.messages.list(

View File

@ -24,6 +24,9 @@ if TYPE_CHECKING:
from langchain_community.docstore.base import Docstore from langchain_community.docstore.base import Docstore
_LOOKUP_AND_SEARCH_TOOLS = {"Lookup", "Search"}
@deprecated( @deprecated(
"0.1.0", "0.1.0",
message=AGENT_DEPRECATION_WARNING, message=AGENT_DEPRECATION_WARNING,
@ -52,11 +55,11 @@ class ReActDocstoreAgent(Agent):
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
validate_tools_single_input(cls.__name__, tools) validate_tools_single_input(cls.__name__, tools)
super()._validate_tools(tools) super()._validate_tools(tools)
if len(tools) != 2: if len(tools) != len(_LOOKUP_AND_SEARCH_TOOLS):
msg = f"Exactly two tools must be specified, but got {tools}" msg = f"Exactly two tools must be specified, but got {tools}"
raise ValueError(msg) raise ValueError(msg)
tool_names = {tool.name for tool in tools} tool_names = {tool.name for tool in tools}
if tool_names != {"Lookup", "Search"}: if tool_names != _LOOKUP_AND_SEARCH_TOOLS:
msg = f"Tool names should be Lookup and Search, got {tool_names}" msg = f"Tool names should be Lookup and Search, got {tool_names}"
raise ValueError(msg) raise ValueError(msg)

View File

@ -47,8 +47,8 @@ def _check_in_allowed_domain(url: str, limit_to_domains: Sequence[str]) -> bool:
scheme, domain = _extract_scheme_and_domain(url) scheme, domain = _extract_scheme_and_domain(url)
for allowed_domain in limit_to_domains: for allowed_domain in limit_to_domains:
allowed_scheme, allowed_domain = _extract_scheme_and_domain(allowed_domain) allowed_scheme, allowed_domain_ = _extract_scheme_and_domain(allowed_domain)
if scheme == allowed_scheme and domain == allowed_domain: if scheme == allowed_scheme and domain == allowed_domain_:
return True return True
return False return False

View File

@ -193,8 +193,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
"multiple llm_chain input_variables" "multiple llm_chain input_variables"
) )
raise ValueError(msg) raise ValueError(msg)
else: elif values["document_variable_name"] not in llm_chain_variables:
if values["document_variable_name"] not in llm_chain_variables:
msg = ( msg = (
f"document_variable_name {values['document_variable_name']} was " f"document_variable_name {values['document_variable_name']} was "
f"not found in llm_chain input_variables: {llm_chain_variables}" f"not found in llm_chain input_variables: {llm_chain_variables}"

View File

@ -161,8 +161,7 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
"multiple llm_chain input_variables" "multiple llm_chain input_variables"
) )
raise ValueError(msg) raise ValueError(msg)
else: elif values["document_variable_name"] not in llm_chain_variables:
if values["document_variable_name"] not in llm_chain_variables:
msg = ( msg = (
f"document_variable_name {values['document_variable_name']} was " f"document_variable_name {values['document_variable_name']} was "
f"not found in llm_chain input_variables: {llm_chain_variables}" f"not found in llm_chain input_variables: {llm_chain_variables}"

View File

@ -325,10 +325,10 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
_token_max, _token_max,
**kwargs, **kwargs,
) )
result_docs = [] result_docs = [
for docs in new_result_doc_list: collapse_docs(docs_, _collapse_docs_func, **kwargs)
new_doc = collapse_docs(docs, _collapse_docs_func, **kwargs) for docs_ in new_result_doc_list
result_docs.append(new_doc) ]
num_tokens = length_func(result_docs, **kwargs) num_tokens = length_func(result_docs, **kwargs)
retries += 1 retries += 1
if self.collapse_max_retries and retries == self.collapse_max_retries: if self.collapse_max_retries and retries == self.collapse_max_retries:
@ -364,10 +364,10 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
_token_max, _token_max,
**kwargs, **kwargs,
) )
result_docs = [] result_docs = [
for docs in new_result_doc_list: await acollapse_docs(docs_, _collapse_docs_func, **kwargs)
new_doc = await acollapse_docs(docs, _collapse_docs_func, **kwargs) for docs_ in new_result_doc_list
result_docs.append(new_doc) ]
num_tokens = length_func(result_docs, **kwargs) num_tokens = length_func(result_docs, **kwargs)
retries += 1 retries += 1
if self.collapse_max_retries and retries == self.collapse_max_retries: if self.collapse_max_retries and retries == self.collapse_max_retries:

View File

@ -140,8 +140,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
"multiple llm_chain input_variables" "multiple llm_chain input_variables"
) )
raise ValueError(msg) raise ValueError(msg)
else: elif values["document_variable_name"] not in llm_chain_variables:
if values["document_variable_name"] not in llm_chain_variables:
msg = ( msg = (
f"document_variable_name {values['document_variable_name']} was " f"document_variable_name {values['document_variable_name']} was "
f"not found in llm_chain input_variables: {llm_chain_variables}" f"not found in llm_chain input_variables: {llm_chain_variables}"

View File

@ -180,8 +180,7 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
"multiple llm_chain_variables" "multiple llm_chain_variables"
) )
raise ValueError(msg) raise ValueError(msg)
else: elif values["document_variable_name"] not in llm_chain_variables:
if values["document_variable_name"] not in llm_chain_variables:
msg = ( msg = (
f"document_variable_name {values['document_variable_name']} was " f"document_variable_name {values['document_variable_name']} was "
f"not found in llm_chain input_variables: {llm_chain_variables}" f"not found in llm_chain input_variables: {llm_chain_variables}"

View File

@ -322,7 +322,7 @@ class Crawler:
if node_name == "#text" and ancestor_exception and ancestor_node: if node_name == "#text" and ancestor_exception and ancestor_node:
text = strings[node_value[index]] text = strings[node_value[index]]
if text == "|" or text == "": if text in {"|", ""}:
continue continue
ancestor_node.append({"type": "type", "value": text}) ancestor_node.append({"type": "type", "value": text})
else: else:
@ -367,7 +367,7 @@ class Crawler:
element_node_value = strings[text_index] element_node_value = strings[text_index]
# remove redundant elements # remove redundant elements
if ancestor_exception and (node_name != "a" and node_name != "button"): if ancestor_exception and (node_name not in {"a", "button"}):
continue continue
elements_in_view_port.append( elements_in_view_port.append(
@ -423,10 +423,7 @@ class Crawler:
# not very elegant, more like a placeholder # not very elegant, more like a placeholder
if ( if (
(converted_node_name != "button" or meta == "") (converted_node_name != "button" or meta == "")
and converted_node_name != "link" and converted_node_name not in {"link", "input", "img", "textarea"}
and converted_node_name != "input"
and converted_node_name != "img"
and converted_node_name != "textarea"
) and inner_text.strip() == "": ) and inner_text.strip() == "":
continue continue

View File

@ -51,8 +51,7 @@ def _format_url(url: str, path_params: dict) -> str:
sep = "," sep = ","
new_val = "" new_val = ""
new_val += sep.join(kv_strs) new_val += sep.join(kv_strs)
else: elif param[0] == ".":
if param[0] == ".":
new_val = f".{val}" new_val = f".{val}"
elif param[0] == ";": elif param[0] == ";":
new_val = f";{clean_param}={val}" new_val = f";{clean_param}={val}"
@ -224,7 +223,7 @@ class SimpleRequestChain(Chain):
_text = f"Calling endpoint {_pretty_name} with arguments:\n" + _pretty_args _text = f"Calling endpoint {_pretty_name} with arguments:\n" + _pretty_args
_run_manager.on_text(_text) _run_manager.on_text(_text)
api_response: Response = self.request_method(name, args) api_response: Response = self.request_method(name, args)
if api_response.status_code != 200: if api_response.status_code != requests.codes.ok:
response = ( response = (
f"{api_response.status_code}: {api_response.reason}" f"{api_response.status_code}: {api_response.reason}"
f"\nFor {name} " f"\nFor {name} "

View File

@ -87,7 +87,7 @@ class HypotheticalDocumentEmbedder:
) )
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder as H from langchain.chains.hyde.base import HypotheticalDocumentEmbedder as H
return H(*args, **kwargs) # type: ignore[return-value] return H(*args, **kwargs) # type: ignore[return-value] # noqa: PLE0101
@classmethod @classmethod
def from_llm(cls, *args: Any, **kwargs: Any) -> Any: def from_llm(cls, *args: Any, **kwargs: Any) -> Any:

View File

@ -89,7 +89,6 @@ def _infer_model_and_provider(
if provider is None and ":" in model: if provider is None and ":" in model:
provider, model_name = _parse_model_string(model) provider, model_name = _parse_model_string(model)
else: else:
provider = provider
model_name = model model_name = model
if not provider: if not provider:

View File

@ -89,7 +89,7 @@ _warned_about_sha1: bool = False
def _warn_about_sha1_encoder() -> None: def _warn_about_sha1_encoder() -> None:
"""Emit a one-time warning about SHA-1 collision weaknesses.""" """Emit a one-time warning about SHA-1 collision weaknesses."""
global _warned_about_sha1 global _warned_about_sha1 # noqa: PLW0603
if not _warned_about_sha1: if not _warned_about_sha1:
warnings.warn( warnings.warn(
"Using default key encoder: SHA-1 is *not* collision-resistant. " "Using default key encoder: SHA-1 is *not* collision-resistant. "

View File

@ -36,6 +36,8 @@ from langchain.evaluation.agents.trajectory_eval_prompt import (
) )
from langchain.evaluation.schema import AgentTrajectoryEvaluator, LLMEvalChain from langchain.evaluation.schema import AgentTrajectoryEvaluator, LLMEvalChain
_MAX_SCORE = 5
class TrajectoryEval(TypedDict): class TrajectoryEval(TypedDict):
"""A named tuple containing the score and reasoning for a trajectory.""" """A named tuple containing the score and reasoning for a trajectory."""
@ -86,10 +88,10 @@ class TrajectoryOutputParser(BaseOutputParser):
raise OutputParserException(msg) raise OutputParserException(msg)
score = int(_score.group(1)) score = int(_score.group(1))
# If the score is not in the range 1-5, raise an exception. # If the score is not in the range 1-5, raise an exception.
if not 1 <= score <= 5: if not 1 <= score <= _MAX_SCORE:
msg = f"Score is not a digit in the range 1-5: {text}" msg = f"Score is not a digit in the range 1-5: {text}"
raise OutputParserException(msg) raise OutputParserException(msg)
normalized_score = (score - 1) / 4 normalized_score = (score - 1) / (_MAX_SCORE - 1)
return TrajectoryEval(score=normalized_score, reasoning=reasoning) return TrajectoryEval(score=normalized_score, reasoning=reasoning)

View File

@ -39,7 +39,7 @@ def set_verbose(
# have migrated to using `set_verbose()` here. # have migrated to using `set_verbose()` here.
langchain.verbose = value langchain.verbose = value
global _verbose global _verbose # noqa: PLW0603
_verbose = value _verbose = value
@ -69,7 +69,6 @@ def get_verbose() -> bool:
# directing them to use `set_verbose()` when they import `langchain.verbose`. # directing them to use `set_verbose()` when they import `langchain.verbose`.
old_verbose = langchain.verbose old_verbose = langchain.verbose
global _verbose
return _verbose or old_verbose return _verbose or old_verbose
@ -94,7 +93,7 @@ def set_debug(
# have migrated to using `set_debug()` here. # have migrated to using `set_debug()` here.
langchain.debug = value langchain.debug = value
global _debug global _debug # noqa: PLW0603
_debug = value _debug = value
@ -122,7 +121,6 @@ def get_debug() -> bool:
# directing them to use `set_debug()` when they import `langchain.debug`. # directing them to use `set_debug()` when they import `langchain.debug`.
old_debug = langchain.debug old_debug = langchain.debug
global _debug
return _debug or old_debug return _debug or old_debug
@ -147,7 +145,7 @@ def set_llm_cache(value: Optional["BaseCache"]) -> None:
# once all users have migrated to using `set_llm_cache()` here. # once all users have migrated to using `set_llm_cache()` here.
langchain.llm_cache = value langchain.llm_cache = value
global _llm_cache global _llm_cache # noqa: PLW0603
_llm_cache = value _llm_cache = value
@ -179,5 +177,4 @@ def get_llm_cache() -> "BaseCache":
# to use `set_llm_cache()` when they import `langchain.llm_cache`. # to use `set_llm_cache()` when they import `langchain.llm_cache`.
old_llm_cache = langchain.llm_cache old_llm_cache = langchain.llm_cache
global _llm_cache
return _llm_cache or old_llm_cache return _llm_cache or old_llm_cache

View File

@ -5,6 +5,8 @@ from typing import Any
from langchain_core.output_parsers import BaseOutputParser from langchain_core.output_parsers import BaseOutputParser
from langchain_core.utils import pre_init from langchain_core.utils import pre_init
_MIN_PARSERS = 2
class CombiningOutputParser(BaseOutputParser[dict[str, Any]]): class CombiningOutputParser(BaseOutputParser[dict[str, Any]]):
"""Combine multiple output parsers into one.""" """Combine multiple output parsers into one."""
@ -19,7 +21,7 @@ class CombiningOutputParser(BaseOutputParser[dict[str, Any]]):
def validate_parsers(cls, values: dict[str, Any]) -> dict[str, Any]: def validate_parsers(cls, values: dict[str, Any]) -> dict[str, Any]:
"""Validate the parsers.""" """Validate the parsers."""
parsers = values["parsers"] parsers = values["parsers"]
if len(parsers) < 2: if len(parsers) < _MIN_PARSERS:
msg = "Must have at least two parsers" msg = "Must have at least two parsers"
raise ValueError(msg) raise ValueError(msg)
for parser in parsers: for parser in parsers:

View File

@ -79,7 +79,7 @@ class PandasDataFrameOutputParser(BaseOutputParser[dict[str, Any]]):
def parse(self, request: str) -> dict[str, Any]: def parse(self, request: str) -> dict[str, Any]:
stripped_request_params = None stripped_request_params = None
splitted_request = request.strip().split(":") splitted_request = request.strip().split(":")
if len(splitted_request) != 2: if len(splitted_request) != 2: # noqa: PLR2004
msg = f"Request '{request}' is not correctly formatted. \ msg = f"Request '{request}' is not correctly formatted. \
Please refer to the format instructions." Please refer to the format instructions."
raise OutputParserException(msg) raise OutputParserException(msg)
@ -127,8 +127,7 @@ class PandasDataFrameOutputParser(BaseOutputParser[dict[str, Any]]):
filtered_df[stripped_request_params], filtered_df[stripped_request_params],
request_type, request_type,
)() )()
else: elif request_type == "column":
if request_type == "column":
result[request_params] = self.dataframe[request_params] result[request_params] = self.dataframe[request_params]
elif request_type == "row": elif request_type == "row":
result[request_params] = self.dataframe.iloc[int(request_params)] result[request_params] = self.dataframe.iloc[int(request_params)]

View File

@ -70,8 +70,7 @@ class LLMChainFilter(BaseDocumentCompressor):
output = output_[self.llm_chain.output_key] output = output_[self.llm_chain.output_key]
if self.llm_chain.prompt.output_parser is not None: if self.llm_chain.prompt.output_parser is not None:
include_doc = self.llm_chain.prompt.output_parser.parse(output) include_doc = self.llm_chain.prompt.output_parser.parse(output)
else: elif isinstance(output_, bool):
if isinstance(output_, bool):
include_doc = output_ include_doc = output_
if include_doc: if include_doc:
filtered_docs.append(doc) filtered_docs.append(doc)
@ -101,8 +100,7 @@ class LLMChainFilter(BaseDocumentCompressor):
output = output_[self.llm_chain.output_key] output = output_[self.llm_chain.output_key]
if self.llm_chain.prompt.output_parser is not None: if self.llm_chain.prompt.output_parser is not None:
include_doc = self.llm_chain.prompt.output_parser.parse(output) include_doc = self.llm_chain.prompt.output_parser.parse(output)
else: elif isinstance(output_, bool):
if isinstance(output_, bool):
include_doc = output_ include_doc = output_
if include_doc: if include_doc:
filtered_docs.append(doc) filtered_docs.append(doc)

View File

@ -147,11 +147,9 @@ select = [
"A", # flake8-builtins "A", # flake8-builtins
"B", # flake8-bugbear "B", # flake8-bugbear
"ASYNC", # flake8-async "ASYNC", # flake8-async
"B", # flake8-bugbear
"C4", # flake8-comprehensions "C4", # flake8-comprehensions
"COM", # flake8-commas "COM", # flake8-commas
"D", # pydocstyle "D", # pydocstyle
"DOC", # pydoclint
"DTZ", # flake8-datetimez "DTZ", # flake8-datetimez
"E", # pycodestyle error "E", # pycodestyle error
"EM", # flake8-errmsg "EM", # flake8-errmsg
@ -164,9 +162,10 @@ select = [
"ICN", # flake8-import-conventions "ICN", # flake8-import-conventions
"INT", # flake8-gettext "INT", # flake8-gettext
"ISC", # isort-comprehensions "ISC", # isort-comprehensions
"PERF", # flake8-perf
"PGH", # pygrep-hooks "PGH", # pygrep-hooks
"PIE", # flake8-pie "PIE", # flake8-pie
"PERF", # flake8-perf "PL", # pylint
"PT", # flake8-pytest-style "PT", # flake8-pytest-style
"PTH", # flake8-use-pathlib "PTH", # flake8-use-pathlib
"PYI", # flake8-pyi "PYI", # flake8-pyi
@ -175,9 +174,9 @@ select = [
"RSE", # flake8-rst-docstrings "RSE", # flake8-rst-docstrings
"RUF", # ruff "RUF", # ruff
"S", # flake8-bandit "S", # flake8-bandit
"SLF", # flake8-self
"SLOT", # flake8-slots "SLOT", # flake8-slots
"SIM", # flake8-simplify "SIM", # flake8-simplify
"SLF", # flake8-self
"T10", # flake8-debugger "T10", # flake8-debugger
"T20", # flake8-print "T20", # flake8-print
"TID", # flake8-tidy-imports "TID", # flake8-tidy-imports
@ -198,12 +197,14 @@ ignore = [
"COM812", # Messes with the formatter "COM812", # Messes with the formatter
"ISC001", # Messes with the formatter "ISC001", # Messes with the formatter
"PERF203", # Rarely useful "PERF203", # Rarely useful
"PLR09", # Too many something (args, statements, etc)
"S112", # Rarely useful "S112", # Rarely useful
"RUF012", # Doesn't play well with Pydantic "RUF012", # Doesn't play well with Pydantic
"SLF001", # Private member access "SLF001", # Private member access
"UP007", # pyupgrade: non-pep604-annotation-union "UP007", # pyupgrade: non-pep604-annotation-union
# TODO # TODO rules
"PLC0415", # pylint: import-outside-top-level
"TRY301", # tryceratops: raise-within-try "TRY301", # tryceratops: raise-within-try
] ]
unfixable = ["B028"] # People should intentionally tune the stacklevel unfixable = ["B028"] # People should intentionally tune the stacklevel
@ -220,6 +221,7 @@ pyupgrade.keep-runtime-typing = true
"S101", # Tests need assertions "S101", # Tests need assertions
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
"SLF001", # Private member access in tests "SLF001", # Private member access in tests
"PLR2004", # Magic value comparisons
] ]
"langchain/chains/constitutional_ai/principles.py" = [ "langchain/chains/constitutional_ai/principles.py" = [
"E501", # Line too long "E501", # Line too long

View File

@ -130,8 +130,7 @@ def pytest_collection_modifyitems(
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`"), pytest.mark.skip(reason=f"Requires pkg: `{pkg}`"),
) )
break break
else: elif only_extended:
if only_extended:
item.add_marker( item.add_marker(
pytest.mark.skip(reason="Skipping not an extended test."), pytest.mark.skip(reason="Skipping not an extended test."),
) )

View File

@ -1,4 +1,4 @@
import pytest as pytest import pytest
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain.retrievers.multi_query import LineListOutputParser, _unique_documents from langchain.retrievers.multi_query import LineListOutputParser, _unique_documents

View File

@ -10,6 +10,8 @@ class AnyStr(str):
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
return isinstance(other, str) return isinstance(other, str)
__hash__ = str.__hash__
# The code below creates version of pydantic models # The code below creates version of pydantic models
# that will work in unit tests with AnyStr as id field # that will work in unit tests with AnyStr as id field