cli[minor]: Fix bug to account for name changes (#20948)

* Fix bug to account for name changes / aliases
* Generate migration list from langchain to langchain_core
This commit is contained in:
Eugene Yurtsev
2024-04-26 15:45:11 -04:00
committed by GitHub
parent 989e4a92c2
commit 8ed150b2fe
6 changed files with 5760 additions and 26 deletions

View File

@@ -5,11 +5,11 @@ import pkgutil
from typing import List, Tuple
def generate_raw_migrations_to_community() -> List[Tuple[str, str]]:
def generate_raw_migrations(
from_package: str, to_package: str
) -> List[Tuple[str, str]]:
"""Scan the `langchain` package and generate migrations for all modules."""
import langchain as package
to_package = "langchain_community"
package = importlib.import_module(from_package)
items = []
for importer, modname, ispkg in pkgutil.walk_packages(
@@ -36,7 +36,9 @@ def generate_raw_migrations_to_community() -> List[Tuple[str, str]]:
continue
if obj and (inspect.isclass(obj) or inspect.isfunction(obj)):
if obj.__module__.startswith(to_package):
items.append((f"{modname}.{name}", f"{obj.__module__}.{name}"))
items.append(
(f"{modname}.{name}", f"{obj.__module__}.{obj.__name__}")
)
# Iterate over all members of the module
for name, obj in inspect.getmembers(module):
@@ -44,12 +46,14 @@ def generate_raw_migrations_to_community() -> List[Tuple[str, str]]:
if inspect.isclass(obj) or inspect.isfunction(obj):
# Check if the module name of the obj starts with 'langchain_community'
if obj.__module__.startswith(to_package):
items.append((f"{modname}.{name}", f"{obj.__module__}.{name}"))
items.append(
(f"{modname}.{name}", f"{obj.__module__}.{obj.__name__}")
)
return items
def generate_top_level_imports_community() -> List[Tuple[str, str]]:
def generate_top_level_imports(pkg: str) -> List[Tuple[str, str]]:
"""This code will look at all the top level modules in langchain_community.
It'll attempt to import everything from each __init__ file
@@ -73,7 +77,9 @@ def generate_top_level_imports_community() -> List[Tuple[str, str]]:
to importing it from the top level namespaces
(e.g., langchain_community.chat_models.XYZ)
"""
import langchain_community as package
import importlib
package = importlib.import_module(pkg)
items = []
# Only iterate through top-level modules/packages
@@ -105,10 +111,12 @@ def generate_top_level_imports_community() -> List[Tuple[str, str]]:
return items
def generate_simplified_migrations() -> List[Tuple[str, str]]:
def generate_simplified_migrations(
from_package: str, to_package: str
) -> List[Tuple[str, str]]:
"""Get all the raw migrations, then simplify them if possible."""
raw_migrations = generate_raw_migrations_to_community()
top_level_simplifications = generate_top_level_imports_community()
raw_migrations = generate_raw_migrations(from_package, to_package)
top_level_simplifications = generate_top_level_imports(to_package)
top_level_dict = {full: top_level for full, top_level in top_level_simplifications}
simple_migrations = []
for migration in raw_migrations: