[test] refactor tests with spawn (#3452)

* [test] added spawn decorator

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
This commit is contained in:
Frank Lee
2023-04-06 14:51:35 +08:00
committed by GitHub
parent 62f4e2eb07
commit 80eba05b0a
240 changed files with 1723 additions and 2342 deletions

View File

@@ -2,20 +2,18 @@
# -*- encoding: utf-8 -*-
import os
from functools import partial
from pathlib import Path
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torchvision import datasets, transforms
import colossalai
from torchvision import transforms, datasets
from colossalai.context import ParallelMode, Config
from colossalai.context import Config, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_dataloader, free_port
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_dataloader
CONFIG = Config(dict(
parallel=dict(
@@ -58,9 +56,7 @@ def run_data_sampler(rank, world_size, port):
@pytest.mark.cpu
@rerun_if_address_is_in_use()
def test_data_sampler():
world_size = 4
test_func = partial(run_data_sampler, world_size=world_size, port=free_port())
mp.spawn(test_func, nprocs=world_size)
spawn(run_data_sampler, 4)
if __name__ == '__main__':

View File

@@ -2,21 +2,18 @@
# -*- encoding: utf-8 -*-
import os
from functools import partial
from pathlib import Path
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torchvision import transforms, datasets
from torchvision import datasets, transforms
import colossalai
from colossalai.context import ParallelMode, Config
from colossalai.context import Config, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_dataloader, free_port
from colossalai.testing import rerun_if_address_is_in_use
from torchvision import transforms
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_dataloader
CONFIG = Config(
dict(
@@ -70,9 +67,7 @@ def run_data_sampler(rank, world_size, port):
@pytest.mark.cpu
@rerun_if_address_is_in_use()
def test_data_sampler():
world_size = 4
test_func = partial(run_data_sampler, world_size=world_size, port=free_port())
mp.spawn(test_func, nprocs=world_size)
spawn(run_data_sampler, 4)
if __name__ == '__main__':