"""This script checks documentation for broken import statements.""" import importlib import json import logging import os import re from pathlib import Path from typing import List, Tuple import warnings from langchain_core._api import LangChainDeprecationWarning logger = logging.getLogger(__name__) DOCS_DIR = Path(os.path.abspath(__file__)).parents[1] / "docs" import_pattern = re.compile( r"import\s+(\w+)|from\s+([\w\.]+)\s+import\s+((?:\w+(?:,\s*)?)+|\(.*?\))", re.DOTALL ) Import = Tuple[str, str] def _get_imports_from_code_cell(code_lines: str) -> List[Import]: """Get (module, import) statements from a single code cell.""" import_statements = [] for line in code_lines: line = line.strip() if line.startswith("#") or not line: continue # Join lines that end with a backslash if line.endswith("\\"): line = line[:-1].rstrip() + " " continue matches = import_pattern.findall(line) for match in matches: if match[0]: # simple import statement import_statements.append((match[0], "")) else: # from ___ import statement module, items = match[1], match[2] items_list = items.replace(" ", "").split(",") for item in items_list: import_statements.append((module, item)) return import_statements def _extract_import_statements(notebook_path: str) -> List[Import]: """Get (module, import) statements from a Jupyter notebook.""" with open(notebook_path, "r", encoding="utf-8") as file: notebook = json.load(file) code_cells = [cell for cell in notebook["cells"] if cell["cell_type"] == "code"] import_statements = [] for cell in code_cells: code_lines = cell["source"] import_statements.extend(_get_imports_from_code_cell(code_lines)) return import_statements def _get_bad_imports(import_statements: List[Import]) -> Tuple[List[Import], List[Import]]: offending_imports = [] deprecated_imports = [] for module, item in import_statements: try: with warnings.catch_warnings(record=True) as caught_warnings: warnings.simplefilter("always") if item: try: # submodule full_module_name = f"{module}.{item}" importlib.import_module(full_module_name) except ModuleNotFoundError: # attribute try: imported_module = importlib.import_module(module) getattr(imported_module, item) except AttributeError: offending_imports.append((module, item)) except Exception: offending_imports.append((module, item)) else: importlib.import_module(module) # Check for deprecation warnings for warning in caught_warnings: if issubclass(warning.category, LangChainDeprecationWarning): deprecated_imports.append((module, item)) except Exception: offending_imports.append((module, item)) return offending_imports, deprecated_imports def _is_relevant_import(module: str) -> bool: """Check if module is recognized.""" # Ignore things like langchain_{bla}, where bla is unrecognized. recognized_packages = [ "langchain", "langchain_core", "langchain_community", "langchain_experimental", "langchain_text_splitters", ] return module.split(".")[0] in recognized_packages def _serialize_bad_imports(bad_files: list) -> str: """Serialize bad imports to a string.""" bad_imports_str = "" for file, bad_imports in bad_files: bad_imports_str += f"File: {file}\n" for module, item in bad_imports: bad_imports_str += f" {module}.{item}\n" return bad_imports_str def check_notebooks(directory: str) -> list: """Check notebooks for broken import statements.""" bad_files = [] deprecated_files = [] for root, _, files in os.walk(directory): for file in files: if file.endswith(".ipynb") and not file.endswith("-checkpoint.ipynb"): notebook_path = os.path.join(root, file) import_statements = [ (module, item) for module, item in _extract_import_statements(notebook_path) if _is_relevant_import(module) ] bad_imports, deprecated_imports = _get_bad_imports(import_statements) if bad_imports: bad_files.append( ( os.path.join(root, file).split("docs/")[-1], bad_imports, ) ) if deprecated_imports: deprecated_files.append( ( os.path.join(root, file).split("docs/")[-1], deprecated_imports, ) ) return bad_files, deprecated_files if __name__ == "__main__": bad_files, deprecated_files = check_notebooks(DOCS_DIR) if deprecated_files: logger.warning("Found files with deprecated imports:\n" f"{_serialize_bad_imports(deprecated_files)}") if bad_files: raise ImportError("Found bad imports:\n" f"{_serialize_bad_imports(bad_files)}")