mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 16:43:35 +00:00
add script to check imports (#19611)
This commit is contained in:
parent
2319212d54
commit
0dbd5f5012
50
.github/workflows/_test_doc_imports.yml
vendored
Normal file
50
.github/workflows/_test_doc_imports.yml
vendored
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
name: test_doc_imports
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_call:
|
||||||
|
|
||||||
|
env:
|
||||||
|
POETRY_VERSION: "1.7.1"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version:
|
||||||
|
- "3.11"
|
||||||
|
name: "check doc imports #${{ matrix.python-version }}"
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }}
|
||||||
|
uses: "./.github/actions/poetry_setup"
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
poetry-version: ${{ env.POETRY_VERSION }}
|
||||||
|
cache-key: core
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
shell: bash
|
||||||
|
run: poetry install --with test
|
||||||
|
|
||||||
|
- name: Install langchain editable
|
||||||
|
run: |
|
||||||
|
poetry run pip install -e libs/core libs/langchain libs/community libs/experimental
|
||||||
|
|
||||||
|
- name: Check doc imports
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
poetry run python docs/scripts/check_imports.py
|
||||||
|
|
||||||
|
- name: Ensure the test did not create any additional files
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
set -eu
|
||||||
|
|
||||||
|
STATUS="$(git status)"
|
||||||
|
echo "$STATUS"
|
||||||
|
|
||||||
|
# grep will exit non-zero if the target message isn't found,
|
||||||
|
# and `set -e` above will cause the step to fail.
|
||||||
|
echo "$STATUS" | grep 'nothing to commit, working tree clean'
|
6
.github/workflows/check_diffs.yml
vendored
6
.github/workflows/check_diffs.yml
vendored
@ -60,6 +60,12 @@ jobs:
|
|||||||
working-directory: ${{ matrix.working-directory }}
|
working-directory: ${{ matrix.working-directory }}
|
||||||
secrets: inherit
|
secrets: inherit
|
||||||
|
|
||||||
|
test_doc_imports:
|
||||||
|
needs: [ build ]
|
||||||
|
if: ${{ needs.build.outputs.dirs-to-test != '[]' }}
|
||||||
|
uses: ./.github/workflows/_test_doc_imports.yml
|
||||||
|
secrets: inherit
|
||||||
|
|
||||||
compile-integration-tests:
|
compile-integration-tests:
|
||||||
name: cd ${{ matrix.working-directory }}
|
name: cd ${{ matrix.working-directory }}
|
||||||
needs: [ build ]
|
needs: [ build ]
|
||||||
|
130
docs/scripts/check_imports.py
Normal file
130
docs/scripts/check_imports.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
"""This script checks documentation for broken import statements."""
|
||||||
|
import importlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import warnings
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DOCS_DIR = Path(os.path.abspath(__file__)).parents[1] / "docs"
|
||||||
|
import_pattern = re.compile(
|
||||||
|
r"import\s+(\w+)|from\s+([\w\.]+)\s+import\s+((?:\w+(?:,\s*)?)+|\(.*?\))", re.DOTALL
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_imports_from_code_cell(code_lines: str) -> List[Tuple[str, str]]:
|
||||||
|
"""Get (module, import) statements from a single code cell."""
|
||||||
|
import_statements = []
|
||||||
|
for line in code_lines:
|
||||||
|
line = line.strip()
|
||||||
|
if line.startswith("#") or not line:
|
||||||
|
continue
|
||||||
|
# Join lines that end with a backslash
|
||||||
|
if line.endswith("\\"):
|
||||||
|
line = line[:-1].rstrip() + " "
|
||||||
|
continue
|
||||||
|
matches = import_pattern.findall(line)
|
||||||
|
for match in matches:
|
||||||
|
if match[0]: # simple import statement
|
||||||
|
import_statements.append((match[0], ""))
|
||||||
|
else: # from ___ import statement
|
||||||
|
module, items = match[1], match[2]
|
||||||
|
items_list = items.replace(" ", "").split(",")
|
||||||
|
for item in items_list:
|
||||||
|
import_statements.append((module, item))
|
||||||
|
return import_statements
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_import_statements(notebook_path: str) -> List[Tuple[str, str]]:
|
||||||
|
"""Get (module, import) statements from a Jupyter notebook."""
|
||||||
|
with open(notebook_path, "r", encoding="utf-8") as file:
|
||||||
|
notebook = json.load(file)
|
||||||
|
code_cells = [cell for cell in notebook["cells"] if cell["cell_type"] == "code"]
|
||||||
|
import_statements = []
|
||||||
|
for cell in code_cells:
|
||||||
|
code_lines = cell["source"]
|
||||||
|
import_statements.extend(_get_imports_from_code_cell(code_lines))
|
||||||
|
return import_statements
|
||||||
|
|
||||||
|
|
||||||
|
def _get_bad_imports(import_statements: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
||||||
|
"""Collect offending import statements."""
|
||||||
|
offending_imports = []
|
||||||
|
for module, item in import_statements:
|
||||||
|
try:
|
||||||
|
if item:
|
||||||
|
try:
|
||||||
|
# submodule
|
||||||
|
full_module_name = f"{module}.{item}"
|
||||||
|
importlib.import_module(full_module_name)
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
# attribute
|
||||||
|
try:
|
||||||
|
imported_module = importlib.import_module(module)
|
||||||
|
getattr(imported_module, item)
|
||||||
|
except AttributeError:
|
||||||
|
offending_imports.append((module, item))
|
||||||
|
except Exception:
|
||||||
|
offending_imports.append((module, item))
|
||||||
|
else:
|
||||||
|
importlib.import_module(module)
|
||||||
|
except Exception:
|
||||||
|
offending_imports.append((module, item))
|
||||||
|
|
||||||
|
return offending_imports
|
||||||
|
|
||||||
|
|
||||||
|
def _is_relevant_import(module: str) -> bool:
|
||||||
|
"""Check if module is recognized."""
|
||||||
|
# Ignore things like langchain_{bla}, where bla is unrecognized.
|
||||||
|
recognized_packages = [
|
||||||
|
"langchain",
|
||||||
|
"langchain_core",
|
||||||
|
"langchain_community",
|
||||||
|
"langchain_experimental",
|
||||||
|
"langchain_text_splitters",
|
||||||
|
]
|
||||||
|
return module.split(".")[0] in recognized_packages
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_bad_imports(bad_files: list) -> str:
|
||||||
|
"""Serialize bad imports to a string."""
|
||||||
|
bad_imports_str = ""
|
||||||
|
for file, bad_imports in bad_files:
|
||||||
|
bad_imports_str += f"File: {file}\n"
|
||||||
|
for module, item in bad_imports:
|
||||||
|
bad_imports_str += f" {module}.{item}\n"
|
||||||
|
return bad_imports_str
|
||||||
|
|
||||||
|
|
||||||
|
def check_notebooks(directory: str) -> list:
|
||||||
|
"""Check notebooks for broken import statements."""
|
||||||
|
bad_files = []
|
||||||
|
for root, _, files in os.walk(directory):
|
||||||
|
for file in files:
|
||||||
|
if file.endswith(".ipynb") and not file.endswith("-checkpoint.ipynb"):
|
||||||
|
notebook_path = os.path.join(root, file)
|
||||||
|
import_statements = [
|
||||||
|
(module, item)
|
||||||
|
for module, item in _extract_import_statements(notebook_path)
|
||||||
|
if _is_relevant_import(module)
|
||||||
|
]
|
||||||
|
bad_imports = _get_bad_imports(import_statements)
|
||||||
|
if bad_imports:
|
||||||
|
bad_files.append(
|
||||||
|
(
|
||||||
|
os.path.join(root, file),
|
||||||
|
bad_imports,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return bad_files
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
bad_files = check_notebooks(DOCS_DIR)
|
||||||
|
if bad_files:
|
||||||
|
raise ImportError("Found bad imports:\n" f"{_serialize_bad_imports(bad_files)}")
|
Loading…
Reference in New Issue
Block a user