[test] refactor tests with spawn (#3452)

* [test] added spawn decorator

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
This commit is contained in:
Frank Lee
2023-04-06 14:51:35 +08:00
committed by GitHub
parent 62f4e2eb07
commit 80eba05b0a
240 changed files with 1723 additions and 2342 deletions

View File

@@ -6,13 +6,14 @@ import torch
import torch.distributed as dist
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
from colossalai import launch
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.pipeline_process_group import ppg
from torch import nn
from torch._C._distributed_rpc import _is_current_rpc_agent_set
from torch.optim import SGD, Adam, Optimizer, RMSprop
from colossalai import launch
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.pipeline_process_group import ppg
rpc_is_initialized = _is_current_rpc_agent_set
@@ -20,7 +21,9 @@ def color_debug(text, prefix=' ', color='blue'):
color = color.upper()
print(getattr(Back, color), prefix, Style.RESET_ALL, text)
class MLP(nn.Module):
def __init__(self, dim: int, layers: int):
super().__init__()
self.layers = torch.nn.ModuleList()
@@ -32,8 +35,10 @@ class MLP(nn.Module):
for layer in self.layers:
x = layer(x)
return x.sum()
class DAG_MLP(nn.Module):
def __init__(self, dim: int, layers: int):
super().__init__()
self.layers = torch.nn.ModuleList()
@@ -48,6 +53,7 @@ class DAG_MLP(nn.Module):
y = self.dag_layer(y)
return x.sum(), y.sum()
class RpcTestModel(nn.Module):
def __init__(self, stage_id, actual_stage_num, feat_num, h) -> None:

View File

@@ -1,27 +1,27 @@
import torch
import pytest
import os
import torch.multiprocessing as mp
import torch.distributed.rpc as rpc
from functools import partial
from torch import nn
import pytest
import torch
import torch.distributed.rpc as rpc
from rpc_test_utils import DAG_MLP, MLP
from torch._C._distributed_rpc import _is_current_rpc_agent_set
from colossalai import launch
from colossalai.fx import ColoTracer
from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.middleware.adaptor import get_fx_topology
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
from colossalai.fx import ColoTracer
from colossalai.pipeline.middleware.adaptor import get_fx_topology
from rpc_test_utils import MLP, DAG_MLP
from functools import partial
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
# global variable for model created
batch_size = 16
dim = 10
rpc_is_initialized = _is_current_rpc_agent_set
def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs):
model.eval()
tracer = ColoTracer()
@@ -34,13 +34,15 @@ def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs):
for submodule in split_submodules:
if isinstance(submodule, torch.fx.GraphModule):
setattr(submodule, '_topo', topo)
return split_submodules[pp_rank+1]
return split_submodules[pp_rank + 1]
def partition(model, data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int):
torch.manual_seed(1024)
partition = create_partition_module(pp_rank, stage_num, model, data_kwargs)
return partition
def run_master(model_cls, world_size, forward_only):
torch.manual_seed(100)
@@ -50,23 +52,27 @@ def run_master(model_cls, world_size, forward_only):
chunk = 1
num_microbatches = 8
use_checkpoint = 'store_true'
if model_cls == MLP:
def data_gen():
x = torch.zeros((batch_size, dim))
kwargs = dict(x=x)
return kwargs
model = model_cls(dim, stage_num * 3)
if forward_only:
labels = None
else:
labels = 1
elif model_cls == DAG_MLP:
def data_gen():
x = torch.zeros((batch_size, dim))
y = torch.zeros((batch_size, dim))
kwargs = dict(x=x, y=y)
return kwargs
model = model_cls(dim, stage_num * 3)
if forward_only:
labels = None
@@ -74,15 +80,17 @@ def run_master(model_cls, world_size, forward_only):
labels = 1
else:
pass
data_kwargs = data_gen()
engine = OneFOneBPipelineEngine(partition_fn=partial(partition, model, data_kwargs),
stage_num=stage_num,
num_microbatches=num_microbatches,
device=device,
chunk=chunk,
checkpoint=use_checkpoint,)
engine = OneFOneBPipelineEngine(
partition_fn=partial(partition, model, data_kwargs),
stage_num=stage_num,
num_microbatches=num_microbatches,
device=device,
chunk=chunk,
checkpoint=use_checkpoint,
)
if not forward_only:
engine.initialize_optimizer(getattr(torch.optim, 'SGD'), lr=1e-3)
@@ -90,13 +98,14 @@ def run_master(model_cls, world_size, forward_only):
input_x = torch.randn((batch_size, dim), device=device)
input_y = torch.randn((batch_size, dim), device=device)
logits = engine.forward_backward({'x': input_x, 'y': input_y}, labels=labels, forward_only=forward_only)
def run_worker(rank, model_cls, world_size, forward_only, master_func):
def run_worker(rank, world_size, port, model_cls, forward_only, master_func):
master_addr = 'localhost'
master_port = 29020
os.environ['MASTER_ADDR'] = master_addr
os.environ['MASTER_PORT'] = str(master_port)
disable_existing_loggers()
launch(dict(), rank, world_size, master_addr, master_port, 'nccl', verbose=False)
@@ -113,7 +122,8 @@ def run_worker(rank, model_cls, world_size, forward_only, master_func):
# barrier here
if rpc_is_initialized():
rpc.shutdown()
@pytest.mark.skip("skip due to CI torch version 1.11")
@parameterize('model_cls', [MLP, DAG_MLP])
@parameterize('forward_only', [True, False])
@@ -122,7 +132,14 @@ def run_worker(rank, model_cls, world_size, forward_only, master_func):
def test_pp_middleware_fwd(model_cls, forward_only):
world_size = 4
master_func = run_master
mp.spawn(run_worker, args=(model_cls, world_size, forward_only, master_func), nprocs=world_size)
spawn(
run_worker,
world_size,
model_cls=model_cls,
forward_only=forward_only,
master_func=master_func,
)
if __name__ == "__main__":
test_pp_middleware_fwd()
test_pp_middleware_fwd()

View File

@@ -1,9 +1,7 @@
import torch
import torch.multiprocessing as mp
from colossalai.pipeline.pipelinable import PipelinableContext
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use, rerun_on_exception, spawn
NUM_CHUNKS = 1
PIPELINE_SIZE = 2
@@ -27,7 +25,7 @@ class MLP(torch.nn.Module):
return x
def run_pipelinable(rank):
def run_pipelinable(rank, world_size, port):
pipelinable = PipelinableContext()
with pipelinable:
model = MLP()
@@ -50,9 +48,9 @@ def run_pipelinable(rank):
assert layers_count_in_part_0 + layers_count_in_part_1 == pipelinable.layers_count
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_pipelinable():
mp.spawn(run_pipelinable, nprocs=1)
spawn(run_pipelinable, 1)
if __name__ == '__main__':

View File

@@ -1,13 +1,12 @@
import os
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
import pytest
from rpc_test_utils import pg_parse_args, rpc_is_initialized
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from rpc_test_utils import pg_parse_args, rpc_is_initialized
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.testing import spawn
def run_worker(rank, args):
@@ -40,4 +39,4 @@ def run_worker(rank, args):
if __name__ == "__main__":
args = pg_parse_args()
world_size = args.world_size
mp.spawn(run_worker, args=(args,), nprocs=world_size)
spawn(run_worker, world_size, args=args)