mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-12 23:42:51 +00:00
chore: merge master, resolve profile/version conflicts
This commit is contained in:
3
.github/workflows/pr_lint.yml
vendored
3
.github/workflows/pr_lint.yml
vendored
@@ -31,7 +31,7 @@
|
||||
# core, langchain, langchain-classic, model-profiles,
|
||||
# standard-tests, text-splitters, docs, anthropic, chroma, deepseek, exa,
|
||||
# fireworks, groq, huggingface, mistralai, nomic, ollama, openai,
|
||||
# perplexity, qdrant, xai, infra, deps
|
||||
# perplexity, qdrant, xai, infra, deps, partners
|
||||
#
|
||||
# Multiple scopes can be used by separating them with a comma. For example:
|
||||
#
|
||||
@@ -119,6 +119,7 @@ jobs:
|
||||
xai
|
||||
infra
|
||||
deps
|
||||
partners
|
||||
requireScope: false
|
||||
disallowScopes: |
|
||||
release
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import inspect
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -11,8 +12,8 @@ from functools import cached_property
|
||||
from operator import itemgetter
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing_extensions import override
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from typing_extensions import Self, override
|
||||
|
||||
from langchain_core.caches import BaseCache
|
||||
from langchain_core.callbacks import (
|
||||
@@ -32,7 +33,10 @@ from langchain_core.language_models.base import (
|
||||
LangSmithParams,
|
||||
LanguageModelInput,
|
||||
)
|
||||
from langchain_core.language_models.model_profile import ModelProfile
|
||||
from langchain_core.language_models.model_profile import (
|
||||
ModelProfile,
|
||||
_warn_unknown_profile_keys,
|
||||
)
|
||||
from langchain_core.load import dumpd, dumps
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
@@ -357,6 +361,54 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
def _resolve_model_profile(self) -> ModelProfile | None:
|
||||
"""Return the default model profile, or `None` if unavailable.
|
||||
|
||||
Override this in subclasses instead of `_set_model_profile`. The base
|
||||
validator calls it automatically and handles assignment. This avoids
|
||||
coupling partner code to Pydantic validator mechanics.
|
||||
|
||||
Each partner needs its own override because things can vary per-partner,
|
||||
such as the attribute that identifies the model (e.g., `model`,
|
||||
`model_name`, `model_id`, `deployment_name`) and the partner-local
|
||||
`_get_default_model_profile` function that reads from each partner's own
|
||||
profile data.
|
||||
"""
|
||||
# TODO: consider adding a `_model_identifier` property on BaseChatModel
|
||||
# to standardize how partners identify their model, which could allow a
|
||||
# default implementation here that calls a shared
|
||||
# profile-loading mechanism.
|
||||
return None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _set_model_profile(self) -> Self:
|
||||
"""Populate `profile` from `_resolve_model_profile` if not provided.
|
||||
|
||||
Partners should override `_resolve_model_profile` rather than this
|
||||
validator. Overriding this with a new `@model_validator` replaces the
|
||||
base validator (Pydantic v2 behavior), bypassing the standard resolution
|
||||
path. A plain method override does not prevent the base validator from
|
||||
running.
|
||||
"""
|
||||
if self.profile is None:
|
||||
# Suppress errors from partner overrides (e.g., missing profile
|
||||
# files, broken imports) so model construction never fails over an
|
||||
# optional field.
|
||||
with contextlib.suppress(Exception):
|
||||
self.profile = self._resolve_model_profile()
|
||||
return self
|
||||
|
||||
# NOTE: _check_profile_keys must be defined AFTER _set_model_profile.
|
||||
# Pydantic v2 runs mode="after" validators in definition order.
|
||||
@model_validator(mode="after")
|
||||
def _check_profile_keys(self) -> Self:
|
||||
"""Warn on unrecognized profile keys."""
|
||||
# isinstance guard: ModelProfile is a TypedDict (always a dict), but
|
||||
# protects against unexpected types from partner overrides.
|
||||
if self.profile and isinstance(self.profile, dict):
|
||||
_warn_unknown_profile_keys(self.profile)
|
||||
return self
|
||||
|
||||
@cached_property
|
||||
def _serialized(self) -> dict[str, Any]:
|
||||
# self is always a Serializable object in this case, thus the result is
|
||||
|
||||
@@ -1,7 +1,14 @@
|
||||
"""Model profile types and utilities."""
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from typing import get_type_hints
|
||||
|
||||
from pydantic import ConfigDict
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelProfile(TypedDict, total=False):
|
||||
"""Model profile.
|
||||
@@ -14,6 +21,25 @@ class ModelProfile(TypedDict, total=False):
|
||||
and supported features.
|
||||
"""
|
||||
|
||||
__pydantic_config__ = ConfigDict(extra="allow") # type: ignore[misc]
|
||||
|
||||
# --- Model metadata ---
|
||||
|
||||
name: str
|
||||
"""Human-readable model name."""
|
||||
|
||||
status: str
|
||||
"""Model status (e.g., `'active'`, `'deprecated'`)."""
|
||||
|
||||
release_date: str
|
||||
"""Model release date (ISO 8601 format, e.g., `'2025-06-01'`)."""
|
||||
|
||||
last_updated: str
|
||||
"""Date the model was last updated (ISO 8601 format)."""
|
||||
|
||||
open_weights: bool
|
||||
"""Whether the model weights are openly available."""
|
||||
|
||||
# --- Input constraints ---
|
||||
|
||||
max_input_tokens: int
|
||||
@@ -86,6 +112,45 @@ class ModelProfile(TypedDict, total=False):
|
||||
"""Whether the model supports a native [structured output](https://docs.langchain.com/oss/python/langchain/models#structured-outputs)
|
||||
feature"""
|
||||
|
||||
# --- Other capabilities ---
|
||||
|
||||
attachment: bool
|
||||
"""Whether the model supports file attachments."""
|
||||
|
||||
temperature: bool
|
||||
"""Whether the model supports a temperature parameter."""
|
||||
|
||||
|
||||
ModelProfileRegistry = dict[str, ModelProfile]
|
||||
"""Registry mapping model identifiers or names to their ModelProfile."""
|
||||
|
||||
|
||||
def _warn_unknown_profile_keys(profile: ModelProfile) -> None:
|
||||
"""Warn if `profile` contains keys not declared on `ModelProfile`.
|
||||
|
||||
Args:
|
||||
profile: The model profile dict to check for undeclared keys.
|
||||
"""
|
||||
if not isinstance(profile, dict):
|
||||
return
|
||||
|
||||
try:
|
||||
declared = frozenset(get_type_hints(ModelProfile).keys())
|
||||
except (TypeError, NameError):
|
||||
# get_type_hints raises NameError on unresolvable forward refs and
|
||||
# TypeError when annotations evaluate to non-type objects.
|
||||
logger.debug(
|
||||
"Could not resolve type hints for ModelProfile; "
|
||||
"skipping unknown-key check.",
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
|
||||
extra = sorted(set(profile) - declared)
|
||||
if extra:
|
||||
warnings.warn(
|
||||
f"Unrecognized keys in model profile: {extra}. "
|
||||
f"This may indicate a version mismatch between langchain-core "
|
||||
f"and your provider package. Consider upgrading langchain-core.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
@@ -6,7 +6,8 @@ from collections.abc import AsyncIterator, Iterator
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import pytest
|
||||
from typing_extensions import override
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import Self, override
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
CallbackManagerForLLMRun,
|
||||
@@ -22,6 +23,7 @@ from langchain_core.language_models.fake_chat_models import (
|
||||
FakeListChatModelError,
|
||||
GenericFakeChatModel,
|
||||
)
|
||||
from langchain_core.language_models.model_profile import ModelProfile
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
@@ -1320,6 +1322,76 @@ def test_model_profiles() -> None:
|
||||
assert model_with_profile.profile == {"max_input_tokens": 100}
|
||||
|
||||
|
||||
def test_resolve_model_profile_hook_populates_profile() -> None:
|
||||
"""_resolve_model_profile is called when profile is None."""
|
||||
|
||||
class ResolverModel(GenericFakeChatModel):
|
||||
def _resolve_model_profile(self) -> ModelProfile | None:
|
||||
return {"max_input_tokens": 500}
|
||||
|
||||
model = ResolverModel(messages=iter([]))
|
||||
assert model.profile == {"max_input_tokens": 500}
|
||||
|
||||
|
||||
def test_resolve_model_profile_hook_skipped_when_explicit() -> None:
|
||||
"""_resolve_model_profile is NOT called when profile is set explicitly."""
|
||||
|
||||
class ResolverModel(GenericFakeChatModel):
|
||||
def _resolve_model_profile(self) -> ModelProfile | None:
|
||||
return {"max_input_tokens": 500}
|
||||
|
||||
model = ResolverModel(messages=iter([]), profile={"max_input_tokens": 999})
|
||||
assert model.profile is not None
|
||||
assert model.profile["max_input_tokens"] == 999
|
||||
|
||||
|
||||
def test_resolve_model_profile_hook_exception_is_caught() -> None:
|
||||
"""Model is still usable if _resolve_model_profile raises."""
|
||||
|
||||
class BrokenProfileModel(GenericFakeChatModel):
|
||||
def _resolve_model_profile(self) -> ModelProfile | None:
|
||||
msg = "profile file not found"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
with warnings.catch_warnings(record=True):
|
||||
warnings.simplefilter("always")
|
||||
model = BrokenProfileModel(messages=iter([]))
|
||||
|
||||
assert model.profile is None
|
||||
|
||||
|
||||
def test_check_profile_keys_runs_despite_partner_override() -> None:
|
||||
"""Verify _check_profile_keys fires even when _set_model_profile is overridden.
|
||||
|
||||
Because _check_profile_keys has a distinct validator name from
|
||||
_set_model_profile, a partner override of the latter does not suppress
|
||||
the key-checking validator.
|
||||
"""
|
||||
|
||||
class PartnerModel(GenericFakeChatModel):
|
||||
"""Simulates a partner that overrides _set_model_profile."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _set_model_profile(self) -> Self:
|
||||
if self.profile is None:
|
||||
profile: dict[str, Any] = {
|
||||
"max_input_tokens": 100,
|
||||
"partner_only_field": True,
|
||||
}
|
||||
self.profile = profile # type: ignore[assignment]
|
||||
return self
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
model = PartnerModel(messages=iter([]))
|
||||
|
||||
assert model.profile is not None
|
||||
assert model.profile.get("partner_only_field") is True
|
||||
profile_warnings = [x for x in w if "Unrecognized keys" in str(x.message)]
|
||||
assert len(profile_warnings) == 1
|
||||
assert "partner_only_field" in str(profile_warnings[0].message)
|
||||
|
||||
|
||||
class MockResponse:
|
||||
"""Mock response for testing _generate_response_from_error."""
|
||||
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
"""Tests for model profile types and utilities."""
|
||||
|
||||
import warnings
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from langchain_core.language_models.model_profile import (
|
||||
ModelProfile,
|
||||
_warn_unknown_profile_keys,
|
||||
)
|
||||
|
||||
|
||||
class TestModelProfileExtraAllow:
|
||||
"""Verify extra='allow' on ModelProfile TypedDict."""
|
||||
|
||||
def test_accepts_declared_keys(self) -> None:
|
||||
profile: ModelProfile = {"max_input_tokens": 100, "tool_calling": True}
|
||||
assert profile["max_input_tokens"] == 100
|
||||
|
||||
def test_extra_keys_accepted_via_typed_dict(self) -> None:
|
||||
"""ModelProfile TypedDict allows extra keys at construction."""
|
||||
profile = ModelProfile(
|
||||
max_input_tokens=100,
|
||||
unknown_future_field="value", # type: ignore[typeddict-unknown-key]
|
||||
)
|
||||
assert profile["unknown_future_field"] == "value" # type: ignore[typeddict-item]
|
||||
|
||||
def test_extra_keys_survive_pydantic_validation(self) -> None:
|
||||
"""Extra keys pass through even when parent model forbids extras."""
|
||||
|
||||
class StrictModel(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
profile: ModelProfile | None = Field(default=None)
|
||||
|
||||
m = StrictModel(
|
||||
profile={
|
||||
"max_input_tokens": 100,
|
||||
"unknown_future_field": True,
|
||||
}
|
||||
)
|
||||
assert m.profile is not None
|
||||
assert m.profile.get("unknown_future_field") is True
|
||||
|
||||
|
||||
class TestWarnUnknownProfileKeys:
|
||||
"""Tests for _warn_unknown_profile_keys."""
|
||||
|
||||
def test_warns_on_extra_keys(self) -> None:
|
||||
profile: dict[str, Any] = {
|
||||
"max_input_tokens": 100,
|
||||
"future_field": True,
|
||||
"another": "val",
|
||||
}
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
_warn_unknown_profile_keys(profile) # type: ignore[arg-type]
|
||||
|
||||
assert len(w) == 1
|
||||
assert "another" in str(w[0].message)
|
||||
assert "future_field" in str(w[0].message)
|
||||
assert "upgrading langchain-core" in str(w[0].message)
|
||||
|
||||
def test_silent_on_declared_keys_only(self) -> None:
|
||||
profile: ModelProfile = {"max_input_tokens": 100, "tool_calling": True}
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
_warn_unknown_profile_keys(profile)
|
||||
|
||||
assert len(w) == 0
|
||||
|
||||
def test_silent_on_empty_profile(self) -> None:
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
_warn_unknown_profile_keys({})
|
||||
|
||||
assert len(w) == 0
|
||||
|
||||
def test_survives_get_type_hints_failure(self) -> None:
|
||||
"""Falls back to silent skip on TypeError from get_type_hints."""
|
||||
profile: dict[str, Any] = {"max_input_tokens": 100, "extra": True}
|
||||
with patch(
|
||||
"langchain_core.language_models.model_profile.get_type_hints",
|
||||
side_effect=TypeError("broken"),
|
||||
):
|
||||
_warn_unknown_profile_keys(profile) # type: ignore[arg-type]
|
||||
@@ -388,36 +388,13 @@ def test_summarization_middleware_token_retention_preserves_ai_tool_pairs() -> N
|
||||
|
||||
|
||||
def test_summarization_middleware_missing_profile() -> None:
|
||||
"""Ensure automatic profile inference falls back when profiles are unavailable."""
|
||||
|
||||
class ImportErrorProfileModel(BaseChatModel):
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "mock"
|
||||
|
||||
# NOTE: Using __getattribute__ because @property cannot override Pydantic fields.
|
||||
def __getattribute__(self, name: str) -> Any:
|
||||
if name == "profile":
|
||||
msg = "Profile not available"
|
||||
raise AttributeError(msg)
|
||||
return super().__getattribute__(name)
|
||||
|
||||
"""Ensure fractional limits fail when model has no profile data."""
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Model profile information is required to use fractional token limits",
|
||||
):
|
||||
_ = SummarizationMiddleware(
|
||||
model=ImportErrorProfileModel(), trigger=("fraction", 0.5), keep=("messages", 1)
|
||||
model=MockChatModel(), trigger=("fraction", 0.5), keep=("messages", 1)
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -5,8 +5,9 @@ import json
|
||||
import re
|
||||
import sys
|
||||
import tempfile
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, get_type_hints
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -150,6 +151,38 @@ def _apply_overrides(
|
||||
return merged
|
||||
|
||||
|
||||
def _warn_undeclared_profile_keys(
|
||||
profiles: dict[str, dict[str, Any]],
|
||||
) -> None:
|
||||
"""Warn if any profile keys are not declared in `ModelProfile`.
|
||||
|
||||
Args:
|
||||
profiles: Mapping of model IDs to their profile dicts.
|
||||
"""
|
||||
try:
|
||||
from langchain_core.language_models.model_profile import ModelProfile
|
||||
except ImportError:
|
||||
# langchain-core may not be installed or importable; skip check.
|
||||
return
|
||||
|
||||
try:
|
||||
declared = set(get_type_hints(ModelProfile).keys())
|
||||
except (TypeError, NameError):
|
||||
# get_type_hints raises NameError on unresolvable forward refs and
|
||||
# TypeError when annotations evaluate to non-type objects.
|
||||
return
|
||||
extra = sorted({k for p in profiles.values() for k in p} - declared)
|
||||
if extra:
|
||||
warnings.warn(
|
||||
f"Profile keys not declared in langchain_core ModelProfile: {extra}. "
|
||||
f"Add these fields to "
|
||||
f"langchain_core.language_models.model_profile.ModelProfile and "
|
||||
f"release langchain-core before publishing partner packages that "
|
||||
f"use these profiles.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
def _ensure_safe_output_path(base_dir: Path, output_file: Path) -> None:
|
||||
"""Ensure the resolved output path remains inside the expected directory."""
|
||||
if base_dir.exists() and base_dir.is_symlink():
|
||||
@@ -300,6 +333,8 @@ def refresh(provider: str, data_dir: Path) -> None: # noqa: C901, PLR0915
|
||||
for model_id in sorted(extra_models):
|
||||
profiles[model_id] = _apply_overrides({}, provider_aug, model_augs[model_id])
|
||||
|
||||
_warn_undeclared_profile_keys(profiles)
|
||||
|
||||
# Ensure directory exists
|
||||
try:
|
||||
data_dir.mkdir(parents=True, exist_ok=True, mode=0o755)
|
||||
|
||||
@@ -1,12 +1,19 @@
|
||||
"""Tests for CLI functionality."""
|
||||
|
||||
import importlib.util
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, get_type_hints
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models.model_profile import ModelProfile
|
||||
|
||||
from langchain_model_profiles.cli import _model_data_to_profile, refresh
|
||||
from langchain_model_profiles.cli import (
|
||||
_model_data_to_profile,
|
||||
_warn_undeclared_profile_keys,
|
||||
refresh,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -364,3 +371,104 @@ def test_model_data_to_profile_text_modalities() -> None:
|
||||
profile = _model_data_to_profile(image_gen_model)
|
||||
assert profile["text_inputs"] is True
|
||||
assert profile["text_outputs"] is False
|
||||
|
||||
|
||||
def test_model_data_to_profile_keys_subset_of_model_profile() -> None:
|
||||
"""All CLI-emitted profile keys must be declared in `ModelProfile`."""
|
||||
# Build a model_data dict with every possible field populated so
|
||||
# _model_data_to_profile includes all keys it can emit.
|
||||
model_data = {
|
||||
"id": "test-model",
|
||||
"name": "Test Model",
|
||||
"status": "active",
|
||||
"release_date": "2025-01-01",
|
||||
"last_updated": "2025-01-01",
|
||||
"open_weights": True,
|
||||
"reasoning": True,
|
||||
"tool_call": True,
|
||||
"tool_choice": True,
|
||||
"structured_output": True,
|
||||
"attachment": True,
|
||||
"temperature": True,
|
||||
"image_url_inputs": True,
|
||||
"image_tool_message": True,
|
||||
"pdf_tool_message": True,
|
||||
"pdf_inputs": True,
|
||||
"limit": {"context": 100000, "output": 4096},
|
||||
"modalities": {
|
||||
"input": ["text", "image", "audio", "video", "pdf"],
|
||||
"output": ["text", "image", "audio", "video"],
|
||||
},
|
||||
}
|
||||
|
||||
profile = _model_data_to_profile(model_data)
|
||||
declared_fields = set(get_type_hints(ModelProfile).keys())
|
||||
emitted_fields = set(profile.keys())
|
||||
extra = emitted_fields - declared_fields
|
||||
|
||||
assert not extra, (
|
||||
f"CLI emits profile keys not declared in ModelProfile: {sorted(extra)}. "
|
||||
f"Add these fields to langchain_core.language_models.model_profile."
|
||||
f"ModelProfile and release langchain-core before refreshing partner "
|
||||
f"profiles."
|
||||
)
|
||||
|
||||
|
||||
class TestWarnUndeclaredProfileKeys:
|
||||
"""Tests for _warn_undeclared_profile_keys."""
|
||||
|
||||
def test_warns_on_undeclared_keys(self) -> None:
|
||||
"""Extra keys across profiles trigger a single warning."""
|
||||
profiles: dict[str, dict[str, Any]] = {
|
||||
"model-a": {"max_input_tokens": 100, "future_key": True},
|
||||
"model-b": {"another_key": "val"},
|
||||
}
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
_warn_undeclared_profile_keys(profiles)
|
||||
|
||||
assert len(w) == 1
|
||||
assert "another_key" in str(w[0].message)
|
||||
assert "future_key" in str(w[0].message)
|
||||
|
||||
def test_silent_on_declared_keys_only(self) -> None:
|
||||
"""No warning when all keys are declared in ModelProfile."""
|
||||
profiles: dict[str, dict[str, Any]] = {
|
||||
"model-a": {"max_input_tokens": 100, "tool_calling": True},
|
||||
}
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
_warn_undeclared_profile_keys(profiles)
|
||||
|
||||
assert len(w) == 0
|
||||
|
||||
def test_silent_when_langchain_core_not_installed(self) -> None:
|
||||
"""Gracefully skips when langchain-core is not importable."""
|
||||
import sys
|
||||
|
||||
profiles: dict[str, dict[str, Any]] = {
|
||||
"model-a": {"unknown": True},
|
||||
}
|
||||
with (
|
||||
patch.dict(
|
||||
sys.modules,
|
||||
{"langchain_core.language_models.model_profile": None},
|
||||
),
|
||||
warnings.catch_warnings(record=True) as w,
|
||||
):
|
||||
warnings.simplefilter("always")
|
||||
_warn_undeclared_profile_keys(profiles)
|
||||
|
||||
undeclared_warnings = [x for x in w if "not declared" in str(x.message)]
|
||||
assert len(undeclared_warnings) == 0
|
||||
|
||||
def test_survives_get_type_hints_failure(self) -> None:
|
||||
"""Gracefully handles TypeError from get_type_hints."""
|
||||
profiles: dict[str, dict[str, Any]] = {
|
||||
"model-a": {"unknown": True},
|
||||
}
|
||||
with patch(
|
||||
"langchain_model_profiles.cli.get_type_hints",
|
||||
side_effect=TypeError("broken"),
|
||||
):
|
||||
_warn_undeclared_profile_keys(profiles)
|
||||
|
||||
@@ -973,18 +973,11 @@ class ChatAnthropic(BaseChatModel):
|
||||
self._add_version("langchain-anthropic", __version__)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _set_model_profile(self) -> Self:
|
||||
"""Set model profile if not overridden."""
|
||||
if self.profile is None:
|
||||
self.profile = _get_default_model_profile(self.model)
|
||||
if (
|
||||
self.profile is not None
|
||||
and self.betas
|
||||
and "context-1m-2025-08-07" in self.betas
|
||||
):
|
||||
self.profile["max_input_tokens"] = 1_000_000
|
||||
return self
|
||||
def _resolve_model_profile(self) -> ModelProfile | None:
|
||||
profile = _get_default_model_profile(self.model) or None
|
||||
if profile is not None and self.betas and "context-1m-2025-08-07" in self.betas:
|
||||
profile["max_input_tokens"] = 1_000_000
|
||||
return profile
|
||||
|
||||
@cached_property
|
||||
def _client_params(self) -> dict[str, Any]:
|
||||
|
||||
@@ -269,12 +269,8 @@ class ChatDeepSeek(BaseChatOpenAI):
|
||||
self.async_client = self.root_async_client.chat.completions
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _set_model_profile(self) -> Self:
|
||||
"""Set model profile if not overridden."""
|
||||
if self.profile is None:
|
||||
self.profile = _get_default_model_profile(self.model_name)
|
||||
return self
|
||||
def _resolve_model_profile(self) -> ModelProfile | None:
|
||||
return _get_default_model_profile(self.model_name) or None
|
||||
|
||||
def _get_request_payload(
|
||||
self,
|
||||
|
||||
@@ -431,12 +431,8 @@ class ChatFireworks(BaseChatModel):
|
||||
self.async_client._max_retries = self.max_retries
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _set_model_profile(self) -> Self:
|
||||
"""Set model profile if not overridden."""
|
||||
if self.profile is None:
|
||||
self.profile = _get_default_model_profile(self.model_name)
|
||||
return self
|
||||
def _resolve_model_profile(self) -> ModelProfile | None:
|
||||
return _get_default_model_profile(self.model_name) or None
|
||||
|
||||
@property
|
||||
def _default_params(self) -> dict[str, Any]:
|
||||
|
||||
@@ -549,12 +549,8 @@ class ChatGroq(BaseChatModel):
|
||||
self._add_version("langchain-groq", __version__)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _set_model_profile(self) -> Self:
|
||||
"""Set model profile if not overridden."""
|
||||
if self.profile is None:
|
||||
self.profile = _get_default_model_profile(self.model_name)
|
||||
return self
|
||||
def _resolve_model_profile(self) -> ModelProfile | None:
|
||||
return _get_default_model_profile(self.model_name) or None
|
||||
|
||||
#
|
||||
# Serializable class method overrides
|
||||
|
||||
@@ -603,12 +603,10 @@ class ChatHuggingFace(BaseChatModel):
|
||||
raise TypeError(msg)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _set_model_profile(self) -> Self:
|
||||
"""Set model profile if not overridden."""
|
||||
if self.profile is None and self.model_id:
|
||||
self.profile = _get_default_model_profile(self.model_id)
|
||||
return self
|
||||
def _resolve_model_profile(self) -> ModelProfile | None:
|
||||
if self.model_id:
|
||||
return _get_default_model_profile(self.model_id) or None
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def from_model_id(
|
||||
|
||||
@@ -653,12 +653,8 @@ class ChatMistralAI(BaseChatModel):
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _set_model_profile(self) -> Self:
|
||||
"""Set model profile if not overridden."""
|
||||
if self.profile is None:
|
||||
self.profile = _get_default_model_profile(self.model)
|
||||
return self
|
||||
def _resolve_model_profile(self) -> ModelProfile | None:
|
||||
return _get_default_model_profile(self.model) or None
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
|
||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import AsyncIterator, Awaitable, Callable, Iterator
|
||||
from typing import Any, Literal, TypeAlias, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypeAlias, TypeVar
|
||||
|
||||
import openai
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
@@ -19,6 +19,9 @@ from typing_extensions import Self
|
||||
|
||||
from langchain_openai.chat_models.base import BaseChatOpenAI, _get_default_model_profile
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.language_models import ModelProfile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -701,12 +704,10 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
self.async_client = self.root_async_client.chat.completions
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _set_model_profile(self) -> Self:
|
||||
"""Set model profile if not overridden."""
|
||||
if self.profile is None and self.deployment_name is not None:
|
||||
self.profile = _get_default_model_profile(self.deployment_name)
|
||||
return self
|
||||
def _resolve_model_profile(self) -> ModelProfile | None:
|
||||
if self.deployment_name is not None:
|
||||
return _get_default_model_profile(self.deployment_name) or None
|
||||
return None
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> dict[str, Any]:
|
||||
|
||||
@@ -1109,12 +1109,8 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
self.async_client = self.root_async_client.chat.completions
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _set_model_profile(self) -> Self:
|
||||
"""Set model profile if not overridden."""
|
||||
if self.profile is None:
|
||||
self.profile = _get_default_model_profile(self.model_name)
|
||||
return self
|
||||
def _resolve_model_profile(self) -> ModelProfile | None:
|
||||
return _get_default_model_profile(self.model_name) or None
|
||||
|
||||
@property
|
||||
def _default_params(self) -> dict[str, Any]:
|
||||
|
||||
@@ -357,12 +357,8 @@ class ChatOpenRouter(BaseChatModel):
|
||||
self.client = openrouter.OpenRouter(**client_kwargs)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _set_model_profile(self) -> Self:
|
||||
"""Set model profile if not overridden."""
|
||||
if self.profile is None:
|
||||
self.profile = _get_default_model_profile(self.model_name)
|
||||
return self
|
||||
def _resolve_model_profile(self) -> ModelProfile | None:
|
||||
return _get_default_model_profile(self.model_name) or None
|
||||
|
||||
#
|
||||
# Serializable class method overrides
|
||||
|
||||
@@ -2971,3 +2971,9 @@ class TestStreamUsage:
|
||||
assert usage["input_tokens"] == 10
|
||||
assert usage["output_tokens"] == 5
|
||||
assert usage["total_tokens"] == 15
|
||||
|
||||
|
||||
def test_profile() -> None:
|
||||
"""Test that the model has a profile."""
|
||||
model = _make_model()
|
||||
assert model.profile
|
||||
|
||||
@@ -312,12 +312,8 @@ class ChatPerplexity(BaseChatModel):
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _set_model_profile(self) -> Self:
|
||||
"""Set model profile if not overridden."""
|
||||
if self.profile is None:
|
||||
self.profile = _get_default_model_profile(self.model)
|
||||
return self
|
||||
def _resolve_model_profile(self) -> ModelProfile | None:
|
||||
return _get_default_model_profile(self.model) or None
|
||||
|
||||
@property
|
||||
def _default_params(self) -> dict[str, Any]:
|
||||
|
||||
@@ -543,12 +543,8 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _set_model_profile(self) -> Self:
|
||||
"""Set model profile if not overridden."""
|
||||
if self.profile is None:
|
||||
self.profile = _get_default_model_profile(self.model_name)
|
||||
return self
|
||||
def _resolve_model_profile(self) -> ModelProfile | None:
|
||||
return _get_default_model_profile(self.model_name) or None
|
||||
|
||||
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGenerationChunk]:
|
||||
"""Route to Chat Completions or Responses API."""
|
||||
|
||||
Reference in New Issue
Block a user