diff --git a/libs/core/langchain_core/_api/deprecation.py b/libs/core/langchain_core/_api/deprecation.py index 602732e7456..c37aef3d94a 100644 --- a/libs/core/langchain_core/_api/deprecation.py +++ b/libs/core/langchain_core/_api/deprecation.py @@ -14,7 +14,17 @@ import contextlib import functools import inspect import warnings -from typing import Any, Callable, Generator, Type, TypeVar, Union, cast +from typing import ( + Any, + Callable, + Generator, + Type, + TypeVar, + Union, + cast, +) + +from typing_extensions import ParamSpec from langchain_core._api.internal import is_caller_internal @@ -448,3 +458,51 @@ def surface_langchain_deprecation_warnings() -> None: "default", category=LangChainDeprecationWarning, ) + + +_P = ParamSpec("_P") +_R = TypeVar("_R") + + +def rename_parameter( + *, + since: str, + removal: str, + old: str, + new: str, +) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]: + """Decorator indicating that parameter *old* of *func* is renamed to *new*. + + The actual implementation of *func* should use *new*, not *old*. If *old* + is passed to *func*, a DeprecationWarning is emitted, and its value is + used, even if *new* is also passed by keyword. + + Example: + + .. code-block:: python + + @_api.rename_parameter("3.1", "bad_name", "good_name") + def func(good_name): ... + """ + + def decorator(f: Callable[_P, _R]) -> Callable[_P, _R]: + @functools.wraps(f) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: + if new in kwargs and old in kwargs: + raise TypeError( + f"{f.__name__}() got multiple values for argument {new!r}" + ) + if old in kwargs: + warn_deprecated( + since, + removal=removal, + message=f"The parameter `{old}` of `{f.__name__}` was " + f"deprecated in {since} and will be removed " + f"in {removal} Use `{new}` instead.", + ) + kwargs[new] = kwargs.pop(old) + return f(*args, **kwargs) + + return wrapper + + return decorator diff --git a/libs/core/tests/unit_tests/_api/test_deprecation.py b/libs/core/tests/unit_tests/_api/test_deprecation.py index 6888fb1bb21..e1a28496354 100644 --- a/libs/core/tests/unit_tests/_api/test_deprecation.py +++ b/libs/core/tests/unit_tests/_api/test_deprecation.py @@ -4,7 +4,11 @@ from typing import Any, Dict import pytest -from langchain_core._api.deprecation import deprecated, warn_deprecated +from langchain_core._api.deprecation import ( + deprecated, + rename_parameter, + warn_deprecated, +) from langchain_core.pydantic_v1 import BaseModel @@ -412,3 +416,84 @@ def test_raise_error_for_bad_decorator() -> None: def deprecated_function() -> str: """original doc""" return "This is a deprecated function." + + +def test_rename_parameter() -> None: + """Test rename parameter.""" + + @rename_parameter(since="2.0.0", removal="3.0.0", old="old_name", new="new_name") + def foo(new_name: str) -> str: + """original doc""" + return new_name + + with warnings.catch_warnings(record=True) as warning_list: + warnings.simplefilter("always") + assert foo(old_name="hello") == "hello" # type: ignore[call-arg] + assert len(warning_list) == 1 + + assert foo(new_name="hello") == "hello" + assert foo("hello") == "hello" + assert foo.__doc__ == "original doc" + with pytest.raises(TypeError): + foo(meow="hello") # type: ignore[call-arg] + with pytest.raises(TypeError): + assert foo("hello", old_name="hello") # type: ignore[call-arg] + + with pytest.raises(TypeError): + assert foo(old_name="goodbye", new_name="hello") # type: ignore[call-arg] + + +async def test_rename_parameter_for_async_func() -> None: + """Test rename parameter.""" + + @rename_parameter(since="2.0.0", removal="3.0.0", old="old_name", new="new_name") + async def foo(new_name: str) -> str: + """original doc""" + return new_name + + with warnings.catch_warnings(record=True) as warning_list: + warnings.simplefilter("always") + assert await foo(old_name="hello") == "hello" # type: ignore[call-arg] + assert len(warning_list) == 1 + assert await foo(new_name="hello") == "hello" + assert await foo("hello") == "hello" + assert foo.__doc__ == "original doc" + with pytest.raises(TypeError): + await foo(meow="hello") # type: ignore[call-arg] + with pytest.raises(TypeError): + assert await foo("hello", old_name="hello") # type: ignore[call-arg] + + with pytest.raises(TypeError): + assert await foo(old_name="a", new_name="hello") # type: ignore[call-arg] + + +def test_rename_parameter_method() -> None: + """Test that it works for a method.""" + + class Foo: + @rename_parameter( + since="2.0.0", removal="3.0.0", old="old_name", new="new_name" + ) + def a(self, new_name: str) -> str: + return new_name + + foo = Foo() + + with warnings.catch_warnings(record=True) as warning_list: + warnings.simplefilter("always") + assert foo.a(old_name="hello") == "hello" # type: ignore[call-arg] + assert len(warning_list) == 1 + assert str(warning_list[0].message) == ( + "The parameter `old_name` of `a` was deprecated in 2.0.0 and will be " + "removed " + "in 3.0.0 Use `new_name` instead." + ) + + assert foo.a(new_name="hello") == "hello" + assert foo.a("hello") == "hello" + + with pytest.raises(TypeError): + foo.a(meow="hello") # type: ignore[call-arg] + + with pytest.raises(TypeError): + assert foo.a("hello", old_name="hello") # type: ignore[call-arg]