mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-02 19:47:13 +00:00
cli[minor]: Add first version of migrate (#20902)
Adds a first version of the migrate script.
This commit is contained in:
0
libs/cli/langchain_cli/namespaces/migrate.py
Normal file
0
libs/cli/langchain_cli/namespaces/migrate.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from enum import Enum
|
||||
from typing import List, Type
|
||||
|
||||
from libcst.codemod import ContextAwareTransformer
|
||||
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor
|
||||
|
||||
from langchain_cli.namespaces.migrate.codemods.replace_imports import (
|
||||
ReplaceImportsCodemod,
|
||||
)
|
||||
|
||||
|
||||
class Rule(str, Enum):
|
||||
R001 = "R001"
|
||||
"""Replace imports that have been moved."""
|
||||
|
||||
|
||||
def gather_codemods(disabled: List[Rule]) -> List[Type[ContextAwareTransformer]]:
|
||||
codemods: List[Type[ContextAwareTransformer]] = []
|
||||
|
||||
if Rule.R001 not in disabled:
|
||||
codemods.append(ReplaceImportsCodemod)
|
||||
|
||||
# Those codemods need to be the last ones.
|
||||
codemods.extend([RemoveImportsVisitor, AddImportsVisitor])
|
||||
return codemods
|
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,14 @@
|
||||
[
|
||||
[
|
||||
"langchain.chat_models.ChatOpenAI",
|
||||
"langchain_openai.ChatOpenAI"
|
||||
],
|
||||
[
|
||||
"langchain.chat_models.ChatOpenAI",
|
||||
"langchain_openai.ChatOpenAI"
|
||||
],
|
||||
[
|
||||
"langchain.chat_models.ChatAnthropic",
|
||||
"langchain_anthropic.ChatAnthropic"
|
||||
]
|
||||
]
|
@@ -0,0 +1,205 @@
|
||||
"""
|
||||
# Adapted from bump-pydantic
|
||||
# https://github.com/pydantic/bump-pydantic
|
||||
|
||||
This codemod deals with the following cases:
|
||||
|
||||
1. `from pydantic import BaseSettings`
|
||||
2. `from pydantic.settings import BaseSettings`
|
||||
3. `from pydantic import BaseSettings as <name>`
|
||||
4. `from pydantic.settings import BaseSettings as <name>` # TODO: This is not working.
|
||||
5. `import pydantic` -> `pydantic.BaseSettings`
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, Iterable, List, Sequence, Tuple, TypeVar
|
||||
|
||||
import libcst as cst
|
||||
import libcst.matchers as m
|
||||
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
|
||||
from libcst.codemod.visitors import AddImportsVisitor
|
||||
|
||||
HERE = os.path.dirname(__file__)
|
||||
|
||||
|
||||
def _load_migrations_by_file(path: str):
|
||||
migrations_path = os.path.join(HERE, path)
|
||||
with open(migrations_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return data
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _deduplicate_in_order(
|
||||
seq: Iterable[T], key: Callable[[T], str] = lambda x: x
|
||||
) -> List[T]:
|
||||
seen = set()
|
||||
seen_add = seen.add
|
||||
return [x for x in seq if not (key(x) in seen or seen_add(key(x)))]
|
||||
|
||||
|
||||
def _load_migrations():
|
||||
"""Load the migrations from the JSON file."""
|
||||
# Later earlier ones have higher precedence.
|
||||
paths = [
|
||||
"migrations_v0.2_partner.json",
|
||||
"migrations_v0.2.json",
|
||||
]
|
||||
|
||||
data = []
|
||||
for path in paths:
|
||||
data.extend(_load_migrations_by_file(path))
|
||||
|
||||
data = _deduplicate_in_order(data, key=lambda x: x[0])
|
||||
|
||||
imports: Dict[str, Tuple[str, str]] = {}
|
||||
|
||||
for old_path, new_path in data:
|
||||
# Parse the old parse which is of the format 'langchain.chat_models.ChatOpenAI'
|
||||
# into the module and class name.
|
||||
old_parts = old_path.split(".")
|
||||
old_module = ".".join(old_parts[:-1])
|
||||
old_class = old_parts[-1]
|
||||
old_path_str = f"{old_module}:{old_class}"
|
||||
|
||||
# Parse the new parse which is of the format 'langchain.chat_models.ChatOpenAI'
|
||||
# Into a 2-tuple of the module and class name.
|
||||
new_parts = new_path.split(".")
|
||||
new_module = ".".join(new_parts[:-1])
|
||||
new_class = new_parts[-1]
|
||||
new_path_str = (new_module, new_class)
|
||||
|
||||
imports[old_path_str] = new_path_str
|
||||
|
||||
return imports
|
||||
|
||||
|
||||
IMPORTS = _load_migrations()
|
||||
|
||||
|
||||
def resolve_module_parts(module_parts: list[str]) -> m.Attribute | m.Name:
|
||||
"""Converts a list of module parts to a `Name` or `Attribute` node."""
|
||||
if len(module_parts) == 1:
|
||||
return m.Name(module_parts[0])
|
||||
if len(module_parts) == 2:
|
||||
first, last = module_parts
|
||||
return m.Attribute(value=m.Name(first), attr=m.Name(last))
|
||||
last_name = module_parts.pop()
|
||||
attr = resolve_module_parts(module_parts)
|
||||
return m.Attribute(value=attr, attr=m.Name(last_name))
|
||||
|
||||
|
||||
def get_import_from_from_str(import_str: str) -> m.ImportFrom:
|
||||
"""Converts a string like `pydantic:BaseSettings` to Examples:
|
||||
>>> get_import_from_from_str("pydantic:BaseSettings")
|
||||
ImportFrom(
|
||||
module=Name("pydantic"),
|
||||
names=[ImportAlias(name=Name("BaseSettings"))],
|
||||
)
|
||||
>>> get_import_from_from_str("pydantic.settings:BaseSettings")
|
||||
ImportFrom(
|
||||
module=Attribute(value=Name("pydantic"), attr=Name("settings")),
|
||||
names=[ImportAlias(name=Name("BaseSettings"))],
|
||||
)
|
||||
>>> get_import_from_from_str("a.b.c:d")
|
||||
ImportFrom(
|
||||
module=Attribute(
|
||||
value=Attribute(value=Name("a"), attr=Name("b")), attr=Name("c")
|
||||
),
|
||||
names=[ImportAlias(name=Name("d"))],
|
||||
)
|
||||
"""
|
||||
module, name = import_str.split(":")
|
||||
module_parts = module.split(".")
|
||||
module_node = resolve_module_parts(module_parts)
|
||||
return m.ImportFrom(
|
||||
module=module_node,
|
||||
names=[m.ZeroOrMore(), m.ImportAlias(name=m.Name(value=name)), m.ZeroOrMore()],
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImportInfo:
|
||||
import_from: m.ImportFrom
|
||||
import_str: str
|
||||
to_import_str: tuple[str, str]
|
||||
|
||||
|
||||
IMPORT_INFOS = [
|
||||
ImportInfo(
|
||||
import_from=get_import_from_from_str(import_str),
|
||||
import_str=import_str,
|
||||
to_import_str=to_import_str,
|
||||
)
|
||||
for import_str, to_import_str in IMPORTS.items()
|
||||
]
|
||||
IMPORT_MATCH = m.OneOf(*[info.import_from for info in IMPORT_INFOS])
|
||||
|
||||
|
||||
class ReplaceImportsCodemod(VisitorBasedCodemodCommand):
|
||||
@m.leave(IMPORT_MATCH)
|
||||
def leave_replace_import(
|
||||
self, _: cst.ImportFrom, updated_node: cst.ImportFrom
|
||||
) -> cst.ImportFrom:
|
||||
for import_info in IMPORT_INFOS:
|
||||
if m.matches(updated_node, import_info.import_from):
|
||||
aliases: Sequence[cst.ImportAlias] = updated_node.names # type: ignore
|
||||
# If multiple objects are imported in a single import statement,
|
||||
# we need to remove only the one we're replacing.
|
||||
AddImportsVisitor.add_needed_import(
|
||||
self.context, *import_info.to_import_str
|
||||
)
|
||||
if len(updated_node.names) > 1: # type: ignore
|
||||
names = [
|
||||
alias
|
||||
for alias in aliases
|
||||
if alias.name.value != import_info.to_import_str[-1]
|
||||
]
|
||||
names[-1] = names[-1].with_changes(comma=cst.MaybeSentinel.DEFAULT)
|
||||
updated_node = updated_node.with_changes(names=names)
|
||||
else:
|
||||
return cst.RemoveFromParent() # type: ignore[return-value]
|
||||
return updated_node
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import textwrap
|
||||
|
||||
from rich.console import Console
|
||||
|
||||
console = Console()
|
||||
|
||||
source = textwrap.dedent(
|
||||
"""
|
||||
from pydantic.settings import BaseSettings
|
||||
from pydantic.color import Color
|
||||
from pydantic.payment import PaymentCardNumber, PaymentCardBrand
|
||||
from pydantic import Color
|
||||
from pydantic import Color as Potato
|
||||
|
||||
|
||||
class Potato(BaseSettings):
|
||||
color: Color
|
||||
payment: PaymentCardNumber
|
||||
brand: PaymentCardBrand
|
||||
potato: Potato
|
||||
"""
|
||||
)
|
||||
console.print(source)
|
||||
console.print("=" * 80)
|
||||
|
||||
mod = cst.parse_module(source)
|
||||
context = CodemodContext(filename="main.py")
|
||||
wrapper = cst.MetadataWrapper(mod)
|
||||
command = ReplaceImportsCodemod(context=context)
|
||||
|
||||
mod = wrapper.visit(command)
|
||||
wrapper = cst.MetadataWrapper(mod)
|
||||
command = AddImportsVisitor(context=context) # type: ignore[assignment]
|
||||
mod = wrapper.visit(command)
|
||||
console.print(mod.code)
|
52
libs/cli/langchain_cli/namespaces/migrate/glob_helpers.py
Normal file
52
libs/cli/langchain_cli/namespaces/migrate/glob_helpers.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# Adapted from bump-pydantic
|
||||
# https://github.com/pydantic/bump-pydantic
|
||||
import fnmatch
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
MATCH_SEP = r"(?:/|\\)"
|
||||
MATCH_SEP_OR_END = r"(?:/|\\|\Z)"
|
||||
MATCH_NON_RECURSIVE = r"[^/\\]*"
|
||||
MATCH_RECURSIVE = r"(?:.*)"
|
||||
|
||||
|
||||
def glob_to_re(pattern: str) -> str:
|
||||
"""Translate a glob pattern to a regular expression for matching."""
|
||||
fragments: List[str] = []
|
||||
for segment in re.split(r"/|\\", pattern):
|
||||
if segment == "":
|
||||
continue
|
||||
if segment == "**":
|
||||
# Remove previous separator match, so the recursive match c
|
||||
# can match zero or more segments.
|
||||
if fragments and fragments[-1] == MATCH_SEP:
|
||||
fragments.pop()
|
||||
fragments.append(MATCH_RECURSIVE)
|
||||
elif "**" in segment:
|
||||
raise ValueError(
|
||||
"invalid pattern: '**' can only be an entire path component"
|
||||
)
|
||||
else:
|
||||
fragment = fnmatch.translate(segment)
|
||||
fragment = fragment.replace(r"(?s:", r"(?:")
|
||||
fragment = fragment.replace(r".*", MATCH_NON_RECURSIVE)
|
||||
fragment = fragment.replace(r"\Z", r"")
|
||||
fragments.append(fragment)
|
||||
fragments.append(MATCH_SEP)
|
||||
# Remove trailing MATCH_SEP, so it can be replaced with MATCH_SEP_OR_END.
|
||||
if fragments and fragments[-1] == MATCH_SEP:
|
||||
fragments.pop()
|
||||
fragments.append(MATCH_SEP_OR_END)
|
||||
return rf"(?s:{''.join(fragments)})"
|
||||
|
||||
|
||||
def match_glob(path: Path, pattern: str) -> bool:
|
||||
"""Check if a path matches a glob pattern.
|
||||
|
||||
If the pattern ends with a directory separator, the path must be a directory.
|
||||
"""
|
||||
match = bool(re.fullmatch(glob_to_re(pattern), str(path)))
|
||||
if pattern.endswith("/") or pattern.endswith("\\"):
|
||||
return match and path.is_dir()
|
||||
return match
|
194
libs/cli/langchain_cli/namespaces/migrate/main.py
Normal file
194
libs/cli/langchain_cli/namespaces/migrate/main.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""Migrate LangChain to the most recent version."""
|
||||
# Adapted from bump-pydantic
|
||||
# https://github.com/pydantic/bump-pydantic
|
||||
import difflib
|
||||
import functools
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Tuple, Type, TypeVar, Union
|
||||
|
||||
import libcst as cst
|
||||
import rich
|
||||
import typer
|
||||
from libcst.codemod import CodemodContext, ContextAwareTransformer
|
||||
from libcst.helpers import calculate_module_and_package
|
||||
from libcst.metadata import FullRepoManager, FullyQualifiedNameProvider, ScopeProvider
|
||||
from rich.console import Console
|
||||
from rich.progress import Progress
|
||||
from typer import Argument, Exit, Option, Typer
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from langchain_cli.namespaces.migrate.codemods import Rule, gather_codemods
|
||||
from langchain_cli.namespaces.migrate.glob_helpers import match_glob
|
||||
|
||||
app = Typer(invoke_without_command=True, add_completion=False)
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
DEFAULT_IGNORES = [".venv/**"]
|
||||
|
||||
|
||||
@app.callback()
|
||||
def main(
|
||||
path: Path = Argument(..., exists=True, dir_okay=True, allow_dash=False),
|
||||
disable: List[Rule] = Option(default=[], help="Disable a rule."),
|
||||
diff: bool = Option(False, help="Show diff instead of applying changes."),
|
||||
ignore: List[str] = Option(
|
||||
default=DEFAULT_IGNORES, help="Ignore a path glob pattern."
|
||||
),
|
||||
log_file: Path = Option("log.txt", help="Log errors to this file."),
|
||||
):
|
||||
"""Migrate langchain to the most recent version."""
|
||||
if not diff:
|
||||
rich.print("[bold red]Alert![/ bold red] langchain-cli migrate", end=": ")
|
||||
if not typer.confirm(
|
||||
"The migration process will modify your files. "
|
||||
"The migration is a `best-effort` process and is not expected to "
|
||||
"be perfect. "
|
||||
"Do you want to continue?"
|
||||
):
|
||||
raise Exit()
|
||||
console = Console(log_time=True)
|
||||
console.log("Start langchain-cli migrate")
|
||||
# NOTE: LIBCST_PARSER_TYPE=native is required according to https://github.com/Instagram/LibCST/issues/487.
|
||||
os.environ["LIBCST_PARSER_TYPE"] = "native"
|
||||
|
||||
if os.path.isfile(path):
|
||||
package = path.parent
|
||||
all_files = [path]
|
||||
else:
|
||||
package = path
|
||||
all_files = sorted(package.glob("**/*.py"))
|
||||
|
||||
filtered_files = [
|
||||
file
|
||||
for file in all_files
|
||||
if not any(match_glob(file, pattern) for pattern in ignore)
|
||||
]
|
||||
files = [str(file.relative_to(".")) for file in filtered_files]
|
||||
|
||||
if len(files) == 1:
|
||||
console.log("Found 1 file to process.")
|
||||
elif len(files) > 1:
|
||||
console.log(f"Found {len(files)} files to process.")
|
||||
else:
|
||||
console.log("No files to process.")
|
||||
raise Exit()
|
||||
|
||||
providers = {FullyQualifiedNameProvider, ScopeProvider}
|
||||
metadata_manager = FullRepoManager(".", files, providers=providers) # type: ignore[arg-type]
|
||||
metadata_manager.resolve_cache()
|
||||
|
||||
scratch: dict[str, Any] = {}
|
||||
start_time = time.time()
|
||||
|
||||
codemods = gather_codemods(disabled=disable)
|
||||
|
||||
log_fp = log_file.open("a+", encoding="utf8")
|
||||
partial_run_codemods = functools.partial(
|
||||
run_codemods, codemods, metadata_manager, scratch, package, diff
|
||||
)
|
||||
with Progress(*Progress.get_default_columns(), transient=True) as progress:
|
||||
task = progress.add_task(description="Executing codemods...", total=len(files))
|
||||
count_errors = 0
|
||||
difflines: List[List[str]] = []
|
||||
with multiprocessing.Pool() as pool:
|
||||
for error, _difflines in pool.imap_unordered(partial_run_codemods, files):
|
||||
progress.advance(task)
|
||||
|
||||
if _difflines is not None:
|
||||
difflines.append(_difflines)
|
||||
|
||||
if error is not None:
|
||||
count_errors += 1
|
||||
log_fp.writelines(error)
|
||||
|
||||
modified = [Path(f) for f in files if os.stat(f).st_mtime > start_time]
|
||||
|
||||
if not diff:
|
||||
if modified:
|
||||
console.log(f"Refactored {len(modified)} files.")
|
||||
else:
|
||||
console.log("No files were modified.")
|
||||
|
||||
for _difflines in difflines:
|
||||
color_diff(console, _difflines)
|
||||
|
||||
if count_errors > 0:
|
||||
console.log(f"Found {count_errors} errors. Please check the {log_file} file.")
|
||||
else:
|
||||
console.log("Run successfully!")
|
||||
|
||||
if difflines:
|
||||
raise Exit(1)
|
||||
|
||||
|
||||
def run_codemods(
|
||||
codemods: List[Type[ContextAwareTransformer]],
|
||||
metadata_manager: FullRepoManager,
|
||||
scratch: Dict[str, Any],
|
||||
package: Path,
|
||||
diff: bool,
|
||||
filename: str,
|
||||
) -> Tuple[Union[str, None], Union[List[str], None]]:
|
||||
try:
|
||||
module_and_package = calculate_module_and_package(str(package), filename)
|
||||
context = CodemodContext(
|
||||
metadata_manager=metadata_manager,
|
||||
filename=filename,
|
||||
full_module_name=module_and_package.name,
|
||||
full_package_name=module_and_package.package,
|
||||
)
|
||||
context.scratch.update(scratch)
|
||||
|
||||
file_path = Path(filename)
|
||||
with file_path.open("r+", encoding="utf-8") as fp:
|
||||
code = fp.read()
|
||||
fp.seek(0)
|
||||
|
||||
input_tree = cst.parse_module(code)
|
||||
|
||||
for codemod in codemods:
|
||||
transformer = codemod(context=context)
|
||||
output_tree = transformer.transform_module(input_tree)
|
||||
input_tree = output_tree
|
||||
|
||||
output_code = input_tree.code
|
||||
if code != output_code:
|
||||
if diff:
|
||||
lines = difflib.unified_diff(
|
||||
code.splitlines(keepends=True),
|
||||
output_code.splitlines(keepends=True),
|
||||
fromfile=filename,
|
||||
tofile=filename,
|
||||
)
|
||||
return None, list(lines)
|
||||
else:
|
||||
fp.write(output_code)
|
||||
fp.truncate()
|
||||
return None, None
|
||||
except cst.ParserSyntaxError as exc:
|
||||
return (
|
||||
f"A syntax error happened on {filename}. This file cannot be "
|
||||
f"formatted.\n"
|
||||
f"{exc}"
|
||||
), None
|
||||
except Exception:
|
||||
return f"An error happened on {filename}.\n{traceback.format_exc()}", None
|
||||
|
||||
|
||||
def color_diff(console: Console, lines: Iterable[str]) -> None:
|
||||
for line in lines:
|
||||
line = line.rstrip("\n")
|
||||
if line.startswith("+"):
|
||||
console.print(line, style="green")
|
||||
elif line.startswith("-"):
|
||||
console.print(line, style="red")
|
||||
elif line.startswith("^"):
|
||||
console.print(line, style="blue")
|
||||
else:
|
||||
console.print(line, style="white")
|
Reference in New Issue
Block a user