chore(cli): fix some DOC rules (preview) (#32839)

Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
Christophe Bornet
2025-09-08 16:36:22 +02:00
committed by GitHub
parent e0aaaccb61
commit 20401df25d
10 changed files with 218 additions and 22 deletions

View File

@@ -14,7 +14,19 @@ def create_demo_server(
config_keys: Sequence[str] = (), config_keys: Sequence[str] = (),
playground_type: Literal["default", "chat"] = "default", playground_type: Literal["default", "chat"] = "default",
) -> FastAPI: ) -> FastAPI:
"""Create a demo server for the current template.""" """Create a demo server for the current template.
Args:
config_keys: Optional sequence of config keys to expose in the playground.
playground_type: The type of playground to use. Can be `'default'` or `'chat'`.
Returns:
The demo server.
Raises:
KeyError: If the `pyproject.toml` file is missing required fields.
ImportError: If the module defined in `pyproject.toml` cannot be imported.
"""
app = FastAPI() app = FastAPI()
package_root = get_package_root() package_root = get_package_root()
pyproject = package_root / "pyproject.toml" pyproject = package_root / "pyproject.toml"
@@ -41,10 +53,18 @@ def create_demo_server(
def create_demo_server_configurable() -> FastAPI: def create_demo_server_configurable() -> FastAPI:
"""Create a configurable demo server.""" """Create a configurable demo server.
Returns:
The configurable demo server.
"""
return create_demo_server(config_keys=["configurable"]) return create_demo_server(config_keys=["configurable"])
def create_demo_server_chat() -> FastAPI: def create_demo_server_chat() -> FastAPI:
"""Create a chat demo server.""" """Create a chat demo server.
Returns:
The chat demo server.
"""
return create_demo_server(playground_type="chat") return create_demo_server(playground_type="chat")

View File

@@ -11,7 +11,16 @@ def generate_raw_migrations(
to_package: str, to_package: str,
filter_by_all: bool = False, # noqa: FBT001, FBT002 filter_by_all: bool = False, # noqa: FBT001, FBT002
) -> list[tuple[str, str]]: ) -> list[tuple[str, str]]:
"""Scan the `langchain` package and generate migrations for all modules.""" """Scan the `langchain` package and generate migrations for all modules.
Args:
from_package: The package to migrate from.
to_package: The package to migrate to.
filter_by_all: Whether to only consider items in `__all__`.
Returns:
A list of tuples containing the original import path and the new import path.
"""
package = importlib.import_module(from_package) package = importlib.import_module(from_package)
items = [] items = []
@@ -84,6 +93,13 @@ def generate_top_level_imports(pkg: str) -> list[tuple[str, str]]:
and the second tuple will contain the path and the second tuple will contain the path
to importing it from the top level namespaces to importing it from the top level namespaces
(e.g., ``langchain_community.chat_models.XYZ``) (e.g., ``langchain_community.chat_models.XYZ``)
Args:
pkg: The package to scan.
Returns:
A list of tuples containing the fully qualified path and the top-level
import path.
""" """
package = importlib.import_module(pkg) package = importlib.import_module(pkg)
@@ -130,7 +146,17 @@ def generate_simplified_migrations(
to_package: str, to_package: str,
filter_by_all: bool = True, # noqa: FBT001, FBT002 filter_by_all: bool = True, # noqa: FBT001, FBT002
) -> list[tuple[str, str]]: ) -> list[tuple[str, str]]:
"""Get all the raw migrations, then simplify them if possible.""" """Get all the raw migrations, then simplify them if possible.
Args:
from_package: The package to migrate from.
to_package: The package to migrate to.
filter_by_all: Whether to only consider items in `__all__`.
Returns:
A list of tuples containing the original import path and the simplified
import path.
"""
raw_migrations = generate_raw_migrations( raw_migrations = generate_raw_migrations(
from_package, from_package,
to_package, to_package,

View File

@@ -2,13 +2,28 @@
def split_package(package: str) -> tuple[str, str]: def split_package(package: str) -> tuple[str, str]:
"""Split a package name into the containing package and the final name.""" """Split a package name into the containing package and the final name.
Args:
package: The full package name.
Returns:
A tuple of `(containing_package, final_name)`.
"""
parts = package.split(".") parts = package.split(".")
return ".".join(parts[:-1]), parts[-1] return ".".join(parts[:-1]), parts[-1]
def dump_migrations_as_grit(name: str, migration_pairs: list[tuple[str, str]]) -> str: def dump_migrations_as_grit(name: str, migration_pairs: list[tuple[str, str]]) -> str:
"""Dump the migration pairs as a Grit file.""" """Dump the migration pairs as a Grit file.
Args:
name: The name of the migration.
migration_pairs: A list of tuples `(from_module, to_module)`.
Returns:
The Grit file as a string.
"""
remapped = ",\n".join( remapped = ",\n".join(
[ [
f""" f"""

View File

@@ -59,7 +59,15 @@ def _get_class_names(code: str) -> list[str]:
def is_subclass(class_obj: type, classes_: list[type]) -> bool: def is_subclass(class_obj: type, classes_: list[type]) -> bool:
"""Check if the given class object is a subclass of any class in list classes.""" """Check if the given class object is a subclass of any class in list classes.
Args:
class_obj: The class to check.
classes_: A list of classes to check against.
Returns:
True if `class_obj` is a subclass of any class in `classes_`, False otherwise.
"""
return any( return any(
issubclass(class_obj, kls) issubclass(class_obj, kls)
for kls in classes_ for kls in classes_
@@ -68,7 +76,15 @@ def is_subclass(class_obj: type, classes_: list[type]) -> bool:
def find_subclasses_in_module(module: ModuleType, classes_: list[type]) -> list[str]: def find_subclasses_in_module(module: ModuleType, classes_: list[type]) -> list[str]:
"""Find all classes in the module that inherit from one of the classes.""" """Find all classes in the module that inherit from one of the classes.
Args:
module: The module to inspect.
classes_: A list of classes to check against.
Returns:
A list of class names that are subclasses of any class in `classes_`.
"""
subclasses = [] subclasses = []
# Iterate over all attributes of the module that are classes # Iterate over all attributes of the module that are classes
for _name, obj in inspect.getmembers(module, inspect.isclass): for _name, obj in inspect.getmembers(module, inspect.isclass):
@@ -91,7 +107,15 @@ def identify_all_imports_in_file(
*, *,
from_package: Optional[str] = None, from_package: Optional[str] = None,
) -> list[tuple[str, str]]: ) -> list[tuple[str, str]]:
"""Let's also identify all the imports in the given file.""" """Identify all the imports in the given file.
Args:
file: The file to analyze.
from_package: If provided, only return imports from this package.
Returns:
A list of tuples `(module, name)` representing the imports found in the file.
"""
code = Path(file).read_text(encoding="utf-8") code = Path(file).read_text(encoding="utf-8")
return find_imports_from_package(code, from_package=from_package) return find_imports_from_package(code, from_package=from_package)
@@ -106,6 +130,9 @@ def identify_pkg_source(pkg_root: str) -> pathlib.Path:
Returns: Returns:
Returns the path to the source code for the package. Returns the path to the source code for the package.
Raises:
ValueError: If there is not exactly one directory starting with `'langchain_'`
in the package root.
""" """
dirs = [d for d in Path(pkg_root).iterdir() if d.is_dir()] dirs = [d for d in Path(pkg_root).iterdir() if d.is_dir()]
matching_dirs = [d for d in dirs if d.name.startswith("langchain_")] matching_dirs = [d for d in dirs if d.name.startswith("langchain_")]
@@ -116,7 +143,15 @@ def identify_pkg_source(pkg_root: str) -> pathlib.Path:
def list_classes_by_package(pkg_root: str) -> list[tuple[str, str]]: def list_classes_by_package(pkg_root: str) -> list[tuple[str, str]]:
"""List all classes in a package.""" """List all classes in a package.
Args:
pkg_root: the root of the package.
Returns:
A list of tuples `(module, class_name)` representing all classes found in the
package, excluding test files.
"""
module_classes = [] module_classes = []
pkg_source = identify_pkg_source(pkg_root) pkg_source = identify_pkg_source(pkg_root)
files = list(pkg_source.rglob("*.py")) files = list(pkg_source.rglob("*.py"))
@@ -130,7 +165,15 @@ def list_classes_by_package(pkg_root: str) -> list[tuple[str, str]]:
def list_init_imports_by_package(pkg_root: str) -> list[tuple[str, str]]: def list_init_imports_by_package(pkg_root: str) -> list[tuple[str, str]]:
"""List all the things that are being imported in a package by module.""" """List all the things that are being imported in a package by module.
Args:
pkg_root: the root of the package.
Returns:
A list of tuples `(module, name)` representing the imports found in
`__init__.py` files.
"""
imports = [] imports = []
pkg_source = identify_pkg_source(pkg_root) pkg_source = identify_pkg_source(pkg_root)
# Scan all the files in the package # Scan all the files in the package
@@ -150,7 +193,15 @@ def find_imports_from_package(
*, *,
from_package: Optional[str] = None, from_package: Optional[str] = None,
) -> list[tuple[str, str]]: ) -> list[tuple[str, str]]:
"""Find imports in code.""" """Find imports in code.
Args:
code: The code to analyze.
from_package: If provided, only return imports from this package.
Returns:
A list of tuples `(module, name)` representing the imports found.
"""
# Parse the code into an AST # Parse the code into an AST
tree = ast.parse(code) tree = ast.parse(code)
# Create an instance of the visitor # Create an instance of the visitor

View File

@@ -22,7 +22,14 @@ class EventDict(TypedDict):
def create_events(events: list[EventDict]) -> Optional[dict[str, Any]]: def create_events(events: list[EventDict]) -> Optional[dict[str, Any]]:
"""Create events.""" """Create events.
Args:
events: A list of event dictionaries.
Returns:
The response from the event tracking service, or None if there was an error.
"""
try: try:
data = { data = {
"events": [ "events": [

View File

@@ -4,7 +4,15 @@ from pathlib import Path
def find_and_replace(source: str, replacements: dict[str, str]) -> str: def find_and_replace(source: str, replacements: dict[str, str]) -> str:
"""Find and replace text in a string.""" """Find and replace text in a string.
Args:
source: The source string.
replacements: A dictionary of `{find: replace}` pairs.
Returns:
The modified string.
"""
rtn = source rtn = source
# replace keys in deterministic alphabetical order # replace keys in deterministic alphabetical order

View File

@@ -36,7 +36,20 @@ def parse_dependency_string(
branch: Optional[str], branch: Optional[str],
api_path: Optional[str], api_path: Optional[str],
) -> DependencySource: ) -> DependencySource:
"""Parse a dependency string into a DependencySource.""" """Parse a dependency string into a DependencySource.
Args:
dep: the dependency string.
repo: optional repository.
branch: optional branch.
api_path: optional API path.
Returns:
The parsed dependency source information.
Raises:
ValueError: if the dependency string is invalid.
"""
if dep is not None and dep.startswith("git+"): if dep is not None and dep.startswith("git+"):
if repo is not None or branch is not None: if repo is not None or branch is not None:
msg = ( msg = (
@@ -129,7 +142,22 @@ def parse_dependencies(
branch: list[str], branch: list[str],
api_path: list[str], api_path: list[str],
) -> list[DependencySource]: ) -> list[DependencySource]:
"""Parse dependencies.""" """Parse dependencies.
Args:
dependencies: the dependencies to parse
repo: the repositories to use
branch: the branches to use
api_path: the api paths to use
Returns:
A list of DependencySource objects.
Raises:
ValueError: if the number of `dependencies`, `repos`, `branches`, or `api_paths`
do not match.
"""
num_deps = max( num_deps = max(
len(dependencies) if dependencies is not None else 0, len(dependencies) if dependencies is not None else 0,
len(repo), len(repo),
@@ -177,7 +205,18 @@ def _get_repo_path(gitstring: str, ref: Optional[str], repo_dir: Path) -> Path:
def update_repo(gitstring: str, ref: Optional[str], repo_dir: Path) -> Path: def update_repo(gitstring: str, ref: Optional[str], repo_dir: Path) -> Path:
"""Update a git repository to the specified ref.""" """Update a git repository to the specified ref.
Tries to pull if the repo already exists, otherwise clones it.
Args:
gitstring: The git repository URL.
ref: The git reference.
repo_dir: The directory to clone the repository into.
Returns:
The path to the cloned repository.
"""
# see if path already saved # see if path already saved
repo_path = _get_repo_path(gitstring, ref, repo_dir) repo_path = _get_repo_path(gitstring, ref, repo_dir)
if repo_path.exists(): if repo_path.exists():

View File

@@ -6,7 +6,14 @@ from typing import Optional
def list_packages(*, contains: Optional[str] = None) -> list[str]: def list_packages(*, contains: Optional[str] = None) -> list[str]:
"""List all packages in the langchain repository templates directory.""" """List all packages in the langchain repository templates directory.
Args:
contains: Optional substring that the package name must contain.
Returns:
A list of package names.
"""
conn = http.client.HTTPSConnection("api.github.com") conn = http.client.HTTPSConnection("api.github.com")
try: try:
headers = { headers = {

View File

@@ -7,7 +7,19 @@ from tomlkit import load
def get_package_root(cwd: Optional[Path] = None) -> Path: def get_package_root(cwd: Optional[Path] = None) -> Path:
"""Get package root directory.""" """Get package root directory.
Args:
cwd: The current working directory to start the search from.
If None, uses the current working directory of the process.
Returns:
The path to the package root directory.
Raises:
FileNotFoundError: If no `pyproject.toml` file is found in the directory
hierarchy.
"""
# traverse path for routes to host (any directory holding a pyproject.toml file) # traverse path for routes to host (any directory holding a pyproject.toml file)
package_root = Path.cwd() if cwd is None else cwd package_root = Path.cwd() if cwd is None else cwd
visited: set[Path] = set() visited: set[Path] = set()
@@ -38,7 +50,17 @@ class LangServeExport(TypedDict):
def get_langserve_export(filepath: Path) -> LangServeExport: def get_langserve_export(filepath: Path) -> LangServeExport:
"""Get LangServe export information from a pyproject.toml file.""" """Get LangServe export information from a `pyproject.toml` file.
Args:
filepath: Path to the `pyproject.toml` file.
Returns:
The LangServeExport information.
Raises:
KeyError: If the `pyproject.toml` file is missing required fields.
"""
with filepath.open() as f: with filepath.open() as f:
data: dict[str, Any] = load(f) data: dict[str, Any] = load(f)
try: try:

View File

@@ -52,6 +52,7 @@ select = [ "ALL",]
ignore = [ ignore = [
"C90", # McCabe complexity "C90", # McCabe complexity
"COM812", # Messes with the formatter "COM812", # Messes with the formatter
"CPY", # No copyright
"FIX002", # Line contains TODO "FIX002", # Line contains TODO
"PERF203", # Rarely useful "PERF203", # Rarely useful
"PLR09", # Too many something (arg, statements, etc) "PLR09", # Too many something (arg, statements, etc)
@@ -78,7 +79,7 @@ pydocstyle.convention = "google"
pyupgrade.keep-runtime-typing = true pyupgrade.keep-runtime-typing = true
[tool.ruff.lint.per-file-ignores] [tool.ruff.lint.per-file-ignores]
"tests/**" = [ "D1", "S", "SLF",] "tests/**" = [ "D1", "DOC", "S", "SLF",]
"scripts/**" = [ "INP", "S",] "scripts/**" = [ "INP", "S",]
[tool.mypy] [tool.mypy]