core[patch]: preserve inspect.iscoroutinefunction with @deprecated decorator (#16295)

Adjusted `deprecate` decorator to make sure decorated async functions
are still recognized as "coroutinefunction" by `inspect`.

Before change, functions such as `LLMChain.acall` which are decorated as
deprecated are not recognized as coroutine functions. After the change,
they are recognized:

```python
import inspect
from langchain import LLMChain

# Is false before change but true after.
inspect.iscoroutinefunction(LLMChain.acall)
```
This commit is contained in:
Piotr Mardziel 2024-01-22 11:34:13 -08:00 committed by GitHub
parent 01c2f27ffa
commit 1b9001db47
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 74 additions and 1 deletions

View File

@ -144,6 +144,15 @@ def deprecated(
emit_warning() emit_warning()
return wrapped(*args, **kwargs) return wrapped(*args, **kwargs)
async def awarning_emitting_wrapper(*args: Any, **kwargs: Any) -> Any:
"""Same as warning_emitting_wrapper, but for async functions."""
nonlocal warned
if not warned and not is_caller_internal():
warned = True
emit_warning()
return await wrapped(*args, **kwargs)
if isinstance(obj, type): if isinstance(obj, type):
if not _obj_type: if not _obj_type:
_obj_type = "class" _obj_type = "class"
@ -256,7 +265,10 @@ def deprecated(
f" {details}" f" {details}"
) )
return finalize(warning_emitting_wrapper, new_doc) if inspect.iscoroutinefunction(obj):
return finalize(awarning_emitting_wrapper, new_doc)
else:
return finalize(warning_emitting_wrapper, new_doc)
return deprecate return deprecate

View File

@ -1,3 +1,4 @@
import inspect
import warnings import warnings
from typing import Any, Dict from typing import Any, Dict
@ -74,6 +75,12 @@ def deprecated_function() -> str:
return "This is a deprecated function." return "This is a deprecated function."
@deprecated(since="2.0.0", removal="3.0.0", pending=False)
async def deprecated_async_function() -> str:
"""original doc"""
return "This is a deprecated async function."
class ClassWithDeprecatedMethods: class ClassWithDeprecatedMethods:
def __init__(self) -> None: def __init__(self) -> None:
"""original doc""" """original doc"""
@ -84,6 +91,11 @@ class ClassWithDeprecatedMethods:
"""original doc""" """original doc"""
return "This is a deprecated method." return "This is a deprecated method."
@deprecated(since="2.0.0", removal="3.0.0")
async def deprecated_async_method(self) -> str:
"""original doc"""
return "This is a deprecated async method."
@classmethod @classmethod
@deprecated(since="2.0.0", removal="3.0.0") @deprecated(since="2.0.0", removal="3.0.0")
def deprecated_classmethod(cls) -> str: def deprecated_classmethod(cls) -> str:
@ -119,6 +131,30 @@ def test_deprecated_function() -> None:
assert isinstance(doc, str) assert isinstance(doc, str)
assert doc.startswith("[*Deprecated*] original doc") assert doc.startswith("[*Deprecated*] original doc")
assert not inspect.iscoroutinefunction(deprecated_function)
@pytest.mark.asyncio
async def test_deprecated_async_function() -> None:
"""Test deprecated async function."""
with warnings.catch_warnings(record=True) as warning_list:
warnings.simplefilter("always")
assert (
await deprecated_async_function() == "This is a deprecated async function."
)
assert len(warning_list) == 1
warning = warning_list[0].message
assert str(warning) == (
"The function `deprecated_async_function` was deprecated "
"in LangChain 2.0.0 and will be removed in 3.0.0"
)
doc = deprecated_function.__doc__
assert isinstance(doc, str)
assert doc.startswith("[*Deprecated*] original doc")
assert inspect.iscoroutinefunction(deprecated_async_function)
def test_deprecated_method() -> None: def test_deprecated_method() -> None:
"""Test deprecated method.""" """Test deprecated method."""
@ -137,6 +173,31 @@ def test_deprecated_method() -> None:
assert isinstance(doc, str) assert isinstance(doc, str)
assert doc.startswith("[*Deprecated*] original doc") assert doc.startswith("[*Deprecated*] original doc")
assert not inspect.iscoroutinefunction(obj.deprecated_method)
@pytest.mark.asyncio
async def test_deprecated_async_method() -> None:
"""Test deprecated async method."""
with warnings.catch_warnings(record=True) as warning_list:
warnings.simplefilter("always")
obj = ClassWithDeprecatedMethods()
assert (
await obj.deprecated_async_method() == "This is a deprecated async method."
)
assert len(warning_list) == 1
warning = warning_list[0].message
assert str(warning) == (
"The function `deprecated_async_method` was deprecated in "
"LangChain 2.0.0 and will be removed in 3.0.0"
)
doc = obj.deprecated_method.__doc__
assert isinstance(doc, str)
assert doc.startswith("[*Deprecated*] original doc")
assert inspect.iscoroutinefunction(obj.deprecated_async_method)
def test_deprecated_classmethod() -> None: def test_deprecated_classmethod() -> None:
"""Test deprecated classmethod.""" """Test deprecated classmethod."""