From 0dbd5f501294f36a2c2f134d334642877c75c747 Mon Sep 17 00:00:00 2001 From: ccurme Date: Fri, 29 Mar 2024 13:30:20 -0400 Subject: [PATCH] add script to check imports (#19611) --- .github/workflows/_test_doc_imports.yml | 50 +++++++++ .github/workflows/check_diffs.yml | 6 ++ docs/scripts/check_imports.py | 130 ++++++++++++++++++++++++ 3 files changed, 186 insertions(+) create mode 100644 .github/workflows/_test_doc_imports.yml create mode 100644 docs/scripts/check_imports.py diff --git a/.github/workflows/_test_doc_imports.yml b/.github/workflows/_test_doc_imports.yml new file mode 100644 index 00000000000..b61c41b8914 --- /dev/null +++ b/.github/workflows/_test_doc_imports.yml @@ -0,0 +1,50 @@ +name: test_doc_imports + +on: + workflow_call: + +env: + POETRY_VERSION: "1.7.1" + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: + - "3.11" + name: "check doc imports #${{ matrix.python-version }}" + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }} + uses: "./.github/actions/poetry_setup" + with: + python-version: ${{ matrix.python-version }} + poetry-version: ${{ env.POETRY_VERSION }} + cache-key: core + + - name: Install dependencies + shell: bash + run: poetry install --with test + + - name: Install langchain editable + run: | + poetry run pip install -e libs/core libs/langchain libs/community libs/experimental + + - name: Check doc imports + shell: bash + run: | + poetry run python docs/scripts/check_imports.py + + - name: Ensure the test did not create any additional files + shell: bash + run: | + set -eu + + STATUS="$(git status)" + echo "$STATUS" + + # grep will exit non-zero if the target message isn't found, + # and `set -e` above will cause the step to fail. + echo "$STATUS" | grep 'nothing to commit, working tree clean' diff --git a/.github/workflows/check_diffs.yml b/.github/workflows/check_diffs.yml index c4bd8b448b8..764cbf7c98e 100644 --- a/.github/workflows/check_diffs.yml +++ b/.github/workflows/check_diffs.yml @@ -60,6 +60,12 @@ jobs: working-directory: ${{ matrix.working-directory }} secrets: inherit + test_doc_imports: + needs: [ build ] + if: ${{ needs.build.outputs.dirs-to-test != '[]' }} + uses: ./.github/workflows/_test_doc_imports.yml + secrets: inherit + compile-integration-tests: name: cd ${{ matrix.working-directory }} needs: [ build ] diff --git a/docs/scripts/check_imports.py b/docs/scripts/check_imports.py new file mode 100644 index 00000000000..c416932432d --- /dev/null +++ b/docs/scripts/check_imports.py @@ -0,0 +1,130 @@ +"""This script checks documentation for broken import statements.""" +import importlib +import json +import logging +import os +import re +import warnings +from pathlib import Path +from typing import List, Tuple + +logger = logging.getLogger(__name__) + +DOCS_DIR = Path(os.path.abspath(__file__)).parents[1] / "docs" +import_pattern = re.compile( + r"import\s+(\w+)|from\s+([\w\.]+)\s+import\s+((?:\w+(?:,\s*)?)+|\(.*?\))", re.DOTALL +) + + +def _get_imports_from_code_cell(code_lines: str) -> List[Tuple[str, str]]: + """Get (module, import) statements from a single code cell.""" + import_statements = [] + for line in code_lines: + line = line.strip() + if line.startswith("#") or not line: + continue + # Join lines that end with a backslash + if line.endswith("\\"): + line = line[:-1].rstrip() + " " + continue + matches = import_pattern.findall(line) + for match in matches: + if match[0]: # simple import statement + import_statements.append((match[0], "")) + else: # from ___ import statement + module, items = match[1], match[2] + items_list = items.replace(" ", "").split(",") + for item in items_list: + import_statements.append((module, item)) + return import_statements + + +def _extract_import_statements(notebook_path: str) -> List[Tuple[str, str]]: + """Get (module, import) statements from a Jupyter notebook.""" + with open(notebook_path, "r", encoding="utf-8") as file: + notebook = json.load(file) + code_cells = [cell for cell in notebook["cells"] if cell["cell_type"] == "code"] + import_statements = [] + for cell in code_cells: + code_lines = cell["source"] + import_statements.extend(_get_imports_from_code_cell(code_lines)) + return import_statements + + +def _get_bad_imports(import_statements: List[Tuple[str, str]]) -> List[Tuple[str, str]]: + """Collect offending import statements.""" + offending_imports = [] + for module, item in import_statements: + try: + if item: + try: + # submodule + full_module_name = f"{module}.{item}" + importlib.import_module(full_module_name) + except ModuleNotFoundError: + # attribute + try: + imported_module = importlib.import_module(module) + getattr(imported_module, item) + except AttributeError: + offending_imports.append((module, item)) + except Exception: + offending_imports.append((module, item)) + else: + importlib.import_module(module) + except Exception: + offending_imports.append((module, item)) + + return offending_imports + + +def _is_relevant_import(module: str) -> bool: + """Check if module is recognized.""" + # Ignore things like langchain_{bla}, where bla is unrecognized. + recognized_packages = [ + "langchain", + "langchain_core", + "langchain_community", + "langchain_experimental", + "langchain_text_splitters", + ] + return module.split(".")[0] in recognized_packages + + +def _serialize_bad_imports(bad_files: list) -> str: + """Serialize bad imports to a string.""" + bad_imports_str = "" + for file, bad_imports in bad_files: + bad_imports_str += f"File: {file}\n" + for module, item in bad_imports: + bad_imports_str += f" {module}.{item}\n" + return bad_imports_str + + +def check_notebooks(directory: str) -> list: + """Check notebooks for broken import statements.""" + bad_files = [] + for root, _, files in os.walk(directory): + for file in files: + if file.endswith(".ipynb") and not file.endswith("-checkpoint.ipynb"): + notebook_path = os.path.join(root, file) + import_statements = [ + (module, item) + for module, item in _extract_import_statements(notebook_path) + if _is_relevant_import(module) + ] + bad_imports = _get_bad_imports(import_statements) + if bad_imports: + bad_files.append( + ( + os.path.join(root, file), + bad_imports, + ) + ) + return bad_files + + +if __name__ == "__main__": + bad_files = check_notebooks(DOCS_DIR) + if bad_files: + raise ImportError("Found bad imports:\n" f"{_serialize_bad_imports(bad_files)}")