diff --git a/libs/cli/langchain_cli/dev_scripts.py b/libs/cli/langchain_cli/dev_scripts.py index b4fab46377a..00048d8afe3 100644 --- a/libs/cli/langchain_cli/dev_scripts.py +++ b/libs/cli/langchain_cli/dev_scripts.py @@ -13,7 +13,7 @@ def create_demo_server( *, config_keys: Sequence[str] = (), playground_type: Literal["default", "chat"] = "default", -): +) -> FastAPI: """Create a demo server for the current template.""" app = FastAPI() package_root = get_package_root() @@ -40,11 +40,11 @@ def create_demo_server( return app -def create_demo_server_configurable(): +def create_demo_server_configurable() -> FastAPI: """Create a configurable demo server.""" return create_demo_server(config_keys=["configurable"]) -def create_demo_server_chat(): +def create_demo_server_chat() -> FastAPI: """Create a chat demo server.""" return create_demo_server(playground_type="chat") diff --git a/libs/cli/langchain_cli/namespaces/app.py b/libs/cli/langchain_cli/namespaces/app.py index 35f61d61537..2894b0013e9 100644 --- a/libs/cli/langchain_cli/namespaces/app.py +++ b/libs/cli/langchain_cli/namespaces/app.py @@ -8,6 +8,7 @@ from pathlib import Path from typing import Annotated, Optional import typer +import uvicorn from langchain_cli.utils.events import create_events from langchain_cli.utils.git import ( @@ -261,7 +262,7 @@ def add( cmd = ["pip", "install", "-e", *installed_destination_strs] cmd_str = " \\\n ".join(installed_destination_strs) typer.echo(f"Running: pip install -e \\\n {cmd_str}") - subprocess.run(cmd, cwd=cwd) # noqa: S603 + subprocess.run(cmd, cwd=cwd, check=True) # noqa: S603 chain_names = [] for e in installed_exports: @@ -367,8 +368,6 @@ def serve( app_str = app if app is not None else "app.server:app" host_str = host if host is not None else "127.0.0.1" - import uvicorn - uvicorn.run( app_str, host=host_str, diff --git a/libs/cli/langchain_cli/namespaces/integration.py b/libs/cli/langchain_cli/namespaces/integration.py index b577c8a65d3..c96452703ad 100644 --- a/libs/cli/langchain_cli/namespaces/integration.py +++ b/libs/cli/langchain_cli/namespaces/integration.py @@ -129,6 +129,7 @@ def new( subprocess.run( ["poetry", "install", "--with", "lint,test,typing,test_integration"], # noqa: S607 cwd=destination_dir, + check=True, ) else: # confirm src and dst are the same length diff --git a/libs/cli/langchain_cli/namespaces/migrate/generate/generic.py b/libs/cli/langchain_cli/namespaces/migrate/generate/generic.py index 24b866c7a70..98167d8c231 100644 --- a/libs/cli/langchain_cli/namespaces/migrate/generate/generic.py +++ b/libs/cli/langchain_cli/namespaces/migrate/generate/generic.py @@ -3,6 +3,7 @@ import importlib import inspect import pkgutil +from types import ModuleType def generate_raw_migrations( @@ -89,7 +90,7 @@ def generate_top_level_imports(pkg: str) -> list[tuple[str, str]]: items = [] # Function to handle importing from modules - def handle_module(module, module_name) -> None: + def handle_module(module: ModuleType, module_name: str) -> None: if hasattr(module, "__all__"): all_objects = module.__all__ for name in all_objects: diff --git a/libs/cli/langchain_cli/namespaces/migrate/generate/utils.py b/libs/cli/langchain_cli/namespaces/migrate/generate/utils.py index de923256c11..4870763a704 100644 --- a/libs/cli/langchain_cli/namespaces/migrate/generate/utils.py +++ b/libs/cli/langchain_cli/namespaces/migrate/generate/utils.py @@ -5,10 +5,9 @@ import inspect import os import pathlib from pathlib import Path +from types import ModuleType from typing import Any, Optional -from typing_extensions import override - HERE = Path(__file__).parent # Should bring us to [root]/src PKGS_ROOT = HERE.parent.parent.parent.parent.parent @@ -19,15 +18,14 @@ PARTNER_PKGS = PKGS_ROOT / "partners" class ImportExtractor(ast.NodeVisitor): - """Import extractor""" + """Import extractor.""" def __init__(self, *, from_package: Optional[str] = None) -> None: """Extract all imports from the given code, optionally filtering by package.""" self.imports: list = [] self.package = from_package - @override - def visit_ImportFrom(self, node) -> None: + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: # noqa: N802 if node.module and ( self.package is None or str(node.module).startswith(self.package) ): @@ -46,7 +44,7 @@ def _get_class_names(code: str) -> list[str]: # Define a node visitor class to collect class names class ClassVisitor(ast.NodeVisitor): - def visit_ClassDef(self, node) -> None: # noqa: N802 + def visit_ClassDef(self, node: ast.ClassDef) -> None: # noqa: N802 class_names.append(node.name) self.generic_visit(node) @@ -65,7 +63,7 @@ def is_subclass(class_obj: Any, classes_: list[type]) -> bool: ) -def find_subclasses_in_module(module, classes_: list[type]) -> list[str]: +def find_subclasses_in_module(module: ModuleType, classes_: list[type]) -> list[str]: """Find all classes in the module that inherit from one of the classes.""" subclasses = [] # Iterate over all attributes of the module that are classes @@ -77,8 +75,7 @@ def find_subclasses_in_module(module, classes_: list[type]) -> list[str]: def _get_all_classnames_from_file(file: Path, pkg: str) -> list[tuple[str, str]]: """Extract all class names from a file.""" - with open(file, encoding="utf-8") as f: - code = f.read() + code = Path(file).read_text(encoding="utf-8") module_name = _get_current_module(file, pkg) class_names = _get_class_names(code) @@ -91,8 +88,7 @@ def identify_all_imports_in_file( from_package: Optional[str] = None, ) -> list[tuple[str, str]]: """Let's also identify all the imports in the given file.""" - with open(file, encoding="utf-8") as f: - code = f.read() + code = Path(file).read_text(encoding="utf-8") return find_imports_from_package(code, from_package=from_package) @@ -162,8 +158,7 @@ def find_imports_from_package( def _get_current_module(path: Path, pkg_root: str) -> str: """Convert a path to a module name.""" - path_as_pathlib = pathlib.Path(os.path.abspath(path)) - relative_path = path_as_pathlib.relative_to(pkg_root).with_suffix("") + relative_path = path.relative_to(pkg_root).with_suffix("") posix_path = relative_path.as_posix() norm_path = os.path.normpath(str(posix_path)) fully_qualified_module = norm_path.replace("/", ".") diff --git a/libs/cli/langchain_cli/namespaces/template.py b/libs/cli/langchain_cli/namespaces/template.py index 7ad5c14f494..4c2c0272e7f 100644 --- a/libs/cli/langchain_cli/namespaces/template.py +++ b/libs/cli/langchain_cli/namespaces/template.py @@ -7,7 +7,9 @@ from pathlib import Path from typing import Annotated, Optional import typer +import uvicorn +from langchain_cli.utils.github import list_packages from langchain_cli.utils.packages import get_langserve_export, get_package_root package_cli = typer.Typer(no_args_is_help=True, add_completion=False) @@ -79,7 +81,7 @@ def new( # poetry install if with_poetry: - subprocess.run(["poetry", "install"], cwd=destination_dir) # noqa: S607 + subprocess.run(["poetry", "install"], cwd=destination_dir, check=True) # noqa: S607 @package_cli.command() @@ -128,8 +130,6 @@ def serve( ) ) - import uvicorn - uvicorn.run( script, factory=True, @@ -142,8 +142,6 @@ def serve( @package_cli.command() def list(contains: Annotated[Optional[str], typer.Argument()] = None) -> None: # noqa: A001 """List all or search for available templates.""" - from langchain_cli.utils.github import list_packages - packages = list_packages(contains=contains) for package in packages: typer.echo(package) diff --git a/libs/cli/langchain_cli/utils/events.py b/libs/cli/langchain_cli/utils/events.py index b3b6edaef52..860c5aaedd9 100644 --- a/libs/cli/langchain_cli/utils/events.py +++ b/libs/cli/langchain_cli/utils/events.py @@ -16,6 +16,7 @@ class EventDict(TypedDict): event: The name of the event. properties: Optional dictionary of event properties. """ + event: str properties: Optional[dict[str, Any]] diff --git a/libs/cli/langchain_cli/utils/git.py b/libs/cli/langchain_cli/utils/git.py index 7ca46815224..9d5cf506229 100644 --- a/libs/cli/langchain_cli/utils/git.py +++ b/libs/cli/langchain_cli/utils/git.py @@ -1,6 +1,7 @@ """Git utilities.""" import hashlib +import logging import re import shutil from collections.abc import Sequence @@ -15,6 +16,8 @@ from langchain_cli.constants import ( DEFAULT_GIT_SUBDIRECTORY, ) +logger = logging.getLogger(__name__) + class DependencySource(TypedDict): """Dependency source information.""" @@ -181,16 +184,15 @@ def update_repo(gitstring: str, ref: Optional[str], repo_dir: Path) -> Path: # try pulling try: repo = Repo(repo_path) - if repo.active_branch.name != ref: - raise ValueError - repo.remotes.origin.pull() + if repo.active_branch.name == ref: + repo.remotes.origin.pull() + return repo_path except Exception: - # if it fails, delete and clone again - shutil.rmtree(repo_path) - Repo.clone_from(gitstring, repo_path, branch=ref, depth=1) - else: - Repo.clone_from(gitstring, repo_path, branch=ref, depth=1) + logger.exception("Failed to pull existing repo") + # if it fails, delete and clone again + shutil.rmtree(repo_path) + Repo.clone_from(gitstring, repo_path, branch=ref, depth=1) return repo_path @@ -203,7 +205,7 @@ def copy_repo( Raises FileNotFound error if it can't find source """ - def ignore_func(_, files): + def ignore_func(_: str, files: list[str]) -> list[str]: return [f for f in files if f == ".git"] shutil.copytree(source, destination, ignore=ignore_func) diff --git a/libs/cli/langchain_cli/utils/packages.py b/libs/cli/langchain_cli/utils/packages.py index 4d634e47792..0d3610c3887 100644 --- a/libs/cli/langchain_cli/utils/packages.py +++ b/libs/cli/langchain_cli/utils/packages.py @@ -39,7 +39,7 @@ class LangServeExport(TypedDict): def get_langserve_export(filepath: Path) -> LangServeExport: """Get LangServe export information from a pyproject.toml file.""" - with open(filepath) as f: + with filepath.open() as f: data: dict[str, Any] = load(f) try: module = data["tool"]["langserve"]["export_module"] diff --git a/libs/cli/langchain_cli/utils/pyproject.py b/libs/cli/langchain_cli/utils/pyproject.py index 732abd0ec76..3a90243f88c 100644 --- a/libs/cli/langchain_cli/utils/pyproject.py +++ b/libs/cli/langchain_cli/utils/pyproject.py @@ -20,7 +20,7 @@ def add_dependencies_to_pyproject_toml( local_editable_dependencies: Iterable[tuple[str, Path]], ) -> None: """Add dependencies to pyproject.toml.""" - with open(pyproject_toml, encoding="utf-8") as f: + with pyproject_toml.open(encoding="utf-8") as f: # tomlkit types aren't amazing - treat as Dict instead pyproject: dict[str, Any] = load(f) pyproject["tool"]["poetry"]["dependencies"].update( @@ -29,7 +29,7 @@ def add_dependencies_to_pyproject_toml( for name, loc in local_editable_dependencies }, ) - with open(pyproject_toml, "w", encoding="utf-8") as f: + with pyproject_toml.open("w", encoding="utf-8") as f: dump(pyproject, f) @@ -38,12 +38,13 @@ def remove_dependencies_from_pyproject_toml( local_editable_dependencies: Iterable[str], ) -> None: """Remove dependencies from pyproject.toml.""" - with open(pyproject_toml, encoding="utf-8") as f: + with pyproject_toml.open(encoding="utf-8") as f: pyproject: dict[str, Any] = load(f) # tomlkit types aren't amazing - treat as Dict instead dependencies = pyproject["tool"]["poetry"]["dependencies"] for name in local_editable_dependencies: with contextlib.suppress(KeyError): del dependencies[name] - with open(pyproject_toml, "w", encoding="utf-8") as f: + + with pyproject_toml.open("w", encoding="utf-8") as f: dump(pyproject, f) diff --git a/libs/cli/pyproject.toml b/libs/cli/pyproject.toml index 8b0c1494973..05c4bbb0526 100644 --- a/libs/cli/pyproject.toml +++ b/libs/cli/pyproject.toml @@ -48,53 +48,40 @@ exclude = [ ] [tool.ruff.lint] -select = [ - "A", # flake8-builtins - "B", # flake8-bugbear - "ARG", # flake8-unused-arguments - "ASYNC", # flake8-async - "C4", # flake8-comprehensions - "COM", # flake8-commas - "D1", # pydocstyle - "E", # pycodestyle error - "EM", # flake8-errmsg - "F", # pyflakes - "FA", # flake8-future-annotations - "FBT", # flake8-boolean-trap - "FLY", # flake8-flynt - "I", # isort - "ICN", # flake8-import-conventions - "INT", # flake8-gettext - "ISC", # isort-comprehensions - "N", # pep8-naming - "PT", # flake8-pytest-style - "PGH", # pygrep-hooks - "PIE", # flake8-pie - "PERF", # flake8-perf - "PYI", # flake8-pyi - "Q", # flake8-quotes - "RET", # flake8-return - "RSE", # flake8-rst-docstrings - "RUF", # ruff - "S", # flake8-bandit - "SLF", # flake8-self - "SLOT", # flake8-slots - "SIM", # flake8-simplify - "T10", # flake8-debugger - "T20", # flake8-print - "TID", # flake8-tidy-imports - "UP", # pyupgrade - "W", # pycodestyle warning - "YTT", # flake8-2020 -] +select = [ "ALL",] ignore = [ - "D407", # pydocstyle: Missing-dashed-underline-after-section + "C90", # McCabe complexity "COM812", # Messes with the formatter + "FIX002", # Line contains TODO + "PERF203", # Rarely useful + "PLR09", # Too many something (arg, statements, etc) + "RUF012", # Doesn't play well with Pydantic + "TC001", # Doesn't play well with Pydantic + "TC002", # Doesn't play well with Pydantic + "TC003", # Doesn't play well with Pydantic + "TD002", # Missing author in TODO + "TD003", # Missing issue link in TODO + + # TODO rules + "ANN401", + "BLE", + "D1", ] +unfixable = [ + "B028", # People should intentionally tune the stacklevel + "PLW1510", # People should intentionally set the check argument +] + +flake8-annotations.allow-star-arg-any = true +flake8-annotations.mypy-init-return = true +flake8-type-checking.runtime-evaluated-base-classes = ["pydantic.BaseModel","langchain_core.load.serializable.Serializable","langchain_core.runnables.base.RunnableSerializable"] +pep8-naming.classmethod-decorators = [ "classmethod", "langchain_core.utils.pydantic.pre_init", "pydantic.field_validator", "pydantic.v1.root_validator",] +pydocstyle.convention = "google" pyupgrade.keep-runtime-typing = true [tool.ruff.lint.per-file-ignores] -"tests/**" = [ "D1"] +"tests/**" = [ "D1", "S", "SLF",] +"scripts/**" = [ "INP", "S",] [tool.mypy] exclude = [ diff --git a/libs/cli/scripts/generate_migrations.py b/libs/cli/scripts/generate_migrations.py index 01dc652a633..9d76bff28ac 100644 --- a/libs/cli/scripts/generate_migrations.py +++ b/libs/cli/scripts/generate_migrations.py @@ -1,8 +1,8 @@ """Script to generate migrations for the migration script.""" import json -import os import pkgutil +from pathlib import Path from typing import Optional import click @@ -73,8 +73,7 @@ def generic( else: dumped = dump_migrations_as_grit(name, migrations) - with open(output, "w") as f: - f.write(dumped) + Path(output).write_text(dumped) def handle_partner(pkg: str, output: Optional[str] = None) -> None: @@ -85,8 +84,7 @@ def handle_partner(pkg: str, output: Optional[str] = None) -> None: data = dump_migrations_as_grit(name, migrations) output_name = f"{name}.grit" if output is None else output if migrations: - with open(output_name, "w") as f: - f.write(data) + Path(output_name).write_text(data) click.secho(f"LangChain migration script saved to {output_name}") else: click.secho(f"No migrations found for {pkg}", fg="yellow") @@ -105,13 +103,13 @@ def partner(pkg: str, output: str) -> None: @click.argument("json_file") def json_to_grit(json_file: str) -> None: """Generate a Grit migration from an old JSON migration file.""" - with open(json_file) as f: + file = Path(json_file) + with file.open() as f: migrations = json.load(f) - name = os.path.basename(json_file).removesuffix(".json").removesuffix(".grit") + name = file.stem data = dump_migrations_as_grit(name, migrations) output_name = f"{name}.grit" - with open(output_name, "w") as f: - f.write(data) + Path(output_name).write_text(data) click.secho(f"GritQL migration script saved to {output_name}") diff --git a/libs/cli/tests/integration_tests/__init__.py b/libs/cli/tests/integration_tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/libs/cli/tests/unit_tests/migrate/cli_runner/file.py b/libs/cli/tests/unit_tests/migrate/cli_runner/file.py index 1e0a5b4b497..c571e7c6099 100644 --- a/libs/cli/tests/unit_tests/migrate/cli_runner/file.py +++ b/libs/cli/tests/unit_tests/migrate/cli_runner/file.py @@ -14,3 +14,6 @@ class File: return False return self.content == __value.content + + def __hash__(self) -> int: + return hash((self.name, self.content)) diff --git a/libs/cli/tests/unit_tests/migrate/cli_runner/folder.py b/libs/cli/tests/unit_tests/migrate/cli_runner/folder.py index d40d494d35c..424420de278 100644 --- a/libs/cli/tests/unit_tests/migrate/cli_runner/folder.py +++ b/libs/cli/tests/unit_tests/migrate/cli_runner/folder.py @@ -57,3 +57,6 @@ class Folder: return False return True + + def __hash__(self) -> int: + return hash((self.name, tuple(self.files)))