mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 02:06:44 +00:00
feat(core): deprecate problematic dict() method (#31685)
`dict()` is a problematic method name as it clashes with the builtin `dict` used as a type annotation. This PR replaces it with an `asdict` method (inspired by dataclasses). It also fixes a few places where `dict` must be replaced by `builtins.dict` until the `dict()` method is removed. --------- Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
committed by
GitHub
parent
f9f11527f6
commit
74c23741b0
@@ -15,6 +15,7 @@ import inspect
|
||||
import sys
|
||||
import warnings
|
||||
from collections.abc import Callable, Generator
|
||||
from contextvars import ContextVar
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@@ -75,6 +76,15 @@ class LangChainPendingDeprecationWarning(PendingDeprecationWarning):
|
||||
"""A class for issuing deprecation warnings for LangChain users."""
|
||||
|
||||
|
||||
# Tracks when callers intentionally silence LangChain deprecation warnings.
|
||||
# Suppressed warnings should not consume a deprecated callable's one-time
|
||||
# warning state; otherwise an internal compatibility path can prevent the first
|
||||
# user-visible call from warning.
|
||||
_SUPPRESSING_LANGCHAIN_DEPRECATION_WARNING = ContextVar(
|
||||
"_SUPPRESSING_LANGCHAIN_DEPRECATION_WARNING", default=False
|
||||
)
|
||||
|
||||
|
||||
# PUBLIC API
|
||||
|
||||
|
||||
@@ -220,16 +230,20 @@ def deprecated(
|
||||
"""
|
||||
nonlocal warned
|
||||
if not warned and not is_caller_internal():
|
||||
warned = True
|
||||
emit_warning()
|
||||
# Only mark the warning as emitted if it was not intentionally
|
||||
# suppressed by `suppress_langchain_deprecation_warning()`.
|
||||
warned = not _SUPPRESSING_LANGCHAIN_DEPRECATION_WARNING.get()
|
||||
return wrapped(*args, **kwargs)
|
||||
|
||||
async def awarning_emitting_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
"""Same as warning_emitting_wrapper, but for async functions."""
|
||||
nonlocal warned
|
||||
if not warned and not is_caller_internal():
|
||||
warned = True
|
||||
emit_warning()
|
||||
# Only mark the warning as emitted if it was not intentionally
|
||||
# suppressed by `suppress_langchain_deprecation_warning()`.
|
||||
warned = not _SUPPRESSING_LANGCHAIN_DEPRECATION_WARNING.get()
|
||||
return await wrapped(*args, **kwargs)
|
||||
|
||||
_package = _package or obj.__module__.split(".")[0].replace("_", "-")
|
||||
@@ -253,8 +267,10 @@ def deprecated(
|
||||
"""Warn that the class is in beta."""
|
||||
nonlocal warned
|
||||
if not warned and type(self) is obj and not is_caller_internal():
|
||||
warned = True
|
||||
emit_warning()
|
||||
# Only mark the warning as emitted if it was not intentionally
|
||||
# suppressed by `suppress_langchain_deprecation_warning()`.
|
||||
warned = not _SUPPRESSING_LANGCHAIN_DEPRECATION_WARNING.get()
|
||||
return wrapped(self, *args, **kwargs)
|
||||
|
||||
obj.__init__ = functools.wraps(obj.__init__)( # type: ignore[misc]
|
||||
@@ -451,10 +467,14 @@ def deprecated(
|
||||
@contextlib.contextmanager
|
||||
def suppress_langchain_deprecation_warning() -> Generator[None, None, None]:
|
||||
"""Context manager to suppress `LangChainDeprecationWarning`."""
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", LangChainDeprecationWarning)
|
||||
warnings.simplefilter("ignore", LangChainPendingDeprecationWarning)
|
||||
yield
|
||||
token = _SUPPRESSING_LANGCHAIN_DEPRECATION_WARNING.set(True)
|
||||
try:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", LangChainDeprecationWarning)
|
||||
warnings.simplefilter("ignore", LangChainPendingDeprecationWarning)
|
||||
yield
|
||||
finally:
|
||||
_SUPPRESSING_LANGCHAIN_DEPRECATION_WARNING.reset(token)
|
||||
|
||||
|
||||
def warn_deprecated(
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import builtins # noqa: TC003
|
||||
import contextlib
|
||||
import inspect
|
||||
import json
|
||||
@@ -16,7 +17,7 @@ from langchain_protocol.protocol import MessageFinishData
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from typing_extensions import Self, override
|
||||
|
||||
from langchain_core._api import beta
|
||||
from langchain_core._api import beta, deprecated, suppress_langchain_deprecation_warning
|
||||
from langchain_core.caches import BaseCache
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManager,
|
||||
@@ -93,7 +94,6 @@ from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass
|
||||
from langchain_core.utils.utils import LC_ID_PREFIX, from_env
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import builtins
|
||||
import uuid
|
||||
from collections.abc import Awaitable
|
||||
|
||||
@@ -430,11 +430,11 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
return self
|
||||
|
||||
@cached_property
|
||||
def _serialized(self) -> dict[str, Any]:
|
||||
def _serialized(self) -> builtins.dict[str, Any]:
|
||||
# self is always a Serializable object in this case, thus the result is
|
||||
# guaranteed to be a dict since dumps uses the default callback, which uses
|
||||
# guaranteed to be a dict since dumpd uses the default callback, which uses
|
||||
# obj.to_json which always returns TypedDict subclasses
|
||||
return cast("dict[str, Any]", dumpd(self))
|
||||
return cast("builtins.dict[str, Any]", dumpd(self))
|
||||
|
||||
# --- Runnable methods ---
|
||||
|
||||
@@ -1371,7 +1371,9 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
|
||||
# --- Custom methods ---
|
||||
|
||||
def _combine_llm_outputs(self, _llm_outputs: list[dict | None], /) -> dict:
|
||||
def _combine_llm_outputs(
|
||||
self, _llm_outputs: list[builtins.dict | None], /
|
||||
) -> builtins.dict:
|
||||
return {}
|
||||
|
||||
def _convert_cached_generations(self, cache_val: list) -> list[ChatGeneration]:
|
||||
@@ -1464,8 +1466,8 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
self,
|
||||
stop: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
params = self.dict()
|
||||
) -> builtins.dict:
|
||||
params = self._dict_for_compat()
|
||||
params["stop"] = stop
|
||||
return {**params, **kwargs}
|
||||
|
||||
@@ -1566,7 +1568,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
metadata: builtins.dict[str, Any] | None = None,
|
||||
run_name: str | None = None,
|
||||
run_id: uuid.UUID | None = None,
|
||||
**kwargs: Any,
|
||||
@@ -1692,7 +1694,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
metadata: builtins.dict[str, Any] | None = None,
|
||||
run_name: str | None = None,
|
||||
run_id: uuid.UUID | None = None,
|
||||
**kwargs: Any,
|
||||
@@ -2303,13 +2305,26 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
|
||||
@deprecated("1.4.2", alternative="asdict", removal="2.0.0")
|
||||
@override
|
||||
def dict(self, **kwargs: Any) -> dict:
|
||||
"""Return a dictionary of the LLM."""
|
||||
def dict(self, **_kwargs: Any) -> builtins.dict[str, Any]:
|
||||
"""DEPRECATED - use `asdict()` instead.
|
||||
|
||||
Return a dictionary representation of the chat model.
|
||||
"""
|
||||
return self.asdict()
|
||||
|
||||
def asdict(self) -> builtins.dict[str, Any]:
|
||||
"""Return a dictionary representation of the chat model."""
|
||||
starter_dict = dict(self._identifying_params)
|
||||
starter_dict["_type"] = self._llm_type
|
||||
return starter_dict
|
||||
|
||||
def _dict_for_compat(self) -> builtins.dict[str, Any]:
|
||||
"""Return the chat model dictionary while preserving deprecated overrides."""
|
||||
with suppress_langchain_deprecation_warning():
|
||||
return self.dict()
|
||||
|
||||
@override
|
||||
def bind(self, **kwargs: Any) -> _ChatModelBinding:
|
||||
"""Bind kwargs to this chat model, returning a typed `_ChatModelBinding`.
|
||||
|
||||
@@ -6,6 +6,7 @@ These are traditionally older models (newer models generally are chat models).
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import builtins # noqa: TC003
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
@@ -32,6 +33,7 @@ from tenacity import (
|
||||
)
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core._api import deprecated, suppress_langchain_deprecation_warning
|
||||
from langchain_core.caches import BaseCache
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManager,
|
||||
@@ -301,11 +303,11 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
)
|
||||
|
||||
@functools.cached_property
|
||||
def _serialized(self) -> dict[str, Any]:
|
||||
def _serialized(self) -> builtins.dict[str, Any]:
|
||||
# self is always a Serializable object in this case, thus the result is
|
||||
# guaranteed to be a dict since dumps uses the default callback, which uses
|
||||
# guaranteed to be a dict since dumpd uses the default callback, which uses
|
||||
# obj.to_json which always returns TypedDict subclasses
|
||||
return cast("dict[str, Any]", dumpd(self))
|
||||
return cast("builtins.dict[str, Any]", dumpd(self))
|
||||
|
||||
# --- Runnable methods ---
|
||||
|
||||
@@ -522,7 +524,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
else:
|
||||
prompt = self._convert_input(input).to_string()
|
||||
config = ensure_config(config)
|
||||
params = self.dict()
|
||||
params = self._dict_for_compat()
|
||||
params["stop"] = stop
|
||||
params = {**params, **kwargs}
|
||||
options = {"stop": stop}
|
||||
@@ -595,7 +597,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
|
||||
prompt = self._convert_input(input).to_string()
|
||||
config = ensure_config(config)
|
||||
params = self.dict()
|
||||
params = self._dict_for_compat()
|
||||
params["stop"] = stop
|
||||
params = {**params, **kwargs}
|
||||
options = {"stop": stop}
|
||||
@@ -853,7 +855,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
callbacks: Callbacks | list[Callbacks] | None = None,
|
||||
*,
|
||||
tags: list[str] | list[list[str]] | None = None,
|
||||
metadata: dict[str, Any] | list[dict[str, Any]] | None = None,
|
||||
metadata: builtins.dict[str, Any] | list[builtins.dict[str, Any]] | None = None,
|
||||
run_name: str | list[str] | None = None,
|
||||
run_id: uuid.UUID | list[uuid.UUID | None] | None = None,
|
||||
**kwargs: Any,
|
||||
@@ -957,7 +959,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
run_name_list = run_name or cast(
|
||||
"list[str | None]", ([None] * len(prompts))
|
||||
)
|
||||
params = self.dict()
|
||||
params = self._dict_for_compat()
|
||||
params["stop"] = stop
|
||||
callback_managers = [
|
||||
CallbackManager.configure(
|
||||
@@ -978,7 +980,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
]
|
||||
else:
|
||||
# We've received a single callbacks arg to apply to all inputs
|
||||
params = self.dict()
|
||||
params = self._dict_for_compat()
|
||||
params["stop"] = stop
|
||||
callback_managers = [
|
||||
CallbackManager.configure(
|
||||
@@ -1136,7 +1138,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
callbacks: Callbacks | list[Callbacks] | None = None,
|
||||
*,
|
||||
tags: list[str] | list[list[str]] | None = None,
|
||||
metadata: dict[str, Any] | list[dict[str, Any]] | None = None,
|
||||
metadata: builtins.dict[str, Any] | list[builtins.dict[str, Any]] | None = None,
|
||||
run_name: str | list[str] | None = None,
|
||||
run_id: uuid.UUID | list[uuid.UUID | None] | None = None,
|
||||
**kwargs: Any,
|
||||
@@ -1229,7 +1231,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
run_name_list = run_name or cast(
|
||||
"list[str | None]", ([None] * len(prompts))
|
||||
)
|
||||
params = self.dict()
|
||||
params = self._dict_for_compat()
|
||||
params["stop"] = stop
|
||||
callback_managers = [
|
||||
AsyncCallbackManager.configure(
|
||||
@@ -1250,7 +1252,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
]
|
||||
else:
|
||||
# We've received a single callbacks arg to apply to all inputs
|
||||
params = self.dict()
|
||||
params = self._dict_for_compat()
|
||||
params["stop"] = stop
|
||||
callback_managers = [
|
||||
AsyncCallbackManager.configure(
|
||||
@@ -1358,7 +1360,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
metadata: builtins.dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Check Cache and run the LLM on the given prompt and input."""
|
||||
@@ -1382,13 +1384,26 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
|
||||
@deprecated("1.4.2", alternative="asdict", removal="2.0.0")
|
||||
@override
|
||||
def dict(self, **kwargs: Any) -> dict:
|
||||
"""Return a dictionary of the LLM."""
|
||||
def dict(self, **_kwargs: Any) -> builtins.dict[str, Any]:
|
||||
"""DEPRECATED - use `asdict()` instead.
|
||||
|
||||
Return a dictionary representation of the LLM.
|
||||
"""
|
||||
return self.asdict()
|
||||
|
||||
def asdict(self) -> builtins.dict[str, Any]:
|
||||
"""Return a dictionary representation of the LLM."""
|
||||
starter_dict = dict(self._identifying_params)
|
||||
starter_dict["_type"] = self._llm_type
|
||||
return starter_dict
|
||||
|
||||
def _dict_for_compat(self) -> builtins.dict[str, Any]:
|
||||
"""Return the LLM dictionary while preserving deprecated overrides."""
|
||||
with suppress_langchain_deprecation_warning():
|
||||
return self.dict()
|
||||
|
||||
def save(self, file_path: Path | str) -> None:
|
||||
"""Save the LLM.
|
||||
|
||||
@@ -1410,7 +1425,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
directory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Fetch dictionary to save
|
||||
prompt_dict = self.dict()
|
||||
prompt_dict = self._dict_for_compat()
|
||||
|
||||
if save_path.suffix == ".json":
|
||||
with save_path.open("w", encoding="utf-8") as f:
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins # noqa: TC003
|
||||
import contextlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
@@ -14,6 +15,7 @@ from typing import (
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.language_models import LanguageModelOutput
|
||||
from langchain_core.messages import AnyMessage, BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
@@ -340,8 +342,17 @@ class BaseOutputParser(
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def dict(self, **kwargs: Any) -> dict:
|
||||
"""Return dictionary representation of output parser."""
|
||||
@deprecated("1.4.2", alternative="asdict", removal="2.0.0")
|
||||
@override
|
||||
def dict(self, **kwargs: Any) -> builtins.dict[str, Any]:
|
||||
"""DEPRECATED - use `asdict()` instead.
|
||||
|
||||
Return a dictionary representation of the output parser.
|
||||
"""
|
||||
return self.asdict(**kwargs)
|
||||
|
||||
def asdict(self, **kwargs: Any) -> builtins.dict[str, Any]:
|
||||
"""Return a dictionary representation of the output parser."""
|
||||
output_parser_dict = super().model_dump(**kwargs)
|
||||
with contextlib.suppress(NotImplementedError):
|
||||
output_parser_dict["_type"] = self._type
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins # noqa: TC003
|
||||
import builtins
|
||||
import contextlib
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping # noqa: TC003
|
||||
from collections.abc import Callable, Mapping
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
|
||||
@@ -15,22 +15,20 @@ import yaml
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from typing_extensions import Self, override
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core._api import deprecated, suppress_langchain_deprecation_warning
|
||||
from langchain_core.exceptions import ErrorCode, create_message
|
||||
from langchain_core.load import dumpd
|
||||
from langchain_core.output_parsers.base import BaseOutputParser # noqa: TC001
|
||||
from langchain_core.output_parsers.base import BaseOutputParser
|
||||
from langchain_core.prompt_values import (
|
||||
ChatPromptValueConcrete,
|
||||
PromptValue,
|
||||
StringPromptValue,
|
||||
)
|
||||
from langchain_core.runnables import RunnableConfig, RunnableSerializable
|
||||
from langchain_core.runnables.config import ensure_config
|
||||
from langchain_core.runnables.base import RunnableSerializable
|
||||
from langchain_core.runnables.config import RunnableConfig, ensure_config
|
||||
from langchain_core.utils.pydantic import create_model_v2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
@@ -123,11 +121,11 @@ class BasePromptTemplate(
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def _serialized(self) -> dict[str, Any]:
|
||||
def _serialized(self) -> builtins.dict[str, Any]:
|
||||
# self is always a Serializable object in this case, thus the result is
|
||||
# guaranteed to be a dict since dumpd uses the default callback, which uses
|
||||
# obj.to_json which always returns TypedDict subclasses
|
||||
return cast("dict[str, Any]", dumpd(self))
|
||||
return cast("builtins.dict[str, Any]", dumpd(self))
|
||||
|
||||
@property
|
||||
@override
|
||||
@@ -157,7 +155,7 @@ class BasePromptTemplate(
|
||||
field_definitions={**required_input_variables, **optional_input_variables},
|
||||
)
|
||||
|
||||
def _validate_input(self, inner_input: Any) -> dict:
|
||||
def _validate_input(self, inner_input: Any) -> builtins.dict:
|
||||
if not isinstance(inner_input, dict):
|
||||
if len(self.input_variables) == 1:
|
||||
var_name = self.input_variables[0]
|
||||
@@ -193,19 +191,23 @@ class BasePromptTemplate(
|
||||
)
|
||||
return inner_input_
|
||||
|
||||
def _format_prompt_with_error_handling(self, inner_input: dict) -> PromptValue:
|
||||
def _format_prompt_with_error_handling(
|
||||
self,
|
||||
inner_input: builtins.dict,
|
||||
) -> PromptValue:
|
||||
inner_input_ = self._validate_input(inner_input)
|
||||
return self.format_prompt(**inner_input_)
|
||||
|
||||
async def _aformat_prompt_with_error_handling(
|
||||
self, inner_input: dict
|
||||
self,
|
||||
inner_input: builtins.dict,
|
||||
) -> PromptValue:
|
||||
inner_input_ = self._validate_input(inner_input)
|
||||
return await self.aformat_prompt(**inner_input_)
|
||||
|
||||
@override
|
||||
def invoke(
|
||||
self, input: dict, config: RunnableConfig | None = None, **kwargs: Any
|
||||
self, input: builtins.dict, config: RunnableConfig | None = None, **kwargs: Any
|
||||
) -> PromptValue:
|
||||
"""Invoke the prompt.
|
||||
|
||||
@@ -231,7 +233,7 @@ class BasePromptTemplate(
|
||||
|
||||
@override
|
||||
async def ainvoke(
|
||||
self, input: dict, config: RunnableConfig | None = None, **kwargs: Any
|
||||
self, input: builtins.dict, config: RunnableConfig | None = None, **kwargs: Any
|
||||
) -> PromptValue:
|
||||
"""Async invoke the prompt.
|
||||
|
||||
@@ -293,7 +295,9 @@ class BasePromptTemplate(
|
||||
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
|
||||
return type(self)(**prompt_dict)
|
||||
|
||||
def _merge_partial_and_user_variables(self, **kwargs: Any) -> dict[str, Any]:
|
||||
def _merge_partial_and_user_variables(
|
||||
self, **kwargs: Any
|
||||
) -> builtins.dict[str, Any]:
|
||||
# Get partial params:
|
||||
partial_kwargs = {
|
||||
k: v if not callable(v) else v() for k, v in self.partial_variables.items()
|
||||
@@ -337,8 +341,17 @@ class BasePromptTemplate(
|
||||
"""Return the prompt type key."""
|
||||
raise NotImplementedError
|
||||
|
||||
def dict(self, **kwargs: Any) -> dict:
|
||||
"""Return dictionary representation of prompt.
|
||||
@deprecated("1.4.2", alternative="asdict", removal="2.0.0")
|
||||
@override
|
||||
def dict(self, **kwargs: Any) -> builtins.dict[str, Any]:
|
||||
"""DEPRECATED - use `asdict()` instead.
|
||||
|
||||
Return a dictionary representation of the prompt.
|
||||
"""
|
||||
return self.asdict(**kwargs)
|
||||
|
||||
def asdict(self, **kwargs: Any) -> builtins.dict[str, Any]:
|
||||
"""Return a dictionary representation of the prompt.
|
||||
|
||||
Args:
|
||||
**kwargs: Any additional arguments to pass to the dictionary.
|
||||
@@ -351,6 +364,11 @@ class BasePromptTemplate(
|
||||
prompt_dict["_type"] = self._prompt_type
|
||||
return prompt_dict
|
||||
|
||||
def _dict_for_compat(self) -> builtins.dict[str, Any]:
|
||||
"""Return the prompt dictionary while preserving deprecated overrides."""
|
||||
with suppress_langchain_deprecation_warning():
|
||||
return self.dict()
|
||||
|
||||
@deprecated(
|
||||
since="1.2.21",
|
||||
removal="2.0.0",
|
||||
@@ -377,8 +395,9 @@ class BasePromptTemplate(
|
||||
msg = "Cannot save prompt with partial variables."
|
||||
raise ValueError(msg)
|
||||
|
||||
# Fetch dictionary to save
|
||||
prompt_dict = self.dict()
|
||||
# Fetch dictionary to save. Preserve deprecated `dict()` overrides until
|
||||
# `dict()` is removed.
|
||||
prompt_dict = self._dict_for_compat()
|
||||
if "_type" not in prompt_dict:
|
||||
msg = f"Prompt {self} does not support saving."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
@@ -8,8 +8,10 @@ import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_core._api.deprecation import (
|
||||
LangChainDeprecationWarning,
|
||||
deprecated,
|
||||
rename_parameter,
|
||||
suppress_langchain_deprecation_warning,
|
||||
warn_deprecated,
|
||||
)
|
||||
|
||||
@@ -130,6 +132,25 @@ class ClassWithDeprecatedMethods:
|
||||
return "This is a deprecated property."
|
||||
|
||||
|
||||
def test_suppressed_deprecation_warning_does_not_consume_warning() -> None:
|
||||
"""Suppressed calls should not block a later user-visible warning.
|
||||
|
||||
For example, an internal compatibility path may call a deprecated method while
|
||||
saving/loading an object with warnings suppressed. That hidden call should not
|
||||
prevent the user's later direct call from seeing the deprecation warning.
|
||||
"""
|
||||
|
||||
@deprecated(since="2.0.0", removal="3.0.0", pending=False)
|
||||
def local_deprecated_function() -> str:
|
||||
return "deprecated"
|
||||
|
||||
with suppress_langchain_deprecation_warning():
|
||||
assert local_deprecated_function() == "deprecated"
|
||||
|
||||
with pytest.warns(LangChainDeprecationWarning):
|
||||
assert local_deprecated_function() == "deprecated"
|
||||
|
||||
|
||||
def test_deprecated_function() -> None:
|
||||
"""Test deprecated function."""
|
||||
with warnings.catch_warnings(record=True) as warning_list:
|
||||
|
||||
@@ -4,13 +4,15 @@ import uuid
|
||||
import warnings
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
from typing import TYPE_CHECKING, Any, Literal, get_type_hints
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from langsmith.env import get_langchain_env_var_metadata
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import Self, override
|
||||
|
||||
from langchain_core._api import LangChainDeprecationWarning
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
BaseCallbackHandler,
|
||||
@@ -92,6 +94,41 @@ def _content_blocks_equal_ignore_id(
|
||||
return True
|
||||
|
||||
|
||||
def test_asdict_replaces_deprecated_dict() -> None:
|
||||
model = FakeListChatModel(responses=["foo"])
|
||||
|
||||
expected = {"responses": ["foo"], "_type": "fake-list-chat-model"}
|
||||
assert model.asdict() == expected
|
||||
with pytest.warns(LangChainDeprecationWarning, match="asdict"):
|
||||
assert model.dict() == expected
|
||||
|
||||
|
||||
def test_base_chat_model_type_hints_resolve() -> None:
|
||||
assert get_type_hints(BaseChatModel.asdict)["return"] == dict[str, Any]
|
||||
|
||||
|
||||
def test_invoke_preserves_deprecated_dict_override() -> None:
|
||||
"""Invoking should preserve `dict()` overrides until `dict()` is removed."""
|
||||
|
||||
class CustomDictChatModel(FakeListChatModel):
|
||||
@override
|
||||
def dict(self, **kwargs: Any) -> dict[str, Any]:
|
||||
data = super().dict(**kwargs)
|
||||
data["custom_trace_param"] = "custom"
|
||||
return data
|
||||
|
||||
model = CustomDictChatModel(responses=["foo"])
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error", LangChainDeprecationWarning)
|
||||
with collect_runs() as cb:
|
||||
assert model.invoke("hello").content == "foo"
|
||||
|
||||
assert cb.traced_runs[0].extra is not None
|
||||
assert cb.traced_runs[0].extra["invocation_params"]["custom_trace_param"] == (
|
||||
"custom"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def messages() -> list[BaseMessage]:
|
||||
return [
|
||||
@@ -1529,10 +1566,12 @@ def test_invocation_params_passed_to_tracer_metadata() -> None:
|
||||
assert len(collector.runs) == 1
|
||||
run = collector.runs[0]
|
||||
|
||||
key = "LANGSMITH_LANGGRAPH_API_VARIANT"
|
||||
|
||||
if key in run.extra["metadata"]:
|
||||
del run.extra["metadata"][key]
|
||||
# LangSmith injects environment-derived keys (e.g. `revision_id`,
|
||||
# `LANGCHAIN_TESTS_USER_AGENT`, `LANGSMITH_*`) into run metadata. These vary
|
||||
# by environment and are not the subject of this test, so strip them before
|
||||
# the exact-equality comparison.
|
||||
for env_key in get_langchain_env_var_metadata():
|
||||
run.extra["metadata"].pop(env_key, None)
|
||||
|
||||
assert run.extra == {
|
||||
"batch_size": 1,
|
||||
@@ -1551,7 +1590,6 @@ def test_invocation_params_passed_to_tracer_metadata() -> None:
|
||||
"ls_model_type": "chat",
|
||||
"ls_provider": "fakechatmodelwithinvocationparams",
|
||||
"ls_temperature": 0.7,
|
||||
"revision_id": run.extra["metadata"]["revision_id"],
|
||||
"stop": None,
|
||||
"temperature": 0.7,
|
||||
},
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import warnings
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from typing import Any
|
||||
from typing import Any, get_type_hints
|
||||
|
||||
import pytest
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core._api import LangChainDeprecationWarning
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
@@ -23,6 +25,41 @@ from tests.unit_tests.fake.callbacks import (
|
||||
)
|
||||
|
||||
|
||||
def test_asdict_replaces_deprecated_dict() -> None:
|
||||
llm = FakeListLLM(responses=["foo"])
|
||||
|
||||
expected = {"responses": ["foo"], "_type": "fake-list"}
|
||||
assert llm.asdict() == expected
|
||||
with pytest.warns(LangChainDeprecationWarning, match="asdict"):
|
||||
assert llm.dict() == expected
|
||||
|
||||
|
||||
def test_base_llm_type_hints_resolve() -> None:
|
||||
assert get_type_hints(BaseLLM.asdict)["return"] == dict[str, Any]
|
||||
|
||||
|
||||
def test_invoke_preserves_deprecated_dict_override() -> None:
|
||||
"""Invoking should preserve `dict()` overrides until `dict()` is removed."""
|
||||
|
||||
class CustomDictLLM(FakeListLLM):
|
||||
@override
|
||||
def dict(self, **kwargs: Any) -> dict[str, Any]:
|
||||
data = super().dict(**kwargs)
|
||||
data["custom_trace_param"] = "custom"
|
||||
return data
|
||||
|
||||
llm = CustomDictLLM(responses=["foo"])
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error", LangChainDeprecationWarning)
|
||||
with collect_runs() as cb:
|
||||
assert llm.invoke("hello") == "foo"
|
||||
|
||||
assert cb.traced_runs[0].extra is not None
|
||||
assert cb.traced_runs[0].extra["invocation_params"]["custom_trace_param"] == (
|
||||
"custom"
|
||||
)
|
||||
|
||||
|
||||
def test_batch() -> None:
|
||||
llm = FakeListLLM(responses=["foo"] * 3)
|
||||
output = llm.batch(["foo", "bar", "foo"])
|
||||
|
||||
@@ -1,17 +1,38 @@
|
||||
"""Module to test base parser implementations."""
|
||||
|
||||
from typing import Any, get_type_hints
|
||||
|
||||
import pytest
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core._api import LangChainDeprecationWarning
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.language_models import GenericFakeChatModel
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.output_parsers import (
|
||||
BaseGenerationOutputParser,
|
||||
BaseOutputParser,
|
||||
BaseTransformOutputParser,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
|
||||
|
||||
def test_asdict_replaces_deprecated_dict() -> None:
|
||||
class StrInvertCase(BaseTransformOutputParser[str]):
|
||||
def parse(self, text: str) -> str:
|
||||
return text.swapcase()
|
||||
|
||||
parser = StrInvertCase()
|
||||
parser_dict = parser.asdict(exclude_none=True)
|
||||
assert parser_dict == {}
|
||||
with pytest.warns(LangChainDeprecationWarning, match="asdict"):
|
||||
assert parser.dict(exclude_none=True) == parser_dict
|
||||
|
||||
|
||||
def test_base_output_parser_type_hints_resolve() -> None:
|
||||
assert get_type_hints(BaseOutputParser.asdict)["return"] == dict[str, Any]
|
||||
|
||||
|
||||
def test_base_generation_parser() -> None:
|
||||
"""Test Base Generation Output Parser."""
|
||||
|
||||
|
||||
@@ -5,8 +5,10 @@ import os
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core._api import suppress_langchain_deprecation_warning
|
||||
from langchain_core.prompts.few_shot import FewShotPromptTemplate
|
||||
@@ -105,6 +107,25 @@ def test_saving_loading_round_trip(tmp_path: Path) -> None:
|
||||
assert loaded_prompt == few_shot_prompt
|
||||
|
||||
|
||||
def test_save_preserves_deprecated_dict_override(tmp_path: Path) -> None:
|
||||
"""Saving should preserve `dict()` overrides until `dict()` is removed."""
|
||||
|
||||
class CustomDictPrompt(PromptTemplate):
|
||||
@override
|
||||
def dict(self, **kwargs: Any) -> dict[str, Any]:
|
||||
data = super().dict(**kwargs)
|
||||
data["custom_save_param"] = "custom"
|
||||
return data
|
||||
|
||||
prompt = CustomDictPrompt(input_variables=["name"], template="Hello {name}")
|
||||
output_path = tmp_path / "prompt.json"
|
||||
|
||||
with suppress_langchain_deprecation_warning():
|
||||
prompt.save(output_path)
|
||||
|
||||
assert json.loads(output_path.read_text())["custom_save_param"] == "custom"
|
||||
|
||||
|
||||
def test_loading_with_template_as_file() -> None:
|
||||
"""Test loading when the template is a file."""
|
||||
with change_directory(EXAMPLE_DIR), suppress_langchain_deprecation_warning():
|
||||
|
||||
@@ -9,6 +9,7 @@ import pytest
|
||||
from packaging import version
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from langchain_core._api import LangChainDeprecationWarning
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.prompts.string import PromptTemplateFormat
|
||||
from langchain_core.tracers.run_collector import RunCollectorCallbackHandler
|
||||
@@ -18,6 +19,15 @@ from tests.unit_tests.pydantic_utils import _normalize_schema
|
||||
PYDANTIC_VERSION_AT_LEAST_29 = version.parse("2.9") <= PYDANTIC_VERSION
|
||||
|
||||
|
||||
def test_asdict_replaces_deprecated_dict() -> None:
|
||||
prompt = PromptTemplate.from_template("This is a {foo} test.")
|
||||
|
||||
prompt_dict = prompt.asdict()
|
||||
assert prompt_dict["_type"] == "prompt"
|
||||
with pytest.warns(LangChainDeprecationWarning, match="asdict"):
|
||||
assert prompt.dict() == prompt_dict
|
||||
|
||||
|
||||
def test_prompt_valid() -> None:
|
||||
"""Test prompts can be constructed."""
|
||||
template = "This is a {foo} test."
|
||||
@@ -217,11 +227,7 @@ def test_mustache_prompt_from_template(snapshot: SnapshotAssertion) -> None:
|
||||
{{/foo}}is a test."""
|
||||
prompt = PromptTemplate.from_template(template, template_format="mustache")
|
||||
assert prompt.format(foo=[{"bar": "yo"}, {"bar": "hello"}]) == (
|
||||
"""This
|
||||
yo
|
||||
|
||||
hello
|
||||
is a test.""" # noqa: W293
|
||||
"This\n yo\n \n hello\n is a test."
|
||||
)
|
||||
assert prompt.input_variables == ["foo"]
|
||||
if PYDANTIC_VERSION_AT_LEAST_29:
|
||||
|
||||
Reference in New Issue
Block a user