core: Add ruff rules PLR (#30696)

Add ruff rules [PLR](https://docs.astral.sh/ruff/rules/#refactor-plr)
Except PLR09xxx and PLR2004.

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Christophe Bornet
2025-04-09 21:15:38 +02:00
committed by GitHub
parent 68361f9c2d
commit 4cc7bc6c93
19 changed files with 289 additions and 321 deletions

View File

@@ -53,8 +53,7 @@ class FileCallbackHandler(BaseCallbackHandler):
"""
if "name" in kwargs:
name = kwargs["name"]
else:
if serialized:
elif serialized:
name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
else:
name = "<unknown>"

View File

@@ -364,17 +364,14 @@ async def _ahandle_event_for_handler(
event = getattr(handler, event_name)
if asyncio.iscoroutinefunction(event):
await event(*args, **kwargs)
else:
if handler.run_inline:
elif handler.run_inline:
event(*args, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
cast(
"Callable",
functools.partial(
copy_context().run, event, *args, **kwargs
),
functools.partial(copy_context().run, event, *args, **kwargs),
),
)
except NotImplementedError as e:
@@ -2426,8 +2423,7 @@ def _configure(
for handler in callback_manager.handlers
):
callback_manager.add_handler(var_handler, inheritable)
else:
if not any(
elif not any(
isinstance(handler, handler_class)
for handler in callback_manager.handlers
):

View File

@@ -37,8 +37,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
"""
if "name" in kwargs:
name = kwargs["name"]
else:
if serialized:
elif serialized:
name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
else:
name = "<unknown>"

View File

@@ -316,7 +316,7 @@ def index(
)
raise ValueError(msg)
if (cleanup == "incremental" or cleanup == "scoped_full") and source_id_key is None:
if (cleanup in {"incremental", "scoped_full"}) and source_id_key is None:
msg = (
"Source id key is required when cleanup mode is incremental or scoped_full."
)
@@ -379,7 +379,7 @@ def index(
source_id_assigner(doc) for doc in hashed_docs
]
if cleanup == "incremental" or cleanup == "scoped_full":
if cleanup in {"incremental", "scoped_full"}:
# source ids are required.
for source_id, hashed_doc in zip(source_ids, hashed_docs):
if source_id is None:
@@ -622,7 +622,7 @@ async def aindex(
)
raise ValueError(msg)
if (cleanup == "incremental" or cleanup == "scoped_full") and source_id_key is None:
if (cleanup in {"incremental", "scoped_full"}) and source_id_key is None:
msg = (
"Source id key is required when cleanup mode is incremental or scoped_full."
)
@@ -667,8 +667,7 @@ async def aindex(
# In such a case, we use the load method and convert it to an async
# iterator.
async_doc_iterator = _to_async_iterator(docs_source.load())
else:
if hasattr(docs_source, "__aiter__"):
elif hasattr(docs_source, "__aiter__"):
async_doc_iterator = docs_source # type: ignore[assignment]
else:
async_doc_iterator = _to_async_iterator(docs_source)
@@ -694,7 +693,7 @@ async def aindex(
source_id_assigner(doc) for doc in hashed_docs
]
if cleanup == "incremental" or cleanup == "scoped_full":
if cleanup in {"incremental", "scoped_full"}:
# If the cleanup mode is incremental, source ids are required.
for source_id, hashed_doc in zip(source_ids, hashed_docs):
if source_id is None:

View File

@@ -955,8 +955,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
)
chunks.append(chunk)
result = generate_from_stream(iter(chunks))
else:
if inspect.signature(self._generate).parameters.get("run_manager"):
elif inspect.signature(self._generate).parameters.get("run_manager"):
result = self._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
@@ -1028,8 +1027,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
)
chunks.append(chunk)
result = generate_from_stream(iter(chunks))
else:
if inspect.signature(self._agenerate).parameters.get("run_manager"):
elif inspect.signature(self._agenerate).parameters.get("run_manager"):
result = await self._agenerate(
messages, stop=stop, run_manager=run_manager, **kwargs
)

View File

@@ -170,10 +170,9 @@ def merge_content(
# If both are lists
merged = merge_lists(cast("list", merged), content) # type: ignore
# If the first content is a list, and the second content is a string
else:
# If the last element of the first content is a string
# Add the second content to the last element
if merged and isinstance(merged[-1], str):
elif merged and isinstance(merged[-1], str):
merged[-1] += content
# If second content is an empty string, treat as a no-op
elif content == "":

View File

@@ -1030,8 +1030,7 @@ def convert_to_openai_messages(
content = message.content
else:
content = [{"type": "text", "text": message.content}]
else:
if text_format == "string" and all(
elif text_format == "string" and all(
isinstance(block, str) or block.get("type") == "text"
for block in message.content
):
@@ -1075,9 +1074,7 @@ def convert_to_openai_messages(
# Anthropic
if source := block.get("source"):
if missing := [
k
for k in ("media_type", "type", "data")
if k not in source
k for k in ("media_type", "type", "data") if k not in source
]:
err = (
f"Unrecognized content block at "
@@ -1192,9 +1189,7 @@ def convert_to_openai_messages(
"text": json.dumps(block["json"]),
}
)
elif (
block.get("type") == "guard_content"
) or "guard_content" in block:
elif (block.get("type") == "guard_content") or "guard_content" in block:
if (
"guard_content" not in block
or "text" not in block["guard_content"]
@@ -1213,9 +1208,7 @@ def convert_to_openai_messages(
content.append({"type": "text", "text": text})
# VertexAI format
elif block.get("type") == "media":
if missing := [
k for k in ("mime_type", "data") if k not in block
]:
if missing := [k for k in ("mime_type", "data") if k not in block]:
err = (
f"Unrecognized content block at "
f"messages[{i}].content[{j}] has 'type': "
@@ -1235,9 +1228,7 @@ def convert_to_openai_messages(
{
"type": "image_url",
"image_url": {
"url": (
f"data:{block['mime_type']};base64,{b64_image}"
)
"url": (f"data:{block['mime_type']};base64,{b64_image}")
},
}
)

View File

@@ -118,12 +118,9 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
}
except json.JSONDecodeError:
return None
else:
if self.args_only:
elif self.args_only:
try:
return json.loads(
function_call["arguments"], strict=self.strict
)
return json.loads(function_call["arguments"], strict=self.strict)
except (json.JSONDecodeError, TypeError) as exc:
msg = f"Could not parse function call data: {exc}"
raise OutputParserException(msg) from exc

View File

@@ -9,10 +9,9 @@ from typing import Any, Callable, Literal
from pydantic import BaseModel, create_model
import langchain_core.utils.mustache as mustache
from langchain_core.prompt_values import PromptValue, StringPromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.utils import get_colored_text
from langchain_core.utils import get_colored_text, mustache
from langchain_core.utils.formatting import formatter
from langchain_core.utils.interactive_env import is_interactive_env

View File

@@ -351,9 +351,7 @@ class Graph:
"""
self.nodes.pop(node.id)
self.edges = [
edge
for edge in self.edges
if edge.source != node.id and edge.target != node.id
edge for edge in self.edges if node.id not in (edge.source, edge.target)
]
def add_edge(

View File

@@ -401,7 +401,7 @@ def _render_mermaid_using_api(
f"?type={file_type}&bgColor={background_color}"
)
response = requests.get(image_url, timeout=10)
if response.status_code == 200:
if response.status_code == requests.codes.ok:
img_bytes = response.content
if output_file_path is not None:
Path(output_file_path).write_bytes(response.content)

View File

@@ -79,8 +79,7 @@ class RootListenersTracer(BaseTracer):
if run.error is None:
if self._arg_on_end is not None:
call_func_with_variable_args(self._arg_on_end, run, self.config)
else:
if self._arg_on_error is not None:
elif self._arg_on_error is not None:
call_func_with_variable_args(self._arg_on_error, run, self.config)
@@ -143,8 +142,5 @@ class AsyncRootListenersTracer(AsyncBaseTracer):
if run.error is None:
if self._arg_on_end is not None:
await acall_func_with_variable_args(self._arg_on_end, run, self.config)
else:
if self._arg_on_error is not None:
await acall_func_with_variable_args(
self._arg_on_error, run, self.config
)
elif self._arg_on_error is not None:
await acall_func_with_variable_args(self._arg_on_error, run, self.config)

View File

@@ -7,6 +7,8 @@ from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.schemas import Run
from langchain_core.utils.input import get_bolded_text, get_colored_text
MILLISECONDS_IN_SECOND = 1000
def try_json_stringify(obj: Any, fallback: str) -> str:
"""Try to stringify an object to JSON.
@@ -36,10 +38,10 @@ def elapsed(run: Any) -> str:
"""
elapsed_time = run.end_time - run.start_time
milliseconds = elapsed_time.total_seconds() * 1000
if milliseconds < 1000:
return f"{milliseconds:.0f}ms"
return f"{(milliseconds / 1000):.2f}s"
seconds = elapsed_time.total_seconds()
if seconds < 1:
return f"{seconds * MILLISECONDS_IN_SECOND:.0f}ms"
return f"{seconds:.2f}s"
class FunctionCallbackHandler(BaseTracer):

View File

@@ -85,7 +85,7 @@ def extract_sub_links(
try:
parsed_link = urlparse(link)
# Some may be absolute links like https://to/path
if parsed_link.scheme == "http" or parsed_link.scheme == "https":
if parsed_link.scheme in {"http", "https"}:
absolute_path = link
# Some may have omitted the protocol like //to/path
elif link.startswith("//"):

View File

@@ -78,15 +78,14 @@ def parse_partial_json(s: str, *, strict: bool = False) -> Any:
escaped = not escaped
else:
escaped = False
else:
if char == '"':
elif char == '"':
is_inside_string = True
escaped = False
elif char == "{":
stack.append("}")
elif char == "[":
stack.append("]")
elif char == "}" or char == "]":
elif char in {"}", "]"}:
if stack and stack[-1] == char:
stack.pop()
else:

View File

@@ -459,8 +459,7 @@ def render(
# Then we don't need to tokenize it
# But it does need to be a generator
tokens: Iterator[tuple[str, str]] = (token for token in template)
else:
if template in g_token_cache:
elif template in g_token_cache:
tokens = (token for token in g_token_cache[template])
else:
# Otherwise make a generator

View File

@@ -103,7 +103,7 @@ ignore = [
"FBT001",
"FBT002",
"PGH003",
"PLR",
"PLR2004",
"RUF",
"SLF",
]

View File

@@ -112,11 +112,8 @@ def pytest_collection_modifyitems(
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`")
)
break
else:
if only_extended:
item.add_marker(
pytest.mark.skip(reason="Skipping not an extended test.")
)
elif only_extended:
item.add_marker(pytest.mark.skip(reason="Skipping not an extended test."))
@pytest.fixture

View File

@@ -1,4 +1,4 @@
import langchain_core.tracers.schemas as schemas
from langchain_core.tracers import schemas
from langchain_core.tracers.schemas import __all__ as schemas_all