From 9e6ffd1264f0973be9aaf56aff910dba031f18c0 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Fri, 28 Feb 2025 19:22:20 +0100 Subject: [PATCH] core: Add ruff rules PTH (pathlib) (#29338) See https://docs.astral.sh/ruff/rules/#flake8-use-pathlib-pth Co-authored-by: ccurme --- libs/core/langchain_core/callbacks/file.py | 3 ++- libs/core/langchain_core/documents/base.py | 10 +++---- .../langchain_core/language_models/llms.py | 6 ++--- libs/core/langchain_core/prompts/base.py | 6 ++--- libs/core/langchain_core/prompts/chat.py | 5 ++-- libs/core/langchain_core/prompts/loading.py | 16 +++++------ libs/core/langchain_core/prompts/prompt.py | 6 ++--- .../langchain_core/runnables/graph_mermaid.py | 10 +++---- libs/core/pyproject.toml | 2 +- .../tests/unit_tests/prompts/test_prompt.py | 2 +- libs/core/tests/unit_tests/test_imports.py | 27 ++++++++----------- 11 files changed, 40 insertions(+), 53 deletions(-) diff --git a/libs/core/langchain_core/callbacks/file.py b/libs/core/langchain_core/callbacks/file.py index cd20fbe4f71..b68b3ba22e6 100644 --- a/libs/core/langchain_core/callbacks/file.py +++ b/libs/core/langchain_core/callbacks/file.py @@ -2,6 +2,7 @@ from __future__ import annotations +from pathlib import Path from typing import TYPE_CHECKING, Any, Optional, TextIO, cast from langchain_core.callbacks import BaseCallbackHandler @@ -30,7 +31,7 @@ class FileCallbackHandler(BaseCallbackHandler): mode: The mode to open the file in. Defaults to "a". color: The color to use for the text. Defaults to None. """ - self.file = cast(TextIO, open(filename, mode, encoding="utf-8")) # noqa: SIM115 + self.file = cast(TextIO, Path(filename).open(mode, encoding="utf-8")) # noqa: SIM115 self.color = color def __del__(self) -> None: diff --git a/libs/core/langchain_core/documents/base.py b/libs/core/langchain_core/documents/base.py index fb4fcd0987e..8d39daaa0c6 100644 --- a/libs/core/langchain_core/documents/base.py +++ b/libs/core/langchain_core/documents/base.py @@ -3,7 +3,7 @@ from __future__ import annotations import contextlib import mimetypes from io import BufferedReader, BytesIO -from pathlib import PurePath +from pathlib import Path, PurePath from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast from pydantic import ConfigDict, Field, field_validator, model_validator @@ -151,8 +151,7 @@ class Blob(BaseMedia): def as_string(self) -> str: """Read data as a string.""" if self.data is None and self.path: - with open(str(self.path), encoding=self.encoding) as f: - return f.read() + return Path(self.path).read_text(encoding=self.encoding) elif isinstance(self.data, bytes): return self.data.decode(self.encoding) elif isinstance(self.data, str): @@ -168,8 +167,7 @@ class Blob(BaseMedia): elif isinstance(self.data, str): return self.data.encode(self.encoding) elif self.data is None and self.path: - with open(str(self.path), "rb") as f: - return f.read() + return Path(self.path).read_bytes() else: msg = f"Unable to get bytes for blob {self}" raise ValueError(msg) @@ -180,7 +178,7 @@ class Blob(BaseMedia): if isinstance(self.data, bytes): yield BytesIO(self.data) elif self.data is None and self.path: - with open(str(self.path), "rb") as f: + with Path(self.path).open("rb") as f: yield f else: msg = f"Unable to convert blob {self}" diff --git a/libs/core/langchain_core/language_models/llms.py b/libs/core/langchain_core/language_models/llms.py index 9af4d4e9ac5..3c5ed7337f2 100644 --- a/libs/core/langchain_core/language_models/llms.py +++ b/libs/core/langchain_core/language_models/llms.py @@ -1402,7 +1402,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): llm.save(file_path="path/llm.yaml") """ # Convert file to Path object. - save_path = Path(file_path) if isinstance(file_path, str) else file_path + save_path = Path(file_path) directory_path = save_path.parent directory_path.mkdir(parents=True, exist_ok=True) @@ -1411,10 +1411,10 @@ class BaseLLM(BaseLanguageModel[str], ABC): prompt_dict = self.dict() if save_path.suffix == ".json": - with open(file_path, "w") as f: + with save_path.open("w") as f: json.dump(prompt_dict, f, indent=4) elif save_path.suffix.endswith((".yaml", ".yml")): - with open(file_path, "w") as f: + with save_path.open("w") as f: yaml.dump(prompt_dict, f, default_flow_style=False) else: msg = f"{save_path} must be json or yaml" diff --git a/libs/core/langchain_core/prompts/base.py b/libs/core/langchain_core/prompts/base.py index bb85416d3ff..16f0058d369 100644 --- a/libs/core/langchain_core/prompts/base.py +++ b/libs/core/langchain_core/prompts/base.py @@ -368,16 +368,16 @@ class BasePromptTemplate( raise NotImplementedError(msg) # Convert file to Path object. - save_path = Path(file_path) if isinstance(file_path, str) else file_path + save_path = Path(file_path) directory_path = save_path.parent directory_path.mkdir(parents=True, exist_ok=True) if save_path.suffix == ".json": - with open(file_path, "w") as f: + with save_path.open("w") as f: json.dump(prompt_dict, f, indent=4) elif save_path.suffix.endswith((".yaml", ".yml")): - with open(file_path, "w") as f: + with save_path.open("w") as f: yaml.dump(prompt_dict, f, default_flow_style=False) else: msg = f"{save_path} must be json or yaml" diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index c819bbe06d7..288e71849f3 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -3,6 +3,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from pathlib import Path from typing import ( TYPE_CHECKING, Annotated, @@ -48,7 +49,6 @@ from langchain_core.utils.interactive_env import is_interactive_env if TYPE_CHECKING: from collections.abc import Sequence - from pathlib import Path class BaseMessagePromptTemplate(Serializable, ABC): @@ -599,8 +599,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): Returns: A new instance of this class. """ - with open(str(template_file)) as f: - template = f.read() + template = Path(template_file).read_text() return cls.from_template(template, input_variables=input_variables, **kwargs) def format_messages(self, **kwargs: Any) -> list[BaseMessage]: diff --git a/libs/core/langchain_core/prompts/loading.py b/libs/core/langchain_core/prompts/loading.py index da89972e757..3eb24c7817e 100644 --- a/libs/core/langchain_core/prompts/loading.py +++ b/libs/core/langchain_core/prompts/loading.py @@ -53,8 +53,7 @@ def _load_template(var_name: str, config: dict) -> dict: template_path = Path(config.pop(f"{var_name}_path")) # Load the template. if template_path.suffix == ".txt": - with open(template_path) as f: - template = f.read() + template = template_path.read_text() else: raise ValueError # Set the template variable to the extracted variable. @@ -67,10 +66,11 @@ def _load_examples(config: dict) -> dict: if isinstance(config["examples"], list): pass elif isinstance(config["examples"], str): - with open(config["examples"]) as f: - if config["examples"].endswith(".json"): + path = Path(config["examples"]) + with path.open() as f: + if path.suffix == ".json": examples = json.load(f) - elif config["examples"].endswith((".yaml", ".yml")): + elif path.suffix in {".yaml", ".yml"}: examples = yaml.safe_load(f) else: msg = "Invalid file format. Only json or yaml formats are supported." @@ -168,13 +168,13 @@ def _load_prompt_from_file( ) -> BasePromptTemplate: """Load prompt from file.""" # Convert file to a Path object. - file_path = Path(file) if isinstance(file, str) else file + file_path = Path(file) # Load from either json or yaml. if file_path.suffix == ".json": - with open(file_path, encoding=encoding) as f: + with file_path.open(encoding=encoding) as f: config = json.load(f) elif file_path.suffix.endswith((".yaml", ".yml")): - with open(file_path, encoding=encoding) as f: + with file_path.open(encoding=encoding) as f: config = yaml.safe_load(f) else: msg = f"Got unsupported file type {file_path.suffix}" diff --git a/libs/core/langchain_core/prompts/prompt.py b/libs/core/langchain_core/prompts/prompt.py index 888fc9ccbc9..1aded63ad7e 100644 --- a/libs/core/langchain_core/prompts/prompt.py +++ b/libs/core/langchain_core/prompts/prompt.py @@ -3,6 +3,7 @@ from __future__ import annotations import warnings +from pathlib import Path from typing import TYPE_CHECKING, Any, Optional, Union from pydantic import BaseModel, model_validator @@ -17,8 +18,6 @@ from langchain_core.prompts.string import ( ) if TYPE_CHECKING: - from pathlib import Path - from langchain_core.runnables.config import RunnableConfig @@ -238,8 +237,7 @@ class PromptTemplate(StringPromptTemplate): Returns: The prompt loaded from the file. """ - with open(str(template_file), encoding=encoding) as f: - template = f.read() + template = Path(template_file).read_text(encoding=encoding) if input_variables: warnings.warn( "`input_variables' is deprecated and ignored.", diff --git a/libs/core/langchain_core/runnables/graph_mermaid.py b/libs/core/langchain_core/runnables/graph_mermaid.py index c9b3025bc85..a331b1b0582 100644 --- a/libs/core/langchain_core/runnables/graph_mermaid.py +++ b/libs/core/langchain_core/runnables/graph_mermaid.py @@ -2,6 +2,7 @@ import asyncio import base64 import re from dataclasses import asdict +from pathlib import Path from typing import Literal, Optional from langchain_core.runnables.graph import ( @@ -290,13 +291,9 @@ async def _render_mermaid_using_pyppeteer( img_bytes = await page.screenshot({"fullPage": False}) await browser.close() - def write_to_file(path: str, bytes: bytes) -> None: - with open(path, "wb") as file: - file.write(bytes) - if output_file_path is not None: await asyncio.get_event_loop().run_in_executor( - None, write_to_file, output_file_path, img_bytes + None, Path(output_file_path).write_bytes, img_bytes ) return img_bytes @@ -337,8 +334,7 @@ def _render_mermaid_using_api( if response.status_code == 200: img_bytes = response.content if output_file_path is not None: - with open(output_file_path, "wb") as file: - file.write(response.content) + Path(output_file_path).write_bytes(response.content) return img_bytes else: diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index dc4a8bbc473..8e39a25e939 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -77,7 +77,7 @@ target-version = "py39" [tool.ruff.lint] -select = [ "ANN", "ASYNC", "B", "C4", "COM", "DJ", "E", "EM", "EXE", "F", "FLY", "FURB", "I", "ICN", "INT", "LOG", "N", "NPY", "PD", "PIE", "Q", "RSE", "S", "SIM", "SLOT", "T10", "T201", "TC", "TID", "TRY", "UP", "W", "YTT",] +select = [ "ANN", "ASYNC", "B", "C4", "COM", "DJ", "E", "EM", "EXE", "F", "FLY", "FURB", "I", "ICN", "INT", "LOG", "N", "NPY", "PD", "PIE", "PTH", "Q", "RSE", "S", "SIM", "SLOT", "T10", "T201", "TC", "TID", "TRY", "UP", "W", "YTT",] ignore = [ "ANN401", "COM812", "UP007", "S110", "S112", "TC001", "TC002", "TC003"] flake8-type-checking.runtime-evaluated-base-classes = ["pydantic.BaseModel","langchain_core.load.serializable.Serializable","langchain_core.runnables.base.RunnableSerializable"] flake8-annotations.allow-star-arg-any = true diff --git a/libs/core/tests/unit_tests/prompts/test_prompt.py b/libs/core/tests/unit_tests/prompts/test_prompt.py index cef1e5595d2..7fd7ff9224f 100644 --- a/libs/core/tests/unit_tests/prompts/test_prompt.py +++ b/libs/core/tests/unit_tests/prompts/test_prompt.py @@ -354,7 +354,7 @@ def test_prompt_from_file_with_partial_variables() -> None: template = "This is a {foo} test {bar}." partial_variables = {"bar": "baz"} # when - with mock.patch("builtins.open", mock.mock_open(read_data=template)): + with mock.patch("pathlib.Path.open", mock.mock_open(read_data=template)): prompt = PromptTemplate.from_file( "mock_file_name", partial_variables=partial_variables ) diff --git a/libs/core/tests/unit_tests/test_imports.py b/libs/core/tests/unit_tests/test_imports.py index d046976e8e5..64f93aa606c 100644 --- a/libs/core/tests/unit_tests/test_imports.py +++ b/libs/core/tests/unit_tests/test_imports.py @@ -1,20 +1,17 @@ import concurrent.futures -import glob import importlib import subprocess from pathlib import Path def test_importable_all() -> None: - for path in glob.glob("../core/langchain_core/*"): - relative_path = Path(path).parts[-1] - if relative_path.endswith(".typed"): - continue - module_name = relative_path.split(".")[0] - module = importlib.import_module("langchain_core." + module_name) - all_ = getattr(module, "__all__", []) - for cls_ in all_: - getattr(module, cls_) + for path in Path("../core/langchain_core/").glob("*"): + module_name = path.stem + if not module_name.startswith(".") and path.suffix != ".typed": + module = importlib.import_module("langchain_core." + module_name) + all_ = getattr(module, "__all__", []) + for cls_ in all_: + getattr(module, cls_) def try_to_import(module_name: str) -> tuple[int, str]: @@ -37,12 +34,10 @@ def test_importable_all_via_subprocess() -> None: for one sequence of imports but not another. """ module_names = [] - for path in glob.glob("../core/langchain_core/*"): - relative_path = Path(path).parts[-1] - if relative_path.endswith(".typed"): - continue - module_name = relative_path.split(".")[0] - module_names.append(module_name) + for path in Path("../core/langchain_core/").glob("*"): + module_name = path.stem + if not module_name.startswith(".") and path.suffix != ".typed": + module_names.append(module_name) with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: futures = [