langchain/libs/cli/langchain_cli/namespaces/migrate/generate/generic.py
Mason Daugherty e7eac27241
ruff: more rules across the board & fixes (#31898)
* standardizes ruff dep version across all `pyproject.toml` files
* cli: ruff rules and corrections
* langchain: rules and corrections
2025-07-07 17:48:01 -04:00

157 lines
5.4 KiB
Python

"""Generate migrations from langchain to langchain-community or core packages."""
import importlib
import inspect
import pkgutil
def generate_raw_migrations(
from_package: str,
to_package: str,
filter_by_all: bool = False, # noqa: FBT001, FBT002
) -> list[tuple[str, str]]:
"""Scan the `langchain` package and generate migrations for all modules."""
package = importlib.import_module(from_package)
items = []
for _importer, modname, _ispkg in pkgutil.walk_packages(
package.__path__,
package.__name__ + ".",
):
try:
module = importlib.import_module(modname)
except ModuleNotFoundError:
continue
# Check if the module is an __init__ file and evaluate __all__
try:
has_all = hasattr(module, "__all__")
except ImportError:
has_all = False
if has_all:
all_objects = module.__all__
for name in all_objects:
# Attempt to fetch each object declared in __all__
try:
obj = getattr(module, name, None)
except ImportError:
continue
if (
obj
and (inspect.isclass(obj) or inspect.isfunction(obj))
and obj.__module__.startswith(to_package)
):
items.append(
(f"{modname}.{name}", f"{obj.__module__}.{obj.__name__}"),
)
if not filter_by_all:
# Iterate over all members of the module
for name, obj in inspect.getmembers(module):
# Check if it's a class or function
# Check if the module name of the obj starts with
# 'langchain_community'
if inspect.isclass(obj) or (
inspect.isfunction(obj) and obj.__module__.startswith(to_package)
):
items.append(
(f"{modname}.{name}", f"{obj.__module__}.{obj.__name__}"),
)
return items
def generate_top_level_imports(pkg: str) -> list[tuple[str, str]]:
"""Look at all the top level modules in langchain_community.
Attempt to import everything from each ``__init__`` file. For example,
langchain_community/
chat_models/
__init__.py # <-- import everything from here
llm/
__init__.py # <-- import everything from here
It'll collect all the imports, import the classes / functions it can find
there. It'll return a list of 2-tuples
Each tuple will contain the fully qualified path of the class / function to where
its logic is defined
(e.g., ``langchain_community.chat_models.xyz_implementation.ver2.XYZ``)
and the second tuple will contain the path
to importing it from the top level namespaces
(e.g., ``langchain_community.chat_models.XYZ``)
"""
package = importlib.import_module(pkg)
items = []
# Function to handle importing from modules
def handle_module(module, module_name) -> None:
if hasattr(module, "__all__"):
all_objects = module.__all__
for name in all_objects:
# Attempt to fetch each object declared in __all__
obj = getattr(module, name, None)
if obj and (inspect.isclass(obj) or inspect.isfunction(obj)):
# Capture the fully qualified name of the object
original_module = obj.__module__
original_name = obj.__name__
# Form the new import path from the top-level namespace
top_level_import = f"{module_name}.{name}"
# Append the tuple with original and top-level paths
items.append(
(f"{original_module}.{original_name}", top_level_import),
)
# Handle the package itself (root level)
handle_module(package, pkg)
# Only iterate through top-level modules/packages
for _finder, modname, ispkg in pkgutil.iter_modules(
package.__path__,
package.__name__ + ".",
):
if ispkg:
try:
module = importlib.import_module(modname)
handle_module(module, modname)
except ModuleNotFoundError:
continue
return items
def generate_simplified_migrations(
from_package: str,
to_package: str,
filter_by_all: bool = True, # noqa: FBT001, FBT002
) -> list[tuple[str, str]]:
"""Get all the raw migrations, then simplify them if possible."""
raw_migrations = generate_raw_migrations(
from_package,
to_package,
filter_by_all=filter_by_all,
)
top_level_simplifications = generate_top_level_imports(to_package)
top_level_dict = dict(top_level_simplifications)
simple_migrations = []
for migration in raw_migrations:
original, new = migration
replacement = top_level_dict.get(new, new)
simple_migrations.append((original, replacement))
# Now let's deduplicate the list based on the original path (which is
# the 1st element of the tuple)
deduped_migrations = []
seen = set()
for migration in simple_migrations:
original = migration[0]
if original not in seen:
deduped_migrations.append(migration)
seen.add(original)
return deduped_migrations