Compare commits

...

1 Commits

Author SHA1 Message Date
Eugene Yurtsev
10282f34cc x 2024-07-23 09:13:36 -04:00
2 changed files with 140 additions and 0 deletions

View File

@@ -0,0 +1,83 @@
import inspect
import time
from typing import Any, Callable, TypeVar
import pytest
F = TypeVar("F", bound=Callable[..., Any])
def _timeout(*, seconds: float) -> Callable[[F], F]:
"""Decorator to measure the execution time of a test function and fail the test
if it exceeds a specified maximum time.
This function does **not** terminate the test function if it exceeds the maximum
allowed time.
Args:
seconds: Maximum allowed time for the test function to execute, in seconds.
Returns:
Callable[[F], F]: A decorated function that measures execution time and
enforces the maximum allowed time by failing the test if it is exceeded
the allowed time.
"""
def decorator(func: F) -> F:
"""Decorator function to wrap the test function.
Args:
func (F): The test function to be decorated.
Returns:
F: The wrapped test function.
"""
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
"""Wrapper for asynchronous test functions to measure execution time.
Args:
*args (Any): Positional arguments for the test function.
**kwargs (Any): Keyword arguments for the test function.
Returns:
Any: The result of the test function.
"""
start_time = time.time()
result = await func(*args, **kwargs)
end_time = time.time()
duration = end_time - start_time
if duration > seconds:
pytest.fail(
f"{func.__name__} exceeded the maximum allowed time of {seconds} "
f"seconds."
)
return result
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
"""Wrapper for synchronous test functions to measure execution time.
Args:
*args (Any): Positional arguments for the test function.
**kwargs (Any): Keyword arguments for the test function.
Returns:
Any: The result of the test function.
"""
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
duration = end_time - start_time
if duration > seconds:
pytest.fail(
f"{func.__name__} exceeded the maximum allowed time of {seconds} "
f"seconds."
)
return result
if inspect.iscoroutinefunction(func):
return async_wrapper # type: ignore
else:
return sync_wrapper # type: ignore
return decorator

View File

@@ -0,0 +1,57 @@
"""Check that the max time decorator works."""
import asyncio
import time
import pytest
from langchain_standard_tests.utils.timeout import _timeout
@_timeout(seconds=0.5)
def test_sync_fast() -> None:
"""Test function that completes within the allowed time."""
time.sleep(0.01)
@_timeout(seconds=0.5)
async def test_async_fast() -> None:
"""Test async function that completes within the allowed time."""
await asyncio.sleep(0.01)
@pytest.mark.xfail(strict=True)
@_timeout(seconds=0)
def test_sync_slow() -> None:
"""Test async function that exceeds the allowed time."""
time.sleep(0.01)
@pytest.mark.xfail(strict=True)
@_timeout(seconds=0)
async def test_async_slow() -> None:
"""Test async function that exceeds the allowed time."""
await asyncio.sleep(0.01)
class TestMethodDecoration:
@_timeout(seconds=0.5)
def test_sync_fast_method(self) -> None:
"""Test function that completes within the allowed time."""
time.sleep(0.01)
@_timeout(seconds=0.5)
async def test_async_fast_method(self) -> None:
"""Test async function that completes within the allowed time."""
await asyncio.sleep(0.01)
@pytest.mark.xfail(strict=True)
@_timeout(seconds=0)
def test_sync_slow_method(self) -> None:
"""Test async function that exceeds the allowed time."""
time.sleep(0.01)
@pytest.mark.xfail(strict=True)
@_timeout(seconds=0)
async def test_async_slow_method(self) -> None:
"""Test async function that exceeds the allowed time."""
await asyncio.sleep(0.01)