mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[unittest] supported condititonal testing based on env var (#1701)
polish code
This commit is contained in:
17
colossalai/testing/pytest_wrapper.py
Normal file
17
colossalai/testing/pytest_wrapper.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import pytest
|
||||
import os
|
||||
|
||||
|
||||
def run_on_environment_flag(name: str):
|
||||
"""
|
||||
Conditionally run a test based on the environment variable. If this environment variable is set
|
||||
to 1, this test will be executed. Otherwise, this test is skipped. The environment variable is default to 0.
|
||||
"""
|
||||
assert isinstance(name, str)
|
||||
flag = os.environ.get(name.upper(), '0')
|
||||
|
||||
reason = f'Environment varialbe {name} is {flag}'
|
||||
if flag == '1':
|
||||
return pytest.mark.skipif(False, reason=reason)
|
||||
else:
|
||||
return pytest.mark.skipif(True, reason=reason)
|
@@ -193,11 +193,12 @@ def skip_if_not_enough_gpus(min_gpus: int):
|
||||
"""
|
||||
|
||||
def _wrap_func(f):
|
||||
|
||||
def _execute_by_gpu_num(*args, **kwargs):
|
||||
num_avail_gpu = torch.cuda.device_count()
|
||||
if num_avail_gpu >= min_gpus:
|
||||
f(*args, **kwargs)
|
||||
|
||||
return _execute_by_gpu_num
|
||||
|
||||
return _wrap_func
|
||||
|
||||
|
Reference in New Issue
Block a user