mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-03 15:55:44 +00:00
x
This commit is contained in:
@@ -0,0 +1,83 @@
|
||||
import inspect
|
||||
import time
|
||||
from typing import Any, Callable, TypeVar
|
||||
|
||||
import pytest
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def _timeout(*, seconds: float) -> Callable[[F], F]:
|
||||
"""Decorator to measure the execution time of a test function and fail the test
|
||||
if it exceeds a specified maximum time.
|
||||
|
||||
This function does **not** terminate the test function if it exceeds the maximum
|
||||
allowed time.
|
||||
|
||||
Args:
|
||||
seconds: Maximum allowed time for the test function to execute, in seconds.
|
||||
|
||||
Returns:
|
||||
Callable[[F], F]: A decorated function that measures execution time and
|
||||
enforces the maximum allowed time by failing the test if it is exceeded
|
||||
the allowed time.
|
||||
"""
|
||||
|
||||
def decorator(func: F) -> F:
|
||||
"""Decorator function to wrap the test function.
|
||||
|
||||
Args:
|
||||
func (F): The test function to be decorated.
|
||||
|
||||
Returns:
|
||||
F: The wrapped test function.
|
||||
"""
|
||||
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
"""Wrapper for asynchronous test functions to measure execution time.
|
||||
|
||||
Args:
|
||||
*args (Any): Positional arguments for the test function.
|
||||
**kwargs (Any): Keyword arguments for the test function.
|
||||
|
||||
Returns:
|
||||
Any: The result of the test function.
|
||||
"""
|
||||
start_time = time.time()
|
||||
result = await func(*args, **kwargs)
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
if duration > seconds:
|
||||
pytest.fail(
|
||||
f"{func.__name__} exceeded the maximum allowed time of {seconds} "
|
||||
f"seconds."
|
||||
)
|
||||
return result
|
||||
|
||||
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
"""Wrapper for synchronous test functions to measure execution time.
|
||||
|
||||
Args:
|
||||
*args (Any): Positional arguments for the test function.
|
||||
**kwargs (Any): Keyword arguments for the test function.
|
||||
|
||||
Returns:
|
||||
Any: The result of the test function.
|
||||
"""
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
if duration > seconds:
|
||||
pytest.fail(
|
||||
f"{func.__name__} exceeded the maximum allowed time of {seconds} "
|
||||
f"seconds."
|
||||
)
|
||||
return result
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return async_wrapper # type: ignore
|
||||
else:
|
||||
return sync_wrapper # type: ignore
|
||||
|
||||
return decorator
|
||||
57
libs/standard-tests/tests/unit_tests/test_max_time.py
Normal file
57
libs/standard-tests/tests/unit_tests/test_max_time.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Check that the max time decorator works."""
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_standard_tests.utils.timeout import _timeout
|
||||
|
||||
|
||||
@_timeout(seconds=0.5)
|
||||
def test_sync_fast() -> None:
|
||||
"""Test function that completes within the allowed time."""
|
||||
time.sleep(0.01)
|
||||
|
||||
|
||||
@_timeout(seconds=0.5)
|
||||
async def test_async_fast() -> None:
|
||||
"""Test async function that completes within the allowed time."""
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
|
||||
@pytest.mark.xfail(strict=True)
|
||||
@_timeout(seconds=0)
|
||||
def test_sync_slow() -> None:
|
||||
"""Test async function that exceeds the allowed time."""
|
||||
time.sleep(0.01)
|
||||
|
||||
|
||||
@pytest.mark.xfail(strict=True)
|
||||
@_timeout(seconds=0)
|
||||
async def test_async_slow() -> None:
|
||||
"""Test async function that exceeds the allowed time."""
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
|
||||
class TestMethodDecoration:
|
||||
@_timeout(seconds=0.5)
|
||||
def test_sync_fast_method(self) -> None:
|
||||
"""Test function that completes within the allowed time."""
|
||||
time.sleep(0.01)
|
||||
|
||||
@_timeout(seconds=0.5)
|
||||
async def test_async_fast_method(self) -> None:
|
||||
"""Test async function that completes within the allowed time."""
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
@pytest.mark.xfail(strict=True)
|
||||
@_timeout(seconds=0)
|
||||
def test_sync_slow_method(self) -> None:
|
||||
"""Test async function that exceeds the allowed time."""
|
||||
time.sleep(0.01)
|
||||
|
||||
@pytest.mark.xfail(strict=True)
|
||||
@_timeout(seconds=0)
|
||||
async def test_async_slow_method(self) -> None:
|
||||
"""Test async function that exceeds the allowed time."""
|
||||
await asyncio.sleep(0.01)
|
||||
Reference in New Issue
Block a user