mirror of
				https://github.com/hwchase17/langchain.git
				synced 2025-10-26 05:10:22 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			173 lines
		
	
	
		
			5.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			173 lines
		
	
	
		
			5.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """Generate migrations utilities."""
 | |
| 
 | |
| import ast
 | |
| import inspect
 | |
| import os
 | |
| import pathlib
 | |
| from pathlib import Path
 | |
| from types import ModuleType
 | |
| from typing import Optional
 | |
| 
 | |
| from typing_extensions import override
 | |
| 
 | |
| HERE = Path(__file__).parent
 | |
| # Should bring us to [root]/src
 | |
| PKGS_ROOT = HERE.parent.parent.parent.parent.parent
 | |
| 
 | |
| LANGCHAIN_PKG = PKGS_ROOT / "langchain"
 | |
| COMMUNITY_PKG = PKGS_ROOT / "community"
 | |
| PARTNER_PKGS = PKGS_ROOT / "partners"
 | |
| 
 | |
| 
 | |
| class ImportExtractor(ast.NodeVisitor):
 | |
|     """Import extractor."""
 | |
| 
 | |
|     def __init__(self, *, from_package: Optional[str] = None) -> None:
 | |
|         """Extract all imports from the given code, optionally filtering by package."""
 | |
|         self.imports: list[tuple[str, str]] = []
 | |
|         self.package = from_package
 | |
| 
 | |
|     @override
 | |
|     def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
 | |
|         if node.module and (
 | |
|             self.package is None or str(node.module).startswith(self.package)
 | |
|         ):
 | |
|             for alias in node.names:
 | |
|                 self.imports.append((node.module, alias.name))
 | |
|         self.generic_visit(node)
 | |
| 
 | |
| 
 | |
| def _get_class_names(code: str) -> list[str]:
 | |
|     """Extract class names from a code string."""
 | |
|     # Parse the content of the file into an AST
 | |
|     tree = ast.parse(code)
 | |
| 
 | |
|     # Initialize a list to hold all class names
 | |
|     class_names = []
 | |
| 
 | |
|     # Define a node visitor class to collect class names
 | |
|     class ClassVisitor(ast.NodeVisitor):
 | |
|         @override
 | |
|         def visit_ClassDef(self, node: ast.ClassDef) -> None:
 | |
|             class_names.append(node.name)
 | |
|             self.generic_visit(node)
 | |
| 
 | |
|     # Create an instance of the visitor and visit the AST
 | |
|     visitor = ClassVisitor()
 | |
|     visitor.visit(tree)
 | |
|     return class_names
 | |
| 
 | |
| 
 | |
| def is_subclass(class_obj: type, classes_: list[type]) -> bool:
 | |
|     """Check if the given class object is a subclass of any class in list classes."""
 | |
|     return any(
 | |
|         issubclass(class_obj, kls)
 | |
|         for kls in classes_
 | |
|         if inspect.isclass(class_obj) and inspect.isclass(kls)
 | |
|     )
 | |
| 
 | |
| 
 | |
| def find_subclasses_in_module(module: ModuleType, classes_: list[type]) -> list[str]:
 | |
|     """Find all classes in the module that inherit from one of the classes."""
 | |
|     subclasses = []
 | |
|     # Iterate over all attributes of the module that are classes
 | |
|     for _name, obj in inspect.getmembers(module, inspect.isclass):
 | |
|         if is_subclass(obj, classes_):
 | |
|             subclasses.append(obj.__name__)
 | |
|     return subclasses
 | |
| 
 | |
| 
 | |
| def _get_all_classnames_from_file(file: Path, pkg: str) -> list[tuple[str, str]]:
 | |
|     """Extract all class names from a file."""
 | |
|     code = Path(file).read_text(encoding="utf-8")
 | |
|     module_name = _get_current_module(file, pkg)
 | |
|     class_names = _get_class_names(code)
 | |
| 
 | |
|     return [(module_name, class_name) for class_name in class_names]
 | |
| 
 | |
| 
 | |
| def identify_all_imports_in_file(
 | |
|     file: str,
 | |
|     *,
 | |
|     from_package: Optional[str] = None,
 | |
| ) -> list[tuple[str, str]]:
 | |
|     """Let's also identify all the imports in the given file."""
 | |
|     code = Path(file).read_text(encoding="utf-8")
 | |
|     return find_imports_from_package(code, from_package=from_package)
 | |
| 
 | |
| 
 | |
| def identify_pkg_source(pkg_root: str) -> pathlib.Path:
 | |
|     """Identify the source of the package.
 | |
| 
 | |
|     Args:
 | |
|         pkg_root: the root of the package. This contains source + tests, and other
 | |
|             things like pyproject.toml, lock files etc
 | |
| 
 | |
|     Returns:
 | |
|         Returns the path to the source code for the package.
 | |
| 
 | |
|     """
 | |
|     dirs = [d for d in Path(pkg_root).iterdir() if d.is_dir()]
 | |
|     matching_dirs = [d for d in dirs if d.name.startswith("langchain_")]
 | |
|     if len(matching_dirs) != 1:
 | |
|         msg = "There should be only one langchain package."
 | |
|         raise ValueError(msg)
 | |
|     return matching_dirs[0]
 | |
| 
 | |
| 
 | |
| def list_classes_by_package(pkg_root: str) -> list[tuple[str, str]]:
 | |
|     """List all classes in a package."""
 | |
|     module_classes = []
 | |
|     pkg_source = identify_pkg_source(pkg_root)
 | |
|     files = list(pkg_source.rglob("*.py"))
 | |
| 
 | |
|     for file in files:
 | |
|         rel_path = os.path.relpath(file, pkg_root)
 | |
|         if rel_path.startswith("tests"):
 | |
|             continue
 | |
|         module_classes.extend(_get_all_classnames_from_file(file, pkg_root))
 | |
|     return module_classes
 | |
| 
 | |
| 
 | |
| def list_init_imports_by_package(pkg_root: str) -> list[tuple[str, str]]:
 | |
|     """List all the things that are being imported in a package by module."""
 | |
|     imports = []
 | |
|     pkg_source = identify_pkg_source(pkg_root)
 | |
|     # Scan all the files in the package
 | |
|     files = list(Path(pkg_source).rglob("*.py"))
 | |
| 
 | |
|     for file in files:
 | |
|         if file.name != "__init__.py":
 | |
|             continue
 | |
|         import_in_file = identify_all_imports_in_file(str(file))
 | |
|         module_name = _get_current_module(file, pkg_root)
 | |
|         imports.extend([(module_name, item) for _, item in import_in_file])
 | |
|     return imports
 | |
| 
 | |
| 
 | |
| def find_imports_from_package(
 | |
|     code: str,
 | |
|     *,
 | |
|     from_package: Optional[str] = None,
 | |
| ) -> list[tuple[str, str]]:
 | |
|     """Find imports in code."""
 | |
|     # Parse the code into an AST
 | |
|     tree = ast.parse(code)
 | |
|     # Create an instance of the visitor
 | |
|     extractor = ImportExtractor(from_package=from_package)
 | |
|     # Use the visitor to update the imports list
 | |
|     extractor.visit(tree)
 | |
|     return extractor.imports
 | |
| 
 | |
| 
 | |
| def _get_current_module(path: Path, pkg_root: str) -> str:
 | |
|     """Convert a path to a module name."""
 | |
|     relative_path = path.relative_to(pkg_root).with_suffix("")
 | |
|     posix_path = relative_path.as_posix()
 | |
|     norm_path = os.path.normpath(str(posix_path))
 | |
|     fully_qualified_module = norm_path.replace("/", ".")
 | |
|     # Strip __init__ if present
 | |
|     if fully_qualified_module.endswith(".__init__"):
 | |
|         return fully_qualified_module[:-9]
 | |
|     return fully_qualified_module
 |