chore(cli): fix some DOC rules (preview) (#32839)

Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
Christophe Bornet
2025-09-08 16:36:22 +02:00
committed by GitHub
parent e0aaaccb61
commit 20401df25d
10 changed files with 218 additions and 22 deletions

View File

@@ -14,7 +14,19 @@ def create_demo_server(
config_keys: Sequence[str] = (),
playground_type: Literal["default", "chat"] = "default",
) -> FastAPI:
"""Create a demo server for the current template."""
"""Create a demo server for the current template.
Args:
config_keys: Optional sequence of config keys to expose in the playground.
playground_type: The type of playground to use. Can be `'default'` or `'chat'`.
Returns:
The demo server.
Raises:
KeyError: If the `pyproject.toml` file is missing required fields.
ImportError: If the module defined in `pyproject.toml` cannot be imported.
"""
app = FastAPI()
package_root = get_package_root()
pyproject = package_root / "pyproject.toml"
@@ -41,10 +53,18 @@ def create_demo_server(
def create_demo_server_configurable() -> FastAPI:
"""Create a configurable demo server."""
"""Create a configurable demo server.
Returns:
The configurable demo server.
"""
return create_demo_server(config_keys=["configurable"])
def create_demo_server_chat() -> FastAPI:
"""Create a chat demo server."""
"""Create a chat demo server.
Returns:
The chat demo server.
"""
return create_demo_server(playground_type="chat")

View File

@@ -11,7 +11,16 @@ def generate_raw_migrations(
to_package: str,
filter_by_all: bool = False, # noqa: FBT001, FBT002
) -> list[tuple[str, str]]:
"""Scan the `langchain` package and generate migrations for all modules."""
"""Scan the `langchain` package and generate migrations for all modules.
Args:
from_package: The package to migrate from.
to_package: The package to migrate to.
filter_by_all: Whether to only consider items in `__all__`.
Returns:
A list of tuples containing the original import path and the new import path.
"""
package = importlib.import_module(from_package)
items = []
@@ -84,6 +93,13 @@ def generate_top_level_imports(pkg: str) -> list[tuple[str, str]]:
and the second tuple will contain the path
to importing it from the top level namespaces
(e.g., ``langchain_community.chat_models.XYZ``)
Args:
pkg: The package to scan.
Returns:
A list of tuples containing the fully qualified path and the top-level
import path.
"""
package = importlib.import_module(pkg)
@@ -130,7 +146,17 @@ def generate_simplified_migrations(
to_package: str,
filter_by_all: bool = True, # noqa: FBT001, FBT002
) -> list[tuple[str, str]]:
"""Get all the raw migrations, then simplify them if possible."""
"""Get all the raw migrations, then simplify them if possible.
Args:
from_package: The package to migrate from.
to_package: The package to migrate to.
filter_by_all: Whether to only consider items in `__all__`.
Returns:
A list of tuples containing the original import path and the simplified
import path.
"""
raw_migrations = generate_raw_migrations(
from_package,
to_package,

View File

@@ -2,13 +2,28 @@
def split_package(package: str) -> tuple[str, str]:
"""Split a package name into the containing package and the final name."""
"""Split a package name into the containing package and the final name.
Args:
package: The full package name.
Returns:
A tuple of `(containing_package, final_name)`.
"""
parts = package.split(".")
return ".".join(parts[:-1]), parts[-1]
def dump_migrations_as_grit(name: str, migration_pairs: list[tuple[str, str]]) -> str:
"""Dump the migration pairs as a Grit file."""
"""Dump the migration pairs as a Grit file.
Args:
name: The name of the migration.
migration_pairs: A list of tuples `(from_module, to_module)`.
Returns:
The Grit file as a string.
"""
remapped = ",\n".join(
[
f"""

View File

@@ -59,7 +59,15 @@ def _get_class_names(code: str) -> list[str]:
def is_subclass(class_obj: type, classes_: list[type]) -> bool:
"""Check if the given class object is a subclass of any class in list classes."""
"""Check if the given class object is a subclass of any class in list classes.
Args:
class_obj: The class to check.
classes_: A list of classes to check against.
Returns:
True if `class_obj` is a subclass of any class in `classes_`, False otherwise.
"""
return any(
issubclass(class_obj, kls)
for kls in classes_
@@ -68,7 +76,15 @@ def is_subclass(class_obj: type, classes_: list[type]) -> bool:
def find_subclasses_in_module(module: ModuleType, classes_: list[type]) -> list[str]:
"""Find all classes in the module that inherit from one of the classes."""
"""Find all classes in the module that inherit from one of the classes.
Args:
module: The module to inspect.
classes_: A list of classes to check against.
Returns:
A list of class names that are subclasses of any class in `classes_`.
"""
subclasses = []
# Iterate over all attributes of the module that are classes
for _name, obj in inspect.getmembers(module, inspect.isclass):
@@ -91,7 +107,15 @@ def identify_all_imports_in_file(
*,
from_package: Optional[str] = None,
) -> list[tuple[str, str]]:
"""Let's also identify all the imports in the given file."""
"""Identify all the imports in the given file.
Args:
file: The file to analyze.
from_package: If provided, only return imports from this package.
Returns:
A list of tuples `(module, name)` representing the imports found in the file.
"""
code = Path(file).read_text(encoding="utf-8")
return find_imports_from_package(code, from_package=from_package)
@@ -106,6 +130,9 @@ def identify_pkg_source(pkg_root: str) -> pathlib.Path:
Returns:
Returns the path to the source code for the package.
Raises:
ValueError: If there is not exactly one directory starting with `'langchain_'`
in the package root.
"""
dirs = [d for d in Path(pkg_root).iterdir() if d.is_dir()]
matching_dirs = [d for d in dirs if d.name.startswith("langchain_")]
@@ -116,7 +143,15 @@ def identify_pkg_source(pkg_root: str) -> pathlib.Path:
def list_classes_by_package(pkg_root: str) -> list[tuple[str, str]]:
"""List all classes in a package."""
"""List all classes in a package.
Args:
pkg_root: the root of the package.
Returns:
A list of tuples `(module, class_name)` representing all classes found in the
package, excluding test files.
"""
module_classes = []
pkg_source = identify_pkg_source(pkg_root)
files = list(pkg_source.rglob("*.py"))
@@ -130,7 +165,15 @@ def list_classes_by_package(pkg_root: str) -> list[tuple[str, str]]:
def list_init_imports_by_package(pkg_root: str) -> list[tuple[str, str]]:
"""List all the things that are being imported in a package by module."""
"""List all the things that are being imported in a package by module.
Args:
pkg_root: the root of the package.
Returns:
A list of tuples `(module, name)` representing the imports found in
`__init__.py` files.
"""
imports = []
pkg_source = identify_pkg_source(pkg_root)
# Scan all the files in the package
@@ -150,7 +193,15 @@ def find_imports_from_package(
*,
from_package: Optional[str] = None,
) -> list[tuple[str, str]]:
"""Find imports in code."""
"""Find imports in code.
Args:
code: The code to analyze.
from_package: If provided, only return imports from this package.
Returns:
A list of tuples `(module, name)` representing the imports found.
"""
# Parse the code into an AST
tree = ast.parse(code)
# Create an instance of the visitor

View File

@@ -22,7 +22,14 @@ class EventDict(TypedDict):
def create_events(events: list[EventDict]) -> Optional[dict[str, Any]]:
"""Create events."""
"""Create events.
Args:
events: A list of event dictionaries.
Returns:
The response from the event tracking service, or None if there was an error.
"""
try:
data = {
"events": [

View File

@@ -4,7 +4,15 @@ from pathlib import Path
def find_and_replace(source: str, replacements: dict[str, str]) -> str:
"""Find and replace text in a string."""
"""Find and replace text in a string.
Args:
source: The source string.
replacements: A dictionary of `{find: replace}` pairs.
Returns:
The modified string.
"""
rtn = source
# replace keys in deterministic alphabetical order

View File

@@ -36,7 +36,20 @@ def parse_dependency_string(
branch: Optional[str],
api_path: Optional[str],
) -> DependencySource:
"""Parse a dependency string into a DependencySource."""
"""Parse a dependency string into a DependencySource.
Args:
dep: the dependency string.
repo: optional repository.
branch: optional branch.
api_path: optional API path.
Returns:
The parsed dependency source information.
Raises:
ValueError: if the dependency string is invalid.
"""
if dep is not None and dep.startswith("git+"):
if repo is not None or branch is not None:
msg = (
@@ -129,7 +142,22 @@ def parse_dependencies(
branch: list[str],
api_path: list[str],
) -> list[DependencySource]:
"""Parse dependencies."""
"""Parse dependencies.
Args:
dependencies: the dependencies to parse
repo: the repositories to use
branch: the branches to use
api_path: the api paths to use
Returns:
A list of DependencySource objects.
Raises:
ValueError: if the number of `dependencies`, `repos`, `branches`, or `api_paths`
do not match.
"""
num_deps = max(
len(dependencies) if dependencies is not None else 0,
len(repo),
@@ -177,7 +205,18 @@ def _get_repo_path(gitstring: str, ref: Optional[str], repo_dir: Path) -> Path:
def update_repo(gitstring: str, ref: Optional[str], repo_dir: Path) -> Path:
"""Update a git repository to the specified ref."""
"""Update a git repository to the specified ref.
Tries to pull if the repo already exists, otherwise clones it.
Args:
gitstring: The git repository URL.
ref: The git reference.
repo_dir: The directory to clone the repository into.
Returns:
The path to the cloned repository.
"""
# see if path already saved
repo_path = _get_repo_path(gitstring, ref, repo_dir)
if repo_path.exists():

View File

@@ -6,7 +6,14 @@ from typing import Optional
def list_packages(*, contains: Optional[str] = None) -> list[str]:
"""List all packages in the langchain repository templates directory."""
"""List all packages in the langchain repository templates directory.
Args:
contains: Optional substring that the package name must contain.
Returns:
A list of package names.
"""
conn = http.client.HTTPSConnection("api.github.com")
try:
headers = {

View File

@@ -7,7 +7,19 @@ from tomlkit import load
def get_package_root(cwd: Optional[Path] = None) -> Path:
"""Get package root directory."""
"""Get package root directory.
Args:
cwd: The current working directory to start the search from.
If None, uses the current working directory of the process.
Returns:
The path to the package root directory.
Raises:
FileNotFoundError: If no `pyproject.toml` file is found in the directory
hierarchy.
"""
# traverse path for routes to host (any directory holding a pyproject.toml file)
package_root = Path.cwd() if cwd is None else cwd
visited: set[Path] = set()
@@ -38,7 +50,17 @@ class LangServeExport(TypedDict):
def get_langserve_export(filepath: Path) -> LangServeExport:
"""Get LangServe export information from a pyproject.toml file."""
"""Get LangServe export information from a `pyproject.toml` file.
Args:
filepath: Path to the `pyproject.toml` file.
Returns:
The LangServeExport information.
Raises:
KeyError: If the `pyproject.toml` file is missing required fields.
"""
with filepath.open() as f:
data: dict[str, Any] = load(f)
try:

View File

@@ -52,6 +52,7 @@ select = [ "ALL",]
ignore = [
"C90", # McCabe complexity
"COM812", # Messes with the formatter
"CPY", # No copyright
"FIX002", # Line contains TODO
"PERF203", # Rarely useful
"PLR09", # Too many something (arg, statements, etc)
@@ -78,7 +79,7 @@ pydocstyle.convention = "google"
pyupgrade.keep-runtime-typing = true
[tool.ruff.lint.per-file-ignores]
"tests/**" = [ "D1", "S", "SLF",]
"tests/**" = [ "D1", "DOC", "S", "SLF",]
"scripts/**" = [ "INP", "S",]
[tool.mypy]