[test] refactor tests with spawn (#3452)

* [test] added spawn decorator

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
This commit is contained in:
Frank Lee
2023-04-06 14:51:35 +08:00
committed by GitHub
parent 62f4e2eb07
commit 80eba05b0a
240 changed files with 1723 additions and 2342 deletions

View File

@@ -1,19 +1,14 @@
import time
from argparse import ArgumentParser
from copy import deepcopy
from functools import partial
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.multiprocessing as mp
import torchvision.models as tm
from bench_utils import bench, data_gen_resnet
import colossalai
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
from colossalai.fx import metainfo_trace, symbolic_trace
from colossalai.utils import free_port
from colossalai.testing import spawn
def _benchmark(rank, world_size, port):
@@ -50,9 +45,7 @@ def _benchmark(rank, world_size, port):
def auto_activation_checkpoint_batchsize_benchmark():
world_size = 1
run_func_module = partial(_benchmark, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
spawn(_benchmark, 1)
if __name__ == "__main__":

View File

@@ -4,14 +4,13 @@ from functools import partial
import matplotlib.pyplot as plt
import torch
import torch.multiprocessing as mp
import torchvision.models as tm
from bench_utils import GPTLMLoss, bench_rotor, data_gen_gpt2, data_gen_resnet, gpt2_medium
import colossalai
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
from colossalai.fx import metainfo_trace, symbolic_trace
from colossalai.utils import free_port
from colossalai.testing import spawn
def _benchmark(rank, world_size, port, args):
@@ -77,8 +76,7 @@ def _benchmark(rank, world_size, port, args):
def auto_activation_checkpoint_benchmark(args):
world_size = 1
run_func_module = partial(_benchmark, world_size=world_size, port=free_port(), args=args)
mp.spawn(run_func_module, nprocs=world_size)
spawn(_benchmark, world_size, args=args)
if __name__ == "__main__":