cli[minor]: Improve partner migrations (#20938)

This auto generates partner migrations.

At the moment the migration is from community -> partner.

So one would need to run the migration script twice to go from langchain to partner.
This commit is contained in:
Eugene Yurtsev
2024-04-26 12:30:15 -04:00
committed by GitHub
parent 5653f36adc
commit 12c906f6ce
17 changed files with 265 additions and 50 deletions

View File

@@ -2,6 +2,7 @@
import importlib
from typing import List, Tuple
from langchain_core.documents import BaseDocumentCompressor, BaseDocumentTransformer
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLanguageModel
from langchain_core.retrievers import BaseRetriever
@@ -11,6 +12,7 @@ from langchain_cli.namespaces.migrate.generate.utils import (
COMMUNITY_PKG,
find_subclasses_in_module,
list_classes_by_package,
list_init_imports_by_package,
)
# PUBLIC API
@@ -29,13 +31,24 @@ def get_migrations_for_partner_package(pkg_name: str) -> List[Tuple[str, str]]:
"""
package = importlib.import_module(pkg_name)
classes_ = find_subclasses_in_module(
package, [BaseLanguageModel, Embeddings, BaseRetriever, VectorStore]
package,
[
BaseLanguageModel,
Embeddings,
BaseRetriever,
VectorStore,
BaseDocumentTransformer,
BaseDocumentCompressor,
],
)
community_classes = list_classes_by_package(str(COMMUNITY_PKG))
imports_for_pkg = list_init_imports_by_package(str(COMMUNITY_PKG))
old_paths = community_classes + imports_for_pkg
migrations = [
(f"{community_module}.{community_class}", f"{pkg_name}.{community_class}")
for community_module, community_class in community_classes
if community_class in classes_
(f"{module}.{item}", f"{pkg_name}.{item}")
for module, item in old_paths
if item in classes_
]
return migrations

View File

@@ -3,7 +3,7 @@ import inspect
import os
import pathlib
from pathlib import Path
from typing import Any, List, Tuple, Type
from typing import Any, List, Optional, Tuple, Type
HERE = Path(__file__).parent
# Should bring us to [root]/src
@@ -15,13 +15,15 @@ PARTNER_PKGS = PKGS_ROOT / "partners"
class ImportExtractor(ast.NodeVisitor):
def __init__(self, *, from_package: str) -> None:
"""Extract all imports from the given package."""
def __init__(self, *, from_package: Optional[str] = None) -> None:
"""Extract all imports from the given code, optionally filtering by package."""
self.imports = []
self.package = from_package
def visit_ImportFrom(self, node):
if node.module and str(node.module).startswith(self.package):
if node.module and (
self.package is None or str(node.module).startswith(self.package)
):
for alias in node.names:
self.imports.append((node.module, alias.name))
self.generic_visit(node)
@@ -72,13 +74,40 @@ def _get_all_classnames_from_file(file: str, pkg: str) -> List[Tuple[str, str]]:
code = f.read()
module_name = _get_current_module(file, pkg)
class_names = _get_class_names(code)
return [(module_name, class_name) for class_name in class_names]
def identify_all_imports_in_file(
file: str, *, from_package: Optional[str] = None
) -> List[Tuple[str, str]]:
"""Let's also identify all the imports in the given file."""
with open(file, encoding="utf-8") as f:
code = f.read()
return find_imports_from_package(code, from_package=from_package)
def identify_pkg_source(pkg_root: str) -> pathlib.Path:
"""Identify the source of the package.
Args:
pkg_root: the root of the package. This contains source + tests, and other
things like pyproject.toml, lock files etc
Returns:
Returns the path to the source code for the package.
"""
dirs = [d for d in Path(pkg_root).iterdir() if d.is_dir()]
matching_dirs = [d for d in dirs if d.name.startswith("langchain_")]
assert len(matching_dirs) == 1, "There should be only one langchain package."
return matching_dirs[0]
def list_classes_by_package(pkg_root: str) -> List[Tuple[str, str]]:
"""List all classes in a package."""
module_classes = []
files = list(Path(pkg_root).rglob("*.py"))
pkg_source = identify_pkg_source(pkg_root)
files = list(pkg_source.rglob("*.py"))
for file in files:
rel_path = os.path.relpath(file, pkg_root)
@@ -88,11 +117,29 @@ def list_classes_by_package(pkg_root: str) -> List[Tuple[str, str]]:
return module_classes
def find_imports_from_package(code: str, *, from_package: str) -> List[Tuple[str, str]]:
def list_init_imports_by_package(pkg_root: str) -> List[Tuple[str, str]]:
"""List all the things that are being imported in a package by module."""
imports = []
pkg_source = identify_pkg_source(pkg_root)
# Scan all the files in the package
files = list(Path(pkg_source).rglob("*.py"))
for file in files:
if not file.name == "__init__.py":
continue
import_in_file = identify_all_imports_in_file(str(file))
module_name = _get_current_module(file, pkg_root)
imports.extend([(module_name, item) for _, item in import_in_file])
return imports
def find_imports_from_package(
code: str, *, from_package: Optional[str] = None
) -> List[Tuple[str, str]]:
# Parse the code into an AST
tree = ast.parse(code)
# Create an instance of the visitor
extractor = ImportExtractor(from_package="langchain_community")
extractor = ImportExtractor(from_package=from_package)
# Use the visitor to update the imports list
extractor.visit(tree)
return extractor.imports