mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[test] refactor tests with spawn (#3452)
* [test] added spawn decorator * polish code * polish code * polish code * polish code * polish code * polish code
This commit is contained in:
@@ -1,19 +1,14 @@
|
||||
import time
|
||||
from argparse import ArgumentParser
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torchvision.models as tm
|
||||
from bench_utils import bench, data_gen_resnet
|
||||
|
||||
import colossalai
|
||||
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
|
||||
from colossalai.fx import metainfo_trace, symbolic_trace
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import spawn
|
||||
|
||||
|
||||
def _benchmark(rank, world_size, port):
|
||||
@@ -50,9 +45,7 @@ def _benchmark(rank, world_size, port):
|
||||
|
||||
|
||||
def auto_activation_checkpoint_batchsize_benchmark():
|
||||
world_size = 1
|
||||
run_func_module = partial(_benchmark, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_module, nprocs=world_size)
|
||||
spawn(_benchmark, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@@ -4,14 +4,13 @@ from functools import partial
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torchvision.models as tm
|
||||
from bench_utils import GPTLMLoss, bench_rotor, data_gen_gpt2, data_gen_resnet, gpt2_medium
|
||||
|
||||
import colossalai
|
||||
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
|
||||
from colossalai.fx import metainfo_trace, symbolic_trace
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import spawn
|
||||
|
||||
|
||||
def _benchmark(rank, world_size, port, args):
|
||||
@@ -77,8 +76,7 @@ def _benchmark(rank, world_size, port, args):
|
||||
|
||||
def auto_activation_checkpoint_benchmark(args):
|
||||
world_size = 1
|
||||
run_func_module = partial(_benchmark, world_size=world_size, port=free_port(), args=args)
|
||||
mp.spawn(run_func_module, nprocs=world_size)
|
||||
spawn(_benchmark, world_size, args=args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user