cli[minor]: Add first version of migrate (#20902)

Adds a first version of the migrate script.
This commit is contained in:
Eugene Yurtsev
2024-04-26 10:50:21 -04:00
committed by GitHub
parent d95e9fb67f
commit 6598757037
24 changed files with 6294 additions and 3 deletions

View File

@@ -0,0 +1,25 @@
from enum import Enum
from typing import List, Type
from libcst.codemod import ContextAwareTransformer
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor
from langchain_cli.namespaces.migrate.codemods.replace_imports import (
ReplaceImportsCodemod,
)
class Rule(str, Enum):
R001 = "R001"
"""Replace imports that have been moved."""
def gather_codemods(disabled: List[Rule]) -> List[Type[ContextAwareTransformer]]:
codemods: List[Type[ContextAwareTransformer]] = []
if Rule.R001 not in disabled:
codemods.append(ReplaceImportsCodemod)
# Those codemods need to be the last ones.
codemods.extend([RemoveImportsVisitor, AddImportsVisitor])
return codemods

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,14 @@
[
[
"langchain.chat_models.ChatOpenAI",
"langchain_openai.ChatOpenAI"
],
[
"langchain.chat_models.ChatOpenAI",
"langchain_openai.ChatOpenAI"
],
[
"langchain.chat_models.ChatAnthropic",
"langchain_anthropic.ChatAnthropic"
]
]

View File

@@ -0,0 +1,205 @@
"""
# Adapted from bump-pydantic
# https://github.com/pydantic/bump-pydantic
This codemod deals with the following cases:
1. `from pydantic import BaseSettings`
2. `from pydantic.settings import BaseSettings`
3. `from pydantic import BaseSettings as <name>`
4. `from pydantic.settings import BaseSettings as <name>` # TODO: This is not working.
5. `import pydantic` -> `pydantic.BaseSettings`
"""
from __future__ import annotations
import json
import os
from dataclasses import dataclass
from typing import Callable, Dict, Iterable, List, Sequence, Tuple, TypeVar
import libcst as cst
import libcst.matchers as m
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
from libcst.codemod.visitors import AddImportsVisitor
HERE = os.path.dirname(__file__)
def _load_migrations_by_file(path: str):
migrations_path = os.path.join(HERE, path)
with open(migrations_path, "r", encoding="utf-8") as f:
data = json.load(f)
return data
T = TypeVar("T")
def _deduplicate_in_order(
seq: Iterable[T], key: Callable[[T], str] = lambda x: x
) -> List[T]:
seen = set()
seen_add = seen.add
return [x for x in seq if not (key(x) in seen or seen_add(key(x)))]
def _load_migrations():
"""Load the migrations from the JSON file."""
# Later earlier ones have higher precedence.
paths = [
"migrations_v0.2_partner.json",
"migrations_v0.2.json",
]
data = []
for path in paths:
data.extend(_load_migrations_by_file(path))
data = _deduplicate_in_order(data, key=lambda x: x[0])
imports: Dict[str, Tuple[str, str]] = {}
for old_path, new_path in data:
# Parse the old parse which is of the format 'langchain.chat_models.ChatOpenAI'
# into the module and class name.
old_parts = old_path.split(".")
old_module = ".".join(old_parts[:-1])
old_class = old_parts[-1]
old_path_str = f"{old_module}:{old_class}"
# Parse the new parse which is of the format 'langchain.chat_models.ChatOpenAI'
# Into a 2-tuple of the module and class name.
new_parts = new_path.split(".")
new_module = ".".join(new_parts[:-1])
new_class = new_parts[-1]
new_path_str = (new_module, new_class)
imports[old_path_str] = new_path_str
return imports
IMPORTS = _load_migrations()
def resolve_module_parts(module_parts: list[str]) -> m.Attribute | m.Name:
"""Converts a list of module parts to a `Name` or `Attribute` node."""
if len(module_parts) == 1:
return m.Name(module_parts[0])
if len(module_parts) == 2:
first, last = module_parts
return m.Attribute(value=m.Name(first), attr=m.Name(last))
last_name = module_parts.pop()
attr = resolve_module_parts(module_parts)
return m.Attribute(value=attr, attr=m.Name(last_name))
def get_import_from_from_str(import_str: str) -> m.ImportFrom:
"""Converts a string like `pydantic:BaseSettings` to Examples:
>>> get_import_from_from_str("pydantic:BaseSettings")
ImportFrom(
module=Name("pydantic"),
names=[ImportAlias(name=Name("BaseSettings"))],
)
>>> get_import_from_from_str("pydantic.settings:BaseSettings")
ImportFrom(
module=Attribute(value=Name("pydantic"), attr=Name("settings")),
names=[ImportAlias(name=Name("BaseSettings"))],
)
>>> get_import_from_from_str("a.b.c:d")
ImportFrom(
module=Attribute(
value=Attribute(value=Name("a"), attr=Name("b")), attr=Name("c")
),
names=[ImportAlias(name=Name("d"))],
)
"""
module, name = import_str.split(":")
module_parts = module.split(".")
module_node = resolve_module_parts(module_parts)
return m.ImportFrom(
module=module_node,
names=[m.ZeroOrMore(), m.ImportAlias(name=m.Name(value=name)), m.ZeroOrMore()],
)
@dataclass
class ImportInfo:
import_from: m.ImportFrom
import_str: str
to_import_str: tuple[str, str]
IMPORT_INFOS = [
ImportInfo(
import_from=get_import_from_from_str(import_str),
import_str=import_str,
to_import_str=to_import_str,
)
for import_str, to_import_str in IMPORTS.items()
]
IMPORT_MATCH = m.OneOf(*[info.import_from for info in IMPORT_INFOS])
class ReplaceImportsCodemod(VisitorBasedCodemodCommand):
@m.leave(IMPORT_MATCH)
def leave_replace_import(
self, _: cst.ImportFrom, updated_node: cst.ImportFrom
) -> cst.ImportFrom:
for import_info in IMPORT_INFOS:
if m.matches(updated_node, import_info.import_from):
aliases: Sequence[cst.ImportAlias] = updated_node.names # type: ignore
# If multiple objects are imported in a single import statement,
# we need to remove only the one we're replacing.
AddImportsVisitor.add_needed_import(
self.context, *import_info.to_import_str
)
if len(updated_node.names) > 1: # type: ignore
names = [
alias
for alias in aliases
if alias.name.value != import_info.to_import_str[-1]
]
names[-1] = names[-1].with_changes(comma=cst.MaybeSentinel.DEFAULT)
updated_node = updated_node.with_changes(names=names)
else:
return cst.RemoveFromParent() # type: ignore[return-value]
return updated_node
if __name__ == "__main__":
import textwrap
from rich.console import Console
console = Console()
source = textwrap.dedent(
"""
from pydantic.settings import BaseSettings
from pydantic.color import Color
from pydantic.payment import PaymentCardNumber, PaymentCardBrand
from pydantic import Color
from pydantic import Color as Potato
class Potato(BaseSettings):
color: Color
payment: PaymentCardNumber
brand: PaymentCardBrand
potato: Potato
"""
)
console.print(source)
console.print("=" * 80)
mod = cst.parse_module(source)
context = CodemodContext(filename="main.py")
wrapper = cst.MetadataWrapper(mod)
command = ReplaceImportsCodemod(context=context)
mod = wrapper.visit(command)
wrapper = cst.MetadataWrapper(mod)
command = AddImportsVisitor(context=context) # type: ignore[assignment]
mod = wrapper.visit(command)
console.print(mod.code)

View File

@@ -0,0 +1,52 @@
# Adapted from bump-pydantic
# https://github.com/pydantic/bump-pydantic
import fnmatch
import re
from pathlib import Path
from typing import List
MATCH_SEP = r"(?:/|\\)"
MATCH_SEP_OR_END = r"(?:/|\\|\Z)"
MATCH_NON_RECURSIVE = r"[^/\\]*"
MATCH_RECURSIVE = r"(?:.*)"
def glob_to_re(pattern: str) -> str:
"""Translate a glob pattern to a regular expression for matching."""
fragments: List[str] = []
for segment in re.split(r"/|\\", pattern):
if segment == "":
continue
if segment == "**":
# Remove previous separator match, so the recursive match c
# can match zero or more segments.
if fragments and fragments[-1] == MATCH_SEP:
fragments.pop()
fragments.append(MATCH_RECURSIVE)
elif "**" in segment:
raise ValueError(
"invalid pattern: '**' can only be an entire path component"
)
else:
fragment = fnmatch.translate(segment)
fragment = fragment.replace(r"(?s:", r"(?:")
fragment = fragment.replace(r".*", MATCH_NON_RECURSIVE)
fragment = fragment.replace(r"\Z", r"")
fragments.append(fragment)
fragments.append(MATCH_SEP)
# Remove trailing MATCH_SEP, so it can be replaced with MATCH_SEP_OR_END.
if fragments and fragments[-1] == MATCH_SEP:
fragments.pop()
fragments.append(MATCH_SEP_OR_END)
return rf"(?s:{''.join(fragments)})"
def match_glob(path: Path, pattern: str) -> bool:
"""Check if a path matches a glob pattern.
If the pattern ends with a directory separator, the path must be a directory.
"""
match = bool(re.fullmatch(glob_to_re(pattern), str(path)))
if pattern.endswith("/") or pattern.endswith("\\"):
return match and path.is_dir()
return match

View File

@@ -0,0 +1,194 @@
"""Migrate LangChain to the most recent version."""
# Adapted from bump-pydantic
# https://github.com/pydantic/bump-pydantic
import difflib
import functools
import multiprocessing
import os
import time
import traceback
from pathlib import Path
from typing import Any, Dict, Iterable, List, Tuple, Type, TypeVar, Union
import libcst as cst
import rich
import typer
from libcst.codemod import CodemodContext, ContextAwareTransformer
from libcst.helpers import calculate_module_and_package
from libcst.metadata import FullRepoManager, FullyQualifiedNameProvider, ScopeProvider
from rich.console import Console
from rich.progress import Progress
from typer import Argument, Exit, Option, Typer
from typing_extensions import ParamSpec
from langchain_cli.namespaces.migrate.codemods import Rule, gather_codemods
from langchain_cli.namespaces.migrate.glob_helpers import match_glob
app = Typer(invoke_without_command=True, add_completion=False)
P = ParamSpec("P")
T = TypeVar("T")
DEFAULT_IGNORES = [".venv/**"]
@app.callback()
def main(
path: Path = Argument(..., exists=True, dir_okay=True, allow_dash=False),
disable: List[Rule] = Option(default=[], help="Disable a rule."),
diff: bool = Option(False, help="Show diff instead of applying changes."),
ignore: List[str] = Option(
default=DEFAULT_IGNORES, help="Ignore a path glob pattern."
),
log_file: Path = Option("log.txt", help="Log errors to this file."),
):
"""Migrate langchain to the most recent version."""
if not diff:
rich.print("[bold red]Alert![/ bold red] langchain-cli migrate", end=": ")
if not typer.confirm(
"The migration process will modify your files. "
"The migration is a `best-effort` process and is not expected to "
"be perfect. "
"Do you want to continue?"
):
raise Exit()
console = Console(log_time=True)
console.log("Start langchain-cli migrate")
# NOTE: LIBCST_PARSER_TYPE=native is required according to https://github.com/Instagram/LibCST/issues/487.
os.environ["LIBCST_PARSER_TYPE"] = "native"
if os.path.isfile(path):
package = path.parent
all_files = [path]
else:
package = path
all_files = sorted(package.glob("**/*.py"))
filtered_files = [
file
for file in all_files
if not any(match_glob(file, pattern) for pattern in ignore)
]
files = [str(file.relative_to(".")) for file in filtered_files]
if len(files) == 1:
console.log("Found 1 file to process.")
elif len(files) > 1:
console.log(f"Found {len(files)} files to process.")
else:
console.log("No files to process.")
raise Exit()
providers = {FullyQualifiedNameProvider, ScopeProvider}
metadata_manager = FullRepoManager(".", files, providers=providers) # type: ignore[arg-type]
metadata_manager.resolve_cache()
scratch: dict[str, Any] = {}
start_time = time.time()
codemods = gather_codemods(disabled=disable)
log_fp = log_file.open("a+", encoding="utf8")
partial_run_codemods = functools.partial(
run_codemods, codemods, metadata_manager, scratch, package, diff
)
with Progress(*Progress.get_default_columns(), transient=True) as progress:
task = progress.add_task(description="Executing codemods...", total=len(files))
count_errors = 0
difflines: List[List[str]] = []
with multiprocessing.Pool() as pool:
for error, _difflines in pool.imap_unordered(partial_run_codemods, files):
progress.advance(task)
if _difflines is not None:
difflines.append(_difflines)
if error is not None:
count_errors += 1
log_fp.writelines(error)
modified = [Path(f) for f in files if os.stat(f).st_mtime > start_time]
if not diff:
if modified:
console.log(f"Refactored {len(modified)} files.")
else:
console.log("No files were modified.")
for _difflines in difflines:
color_diff(console, _difflines)
if count_errors > 0:
console.log(f"Found {count_errors} errors. Please check the {log_file} file.")
else:
console.log("Run successfully!")
if difflines:
raise Exit(1)
def run_codemods(
codemods: List[Type[ContextAwareTransformer]],
metadata_manager: FullRepoManager,
scratch: Dict[str, Any],
package: Path,
diff: bool,
filename: str,
) -> Tuple[Union[str, None], Union[List[str], None]]:
try:
module_and_package = calculate_module_and_package(str(package), filename)
context = CodemodContext(
metadata_manager=metadata_manager,
filename=filename,
full_module_name=module_and_package.name,
full_package_name=module_and_package.package,
)
context.scratch.update(scratch)
file_path = Path(filename)
with file_path.open("r+", encoding="utf-8") as fp:
code = fp.read()
fp.seek(0)
input_tree = cst.parse_module(code)
for codemod in codemods:
transformer = codemod(context=context)
output_tree = transformer.transform_module(input_tree)
input_tree = output_tree
output_code = input_tree.code
if code != output_code:
if diff:
lines = difflib.unified_diff(
code.splitlines(keepends=True),
output_code.splitlines(keepends=True),
fromfile=filename,
tofile=filename,
)
return None, list(lines)
else:
fp.write(output_code)
fp.truncate()
return None, None
except cst.ParserSyntaxError as exc:
return (
f"A syntax error happened on {filename}. This file cannot be "
f"formatted.\n"
f"{exc}"
), None
except Exception:
return f"An error happened on {filename}.\n{traceback.format_exc()}", None
def color_diff(console: Console, lines: Iterable[str]) -> None:
for line in lines:
line = line.rstrip("\n")
if line.startswith("+"):
console.print(line, style="green")
elif line.startswith("-"):
console.print(line, style="red")
elif line.startswith("^"):
console.print(line, style="blue")
else:
console.print(line, style="white")