[test] refactor tests with spawn (#3452)

* [test] added spawn decorator

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
This commit is contained in:
Frank Lee
2023-04-06 14:51:35 +08:00
committed by GitHub
parent 62f4e2eb07
commit 80eba05b0a
240 changed files with 1723 additions and 2342 deletions

View File

@@ -1,11 +1,9 @@
import os
import random
from functools import partial
import numpy as np
import pytest
import torch
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from vit import get_training_components
@@ -15,8 +13,7 @@ from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext
@@ -156,8 +153,7 @@ def run_dist(rank, world_size, port, use_ddp):
@pytest.mark.parametrize('use_ddp', [False, True])
@rerun_if_address_is_in_use()
def test_vit(world_size, use_ddp):
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp)
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, world_size, use_ddp=use_ddp)
if __name__ == '__main__':

View File

@@ -1,20 +1,20 @@
import time
import pytest
import argparse
from functools import partial
import time
import pytest
import torch
from model_zoo import GPTLMLoss, get_gpt2_components
from torch.utils._pytree import tree_map
import torch.multiprocessing as mp
import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.fx.profiler import parameter_size
from colossalai.utils import free_port, get_current_device
from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer
from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
from colossalai.auto_parallel.offload.solver import NOT_NVML
from model_zoo import get_gpt2_components, GPTLMLoss
from colossalai.fx.profiler import parameter_size
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import spawn
from colossalai.utils import get_current_device
def parse_args():
parser = argparse.ArgumentParser()
@@ -24,6 +24,7 @@ def parse_args():
parser.add_argument('--memory_budget', type=float, default=16)
return parser.parse_args()
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
def train_gpt(args):
memory_budget = args.memory_budget * 1024 * 1024 * 1024
@@ -33,13 +34,16 @@ def train_gpt(args):
# build model
model_builder, data_gen = get_gpt2_components(model_type=model_type, batch_size=batch_size)
label = torch.randint(low=0, high=128, size=(64, 8,), device=get_current_device())
label = torch.randint(low=0, high=128, size=(
64,
8,
), device=get_current_device())
criterion = GPTLMLoss()
start_time = time.time()
model = model_builder()
model.train()
param_size = parameter_size(model) / 1024 ** 2 / 2
param_size = parameter_size(model) / 1024**2 / 2
init_time = time.time() - start_time
print(f"init_param_size={param_size:.3f} MB | init_model_time={init_time:.3f} s")
@@ -74,21 +78,20 @@ def train_gpt(args):
torch.cuda.synchronize()
exec_time = sum(sorted(time_list)[:5]) / 5
runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2
runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2
runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024**2
runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024**2
print(f'solver_type: {solver_type} | model_type: {model_type}')
print(
f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB '
f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|'
)
print(f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB '
f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|')
print(time_list)
def run(rank, world_size, port, args):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
train_gpt(args)
if __name__ == '__main__':
args = parse_args()
run_func = partial(run, world_size=1, port=free_port(), args=args)
mp.spawn(run_func, nprocs=1)
spawn(run, 1, args=args)

View File

@@ -1,18 +1,13 @@
from functools import partial
from time import time
from typing import Dict, Optional, Tuple, Union
import psutil
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import transformers
from gpt_modules import GPT2LMHeadModel, GPTLMLoss
from torch.fx import GraphModule
from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize, initialize_model
from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize
from colossalai.core import global_context as gpc
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch_from_torch
from colossalai.logging import disable_existing_loggers, get_dist_logger

View File

@@ -1,19 +1,14 @@
import time
from argparse import ArgumentParser
from copy import deepcopy
from functools import partial
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.multiprocessing as mp
import torchvision.models as tm
from bench_utils import bench, data_gen_resnet
import colossalai
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
from colossalai.fx import metainfo_trace, symbolic_trace
from colossalai.utils import free_port
from colossalai.testing import spawn
def _benchmark(rank, world_size, port):
@@ -50,9 +45,7 @@ def _benchmark(rank, world_size, port):
def auto_activation_checkpoint_batchsize_benchmark():
world_size = 1
run_func_module = partial(_benchmark, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
spawn(_benchmark, 1)
if __name__ == "__main__":

View File

@@ -4,14 +4,13 @@ from functools import partial
import matplotlib.pyplot as plt
import torch
import torch.multiprocessing as mp
import torchvision.models as tm
from bench_utils import GPTLMLoss, bench_rotor, data_gen_gpt2, data_gen_resnet, gpt2_medium
import colossalai
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
from colossalai.fx import metainfo_trace, symbolic_trace
from colossalai.utils import free_port
from colossalai.testing import spawn
def _benchmark(rank, world_size, port, args):
@@ -77,8 +76,7 @@ def _benchmark(rank, world_size, port, args):
def auto_activation_checkpoint_benchmark(args):
world_size = 1
run_func_module = partial(_benchmark, world_size=world_size, port=free_port(), args=args)
mp.spawn(run_func_module, nprocs=world_size)
spawn(_benchmark, world_size, args=args)
if __name__ == "__main__":