Files
Mason Daugherty 5d9568b5f5 feat(model-profiles): new fields + Makefile target (#35788)
Extract additional fields from models.dev into `_model_data_to_profile`:
`name`, `status`, `release_date`, `last_updated`, `open_weights`,
`attachment`, `temperature`

Move the model profile refresh logic from an inline bash script in the
GitHub Actions workflow into a `make refresh-profiles` target in
`libs/model-profiles/Makefile`. This makes it runnable locally with a
single command and keeps the provider map in one place instead of
duplicated between CI and developer docs.
2026-03-12 13:56:25 +00:00

369 lines
12 KiB
Python

"""CLI for refreshing model profile data from models.dev."""
import argparse
import json
import re
import sys
import tempfile
from pathlib import Path
from typing import Any
import httpx
try:
import tomllib # type: ignore[import-not-found] # Python 3.11+
except ImportError:
import tomli as tomllib # type: ignore[import-not-found,no-redef]
def _validate_data_dir(data_dir: Path) -> Path:
"""Validate and canonicalize data directory path.
Args:
data_dir: User-provided data directory path.
Returns:
Resolved, canonical path.
Raises:
SystemExit: If user declines to write outside current directory.
"""
# Resolve to absolute, canonical path (follows symlinks)
try:
resolved = data_dir.resolve(strict=False)
except (OSError, RuntimeError) as e:
msg = f"Invalid data directory path: {e}"
print(f"{msg}", file=sys.stderr)
sys.exit(1)
# Warn if writing outside current directory
cwd = Path.cwd().resolve()
try:
resolved.relative_to(cwd)
except ValueError:
# Not relative to cwd
print("⚠️ WARNING: Writing outside current directory", file=sys.stderr)
print(f" Current directory: {cwd}", file=sys.stderr)
print(f" Target directory: {resolved}", file=sys.stderr)
print(file=sys.stderr)
response = input("Continue? (y/N): ")
if response.lower() != "y":
print("Aborted.", file=sys.stderr)
sys.exit(1)
return resolved
def _load_augmentations(
data_dir: Path,
) -> tuple[dict[str, Any], dict[str, dict[str, Any]]]:
"""Load augmentations from `profile_augmentations.toml`.
Args:
data_dir: Directory containing `profile_augmentations.toml`.
Returns:
Tuple of `(provider_augmentations, model_augmentations)`.
"""
aug_file = data_dir / "profile_augmentations.toml"
if not aug_file.exists():
return {}, {}
try:
with aug_file.open("rb") as f:
data = tomllib.load(f)
except PermissionError:
msg = f"Permission denied reading augmentations file: {aug_file}"
print(f"{msg}", file=sys.stderr)
sys.exit(1)
except tomllib.TOMLDecodeError as e:
msg = f"Invalid TOML syntax in augmentations file: {e}"
print(f"{msg}", file=sys.stderr)
sys.exit(1)
except OSError as e:
msg = f"Failed to read augmentations file: {e}"
print(f"{msg}", file=sys.stderr)
sys.exit(1)
overrides = data.get("overrides", {})
provider_aug: dict[str, Any] = {}
model_augs: dict[str, dict[str, Any]] = {}
for key, value in overrides.items():
if isinstance(value, dict):
model_augs[key] = value
else:
provider_aug[key] = value
return provider_aug, model_augs
def _model_data_to_profile(model_data: dict[str, Any]) -> dict[str, Any]:
"""Convert raw models.dev data into the canonical profile structure."""
limit = model_data.get("limit") or {}
modalities = model_data.get("modalities") or {}
input_modalities = modalities.get("input") or []
output_modalities = modalities.get("output") or []
profile = {
"name": model_data.get("name"),
"status": model_data.get("status"),
"release_date": model_data.get("release_date"),
"last_updated": model_data.get("last_updated"),
"open_weights": model_data.get("open_weights"),
"max_input_tokens": limit.get("context"),
"max_output_tokens": limit.get("output"),
"text_inputs": "text" in input_modalities,
"image_inputs": "image" in input_modalities,
"audio_inputs": "audio" in input_modalities,
"pdf_inputs": "pdf" in input_modalities or model_data.get("pdf_inputs"),
"video_inputs": "video" in input_modalities,
"text_outputs": "text" in output_modalities,
"image_outputs": "image" in output_modalities,
"audio_outputs": "audio" in output_modalities,
"video_outputs": "video" in output_modalities,
"reasoning_output": model_data.get("reasoning"),
"tool_calling": model_data.get("tool_call"),
"tool_choice": model_data.get("tool_choice"),
"structured_output": model_data.get("structured_output"),
"attachment": model_data.get("attachment"),
"temperature": model_data.get("temperature"),
"image_url_inputs": model_data.get("image_url_inputs"),
"image_tool_message": model_data.get("image_tool_message"),
"pdf_tool_message": model_data.get("pdf_tool_message"),
}
return {k: v for k, v in profile.items() if v is not None}
def _apply_overrides(
profile: dict[str, Any], *overrides: dict[str, Any] | None
) -> dict[str, Any]:
"""Merge provider and model overrides onto the canonical profile."""
merged = dict(profile)
for override in overrides:
if not override:
continue
for key, value in override.items():
if value is not None:
merged[key] = value # noqa: PERF403
return merged
def _ensure_safe_output_path(base_dir: Path, output_file: Path) -> None:
"""Ensure the resolved output path remains inside the expected directory."""
if base_dir.exists() and base_dir.is_symlink():
msg = f"Data directory {base_dir} is a symlink; refusing to write profiles."
print(f"{msg}", file=sys.stderr)
sys.exit(1)
if output_file.exists() and output_file.is_symlink():
msg = (
f"profiles.py at {output_file} is a symlink; refusing to overwrite it.\n"
"Delete the symlink or point --data-dir to a safe location."
)
print(f"{msg}", file=sys.stderr)
sys.exit(1)
try:
output_file.resolve(strict=False).relative_to(base_dir.resolve())
except (OSError, RuntimeError) as e:
msg = f"Failed to resolve output path: {e}"
print(f"{msg}", file=sys.stderr)
sys.exit(1)
except ValueError:
msg = f"Refusing to write outside of data directory: {output_file}"
print(f"{msg}", file=sys.stderr)
sys.exit(1)
def _write_profiles_file(output_file: Path, contents: str) -> None:
"""Write the generated module atomically without following symlinks."""
_ensure_safe_output_path(output_file.parent, output_file)
temp_path: Path | None = None
try:
with tempfile.NamedTemporaryFile(
mode="w", encoding="utf-8", dir=output_file.parent, delete=False
) as tmp_file:
tmp_file.write(contents)
temp_path = Path(tmp_file.name)
temp_path.replace(output_file)
except PermissionError:
msg = f"Permission denied writing file: {output_file}"
print(f"{msg}", file=sys.stderr)
if temp_path:
temp_path.unlink(missing_ok=True)
sys.exit(1)
except OSError as e:
msg = f"Failed to write file: {e}"
print(f"{msg}", file=sys.stderr)
if temp_path:
temp_path.unlink(missing_ok=True)
sys.exit(1)
MODULE_ADMONITION = """Auto-generated model profiles.
DO NOT EDIT THIS FILE MANUALLY.
This file is generated by the langchain-profiles CLI tool.
It contains data derived from the models.dev project.
Source: https://github.com/sst/models.dev
License: MIT License
To update these data, refer to the instructions here:
https://docs.langchain.com/oss/python/langchain/models#updating-or-overwriting-profile-data
"""
def refresh(provider: str, data_dir: Path) -> None: # noqa: C901, PLR0915
"""Download and merge model profile data for a specific provider.
Args:
provider: Provider ID from models.dev (e.g., `'anthropic'`, `'openai'`).
data_dir: Directory containing `profile_augmentations.toml` and where
`profiles.py` will be written.
"""
# Validate and canonicalize data directory path
data_dir = _validate_data_dir(data_dir)
api_url = "https://models.dev/api.json"
print(f"Provider: {provider}")
print(f"Data directory: {data_dir}")
print()
# Download data from models.dev
print(f"Downloading data from {api_url}...")
try:
response = httpx.get(api_url, timeout=30)
response.raise_for_status()
except httpx.TimeoutException:
msg = f"Request timed out connecting to {api_url}"
print(f"{msg}", file=sys.stderr)
sys.exit(1)
except httpx.HTTPStatusError as e:
msg = f"HTTP error {e.response.status_code} from {api_url}"
print(f"{msg}", file=sys.stderr)
sys.exit(1)
except httpx.RequestError as e:
msg = f"Failed to connect to {api_url}: {e}"
print(f"{msg}", file=sys.stderr)
sys.exit(1)
try:
all_data = response.json()
except json.JSONDecodeError as e:
msg = f"Invalid JSON response from API: {e}"
print(f"{msg}", file=sys.stderr)
sys.exit(1)
# Basic validation
if not isinstance(all_data, dict):
msg = "Expected API response to be a dictionary"
print(f"{msg}", file=sys.stderr)
sys.exit(1)
provider_count = len(all_data)
model_count = sum(len(p.get("models", {})) for p in all_data.values())
print(f"Downloaded {provider_count} providers with {model_count} models")
# Extract data for this provider
if provider not in all_data:
msg = f"Provider '{provider}' not found in models.dev data"
print(msg, file=sys.stderr)
sys.exit(1)
provider_data = all_data[provider]
models = provider_data.get("models", {})
print(f"Extracted {len(models)} models for {provider}")
# Load augmentations
print("Loading augmentations...")
provider_aug, model_augs = _load_augmentations(data_dir)
# Merge and convert to profiles
profiles: dict[str, dict[str, Any]] = {}
for model_id, model_data in models.items():
base_profile = _model_data_to_profile(model_data)
profiles[model_id] = _apply_overrides(
base_profile, provider_aug, model_augs.get(model_id)
)
# Include new models defined purely via augmentations
extra_models = set(model_augs) - set(models)
if extra_models:
print(f"Adding {len(extra_models)} models from augmentations only...")
for model_id in sorted(extra_models):
profiles[model_id] = _apply_overrides({}, provider_aug, model_augs[model_id])
# Ensure directory exists
try:
data_dir.mkdir(parents=True, exist_ok=True, mode=0o755)
except PermissionError:
msg = f"Permission denied creating directory: {data_dir}"
print(f"{msg}", file=sys.stderr)
sys.exit(1)
except OSError as e:
msg = f"Failed to create directory: {e}"
print(f"{msg}", file=sys.stderr)
sys.exit(1)
# Write as Python module
output_file = data_dir / "_profiles.py"
print(f"Writing to {output_file}...")
module_content = [f'"""{MODULE_ADMONITION}"""\n\n', "from typing import Any\n\n"]
module_content.append("_PROFILES: dict[str, dict[str, Any]] = ")
json_str = json.dumps(dict(sorted(profiles.items())), indent=4)
json_str = (
json_str.replace("true", "True")
.replace("false", "False")
.replace("null", "None")
)
# Add trailing commas for ruff format compliance
json_str = re.sub(r"([^\s,{\[])(?=\n\s*[\}\]])", r"\1,", json_str)
module_content.append(f"{json_str}\n")
_write_profiles_file(output_file, "".join(module_content))
print(
f"✓ Successfully refreshed {len(profiles)} model profiles "
f"({output_file.stat().st_size:,} bytes)"
)
def main() -> None:
"""CLI entrypoint."""
parser = argparse.ArgumentParser(
description="Refresh model profile data from models.dev",
prog="langchain-profiles",
)
subparsers = parser.add_subparsers(dest="command", required=True)
# refresh command
refresh_parser = subparsers.add_parser(
"refresh", help="Download and merge model profile data for a provider"
)
refresh_parser.add_argument(
"--provider",
required=True,
help="Provider ID from models.dev (e.g., 'anthropic', 'openai', 'google')",
)
refresh_parser.add_argument(
"--data-dir",
required=True,
type=Path,
help="Data directory containing profile_augmentations.toml",
)
args = parser.parse_args()
if args.command == "refresh":
refresh(args.provider, args.data_dir)
if __name__ == "__main__":
main()