mirror of
				https://github.com/hwchase17/langchain.git
				synced 2025-10-31 16:08:59 +00:00 
			
		
		
		
	Trying to unblock documentation build pipeline * Bump langgraph dep in docs * Update langgraph in lock file (resolves an issue in API reference generation)
		
			
				
	
	
		
			132 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			132 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """This script checks documentation for broken import statements."""
 | |
| 
 | |
| import importlib
 | |
| import json
 | |
| import logging
 | |
| import os
 | |
| import re
 | |
| import warnings
 | |
| from pathlib import Path
 | |
| from typing import List, Tuple
 | |
| 
 | |
| logger = logging.getLogger(__name__)
 | |
| 
 | |
| DOCS_DIR = Path(os.path.abspath(__file__)).parents[1] / "docs"
 | |
| import_pattern = re.compile(
 | |
|     r"import\s+(\w+)|from\s+([\w\.]+)\s+import\s+((?:\w+(?:,\s*)?)+|\(.*?\))", re.DOTALL
 | |
| )
 | |
| 
 | |
| 
 | |
| def _get_imports_from_code_cell(code_lines: str) -> List[Tuple[str, str]]:
 | |
|     """Get (module, import) statements from a single code cell."""
 | |
|     import_statements = []
 | |
|     for line in code_lines:
 | |
|         line = line.strip()
 | |
|         if line.startswith("#") or not line:
 | |
|             continue
 | |
|         # Join lines that end with a backslash
 | |
|         if line.endswith("\\"):
 | |
|             line = line[:-1].rstrip() + " "
 | |
|             continue
 | |
|         matches = import_pattern.findall(line)
 | |
|         for match in matches:
 | |
|             if match[0]:  # simple import statement
 | |
|                 import_statements.append((match[0], ""))
 | |
|             else:  # from ___ import statement
 | |
|                 module, items = match[1], match[2]
 | |
|                 items_list = items.replace(" ", "").split(",")
 | |
|                 for item in items_list:
 | |
|                     import_statements.append((module, item))
 | |
|     return import_statements
 | |
| 
 | |
| 
 | |
| def _extract_import_statements(notebook_path: str) -> List[Tuple[str, str]]:
 | |
|     """Get (module, import) statements from a Jupyter notebook."""
 | |
|     with open(notebook_path, "r", encoding="utf-8") as file:
 | |
|         notebook = json.load(file)
 | |
|     code_cells = [cell for cell in notebook["cells"] if cell["cell_type"] == "code"]
 | |
|     import_statements = []
 | |
|     for cell in code_cells:
 | |
|         code_lines = cell["source"]
 | |
|         import_statements.extend(_get_imports_from_code_cell(code_lines))
 | |
|     return import_statements
 | |
| 
 | |
| 
 | |
| def _get_bad_imports(import_statements: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
 | |
|     """Collect offending import statements."""
 | |
|     offending_imports = []
 | |
|     for module, item in import_statements:
 | |
|         try:
 | |
|             if item:
 | |
|                 try:
 | |
|                     # submodule
 | |
|                     full_module_name = f"{module}.{item}"
 | |
|                     importlib.import_module(full_module_name)
 | |
|                 except ModuleNotFoundError:
 | |
|                     # attribute
 | |
|                     try:
 | |
|                         imported_module = importlib.import_module(module)
 | |
|                         getattr(imported_module, item)
 | |
|                     except AttributeError:
 | |
|                         offending_imports.append((module, item))
 | |
|                 except Exception:
 | |
|                     offending_imports.append((module, item))
 | |
|             else:
 | |
|                 importlib.import_module(module)
 | |
|         except Exception:
 | |
|             offending_imports.append((module, item))
 | |
| 
 | |
|     return offending_imports
 | |
| 
 | |
| 
 | |
| def _is_relevant_import(module: str) -> bool:
 | |
|     """Check if module is recognized."""
 | |
|     # Ignore things like langchain_{bla}, where bla is unrecognized.
 | |
|     recognized_packages = [
 | |
|         "langchain",
 | |
|         "langchain_core",
 | |
|         "langchain_community",
 | |
|         "langchain_experimental",
 | |
|         "langchain_text_splitters",
 | |
|     ]
 | |
|     return module.split(".")[0] in recognized_packages
 | |
| 
 | |
| 
 | |
| def _serialize_bad_imports(bad_files: list) -> str:
 | |
|     """Serialize bad imports to a string."""
 | |
|     bad_imports_str = ""
 | |
|     for file, bad_imports in bad_files:
 | |
|         bad_imports_str += f"File: {file}\n"
 | |
|         for module, item in bad_imports:
 | |
|             bad_imports_str += f"    {module}.{item}\n"
 | |
|     return bad_imports_str
 | |
| 
 | |
| 
 | |
| def check_notebooks(directory: str) -> list:
 | |
|     """Check notebooks for broken import statements."""
 | |
|     bad_files = []
 | |
|     for root, _, files in os.walk(directory):
 | |
|         for file in files:
 | |
|             if file.endswith(".ipynb") and not file.endswith("-checkpoint.ipynb"):
 | |
|                 notebook_path = os.path.join(root, file)
 | |
|                 import_statements = [
 | |
|                     (module, item)
 | |
|                     for module, item in _extract_import_statements(notebook_path)
 | |
|                     if _is_relevant_import(module)
 | |
|                 ]
 | |
|                 bad_imports = _get_bad_imports(import_statements)
 | |
|                 if bad_imports:
 | |
|                     bad_files.append(
 | |
|                         (
 | |
|                             os.path.join(root, file),
 | |
|                             bad_imports,
 | |
|                         )
 | |
|                     )
 | |
|     return bad_files
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     bad_files = check_notebooks(DOCS_DIR)
 | |
|     if bad_files:
 | |
|         raise ImportError(f"Found bad imports:\n{_serialize_bad_imports(bad_files)}")
 |