"""Generate migrations from langchain to langchain-community or core packages.""" import importlib import inspect import pkgutil def generate_raw_migrations( from_package: str, to_package: str, filter_by_all: bool = False, # noqa: FBT001, FBT002 ) -> list[tuple[str, str]]: """Scan the `langchain` package and generate migrations for all modules.""" package = importlib.import_module(from_package) items = [] for _importer, modname, _ispkg in pkgutil.walk_packages( package.__path__, package.__name__ + ".", ): try: module = importlib.import_module(modname) except ModuleNotFoundError: continue # Check if the module is an __init__ file and evaluate __all__ try: has_all = hasattr(module, "__all__") except ImportError: has_all = False if has_all: all_objects = module.__all__ for name in all_objects: # Attempt to fetch each object declared in __all__ try: obj = getattr(module, name, None) except ImportError: continue if ( obj and (inspect.isclass(obj) or inspect.isfunction(obj)) and obj.__module__.startswith(to_package) ): items.append( (f"{modname}.{name}", f"{obj.__module__}.{obj.__name__}"), ) if not filter_by_all: # Iterate over all members of the module for name, obj in inspect.getmembers(module): # Check if it's a class or function # Check if the module name of the obj starts with # 'langchain_community' if inspect.isclass(obj) or ( inspect.isfunction(obj) and obj.__module__.startswith(to_package) ): items.append( (f"{modname}.{name}", f"{obj.__module__}.{obj.__name__}"), ) return items def generate_top_level_imports(pkg: str) -> list[tuple[str, str]]: """Look at all the top level modules in langchain_community. Attempt to import everything from each ``__init__`` file. For example, langchain_community/ chat_models/ __init__.py # <-- import everything from here llm/ __init__.py # <-- import everything from here It'll collect all the imports, import the classes / functions it can find there. It'll return a list of 2-tuples Each tuple will contain the fully qualified path of the class / function to where its logic is defined (e.g., ``langchain_community.chat_models.xyz_implementation.ver2.XYZ``) and the second tuple will contain the path to importing it from the top level namespaces (e.g., ``langchain_community.chat_models.XYZ``) """ package = importlib.import_module(pkg) items = [] # Function to handle importing from modules def handle_module(module, module_name) -> None: if hasattr(module, "__all__"): all_objects = module.__all__ for name in all_objects: # Attempt to fetch each object declared in __all__ obj = getattr(module, name, None) if obj and (inspect.isclass(obj) or inspect.isfunction(obj)): # Capture the fully qualified name of the object original_module = obj.__module__ original_name = obj.__name__ # Form the new import path from the top-level namespace top_level_import = f"{module_name}.{name}" # Append the tuple with original and top-level paths items.append( (f"{original_module}.{original_name}", top_level_import), ) # Handle the package itself (root level) handle_module(package, pkg) # Only iterate through top-level modules/packages for _finder, modname, ispkg in pkgutil.iter_modules( package.__path__, package.__name__ + ".", ): if ispkg: try: module = importlib.import_module(modname) handle_module(module, modname) except ModuleNotFoundError: continue return items def generate_simplified_migrations( from_package: str, to_package: str, filter_by_all: bool = True, # noqa: FBT001, FBT002 ) -> list[tuple[str, str]]: """Get all the raw migrations, then simplify them if possible.""" raw_migrations = generate_raw_migrations( from_package, to_package, filter_by_all=filter_by_all, ) top_level_simplifications = generate_top_level_imports(to_package) top_level_dict = dict(top_level_simplifications) simple_migrations = [] for migration in raw_migrations: original, new = migration replacement = top_level_dict.get(new, new) simple_migrations.append((original, replacement)) # Now let's deduplicate the list based on the original path (which is # the 1st element of the tuple) deduped_migrations = [] seen = set() for migration in simple_migrations: original = migration[0] if original not in seen: deduped_migrations.append(migration) seen.add(original) return deduped_migrations