mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-28 11:55:21 +00:00
cli[minor]: update code to generate migrations from langchain to community (#20946)
Updates code that generates migrations from langchain to community
This commit is contained in:
parent
078c5d9bc6
commit
2fa0ff1a2d
File diff suppressed because it is too large
Load Diff
@ -1,52 +1,129 @@
|
||||
"""Generate migrations from langchain to langchain-community or core packages."""
|
||||
import glob
|
||||
from pathlib import Path
|
||||
import importlib
|
||||
import inspect
|
||||
import pkgutil
|
||||
from typing import List, Tuple
|
||||
|
||||
from langchain_cli.namespaces.migrate.generate.utils import (
|
||||
_get_current_module,
|
||||
find_imports_from_package,
|
||||
)
|
||||
|
||||
HERE = Path(__file__).parent
|
||||
PKGS_ROOT = HERE.parent.parent.parent
|
||||
LANGCHAIN_PKG = PKGS_ROOT / "langchain"
|
||||
COMMUNITY_PKG = PKGS_ROOT / "community"
|
||||
PARTNER_PKGS = PKGS_ROOT / "partners"
|
||||
def generate_raw_migrations_to_community() -> List[Tuple[str, str]]:
|
||||
"""Scan the `langchain` package and generate migrations for all modules."""
|
||||
import langchain as package
|
||||
|
||||
to_package = "langchain_community"
|
||||
|
||||
items = []
|
||||
for importer, modname, ispkg in pkgutil.walk_packages(
|
||||
package.__path__, package.__name__ + "."
|
||||
):
|
||||
try:
|
||||
module = importlib.import_module(modname)
|
||||
except ModuleNotFoundError:
|
||||
continue
|
||||
|
||||
# Check if the module is an __init__ file and evaluate __all__
|
||||
try:
|
||||
has_all = hasattr(module, "__all__")
|
||||
except ImportError:
|
||||
has_all = False
|
||||
|
||||
if has_all:
|
||||
all_objects = module.__all__
|
||||
for name in all_objects:
|
||||
# Attempt to fetch each object declared in __all__
|
||||
try:
|
||||
obj = getattr(module, name, None)
|
||||
except ImportError:
|
||||
continue
|
||||
if obj and (inspect.isclass(obj) or inspect.isfunction(obj)):
|
||||
if obj.__module__.startswith(to_package):
|
||||
items.append((f"{modname}.{name}", f"{obj.__module__}.{name}"))
|
||||
|
||||
# Iterate over all members of the module
|
||||
for name, obj in inspect.getmembers(module):
|
||||
# Check if it's a class or function
|
||||
if inspect.isclass(obj) or inspect.isfunction(obj):
|
||||
# Check if the module name of the obj starts with 'langchain_community'
|
||||
if obj.__module__.startswith(to_package):
|
||||
items.append((f"{modname}.{name}", f"{obj.__module__}.{name}"))
|
||||
|
||||
return items
|
||||
|
||||
|
||||
def _generate_migrations_from_file(
|
||||
source_module: str, code: str, *, from_package: str
|
||||
) -> List[Tuple[str, str]]:
|
||||
"""Generate migrations"""
|
||||
imports = find_imports_from_package(code, from_package=from_package)
|
||||
return [
|
||||
# Rewrite in a list comprehension
|
||||
(f"{source_module}.{item}", f"{new_path}.{item}")
|
||||
for new_path, item in imports
|
||||
]
|
||||
def generate_top_level_imports_community() -> List[Tuple[str, str]]:
|
||||
"""This code will look at all the top level modules in langchain_community.
|
||||
|
||||
It'll attempt to import everything from each __init__ file
|
||||
|
||||
for example,
|
||||
|
||||
langchain_community/
|
||||
chat_models/
|
||||
__init__.py # <-- import everything from here
|
||||
llm/
|
||||
__init__.py # <-- import everything from here
|
||||
|
||||
|
||||
def _generate_migrations_from_file_in_pkg(
|
||||
file: str, root_pkg: str
|
||||
) -> List[Tuple[str, str]]:
|
||||
"""Generate migrations for a file that's relative to langchain pkg."""
|
||||
# Read the file.
|
||||
with open(file, encoding="utf-8") as f:
|
||||
code = f.read()
|
||||
It'll collect all the imports, import the classes / functions it can find
|
||||
there. It'll return a list of 2-tuples
|
||||
|
||||
module_name = _get_current_module(file, root_pkg)
|
||||
return _generate_migrations_from_file(
|
||||
module_name, code, from_package="langchain_community"
|
||||
)
|
||||
Each tuple will contain the fully qualified path of the class / function to where
|
||||
its logic is defined
|
||||
(e.g., langchain_community.chat_models.xyz_implementation.ver2.XYZ)
|
||||
and the second tuple will contain the path
|
||||
to importing it from the top level namespaces
|
||||
(e.g., langchain_community.chat_models.XYZ)
|
||||
"""
|
||||
import langchain_community as package
|
||||
|
||||
items = []
|
||||
# Only iterate through top-level modules/packages
|
||||
for finder, modname, ispkg in pkgutil.iter_modules(
|
||||
package.__path__, package.__name__ + "."
|
||||
):
|
||||
if ispkg:
|
||||
try:
|
||||
module = importlib.import_module(modname)
|
||||
except ModuleNotFoundError:
|
||||
continue
|
||||
|
||||
if hasattr(module, "__all__"):
|
||||
all_objects = getattr(module, "__all__")
|
||||
for name in all_objects:
|
||||
# Attempt to fetch each object declared in __all__
|
||||
obj = getattr(module, name, None)
|
||||
if obj and (inspect.isclass(obj) or inspect.isfunction(obj)):
|
||||
# Capture the fully qualified name of the object
|
||||
original_module = obj.__module__
|
||||
original_name = obj.__name__
|
||||
# Form the new import path from the top-level namespace
|
||||
top_level_import = f"{modname}.{name}"
|
||||
# Append the tuple with original and top-level paths
|
||||
items.append(
|
||||
(f"{original_module}.{original_name}", top_level_import)
|
||||
)
|
||||
|
||||
return items
|
||||
|
||||
|
||||
def generate_migrations_from_langchain_to_community() -> List[Tuple[str, str]]:
|
||||
"""Generate migrations from langchain to langchain-community."""
|
||||
migrations = []
|
||||
# scanning files in pkg
|
||||
for file_path in glob.glob(str(LANGCHAIN_PKG) + "**/*.py"):
|
||||
migrations.extend(
|
||||
_generate_migrations_from_file_in_pkg(file_path, str(LANGCHAIN_PKG))
|
||||
)
|
||||
return migrations
|
||||
def generate_simplified_migrations() -> List[Tuple[str, str]]:
|
||||
"""Get all the raw migrations, then simplify them if possible."""
|
||||
raw_migrations = generate_raw_migrations_to_community()
|
||||
top_level_simplifications = generate_top_level_imports_community()
|
||||
top_level_dict = {full: top_level for full, top_level in top_level_simplifications}
|
||||
simple_migrations = []
|
||||
for migration in raw_migrations:
|
||||
original, new = migration
|
||||
replacement = top_level_dict.get(new, new)
|
||||
simple_migrations.append((original, replacement))
|
||||
|
||||
# Now let's deduplicate the list based on the original path (which is
|
||||
# the 1st element of the tuple)
|
||||
deduped_migrations = []
|
||||
seen = set()
|
||||
for migration in simple_migrations:
|
||||
original = migration[0]
|
||||
if original not in seen:
|
||||
deduped_migrations.append(migration)
|
||||
seen.add(original)
|
||||
|
||||
return deduped_migrations
|
||||
|
@ -5,7 +5,7 @@ import pkgutil
|
||||
import click
|
||||
|
||||
from langchain_cli.namespaces.migrate.generate.langchain import (
|
||||
generate_migrations_from_langchain_to_community,
|
||||
generate_simplified_migrations,
|
||||
)
|
||||
from langchain_cli.namespaces.migrate.generate.partner import (
|
||||
get_migrations_for_partner_package,
|
||||
@ -27,9 +27,9 @@ def cli():
|
||||
def langchain(output: str) -> None:
|
||||
"""Generate a migration script."""
|
||||
click.echo("Migration script generated.")
|
||||
migrations = generate_migrations_from_langchain_to_community()
|
||||
migrations = generate_simplified_migrations()
|
||||
with open(output, "w") as f:
|
||||
f.write(json.dumps(migrations))
|
||||
f.write(json.dumps(migrations, indent=2, sort_keys=True))
|
||||
|
||||
|
||||
@cli.command()
|
||||
|
@ -0,0 +1,25 @@
|
||||
from langchain_cli.namespaces.migrate.generate.langchain import (
|
||||
generate_simplified_migrations,
|
||||
)
|
||||
|
||||
|
||||
def test_create_json_agent_migration() -> None:
|
||||
"""Test the migration of create_json_agent from langchain to langchain_community."""
|
||||
raw_migrations = generate_simplified_migrations()
|
||||
json_agent_migrations = [
|
||||
migration for migration in raw_migrations if "create_json_agent" in migration[0]
|
||||
]
|
||||
assert json_agent_migrations == [
|
||||
(
|
||||
"langchain.agents.create_json_agent",
|
||||
"langchain_community.agent_toolkits.create_json_agent",
|
||||
),
|
||||
(
|
||||
"langchain.agents.agent_toolkits.create_json_agent",
|
||||
"langchain_community.agent_toolkits.create_json_agent",
|
||||
),
|
||||
(
|
||||
"langchain.agents.agent_toolkits.json.base.create_json_agent",
|
||||
"langchain_community.agent_toolkits.create_json_agent",
|
||||
),
|
||||
]
|
Loading…
Reference in New Issue
Block a user