Files
langchain/libs/cli/langchain_cli/namespaces/migrate/generate/utils.py
Christophe Bornet e3c4aeaea1 chore(cli): add mypy strict checking (#32386)
Co-authored-by: Mason Daugherty <mason@langchain.dev>
2025-08-30 13:02:45 -05:00

173 lines
5.6 KiB
Python

"""Generate migrations utilities."""
import ast
import inspect
import os
import pathlib
from pathlib import Path
from types import ModuleType
from typing import Optional
from typing_extensions import override
HERE = Path(__file__).parent
# Should bring us to [root]/src
PKGS_ROOT = HERE.parent.parent.parent.parent.parent
LANGCHAIN_PKG = PKGS_ROOT / "langchain"
COMMUNITY_PKG = PKGS_ROOT / "community"
PARTNER_PKGS = PKGS_ROOT / "partners"
class ImportExtractor(ast.NodeVisitor):
"""Import extractor."""
def __init__(self, *, from_package: Optional[str] = None) -> None:
"""Extract all imports from the given code, optionally filtering by package."""
self.imports: list[tuple[str, str]] = []
self.package = from_package
@override
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if node.module and (
self.package is None or str(node.module).startswith(self.package)
):
for alias in node.names:
self.imports.append((node.module, alias.name))
self.generic_visit(node)
def _get_class_names(code: str) -> list[str]:
"""Extract class names from a code string."""
# Parse the content of the file into an AST
tree = ast.parse(code)
# Initialize a list to hold all class names
class_names = []
# Define a node visitor class to collect class names
class ClassVisitor(ast.NodeVisitor):
@override
def visit_ClassDef(self, node: ast.ClassDef) -> None:
class_names.append(node.name)
self.generic_visit(node)
# Create an instance of the visitor and visit the AST
visitor = ClassVisitor()
visitor.visit(tree)
return class_names
def is_subclass(class_obj: type, classes_: list[type]) -> bool:
"""Check if the given class object is a subclass of any class in list classes."""
return any(
issubclass(class_obj, kls)
for kls in classes_
if inspect.isclass(class_obj) and inspect.isclass(kls)
)
def find_subclasses_in_module(module: ModuleType, classes_: list[type]) -> list[str]:
"""Find all classes in the module that inherit from one of the classes."""
subclasses = []
# Iterate over all attributes of the module that are classes
for _name, obj in inspect.getmembers(module, inspect.isclass):
if is_subclass(obj, classes_):
subclasses.append(obj.__name__)
return subclasses
def _get_all_classnames_from_file(file: Path, pkg: str) -> list[tuple[str, str]]:
"""Extract all class names from a file."""
code = Path(file).read_text(encoding="utf-8")
module_name = _get_current_module(file, pkg)
class_names = _get_class_names(code)
return [(module_name, class_name) for class_name in class_names]
def identify_all_imports_in_file(
file: str,
*,
from_package: Optional[str] = None,
) -> list[tuple[str, str]]:
"""Let's also identify all the imports in the given file."""
code = Path(file).read_text(encoding="utf-8")
return find_imports_from_package(code, from_package=from_package)
def identify_pkg_source(pkg_root: str) -> pathlib.Path:
"""Identify the source of the package.
Args:
pkg_root: the root of the package. This contains source + tests, and other
things like pyproject.toml, lock files etc
Returns:
Returns the path to the source code for the package.
"""
dirs = [d for d in Path(pkg_root).iterdir() if d.is_dir()]
matching_dirs = [d for d in dirs if d.name.startswith("langchain_")]
if len(matching_dirs) != 1:
msg = "There should be only one langchain package."
raise ValueError(msg)
return matching_dirs[0]
def list_classes_by_package(pkg_root: str) -> list[tuple[str, str]]:
"""List all classes in a package."""
module_classes = []
pkg_source = identify_pkg_source(pkg_root)
files = list(pkg_source.rglob("*.py"))
for file in files:
rel_path = os.path.relpath(file, pkg_root)
if rel_path.startswith("tests"):
continue
module_classes.extend(_get_all_classnames_from_file(file, pkg_root))
return module_classes
def list_init_imports_by_package(pkg_root: str) -> list[tuple[str, str]]:
"""List all the things that are being imported in a package by module."""
imports = []
pkg_source = identify_pkg_source(pkg_root)
# Scan all the files in the package
files = list(Path(pkg_source).rglob("*.py"))
for file in files:
if file.name != "__init__.py":
continue
import_in_file = identify_all_imports_in_file(str(file))
module_name = _get_current_module(file, pkg_root)
imports.extend([(module_name, item) for _, item in import_in_file])
return imports
def find_imports_from_package(
code: str,
*,
from_package: Optional[str] = None,
) -> list[tuple[str, str]]:
"""Find imports in code."""
# Parse the code into an AST
tree = ast.parse(code)
# Create an instance of the visitor
extractor = ImportExtractor(from_package=from_package)
# Use the visitor to update the imports list
extractor.visit(tree)
return extractor.imports
def _get_current_module(path: Path, pkg_root: str) -> str:
"""Convert a path to a module name."""
relative_path = path.relative_to(pkg_root).with_suffix("")
posix_path = relative_path.as_posix()
norm_path = os.path.normpath(str(posix_path))
fully_qualified_module = norm_path.replace("/", ".")
# Strip __init__ if present
if fully_qualified_module.endswith(".__init__"):
return fully_qualified_module[:-9]
return fully_qualified_module