From 3496e1739ec66d1db85f101b6f445dab3097b44f Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Wed, 23 Jul 2025 05:55:32 +0200 Subject: [PATCH] feat(langchain): add ruff rules PL (#32079) See https://docs.astral.sh/ruff/rules/#pylint-pl --- libs/langchain/langchain/__init__.py | 2 +- .../langchain/agents/format_scratchpad/xml.py | 6 +++--- .../langchain/agents/openai_assistant/base.py | 4 ++-- libs/langchain/langchain/agents/react/base.py | 7 +++++-- libs/langchain/langchain/chains/api/base.py | 4 ++-- .../chains/combine_documents/map_reduce.py | 13 ++++++------ .../chains/combine_documents/map_rerank.py | 13 ++++++------ .../chains/combine_documents/reduce.py | 16 +++++++-------- .../chains/combine_documents/refine.py | 13 ++++++------ .../chains/combine_documents/stuff.py | 13 ++++++------ .../langchain/chains/natbot/crawler.py | 9 +++------ .../chains/openai_functions/openapi.py | 13 ++++++------ .../langchain/embeddings/__init__.py | 2 +- libs/langchain/langchain/embeddings/base.py | 1 - libs/langchain/langchain/embeddings/cache.py | 2 +- .../agents/trajectory_eval_chain.py | 6 ++++-- libs/langchain/langchain/globals.py | 9 +++------ .../langchain/output_parsers/combining.py | 4 +++- .../output_parsers/pandas_dataframe.py | 19 +++++++++--------- .../document_compressors/chain_filter.py | 10 ++++------ libs/langchain/pyproject.toml | 20 ++++++++++--------- libs/langchain/tests/unit_tests/conftest.py | 9 ++++----- .../unit_tests/retrievers/test_multi_query.py | 2 +- libs/langchain/tests/unit_tests/stubs.py | 2 ++ 24 files changed, 97 insertions(+), 102 deletions(-) diff --git a/libs/langchain/langchain/__init__.py b/libs/langchain/langchain/__init__.py index edc20ad2b3a..2d441e75ced 100644 --- a/libs/langchain/langchain/__init__.py +++ b/libs/langchain/langchain/__init__.py @@ -357,7 +357,7 @@ def __getattr__(name: str) -> Any: return ElasticVectorSearch # For backwards compatibility - if name == "SerpAPIChain" or name == "SerpAPIWrapper": + if name in {"SerpAPIChain", "SerpAPIWrapper"}: from langchain_community.utilities import SerpAPIWrapper _warn_on_import( diff --git a/libs/langchain/langchain/agents/format_scratchpad/xml.py b/libs/langchain/langchain/agents/format_scratchpad/xml.py index 3fb73637790..942b2ec8fa9 100644 --- a/libs/langchain/langchain/agents/format_scratchpad/xml.py +++ b/libs/langchain/langchain/agents/format_scratchpad/xml.py @@ -40,13 +40,13 @@ def format_xml( # Escape XML tags in tool names and inputs using custom delimiters tool = _escape(action.tool) tool_input = _escape(str(action.tool_input)) - observation = _escape(str(observation)) + observation_ = _escape(str(observation)) else: tool = action.tool tool_input = str(action.tool_input) - observation = str(observation) + observation_ = str(observation) log += ( f"{tool}{tool_input}" - f"{observation}" + f"{observation_}" ) return log diff --git a/libs/langchain/langchain/agents/openai_assistant/base.py b/libs/langchain/langchain/agents/openai_assistant/base.py index 5a014069133..c1e862b1d67 100644 --- a/libs/langchain/langchain/agents/openai_assistant/base.py +++ b/libs/langchain/langchain/agents/openai_assistant/base.py @@ -582,7 +582,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]): major_version = int(openai.version.VERSION.split(".")[0]) minor_version = int(openai.version.VERSION.split(".")[1]) version_gte_1_14 = (major_version > 1) or ( - major_version == 1 and minor_version >= 14 + major_version == 1 and minor_version >= 14 # noqa: PLR2004 ) messages = self.client.beta.threads.messages.list( @@ -739,7 +739,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]): major_version = int(openai.version.VERSION.split(".")[0]) minor_version = int(openai.version.VERSION.split(".")[1]) version_gte_1_14 = (major_version > 1) or ( - major_version == 1 and minor_version >= 14 + major_version == 1 and minor_version >= 14 # noqa: PLR2004 ) messages = await self.async_client.beta.threads.messages.list( diff --git a/libs/langchain/langchain/agents/react/base.py b/libs/langchain/langchain/agents/react/base.py index 7430e639650..33c3099815c 100644 --- a/libs/langchain/langchain/agents/react/base.py +++ b/libs/langchain/langchain/agents/react/base.py @@ -24,6 +24,9 @@ if TYPE_CHECKING: from langchain_community.docstore.base import Docstore +_LOOKUP_AND_SEARCH_TOOLS = {"Lookup", "Search"} + + @deprecated( "0.1.0", message=AGENT_DEPRECATION_WARNING, @@ -52,11 +55,11 @@ class ReActDocstoreAgent(Agent): def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: validate_tools_single_input(cls.__name__, tools) super()._validate_tools(tools) - if len(tools) != 2: + if len(tools) != len(_LOOKUP_AND_SEARCH_TOOLS): msg = f"Exactly two tools must be specified, but got {tools}" raise ValueError(msg) tool_names = {tool.name for tool in tools} - if tool_names != {"Lookup", "Search"}: + if tool_names != _LOOKUP_AND_SEARCH_TOOLS: msg = f"Tool names should be Lookup and Search, got {tool_names}" raise ValueError(msg) diff --git a/libs/langchain/langchain/chains/api/base.py b/libs/langchain/langchain/chains/api/base.py index df2f9623e90..454993c068d 100644 --- a/libs/langchain/langchain/chains/api/base.py +++ b/libs/langchain/langchain/chains/api/base.py @@ -47,8 +47,8 @@ def _check_in_allowed_domain(url: str, limit_to_domains: Sequence[str]) -> bool: scheme, domain = _extract_scheme_and_domain(url) for allowed_domain in limit_to_domains: - allowed_scheme, allowed_domain = _extract_scheme_and_domain(allowed_domain) - if scheme == allowed_scheme and domain == allowed_domain: + allowed_scheme, allowed_domain_ = _extract_scheme_and_domain(allowed_domain) + if scheme == allowed_scheme and domain == allowed_domain_: return True return False diff --git a/libs/langchain/langchain/chains/combine_documents/map_reduce.py b/libs/langchain/langchain/chains/combine_documents/map_reduce.py index 8bd373ebd82..bab85195cd9 100644 --- a/libs/langchain/langchain/chains/combine_documents/map_reduce.py +++ b/libs/langchain/langchain/chains/combine_documents/map_reduce.py @@ -193,13 +193,12 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain): "multiple llm_chain input_variables" ) raise ValueError(msg) - else: - if values["document_variable_name"] not in llm_chain_variables: - msg = ( - f"document_variable_name {values['document_variable_name']} was " - f"not found in llm_chain input_variables: {llm_chain_variables}" - ) - raise ValueError(msg) + elif values["document_variable_name"] not in llm_chain_variables: + msg = ( + f"document_variable_name {values['document_variable_name']} was " + f"not found in llm_chain input_variables: {llm_chain_variables}" + ) + raise ValueError(msg) return values @property diff --git a/libs/langchain/langchain/chains/combine_documents/map_rerank.py b/libs/langchain/langchain/chains/combine_documents/map_rerank.py index c6ad0c611c7..743c2f7282b 100644 --- a/libs/langchain/langchain/chains/combine_documents/map_rerank.py +++ b/libs/langchain/langchain/chains/combine_documents/map_rerank.py @@ -161,13 +161,12 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain): "multiple llm_chain input_variables" ) raise ValueError(msg) - else: - if values["document_variable_name"] not in llm_chain_variables: - msg = ( - f"document_variable_name {values['document_variable_name']} was " - f"not found in llm_chain input_variables: {llm_chain_variables}" - ) - raise ValueError(msg) + elif values["document_variable_name"] not in llm_chain_variables: + msg = ( + f"document_variable_name {values['document_variable_name']} was " + f"not found in llm_chain input_variables: {llm_chain_variables}" + ) + raise ValueError(msg) return values def combine_docs( diff --git a/libs/langchain/langchain/chains/combine_documents/reduce.py b/libs/langchain/langchain/chains/combine_documents/reduce.py index eb8648685bb..18c2cba7b84 100644 --- a/libs/langchain/langchain/chains/combine_documents/reduce.py +++ b/libs/langchain/langchain/chains/combine_documents/reduce.py @@ -325,10 +325,10 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain): _token_max, **kwargs, ) - result_docs = [] - for docs in new_result_doc_list: - new_doc = collapse_docs(docs, _collapse_docs_func, **kwargs) - result_docs.append(new_doc) + result_docs = [ + collapse_docs(docs_, _collapse_docs_func, **kwargs) + for docs_ in new_result_doc_list + ] num_tokens = length_func(result_docs, **kwargs) retries += 1 if self.collapse_max_retries and retries == self.collapse_max_retries: @@ -364,10 +364,10 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain): _token_max, **kwargs, ) - result_docs = [] - for docs in new_result_doc_list: - new_doc = await acollapse_docs(docs, _collapse_docs_func, **kwargs) - result_docs.append(new_doc) + result_docs = [ + await acollapse_docs(docs_, _collapse_docs_func, **kwargs) + for docs_ in new_result_doc_list + ] num_tokens = length_func(result_docs, **kwargs) retries += 1 if self.collapse_max_retries and retries == self.collapse_max_retries: diff --git a/libs/langchain/langchain/chains/combine_documents/refine.py b/libs/langchain/langchain/chains/combine_documents/refine.py index b904b909191..5e5e3e74d05 100644 --- a/libs/langchain/langchain/chains/combine_documents/refine.py +++ b/libs/langchain/langchain/chains/combine_documents/refine.py @@ -140,13 +140,12 @@ class RefineDocumentsChain(BaseCombineDocumentsChain): "multiple llm_chain input_variables" ) raise ValueError(msg) - else: - if values["document_variable_name"] not in llm_chain_variables: - msg = ( - f"document_variable_name {values['document_variable_name']} was " - f"not found in llm_chain input_variables: {llm_chain_variables}" - ) - raise ValueError(msg) + elif values["document_variable_name"] not in llm_chain_variables: + msg = ( + f"document_variable_name {values['document_variable_name']} was " + f"not found in llm_chain input_variables: {llm_chain_variables}" + ) + raise ValueError(msg) return values def combine_docs( diff --git a/libs/langchain/langchain/chains/combine_documents/stuff.py b/libs/langchain/langchain/chains/combine_documents/stuff.py index 138629ab143..0b380dd8bf6 100644 --- a/libs/langchain/langchain/chains/combine_documents/stuff.py +++ b/libs/langchain/langchain/chains/combine_documents/stuff.py @@ -180,13 +180,12 @@ class StuffDocumentsChain(BaseCombineDocumentsChain): "multiple llm_chain_variables" ) raise ValueError(msg) - else: - if values["document_variable_name"] not in llm_chain_variables: - msg = ( - f"document_variable_name {values['document_variable_name']} was " - f"not found in llm_chain input_variables: {llm_chain_variables}" - ) - raise ValueError(msg) + elif values["document_variable_name"] not in llm_chain_variables: + msg = ( + f"document_variable_name {values['document_variable_name']} was " + f"not found in llm_chain input_variables: {llm_chain_variables}" + ) + raise ValueError(msg) return values @property diff --git a/libs/langchain/langchain/chains/natbot/crawler.py b/libs/langchain/langchain/chains/natbot/crawler.py index ab5348d40c7..00a5dd38a6b 100644 --- a/libs/langchain/langchain/chains/natbot/crawler.py +++ b/libs/langchain/langchain/chains/natbot/crawler.py @@ -322,7 +322,7 @@ class Crawler: if node_name == "#text" and ancestor_exception and ancestor_node: text = strings[node_value[index]] - if text == "|" or text == "•": + if text in {"|", "•"}: continue ancestor_node.append({"type": "type", "value": text}) else: @@ -367,7 +367,7 @@ class Crawler: element_node_value = strings[text_index] # remove redundant elements - if ancestor_exception and (node_name != "a" and node_name != "button"): + if ancestor_exception and (node_name not in {"a", "button"}): continue elements_in_view_port.append( @@ -423,10 +423,7 @@ class Crawler: # not very elegant, more like a placeholder if ( (converted_node_name != "button" or meta == "") - and converted_node_name != "link" - and converted_node_name != "input" - and converted_node_name != "img" - and converted_node_name != "textarea" + and converted_node_name not in {"link", "input", "img", "textarea"} ) and inner_text.strip() == "": continue diff --git a/libs/langchain/langchain/chains/openai_functions/openapi.py b/libs/langchain/langchain/chains/openai_functions/openapi.py index fe7a265d98c..a981d9fddb5 100644 --- a/libs/langchain/langchain/chains/openai_functions/openapi.py +++ b/libs/langchain/langchain/chains/openai_functions/openapi.py @@ -51,13 +51,12 @@ def _format_url(url: str, path_params: dict) -> str: sep = "," new_val = "" new_val += sep.join(kv_strs) + elif param[0] == ".": + new_val = f".{val}" + elif param[0] == ";": + new_val = f";{clean_param}={val}" else: - if param[0] == ".": - new_val = f".{val}" - elif param[0] == ";": - new_val = f";{clean_param}={val}" - else: - new_val = val + new_val = val new_params[param] = new_val return url.format(**new_params) @@ -224,7 +223,7 @@ class SimpleRequestChain(Chain): _text = f"Calling endpoint {_pretty_name} with arguments:\n" + _pretty_args _run_manager.on_text(_text) api_response: Response = self.request_method(name, args) - if api_response.status_code != 200: + if api_response.status_code != requests.codes.ok: response = ( f"{api_response.status_code}: {api_response.reason}" f"\nFor {name} " diff --git a/libs/langchain/langchain/embeddings/__init__.py b/libs/langchain/langchain/embeddings/__init__.py index 7975ae1f1c9..d377dc181cf 100644 --- a/libs/langchain/langchain/embeddings/__init__.py +++ b/libs/langchain/langchain/embeddings/__init__.py @@ -87,7 +87,7 @@ class HypotheticalDocumentEmbedder: ) from langchain.chains.hyde.base import HypotheticalDocumentEmbedder as H - return H(*args, **kwargs) # type: ignore[return-value] + return H(*args, **kwargs) # type: ignore[return-value] # noqa: PLE0101 @classmethod def from_llm(cls, *args: Any, **kwargs: Any) -> Any: diff --git a/libs/langchain/langchain/embeddings/base.py b/libs/langchain/langchain/embeddings/base.py index c0a6c8b891d..6eb47a5f2d2 100644 --- a/libs/langchain/langchain/embeddings/base.py +++ b/libs/langchain/langchain/embeddings/base.py @@ -89,7 +89,6 @@ def _infer_model_and_provider( if provider is None and ":" in model: provider, model_name = _parse_model_string(model) else: - provider = provider model_name = model if not provider: diff --git a/libs/langchain/langchain/embeddings/cache.py b/libs/langchain/langchain/embeddings/cache.py index b42bd959312..fb200c76847 100644 --- a/libs/langchain/langchain/embeddings/cache.py +++ b/libs/langchain/langchain/embeddings/cache.py @@ -89,7 +89,7 @@ _warned_about_sha1: bool = False def _warn_about_sha1_encoder() -> None: """Emit a one-time warning about SHA-1 collision weaknesses.""" - global _warned_about_sha1 + global _warned_about_sha1 # noqa: PLW0603 if not _warned_about_sha1: warnings.warn( "Using default key encoder: SHA-1 is *not* collision-resistant. " diff --git a/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py b/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py index e31be77d2b8..f80ab15b973 100644 --- a/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py +++ b/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py @@ -36,6 +36,8 @@ from langchain.evaluation.agents.trajectory_eval_prompt import ( ) from langchain.evaluation.schema import AgentTrajectoryEvaluator, LLMEvalChain +_MAX_SCORE = 5 + class TrajectoryEval(TypedDict): """A named tuple containing the score and reasoning for a trajectory.""" @@ -86,10 +88,10 @@ class TrajectoryOutputParser(BaseOutputParser): raise OutputParserException(msg) score = int(_score.group(1)) # If the score is not in the range 1-5, raise an exception. - if not 1 <= score <= 5: + if not 1 <= score <= _MAX_SCORE: msg = f"Score is not a digit in the range 1-5: {text}" raise OutputParserException(msg) - normalized_score = (score - 1) / 4 + normalized_score = (score - 1) / (_MAX_SCORE - 1) return TrajectoryEval(score=normalized_score, reasoning=reasoning) diff --git a/libs/langchain/langchain/globals.py b/libs/langchain/langchain/globals.py index 7426e06d2fa..34ad94cd0ff 100644 --- a/libs/langchain/langchain/globals.py +++ b/libs/langchain/langchain/globals.py @@ -39,7 +39,7 @@ def set_verbose( # have migrated to using `set_verbose()` here. langchain.verbose = value - global _verbose + global _verbose # noqa: PLW0603 _verbose = value @@ -69,7 +69,6 @@ def get_verbose() -> bool: # directing them to use `set_verbose()` when they import `langchain.verbose`. old_verbose = langchain.verbose - global _verbose return _verbose or old_verbose @@ -94,7 +93,7 @@ def set_debug( # have migrated to using `set_debug()` here. langchain.debug = value - global _debug + global _debug # noqa: PLW0603 _debug = value @@ -122,7 +121,6 @@ def get_debug() -> bool: # directing them to use `set_debug()` when they import `langchain.debug`. old_debug = langchain.debug - global _debug return _debug or old_debug @@ -147,7 +145,7 @@ def set_llm_cache(value: Optional["BaseCache"]) -> None: # once all users have migrated to using `set_llm_cache()` here. langchain.llm_cache = value - global _llm_cache + global _llm_cache # noqa: PLW0603 _llm_cache = value @@ -179,5 +177,4 @@ def get_llm_cache() -> "BaseCache": # to use `set_llm_cache()` when they import `langchain.llm_cache`. old_llm_cache = langchain.llm_cache - global _llm_cache return _llm_cache or old_llm_cache diff --git a/libs/langchain/langchain/output_parsers/combining.py b/libs/langchain/langchain/output_parsers/combining.py index ed52a89ce75..bf112818184 100644 --- a/libs/langchain/langchain/output_parsers/combining.py +++ b/libs/langchain/langchain/output_parsers/combining.py @@ -5,6 +5,8 @@ from typing import Any from langchain_core.output_parsers import BaseOutputParser from langchain_core.utils import pre_init +_MIN_PARSERS = 2 + class CombiningOutputParser(BaseOutputParser[dict[str, Any]]): """Combine multiple output parsers into one.""" @@ -19,7 +21,7 @@ class CombiningOutputParser(BaseOutputParser[dict[str, Any]]): def validate_parsers(cls, values: dict[str, Any]) -> dict[str, Any]: """Validate the parsers.""" parsers = values["parsers"] - if len(parsers) < 2: + if len(parsers) < _MIN_PARSERS: msg = "Must have at least two parsers" raise ValueError(msg) for parser in parsers: diff --git a/libs/langchain/langchain/output_parsers/pandas_dataframe.py b/libs/langchain/langchain/output_parsers/pandas_dataframe.py index 78b16ba0910..d360865e324 100644 --- a/libs/langchain/langchain/output_parsers/pandas_dataframe.py +++ b/libs/langchain/langchain/output_parsers/pandas_dataframe.py @@ -79,7 +79,7 @@ class PandasDataFrameOutputParser(BaseOutputParser[dict[str, Any]]): def parse(self, request: str) -> dict[str, Any]: stripped_request_params = None splitted_request = request.strip().split(":") - if len(splitted_request) != 2: + if len(splitted_request) != 2: # noqa: PLR2004 msg = f"Request '{request}' is not correctly formatted. \ Please refer to the format instructions." raise OutputParserException(msg) @@ -127,16 +127,15 @@ class PandasDataFrameOutputParser(BaseOutputParser[dict[str, Any]]): filtered_df[stripped_request_params], request_type, )() + elif request_type == "column": + result[request_params] = self.dataframe[request_params] + elif request_type == "row": + result[request_params] = self.dataframe.iloc[int(request_params)] else: - if request_type == "column": - result[request_params] = self.dataframe[request_params] - elif request_type == "row": - result[request_params] = self.dataframe.iloc[int(request_params)] - else: - result[request_type] = getattr( - self.dataframe[request_params], - request_type, - )() + result[request_type] = getattr( + self.dataframe[request_params], + request_type, + )() except (AttributeError, IndexError, KeyError) as e: if request_type not in {"column", "row"}: msg = f"Unsupported request type '{request_type}'. \ diff --git a/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py b/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py index b5c6f9d1879..e26d8847c2a 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py +++ b/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py @@ -70,9 +70,8 @@ class LLMChainFilter(BaseDocumentCompressor): output = output_[self.llm_chain.output_key] if self.llm_chain.prompt.output_parser is not None: include_doc = self.llm_chain.prompt.output_parser.parse(output) - else: - if isinstance(output_, bool): - include_doc = output_ + elif isinstance(output_, bool): + include_doc = output_ if include_doc: filtered_docs.append(doc) @@ -101,9 +100,8 @@ class LLMChainFilter(BaseDocumentCompressor): output = output_[self.llm_chain.output_key] if self.llm_chain.prompt.output_parser is not None: include_doc = self.llm_chain.prompt.output_parser.parse(output) - else: - if isinstance(output_, bool): - include_doc = output_ + elif isinstance(output_, bool): + include_doc = output_ if include_doc: filtered_docs.append(doc) diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index 3b1add6dca5..c237945ea4e 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -147,11 +147,9 @@ select = [ "A", # flake8-builtins "B", # flake8-bugbear "ASYNC", # flake8-async - "B", # flake8-bugbear "C4", # flake8-comprehensions "COM", # flake8-commas "D", # pydocstyle - "DOC", # pydoclint "DTZ", # flake8-datetimez "E", # pycodestyle error "EM", # flake8-errmsg @@ -164,9 +162,10 @@ select = [ "ICN", # flake8-import-conventions "INT", # flake8-gettext "ISC", # isort-comprehensions + "PERF", # flake8-perf "PGH", # pygrep-hooks "PIE", # flake8-pie - "PERF", # flake8-perf + "PL", # pylint "PT", # flake8-pytest-style "PTH", # flake8-use-pathlib "PYI", # flake8-pyi @@ -175,9 +174,9 @@ select = [ "RSE", # flake8-rst-docstrings "RUF", # ruff "S", # flake8-bandit - "SLF", # flake8-self "SLOT", # flake8-slots "SIM", # flake8-simplify + "SLF", # flake8-self "T10", # flake8-debugger "T20", # flake8-print "TID", # flake8-tidy-imports @@ -198,13 +197,15 @@ ignore = [ "COM812", # Messes with the formatter "ISC001", # Messes with the formatter "PERF203", # Rarely useful + "PLR09", # Too many something (args, statements, etc) "S112", # Rarely useful "RUF012", # Doesn't play well with Pydantic "SLF001", # Private member access "UP007", # pyupgrade: non-pep604-annotation-union - # TODO - "TRY301", # tryceratops: raise-within-try + # TODO rules + "PLC0415", # pylint: import-outside-top-level + "TRY301", # tryceratops: raise-within-try ] unfixable = ["B028"] # People should intentionally tune the stacklevel @@ -217,9 +218,10 @@ pyupgrade.keep-runtime-typing = true [tool.ruff.lint.extend-per-file-ignores] "tests/**/*.py" = [ - "S101", # Tests need assertions - "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes - "SLF001", # Private member access in tests + "S101", # Tests need assertions + "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes + "SLF001", # Private member access in tests + "PLR2004", # Magic value comparisons ] "langchain/chains/constitutional_ai/principles.py" = [ "E501", # Line too long diff --git a/libs/langchain/tests/unit_tests/conftest.py b/libs/langchain/tests/unit_tests/conftest.py index fe0b62e766f..3a620b411fe 100644 --- a/libs/langchain/tests/unit_tests/conftest.py +++ b/libs/langchain/tests/unit_tests/conftest.py @@ -130,8 +130,7 @@ def pytest_collection_modifyitems( pytest.mark.skip(reason=f"Requires pkg: `{pkg}`"), ) break - else: - if only_extended: - item.add_marker( - pytest.mark.skip(reason="Skipping not an extended test."), - ) + elif only_extended: + item.add_marker( + pytest.mark.skip(reason="Skipping not an extended test."), + ) diff --git a/libs/langchain/tests/unit_tests/retrievers/test_multi_query.py b/libs/langchain/tests/unit_tests/retrievers/test_multi_query.py index c9bdea85f28..b492c1b93c6 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_multi_query.py +++ b/libs/langchain/tests/unit_tests/retrievers/test_multi_query.py @@ -1,4 +1,4 @@ -import pytest as pytest +import pytest from langchain_core.documents import Document from langchain.retrievers.multi_query import LineListOutputParser, _unique_documents diff --git a/libs/langchain/tests/unit_tests/stubs.py b/libs/langchain/tests/unit_tests/stubs.py index 44450d946bd..742d2df5813 100644 --- a/libs/langchain/tests/unit_tests/stubs.py +++ b/libs/langchain/tests/unit_tests/stubs.py @@ -10,6 +10,8 @@ class AnyStr(str): def __eq__(self, other: object) -> bool: return isinstance(other, str) + __hash__ = str.__hash__ + # The code below creates version of pydantic models # that will work in unit tests with AnyStr as id field