mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-11 07:50:47 +00:00
cli[minor]: Add ipynb support, add text_splitters (#20963)
This commit is contained in:
@@ -8,7 +8,7 @@ import os
|
||||
import time
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Tuple, Type, TypeVar, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import libcst as cst
|
||||
import rich
|
||||
@@ -41,6 +41,9 @@ def main(
|
||||
default=DEFAULT_IGNORES, help="Ignore a path glob pattern."
|
||||
),
|
||||
log_file: Path = Option("log.txt", help="Log errors to this file."),
|
||||
include_ipynb: bool = Option(
|
||||
False, help="Include Jupyter Notebook files in the migration."
|
||||
),
|
||||
):
|
||||
"""Migrate langchain to the most recent version."""
|
||||
if not diff:
|
||||
@@ -63,6 +66,8 @@ def main(
|
||||
else:
|
||||
package = path
|
||||
all_files = sorted(package.glob("**/*.py"))
|
||||
if include_ipynb:
|
||||
all_files.extend(sorted(package.glob("**/*.ipynb")))
|
||||
|
||||
filtered_files = [
|
||||
file
|
||||
@@ -86,11 +91,9 @@ def main(
|
||||
scratch: dict[str, Any] = {}
|
||||
start_time = time.time()
|
||||
|
||||
codemods = gather_codemods(disabled=disable)
|
||||
|
||||
log_fp = log_file.open("a+", encoding="utf8")
|
||||
partial_run_codemods = functools.partial(
|
||||
run_codemods, codemods, metadata_manager, scratch, package, diff
|
||||
get_and_run_codemods, disable, metadata_manager, scratch, package, diff
|
||||
)
|
||||
with Progress(*Progress.get_default_columns(), transient=True) as progress:
|
||||
task = progress.add_task(description="Executing codemods...", total=len(files))
|
||||
@@ -127,6 +130,121 @@ def main(
|
||||
raise Exit(1)
|
||||
|
||||
|
||||
def get_and_run_codemods(
|
||||
disabled_rules: List[Rule],
|
||||
metadata_manager: FullRepoManager,
|
||||
scratch: Dict[str, Any],
|
||||
package: Path,
|
||||
diff: bool,
|
||||
filename: str,
|
||||
) -> Tuple[Union[str, None], Union[List[str], None]]:
|
||||
"""Run codemods from rules.
|
||||
|
||||
Wrapper around run_codemods to be used with multiprocessing.Pool.
|
||||
"""
|
||||
codemods = gather_codemods(disabled=disabled_rules)
|
||||
return run_codemods(codemods, metadata_manager, scratch, package, diff, filename)
|
||||
|
||||
|
||||
def _rewrite_file(
|
||||
filename: str,
|
||||
codemods: List[Type[ContextAwareTransformer]],
|
||||
diff: bool,
|
||||
context: CodemodContext,
|
||||
) -> Tuple[Union[str, None], Union[List[str], None]]:
|
||||
file_path = Path(filename)
|
||||
with file_path.open("r+", encoding="utf-8") as fp:
|
||||
code = fp.read()
|
||||
fp.seek(0)
|
||||
|
||||
input_tree = cst.parse_module(code)
|
||||
|
||||
for codemod in codemods:
|
||||
transformer = codemod(context=context)
|
||||
output_tree = transformer.transform_module(input_tree)
|
||||
input_tree = output_tree
|
||||
|
||||
output_code = input_tree.code
|
||||
if code != output_code:
|
||||
if diff:
|
||||
lines = difflib.unified_diff(
|
||||
code.splitlines(keepends=True),
|
||||
output_code.splitlines(keepends=True),
|
||||
fromfile=filename,
|
||||
tofile=filename,
|
||||
)
|
||||
return None, list(lines)
|
||||
else:
|
||||
fp.write(output_code)
|
||||
fp.truncate()
|
||||
return None, None
|
||||
|
||||
|
||||
def _rewrite_notebook(
|
||||
filename: str,
|
||||
codemods: List[Type[ContextAwareTransformer]],
|
||||
diff: bool,
|
||||
context: CodemodContext,
|
||||
) -> Tuple[Optional[str], Optional[List[str]]]:
|
||||
"""Try to rewrite a Jupyter Notebook file."""
|
||||
import nbformat
|
||||
|
||||
file_path = Path(filename)
|
||||
if file_path.suffix != ".ipynb":
|
||||
raise ValueError("Only Jupyter Notebook files (.ipynb) are supported.")
|
||||
|
||||
with file_path.open("r", encoding="utf-8") as fp:
|
||||
notebook = nbformat.read(fp, as_version=4)
|
||||
|
||||
diffs = []
|
||||
|
||||
for cell in notebook.cells:
|
||||
if cell.cell_type == "code":
|
||||
code = "".join(cell.source)
|
||||
|
||||
# Skip code if any of the lines begin with a magic command or
|
||||
# a ! command.
|
||||
# We can try to handle later.
|
||||
if any(
|
||||
line.startswith("!") or line.startswith("%")
|
||||
for line in code.splitlines()
|
||||
):
|
||||
continue
|
||||
|
||||
input_tree = cst.parse_module(code)
|
||||
|
||||
# TODO(Team): Quick hack, need to figure out
|
||||
# how to handle this correctly.
|
||||
# This prevents the code from trying to re-insert the imports
|
||||
# for every cell in the notebook.
|
||||
local_context = CodemodContext()
|
||||
|
||||
for codemod in codemods:
|
||||
transformer = codemod(context=local_context)
|
||||
output_tree = transformer.transform_module(input_tree)
|
||||
input_tree = output_tree
|
||||
|
||||
output_code = input_tree.code
|
||||
if code != output_code:
|
||||
cell.source = output_code.splitlines(keepends=True)
|
||||
if diff:
|
||||
cell_diff = difflib.unified_diff(
|
||||
code.splitlines(keepends=True),
|
||||
output_code.splitlines(keepends=True),
|
||||
fromfile=filename,
|
||||
tofile=filename,
|
||||
)
|
||||
diffs.extend(list(cell_diff))
|
||||
|
||||
if diff:
|
||||
return None, diffs
|
||||
|
||||
with file_path.open("w", encoding="utf-8") as fp:
|
||||
nbformat.write(notebook, fp)
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def run_codemods(
|
||||
codemods: List[Type[ContextAwareTransformer]],
|
||||
metadata_manager: FullRepoManager,
|
||||
@@ -145,32 +263,10 @@ def run_codemods(
|
||||
)
|
||||
context.scratch.update(scratch)
|
||||
|
||||
file_path = Path(filename)
|
||||
with file_path.open("r+", encoding="utf-8") as fp:
|
||||
code = fp.read()
|
||||
fp.seek(0)
|
||||
|
||||
input_tree = cst.parse_module(code)
|
||||
|
||||
for codemod in codemods:
|
||||
transformer = codemod(context=context)
|
||||
output_tree = transformer.transform_module(input_tree)
|
||||
input_tree = output_tree
|
||||
|
||||
output_code = input_tree.code
|
||||
if code != output_code:
|
||||
if diff:
|
||||
lines = difflib.unified_diff(
|
||||
code.splitlines(keepends=True),
|
||||
output_code.splitlines(keepends=True),
|
||||
fromfile=filename,
|
||||
tofile=filename,
|
||||
)
|
||||
return None, list(lines)
|
||||
else:
|
||||
fp.write(output_code)
|
||||
fp.truncate()
|
||||
return None, None
|
||||
if filename.endswith(".ipynb"):
|
||||
return _rewrite_notebook(filename, codemods, diff, context)
|
||||
else:
|
||||
return _rewrite_file(filename, codemods, diff, context)
|
||||
except cst.ParserSyntaxError as exc:
|
||||
return (
|
||||
f"A syntax error happened on {filename}. This file cannot be "
|
||||
|
Reference in New Issue
Block a user