From d55355f00a55020e0ff805d7304a50862031aec1 Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Sun, 31 Mar 2024 10:03:02 -0400 Subject: [PATCH] check for deprecated imports --- docs/scripts/check_imports.py | 73 +++++++++++++++++++++++------------ 1 file changed, 49 insertions(+), 24 deletions(-) diff --git a/docs/scripts/check_imports.py b/docs/scripts/check_imports.py index c416932432d..d40c323a963 100644 --- a/docs/scripts/check_imports.py +++ b/docs/scripts/check_imports.py @@ -4,9 +4,11 @@ import json import logging import os import re -import warnings from pathlib import Path from typing import List, Tuple +import warnings + +from langchain_core._api import LangChainDeprecationWarning logger = logging.getLogger(__name__) @@ -14,9 +16,10 @@ DOCS_DIR = Path(os.path.abspath(__file__)).parents[1] / "docs" import_pattern = re.compile( r"import\s+(\w+)|from\s+([\w\.]+)\s+import\s+((?:\w+(?:,\s*)?)+|\(.*?\))", re.DOTALL ) +Import = Tuple[str, str] -def _get_imports_from_code_cell(code_lines: str) -> List[Tuple[str, str]]: +def _get_imports_from_code_cell(code_lines: str) -> List[Import]: """Get (module, import) statements from a single code cell.""" import_statements = [] for line in code_lines: @@ -39,7 +42,7 @@ def _get_imports_from_code_cell(code_lines: str) -> List[Tuple[str, str]]: return import_statements -def _extract_import_statements(notebook_path: str) -> List[Tuple[str, str]]: +def _extract_import_statements(notebook_path: str) -> List[Import]: """Get (module, import) statements from a Jupyter notebook.""" with open(notebook_path, "r", encoding="utf-8") as file: notebook = json.load(file) @@ -51,31 +54,43 @@ def _extract_import_statements(notebook_path: str) -> List[Tuple[str, str]]: return import_statements -def _get_bad_imports(import_statements: List[Tuple[str, str]]) -> List[Tuple[str, str]]: - """Collect offending import statements.""" +def _get_bad_imports(import_statements: List[Import]) -> Tuple[List[Import], List[Import]]: offending_imports = [] + deprecated_imports = [] + for module, item in import_statements: try: - if item: - try: - # submodule - full_module_name = f"{module}.{item}" - importlib.import_module(full_module_name) - except ModuleNotFoundError: - # attribute + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter("always") + + if item: try: - imported_module = importlib.import_module(module) - getattr(imported_module, item) - except AttributeError: + # submodule + full_module_name = f"{module}.{item}" + importlib.import_module(full_module_name) + except ModuleNotFoundError: + # attribute + try: + imported_module = importlib.import_module(module) + getattr(imported_module, item) + except AttributeError: + offending_imports.append((module, item)) + except Exception: offending_imports.append((module, item)) - except Exception: - offending_imports.append((module, item)) - else: - importlib.import_module(module) + else: + importlib.import_module(module) + + # Check for deprecation warnings + + for warning in caught_warnings: + if issubclass(warning.category, LangChainDeprecationWarning): + deprecated_imports.append((module, item)) + except Exception: offending_imports.append((module, item)) - return offending_imports + return offending_imports, deprecated_imports + def _is_relevant_import(module: str) -> bool: @@ -104,6 +119,7 @@ def _serialize_bad_imports(bad_files: list) -> str: def check_notebooks(directory: str) -> list: """Check notebooks for broken import statements.""" bad_files = [] + deprecated_files = [] for root, _, files in os.walk(directory): for file in files: if file.endswith(".ipynb") and not file.endswith("-checkpoint.ipynb"): @@ -113,18 +129,27 @@ def check_notebooks(directory: str) -> list: for module, item in _extract_import_statements(notebook_path) if _is_relevant_import(module) ] - bad_imports = _get_bad_imports(import_statements) + bad_imports, deprecated_imports = _get_bad_imports(import_statements) if bad_imports: bad_files.append( ( - os.path.join(root, file), + os.path.join(root, file).split("docs/")[-1], bad_imports, ) ) - return bad_files + if deprecated_imports: + deprecated_files.append( + ( + os.path.join(root, file).split("docs/")[-1], + deprecated_imports, + ) + ) + return bad_files, deprecated_files if __name__ == "__main__": - bad_files = check_notebooks(DOCS_DIR) + bad_files, deprecated_files = check_notebooks(DOCS_DIR) + if deprecated_files: + logger.warning("Found files with deprecated imports:\n" f"{_serialize_bad_imports(deprecated_files)}") if bad_files: raise ImportError("Found bad imports:\n" f"{_serialize_bad_imports(bad_files)}")