mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 22:52:25 +00:00
[pipeline/chimera] test chimera | fix bug of initializing (#1615)
* [pipeline/tuning] improve dispatch performance both time and space cost * [pipeline/converge] add interface for testing convergence * [NFC] polish colossalai/utils/multi_tensor_apply/multi_tensor_apply.py code style * Update PipelineBase.py * [pipeline/chimera] reconstruct PipelineBase and Worker to support more feasible custom schedule | finish Chimera * [pipeline/chimera] test chimera | fix bug of initializing
This commit is contained in:
@@ -8,8 +8,13 @@ import torch.multiprocessing as mp
|
||||
import torch.distributed.rpc as rpc
|
||||
from torch.optim import SGD, Adam, RMSprop, Optimizer
|
||||
from torch._C._distributed_rpc import _is_current_rpc_agent_set
|
||||
import torch.distributed as dist
|
||||
from colorama import Back, Style
|
||||
|
||||
from colossalai.pipeline.pipeline_process_group import ppg
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai import launch
|
||||
|
||||
rpc_is_initialized = _is_current_rpc_agent_set
|
||||
|
||||
|
||||
@@ -25,12 +30,15 @@ class RpcTestModel(nn.Module):
|
||||
self.rank = stage_id
|
||||
self.is_last_rank = stage_id == actual_stage_num - 1
|
||||
self.linear_name = f'linear_{stage_id}'
|
||||
|
||||
if stage_id == 0:
|
||||
setattr(self, self.linear_name, nn.Linear(feat_num, h))
|
||||
linear = nn.Linear(feat_num, h)
|
||||
elif stage_id == actual_stage_num - 1:
|
||||
setattr(self, self.linear_name, nn.Linear(h, 1))
|
||||
linear = nn.Linear(h, 1)
|
||||
else:
|
||||
setattr(self, self.linear_name, nn.Linear(h, h))
|
||||
linear = nn.Linear(h, h)
|
||||
|
||||
setattr(self, self.linear_name, linear)
|
||||
|
||||
def forward(self, x) -> torch.Tensor:
|
||||
linear: nn.Module = getattr(self, self.linear_name)
|
||||
@@ -46,6 +54,8 @@ def parse_args():
|
||||
parser.add_argument('--epoch', type=int, default=1)
|
||||
parser.add_argument('--world_size', type=int, default=2)
|
||||
parser.add_argument('--batch_size', type=int, default=16)
|
||||
parser.add_argument('--dp_degree', type=int, default=1)
|
||||
parser.add_argument('--tp_degree', type=int, default=1)
|
||||
parser.add_argument('--num_microbatches', type=int, default=2)
|
||||
parser.add_argument('--chunk', type=int, default=1)
|
||||
parser.add_argument('--use_checkpoint', action='store_true')
|
||||
@@ -74,16 +84,24 @@ def run_worker(rank, args, master_func):
|
||||
os.environ['MASTER_ADDR'] = args.master_addr
|
||||
os.environ['MASTER_PORT'] = args.master_port
|
||||
|
||||
# config rpc
|
||||
# if cuda is used, set_device_map is a must is configured
|
||||
# for cuda is not supported in torch rpc by default
|
||||
options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=args.num_worker_threads)
|
||||
|
||||
device = args.device
|
||||
world_size = args.world_size
|
||||
for rank_idx in range(world_size):
|
||||
options.set_device_map(f'work{rank_idx}', {rank: rank_idx})
|
||||
dp_degree = args.dp_degree
|
||||
tp_degree = args.tp_degree
|
||||
num_worker_threads = args.num_worker_threads
|
||||
host = args.master_addr
|
||||
port = args.master_port
|
||||
backend = 'nccl' if device == 'cuda' else 'gloo'
|
||||
|
||||
rpc.init_rpc(name=f'work{rank}', rank=rank, world_size=world_size, rpc_backend_options=options)
|
||||
disable_existing_loggers()
|
||||
|
||||
launch(dict(), rank, world_size, host, int(port), backend, verbose=False)
|
||||
ppg.set_global_info(rank=rank,
|
||||
world_size=world_size,
|
||||
dp_degree=dp_degree,
|
||||
tp_degree=tp_degree,
|
||||
num_worker_threads=num_worker_threads,
|
||||
device=device)
|
||||
|
||||
# in rpc mode, only rank 0 is needed to be coded
|
||||
if rank == 0:
|
||||
|
@@ -1,9 +1,21 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.autograd as autograd
|
||||
|
||||
from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine, ChimeraPipelineEngine
|
||||
from colossalai.pipeline.rpc import ChimeraPipelineEngine
|
||||
from colossalai.testing import assert_close
|
||||
from rpc_test_utils import rpc_run, parse_args, RpcTestModel
|
||||
|
||||
# global variable for model created
|
||||
feat_num = 100
|
||||
h = 100
|
||||
|
||||
|
||||
def partition(pp_rank: int, chunk: int, stage_num: int):
|
||||
torch.manual_seed(1024)
|
||||
partition = RpcTestModel(pp_rank, stage_num, feat_num, h)
|
||||
return partition
|
||||
|
||||
|
||||
def run_master(args):
|
||||
torch.manual_seed(100)
|
||||
@@ -17,23 +29,51 @@ def run_master(args):
|
||||
use_checkpoint = False
|
||||
|
||||
sample_num = 1024
|
||||
feat_num = 10
|
||||
h = 10
|
||||
batch_size = 1024
|
||||
|
||||
assert sample_num % batch_size == 0
|
||||
|
||||
module_partitions = [RpcTestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)]
|
||||
engine = ChimeraPipelineEngine(module_partitions=module_partitions,
|
||||
engine = ChimeraPipelineEngine(partition_fn=partition,
|
||||
stage_num=stage_num,
|
||||
num_microbatches=num_microbatches,
|
||||
device=device,
|
||||
checkpoint=use_checkpoint)
|
||||
engine.initialize_optimizer(torch.optim.Adam, lr=1e-3)
|
||||
|
||||
input_sample = torch.randn((sample_num, feat_num), device=device)
|
||||
|
||||
for _ in range(epoch):
|
||||
_ = engine.forward_backward(input_sample, forward_only=False)
|
||||
forward_result = engine.forward_backward(input_sample)
|
||||
|
||||
cuda_rpc_result = []
|
||||
single_result = []
|
||||
actual_stage_num = engine._get_actual_stage_num()
|
||||
|
||||
# compute forward result and backward grad of parameters in cuda rpc
|
||||
cuda_rpc_result.append(sum(forward_result[0]))
|
||||
grad = engine.remote_grad()
|
||||
for stage_id in range(actual_stage_num):
|
||||
for p in grad[stage_id]:
|
||||
cuda_rpc_result.append(p)
|
||||
|
||||
# compute forward result and backward grad of parameters just in rank_0
|
||||
test_model = nn.Sequential(
|
||||
*[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)]).to(device)
|
||||
# input_sample = input_sample[len(input_sample) // 2:]
|
||||
input_sample = input_sample.requires_grad_()
|
||||
out_val = test_model(input_sample).sum()
|
||||
autograd.backward(out_val)
|
||||
single_result.append(out_val)
|
||||
for p in test_model.parameters():
|
||||
single_result.append(p.grad)
|
||||
|
||||
# print("my")
|
||||
# print(cuda_rpc_result[1])
|
||||
# print("answer:")
|
||||
# print(single_result[1])
|
||||
|
||||
# assert len(cuda_rpc_result) == len(single_result)
|
||||
# for r_c, r_s in zip(cuda_rpc_result, single_result):
|
||||
# assert_close(r_c, r_s, 0.001, 0.001)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@@ -7,6 +7,16 @@ from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine,
|
||||
from colossalai.testing import assert_close
|
||||
from rpc_test_utils import rpc_run, parse_args, RpcTestModel
|
||||
|
||||
# global variable for model created
|
||||
feat_num = 100
|
||||
h = 100
|
||||
|
||||
|
||||
def partition(pp_rank: int, chunk: int, stage_num: int):
|
||||
torch.manual_seed(1024)
|
||||
partition = RpcTestModel(pp_rank, stage_num, feat_num, h)
|
||||
return partition
|
||||
|
||||
|
||||
def run_master(args):
|
||||
torch.manual_seed(100)
|
||||
@@ -20,20 +30,14 @@ def run_master(args):
|
||||
optimizer_class = globals()[args.optimizer]
|
||||
|
||||
lr = 1e-3
|
||||
|
||||
sample_num = 1024
|
||||
feat_num = 100
|
||||
h = 100
|
||||
batch_size = 1024
|
||||
|
||||
assert sample_num % batch_size == 0
|
||||
batch_num = sample_num // batch_size
|
||||
|
||||
input_sample = torch.randn((sample_num, feat_num), device=device)
|
||||
|
||||
module_partitions = [RpcTestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)]
|
||||
|
||||
engine = OneFOneBPipelineEngine(module_partitions=module_partitions,
|
||||
engine = OneFOneBPipelineEngine(partition_fn=partition,
|
||||
stage_num=stage_num,
|
||||
num_microbatches=num_microbatches,
|
||||
device=device,
|
||||
@@ -55,7 +59,8 @@ def run_master(args):
|
||||
cuda_rpc_result.append(p)
|
||||
|
||||
# compute forward result and backward grad of parameters just in rank_0
|
||||
test_model = nn.Sequential(*module_partitions).to(device)
|
||||
test_model = nn.Sequential(
|
||||
*[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)]).to(device)
|
||||
optimizer: Optimizer = optimizer_class(test_model.parameters(), lr=lr)
|
||||
input_sample = input_sample.requires_grad_()
|
||||
out_val = test_model(input_sample).sum()
|
||||
|
@@ -18,17 +18,30 @@ from colossalai.trainer import Trainer, hooks
|
||||
from colossalai.utils import MultiTimer, get_dataloader
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.pipeline.pipelinable import PipelinableContext, PipelinableModel
|
||||
from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine
|
||||
from colossalai.pipeline.rpc import OneFOneBPipelineEngine, ChimeraPipelineEngine
|
||||
from colossalai.pipeline.pipeline_process_group import ppg
|
||||
|
||||
|
||||
def flatten(x):
|
||||
return torch.flatten(x, 1)
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
def partition(pp_rank: int, chunk: int, stage_num: int):
|
||||
pipelinable = PipelinableContext()
|
||||
|
||||
def forward(self, x):
|
||||
return torch.flatten(x, start_dim=1)
|
||||
# build model partitions
|
||||
with pipelinable:
|
||||
# input : [B, 3, 32, 32]
|
||||
_ = resnet50()
|
||||
|
||||
pipelinable.policy = "customized"
|
||||
|
||||
exec_seq = [
|
||||
'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool', (flatten, "behind"), 'fc'
|
||||
]
|
||||
pipelinable.to_layer_list(exec_seq)
|
||||
partition = pipelinable.partition(chunk, stage_num, pp_rank)
|
||||
return partition
|
||||
|
||||
|
||||
def run_master(args):
|
||||
@@ -39,37 +52,12 @@ def run_master(args):
|
||||
stage_num = world_size
|
||||
num_microbatches = args.num_microbatches
|
||||
|
||||
assert chunk == 1
|
||||
|
||||
pipelinable = PipelinableContext()
|
||||
|
||||
# build model partitions
|
||||
with pipelinable:
|
||||
# input : [B, 3, 32, 32]
|
||||
model = resnet50()
|
||||
|
||||
exec_seq = [
|
||||
'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool', (flatten, "behind"), 'fc'
|
||||
]
|
||||
pipelinable.to_layer_list(exec_seq)
|
||||
module_partitions: List[PipelinableModel] = [
|
||||
pipelinable.partition(chunk, stage_num, pp_rank) for pp_rank in range(world_size)
|
||||
]
|
||||
|
||||
# build dataloader
|
||||
root = os.environ.get('DATA', './data')
|
||||
train_dataloader, test_dataloader = build_cifar(batch_size, root, padding=4, crop=32, resize=32)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
partition_1 = module_partitions[0]
|
||||
partition_2 = []
|
||||
for model in module_partitions[1]._module_list:
|
||||
partition_2.append(model)
|
||||
partition_2.insert(len(partition_2) - 1, Flatten())
|
||||
partition_2 = nn.Sequential(*partition_2)
|
||||
module_partitions = [partition_1, partition_2]
|
||||
|
||||
pp_engine = OneFOneBPipelineEngine(module_partitions=module_partitions,
|
||||
pp_engine = OneFOneBPipelineEngine(partition_fn=partition,
|
||||
stage_num=stage_num,
|
||||
num_microbatches=num_microbatches,
|
||||
device=device,
|
||||
|
@@ -4,6 +4,16 @@ from torch import nn
|
||||
from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine
|
||||
from rpc_test_utils import rpc_run, parse_args, RpcTestModel
|
||||
|
||||
# global variable for model created
|
||||
feat_num = 100
|
||||
h = 100
|
||||
|
||||
|
||||
def partition(pp_rank: int, chunk: int, stage_num: int):
|
||||
torch.manual_seed(1024)
|
||||
partition = RpcTestModel(pp_rank, stage_num, feat_num, h)
|
||||
return partition
|
||||
|
||||
|
||||
def run_master(args):
|
||||
torch.manual_seed(100)
|
||||
@@ -13,22 +23,16 @@ def run_master(args):
|
||||
stage_num = args.world_size
|
||||
chunk = args.chunk
|
||||
num_microbatches = args.num_microbatches
|
||||
actual_stage_num = stage_num * chunk
|
||||
use_checkpoint = args.use_checkpoint
|
||||
|
||||
sample_num = 1024
|
||||
feat_num = 10
|
||||
h = 10
|
||||
batch_size = 1024
|
||||
|
||||
assert sample_num % batch_size == 0
|
||||
batch_num = sample_num // batch_size
|
||||
|
||||
input_sample = torch.randn((sample_num, feat_num), device=device)
|
||||
|
||||
module_partitions = [RpcTestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)]
|
||||
|
||||
engine = OneFOneBPipelineEngine(module_partitions=module_partitions,
|
||||
engine = OneFOneBPipelineEngine(partition_fn=partition,
|
||||
stage_num=stage_num,
|
||||
num_microbatches=num_microbatches,
|
||||
device=device,
|
||||
|
@@ -6,6 +6,15 @@ from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine,
|
||||
from colossalai.testing import assert_close
|
||||
from rpc_test_utils import rpc_run, parse_args, RpcTestModel
|
||||
|
||||
feat_num = 100
|
||||
h = 100
|
||||
|
||||
|
||||
def partition(pp_rank: int, chunk: int, stage_num: int):
|
||||
torch.manual_seed(1024)
|
||||
partition = RpcTestModel(pp_rank, stage_num, feat_num, h)
|
||||
return partition
|
||||
|
||||
|
||||
def run_master(args):
|
||||
torch.manual_seed(100)
|
||||
@@ -18,25 +27,20 @@ def run_master(args):
|
||||
num_microbatches = args.num_microbatches
|
||||
|
||||
sample_num = 1024
|
||||
feat_num = 100
|
||||
h = 100
|
||||
batch_size = 1024
|
||||
|
||||
assert sample_num % batch_size == 0
|
||||
batch_num = sample_num // batch_size
|
||||
|
||||
input_sample = torch.randn((sample_num, feat_num), device=device)
|
||||
|
||||
module_partitions = [RpcTestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)]
|
||||
|
||||
engine = OneFOneBPipelineEngine(module_partitions=module_partitions,
|
||||
engine = OneFOneBPipelineEngine(partition_fn=partition,
|
||||
stage_num=stage_num,
|
||||
num_microbatches=num_microbatches,
|
||||
device=device,
|
||||
chunk=chunk,
|
||||
checkpoint=use_checkpoint)
|
||||
|
||||
forward_result = engine.forward_backward(input_sample)[0]
|
||||
forward_result = engine.forward_backward(input_sample)
|
||||
|
||||
cuda_rpc_result = []
|
||||
single_result = []
|
||||
@@ -50,7 +54,8 @@ def run_master(args):
|
||||
cuda_rpc_result.append(p)
|
||||
|
||||
# compute forward result and backward grad of parameters just in rank_0
|
||||
test_model = nn.Sequential(*module_partitions).to(device)
|
||||
test_model = nn.Sequential(
|
||||
*[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)]).to(device)
|
||||
input_sample = input_sample.requires_grad_()
|
||||
out_val = test_model(input_sample).sum()
|
||||
autograd.backward(out_val)
|
||||
|
@@ -4,7 +4,7 @@ import torch.distributed.rpc as rpc
|
||||
import torch.multiprocessing as mp
|
||||
import pytest
|
||||
|
||||
from colossalai.pipeline.pipeline_process_group import PipelineProcessGroup
|
||||
from colossalai.pipeline.pipeline_process_group import ppg
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from rpc_test_utils import pg_parse_args, rpc_is_initialized
|
||||
@@ -26,12 +26,12 @@ def run_worker(rank, args):
|
||||
disable_existing_loggers()
|
||||
launch(dict(), rank, world_size, host, int(port), backend, verbose=False)
|
||||
|
||||
pg = PipelineProcessGroup(rank=rank,
|
||||
world_size=world_size,
|
||||
dp_degree=dp_degree,
|
||||
tp_degree=tp_degree,
|
||||
num_worker_threads=num_worker_threads,
|
||||
device=device)
|
||||
ppg.set_global_info(rank=rank,
|
||||
world_size=world_size,
|
||||
dp_degree=dp_degree,
|
||||
tp_degree=tp_degree,
|
||||
num_worker_threads=num_worker_threads,
|
||||
device=device)
|
||||
|
||||
if rpc_is_initialized():
|
||||
rpc.shutdown()
|
||||
|
Reference in New Issue
Block a user