cli[minor]: Add ipynb support, add text_splitters (#20963)

This commit is contained in:
Eugene Yurtsev
2024-04-29 10:11:21 -04:00
committed by GitHub
parent 5e0b6b3e75
commit d781560722
13 changed files with 2632 additions and 6525 deletions

View File

@@ -8,7 +8,7 @@ import os
import time
import traceback
from pathlib import Path
from typing import Any, Dict, Iterable, List, Tuple, Type, TypeVar, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union
import libcst as cst
import rich
@@ -41,6 +41,9 @@ def main(
default=DEFAULT_IGNORES, help="Ignore a path glob pattern."
),
log_file: Path = Option("log.txt", help="Log errors to this file."),
include_ipynb: bool = Option(
False, help="Include Jupyter Notebook files in the migration."
),
):
"""Migrate langchain to the most recent version."""
if not diff:
@@ -63,6 +66,8 @@ def main(
else:
package = path
all_files = sorted(package.glob("**/*.py"))
if include_ipynb:
all_files.extend(sorted(package.glob("**/*.ipynb")))
filtered_files = [
file
@@ -86,11 +91,9 @@ def main(
scratch: dict[str, Any] = {}
start_time = time.time()
codemods = gather_codemods(disabled=disable)
log_fp = log_file.open("a+", encoding="utf8")
partial_run_codemods = functools.partial(
run_codemods, codemods, metadata_manager, scratch, package, diff
get_and_run_codemods, disable, metadata_manager, scratch, package, diff
)
with Progress(*Progress.get_default_columns(), transient=True) as progress:
task = progress.add_task(description="Executing codemods...", total=len(files))
@@ -127,6 +130,121 @@ def main(
raise Exit(1)
def get_and_run_codemods(
disabled_rules: List[Rule],
metadata_manager: FullRepoManager,
scratch: Dict[str, Any],
package: Path,
diff: bool,
filename: str,
) -> Tuple[Union[str, None], Union[List[str], None]]:
"""Run codemods from rules.
Wrapper around run_codemods to be used with multiprocessing.Pool.
"""
codemods = gather_codemods(disabled=disabled_rules)
return run_codemods(codemods, metadata_manager, scratch, package, diff, filename)
def _rewrite_file(
filename: str,
codemods: List[Type[ContextAwareTransformer]],
diff: bool,
context: CodemodContext,
) -> Tuple[Union[str, None], Union[List[str], None]]:
file_path = Path(filename)
with file_path.open("r+", encoding="utf-8") as fp:
code = fp.read()
fp.seek(0)
input_tree = cst.parse_module(code)
for codemod in codemods:
transformer = codemod(context=context)
output_tree = transformer.transform_module(input_tree)
input_tree = output_tree
output_code = input_tree.code
if code != output_code:
if diff:
lines = difflib.unified_diff(
code.splitlines(keepends=True),
output_code.splitlines(keepends=True),
fromfile=filename,
tofile=filename,
)
return None, list(lines)
else:
fp.write(output_code)
fp.truncate()
return None, None
def _rewrite_notebook(
filename: str,
codemods: List[Type[ContextAwareTransformer]],
diff: bool,
context: CodemodContext,
) -> Tuple[Optional[str], Optional[List[str]]]:
"""Try to rewrite a Jupyter Notebook file."""
import nbformat
file_path = Path(filename)
if file_path.suffix != ".ipynb":
raise ValueError("Only Jupyter Notebook files (.ipynb) are supported.")
with file_path.open("r", encoding="utf-8") as fp:
notebook = nbformat.read(fp, as_version=4)
diffs = []
for cell in notebook.cells:
if cell.cell_type == "code":
code = "".join(cell.source)
# Skip code if any of the lines begin with a magic command or
# a ! command.
# We can try to handle later.
if any(
line.startswith("!") or line.startswith("%")
for line in code.splitlines()
):
continue
input_tree = cst.parse_module(code)
# TODO(Team): Quick hack, need to figure out
# how to handle this correctly.
# This prevents the code from trying to re-insert the imports
# for every cell in the notebook.
local_context = CodemodContext()
for codemod in codemods:
transformer = codemod(context=local_context)
output_tree = transformer.transform_module(input_tree)
input_tree = output_tree
output_code = input_tree.code
if code != output_code:
cell.source = output_code.splitlines(keepends=True)
if diff:
cell_diff = difflib.unified_diff(
code.splitlines(keepends=True),
output_code.splitlines(keepends=True),
fromfile=filename,
tofile=filename,
)
diffs.extend(list(cell_diff))
if diff:
return None, diffs
with file_path.open("w", encoding="utf-8") as fp:
nbformat.write(notebook, fp)
return None, None
def run_codemods(
codemods: List[Type[ContextAwareTransformer]],
metadata_manager: FullRepoManager,
@@ -145,32 +263,10 @@ def run_codemods(
)
context.scratch.update(scratch)
file_path = Path(filename)
with file_path.open("r+", encoding="utf-8") as fp:
code = fp.read()
fp.seek(0)
input_tree = cst.parse_module(code)
for codemod in codemods:
transformer = codemod(context=context)
output_tree = transformer.transform_module(input_tree)
input_tree = output_tree
output_code = input_tree.code
if code != output_code:
if diff:
lines = difflib.unified_diff(
code.splitlines(keepends=True),
output_code.splitlines(keepends=True),
fromfile=filename,
tofile=filename,
)
return None, list(lines)
else:
fp.write(output_code)
fp.truncate()
return None, None
if filename.endswith(".ipynb"):
return _rewrite_notebook(filename, codemods, diff, context)
else:
return _rewrite_file(filename, codemods, diff, context)
except cst.ParserSyntaxError as exc:
return (
f"A syntax error happened on {filename}. This file cannot be "