diff --git a/libs/cli/langchain_cli/cli.py b/libs/cli/langchain_cli/cli.py index 7e766713268..b60035d18a7 100644 --- a/libs/cli/langchain_cli/cli.py +++ b/libs/cli/langchain_cli/cli.py @@ -35,7 +35,7 @@ app.command( def version_callback(show_version: bool) -> None: if show_version: typer.echo(f"langchain-cli {__version__}") - raise typer.Exit() + raise typer.Exit @app.callback() @@ -48,7 +48,7 @@ def main( callback=version_callback, is_eager=True, ), -): +) -> None: pass @@ -62,10 +62,7 @@ def serve( Optional[str], typer.Option(help="The host to run the server on") ] = None, ) -> None: - """ - Start the LangServe app, whether it's a template or an app. - """ - + """Start the LangServe app, whether it's a template or an app.""" # see if is a template try: project_dir = get_package_root() diff --git a/libs/cli/langchain_cli/dev_scripts.py b/libs/cli/langchain_cli/dev_scripts.py index b12b47ed166..31c4259339c 100644 --- a/libs/cli/langchain_cli/dev_scripts.py +++ b/libs/cli/langchain_cli/dev_scripts.py @@ -1,7 +1,5 @@ # type: ignore -""" -Development Scripts for template packages -""" +"""Development Scripts for template packages.""" from collections.abc import Sequence @@ -16,9 +14,7 @@ def create_demo_server( config_keys: Sequence[str] = (), playground_type: str = "default", ): - """ - Creates a demo server for the current template. - """ + """Creates a demo server for the current template.""" app = FastAPI() package_root = get_package_root() pyproject = package_root / "pyproject.toml" @@ -35,9 +31,11 @@ def create_demo_server( playground_type=playground_type, ) except KeyError as e: - raise KeyError("Missing fields from pyproject.toml") from e + msg = "Missing fields from pyproject.toml" + raise KeyError(msg) from e except ImportError as e: - raise ImportError("Could not import module defined in pyproject.toml") from e + msg = "Could not import module defined in pyproject.toml" + raise ImportError(msg) from e return app diff --git a/libs/cli/langchain_cli/namespaces/app.py b/libs/cli/langchain_cli/namespaces/app.py index 21d967dcf56..a259cf48aa2 100644 --- a/libs/cli/langchain_cli/namespaces/app.py +++ b/libs/cli/langchain_cli/namespaces/app.py @@ -1,6 +1,4 @@ -""" -Manage LangChain apps -""" +"""Manage LangChain apps.""" import shutil import subprocess @@ -62,15 +60,14 @@ def new( is_flag=True, ), ] = False, -): - """ - Create a new LangServe application. - """ +) -> None: + """Create a new LangServe application.""" has_packages = package is not None and len(package) > 0 if noninteractive: if name is None: - raise typer.BadParameter("name is required when --non-interactive is set") + msg = "name is required when --non-interactive is set" + raise typer.BadParameter(msg) name_str = name pip_bool = bool(pip) # None should be false else: @@ -154,9 +151,8 @@ def add( prompt="Would you like to `pip install -e` the template(s)?", ), ], -): - """ - Adds the specified template to the current LangServe app. +) -> None: + """Adds the specified template to the current LangServe app. e.g.: langchain app add extraction-openai-functions @@ -166,7 +162,8 @@ def add( if not branch and not repo: warnings.warn( "Adding templates from the default branch and repo is deprecated." - " At a minimum, you will have to add `--branch v0.2` for this to work" + " At a minimum, you will have to add `--branch v0.2` for this to work", + stacklevel=2, ) parsed_deps = parse_dependencies(dependencies, repo, branch, api_path) @@ -176,7 +173,7 @@ def add( package_dir = project_root / "packages" create_events( - [{"event": "serve add", "properties": dict(parsed_dep=d)} for d in parsed_deps] + [{"event": "serve add", "properties": {"parsed_dep": d}} for d in parsed_deps] ) # group by repo/ref @@ -216,7 +213,7 @@ def add( destination_path = package_dir / inner_api_path if destination_path.exists(): typer.echo( - f"Folder {str(inner_api_path)} already exists. Skipping...", + f"Folder {inner_api_path} already exists. Skipping...", ) continue copy_repo(source_path, destination_path) @@ -248,7 +245,7 @@ def add( typer.echo("Failed to print install command, continuing...") else: if pip: - cmd = ["pip", "install", "-e"] + installed_destination_strs + cmd = ["pip", "install", "-e", *installed_destination_strs] cmd_str = " \\\n ".join(installed_destination_strs) typer.echo(f"Running: pip install -e \\\n {cmd_str}") subprocess.run(cmd, cwd=cwd) @@ -282,13 +279,15 @@ def add( if len(chain_names) == 1 else f"these {len(chain_names)} templates" ) - lines = ( - ["", f"To use {t}, add the following to your app:\n\n```", ""] - + imports - + [""] - + routes - + ["```"] - ) + lines = [ + "", + f"To use {t}, add the following to your app:\n\n```", + "", + *imports, + "", + *routes, + "```", + ] typer.echo("\n".join(lines)) @@ -299,11 +298,8 @@ def remove( project_dir: Annotated[ Optional[Path], typer.Option(help="The project directory") ] = None, -): - """ - Removes the specified package from the current LangServe app. - """ - +) -> None: + """Removes the specified package from the current LangServe app.""" project_root = get_package_root(project_dir) project_pyproject = project_root / "pyproject.toml" @@ -347,10 +343,7 @@ def serve( Optional[str], typer.Option(help="The app to run, e.g. `app.server:app`") ] = None, ) -> None: - """ - Starts the LangServe app. - """ - + """Starts the LangServe app.""" # add current dir as first entry of path sys.path.append(str(Path.cwd())) diff --git a/libs/cli/langchain_cli/namespaces/integration.py b/libs/cli/langchain_cli/namespaces/integration.py index 34c7f06b995..e92ed283848 100644 --- a/libs/cli/langchain_cli/namespaces/integration.py +++ b/libs/cli/langchain_cli/namespaces/integration.py @@ -1,6 +1,4 @@ -""" -Develop integration packages for LangChain. -""" +"""Develop integration packages for LangChain.""" import re import shutil @@ -28,18 +26,20 @@ class Replacements(TypedDict): def _process_name(name: str, *, community: bool = False) -> Replacements: preprocessed = name.replace("_", "-").lower() - if preprocessed.startswith("langchain-"): - preprocessed = preprocessed[len("langchain-") :] + preprocessed = preprocessed.removeprefix("langchain-") if not re.match(r"^[a-z][a-z0-9-]*$", preprocessed): - raise ValueError( + msg = ( "Name should only contain lowercase letters (a-z), numbers, and hyphens" ", and start with a letter." ) + raise ValueError(msg) if preprocessed.endswith("-"): - raise ValueError("Name should not end with `-`.") + msg = "Name should not end with `-`." + raise ValueError(msg) if preprocessed.find("--") != -1: - raise ValueError("Name should not contain consecutive hyphens.") + msg = "Name should not contain consecutive hyphens." + raise ValueError(msg) replacements: Replacements = { "__package_name__": f"langchain-{preprocessed}", "__module_name__": "langchain_" + preprocessed.replace("-", "_"), @@ -84,11 +84,8 @@ def new( ". e.g. `my-integration/my_integration.py`", ), ] = None, -): - """ - Creates a new integration package. - """ - +) -> None: + """Creates a new integration package.""" try: replacements = _process_name(name) except ValueError as e: @@ -123,7 +120,7 @@ def new( shutil.move(destination_dir / "integration_template", package_dir) # replacements in files - replace_glob(destination_dir, "**/*", cast(dict[str, str], replacements)) + replace_glob(destination_dir, "**/*", cast("dict[str, str]", replacements)) # poetry install subprocess.run( @@ -167,7 +164,7 @@ def new( for src_path, dst_path in zip(src_paths, dst_paths): shutil.copy(src_path, dst_path) - replace_file(dst_path, cast(dict[str, str], replacements)) + replace_file(dst_path, cast("dict[str, str]", replacements)) TEMPLATE_MAP: dict[str, str] = { @@ -183,7 +180,7 @@ TEMPLATE_MAP: dict[str, str] = { "Retriever": "retrievers.ipynb", } -_component_types_str = ", ".join(f"`{k}`" for k in TEMPLATE_MAP.keys()) +_component_types_str = ", ".join(f"`{k}`" for k in TEMPLATE_MAP) @integration_cli.command() @@ -226,10 +223,8 @@ def create_doc( prompt="The relative path to the docs directory to place the new file in.", ), ] = "docs/docs/integrations/chat/", -): - """ - Creates a new integration doc. - """ +) -> None: + """Creates a new integration doc.""" if component_type not in TEMPLATE_MAP: typer.echo( f"Unrecognized {component_type=}. Expected one of {_component_types_str}." diff --git a/libs/cli/langchain_cli/namespaces/migrate/generate/generic.py b/libs/cli/langchain_cli/namespaces/migrate/generate/generic.py index 3b6631d53ac..e2fa0fa4879 100644 --- a/libs/cli/langchain_cli/namespaces/migrate/generate/generic.py +++ b/libs/cli/langchain_cli/namespaces/migrate/generate/generic.py @@ -12,7 +12,7 @@ def generate_raw_migrations( package = importlib.import_module(from_package) items = [] - for importer, modname, ispkg in pkgutil.walk_packages( + for _importer, modname, _ispkg in pkgutil.walk_packages( package.__path__, package.__name__ + "." ): try: @@ -84,9 +84,9 @@ def generate_top_level_imports(pkg: str) -> list[tuple[str, str]]: items = [] # Function to handle importing from modules - def handle_module(module, module_name): + def handle_module(module, module_name) -> None: if hasattr(module, "__all__"): - all_objects = getattr(module, "__all__") + all_objects = module.__all__ for name in all_objects: # Attempt to fetch each object declared in __all__ obj = getattr(module, name, None) @@ -105,7 +105,7 @@ def generate_top_level_imports(pkg: str) -> list[tuple[str, str]]: handle_module(package, pkg) # Only iterate through top-level modules/packages - for finder, modname, ispkg in pkgutil.iter_modules( + for _finder, modname, ispkg in pkgutil.iter_modules( package.__path__, package.__name__ + "." ): if ispkg: @@ -126,7 +126,7 @@ def generate_simplified_migrations( from_package, to_package, filter_by_all=filter_by_all ) top_level_simplifications = generate_top_level_imports(to_package) - top_level_dict = {full: top_level for full, top_level in top_level_simplifications} + top_level_dict = dict(top_level_simplifications) simple_migrations = [] for migration in raw_migrations: original, new = migration diff --git a/libs/cli/langchain_cli/namespaces/migrate/generate/grit.py b/libs/cli/langchain_cli/namespaces/migrate/generate/grit.py index f39d78be477..3da2c9933b0 100644 --- a/libs/cli/langchain_cli/namespaces/migrate/generate/grit.py +++ b/libs/cli/langchain_cli/namespaces/migrate/generate/grit.py @@ -1,12 +1,11 @@ def split_package(package: str) -> tuple[str, str]: - """Split a package name into the containing package and the final name""" + """Split a package name into the containing package and the final name.""" parts = package.split(".") return ".".join(parts[:-1]), parts[-1] -def dump_migrations_as_grit(name: str, migration_pairs: list[tuple[str, str]]): +def dump_migrations_as_grit(name: str, migration_pairs: list[tuple[str, str]]) -> str: """Dump the migration pairs as a Grit file.""" - output = "language python" remapped = ",\n".join( [ f""" @@ -21,7 +20,7 @@ def dump_migrations_as_grit(name: str, migration_pairs: list[tuple[str, str]]): ] ) pattern_name = f"langchain_migrate_{name}" - output = f""" + return f""" language python // This migration is generated automatically - do not manually edit this file @@ -34,4 +33,3 @@ pattern {pattern_name}() {{ // Add this for invoking directly {pattern_name}() """ - return output diff --git a/libs/cli/langchain_cli/namespaces/migrate/generate/partner.py b/libs/cli/langchain_cli/namespaces/migrate/generate/partner.py index d17000e4859..ef8a34409db 100644 --- a/libs/cli/langchain_cli/namespaces/migrate/generate/partner.py +++ b/libs/cli/langchain_cli/namespaces/migrate/generate/partner.py @@ -46,9 +46,8 @@ def get_migrations_for_partner_package(pkg_name: str) -> list[tuple[str, str]]: old_paths = community_classes + imports_for_pkg - migrations = [ + return [ (f"{module}.{item}", f"{pkg_name}.{item}") for module, item in old_paths if item in classes_ ] - return migrations diff --git a/libs/cli/langchain_cli/namespaces/migrate/generate/utils.py b/libs/cli/langchain_cli/namespaces/migrate/generate/utils.py index 5fa9b8d2cdb..2fbd65f8b8f 100644 --- a/libs/cli/langchain_cli/namespaces/migrate/generate/utils.py +++ b/libs/cli/langchain_cli/namespaces/migrate/generate/utils.py @@ -20,7 +20,7 @@ class ImportExtractor(ast.NodeVisitor): self.imports: list = [] self.package = from_package - def visit_ImportFrom(self, node): + def visit_ImportFrom(self, node) -> None: if node.module and ( self.package is None or str(node.module).startswith(self.package) ): @@ -39,7 +39,7 @@ def _get_class_names(code: str) -> list[str]: # Define a node visitor class to collect class names class ClassVisitor(ast.NodeVisitor): - def visit_ClassDef(self, node): + def visit_ClassDef(self, node) -> None: class_names.append(node.name) self.generic_visit(node) @@ -62,7 +62,7 @@ def find_subclasses_in_module(module, classes_: list[type]) -> list[str]: """Find all classes in the module that inherit from one of the classes.""" subclasses = [] # Iterate over all attributes of the module that are classes - for name, obj in inspect.getmembers(module, inspect.isclass): + for _name, obj in inspect.getmembers(module, inspect.isclass): if is_subclass(obj, classes_): subclasses.append(obj.__name__) return subclasses @@ -125,7 +125,7 @@ def list_init_imports_by_package(pkg_root: str) -> list[tuple[str, str]]: files = list(Path(pkg_source).rglob("*.py")) for file in files: - if not file.name == "__init__.py": + if file.name != "__init__.py": continue import_in_file = identify_all_imports_in_file(str(file)) module_name = _get_current_module(file, pkg_root) diff --git a/libs/cli/langchain_cli/namespaces/template.py b/libs/cli/langchain_cli/namespaces/template.py index 16b8e0ca4d7..1f348e35178 100644 --- a/libs/cli/langchain_cli/namespaces/template.py +++ b/libs/cli/langchain_cli/namespaces/template.py @@ -1,6 +1,4 @@ -""" -Develop installable templates. -""" +"""Develop installable templates.""" import re import shutil @@ -22,10 +20,8 @@ def new( bool, typer.Option("--with-poetry/--no-poetry", help="Don't run poetry install"), ] = False, -): - """ - Creates a new template package. - """ +) -> None: + """Creates a new template package.""" computed_name = name if name != "." else Path.cwd().name destination_dir = Path.cwd() / name if name != "." else Path.cwd() @@ -108,9 +104,7 @@ def serve( ), ] = False, ) -> None: - """ - Starts a demo app for this template. - """ + """Starts a demo app for this template.""" # load pyproject.toml project_dir = get_package_root() pyproject = project_dir / "pyproject.toml" @@ -143,9 +137,7 @@ def serve( @package_cli.command() def list(contains: Annotated[Optional[str], typer.Argument()] = None) -> None: - """ - List all or search for available templates. - """ + """List all or search for available templates.""" from langchain_cli.utils.github import list_packages packages = list_packages(contains=contains) diff --git a/libs/cli/langchain_cli/utils/git.py b/libs/cli/langchain_cli/utils/git.py index d3777a9697c..cef0ba639bc 100644 --- a/libs/cli/langchain_cli/utils/git.py +++ b/libs/cli/langchain_cli/utils/git.py @@ -31,10 +31,11 @@ def parse_dependency_string( ) -> DependencySource: if dep is not None and dep.startswith("git+"): if repo is not None or branch is not None: - raise ValueError( + msg = ( "If a dependency starts with git+, you cannot manually specify " "a repo or branch." ) + raise ValueError(msg) # remove git+ gitstring = dep[4:] subdirectory = None @@ -43,9 +44,8 @@ def parse_dependency_string( if "#subdirectory=" in gitstring: gitstring, subdirectory = gitstring.split("#subdirectory=") if "#" in subdirectory or "@" in subdirectory: - raise ValueError( - "#subdirectory must be the last part of the dependency string" - ) + msg = "#subdirectory must be the last part of the dependency string" + raise ValueError(msg) # find first slash after :// # find @ or # after that slash @@ -54,9 +54,8 @@ def parse_dependency_string( # find first slash after :// if "://" not in gitstring: - raise ValueError( - "git+ dependencies must start with git+https:// or git+ssh://" - ) + msg = "git+ dependencies must start with git+https:// or git+ssh://" + raise ValueError(msg) _, find_slash = gitstring.split("://", 1) @@ -79,41 +78,41 @@ def parse_dependency_string( event_metadata={"dependency_string": dep}, ) - elif dep is not None and dep.startswith("https://"): - raise ValueError("Only git dependencies are supported") - else: - # if repo is none, use default, including subdirectory - base_subdir = Path(DEFAULT_GIT_SUBDIRECTORY) if repo is None else Path() - subdir = str(base_subdir / dep) if dep is not None else None - gitstring = ( - DEFAULT_GIT_REPO - if repo is None - else f"https://github.com/{repo.strip('/')}.git" - ) - ref = DEFAULT_GIT_REF if branch is None else branch - # it's a default git repo dependency - return DependencySource( - git=gitstring, - ref=ref, - subdirectory=subdir, - api_path=api_path, - event_metadata={ - "dependency_string": dep, - "used_repo_flag": repo is not None, - "used_branch_flag": branch is not None, - }, - ) + if dep is not None and dep.startswith("https://"): + msg = "Only git dependencies are supported" + raise ValueError(msg) + # if repo is none, use default, including subdirectory + base_subdir = Path(DEFAULT_GIT_SUBDIRECTORY) if repo is None else Path() + subdir = str(base_subdir / dep) if dep is not None else None + gitstring = ( + DEFAULT_GIT_REPO + if repo is None + else f"https://github.com/{repo.strip('/')}.git" + ) + ref = DEFAULT_GIT_REF if branch is None else branch + # it's a default git repo dependency + return DependencySource( + git=gitstring, + ref=ref, + subdirectory=subdir, + api_path=api_path, + event_metadata={ + "dependency_string": dep, + "used_repo_flag": repo is not None, + "used_branch_flag": branch is not None, + }, + ) def _list_arg_to_length(arg: Optional[list[str]], num: int) -> Sequence[Optional[str]]: if not arg: return [None] * num - elif len(arg) == 1: + if len(arg) == 1: return arg * num - elif len(arg) == num: + if len(arg) == num: return arg - else: - raise ValueError(f"Argument must be of length 1 or {num}") + msg = f"Argument must be of length 1 or {num}" + raise ValueError(msg) def parse_dependencies( @@ -131,10 +130,11 @@ def parse_dependencies( or (repo and len(repo) not in [1, num_deps]) or (branch and len(branch) not in [1, num_deps]) ): - raise ValueError( + msg = ( "Number of defined repos/branches/api_paths did not match the " "number of templates." ) + raise ValueError(msg) inner_deps = _list_arg_to_length(dependencies, num_deps) inner_api_paths = _list_arg_to_length(api_path, num_deps) inner_repos = _list_arg_to_length(repo, num_deps) @@ -170,7 +170,7 @@ def update_repo(gitstring: str, ref: Optional[str], repo_dir: Path) -> Path: try: repo = Repo(repo_path) if repo.active_branch.name != ref: - raise ValueError() + raise ValueError repo.remotes.origin.pull() except Exception: # if it fails, delete and clone again @@ -186,8 +186,7 @@ def copy_repo( source: Path, destination: Path, ) -> None: - """ - Copies a repo, ignoring git folders. + """Copies a repo, ignoring git folders. Raises FileNotFound error if it can't find source """ diff --git a/libs/cli/langchain_cli/utils/github.py b/libs/cli/langchain_cli/utils/github.py index edaea5525d7..fe2c6b3d638 100644 --- a/libs/cli/langchain_cli/utils/github.py +++ b/libs/cli/langchain_cli/utils/github.py @@ -23,7 +23,4 @@ def list_packages(*, contains: Optional[str] = None): package_names = [ p["name"] for p in data if p["type"] == "dir" and p["name"] != "docs" ] - package_names_filtered = ( - [p for p in package_names if contains in p] if contains else package_names - ) - return package_names_filtered + return [p for p in package_names if contains in p] if contains else package_names diff --git a/libs/cli/langchain_cli/utils/packages.py b/libs/cli/langchain_cli/utils/packages.py index 6dbef342c0b..c8d00722579 100644 --- a/libs/cli/langchain_cli/utils/packages.py +++ b/libs/cli/langchain_cli/utils/packages.py @@ -15,12 +15,12 @@ def get_package_root(cwd: Optional[Path] = None) -> Path: if pyproject_path.exists(): return package_root package_root = package_root.parent - raise FileNotFoundError("No pyproject.toml found") + msg = "No pyproject.toml found" + raise FileNotFoundError(msg) class LangServeExport(TypedDict): - """ - Fields from pyproject.toml that are relevant to LangServe + """Fields from pyproject.toml that are relevant to LangServe. Attributes: module: The module to import from, tool.langserve.export_module @@ -41,5 +41,6 @@ def get_langserve_export(filepath: Path) -> LangServeExport: attr = data["tool"]["langserve"]["export_attr"] package_name = data["tool"]["poetry"]["name"] except KeyError as e: - raise KeyError("Invalid LangServe PyProject.toml") from e + msg = "Invalid LangServe PyProject.toml" + raise KeyError(msg) from e return LangServeExport(module=module, attr=attr, package_name=package_name) diff --git a/libs/cli/langchain_cli/utils/pyproject.py b/libs/cli/langchain_cli/utils/pyproject.py index e967d4a4c34..61a86c45a93 100644 --- a/libs/cli/langchain_cli/utils/pyproject.py +++ b/libs/cli/langchain_cli/utils/pyproject.py @@ -1,3 +1,4 @@ +import contextlib from collections.abc import Iterable from pathlib import Path from typing import Any @@ -38,9 +39,7 @@ def remove_dependencies_from_pyproject_toml( # tomlkit types aren't amazing - treat as Dict instead dependencies = pyproject["tool"]["poetry"]["dependencies"] for name in local_editable_dependencies: - try: + with contextlib.suppress(KeyError): del dependencies[name] - except KeyError: - pass with open(pyproject_toml, "w", encoding="utf-8") as f: dump(pyproject, f) diff --git a/libs/cli/scripts/generate_migrations.py b/libs/cli/scripts/generate_migrations.py index b625d2528a2..1dcd5650a2b 100644 --- a/libs/cli/scripts/generate_migrations.py +++ b/libs/cli/scripts/generate_migrations.py @@ -4,6 +4,7 @@ import json import os import pkgutil +from typing import Optional import click @@ -19,9 +20,8 @@ from langchain_cli.namespaces.migrate.generate.partner import ( @click.group() -def cli(): +def cli() -> None: """Migration scripts management.""" - pass @cli.command() @@ -73,7 +73,7 @@ def generic( f.write(dumped) -def handle_partner(pkg: str, output: str = None): +def handle_partner(pkg: str, output: Optional[str] = None) -> None: migrations = get_migrations_for_partner_package(pkg) # Run with python 3.9+ name = pkg.removeprefix("langchain_") diff --git a/libs/cli/tests/integration_tests/test_compile.py b/libs/cli/tests/integration_tests/test_compile.py index 33ecccdfa0f..f315e45f521 100644 --- a/libs/cli/tests/integration_tests/test_compile.py +++ b/libs/cli/tests/integration_tests/test_compile.py @@ -4,4 +4,3 @@ import pytest @pytest.mark.compile def test_placeholder() -> None: """Used for compiling integration tests without running any real tests.""" - pass diff --git a/libs/cli/tests/unit_tests/migrate/cli_runner/test_cli.py b/libs/cli/tests/unit_tests/migrate/cli_runner/test_cli.py index 3cc57112ead..e435547b16b 100644 --- a/libs/cli/tests/unit_tests/migrate/cli_runner/test_cli.py +++ b/libs/cli/tests/unit_tests/migrate/cli_runner/test_cli.py @@ -1,19 +1,17 @@ -# ruff: noqa: E402 from __future__ import annotations -import pytest - -pytest.importorskip("gritql") - import difflib from pathlib import Path +import pytest from typer.testing import CliRunner from langchain_cli.cli import app from tests.unit_tests.migrate.cli_runner.cases import before, expected from tests.unit_tests.migrate.cli_runner.folder import Folder +pytest.importorskip("gritql") + def find_issue(current: Folder, expected: Folder) -> str: for current_file, expected_file in zip(current.files, expected.files): @@ -25,7 +23,7 @@ def find_issue(current: Folder, expected: Folder) -> str: ) if isinstance(current_file, Folder) and isinstance(expected_file, Folder): return find_issue(current_file, expected_file) - elif isinstance(current_file, Folder) or isinstance(expected_file, Folder): + if isinstance(current_file, Folder) or isinstance(expected_file, Folder): return ( f"One of the files is a " f"folder: {current_file.name} != {expected_file.name}" diff --git a/libs/cli/tests/unit_tests/migrate/generate/test_langchain_migration.py b/libs/cli/tests/unit_tests/migrate/generate/test_langchain_migration.py index 77347912d0b..1fd8a88161a 100644 --- a/libs/cli/tests/unit_tests/migrate/generate/test_langchain_migration.py +++ b/libs/cli/tests/unit_tests/migrate/generate/test_langchain_migration.py @@ -10,49 +10,45 @@ from langchain_cli.namespaces.migrate.generate.generic import ( @pytest.mark.xfail(reason="Unknown reason") def test_create_json_agent_migration() -> None: """Test the migration of create_json_agent from langchain to langchain_community.""" - with sup1(): - with sup2(): - raw_migrations = generate_simplified_migrations( - from_package="langchain", to_package="langchain_community" - ) - json_agent_migrations = [ - migration - for migration in raw_migrations - if "create_json_agent" in migration[0] - ] - assert json_agent_migrations == [ - ( - "langchain.agents.create_json_agent", - "langchain_community.agent_toolkits.create_json_agent", - ), - ( - "langchain.agents.agent_toolkits.create_json_agent", - "langchain_community.agent_toolkits.create_json_agent", - ), - ( - "langchain.agents.agent_toolkits.json.base.create_json_agent", - "langchain_community.agent_toolkits.create_json_agent", - ), - ] + with sup1(), sup2(): + raw_migrations = generate_simplified_migrations( + from_package="langchain", to_package="langchain_community" + ) + json_agent_migrations = [ + migration + for migration in raw_migrations + if "create_json_agent" in migration[0] + ] + assert json_agent_migrations == [ + ( + "langchain.agents.create_json_agent", + "langchain_community.agent_toolkits.create_json_agent", + ), + ( + "langchain.agents.agent_toolkits.create_json_agent", + "langchain_community.agent_toolkits.create_json_agent", + ), + ( + "langchain.agents.agent_toolkits.json.base.create_json_agent", + "langchain_community.agent_toolkits.create_json_agent", + ), + ] @pytest.mark.xfail(reason="Unknown reason") def test_create_single_store_retriever_db() -> None: - """Test migration from langchain to langchain_core""" - with sup1(): - with sup2(): - raw_migrations = generate_simplified_migrations( - from_package="langchain", to_package="langchain_core" - ) - # SingleStore was an old name for VectorStoreRetriever - single_store_migration = [ - migration - for migration in raw_migrations - if "SingleStore" in migration[0] - ] - assert single_store_migration == [ - ( - "langchain.vectorstores.singlestoredb.SingleStoreDBRetriever", - "langchain_core.vectorstores.VectorStoreRetriever", - ), - ] + """Test migration from langchain to langchain_core.""" + with sup1(), sup2(): + raw_migrations = generate_simplified_migrations( + from_package="langchain", to_package="langchain_core" + ) + # SingleStore was an old name for VectorStoreRetriever + single_store_migration = [ + migration for migration in raw_migrations if "SingleStore" in migration[0] + ] + assert single_store_migration == [ + ( + "langchain.vectorstores.singlestoredb.SingleStoreDBRetriever", + "langchain_core.vectorstores.VectorStoreRetriever", + ), + ]