cli[minor]: Improve partner migrations (#20938)

This auto generates partner migrations.

At the moment the migration is from community -> partner.

So one would need to run the migration script twice to go from langchain to partner.
This commit is contained in:
Eugene Yurtsev 2024-04-26 12:30:15 -04:00 committed by GitHub
parent 5653f36adc
commit 12c906f6ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 265 additions and 50 deletions

View File

@ -0,0 +1,18 @@
[
[
"langchain_community.llms.anthropic.Anthropic",
"langchain_anthropic.Anthropic"
],
[
"langchain_community.chat_models.anthropic.ChatAnthropic",
"langchain_anthropic.ChatAnthropic"
],
[
"langchain_community.llms.Anthropic",
"langchain_anthropic.Anthropic"
],
[
"langchain_community.chat_models.ChatAnthropic",
"langchain_anthropic.ChatAnthropic"
]
]

View File

@ -0,0 +1,18 @@
[
[
"langchain_community.llms.fireworks.Fireworks",
"langchain_fireworks.Fireworks"
],
[
"langchain_community.chat_models.fireworks.ChatFireworks",
"langchain_fireworks.ChatFireworks"
],
[
"langchain_community.llms.Fireworks",
"langchain_fireworks.Fireworks"
],
[
"langchain_community.chat_models.ChatFireworks",
"langchain_fireworks.ChatFireworks"
]
]

View File

@ -0,0 +1,10 @@
[
[
"langchain_community.llms.watsonxllm.WatsonxLLM",
"langchain_ibm.WatsonxLLM"
],
[
"langchain_community.llms.WatsonxLLM",
"langchain_ibm.WatsonxLLM"
]
]

View File

@ -0,0 +1,50 @@
[
[
"langchain_community.llms.openai.OpenAI",
"langchain_openai.OpenAI"
],
[
"langchain_community.llms.openai.AzureOpenAI",
"langchain_openai.AzureOpenAI"
],
[
"langchain_community.embeddings.openai.OpenAIEmbeddings",
"langchain_openai.OpenAIEmbeddings"
],
[
"langchain_community.embeddings.azure_openai.AzureOpenAIEmbeddings",
"langchain_openai.AzureOpenAIEmbeddings"
],
[
"langchain_community.chat_models.openai.ChatOpenAI",
"langchain_openai.ChatOpenAI"
],
[
"langchain_community.chat_models.azure_openai.AzureChatOpenAI",
"langchain_openai.AzureChatOpenAI"
],
[
"langchain_community.llms.AzureOpenAI",
"langchain_openai.AzureOpenAI"
],
[
"langchain_community.llms.OpenAI",
"langchain_openai.OpenAI"
],
[
"langchain_community.embeddings.AzureOpenAIEmbeddings",
"langchain_openai.AzureOpenAIEmbeddings"
],
[
"langchain_community.embeddings.OpenAIEmbeddings",
"langchain_openai.OpenAIEmbeddings"
],
[
"langchain_community.chat_models.AzureChatOpenAI",
"langchain_openai.AzureChatOpenAI"
],
[
"langchain_community.chat_models.ChatOpenAI",
"langchain_openai.ChatOpenAI"
]
]

View File

@ -0,0 +1,10 @@
[
[
"langchain_community.vectorstores.pinecone.Pinecone",
"langchain_pinecone.Pinecone"
],
[
"langchain_community.vectorstores.Pinecone",
"langchain_pinecone.Pinecone"
]
]

View File

@ -1,14 +0,0 @@
[
[
"langchain.chat_models.ChatOpenAI",
"langchain_openai.ChatOpenAI"
],
[
"langchain.chat_models.ChatOpenAI",
"langchain_openai.ChatOpenAI"
],
[
"langchain.chat_models.ChatAnthropic",
"langchain_anthropic.ChatAnthropic"
]
]

View File

@ -26,7 +26,7 @@ HERE = os.path.dirname(__file__)
def _load_migrations_by_file(path: str): def _load_migrations_by_file(path: str):
migrations_path = os.path.join(HERE, path) migrations_path = os.path.join(HERE, "migrations", path)
with open(migrations_path, "r", encoding="utf-8") as f: with open(migrations_path, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
return data return data
@ -43,21 +43,30 @@ def _deduplicate_in_order(
return [x for x in seq if not (key(x) in seen or seen_add(key(x)))] return [x for x in seq if not (key(x) in seen or seen_add(key(x)))]
def _load_migrations(): PARTNERS = [
"""Load the migrations from the JSON file.""" "anthropic.json",
# Later earlier ones have higher precedence. "ibm.json",
paths = [ "openai.json",
"migrations_v0.2_partner.json", "pinecone.json",
"migrations_v0.2.json", "fireworks.json",
] ]
def _load_migrations_from_fixtures() -> List[Tuple[str, str]]:
"""Load migrations from fixtures."""
paths: List[str] = PARTNERS + ["langchain.json"]
data = [] data = []
for path in paths: for path in paths:
data.extend(_load_migrations_by_file(path)) data.extend(_load_migrations_by_file(path))
data = _deduplicate_in_order(data, key=lambda x: x[0]) data = _deduplicate_in_order(data, key=lambda x: x[0])
return data
def _load_migrations():
"""Load the migrations from the JSON file."""
# Later earlier ones have higher precedence.
imports: Dict[str, Tuple[str, str]] = {} imports: Dict[str, Tuple[str, str]] = {}
data = _load_migrations_from_fixtures()
for old_path, new_path in data: for old_path, new_path in data:
# Parse the old parse which is of the format 'langchain.chat_models.ChatOpenAI' # Parse the old parse which is of the format 'langchain.chat_models.ChatOpenAI'

View File

@ -2,6 +2,7 @@
import importlib import importlib
from typing import List, Tuple from typing import List, Tuple
from langchain_core.documents import BaseDocumentCompressor, BaseDocumentTransformer
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
from langchain_core.retrievers import BaseRetriever from langchain_core.retrievers import BaseRetriever
@ -11,6 +12,7 @@ from langchain_cli.namespaces.migrate.generate.utils import (
COMMUNITY_PKG, COMMUNITY_PKG,
find_subclasses_in_module, find_subclasses_in_module,
list_classes_by_package, list_classes_by_package,
list_init_imports_by_package,
) )
# PUBLIC API # PUBLIC API
@ -29,13 +31,24 @@ def get_migrations_for_partner_package(pkg_name: str) -> List[Tuple[str, str]]:
""" """
package = importlib.import_module(pkg_name) package = importlib.import_module(pkg_name)
classes_ = find_subclasses_in_module( classes_ = find_subclasses_in_module(
package, [BaseLanguageModel, Embeddings, BaseRetriever, VectorStore] package,
[
BaseLanguageModel,
Embeddings,
BaseRetriever,
VectorStore,
BaseDocumentTransformer,
BaseDocumentCompressor,
],
) )
community_classes = list_classes_by_package(str(COMMUNITY_PKG)) community_classes = list_classes_by_package(str(COMMUNITY_PKG))
imports_for_pkg = list_init_imports_by_package(str(COMMUNITY_PKG))
old_paths = community_classes + imports_for_pkg
migrations = [ migrations = [
(f"{community_module}.{community_class}", f"{pkg_name}.{community_class}") (f"{module}.{item}", f"{pkg_name}.{item}")
for community_module, community_class in community_classes for module, item in old_paths
if community_class in classes_ if item in classes_
] ]
return migrations return migrations

View File

@ -3,7 +3,7 @@ import inspect
import os import os
import pathlib import pathlib
from pathlib import Path from pathlib import Path
from typing import Any, List, Tuple, Type from typing import Any, List, Optional, Tuple, Type
HERE = Path(__file__).parent HERE = Path(__file__).parent
# Should bring us to [root]/src # Should bring us to [root]/src
@ -15,13 +15,15 @@ PARTNER_PKGS = PKGS_ROOT / "partners"
class ImportExtractor(ast.NodeVisitor): class ImportExtractor(ast.NodeVisitor):
def __init__(self, *, from_package: str) -> None: def __init__(self, *, from_package: Optional[str] = None) -> None:
"""Extract all imports from the given package.""" """Extract all imports from the given code, optionally filtering by package."""
self.imports = [] self.imports = []
self.package = from_package self.package = from_package
def visit_ImportFrom(self, node): def visit_ImportFrom(self, node):
if node.module and str(node.module).startswith(self.package): if node.module and (
self.package is None or str(node.module).startswith(self.package)
):
for alias in node.names: for alias in node.names:
self.imports.append((node.module, alias.name)) self.imports.append((node.module, alias.name))
self.generic_visit(node) self.generic_visit(node)
@ -72,13 +74,40 @@ def _get_all_classnames_from_file(file: str, pkg: str) -> List[Tuple[str, str]]:
code = f.read() code = f.read()
module_name = _get_current_module(file, pkg) module_name = _get_current_module(file, pkg)
class_names = _get_class_names(code) class_names = _get_class_names(code)
return [(module_name, class_name) for class_name in class_names] return [(module_name, class_name) for class_name in class_names]
def identify_all_imports_in_file(
file: str, *, from_package: Optional[str] = None
) -> List[Tuple[str, str]]:
"""Let's also identify all the imports in the given file."""
with open(file, encoding="utf-8") as f:
code = f.read()
return find_imports_from_package(code, from_package=from_package)
def identify_pkg_source(pkg_root: str) -> pathlib.Path:
"""Identify the source of the package.
Args:
pkg_root: the root of the package. This contains source + tests, and other
things like pyproject.toml, lock files etc
Returns:
Returns the path to the source code for the package.
"""
dirs = [d for d in Path(pkg_root).iterdir() if d.is_dir()]
matching_dirs = [d for d in dirs if d.name.startswith("langchain_")]
assert len(matching_dirs) == 1, "There should be only one langchain package."
return matching_dirs[0]
def list_classes_by_package(pkg_root: str) -> List[Tuple[str, str]]: def list_classes_by_package(pkg_root: str) -> List[Tuple[str, str]]:
"""List all classes in a package.""" """List all classes in a package."""
module_classes = [] module_classes = []
files = list(Path(pkg_root).rglob("*.py")) pkg_source = identify_pkg_source(pkg_root)
files = list(pkg_source.rglob("*.py"))
for file in files: for file in files:
rel_path = os.path.relpath(file, pkg_root) rel_path = os.path.relpath(file, pkg_root)
@ -88,11 +117,29 @@ def list_classes_by_package(pkg_root: str) -> List[Tuple[str, str]]:
return module_classes return module_classes
def find_imports_from_package(code: str, *, from_package: str) -> List[Tuple[str, str]]: def list_init_imports_by_package(pkg_root: str) -> List[Tuple[str, str]]:
"""List all the things that are being imported in a package by module."""
imports = []
pkg_source = identify_pkg_source(pkg_root)
# Scan all the files in the package
files = list(Path(pkg_source).rglob("*.py"))
for file in files:
if not file.name == "__init__.py":
continue
import_in_file = identify_all_imports_in_file(str(file))
module_name = _get_current_module(file, pkg_root)
imports.extend([(module_name, item) for _, item in import_in_file])
return imports
def find_imports_from_package(
code: str, *, from_package: Optional[str] = None
) -> List[Tuple[str, str]]:
# Parse the code into an AST # Parse the code into an AST
tree = ast.parse(code) tree = ast.parse(code)
# Create an instance of the visitor # Create an instance of the visitor
extractor = ImportExtractor(from_package="langchain_community") extractor = ImportExtractor(from_package=from_package)
# Use the visitor to update the imports list # Use the visitor to update the imports list
extractor.visit(tree) extractor.visit(tree)
return extractor.imports return extractor.imports

View File

@ -50,7 +50,7 @@ select = [
] ]
[tool.poe.tasks] [tool.poe.tasks]
test = "poetry run pytest" test = "poetry run pytest tests"
watch = "poetry run ptw" watch = "poetry run ptw"
version = "poetry version --short" version = "poetry version --short"
bump = ["_bump_1", "_bump_2"] bump = ["_bump_1", "_bump_2"]

View File

@ -1,5 +1,6 @@
"""Script to generate migrations for the migration script.""" """Script to generate migrations for the migration script."""
import json import json
import pkgutil
import click import click
@ -38,10 +39,39 @@ def partner(pkg: str, output: str) -> None:
"""Generate migration scripts specifically for LangChain modules.""" """Generate migration scripts specifically for LangChain modules."""
click.echo("Migration script for LangChain generated.") click.echo("Migration script for LangChain generated.")
migrations = get_migrations_for_partner_package(pkg) migrations = get_migrations_for_partner_package(pkg)
output_name = f"partner_{pkg}.json" if output is None else output # Run with python 3.9+
with open(output_name, "w") as f: output_name = f"{pkg.removeprefix('langchain_')}.json" if output is None else output
f.write(json.dumps(migrations, indent=2, sort_keys=True)) if migrations:
click.secho(f"LangChain migration script saved to {output_name}") with open(output_name, "w") as f:
f.write(json.dumps(migrations, indent=2, sort_keys=True))
click.secho(f"LangChain migration script saved to {output_name}")
else:
click.secho(f"No migrations found for {pkg}", fg="yellow")
@cli.command()
def all_installed_partner_pkgs() -> None:
"""Generate migration scripts for all LangChain modules."""
# Will generate migrations for all pather packages.
# Define as "langchain_<partner_name>".
# First let's determine which packages are installed in the environment
# and then generate migrations for them.
langchain_pkgs = [
name
for _, name, _ in pkgutil.iter_modules()
if name.startswith("langchain_")
and name not in {"langchain_core", "langchain_cli", "langchain_community"}
]
for pkg in langchain_pkgs:
migrations = get_migrations_for_partner_package(pkg)
# Run with python 3.9+
output_name = f"{pkg.removeprefix('langchain_')}.json"
if migrations:
with open(output_name, "w") as f:
f.write(json.dumps(migrations, indent=2, sort_keys=True))
click.secho(f"LangChain migration script saved to {output_name}")
else:
click.secho(f"No migrations found for {pkg}", fg="yellow")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1 +0,0 @@
[["langchain_community.embeddings.openai.OpenAIEmbeddings", "langchain_openai.embeddings.base.OpenAIEmbeddings"], ["langchain_community.embeddings.azure_openai.AzureOpenAIEmbeddings", "langchain_openai.embeddings.azure.AzureOpenAIEmbeddings"], ["langchain_community.chat_models.openai.ChatOpenAI", "langchain_openai.chat_models.base.ChatOpenAI"], ["langchain_community.chat_models.azure_openai.AzureChatOpenAI", "langchain_openai.chat_models.azure.AzureChatOpenAI"]]

View File

@ -1,7 +1,7 @@
from tests.unit_tests.migrate.integration.case import Case from tests.unit_tests.migrate.cli_runner.case import Case
from tests.unit_tests.migrate.integration.cases import imports from tests.unit_tests.migrate.cli_runner.cases import imports
from tests.unit_tests.migrate.integration.file import File from tests.unit_tests.migrate.cli_runner.file import File
from tests.unit_tests.migrate.integration.folder import Folder from tests.unit_tests.migrate.cli_runner.folder import Folder
cases = [ cases = [
Case( Case(

View File

@ -1,5 +1,5 @@
from tests.unit_tests.migrate.integration.case import Case from tests.unit_tests.migrate.cli_runner.case import Case
from tests.unit_tests.migrate.integration.file import File from tests.unit_tests.migrate.cli_runner.file import File
cases = [ cases = [
Case( Case(
@ -7,7 +7,7 @@ cases = [
source=File( source=File(
"app.py", "app.py",
content=[ content=[
"from langchain.chat_models import ChatOpenAI", "from langchain_community.chat_models import ChatOpenAI",
"", "",
"", "",
"class foo:", "class foo:",

View File

@ -28,4 +28,19 @@ def test_generate_migrations() -> None:
"langchain_community.chat_models.azure_openai.AzureChatOpenAI", "langchain_community.chat_models.azure_openai.AzureChatOpenAI",
"langchain_openai.AzureChatOpenAI", "langchain_openai.AzureChatOpenAI",
), ),
("langchain_community.llms.AzureOpenAI", "langchain_openai.AzureOpenAI"),
("langchain_community.llms.OpenAI", "langchain_openai.OpenAI"),
(
"langchain_community.embeddings.AzureOpenAIEmbeddings",
"langchain_openai.AzureOpenAIEmbeddings",
),
(
"langchain_community.embeddings.OpenAIEmbeddings",
"langchain_openai.OpenAIEmbeddings",
),
(
"langchain_community.chat_models.AzureChatOpenAI",
"langchain_openai.AzureChatOpenAI",
),
("langchain_community.chat_models.ChatOpenAI", "langchain_openai.ChatOpenAI"),
] ]

View File

@ -19,22 +19,32 @@ class TestReplaceImportsCommand(CodemodTest):
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
""" """
after = """ after = """
from langchain_community.chat_models import ChatOpenAI
"""
self.assertCodemod(before, after)
def test_from_community_to_partner(self) -> None:
"""Test that we can replace imports from community to partner."""
before = """
from langchain_community.chat_models import ChatOpenAI
"""
after = """
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
""" """
self.assertCodemod(before, after) self.assertCodemod(before, after)
def test_noop_import(self) -> None: def test_noop_import(self) -> None:
code = """ code = """
from foo import ChatOpenAI from foo import ChatOpenAI
""" """
self.assertCodemod(code, code) self.assertCodemod(code, code)
def test_mixed_imports(self) -> None: def test_mixed_imports(self) -> None:
before = """ before = """
from langchain.chat_models import ChatOpenAI, ChatAnthropic, foo from langchain_community.chat_models import ChatOpenAI, ChatAnthropic, foo
""" """
after = """ after = """
from langchain.chat_models import foo from langchain_community.chat_models import foo
from langchain_anthropic import ChatAnthropic from langchain_anthropic import ChatAnthropic
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
""" """