mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-31 16:39:20 +00:00
feat(langchain): add ruff rules PL (#32079)
See https://docs.astral.sh/ruff/rules/#pylint-pl
This commit is contained in:
parent
0f39155f62
commit
3496e1739e
@ -357,7 +357,7 @@ def __getattr__(name: str) -> Any:
|
||||
|
||||
return ElasticVectorSearch
|
||||
# For backwards compatibility
|
||||
if name == "SerpAPIChain" or name == "SerpAPIWrapper":
|
||||
if name in {"SerpAPIChain", "SerpAPIWrapper"}:
|
||||
from langchain_community.utilities import SerpAPIWrapper
|
||||
|
||||
_warn_on_import(
|
||||
|
@ -40,13 +40,13 @@ def format_xml(
|
||||
# Escape XML tags in tool names and inputs using custom delimiters
|
||||
tool = _escape(action.tool)
|
||||
tool_input = _escape(str(action.tool_input))
|
||||
observation = _escape(str(observation))
|
||||
observation_ = _escape(str(observation))
|
||||
else:
|
||||
tool = action.tool
|
||||
tool_input = str(action.tool_input)
|
||||
observation = str(observation)
|
||||
observation_ = str(observation)
|
||||
log += (
|
||||
f"<tool>{tool}</tool><tool_input>{tool_input}"
|
||||
f"</tool_input><observation>{observation}</observation>"
|
||||
f"</tool_input><observation>{observation_}</observation>"
|
||||
)
|
||||
return log
|
||||
|
@ -582,7 +582,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
||||
major_version = int(openai.version.VERSION.split(".")[0])
|
||||
minor_version = int(openai.version.VERSION.split(".")[1])
|
||||
version_gte_1_14 = (major_version > 1) or (
|
||||
major_version == 1 and minor_version >= 14
|
||||
major_version == 1 and minor_version >= 14 # noqa: PLR2004
|
||||
)
|
||||
|
||||
messages = self.client.beta.threads.messages.list(
|
||||
@ -739,7 +739,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
||||
major_version = int(openai.version.VERSION.split(".")[0])
|
||||
minor_version = int(openai.version.VERSION.split(".")[1])
|
||||
version_gte_1_14 = (major_version > 1) or (
|
||||
major_version == 1 and minor_version >= 14
|
||||
major_version == 1 and minor_version >= 14 # noqa: PLR2004
|
||||
)
|
||||
|
||||
messages = await self.async_client.beta.threads.messages.list(
|
||||
|
@ -24,6 +24,9 @@ if TYPE_CHECKING:
|
||||
from langchain_community.docstore.base import Docstore
|
||||
|
||||
|
||||
_LOOKUP_AND_SEARCH_TOOLS = {"Lookup", "Search"}
|
||||
|
||||
|
||||
@deprecated(
|
||||
"0.1.0",
|
||||
message=AGENT_DEPRECATION_WARNING,
|
||||
@ -52,11 +55,11 @@ class ReActDocstoreAgent(Agent):
|
||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||
validate_tools_single_input(cls.__name__, tools)
|
||||
super()._validate_tools(tools)
|
||||
if len(tools) != 2:
|
||||
if len(tools) != len(_LOOKUP_AND_SEARCH_TOOLS):
|
||||
msg = f"Exactly two tools must be specified, but got {tools}"
|
||||
raise ValueError(msg)
|
||||
tool_names = {tool.name for tool in tools}
|
||||
if tool_names != {"Lookup", "Search"}:
|
||||
if tool_names != _LOOKUP_AND_SEARCH_TOOLS:
|
||||
msg = f"Tool names should be Lookup and Search, got {tool_names}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
@ -47,8 +47,8 @@ def _check_in_allowed_domain(url: str, limit_to_domains: Sequence[str]) -> bool:
|
||||
scheme, domain = _extract_scheme_and_domain(url)
|
||||
|
||||
for allowed_domain in limit_to_domains:
|
||||
allowed_scheme, allowed_domain = _extract_scheme_and_domain(allowed_domain)
|
||||
if scheme == allowed_scheme and domain == allowed_domain:
|
||||
allowed_scheme, allowed_domain_ = _extract_scheme_and_domain(allowed_domain)
|
||||
if scheme == allowed_scheme and domain == allowed_domain_:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -193,13 +193,12 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
"multiple llm_chain input_variables"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
if values["document_variable_name"] not in llm_chain_variables:
|
||||
msg = (
|
||||
f"document_variable_name {values['document_variable_name']} was "
|
||||
f"not found in llm_chain input_variables: {llm_chain_variables}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
elif values["document_variable_name"] not in llm_chain_variables:
|
||||
msg = (
|
||||
f"document_variable_name {values['document_variable_name']} was "
|
||||
f"not found in llm_chain input_variables: {llm_chain_variables}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return values
|
||||
|
||||
@property
|
||||
|
@ -161,13 +161,12 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
||||
"multiple llm_chain input_variables"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
if values["document_variable_name"] not in llm_chain_variables:
|
||||
msg = (
|
||||
f"document_variable_name {values['document_variable_name']} was "
|
||||
f"not found in llm_chain input_variables: {llm_chain_variables}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
elif values["document_variable_name"] not in llm_chain_variables:
|
||||
msg = (
|
||||
f"document_variable_name {values['document_variable_name']} was "
|
||||
f"not found in llm_chain input_variables: {llm_chain_variables}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return values
|
||||
|
||||
def combine_docs(
|
||||
|
@ -325,10 +325,10 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
_token_max,
|
||||
**kwargs,
|
||||
)
|
||||
result_docs = []
|
||||
for docs in new_result_doc_list:
|
||||
new_doc = collapse_docs(docs, _collapse_docs_func, **kwargs)
|
||||
result_docs.append(new_doc)
|
||||
result_docs = [
|
||||
collapse_docs(docs_, _collapse_docs_func, **kwargs)
|
||||
for docs_ in new_result_doc_list
|
||||
]
|
||||
num_tokens = length_func(result_docs, **kwargs)
|
||||
retries += 1
|
||||
if self.collapse_max_retries and retries == self.collapse_max_retries:
|
||||
@ -364,10 +364,10 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
_token_max,
|
||||
**kwargs,
|
||||
)
|
||||
result_docs = []
|
||||
for docs in new_result_doc_list:
|
||||
new_doc = await acollapse_docs(docs, _collapse_docs_func, **kwargs)
|
||||
result_docs.append(new_doc)
|
||||
result_docs = [
|
||||
await acollapse_docs(docs_, _collapse_docs_func, **kwargs)
|
||||
for docs_ in new_result_doc_list
|
||||
]
|
||||
num_tokens = length_func(result_docs, **kwargs)
|
||||
retries += 1
|
||||
if self.collapse_max_retries and retries == self.collapse_max_retries:
|
||||
|
@ -140,13 +140,12 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
||||
"multiple llm_chain input_variables"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
if values["document_variable_name"] not in llm_chain_variables:
|
||||
msg = (
|
||||
f"document_variable_name {values['document_variable_name']} was "
|
||||
f"not found in llm_chain input_variables: {llm_chain_variables}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
elif values["document_variable_name"] not in llm_chain_variables:
|
||||
msg = (
|
||||
f"document_variable_name {values['document_variable_name']} was "
|
||||
f"not found in llm_chain input_variables: {llm_chain_variables}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return values
|
||||
|
||||
def combine_docs(
|
||||
|
@ -180,13 +180,12 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||
"multiple llm_chain_variables"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
if values["document_variable_name"] not in llm_chain_variables:
|
||||
msg = (
|
||||
f"document_variable_name {values['document_variable_name']} was "
|
||||
f"not found in llm_chain input_variables: {llm_chain_variables}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
elif values["document_variable_name"] not in llm_chain_variables:
|
||||
msg = (
|
||||
f"document_variable_name {values['document_variable_name']} was "
|
||||
f"not found in llm_chain input_variables: {llm_chain_variables}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return values
|
||||
|
||||
@property
|
||||
|
@ -322,7 +322,7 @@ class Crawler:
|
||||
|
||||
if node_name == "#text" and ancestor_exception and ancestor_node:
|
||||
text = strings[node_value[index]]
|
||||
if text == "|" or text == "•":
|
||||
if text in {"|", "•"}:
|
||||
continue
|
||||
ancestor_node.append({"type": "type", "value": text})
|
||||
else:
|
||||
@ -367,7 +367,7 @@ class Crawler:
|
||||
element_node_value = strings[text_index]
|
||||
|
||||
# remove redundant elements
|
||||
if ancestor_exception and (node_name != "a" and node_name != "button"):
|
||||
if ancestor_exception and (node_name not in {"a", "button"}):
|
||||
continue
|
||||
|
||||
elements_in_view_port.append(
|
||||
@ -423,10 +423,7 @@ class Crawler:
|
||||
# not very elegant, more like a placeholder
|
||||
if (
|
||||
(converted_node_name != "button" or meta == "")
|
||||
and converted_node_name != "link"
|
||||
and converted_node_name != "input"
|
||||
and converted_node_name != "img"
|
||||
and converted_node_name != "textarea"
|
||||
and converted_node_name not in {"link", "input", "img", "textarea"}
|
||||
) and inner_text.strip() == "":
|
||||
continue
|
||||
|
||||
|
@ -51,13 +51,12 @@ def _format_url(url: str, path_params: dict) -> str:
|
||||
sep = ","
|
||||
new_val = ""
|
||||
new_val += sep.join(kv_strs)
|
||||
elif param[0] == ".":
|
||||
new_val = f".{val}"
|
||||
elif param[0] == ";":
|
||||
new_val = f";{clean_param}={val}"
|
||||
else:
|
||||
if param[0] == ".":
|
||||
new_val = f".{val}"
|
||||
elif param[0] == ";":
|
||||
new_val = f";{clean_param}={val}"
|
||||
else:
|
||||
new_val = val
|
||||
new_val = val
|
||||
new_params[param] = new_val
|
||||
return url.format(**new_params)
|
||||
|
||||
@ -224,7 +223,7 @@ class SimpleRequestChain(Chain):
|
||||
_text = f"Calling endpoint {_pretty_name} with arguments:\n" + _pretty_args
|
||||
_run_manager.on_text(_text)
|
||||
api_response: Response = self.request_method(name, args)
|
||||
if api_response.status_code != 200:
|
||||
if api_response.status_code != requests.codes.ok:
|
||||
response = (
|
||||
f"{api_response.status_code}: {api_response.reason}"
|
||||
f"\nFor {name} "
|
||||
|
@ -87,7 +87,7 @@ class HypotheticalDocumentEmbedder:
|
||||
)
|
||||
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder as H
|
||||
|
||||
return H(*args, **kwargs) # type: ignore[return-value]
|
||||
return H(*args, **kwargs) # type: ignore[return-value] # noqa: PLE0101
|
||||
|
||||
@classmethod
|
||||
def from_llm(cls, *args: Any, **kwargs: Any) -> Any:
|
||||
|
@ -89,7 +89,6 @@ def _infer_model_and_provider(
|
||||
if provider is None and ":" in model:
|
||||
provider, model_name = _parse_model_string(model)
|
||||
else:
|
||||
provider = provider
|
||||
model_name = model
|
||||
|
||||
if not provider:
|
||||
|
@ -89,7 +89,7 @@ _warned_about_sha1: bool = False
|
||||
|
||||
def _warn_about_sha1_encoder() -> None:
|
||||
"""Emit a one-time warning about SHA-1 collision weaknesses."""
|
||||
global _warned_about_sha1
|
||||
global _warned_about_sha1 # noqa: PLW0603
|
||||
if not _warned_about_sha1:
|
||||
warnings.warn(
|
||||
"Using default key encoder: SHA-1 is *not* collision-resistant. "
|
||||
|
@ -36,6 +36,8 @@ from langchain.evaluation.agents.trajectory_eval_prompt import (
|
||||
)
|
||||
from langchain.evaluation.schema import AgentTrajectoryEvaluator, LLMEvalChain
|
||||
|
||||
_MAX_SCORE = 5
|
||||
|
||||
|
||||
class TrajectoryEval(TypedDict):
|
||||
"""A named tuple containing the score and reasoning for a trajectory."""
|
||||
@ -86,10 +88,10 @@ class TrajectoryOutputParser(BaseOutputParser):
|
||||
raise OutputParserException(msg)
|
||||
score = int(_score.group(1))
|
||||
# If the score is not in the range 1-5, raise an exception.
|
||||
if not 1 <= score <= 5:
|
||||
if not 1 <= score <= _MAX_SCORE:
|
||||
msg = f"Score is not a digit in the range 1-5: {text}"
|
||||
raise OutputParserException(msg)
|
||||
normalized_score = (score - 1) / 4
|
||||
normalized_score = (score - 1) / (_MAX_SCORE - 1)
|
||||
return TrajectoryEval(score=normalized_score, reasoning=reasoning)
|
||||
|
||||
|
||||
|
@ -39,7 +39,7 @@ def set_verbose(
|
||||
# have migrated to using `set_verbose()` here.
|
||||
langchain.verbose = value
|
||||
|
||||
global _verbose
|
||||
global _verbose # noqa: PLW0603
|
||||
_verbose = value
|
||||
|
||||
|
||||
@ -69,7 +69,6 @@ def get_verbose() -> bool:
|
||||
# directing them to use `set_verbose()` when they import `langchain.verbose`.
|
||||
old_verbose = langchain.verbose
|
||||
|
||||
global _verbose
|
||||
return _verbose or old_verbose
|
||||
|
||||
|
||||
@ -94,7 +93,7 @@ def set_debug(
|
||||
# have migrated to using `set_debug()` here.
|
||||
langchain.debug = value
|
||||
|
||||
global _debug
|
||||
global _debug # noqa: PLW0603
|
||||
_debug = value
|
||||
|
||||
|
||||
@ -122,7 +121,6 @@ def get_debug() -> bool:
|
||||
# directing them to use `set_debug()` when they import `langchain.debug`.
|
||||
old_debug = langchain.debug
|
||||
|
||||
global _debug
|
||||
return _debug or old_debug
|
||||
|
||||
|
||||
@ -147,7 +145,7 @@ def set_llm_cache(value: Optional["BaseCache"]) -> None:
|
||||
# once all users have migrated to using `set_llm_cache()` here.
|
||||
langchain.llm_cache = value
|
||||
|
||||
global _llm_cache
|
||||
global _llm_cache # noqa: PLW0603
|
||||
_llm_cache = value
|
||||
|
||||
|
||||
@ -179,5 +177,4 @@ def get_llm_cache() -> "BaseCache":
|
||||
# to use `set_llm_cache()` when they import `langchain.llm_cache`.
|
||||
old_llm_cache = langchain.llm_cache
|
||||
|
||||
global _llm_cache
|
||||
return _llm_cache or old_llm_cache
|
||||
|
@ -5,6 +5,8 @@ from typing import Any
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
from langchain_core.utils import pre_init
|
||||
|
||||
_MIN_PARSERS = 2
|
||||
|
||||
|
||||
class CombiningOutputParser(BaseOutputParser[dict[str, Any]]):
|
||||
"""Combine multiple output parsers into one."""
|
||||
@ -19,7 +21,7 @@ class CombiningOutputParser(BaseOutputParser[dict[str, Any]]):
|
||||
def validate_parsers(cls, values: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Validate the parsers."""
|
||||
parsers = values["parsers"]
|
||||
if len(parsers) < 2:
|
||||
if len(parsers) < _MIN_PARSERS:
|
||||
msg = "Must have at least two parsers"
|
||||
raise ValueError(msg)
|
||||
for parser in parsers:
|
||||
|
@ -79,7 +79,7 @@ class PandasDataFrameOutputParser(BaseOutputParser[dict[str, Any]]):
|
||||
def parse(self, request: str) -> dict[str, Any]:
|
||||
stripped_request_params = None
|
||||
splitted_request = request.strip().split(":")
|
||||
if len(splitted_request) != 2:
|
||||
if len(splitted_request) != 2: # noqa: PLR2004
|
||||
msg = f"Request '{request}' is not correctly formatted. \
|
||||
Please refer to the format instructions."
|
||||
raise OutputParserException(msg)
|
||||
@ -127,16 +127,15 @@ class PandasDataFrameOutputParser(BaseOutputParser[dict[str, Any]]):
|
||||
filtered_df[stripped_request_params],
|
||||
request_type,
|
||||
)()
|
||||
elif request_type == "column":
|
||||
result[request_params] = self.dataframe[request_params]
|
||||
elif request_type == "row":
|
||||
result[request_params] = self.dataframe.iloc[int(request_params)]
|
||||
else:
|
||||
if request_type == "column":
|
||||
result[request_params] = self.dataframe[request_params]
|
||||
elif request_type == "row":
|
||||
result[request_params] = self.dataframe.iloc[int(request_params)]
|
||||
else:
|
||||
result[request_type] = getattr(
|
||||
self.dataframe[request_params],
|
||||
request_type,
|
||||
)()
|
||||
result[request_type] = getattr(
|
||||
self.dataframe[request_params],
|
||||
request_type,
|
||||
)()
|
||||
except (AttributeError, IndexError, KeyError) as e:
|
||||
if request_type not in {"column", "row"}:
|
||||
msg = f"Unsupported request type '{request_type}'. \
|
||||
|
@ -70,9 +70,8 @@ class LLMChainFilter(BaseDocumentCompressor):
|
||||
output = output_[self.llm_chain.output_key]
|
||||
if self.llm_chain.prompt.output_parser is not None:
|
||||
include_doc = self.llm_chain.prompt.output_parser.parse(output)
|
||||
else:
|
||||
if isinstance(output_, bool):
|
||||
include_doc = output_
|
||||
elif isinstance(output_, bool):
|
||||
include_doc = output_
|
||||
if include_doc:
|
||||
filtered_docs.append(doc)
|
||||
|
||||
@ -101,9 +100,8 @@ class LLMChainFilter(BaseDocumentCompressor):
|
||||
output = output_[self.llm_chain.output_key]
|
||||
if self.llm_chain.prompt.output_parser is not None:
|
||||
include_doc = self.llm_chain.prompt.output_parser.parse(output)
|
||||
else:
|
||||
if isinstance(output_, bool):
|
||||
include_doc = output_
|
||||
elif isinstance(output_, bool):
|
||||
include_doc = output_
|
||||
if include_doc:
|
||||
filtered_docs.append(doc)
|
||||
|
||||
|
@ -147,11 +147,9 @@ select = [
|
||||
"A", # flake8-builtins
|
||||
"B", # flake8-bugbear
|
||||
"ASYNC", # flake8-async
|
||||
"B", # flake8-bugbear
|
||||
"C4", # flake8-comprehensions
|
||||
"COM", # flake8-commas
|
||||
"D", # pydocstyle
|
||||
"DOC", # pydoclint
|
||||
"DTZ", # flake8-datetimez
|
||||
"E", # pycodestyle error
|
||||
"EM", # flake8-errmsg
|
||||
@ -164,9 +162,10 @@ select = [
|
||||
"ICN", # flake8-import-conventions
|
||||
"INT", # flake8-gettext
|
||||
"ISC", # isort-comprehensions
|
||||
"PERF", # flake8-perf
|
||||
"PGH", # pygrep-hooks
|
||||
"PIE", # flake8-pie
|
||||
"PERF", # flake8-perf
|
||||
"PL", # pylint
|
||||
"PT", # flake8-pytest-style
|
||||
"PTH", # flake8-use-pathlib
|
||||
"PYI", # flake8-pyi
|
||||
@ -175,9 +174,9 @@ select = [
|
||||
"RSE", # flake8-rst-docstrings
|
||||
"RUF", # ruff
|
||||
"S", # flake8-bandit
|
||||
"SLF", # flake8-self
|
||||
"SLOT", # flake8-slots
|
||||
"SIM", # flake8-simplify
|
||||
"SLF", # flake8-self
|
||||
"T10", # flake8-debugger
|
||||
"T20", # flake8-print
|
||||
"TID", # flake8-tidy-imports
|
||||
@ -198,13 +197,15 @@ ignore = [
|
||||
"COM812", # Messes with the formatter
|
||||
"ISC001", # Messes with the formatter
|
||||
"PERF203", # Rarely useful
|
||||
"PLR09", # Too many something (args, statements, etc)
|
||||
"S112", # Rarely useful
|
||||
"RUF012", # Doesn't play well with Pydantic
|
||||
"SLF001", # Private member access
|
||||
"UP007", # pyupgrade: non-pep604-annotation-union
|
||||
|
||||
# TODO
|
||||
"TRY301", # tryceratops: raise-within-try
|
||||
# TODO rules
|
||||
"PLC0415", # pylint: import-outside-top-level
|
||||
"TRY301", # tryceratops: raise-within-try
|
||||
]
|
||||
unfixable = ["B028"] # People should intentionally tune the stacklevel
|
||||
|
||||
@ -217,9 +218,10 @@ pyupgrade.keep-runtime-typing = true
|
||||
|
||||
[tool.ruff.lint.extend-per-file-ignores]
|
||||
"tests/**/*.py" = [
|
||||
"S101", # Tests need assertions
|
||||
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
|
||||
"SLF001", # Private member access in tests
|
||||
"S101", # Tests need assertions
|
||||
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
|
||||
"SLF001", # Private member access in tests
|
||||
"PLR2004", # Magic value comparisons
|
||||
]
|
||||
"langchain/chains/constitutional_ai/principles.py" = [
|
||||
"E501", # Line too long
|
||||
|
@ -130,8 +130,7 @@ def pytest_collection_modifyitems(
|
||||
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`"),
|
||||
)
|
||||
break
|
||||
else:
|
||||
if only_extended:
|
||||
item.add_marker(
|
||||
pytest.mark.skip(reason="Skipping not an extended test."),
|
||||
)
|
||||
elif only_extended:
|
||||
item.add_marker(
|
||||
pytest.mark.skip(reason="Skipping not an extended test."),
|
||||
)
|
||||
|
@ -1,4 +1,4 @@
|
||||
import pytest as pytest
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain.retrievers.multi_query import LineListOutputParser, _unique_documents
|
||||
|
@ -10,6 +10,8 @@ class AnyStr(str):
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, str)
|
||||
|
||||
__hash__ = str.__hash__
|
||||
|
||||
|
||||
# The code below creates version of pydantic models
|
||||
# that will work in unit tests with AnyStr as id field
|
||||
|
Loading…
Reference in New Issue
Block a user