[test] refactor tests with spawn (#3452)

* [test] added spawn decorator

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
This commit is contained in:
Frank Lee
2023-04-06 14:51:35 +08:00
committed by GitHub
parent 62f4e2eb07
commit 80eba05b0a
240 changed files with 1723 additions and 2342 deletions

View File

@@ -1,8 +1,13 @@
import gc
import random
import re
import torch
from typing import Callable, List, Any
import socket
from functools import partial
from inspect import signature
from typing import Any, Callable, List
import torch
import torch.multiprocessing as mp
from packaging import version
@@ -43,7 +48,7 @@ def parameterize(argument: str, values: List[Any]) -> Callable:
# > davis: hello
# > davis: bye
# > davis: stop
Args:
argument (str): the name of the argument to parameterize
values (List[Any]): a list of values to iterate for this argument
@@ -85,13 +90,13 @@ def rerun_on_exception(exception_type: Exception = Exception, pattern: str = Non
def test_method():
print('hey')
raise RuntimeError('Address already in use')
# rerun for infinite times if Runtime error occurs
@rerun_on_exception(exception_type=RuntimeError, max_try=None)
def test_method():
print('hey')
raise RuntimeError('Address already in use')
# rerun only the exception message is matched with pattern
# for infinite times if Runtime error occurs
@rerun_on_exception(exception_type=RuntimeError, pattern="^Address.*$")
@@ -101,10 +106,10 @@ def rerun_on_exception(exception_type: Exception = Exception, pattern: str = Non
Args:
exception_type (Exception, Optional): The type of exception to detect for rerun
pattern (str, Optional): The pattern to match the exception message.
pattern (str, Optional): The pattern to match the exception message.
If the pattern is not None and matches the exception message,
the exception will be detected for rerun
max_try (int, Optional): Maximum reruns for this function. The default value is 5.
max_try (int, Optional): Maximum reruns for this function. The default value is 5.
If max_try is None, it will rerun foreven if exception keeps occurings
"""
@@ -202,3 +207,72 @@ def skip_if_not_enough_gpus(min_gpus: int):
return _execute_by_gpu_num
return _wrap_func
def free_port() -> int:
"""Get a free port on localhost.
Returns:
int: A free port on localhost.
"""
while True:
port = random.randint(20000, 65000)
try:
with socket.socket() as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("localhost", port))
return port
except OSError:
continue
def spawn(func, nprocs=1, **kwargs):
"""
This function is used to spawn processes for testing.
Usage:
# must contians arguments rank, world_size, port
def do_something(rank, world_size, port):
...
spawn(do_something, nprocs=8)
# can also pass other arguments
def do_something(rank, world_size, port, arg1, arg2):
...
spawn(do_something, nprocs=8, arg1=1, arg2=2)
Args:
func (Callable): The function to be spawned.
nprocs (int, optional): The number of processes to spawn. Defaults to 1.
"""
port = free_port()
wrapped_func = partial(func, world_size=nprocs, port=port, **kwargs)
mp.spawn(wrapped_func, nprocs=nprocs)
def clear_cache_before_run():
"""
This function is a wrapper to clear CUDA and python cache before executing the function.
Usage:
@clear_cache_before_run()
def test_something():
...
"""
def _wrap_func(f):
def _clear_cache(*args, **kwargs):
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_max_memory_cached()
torch.cuda.synchronize()
gc.collect()
f(*args, **kwargs)
return _clear_cache
return _wrap_func