core[patch]: preserve inspect.iscoroutinefunction with @deprecated decorator (#16295)

Adjusted `deprecate` decorator to make sure decorated async functions
are still recognized as "coroutinefunction" by `inspect`.

Before change, functions such as `LLMChain.acall` which are decorated as
deprecated are not recognized as coroutine functions. After the change,
they are recognized:

```python
import inspect
from langchain import LLMChain

# Is false before change but true after.
inspect.iscoroutinefunction(LLMChain.acall)
```
This commit is contained in:
Piotr Mardziel
2024-01-22 11:34:13 -08:00
committed by GitHub
parent 01c2f27ffa
commit 1b9001db47
2 changed files with 74 additions and 1 deletions

View File

@@ -1,3 +1,4 @@
import inspect
import warnings
from typing import Any, Dict
@@ -74,6 +75,12 @@ def deprecated_function() -> str:
return "This is a deprecated function."
@deprecated(since="2.0.0", removal="3.0.0", pending=False)
async def deprecated_async_function() -> str:
"""original doc"""
return "This is a deprecated async function."
class ClassWithDeprecatedMethods:
def __init__(self) -> None:
"""original doc"""
@@ -84,6 +91,11 @@ class ClassWithDeprecatedMethods:
"""original doc"""
return "This is a deprecated method."
@deprecated(since="2.0.0", removal="3.0.0")
async def deprecated_async_method(self) -> str:
"""original doc"""
return "This is a deprecated async method."
@classmethod
@deprecated(since="2.0.0", removal="3.0.0")
def deprecated_classmethod(cls) -> str:
@@ -119,6 +131,30 @@ def test_deprecated_function() -> None:
assert isinstance(doc, str)
assert doc.startswith("[*Deprecated*] original doc")
assert not inspect.iscoroutinefunction(deprecated_function)
@pytest.mark.asyncio
async def test_deprecated_async_function() -> None:
"""Test deprecated async function."""
with warnings.catch_warnings(record=True) as warning_list:
warnings.simplefilter("always")
assert (
await deprecated_async_function() == "This is a deprecated async function."
)
assert len(warning_list) == 1
warning = warning_list[0].message
assert str(warning) == (
"The function `deprecated_async_function` was deprecated "
"in LangChain 2.0.0 and will be removed in 3.0.0"
)
doc = deprecated_function.__doc__
assert isinstance(doc, str)
assert doc.startswith("[*Deprecated*] original doc")
assert inspect.iscoroutinefunction(deprecated_async_function)
def test_deprecated_method() -> None:
"""Test deprecated method."""
@@ -137,6 +173,31 @@ def test_deprecated_method() -> None:
assert isinstance(doc, str)
assert doc.startswith("[*Deprecated*] original doc")
assert not inspect.iscoroutinefunction(obj.deprecated_method)
@pytest.mark.asyncio
async def test_deprecated_async_method() -> None:
"""Test deprecated async method."""
with warnings.catch_warnings(record=True) as warning_list:
warnings.simplefilter("always")
obj = ClassWithDeprecatedMethods()
assert (
await obj.deprecated_async_method() == "This is a deprecated async method."
)
assert len(warning_list) == 1
warning = warning_list[0].message
assert str(warning) == (
"The function `deprecated_async_method` was deprecated in "
"LangChain 2.0.0 and will be removed in 3.0.0"
)
doc = obj.deprecated_method.__doc__
assert isinstance(doc, str)
assert doc.startswith("[*Deprecated*] original doc")
assert inspect.iscoroutinefunction(obj.deprecated_async_method)
def test_deprecated_classmethod() -> None:
"""Test deprecated classmethod."""