mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 13:30:19 +00:00
[test] refactor tests with spawn (#3452)
* [test] added spawn decorator * polish code * polish code * polish code * polish code * polish code * polish code
This commit is contained in:
@@ -1,8 +1,13 @@
|
||||
import gc
|
||||
import random
|
||||
import re
|
||||
import torch
|
||||
from typing import Callable, List, Any
|
||||
import socket
|
||||
from functools import partial
|
||||
from inspect import signature
|
||||
from typing import Any, Callable, List
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from packaging import version
|
||||
|
||||
|
||||
@@ -43,7 +48,7 @@ def parameterize(argument: str, values: List[Any]) -> Callable:
|
||||
# > davis: hello
|
||||
# > davis: bye
|
||||
# > davis: stop
|
||||
|
||||
|
||||
Args:
|
||||
argument (str): the name of the argument to parameterize
|
||||
values (List[Any]): a list of values to iterate for this argument
|
||||
@@ -85,13 +90,13 @@ def rerun_on_exception(exception_type: Exception = Exception, pattern: str = Non
|
||||
def test_method():
|
||||
print('hey')
|
||||
raise RuntimeError('Address already in use')
|
||||
|
||||
|
||||
# rerun for infinite times if Runtime error occurs
|
||||
@rerun_on_exception(exception_type=RuntimeError, max_try=None)
|
||||
def test_method():
|
||||
print('hey')
|
||||
raise RuntimeError('Address already in use')
|
||||
|
||||
|
||||
# rerun only the exception message is matched with pattern
|
||||
# for infinite times if Runtime error occurs
|
||||
@rerun_on_exception(exception_type=RuntimeError, pattern="^Address.*$")
|
||||
@@ -101,10 +106,10 @@ def rerun_on_exception(exception_type: Exception = Exception, pattern: str = Non
|
||||
|
||||
Args:
|
||||
exception_type (Exception, Optional): The type of exception to detect for rerun
|
||||
pattern (str, Optional): The pattern to match the exception message.
|
||||
pattern (str, Optional): The pattern to match the exception message.
|
||||
If the pattern is not None and matches the exception message,
|
||||
the exception will be detected for rerun
|
||||
max_try (int, Optional): Maximum reruns for this function. The default value is 5.
|
||||
max_try (int, Optional): Maximum reruns for this function. The default value is 5.
|
||||
If max_try is None, it will rerun foreven if exception keeps occurings
|
||||
"""
|
||||
|
||||
@@ -202,3 +207,72 @@ def skip_if_not_enough_gpus(min_gpus: int):
|
||||
return _execute_by_gpu_num
|
||||
|
||||
return _wrap_func
|
||||
|
||||
|
||||
def free_port() -> int:
|
||||
"""Get a free port on localhost.
|
||||
|
||||
Returns:
|
||||
int: A free port on localhost.
|
||||
"""
|
||||
while True:
|
||||
port = random.randint(20000, 65000)
|
||||
try:
|
||||
with socket.socket() as sock:
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock.bind(("localhost", port))
|
||||
return port
|
||||
except OSError:
|
||||
continue
|
||||
|
||||
|
||||
def spawn(func, nprocs=1, **kwargs):
|
||||
"""
|
||||
This function is used to spawn processes for testing.
|
||||
|
||||
Usage:
|
||||
# must contians arguments rank, world_size, port
|
||||
def do_something(rank, world_size, port):
|
||||
...
|
||||
|
||||
spawn(do_something, nprocs=8)
|
||||
|
||||
# can also pass other arguments
|
||||
def do_something(rank, world_size, port, arg1, arg2):
|
||||
...
|
||||
|
||||
spawn(do_something, nprocs=8, arg1=1, arg2=2)
|
||||
|
||||
Args:
|
||||
func (Callable): The function to be spawned.
|
||||
nprocs (int, optional): The number of processes to spawn. Defaults to 1.
|
||||
"""
|
||||
port = free_port()
|
||||
wrapped_func = partial(func, world_size=nprocs, port=port, **kwargs)
|
||||
mp.spawn(wrapped_func, nprocs=nprocs)
|
||||
|
||||
|
||||
def clear_cache_before_run():
|
||||
"""
|
||||
This function is a wrapper to clear CUDA and python cache before executing the function.
|
||||
|
||||
Usage:
|
||||
@clear_cache_before_run()
|
||||
def test_something():
|
||||
...
|
||||
"""
|
||||
|
||||
def _wrap_func(f):
|
||||
|
||||
def _clear_cache(*args, **kwargs):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_max_memory_cached()
|
||||
torch.cuda.synchronize()
|
||||
gc.collect()
|
||||
f(*args, **kwargs)
|
||||
|
||||
return _clear_cache
|
||||
|
||||
return _wrap_func
|
||||
|
Reference in New Issue
Block a user