mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 05:01:44 +00:00
[test] refactor tests with spawn (#3452)
* [test] added spawn decorator * polish code * polish code * polish code * polish code * polish code * polish code
This commit is contained in:
@@ -6,13 +6,14 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.rpc as rpc
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.pipeline.pipeline_process_group import ppg
|
||||
from torch import nn
|
||||
from torch._C._distributed_rpc import _is_current_rpc_agent_set
|
||||
from torch.optim import SGD, Adam, Optimizer, RMSprop
|
||||
|
||||
from colossalai import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.pipeline.pipeline_process_group import ppg
|
||||
|
||||
rpc_is_initialized = _is_current_rpc_agent_set
|
||||
|
||||
|
||||
@@ -20,7 +21,9 @@ def color_debug(text, prefix=' ', color='blue'):
|
||||
color = color.upper()
|
||||
print(getattr(Back, color), prefix, Style.RESET_ALL, text)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
|
||||
def __init__(self, dim: int, layers: int):
|
||||
super().__init__()
|
||||
self.layers = torch.nn.ModuleList()
|
||||
@@ -32,8 +35,10 @@ class MLP(nn.Module):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x.sum()
|
||||
|
||||
|
||||
|
||||
class DAG_MLP(nn.Module):
|
||||
|
||||
def __init__(self, dim: int, layers: int):
|
||||
super().__init__()
|
||||
self.layers = torch.nn.ModuleList()
|
||||
@@ -48,6 +53,7 @@ class DAG_MLP(nn.Module):
|
||||
y = self.dag_layer(y)
|
||||
return x.sum(), y.sum()
|
||||
|
||||
|
||||
class RpcTestModel(nn.Module):
|
||||
|
||||
def __init__(self, stage_id, actual_stage_num, feat_num, h) -> None:
|
||||
|
@@ -1,27 +1,27 @@
|
||||
import torch
|
||||
import pytest
|
||||
import os
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed.rpc as rpc
|
||||
from functools import partial
|
||||
|
||||
from torch import nn
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed.rpc as rpc
|
||||
from rpc_test_utils import DAG_MLP, MLP
|
||||
from torch._C._distributed_rpc import _is_current_rpc_agent_set
|
||||
|
||||
from colossalai import launch
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.pipeline.middleware.adaptor import get_fx_topology
|
||||
from colossalai.pipeline.pipeline_process_group import ppg
|
||||
from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine
|
||||
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.pipeline.middleware.adaptor import get_fx_topology
|
||||
from rpc_test_utils import MLP, DAG_MLP
|
||||
from functools import partial
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
# global variable for model created
|
||||
batch_size = 16
|
||||
dim = 10
|
||||
rpc_is_initialized = _is_current_rpc_agent_set
|
||||
|
||||
|
||||
def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs):
|
||||
model.eval()
|
||||
tracer = ColoTracer()
|
||||
@@ -34,13 +34,15 @@ def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs):
|
||||
for submodule in split_submodules:
|
||||
if isinstance(submodule, torch.fx.GraphModule):
|
||||
setattr(submodule, '_topo', topo)
|
||||
return split_submodules[pp_rank+1]
|
||||
return split_submodules[pp_rank + 1]
|
||||
|
||||
|
||||
def partition(model, data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int):
|
||||
torch.manual_seed(1024)
|
||||
partition = create_partition_module(pp_rank, stage_num, model, data_kwargs)
|
||||
return partition
|
||||
|
||||
|
||||
def run_master(model_cls, world_size, forward_only):
|
||||
torch.manual_seed(100)
|
||||
|
||||
@@ -50,23 +52,27 @@ def run_master(model_cls, world_size, forward_only):
|
||||
chunk = 1
|
||||
num_microbatches = 8
|
||||
use_checkpoint = 'store_true'
|
||||
|
||||
|
||||
if model_cls == MLP:
|
||||
|
||||
def data_gen():
|
||||
x = torch.zeros((batch_size, dim))
|
||||
kwargs = dict(x=x)
|
||||
return kwargs
|
||||
|
||||
model = model_cls(dim, stage_num * 3)
|
||||
if forward_only:
|
||||
labels = None
|
||||
else:
|
||||
labels = 1
|
||||
elif model_cls == DAG_MLP:
|
||||
|
||||
def data_gen():
|
||||
x = torch.zeros((batch_size, dim))
|
||||
y = torch.zeros((batch_size, dim))
|
||||
kwargs = dict(x=x, y=y)
|
||||
return kwargs
|
||||
|
||||
model = model_cls(dim, stage_num * 3)
|
||||
if forward_only:
|
||||
labels = None
|
||||
@@ -74,15 +80,17 @@ def run_master(model_cls, world_size, forward_only):
|
||||
labels = 1
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
data_kwargs = data_gen()
|
||||
|
||||
engine = OneFOneBPipelineEngine(partition_fn=partial(partition, model, data_kwargs),
|
||||
stage_num=stage_num,
|
||||
num_microbatches=num_microbatches,
|
||||
device=device,
|
||||
chunk=chunk,
|
||||
checkpoint=use_checkpoint,)
|
||||
|
||||
engine = OneFOneBPipelineEngine(
|
||||
partition_fn=partial(partition, model, data_kwargs),
|
||||
stage_num=stage_num,
|
||||
num_microbatches=num_microbatches,
|
||||
device=device,
|
||||
chunk=chunk,
|
||||
checkpoint=use_checkpoint,
|
||||
)
|
||||
if not forward_only:
|
||||
engine.initialize_optimizer(getattr(torch.optim, 'SGD'), lr=1e-3)
|
||||
|
||||
@@ -90,13 +98,14 @@ def run_master(model_cls, world_size, forward_only):
|
||||
input_x = torch.randn((batch_size, dim), device=device)
|
||||
input_y = torch.randn((batch_size, dim), device=device)
|
||||
logits = engine.forward_backward({'x': input_x, 'y': input_y}, labels=labels, forward_only=forward_only)
|
||||
|
||||
def run_worker(rank, model_cls, world_size, forward_only, master_func):
|
||||
|
||||
|
||||
def run_worker(rank, world_size, port, model_cls, forward_only, master_func):
|
||||
master_addr = 'localhost'
|
||||
master_port = 29020
|
||||
os.environ['MASTER_ADDR'] = master_addr
|
||||
os.environ['MASTER_PORT'] = str(master_port)
|
||||
|
||||
|
||||
disable_existing_loggers()
|
||||
|
||||
launch(dict(), rank, world_size, master_addr, master_port, 'nccl', verbose=False)
|
||||
@@ -113,7 +122,8 @@ def run_worker(rank, model_cls, world_size, forward_only, master_func):
|
||||
# barrier here
|
||||
if rpc_is_initialized():
|
||||
rpc.shutdown()
|
||||
|
||||
|
||||
|
||||
@pytest.mark.skip("skip due to CI torch version 1.11")
|
||||
@parameterize('model_cls', [MLP, DAG_MLP])
|
||||
@parameterize('forward_only', [True, False])
|
||||
@@ -122,7 +132,14 @@ def run_worker(rank, model_cls, world_size, forward_only, master_func):
|
||||
def test_pp_middleware_fwd(model_cls, forward_only):
|
||||
world_size = 4
|
||||
master_func = run_master
|
||||
mp.spawn(run_worker, args=(model_cls, world_size, forward_only, master_func), nprocs=world_size)
|
||||
spawn(
|
||||
run_worker,
|
||||
world_size,
|
||||
model_cls=model_cls,
|
||||
forward_only=forward_only,
|
||||
master_func=master_func,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_pp_middleware_fwd()
|
||||
test_pp_middleware_fwd()
|
||||
|
@@ -1,9 +1,7 @@
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from colossalai.pipeline.pipelinable import PipelinableContext
|
||||
|
||||
from colossalai.testing import rerun_on_exception
|
||||
from colossalai.testing import rerun_if_address_is_in_use, rerun_on_exception, spawn
|
||||
|
||||
NUM_CHUNKS = 1
|
||||
PIPELINE_SIZE = 2
|
||||
@@ -27,7 +25,7 @@ class MLP(torch.nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def run_pipelinable(rank):
|
||||
def run_pipelinable(rank, world_size, port):
|
||||
pipelinable = PipelinableContext()
|
||||
with pipelinable:
|
||||
model = MLP()
|
||||
@@ -50,9 +48,9 @@ def run_pipelinable(rank):
|
||||
assert layers_count_in_part_0 + layers_count_in_part_1 == pipelinable.layers_count
|
||||
|
||||
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_pipelinable():
|
||||
mp.spawn(run_pipelinable, nprocs=1)
|
||||
spawn(run_pipelinable, 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -1,13 +1,12 @@
|
||||
import os
|
||||
|
||||
import torch.distributed.rpc as rpc
|
||||
import torch.multiprocessing as mp
|
||||
import pytest
|
||||
from rpc_test_utils import pg_parse_args, rpc_is_initialized
|
||||
|
||||
from colossalai.pipeline.pipeline_process_group import ppg
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from rpc_test_utils import pg_parse_args, rpc_is_initialized
|
||||
from colossalai.pipeline.pipeline_process_group import ppg
|
||||
from colossalai.testing import spawn
|
||||
|
||||
|
||||
def run_worker(rank, args):
|
||||
@@ -40,4 +39,4 @@ def run_worker(rank, args):
|
||||
if __name__ == "__main__":
|
||||
args = pg_parse_args()
|
||||
world_size = args.world_size
|
||||
mp.spawn(run_worker, args=(args,), nprocs=world_size)
|
||||
spawn(run_worker, world_size, args=args)
|
||||
|
Reference in New Issue
Block a user