diff --git a/libs/core/langchain_core/_api/deprecation.py b/libs/core/langchain_core/_api/deprecation.py index ccc31d6b9a5..45930c3e6cb 100644 --- a/libs/core/langchain_core/_api/deprecation.py +++ b/libs/core/langchain_core/_api/deprecation.py @@ -15,6 +15,7 @@ import inspect import sys import warnings from collections.abc import Callable, Generator +from contextvars import ContextVar from typing import ( TYPE_CHECKING, Any, @@ -75,6 +76,15 @@ class LangChainPendingDeprecationWarning(PendingDeprecationWarning): """A class for issuing deprecation warnings for LangChain users.""" +# Tracks when callers intentionally silence LangChain deprecation warnings. +# Suppressed warnings should not consume a deprecated callable's one-time +# warning state; otherwise an internal compatibility path can prevent the first +# user-visible call from warning. +_SUPPRESSING_LANGCHAIN_DEPRECATION_WARNING = ContextVar( + "_SUPPRESSING_LANGCHAIN_DEPRECATION_WARNING", default=False +) + + # PUBLIC API @@ -220,16 +230,20 @@ def deprecated( """ nonlocal warned if not warned and not is_caller_internal(): - warned = True emit_warning() + # Only mark the warning as emitted if it was not intentionally + # suppressed by `suppress_langchain_deprecation_warning()`. + warned = not _SUPPRESSING_LANGCHAIN_DEPRECATION_WARNING.get() return wrapped(*args, **kwargs) async def awarning_emitting_wrapper(*args: Any, **kwargs: Any) -> Any: """Same as warning_emitting_wrapper, but for async functions.""" nonlocal warned if not warned and not is_caller_internal(): - warned = True emit_warning() + # Only mark the warning as emitted if it was not intentionally + # suppressed by `suppress_langchain_deprecation_warning()`. + warned = not _SUPPRESSING_LANGCHAIN_DEPRECATION_WARNING.get() return await wrapped(*args, **kwargs) _package = _package or obj.__module__.split(".")[0].replace("_", "-") @@ -253,8 +267,10 @@ def deprecated( """Warn that the class is in beta.""" nonlocal warned if not warned and type(self) is obj and not is_caller_internal(): - warned = True emit_warning() + # Only mark the warning as emitted if it was not intentionally + # suppressed by `suppress_langchain_deprecation_warning()`. + warned = not _SUPPRESSING_LANGCHAIN_DEPRECATION_WARNING.get() return wrapped(self, *args, **kwargs) obj.__init__ = functools.wraps(obj.__init__)( # type: ignore[misc] @@ -451,10 +467,14 @@ def deprecated( @contextlib.contextmanager def suppress_langchain_deprecation_warning() -> Generator[None, None, None]: """Context manager to suppress `LangChainDeprecationWarning`.""" - with warnings.catch_warnings(): - warnings.simplefilter("ignore", LangChainDeprecationWarning) - warnings.simplefilter("ignore", LangChainPendingDeprecationWarning) - yield + token = _SUPPRESSING_LANGCHAIN_DEPRECATION_WARNING.set(True) + try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", LangChainDeprecationWarning) + warnings.simplefilter("ignore", LangChainPendingDeprecationWarning) + yield + finally: + _SUPPRESSING_LANGCHAIN_DEPRECATION_WARNING.reset(token) def warn_deprecated( diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 4738a4e22f7..76c8728c1b3 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import builtins # noqa: TC003 import contextlib import inspect import json @@ -16,7 +17,7 @@ from langchain_protocol.protocol import MessageFinishData from pydantic import BaseModel, ConfigDict, Field, model_validator from typing_extensions import Self, override -from langchain_core._api import beta +from langchain_core._api import beta, deprecated, suppress_langchain_deprecation_warning from langchain_core.caches import BaseCache from langchain_core.callbacks import ( AsyncCallbackManager, @@ -93,7 +94,6 @@ from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass from langchain_core.utils.utils import LC_ID_PREFIX, from_env if TYPE_CHECKING: - import builtins import uuid from collections.abc import Awaitable @@ -430,11 +430,11 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): return self @cached_property - def _serialized(self) -> dict[str, Any]: + def _serialized(self) -> builtins.dict[str, Any]: # self is always a Serializable object in this case, thus the result is - # guaranteed to be a dict since dumps uses the default callback, which uses + # guaranteed to be a dict since dumpd uses the default callback, which uses # obj.to_json which always returns TypedDict subclasses - return cast("dict[str, Any]", dumpd(self)) + return cast("builtins.dict[str, Any]", dumpd(self)) # --- Runnable methods --- @@ -1371,7 +1371,9 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): # --- Custom methods --- - def _combine_llm_outputs(self, _llm_outputs: list[dict | None], /) -> dict: + def _combine_llm_outputs( + self, _llm_outputs: list[builtins.dict | None], / + ) -> builtins.dict: return {} def _convert_cached_generations(self, cache_val: list) -> list[ChatGeneration]: @@ -1464,8 +1466,8 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): self, stop: list[str] | None = None, **kwargs: Any, - ) -> dict: - params = self.dict() + ) -> builtins.dict: + params = self._dict_for_compat() params["stop"] = stop return {**params, **kwargs} @@ -1566,7 +1568,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): callbacks: Callbacks = None, *, tags: list[str] | None = None, - metadata: dict[str, Any] | None = None, + metadata: builtins.dict[str, Any] | None = None, run_name: str | None = None, run_id: uuid.UUID | None = None, **kwargs: Any, @@ -1692,7 +1694,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): callbacks: Callbacks = None, *, tags: list[str] | None = None, - metadata: dict[str, Any] | None = None, + metadata: builtins.dict[str, Any] | None = None, run_name: str | None = None, run_id: uuid.UUID | None = None, **kwargs: Any, @@ -2303,13 +2305,26 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): def _llm_type(self) -> str: """Return type of chat model.""" + @deprecated("1.4.2", alternative="asdict", removal="2.0.0") @override - def dict(self, **kwargs: Any) -> dict: - """Return a dictionary of the LLM.""" + def dict(self, **_kwargs: Any) -> builtins.dict[str, Any]: + """DEPRECATED - use `asdict()` instead. + + Return a dictionary representation of the chat model. + """ + return self.asdict() + + def asdict(self) -> builtins.dict[str, Any]: + """Return a dictionary representation of the chat model.""" starter_dict = dict(self._identifying_params) starter_dict["_type"] = self._llm_type return starter_dict + def _dict_for_compat(self) -> builtins.dict[str, Any]: + """Return the chat model dictionary while preserving deprecated overrides.""" + with suppress_langchain_deprecation_warning(): + return self.dict() + @override def bind(self, **kwargs: Any) -> _ChatModelBinding: """Bind kwargs to this chat model, returning a typed `_ChatModelBinding`. diff --git a/libs/core/langchain_core/language_models/llms.py b/libs/core/langchain_core/language_models/llms.py index 1ace9cb554a..614e12ec16e 100644 --- a/libs/core/langchain_core/language_models/llms.py +++ b/libs/core/langchain_core/language_models/llms.py @@ -6,6 +6,7 @@ These are traditionally older models (newer models generally are chat models). from __future__ import annotations import asyncio +import builtins # noqa: TC003 import functools import inspect import json @@ -32,6 +33,7 @@ from tenacity import ( ) from typing_extensions import override +from langchain_core._api import deprecated, suppress_langchain_deprecation_warning from langchain_core.caches import BaseCache from langchain_core.callbacks import ( AsyncCallbackManager, @@ -301,11 +303,11 @@ class BaseLLM(BaseLanguageModel[str], ABC): ) @functools.cached_property - def _serialized(self) -> dict[str, Any]: + def _serialized(self) -> builtins.dict[str, Any]: # self is always a Serializable object in this case, thus the result is - # guaranteed to be a dict since dumps uses the default callback, which uses + # guaranteed to be a dict since dumpd uses the default callback, which uses # obj.to_json which always returns TypedDict subclasses - return cast("dict[str, Any]", dumpd(self)) + return cast("builtins.dict[str, Any]", dumpd(self)) # --- Runnable methods --- @@ -522,7 +524,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): else: prompt = self._convert_input(input).to_string() config = ensure_config(config) - params = self.dict() + params = self._dict_for_compat() params["stop"] = stop params = {**params, **kwargs} options = {"stop": stop} @@ -595,7 +597,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): prompt = self._convert_input(input).to_string() config = ensure_config(config) - params = self.dict() + params = self._dict_for_compat() params["stop"] = stop params = {**params, **kwargs} options = {"stop": stop} @@ -853,7 +855,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): callbacks: Callbacks | list[Callbacks] | None = None, *, tags: list[str] | list[list[str]] | None = None, - metadata: dict[str, Any] | list[dict[str, Any]] | None = None, + metadata: builtins.dict[str, Any] | list[builtins.dict[str, Any]] | None = None, run_name: str | list[str] | None = None, run_id: uuid.UUID | list[uuid.UUID | None] | None = None, **kwargs: Any, @@ -957,7 +959,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): run_name_list = run_name or cast( "list[str | None]", ([None] * len(prompts)) ) - params = self.dict() + params = self._dict_for_compat() params["stop"] = stop callback_managers = [ CallbackManager.configure( @@ -978,7 +980,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): ] else: # We've received a single callbacks arg to apply to all inputs - params = self.dict() + params = self._dict_for_compat() params["stop"] = stop callback_managers = [ CallbackManager.configure( @@ -1136,7 +1138,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): callbacks: Callbacks | list[Callbacks] | None = None, *, tags: list[str] | list[list[str]] | None = None, - metadata: dict[str, Any] | list[dict[str, Any]] | None = None, + metadata: builtins.dict[str, Any] | list[builtins.dict[str, Any]] | None = None, run_name: str | list[str] | None = None, run_id: uuid.UUID | list[uuid.UUID | None] | None = None, **kwargs: Any, @@ -1229,7 +1231,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): run_name_list = run_name or cast( "list[str | None]", ([None] * len(prompts)) ) - params = self.dict() + params = self._dict_for_compat() params["stop"] = stop callback_managers = [ AsyncCallbackManager.configure( @@ -1250,7 +1252,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): ] else: # We've received a single callbacks arg to apply to all inputs - params = self.dict() + params = self._dict_for_compat() params["stop"] = stop callback_managers = [ AsyncCallbackManager.configure( @@ -1358,7 +1360,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): callbacks: Callbacks = None, *, tags: list[str] | None = None, - metadata: dict[str, Any] | None = None, + metadata: builtins.dict[str, Any] | None = None, **kwargs: Any, ) -> str: """Check Cache and run the LLM on the given prompt and input.""" @@ -1382,13 +1384,26 @@ class BaseLLM(BaseLanguageModel[str], ABC): def _llm_type(self) -> str: """Return type of llm.""" + @deprecated("1.4.2", alternative="asdict", removal="2.0.0") @override - def dict(self, **kwargs: Any) -> dict: - """Return a dictionary of the LLM.""" + def dict(self, **_kwargs: Any) -> builtins.dict[str, Any]: + """DEPRECATED - use `asdict()` instead. + + Return a dictionary representation of the LLM. + """ + return self.asdict() + + def asdict(self) -> builtins.dict[str, Any]: + """Return a dictionary representation of the LLM.""" starter_dict = dict(self._identifying_params) starter_dict["_type"] = self._llm_type return starter_dict + def _dict_for_compat(self) -> builtins.dict[str, Any]: + """Return the LLM dictionary while preserving deprecated overrides.""" + with suppress_langchain_deprecation_warning(): + return self.dict() + def save(self, file_path: Path | str) -> None: """Save the LLM. @@ -1410,7 +1425,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): directory_path.mkdir(parents=True, exist_ok=True) # Fetch dictionary to save - prompt_dict = self.dict() + prompt_dict = self._dict_for_compat() if save_path.suffix == ".json": with save_path.open("w", encoding="utf-8") as f: diff --git a/libs/core/langchain_core/output_parsers/base.py b/libs/core/langchain_core/output_parsers/base.py index 861e8ba7777..b316abad92d 100644 --- a/libs/core/langchain_core/output_parsers/base.py +++ b/libs/core/langchain_core/output_parsers/base.py @@ -2,6 +2,7 @@ from __future__ import annotations +import builtins # noqa: TC003 import contextlib from abc import ABC, abstractmethod from typing import ( @@ -14,6 +15,7 @@ from typing import ( from typing_extensions import override +from langchain_core._api import deprecated from langchain_core.language_models import LanguageModelOutput from langchain_core.messages import AnyMessage, BaseMessage from langchain_core.outputs import ChatGeneration, Generation @@ -340,8 +342,17 @@ class BaseOutputParser( ) raise NotImplementedError(msg) - def dict(self, **kwargs: Any) -> dict: - """Return dictionary representation of output parser.""" + @deprecated("1.4.2", alternative="asdict", removal="2.0.0") + @override + def dict(self, **kwargs: Any) -> builtins.dict[str, Any]: + """DEPRECATED - use `asdict()` instead. + + Return a dictionary representation of the output parser. + """ + return self.asdict(**kwargs) + + def asdict(self, **kwargs: Any) -> builtins.dict[str, Any]: + """Return a dictionary representation of the output parser.""" output_parser_dict = super().model_dump(**kwargs) with contextlib.suppress(NotImplementedError): output_parser_dict["_type"] = self._type diff --git a/libs/core/langchain_core/prompts/base.py b/libs/core/langchain_core/prompts/base.py index c96c936d80f..9a9cc30242d 100644 --- a/libs/core/langchain_core/prompts/base.py +++ b/libs/core/langchain_core/prompts/base.py @@ -2,11 +2,11 @@ from __future__ import annotations -import builtins # noqa: TC003 +import builtins import contextlib import json from abc import ABC, abstractmethod -from collections.abc import Mapping # noqa: TC003 +from collections.abc import Callable, Mapping from functools import cached_property from pathlib import Path from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast @@ -15,22 +15,20 @@ import yaml from pydantic import BaseModel, ConfigDict, Field, model_validator from typing_extensions import Self, override -from langchain_core._api import deprecated +from langchain_core._api import deprecated, suppress_langchain_deprecation_warning from langchain_core.exceptions import ErrorCode, create_message from langchain_core.load import dumpd -from langchain_core.output_parsers.base import BaseOutputParser # noqa: TC001 +from langchain_core.output_parsers.base import BaseOutputParser from langchain_core.prompt_values import ( ChatPromptValueConcrete, PromptValue, StringPromptValue, ) -from langchain_core.runnables import RunnableConfig, RunnableSerializable -from langchain_core.runnables.config import ensure_config +from langchain_core.runnables.base import RunnableSerializable +from langchain_core.runnables.config import RunnableConfig, ensure_config from langchain_core.utils.pydantic import create_model_v2 if TYPE_CHECKING: - from collections.abc import Callable - from langchain_core.documents import Document @@ -123,11 +121,11 @@ class BasePromptTemplate( ) @cached_property - def _serialized(self) -> dict[str, Any]: + def _serialized(self) -> builtins.dict[str, Any]: # self is always a Serializable object in this case, thus the result is # guaranteed to be a dict since dumpd uses the default callback, which uses # obj.to_json which always returns TypedDict subclasses - return cast("dict[str, Any]", dumpd(self)) + return cast("builtins.dict[str, Any]", dumpd(self)) @property @override @@ -157,7 +155,7 @@ class BasePromptTemplate( field_definitions={**required_input_variables, **optional_input_variables}, ) - def _validate_input(self, inner_input: Any) -> dict: + def _validate_input(self, inner_input: Any) -> builtins.dict: if not isinstance(inner_input, dict): if len(self.input_variables) == 1: var_name = self.input_variables[0] @@ -193,19 +191,23 @@ class BasePromptTemplate( ) return inner_input_ - def _format_prompt_with_error_handling(self, inner_input: dict) -> PromptValue: + def _format_prompt_with_error_handling( + self, + inner_input: builtins.dict, + ) -> PromptValue: inner_input_ = self._validate_input(inner_input) return self.format_prompt(**inner_input_) async def _aformat_prompt_with_error_handling( - self, inner_input: dict + self, + inner_input: builtins.dict, ) -> PromptValue: inner_input_ = self._validate_input(inner_input) return await self.aformat_prompt(**inner_input_) @override def invoke( - self, input: dict, config: RunnableConfig | None = None, **kwargs: Any + self, input: builtins.dict, config: RunnableConfig | None = None, **kwargs: Any ) -> PromptValue: """Invoke the prompt. @@ -231,7 +233,7 @@ class BasePromptTemplate( @override async def ainvoke( - self, input: dict, config: RunnableConfig | None = None, **kwargs: Any + self, input: builtins.dict, config: RunnableConfig | None = None, **kwargs: Any ) -> PromptValue: """Async invoke the prompt. @@ -293,7 +295,9 @@ class BasePromptTemplate( prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs} return type(self)(**prompt_dict) - def _merge_partial_and_user_variables(self, **kwargs: Any) -> dict[str, Any]: + def _merge_partial_and_user_variables( + self, **kwargs: Any + ) -> builtins.dict[str, Any]: # Get partial params: partial_kwargs = { k: v if not callable(v) else v() for k, v in self.partial_variables.items() @@ -337,8 +341,17 @@ class BasePromptTemplate( """Return the prompt type key.""" raise NotImplementedError - def dict(self, **kwargs: Any) -> dict: - """Return dictionary representation of prompt. + @deprecated("1.4.2", alternative="asdict", removal="2.0.0") + @override + def dict(self, **kwargs: Any) -> builtins.dict[str, Any]: + """DEPRECATED - use `asdict()` instead. + + Return a dictionary representation of the prompt. + """ + return self.asdict(**kwargs) + + def asdict(self, **kwargs: Any) -> builtins.dict[str, Any]: + """Return a dictionary representation of the prompt. Args: **kwargs: Any additional arguments to pass to the dictionary. @@ -351,6 +364,11 @@ class BasePromptTemplate( prompt_dict["_type"] = self._prompt_type return prompt_dict + def _dict_for_compat(self) -> builtins.dict[str, Any]: + """Return the prompt dictionary while preserving deprecated overrides.""" + with suppress_langchain_deprecation_warning(): + return self.dict() + @deprecated( since="1.2.21", removal="2.0.0", @@ -377,8 +395,9 @@ class BasePromptTemplate( msg = "Cannot save prompt with partial variables." raise ValueError(msg) - # Fetch dictionary to save - prompt_dict = self.dict() + # Fetch dictionary to save. Preserve deprecated `dict()` overrides until + # `dict()` is removed. + prompt_dict = self._dict_for_compat() if "_type" not in prompt_dict: msg = f"Prompt {self} does not support saving." raise NotImplementedError(msg) diff --git a/libs/core/tests/unit_tests/_api/test_deprecation.py b/libs/core/tests/unit_tests/_api/test_deprecation.py index 7993a823380..fd23462e46c 100644 --- a/libs/core/tests/unit_tests/_api/test_deprecation.py +++ b/libs/core/tests/unit_tests/_api/test_deprecation.py @@ -8,8 +8,10 @@ import pytest from pydantic import BaseModel from langchain_core._api.deprecation import ( + LangChainDeprecationWarning, deprecated, rename_parameter, + suppress_langchain_deprecation_warning, warn_deprecated, ) @@ -130,6 +132,25 @@ class ClassWithDeprecatedMethods: return "This is a deprecated property." +def test_suppressed_deprecation_warning_does_not_consume_warning() -> None: + """Suppressed calls should not block a later user-visible warning. + + For example, an internal compatibility path may call a deprecated method while + saving/loading an object with warnings suppressed. That hidden call should not + prevent the user's later direct call from seeing the deprecation warning. + """ + + @deprecated(since="2.0.0", removal="3.0.0", pending=False) + def local_deprecated_function() -> str: + return "deprecated" + + with suppress_langchain_deprecation_warning(): + assert local_deprecated_function() == "deprecated" + + with pytest.warns(LangChainDeprecationWarning): + assert local_deprecated_function() == "deprecated" + + def test_deprecated_function() -> None: """Test deprecated function.""" with warnings.catch_warnings(record=True) as warning_list: diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py index 9594ad23d9a..68f401501a5 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py @@ -4,13 +4,15 @@ import uuid import warnings from collections.abc import AsyncIterator, Iterator from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, get_type_hints from unittest.mock import patch import pytest +from langsmith.env import get_langchain_env_var_metadata from pydantic import model_validator from typing_extensions import Self, override +from langchain_core._api import LangChainDeprecationWarning from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, BaseCallbackHandler, @@ -92,6 +94,41 @@ def _content_blocks_equal_ignore_id( return True +def test_asdict_replaces_deprecated_dict() -> None: + model = FakeListChatModel(responses=["foo"]) + + expected = {"responses": ["foo"], "_type": "fake-list-chat-model"} + assert model.asdict() == expected + with pytest.warns(LangChainDeprecationWarning, match="asdict"): + assert model.dict() == expected + + +def test_base_chat_model_type_hints_resolve() -> None: + assert get_type_hints(BaseChatModel.asdict)["return"] == dict[str, Any] + + +def test_invoke_preserves_deprecated_dict_override() -> None: + """Invoking should preserve `dict()` overrides until `dict()` is removed.""" + + class CustomDictChatModel(FakeListChatModel): + @override + def dict(self, **kwargs: Any) -> dict[str, Any]: + data = super().dict(**kwargs) + data["custom_trace_param"] = "custom" + return data + + model = CustomDictChatModel(responses=["foo"]) + with warnings.catch_warnings(): + warnings.simplefilter("error", LangChainDeprecationWarning) + with collect_runs() as cb: + assert model.invoke("hello").content == "foo" + + assert cb.traced_runs[0].extra is not None + assert cb.traced_runs[0].extra["invocation_params"]["custom_trace_param"] == ( + "custom" + ) + + @pytest.fixture def messages() -> list[BaseMessage]: return [ @@ -1529,10 +1566,12 @@ def test_invocation_params_passed_to_tracer_metadata() -> None: assert len(collector.runs) == 1 run = collector.runs[0] - key = "LANGSMITH_LANGGRAPH_API_VARIANT" - - if key in run.extra["metadata"]: - del run.extra["metadata"][key] + # LangSmith injects environment-derived keys (e.g. `revision_id`, + # `LANGCHAIN_TESTS_USER_AGENT`, `LANGSMITH_*`) into run metadata. These vary + # by environment and are not the subject of this test, so strip them before + # the exact-equality comparison. + for env_key in get_langchain_env_var_metadata(): + run.extra["metadata"].pop(env_key, None) assert run.extra == { "batch_size": 1, @@ -1551,7 +1590,6 @@ def test_invocation_params_passed_to_tracer_metadata() -> None: "ls_model_type": "chat", "ls_provider": "fakechatmodelwithinvocationparams", "ls_temperature": 0.7, - "revision_id": run.extra["metadata"]["revision_id"], "stop": None, "temperature": 0.7, }, diff --git a/libs/core/tests/unit_tests/language_models/llms/test_base.py b/libs/core/tests/unit_tests/language_models/llms/test_base.py index 5b99cbdd26d..550827cde66 100644 --- a/libs/core/tests/unit_tests/language_models/llms/test_base.py +++ b/libs/core/tests/unit_tests/language_models/llms/test_base.py @@ -1,9 +1,11 @@ +import warnings from collections.abc import AsyncIterator, Iterator -from typing import Any +from typing import Any, get_type_hints import pytest from typing_extensions import override +from langchain_core._api import LangChainDeprecationWarning from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, @@ -23,6 +25,41 @@ from tests.unit_tests.fake.callbacks import ( ) +def test_asdict_replaces_deprecated_dict() -> None: + llm = FakeListLLM(responses=["foo"]) + + expected = {"responses": ["foo"], "_type": "fake-list"} + assert llm.asdict() == expected + with pytest.warns(LangChainDeprecationWarning, match="asdict"): + assert llm.dict() == expected + + +def test_base_llm_type_hints_resolve() -> None: + assert get_type_hints(BaseLLM.asdict)["return"] == dict[str, Any] + + +def test_invoke_preserves_deprecated_dict_override() -> None: + """Invoking should preserve `dict()` overrides until `dict()` is removed.""" + + class CustomDictLLM(FakeListLLM): + @override + def dict(self, **kwargs: Any) -> dict[str, Any]: + data = super().dict(**kwargs) + data["custom_trace_param"] = "custom" + return data + + llm = CustomDictLLM(responses=["foo"]) + with warnings.catch_warnings(): + warnings.simplefilter("error", LangChainDeprecationWarning) + with collect_runs() as cb: + assert llm.invoke("hello") == "foo" + + assert cb.traced_runs[0].extra is not None + assert cb.traced_runs[0].extra["invocation_params"]["custom_trace_param"] == ( + "custom" + ) + + def test_batch() -> None: llm = FakeListLLM(responses=["foo"] * 3) output = llm.batch(["foo", "bar", "foo"]) diff --git a/libs/core/tests/unit_tests/output_parsers/test_base_parsers.py b/libs/core/tests/unit_tests/output_parsers/test_base_parsers.py index 2013790aa05..45cf6ddcea5 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_base_parsers.py +++ b/libs/core/tests/unit_tests/output_parsers/test_base_parsers.py @@ -1,17 +1,38 @@ """Module to test base parser implementations.""" +from typing import Any, get_type_hints + +import pytest from typing_extensions import override +from langchain_core._api import LangChainDeprecationWarning from langchain_core.exceptions import OutputParserException from langchain_core.language_models import GenericFakeChatModel from langchain_core.messages import AIMessage from langchain_core.output_parsers import ( BaseGenerationOutputParser, + BaseOutputParser, BaseTransformOutputParser, ) from langchain_core.outputs import ChatGeneration, Generation +def test_asdict_replaces_deprecated_dict() -> None: + class StrInvertCase(BaseTransformOutputParser[str]): + def parse(self, text: str) -> str: + return text.swapcase() + + parser = StrInvertCase() + parser_dict = parser.asdict(exclude_none=True) + assert parser_dict == {} + with pytest.warns(LangChainDeprecationWarning, match="asdict"): + assert parser.dict(exclude_none=True) == parser_dict + + +def test_base_output_parser_type_hints_resolve() -> None: + assert get_type_hints(BaseOutputParser.asdict)["return"] == dict[str, Any] + + def test_base_generation_parser() -> None: """Test Base Generation Output Parser.""" diff --git a/libs/core/tests/unit_tests/prompts/test_loading.py b/libs/core/tests/unit_tests/prompts/test_loading.py index 20c7399c1db..6c0569ef8b3 100644 --- a/libs/core/tests/unit_tests/prompts/test_loading.py +++ b/libs/core/tests/unit_tests/prompts/test_loading.py @@ -5,8 +5,10 @@ import os from collections.abc import Iterator from contextlib import contextmanager from pathlib import Path +from typing import Any import pytest +from typing_extensions import override from langchain_core._api import suppress_langchain_deprecation_warning from langchain_core.prompts.few_shot import FewShotPromptTemplate @@ -105,6 +107,25 @@ def test_saving_loading_round_trip(tmp_path: Path) -> None: assert loaded_prompt == few_shot_prompt +def test_save_preserves_deprecated_dict_override(tmp_path: Path) -> None: + """Saving should preserve `dict()` overrides until `dict()` is removed.""" + + class CustomDictPrompt(PromptTemplate): + @override + def dict(self, **kwargs: Any) -> dict[str, Any]: + data = super().dict(**kwargs) + data["custom_save_param"] = "custom" + return data + + prompt = CustomDictPrompt(input_variables=["name"], template="Hello {name}") + output_path = tmp_path / "prompt.json" + + with suppress_langchain_deprecation_warning(): + prompt.save(output_path) + + assert json.loads(output_path.read_text())["custom_save_param"] == "custom" + + def test_loading_with_template_as_file() -> None: """Test loading when the template is a file.""" with change_directory(EXAMPLE_DIR), suppress_langchain_deprecation_warning(): diff --git a/libs/core/tests/unit_tests/prompts/test_prompt.py b/libs/core/tests/unit_tests/prompts/test_prompt.py index f2bce469740..9d084aeb55d 100644 --- a/libs/core/tests/unit_tests/prompts/test_prompt.py +++ b/libs/core/tests/unit_tests/prompts/test_prompt.py @@ -9,6 +9,7 @@ import pytest from packaging import version from syrupy.assertion import SnapshotAssertion +from langchain_core._api import LangChainDeprecationWarning from langchain_core.prompts.prompt import PromptTemplate from langchain_core.prompts.string import PromptTemplateFormat from langchain_core.tracers.run_collector import RunCollectorCallbackHandler @@ -18,6 +19,15 @@ from tests.unit_tests.pydantic_utils import _normalize_schema PYDANTIC_VERSION_AT_LEAST_29 = version.parse("2.9") <= PYDANTIC_VERSION +def test_asdict_replaces_deprecated_dict() -> None: + prompt = PromptTemplate.from_template("This is a {foo} test.") + + prompt_dict = prompt.asdict() + assert prompt_dict["_type"] == "prompt" + with pytest.warns(LangChainDeprecationWarning, match="asdict"): + assert prompt.dict() == prompt_dict + + def test_prompt_valid() -> None: """Test prompts can be constructed.""" template = "This is a {foo} test." @@ -217,11 +227,7 @@ def test_mustache_prompt_from_template(snapshot: SnapshotAssertion) -> None: {{/foo}}is a test.""" prompt = PromptTemplate.from_template(template, template_format="mustache") assert prompt.format(foo=[{"bar": "yo"}, {"bar": "hello"}]) == ( - """This - yo - - hello - is a test.""" # noqa: W293 + "This\n yo\n \n hello\n is a test." ) assert prompt.input_variables == ["foo"] if PYDANTIC_VERSION_AT_LEAST_29: