cli: Ruff autofixes (#31863)

Auto-fixes from ruff with rule ALL
This commit is contained in:
Christophe Bornet 2025-07-07 16:06:34 +02:00 committed by GitHub
parent 451c90fefa
commit a46a2b8bda
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 157 additions and 196 deletions

View File

@ -35,7 +35,7 @@ app.command(
def version_callback(show_version: bool) -> None: def version_callback(show_version: bool) -> None:
if show_version: if show_version:
typer.echo(f"langchain-cli {__version__}") typer.echo(f"langchain-cli {__version__}")
raise typer.Exit() raise typer.Exit
@app.callback() @app.callback()
@ -48,7 +48,7 @@ def main(
callback=version_callback, callback=version_callback,
is_eager=True, is_eager=True,
), ),
): ) -> None:
pass pass
@ -62,10 +62,7 @@ def serve(
Optional[str], typer.Option(help="The host to run the server on") Optional[str], typer.Option(help="The host to run the server on")
] = None, ] = None,
) -> None: ) -> None:
""" """Start the LangServe app, whether it's a template or an app."""
Start the LangServe app, whether it's a template or an app.
"""
# see if is a template # see if is a template
try: try:
project_dir = get_package_root() project_dir = get_package_root()

View File

@ -1,7 +1,5 @@
# type: ignore # type: ignore
""" """Development Scripts for template packages."""
Development Scripts for template packages
"""
from collections.abc import Sequence from collections.abc import Sequence
@ -16,9 +14,7 @@ def create_demo_server(
config_keys: Sequence[str] = (), config_keys: Sequence[str] = (),
playground_type: str = "default", playground_type: str = "default",
): ):
""" """Creates a demo server for the current template."""
Creates a demo server for the current template.
"""
app = FastAPI() app = FastAPI()
package_root = get_package_root() package_root = get_package_root()
pyproject = package_root / "pyproject.toml" pyproject = package_root / "pyproject.toml"
@ -35,9 +31,11 @@ def create_demo_server(
playground_type=playground_type, playground_type=playground_type,
) )
except KeyError as e: except KeyError as e:
raise KeyError("Missing fields from pyproject.toml") from e msg = "Missing fields from pyproject.toml"
raise KeyError(msg) from e
except ImportError as e: except ImportError as e:
raise ImportError("Could not import module defined in pyproject.toml") from e msg = "Could not import module defined in pyproject.toml"
raise ImportError(msg) from e
return app return app

View File

@ -1,6 +1,4 @@
""" """Manage LangChain apps."""
Manage LangChain apps
"""
import shutil import shutil
import subprocess import subprocess
@ -62,15 +60,14 @@ def new(
is_flag=True, is_flag=True,
), ),
] = False, ] = False,
): ) -> None:
""" """Create a new LangServe application."""
Create a new LangServe application.
"""
has_packages = package is not None and len(package) > 0 has_packages = package is not None and len(package) > 0
if noninteractive: if noninteractive:
if name is None: if name is None:
raise typer.BadParameter("name is required when --non-interactive is set") msg = "name is required when --non-interactive is set"
raise typer.BadParameter(msg)
name_str = name name_str = name
pip_bool = bool(pip) # None should be false pip_bool = bool(pip) # None should be false
else: else:
@ -154,9 +151,8 @@ def add(
prompt="Would you like to `pip install -e` the template(s)?", prompt="Would you like to `pip install -e` the template(s)?",
), ),
], ],
): ) -> None:
""" """Adds the specified template to the current LangServe app.
Adds the specified template to the current LangServe app.
e.g.: e.g.:
langchain app add extraction-openai-functions langchain app add extraction-openai-functions
@ -166,7 +162,8 @@ def add(
if not branch and not repo: if not branch and not repo:
warnings.warn( warnings.warn(
"Adding templates from the default branch and repo is deprecated." "Adding templates from the default branch and repo is deprecated."
" At a minimum, you will have to add `--branch v0.2` for this to work" " At a minimum, you will have to add `--branch v0.2` for this to work",
stacklevel=2,
) )
parsed_deps = parse_dependencies(dependencies, repo, branch, api_path) parsed_deps = parse_dependencies(dependencies, repo, branch, api_path)
@ -176,7 +173,7 @@ def add(
package_dir = project_root / "packages" package_dir = project_root / "packages"
create_events( create_events(
[{"event": "serve add", "properties": dict(parsed_dep=d)} for d in parsed_deps] [{"event": "serve add", "properties": {"parsed_dep": d}} for d in parsed_deps]
) )
# group by repo/ref # group by repo/ref
@ -216,7 +213,7 @@ def add(
destination_path = package_dir / inner_api_path destination_path = package_dir / inner_api_path
if destination_path.exists(): if destination_path.exists():
typer.echo( typer.echo(
f"Folder {str(inner_api_path)} already exists. Skipping...", f"Folder {inner_api_path} already exists. Skipping...",
) )
continue continue
copy_repo(source_path, destination_path) copy_repo(source_path, destination_path)
@ -248,7 +245,7 @@ def add(
typer.echo("Failed to print install command, continuing...") typer.echo("Failed to print install command, continuing...")
else: else:
if pip: if pip:
cmd = ["pip", "install", "-e"] + installed_destination_strs cmd = ["pip", "install", "-e", *installed_destination_strs]
cmd_str = " \\\n ".join(installed_destination_strs) cmd_str = " \\\n ".join(installed_destination_strs)
typer.echo(f"Running: pip install -e \\\n {cmd_str}") typer.echo(f"Running: pip install -e \\\n {cmd_str}")
subprocess.run(cmd, cwd=cwd) subprocess.run(cmd, cwd=cwd)
@ -282,13 +279,15 @@ def add(
if len(chain_names) == 1 if len(chain_names) == 1
else f"these {len(chain_names)} templates" else f"these {len(chain_names)} templates"
) )
lines = ( lines = [
["", f"To use {t}, add the following to your app:\n\n```", ""] "",
+ imports f"To use {t}, add the following to your app:\n\n```",
+ [""] "",
+ routes *imports,
+ ["```"] "",
) *routes,
"```",
]
typer.echo("\n".join(lines)) typer.echo("\n".join(lines))
@ -299,11 +298,8 @@ def remove(
project_dir: Annotated[ project_dir: Annotated[
Optional[Path], typer.Option(help="The project directory") Optional[Path], typer.Option(help="The project directory")
] = None, ] = None,
): ) -> None:
""" """Removes the specified package from the current LangServe app."""
Removes the specified package from the current LangServe app.
"""
project_root = get_package_root(project_dir) project_root = get_package_root(project_dir)
project_pyproject = project_root / "pyproject.toml" project_pyproject = project_root / "pyproject.toml"
@ -347,10 +343,7 @@ def serve(
Optional[str], typer.Option(help="The app to run, e.g. `app.server:app`") Optional[str], typer.Option(help="The app to run, e.g. `app.server:app`")
] = None, ] = None,
) -> None: ) -> None:
""" """Starts the LangServe app."""
Starts the LangServe app.
"""
# add current dir as first entry of path # add current dir as first entry of path
sys.path.append(str(Path.cwd())) sys.path.append(str(Path.cwd()))

View File

@ -1,6 +1,4 @@
""" """Develop integration packages for LangChain."""
Develop integration packages for LangChain.
"""
import re import re
import shutil import shutil
@ -28,18 +26,20 @@ class Replacements(TypedDict):
def _process_name(name: str, *, community: bool = False) -> Replacements: def _process_name(name: str, *, community: bool = False) -> Replacements:
preprocessed = name.replace("_", "-").lower() preprocessed = name.replace("_", "-").lower()
if preprocessed.startswith("langchain-"): preprocessed = preprocessed.removeprefix("langchain-")
preprocessed = preprocessed[len("langchain-") :]
if not re.match(r"^[a-z][a-z0-9-]*$", preprocessed): if not re.match(r"^[a-z][a-z0-9-]*$", preprocessed):
raise ValueError( msg = (
"Name should only contain lowercase letters (a-z), numbers, and hyphens" "Name should only contain lowercase letters (a-z), numbers, and hyphens"
", and start with a letter." ", and start with a letter."
) )
raise ValueError(msg)
if preprocessed.endswith("-"): if preprocessed.endswith("-"):
raise ValueError("Name should not end with `-`.") msg = "Name should not end with `-`."
raise ValueError(msg)
if preprocessed.find("--") != -1: if preprocessed.find("--") != -1:
raise ValueError("Name should not contain consecutive hyphens.") msg = "Name should not contain consecutive hyphens."
raise ValueError(msg)
replacements: Replacements = { replacements: Replacements = {
"__package_name__": f"langchain-{preprocessed}", "__package_name__": f"langchain-{preprocessed}",
"__module_name__": "langchain_" + preprocessed.replace("-", "_"), "__module_name__": "langchain_" + preprocessed.replace("-", "_"),
@ -84,11 +84,8 @@ def new(
". e.g. `my-integration/my_integration.py`", ". e.g. `my-integration/my_integration.py`",
), ),
] = None, ] = None,
): ) -> None:
""" """Creates a new integration package."""
Creates a new integration package.
"""
try: try:
replacements = _process_name(name) replacements = _process_name(name)
except ValueError as e: except ValueError as e:
@ -123,7 +120,7 @@ def new(
shutil.move(destination_dir / "integration_template", package_dir) shutil.move(destination_dir / "integration_template", package_dir)
# replacements in files # replacements in files
replace_glob(destination_dir, "**/*", cast(dict[str, str], replacements)) replace_glob(destination_dir, "**/*", cast("dict[str, str]", replacements))
# poetry install # poetry install
subprocess.run( subprocess.run(
@ -167,7 +164,7 @@ def new(
for src_path, dst_path in zip(src_paths, dst_paths): for src_path, dst_path in zip(src_paths, dst_paths):
shutil.copy(src_path, dst_path) shutil.copy(src_path, dst_path)
replace_file(dst_path, cast(dict[str, str], replacements)) replace_file(dst_path, cast("dict[str, str]", replacements))
TEMPLATE_MAP: dict[str, str] = { TEMPLATE_MAP: dict[str, str] = {
@ -183,7 +180,7 @@ TEMPLATE_MAP: dict[str, str] = {
"Retriever": "retrievers.ipynb", "Retriever": "retrievers.ipynb",
} }
_component_types_str = ", ".join(f"`{k}`" for k in TEMPLATE_MAP.keys()) _component_types_str = ", ".join(f"`{k}`" for k in TEMPLATE_MAP)
@integration_cli.command() @integration_cli.command()
@ -226,10 +223,8 @@ def create_doc(
prompt="The relative path to the docs directory to place the new file in.", prompt="The relative path to the docs directory to place the new file in.",
), ),
] = "docs/docs/integrations/chat/", ] = "docs/docs/integrations/chat/",
): ) -> None:
""" """Creates a new integration doc."""
Creates a new integration doc.
"""
if component_type not in TEMPLATE_MAP: if component_type not in TEMPLATE_MAP:
typer.echo( typer.echo(
f"Unrecognized {component_type=}. Expected one of {_component_types_str}." f"Unrecognized {component_type=}. Expected one of {_component_types_str}."

View File

@ -12,7 +12,7 @@ def generate_raw_migrations(
package = importlib.import_module(from_package) package = importlib.import_module(from_package)
items = [] items = []
for importer, modname, ispkg in pkgutil.walk_packages( for _importer, modname, _ispkg in pkgutil.walk_packages(
package.__path__, package.__name__ + "." package.__path__, package.__name__ + "."
): ):
try: try:
@ -84,9 +84,9 @@ def generate_top_level_imports(pkg: str) -> list[tuple[str, str]]:
items = [] items = []
# Function to handle importing from modules # Function to handle importing from modules
def handle_module(module, module_name): def handle_module(module, module_name) -> None:
if hasattr(module, "__all__"): if hasattr(module, "__all__"):
all_objects = getattr(module, "__all__") all_objects = module.__all__
for name in all_objects: for name in all_objects:
# Attempt to fetch each object declared in __all__ # Attempt to fetch each object declared in __all__
obj = getattr(module, name, None) obj = getattr(module, name, None)
@ -105,7 +105,7 @@ def generate_top_level_imports(pkg: str) -> list[tuple[str, str]]:
handle_module(package, pkg) handle_module(package, pkg)
# Only iterate through top-level modules/packages # Only iterate through top-level modules/packages
for finder, modname, ispkg in pkgutil.iter_modules( for _finder, modname, ispkg in pkgutil.iter_modules(
package.__path__, package.__name__ + "." package.__path__, package.__name__ + "."
): ):
if ispkg: if ispkg:
@ -126,7 +126,7 @@ def generate_simplified_migrations(
from_package, to_package, filter_by_all=filter_by_all from_package, to_package, filter_by_all=filter_by_all
) )
top_level_simplifications = generate_top_level_imports(to_package) top_level_simplifications = generate_top_level_imports(to_package)
top_level_dict = {full: top_level for full, top_level in top_level_simplifications} top_level_dict = dict(top_level_simplifications)
simple_migrations = [] simple_migrations = []
for migration in raw_migrations: for migration in raw_migrations:
original, new = migration original, new = migration

View File

@ -1,12 +1,11 @@
def split_package(package: str) -> tuple[str, str]: def split_package(package: str) -> tuple[str, str]:
"""Split a package name into the containing package and the final name""" """Split a package name into the containing package and the final name."""
parts = package.split(".") parts = package.split(".")
return ".".join(parts[:-1]), parts[-1] return ".".join(parts[:-1]), parts[-1]
def dump_migrations_as_grit(name: str, migration_pairs: list[tuple[str, str]]): def dump_migrations_as_grit(name: str, migration_pairs: list[tuple[str, str]]) -> str:
"""Dump the migration pairs as a Grit file.""" """Dump the migration pairs as a Grit file."""
output = "language python"
remapped = ",\n".join( remapped = ",\n".join(
[ [
f""" f"""
@ -21,7 +20,7 @@ def dump_migrations_as_grit(name: str, migration_pairs: list[tuple[str, str]]):
] ]
) )
pattern_name = f"langchain_migrate_{name}" pattern_name = f"langchain_migrate_{name}"
output = f""" return f"""
language python language python
// This migration is generated automatically - do not manually edit this file // This migration is generated automatically - do not manually edit this file
@ -34,4 +33,3 @@ pattern {pattern_name}() {{
// Add this for invoking directly // Add this for invoking directly
{pattern_name}() {pattern_name}()
""" """
return output

View File

@ -46,9 +46,8 @@ def get_migrations_for_partner_package(pkg_name: str) -> list[tuple[str, str]]:
old_paths = community_classes + imports_for_pkg old_paths = community_classes + imports_for_pkg
migrations = [ return [
(f"{module}.{item}", f"{pkg_name}.{item}") (f"{module}.{item}", f"{pkg_name}.{item}")
for module, item in old_paths for module, item in old_paths
if item in classes_ if item in classes_
] ]
return migrations

View File

@ -20,7 +20,7 @@ class ImportExtractor(ast.NodeVisitor):
self.imports: list = [] self.imports: list = []
self.package = from_package self.package = from_package
def visit_ImportFrom(self, node): def visit_ImportFrom(self, node) -> None:
if node.module and ( if node.module and (
self.package is None or str(node.module).startswith(self.package) self.package is None or str(node.module).startswith(self.package)
): ):
@ -39,7 +39,7 @@ def _get_class_names(code: str) -> list[str]:
# Define a node visitor class to collect class names # Define a node visitor class to collect class names
class ClassVisitor(ast.NodeVisitor): class ClassVisitor(ast.NodeVisitor):
def visit_ClassDef(self, node): def visit_ClassDef(self, node) -> None:
class_names.append(node.name) class_names.append(node.name)
self.generic_visit(node) self.generic_visit(node)
@ -62,7 +62,7 @@ def find_subclasses_in_module(module, classes_: list[type]) -> list[str]:
"""Find all classes in the module that inherit from one of the classes.""" """Find all classes in the module that inherit from one of the classes."""
subclasses = [] subclasses = []
# Iterate over all attributes of the module that are classes # Iterate over all attributes of the module that are classes
for name, obj in inspect.getmembers(module, inspect.isclass): for _name, obj in inspect.getmembers(module, inspect.isclass):
if is_subclass(obj, classes_): if is_subclass(obj, classes_):
subclasses.append(obj.__name__) subclasses.append(obj.__name__)
return subclasses return subclasses
@ -125,7 +125,7 @@ def list_init_imports_by_package(pkg_root: str) -> list[tuple[str, str]]:
files = list(Path(pkg_source).rglob("*.py")) files = list(Path(pkg_source).rglob("*.py"))
for file in files: for file in files:
if not file.name == "__init__.py": if file.name != "__init__.py":
continue continue
import_in_file = identify_all_imports_in_file(str(file)) import_in_file = identify_all_imports_in_file(str(file))
module_name = _get_current_module(file, pkg_root) module_name = _get_current_module(file, pkg_root)

View File

@ -1,6 +1,4 @@
""" """Develop installable templates."""
Develop installable templates.
"""
import re import re
import shutil import shutil
@ -22,10 +20,8 @@ def new(
bool, bool,
typer.Option("--with-poetry/--no-poetry", help="Don't run poetry install"), typer.Option("--with-poetry/--no-poetry", help="Don't run poetry install"),
] = False, ] = False,
): ) -> None:
""" """Creates a new template package."""
Creates a new template package.
"""
computed_name = name if name != "." else Path.cwd().name computed_name = name if name != "." else Path.cwd().name
destination_dir = Path.cwd() / name if name != "." else Path.cwd() destination_dir = Path.cwd() / name if name != "." else Path.cwd()
@ -108,9 +104,7 @@ def serve(
), ),
] = False, ] = False,
) -> None: ) -> None:
""" """Starts a demo app for this template."""
Starts a demo app for this template.
"""
# load pyproject.toml # load pyproject.toml
project_dir = get_package_root() project_dir = get_package_root()
pyproject = project_dir / "pyproject.toml" pyproject = project_dir / "pyproject.toml"
@ -143,9 +137,7 @@ def serve(
@package_cli.command() @package_cli.command()
def list(contains: Annotated[Optional[str], typer.Argument()] = None) -> None: def list(contains: Annotated[Optional[str], typer.Argument()] = None) -> None:
""" """List all or search for available templates."""
List all or search for available templates.
"""
from langchain_cli.utils.github import list_packages from langchain_cli.utils.github import list_packages
packages = list_packages(contains=contains) packages = list_packages(contains=contains)

View File

@ -31,10 +31,11 @@ def parse_dependency_string(
) -> DependencySource: ) -> DependencySource:
if dep is not None and dep.startswith("git+"): if dep is not None and dep.startswith("git+"):
if repo is not None or branch is not None: if repo is not None or branch is not None:
raise ValueError( msg = (
"If a dependency starts with git+, you cannot manually specify " "If a dependency starts with git+, you cannot manually specify "
"a repo or branch." "a repo or branch."
) )
raise ValueError(msg)
# remove git+ # remove git+
gitstring = dep[4:] gitstring = dep[4:]
subdirectory = None subdirectory = None
@ -43,9 +44,8 @@ def parse_dependency_string(
if "#subdirectory=" in gitstring: if "#subdirectory=" in gitstring:
gitstring, subdirectory = gitstring.split("#subdirectory=") gitstring, subdirectory = gitstring.split("#subdirectory=")
if "#" in subdirectory or "@" in subdirectory: if "#" in subdirectory or "@" in subdirectory:
raise ValueError( msg = "#subdirectory must be the last part of the dependency string"
"#subdirectory must be the last part of the dependency string" raise ValueError(msg)
)
# find first slash after :// # find first slash after ://
# find @ or # after that slash # find @ or # after that slash
@ -54,9 +54,8 @@ def parse_dependency_string(
# find first slash after :// # find first slash after ://
if "://" not in gitstring: if "://" not in gitstring:
raise ValueError( msg = "git+ dependencies must start with git+https:// or git+ssh://"
"git+ dependencies must start with git+https:// or git+ssh://" raise ValueError(msg)
)
_, find_slash = gitstring.split("://", 1) _, find_slash = gitstring.split("://", 1)
@ -79,9 +78,9 @@ def parse_dependency_string(
event_metadata={"dependency_string": dep}, event_metadata={"dependency_string": dep},
) )
elif dep is not None and dep.startswith("https://"): if dep is not None and dep.startswith("https://"):
raise ValueError("Only git dependencies are supported") msg = "Only git dependencies are supported"
else: raise ValueError(msg)
# if repo is none, use default, including subdirectory # if repo is none, use default, including subdirectory
base_subdir = Path(DEFAULT_GIT_SUBDIRECTORY) if repo is None else Path() base_subdir = Path(DEFAULT_GIT_SUBDIRECTORY) if repo is None else Path()
subdir = str(base_subdir / dep) if dep is not None else None subdir = str(base_subdir / dep) if dep is not None else None
@ -108,12 +107,12 @@ def parse_dependency_string(
def _list_arg_to_length(arg: Optional[list[str]], num: int) -> Sequence[Optional[str]]: def _list_arg_to_length(arg: Optional[list[str]], num: int) -> Sequence[Optional[str]]:
if not arg: if not arg:
return [None] * num return [None] * num
elif len(arg) == 1: if len(arg) == 1:
return arg * num return arg * num
elif len(arg) == num: if len(arg) == num:
return arg return arg
else: msg = f"Argument must be of length 1 or {num}"
raise ValueError(f"Argument must be of length 1 or {num}") raise ValueError(msg)
def parse_dependencies( def parse_dependencies(
@ -131,10 +130,11 @@ def parse_dependencies(
or (repo and len(repo) not in [1, num_deps]) or (repo and len(repo) not in [1, num_deps])
or (branch and len(branch) not in [1, num_deps]) or (branch and len(branch) not in [1, num_deps])
): ):
raise ValueError( msg = (
"Number of defined repos/branches/api_paths did not match the " "Number of defined repos/branches/api_paths did not match the "
"number of templates." "number of templates."
) )
raise ValueError(msg)
inner_deps = _list_arg_to_length(dependencies, num_deps) inner_deps = _list_arg_to_length(dependencies, num_deps)
inner_api_paths = _list_arg_to_length(api_path, num_deps) inner_api_paths = _list_arg_to_length(api_path, num_deps)
inner_repos = _list_arg_to_length(repo, num_deps) inner_repos = _list_arg_to_length(repo, num_deps)
@ -170,7 +170,7 @@ def update_repo(gitstring: str, ref: Optional[str], repo_dir: Path) -> Path:
try: try:
repo = Repo(repo_path) repo = Repo(repo_path)
if repo.active_branch.name != ref: if repo.active_branch.name != ref:
raise ValueError() raise ValueError
repo.remotes.origin.pull() repo.remotes.origin.pull()
except Exception: except Exception:
# if it fails, delete and clone again # if it fails, delete and clone again
@ -186,8 +186,7 @@ def copy_repo(
source: Path, source: Path,
destination: Path, destination: Path,
) -> None: ) -> None:
""" """Copies a repo, ignoring git folders.
Copies a repo, ignoring git folders.
Raises FileNotFound error if it can't find source Raises FileNotFound error if it can't find source
""" """

View File

@ -23,7 +23,4 @@ def list_packages(*, contains: Optional[str] = None):
package_names = [ package_names = [
p["name"] for p in data if p["type"] == "dir" and p["name"] != "docs" p["name"] for p in data if p["type"] == "dir" and p["name"] != "docs"
] ]
package_names_filtered = ( return [p for p in package_names if contains in p] if contains else package_names
[p for p in package_names if contains in p] if contains else package_names
)
return package_names_filtered

View File

@ -15,12 +15,12 @@ def get_package_root(cwd: Optional[Path] = None) -> Path:
if pyproject_path.exists(): if pyproject_path.exists():
return package_root return package_root
package_root = package_root.parent package_root = package_root.parent
raise FileNotFoundError("No pyproject.toml found") msg = "No pyproject.toml found"
raise FileNotFoundError(msg)
class LangServeExport(TypedDict): class LangServeExport(TypedDict):
""" """Fields from pyproject.toml that are relevant to LangServe.
Fields from pyproject.toml that are relevant to LangServe
Attributes: Attributes:
module: The module to import from, tool.langserve.export_module module: The module to import from, tool.langserve.export_module
@ -41,5 +41,6 @@ def get_langserve_export(filepath: Path) -> LangServeExport:
attr = data["tool"]["langserve"]["export_attr"] attr = data["tool"]["langserve"]["export_attr"]
package_name = data["tool"]["poetry"]["name"] package_name = data["tool"]["poetry"]["name"]
except KeyError as e: except KeyError as e:
raise KeyError("Invalid LangServe PyProject.toml") from e msg = "Invalid LangServe PyProject.toml"
raise KeyError(msg) from e
return LangServeExport(module=module, attr=attr, package_name=package_name) return LangServeExport(module=module, attr=attr, package_name=package_name)

View File

@ -1,3 +1,4 @@
import contextlib
from collections.abc import Iterable from collections.abc import Iterable
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@ -38,9 +39,7 @@ def remove_dependencies_from_pyproject_toml(
# tomlkit types aren't amazing - treat as Dict instead # tomlkit types aren't amazing - treat as Dict instead
dependencies = pyproject["tool"]["poetry"]["dependencies"] dependencies = pyproject["tool"]["poetry"]["dependencies"]
for name in local_editable_dependencies: for name in local_editable_dependencies:
try: with contextlib.suppress(KeyError):
del dependencies[name] del dependencies[name]
except KeyError:
pass
with open(pyproject_toml, "w", encoding="utf-8") as f: with open(pyproject_toml, "w", encoding="utf-8") as f:
dump(pyproject, f) dump(pyproject, f)

View File

@ -4,6 +4,7 @@
import json import json
import os import os
import pkgutil import pkgutil
from typing import Optional
import click import click
@ -19,9 +20,8 @@ from langchain_cli.namespaces.migrate.generate.partner import (
@click.group() @click.group()
def cli(): def cli() -> None:
"""Migration scripts management.""" """Migration scripts management."""
pass
@cli.command() @cli.command()
@ -73,7 +73,7 @@ def generic(
f.write(dumped) f.write(dumped)
def handle_partner(pkg: str, output: str = None): def handle_partner(pkg: str, output: Optional[str] = None) -> None:
migrations = get_migrations_for_partner_package(pkg) migrations = get_migrations_for_partner_package(pkg)
# Run with python 3.9+ # Run with python 3.9+
name = pkg.removeprefix("langchain_") name = pkg.removeprefix("langchain_")

View File

@ -4,4 +4,3 @@ import pytest
@pytest.mark.compile @pytest.mark.compile
def test_placeholder() -> None: def test_placeholder() -> None:
"""Used for compiling integration tests without running any real tests.""" """Used for compiling integration tests without running any real tests."""
pass

View File

@ -1,19 +1,17 @@
# ruff: noqa: E402
from __future__ import annotations from __future__ import annotations
import pytest
pytest.importorskip("gritql")
import difflib import difflib
from pathlib import Path from pathlib import Path
import pytest
from typer.testing import CliRunner from typer.testing import CliRunner
from langchain_cli.cli import app from langchain_cli.cli import app
from tests.unit_tests.migrate.cli_runner.cases import before, expected from tests.unit_tests.migrate.cli_runner.cases import before, expected
from tests.unit_tests.migrate.cli_runner.folder import Folder from tests.unit_tests.migrate.cli_runner.folder import Folder
pytest.importorskip("gritql")
def find_issue(current: Folder, expected: Folder) -> str: def find_issue(current: Folder, expected: Folder) -> str:
for current_file, expected_file in zip(current.files, expected.files): for current_file, expected_file in zip(current.files, expected.files):
@ -25,7 +23,7 @@ def find_issue(current: Folder, expected: Folder) -> str:
) )
if isinstance(current_file, Folder) and isinstance(expected_file, Folder): if isinstance(current_file, Folder) and isinstance(expected_file, Folder):
return find_issue(current_file, expected_file) return find_issue(current_file, expected_file)
elif isinstance(current_file, Folder) or isinstance(expected_file, Folder): if isinstance(current_file, Folder) or isinstance(expected_file, Folder):
return ( return (
f"One of the files is a " f"One of the files is a "
f"folder: {current_file.name} != {expected_file.name}" f"folder: {current_file.name} != {expected_file.name}"

View File

@ -10,8 +10,7 @@ from langchain_cli.namespaces.migrate.generate.generic import (
@pytest.mark.xfail(reason="Unknown reason") @pytest.mark.xfail(reason="Unknown reason")
def test_create_json_agent_migration() -> None: def test_create_json_agent_migration() -> None:
"""Test the migration of create_json_agent from langchain to langchain_community.""" """Test the migration of create_json_agent from langchain to langchain_community."""
with sup1(): with sup1(), sup2():
with sup2():
raw_migrations = generate_simplified_migrations( raw_migrations = generate_simplified_migrations(
from_package="langchain", to_package="langchain_community" from_package="langchain", to_package="langchain_community"
) )
@ -38,17 +37,14 @@ def test_create_json_agent_migration() -> None:
@pytest.mark.xfail(reason="Unknown reason") @pytest.mark.xfail(reason="Unknown reason")
def test_create_single_store_retriever_db() -> None: def test_create_single_store_retriever_db() -> None:
"""Test migration from langchain to langchain_core""" """Test migration from langchain to langchain_core."""
with sup1(): with sup1(), sup2():
with sup2():
raw_migrations = generate_simplified_migrations( raw_migrations = generate_simplified_migrations(
from_package="langchain", to_package="langchain_core" from_package="langchain", to_package="langchain_core"
) )
# SingleStore was an old name for VectorStoreRetriever # SingleStore was an old name for VectorStoreRetriever
single_store_migration = [ single_store_migration = [
migration migration for migration in raw_migrations if "SingleStore" in migration[0]
for migration in raw_migrations
if "SingleStore" in migration[0]
] ]
assert single_store_migration == [ assert single_store_migration == [
( (