Files
langchain/libs/model-profiles/langchain_model_profiles/cli.py
Mason Daugherty 2f64d80cc6 fix(core,model-profiles): add missing ModelProfile fields, warn on schema drift (#36129)
PR #35788 added 7 new fields to the `langchain-profiles` CLI output
(`name`, `status`, `release_date`, `last_updated`, `open_weights`,
`attachment`, `temperature`) but didn't update `ModelProfile` in
`langchain-core`. Partner packages like `langchain-aws` that set
`extra="forbid"` on their Pydantic models hit `extra_forbidden`
validation errors when Pydantic encountered undeclared TypedDict keys at
construction time. This adds the missing fields, makes `ModelProfile`
forward-compatible, provides a base-class hook so partners can stop
duplicating model-profile validator boilerplate, migrates all in-repo
partners to the new hook, and adds runtime + CI-time warnings for schema
drift.

## Changes

### `langchain-core`
- Add `__pydantic_config__ = ConfigDict(extra="allow")` to
`ModelProfile` so unknown profile keys pass Pydantic validation even on
models with `extra="forbid"` — forward-compatibility for when the CLI
schema evolves ahead of core
- Declare the 7 missing fields on `ModelProfile`: `name`, `status`,
`release_date`, `last_updated`, `open_weights` (metadata) and
`attachment`, `temperature` (capabilities)
- Add `_warn_unknown_profile_keys()` in `model_profile.py` — emits a
`UserWarning` when a profile dict contains keys not in `ModelProfile`,
suggesting a core upgrade. Wrapped in a bare `except` so introspection
failures never crash model construction
- Add `BaseChatModel._resolve_model_profile()` hook that returns `None`
by default. Partners can override this single method instead of
redefining the full `_set_model_profile` validator — the base validator
calls it automatically
- Add `BaseChatModel._check_profile_keys` as a separate
`model_validator` that calls `_warn_unknown_profile_keys`. Uses a
distinct method name so partner overrides of `_set_model_profile` don't
inadvertently suppress the check

### `langchain-profiles` CLI
- Add `_warn_undeclared_profile_keys()` to the CLI (`cli.py`), called
after merging augmentations in `refresh()` — warns at profile-generation
time (not just runtime) when emitted keys aren't declared in
`ModelProfile`. Gracefully skips if `langchain-core` isn't installed
- Add guard test
`test_model_data_to_profile_keys_subset_of_model_profile` in
model-profiles — feeds a fully-populated model dict to
`_model_data_to_profile()` and asserts every emitted key exists in
`ModelProfile.__annotations__`. CI fails before any release if someone
adds a CLI field without updating the TypedDict

### Partner packages
- Migrate all 10 in-repo partners to the `_resolve_model_profile()`
hook, replacing duplicated `@model_validator` / `_set_model_profile`
overrides: anthropic, deepseek, fireworks, groq, huggingface, mistralai,
openai (base + azure), openrouter, perplexity, xai
- Anthropic retains custom logic (context-1m beta → `max_input_tokens`
override); all others reduce to a one-liner
- Add `pr_lint.yml` scope for the new `model-profiles` package
2026-03-23 00:44:27 -04:00

404 lines
14 KiB
Python

"""CLI for refreshing model profile data from models.dev."""
import argparse
import json
import re
import sys
import tempfile
import warnings
from pathlib import Path
from typing import Any, get_type_hints
import httpx
try:
import tomllib # type: ignore[import-not-found] # Python 3.11+
except ImportError:
import tomli as tomllib # type: ignore[import-not-found,no-redef]
def _validate_data_dir(data_dir: Path) -> Path:
"""Validate and canonicalize data directory path.
Args:
data_dir: User-provided data directory path.
Returns:
Resolved, canonical path.
Raises:
SystemExit: If user declines to write outside current directory.
"""
# Resolve to absolute, canonical path (follows symlinks)
try:
resolved = data_dir.resolve(strict=False)
except (OSError, RuntimeError) as e:
msg = f"Invalid data directory path: {e}"
print(f"{msg}", file=sys.stderr)
sys.exit(1)
# Warn if writing outside current directory
cwd = Path.cwd().resolve()
try:
resolved.relative_to(cwd)
except ValueError:
# Not relative to cwd
print("⚠️ WARNING: Writing outside current directory", file=sys.stderr)
print(f" Current directory: {cwd}", file=sys.stderr)
print(f" Target directory: {resolved}", file=sys.stderr)
print(file=sys.stderr)
response = input("Continue? (y/N): ")
if response.lower() != "y":
print("Aborted.", file=sys.stderr)
sys.exit(1)
return resolved
def _load_augmentations(
data_dir: Path,
) -> tuple[dict[str, Any], dict[str, dict[str, Any]]]:
"""Load augmentations from `profile_augmentations.toml`.
Args:
data_dir: Directory containing `profile_augmentations.toml`.
Returns:
Tuple of `(provider_augmentations, model_augmentations)`.
"""
aug_file = data_dir / "profile_augmentations.toml"
if not aug_file.exists():
return {}, {}
try:
with aug_file.open("rb") as f:
data = tomllib.load(f)
except PermissionError:
msg = f"Permission denied reading augmentations file: {aug_file}"
print(f"{msg}", file=sys.stderr)
sys.exit(1)
except tomllib.TOMLDecodeError as e:
msg = f"Invalid TOML syntax in augmentations file: {e}"
print(f"{msg}", file=sys.stderr)
sys.exit(1)
except OSError as e:
msg = f"Failed to read augmentations file: {e}"
print(f"{msg}", file=sys.stderr)
sys.exit(1)
overrides = data.get("overrides", {})
provider_aug: dict[str, Any] = {}
model_augs: dict[str, dict[str, Any]] = {}
for key, value in overrides.items():
if isinstance(value, dict):
model_augs[key] = value
else:
provider_aug[key] = value
return provider_aug, model_augs
def _model_data_to_profile(model_data: dict[str, Any]) -> dict[str, Any]:
"""Convert raw models.dev data into the canonical profile structure."""
limit = model_data.get("limit") or {}
modalities = model_data.get("modalities") or {}
input_modalities = modalities.get("input") or []
output_modalities = modalities.get("output") or []
profile = {
"name": model_data.get("name"),
"status": model_data.get("status"),
"release_date": model_data.get("release_date"),
"last_updated": model_data.get("last_updated"),
"open_weights": model_data.get("open_weights"),
"max_input_tokens": limit.get("context"),
"max_output_tokens": limit.get("output"),
"text_inputs": "text" in input_modalities,
"image_inputs": "image" in input_modalities,
"audio_inputs": "audio" in input_modalities,
"pdf_inputs": "pdf" in input_modalities or model_data.get("pdf_inputs"),
"video_inputs": "video" in input_modalities,
"text_outputs": "text" in output_modalities,
"image_outputs": "image" in output_modalities,
"audio_outputs": "audio" in output_modalities,
"video_outputs": "video" in output_modalities,
"reasoning_output": model_data.get("reasoning"),
"tool_calling": model_data.get("tool_call"),
"tool_choice": model_data.get("tool_choice"),
"structured_output": model_data.get("structured_output"),
"attachment": model_data.get("attachment"),
"temperature": model_data.get("temperature"),
"image_url_inputs": model_data.get("image_url_inputs"),
"image_tool_message": model_data.get("image_tool_message"),
"pdf_tool_message": model_data.get("pdf_tool_message"),
}
return {k: v for k, v in profile.items() if v is not None}
def _apply_overrides(
profile: dict[str, Any], *overrides: dict[str, Any] | None
) -> dict[str, Any]:
"""Merge provider and model overrides onto the canonical profile."""
merged = dict(profile)
for override in overrides:
if not override:
continue
for key, value in override.items():
if value is not None:
merged[key] = value # noqa: PERF403
return merged
def _warn_undeclared_profile_keys(
profiles: dict[str, dict[str, Any]],
) -> None:
"""Warn if any profile keys are not declared in `ModelProfile`.
Args:
profiles: Mapping of model IDs to their profile dicts.
"""
try:
from langchain_core.language_models.model_profile import ModelProfile
except ImportError:
# langchain-core may not be installed or importable; skip check.
return
try:
declared = set(get_type_hints(ModelProfile).keys())
except (TypeError, NameError):
# get_type_hints raises NameError on unresolvable forward refs and
# TypeError when annotations evaluate to non-type objects.
return
extra = sorted({k for p in profiles.values() for k in p} - declared)
if extra:
warnings.warn(
f"Profile keys not declared in langchain_core ModelProfile: {extra}. "
f"Add these fields to "
f"langchain_core.language_models.model_profile.ModelProfile and "
f"release langchain-core before publishing partner packages that "
f"use these profiles.",
stacklevel=2,
)
def _ensure_safe_output_path(base_dir: Path, output_file: Path) -> None:
"""Ensure the resolved output path remains inside the expected directory."""
if base_dir.exists() and base_dir.is_symlink():
msg = f"Data directory {base_dir} is a symlink; refusing to write profiles."
print(f"{msg}", file=sys.stderr)
sys.exit(1)
if output_file.exists() and output_file.is_symlink():
msg = (
f"profiles.py at {output_file} is a symlink; refusing to overwrite it.\n"
"Delete the symlink or point --data-dir to a safe location."
)
print(f"{msg}", file=sys.stderr)
sys.exit(1)
try:
output_file.resolve(strict=False).relative_to(base_dir.resolve())
except (OSError, RuntimeError) as e:
msg = f"Failed to resolve output path: {e}"
print(f"{msg}", file=sys.stderr)
sys.exit(1)
except ValueError:
msg = f"Refusing to write outside of data directory: {output_file}"
print(f"{msg}", file=sys.stderr)
sys.exit(1)
def _write_profiles_file(output_file: Path, contents: str) -> None:
"""Write the generated module atomically without following symlinks."""
_ensure_safe_output_path(output_file.parent, output_file)
temp_path: Path | None = None
try:
with tempfile.NamedTemporaryFile(
mode="w", encoding="utf-8", dir=output_file.parent, delete=False
) as tmp_file:
tmp_file.write(contents)
temp_path = Path(tmp_file.name)
temp_path.replace(output_file)
except PermissionError:
msg = f"Permission denied writing file: {output_file}"
print(f"{msg}", file=sys.stderr)
if temp_path:
temp_path.unlink(missing_ok=True)
sys.exit(1)
except OSError as e:
msg = f"Failed to write file: {e}"
print(f"{msg}", file=sys.stderr)
if temp_path:
temp_path.unlink(missing_ok=True)
sys.exit(1)
MODULE_ADMONITION = """Auto-generated model profiles.
DO NOT EDIT THIS FILE MANUALLY.
This file is generated by the langchain-profiles CLI tool.
It contains data derived from the models.dev project.
Source: https://github.com/sst/models.dev
License: MIT License
To update these data, refer to the instructions here:
https://docs.langchain.com/oss/python/langchain/models#updating-or-overwriting-profile-data
"""
def refresh(provider: str, data_dir: Path) -> None: # noqa: C901, PLR0915
"""Download and merge model profile data for a specific provider.
Args:
provider: Provider ID from models.dev (e.g., `'anthropic'`, `'openai'`).
data_dir: Directory containing `profile_augmentations.toml` and where
`profiles.py` will be written.
"""
# Validate and canonicalize data directory path
data_dir = _validate_data_dir(data_dir)
api_url = "https://models.dev/api.json"
print(f"Provider: {provider}")
print(f"Data directory: {data_dir}")
print()
# Download data from models.dev
print(f"Downloading data from {api_url}...")
try:
response = httpx.get(api_url, timeout=30)
response.raise_for_status()
except httpx.TimeoutException:
msg = f"Request timed out connecting to {api_url}"
print(f"{msg}", file=sys.stderr)
sys.exit(1)
except httpx.HTTPStatusError as e:
msg = f"HTTP error {e.response.status_code} from {api_url}"
print(f"{msg}", file=sys.stderr)
sys.exit(1)
except httpx.RequestError as e:
msg = f"Failed to connect to {api_url}: {e}"
print(f"{msg}", file=sys.stderr)
sys.exit(1)
try:
all_data = response.json()
except json.JSONDecodeError as e:
msg = f"Invalid JSON response from API: {e}"
print(f"{msg}", file=sys.stderr)
sys.exit(1)
# Basic validation
if not isinstance(all_data, dict):
msg = "Expected API response to be a dictionary"
print(f"{msg}", file=sys.stderr)
sys.exit(1)
provider_count = len(all_data)
model_count = sum(len(p.get("models", {})) for p in all_data.values())
print(f"Downloaded {provider_count} providers with {model_count} models")
# Extract data for this provider
if provider not in all_data:
msg = f"Provider '{provider}' not found in models.dev data"
print(msg, file=sys.stderr)
sys.exit(1)
provider_data = all_data[provider]
models = provider_data.get("models", {})
print(f"Extracted {len(models)} models for {provider}")
# Load augmentations
print("Loading augmentations...")
provider_aug, model_augs = _load_augmentations(data_dir)
# Merge and convert to profiles
profiles: dict[str, dict[str, Any]] = {}
for model_id, model_data in models.items():
base_profile = _model_data_to_profile(model_data)
profiles[model_id] = _apply_overrides(
base_profile, provider_aug, model_augs.get(model_id)
)
# Include new models defined purely via augmentations
extra_models = set(model_augs) - set(models)
if extra_models:
print(f"Adding {len(extra_models)} models from augmentations only...")
for model_id in sorted(extra_models):
profiles[model_id] = _apply_overrides({}, provider_aug, model_augs[model_id])
_warn_undeclared_profile_keys(profiles)
# Ensure directory exists
try:
data_dir.mkdir(parents=True, exist_ok=True, mode=0o755)
except PermissionError:
msg = f"Permission denied creating directory: {data_dir}"
print(f"{msg}", file=sys.stderr)
sys.exit(1)
except OSError as e:
msg = f"Failed to create directory: {e}"
print(f"{msg}", file=sys.stderr)
sys.exit(1)
# Write as Python module
output_file = data_dir / "_profiles.py"
print(f"Writing to {output_file}...")
module_content = [f'"""{MODULE_ADMONITION}"""\n\n', "from typing import Any\n\n"]
module_content.append("_PROFILES: dict[str, dict[str, Any]] = ")
json_str = json.dumps(dict(sorted(profiles.items())), indent=4)
json_str = (
json_str.replace("true", "True")
.replace("false", "False")
.replace("null", "None")
)
# Add trailing commas for ruff format compliance
json_str = re.sub(r"([^\s,{\[])(?=\n\s*[\}\]])", r"\1,", json_str)
module_content.append(f"{json_str}\n")
_write_profiles_file(output_file, "".join(module_content))
print(
f"✓ Successfully refreshed {len(profiles)} model profiles "
f"({output_file.stat().st_size:,} bytes)"
)
def main() -> None:
"""CLI entrypoint."""
parser = argparse.ArgumentParser(
description="Refresh model profile data from models.dev",
prog="langchain-profiles",
)
subparsers = parser.add_subparsers(dest="command", required=True)
# refresh command
refresh_parser = subparsers.add_parser(
"refresh", help="Download and merge model profile data for a provider"
)
refresh_parser.add_argument(
"--provider",
required=True,
help="Provider ID from models.dev (e.g., 'anthropic', 'openai', 'google')",
)
refresh_parser.add_argument(
"--data-dir",
required=True,
type=Path,
help="Data directory containing profile_augmentations.toml",
)
args = parser.parse_args()
if args.command == "refresh":
refresh(args.provider, args.data_dir)
if __name__ == "__main__":
main()