diff --git a/colossalai/utils/checkpoint/module_checkpoint.py b/colossalai/utils/checkpoint/module_checkpoint.py index 3f61aed2f..119d719b2 100644 --- a/colossalai/utils/checkpoint/module_checkpoint.py +++ b/colossalai/utils/checkpoint/module_checkpoint.py @@ -28,7 +28,8 @@ def save_checkpoint(dire: str, if isinstance(v, ColoTensor): mapping[k] = (v.dist_spec, v.compute_spec) new_dict[k] = v.to_replicate().detach() - + else: + new_dict[k] = v if dist.get_rank() == 0: for k, v in new_dict.items(): if isinstance(v, ColoTensor): @@ -60,7 +61,7 @@ def load_checkpoint(dire, """ mapping = dict() - for k, v in model.named_parameters(): + for k, v in model.state_dict().items(): if isinstance(v, ColoTensor): mapping[k] = (v.dist_spec, v.compute_spec) v.to_replicate_() @@ -70,6 +71,6 @@ def load_checkpoint(dire, # reset tensors to original dist spec. with DistSpecManager.no_grad(): - for k, v in model.named_parameters(): + for k, v in model.state_dict().items(): if isinstance(v, ColoTensor): v.set_tensor_spec(*mapping[k]) diff --git a/tests/test_utils/test_colo_checkpoint.py b/tests/test_utils/test_colo_checkpoint.py index 0581d7bf0..4557cfa28 100644 --- a/tests/test_utils/test_colo_checkpoint.py +++ b/tests/test_utils/test_colo_checkpoint.py @@ -1,91 +1,65 @@ -from abc import ABC, abstractmethod import os, shutil import torch -import torch.nn as nn import pytest from functools import partial import torch.multiprocessing as mp import torch.distributed as dist + from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import MultiplicativeLR +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR import colossalai from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils.cuda import get_current_device from colossalai.utils import free_port from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ShardSpec, ProcessGroup +from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup, DistSpecManager, ReplicaSpec from colossalai.nn.parallel.data_parallel import ColoDDP from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint -from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import ColoOptimizer + +from tests.components_to_test.registry import non_distributed_component_funcs -class DummyDataGenerator(ABC): - - def __init__(self, length=10): - self.length = length - - @abstractmethod - def generate(self): - pass - - def __iter__(self): - self.step = 0 - return self - - def __next__(self): - if self.step < self.length: - self.step += 1 - return self.generate() - else: - raise StopIteration - - def __len__(self): - return self.length +def init_1d_row_linear(weight: ColoTensor, pg: ProcessGroup): + spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + weight.set_process_group(pg) + weight.set_tensor_spec(*spec) -class DummyDataLoader(DummyDataGenerator): - - def __init__(self, batch_size, category, feature_size, length=10): - super().__init__(length) - self.batch_size = batch_size - self.category = category - self.feature_size = feature_size - - def generate(self): - image_dict = {} - image_dict['pixel_values'] = torch.rand(self.batch_size, self.feature_size, device=get_current_device()) * 2 - 1 - image_dict['label'] = torch.randint(self.category, (self.batch_size,), - dtype=torch.int64, - device=get_current_device()) - return image_dict +def init_1d_col_linear(weight, pg): + spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + weight.set_process_group(pg) + weight.set_tensor_spec(*spec) -class MLP(nn.Module): +def init_1d_row_embedding(weight, pg): + spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + weight.set_process_group(pg) + weight.set_tensor_spec(*spec) - def __init__(self, in_features, out_features, hidden_features=None): - super().__init__() - if hidden_features is None: - hidden_features = out_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.fc2 = nn.Linear(hidden_features, out_features) - self.activation = nn.ReLU() - def forward(self, x): - x = self.fc1(x) - x = self.activation(x) - x = self.fc2(x) - return x +def init_1d_col_embedding(weight, pg): + spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + weight.set_process_group(pg) + weight.set_tensor_spec(*spec) def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup): spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - with DistSpecManager.no_grad(): - for n, p in model.named_parameters(): - if 'weight' in n: - p.set_process_group(pg) - p.set_tensor_spec(*spec) + for name, p in model.named_parameters(): + if not isinstance(p, ColoTensor): + continue + if 'embed' in name and 'weight' in name: + init_1d_col_embedding(p, pg) + if 'proj1' in name and ('weight' in name or 'bias' in name): + init_1d_col_linear(p, pg) + if 'proj2' in name and 'weight' in name: + init_1d_row_linear(p, pg) + if 'classifier' in name and ('weight' in name or 'bias' in name): + init_1d_col_linear(p, pg) def check_param_equal(model, torch_model): @@ -103,56 +77,75 @@ def remove(path): raise ValueError("file {} is not a file or dir.".format(path)) -def run_checkpoint(init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg): - num_epoch = 5 - warmup_epoch = 2 +def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg): + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - batch = 3 - feature = 32 - category = 16 + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + # set_seed(1) with ColoInitContext(device=get_current_device()): - model = MLP(feature, category) + model = model_builder(checkpoint=True) + model_reload = model_builder(checkpoint=True) - with ColoInitContext(device=get_current_device()): - model_reload = MLP(feature, category) + if use_mp_reload: + if 'bert' == model_name: + for name, p in model.named_parameters(): + if not isinstance(p, ColoTensor): + continue + # num_class = type_vocab_size = 2 | (8, 2) + if 'classifier' in name and 'weight' in name: + init_1d_row_linear(p, pg) + # num_class = vocab_size = 30524 | (30524, 8) + elif 'word_embeddings' in name and 'weight' in name: + init_1d_row_embedding(p, pg) + # num_class = seq_len = 512 | (512, 8) + elif 'position_embeddings' in name and 'weight' in name: + init_1d_row_embedding(p, pg) + # num_class = type_vocab_size = 2 | (2, 8) + elif 'token_type_embeddings' in name and 'weight' in name: + init_1d_col_embedding(p, pg) + elif p.process_group.tp_world_size() == 1: + p.redistribute(ReplicaSpec(), pg) + elif "simple_net" == model_name: + init_spec_func(model, pg) model = model.cuda() + model.train() + model_reload = model_reload.cuda() - if use_ddp: - model = ColoDDP(model, pg) - model_reload = ColoDDP(model_reload, pg) + model_reload.train() - init_spec_func(model, pg) - if use_mp_reload: - init_spec_func(model_reload, pg) + colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1) - optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) - optimizer_reload = torch.optim.Adam(model_reload.parameters(), - lr=0.001, - betas=(0.9, 0.999), - eps=1e-08, - weight_decay=0) + for i, (data, label) in enumerate(train_dataloader): - lr_scheduler = None - if test_scheduler == 'colossalai_cosine_warmup': - lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=num_epoch, warmup_steps=warmup_epoch) - lr_scheduler_reload = CosineAnnealingWarmupLR(optimizer=optimizer_reload, - total_steps=num_epoch, - warmup_steps=warmup_epoch) - elif test_scheduler == 'torch_cosine': - lr_scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=num_epoch) - lr_scheduler_reload = CosineAnnealingLR(optimizer=optimizer_reload, T_max=num_epoch) - elif test_scheduler == 'torch_lambda': - lr_lambda = lambda epoch: 0.95 - lr_scheduler = MultiplicativeLR(optimizer=optimizer, lr_lambda=lr_lambda) - lr_scheduler_reload = MultiplicativeLR(optimizer=optimizer_reload, lr_lambda=lr_lambda) - else: - raise TypeError(f"{test_scheduler} is invalid") + # Zero grad + colo_optimizer.zero_grad() - save_checkpoint('./checkpoint', 0, model, optimizer, lr_scheduler) + data = data.to(get_current_device()) + label = label.to(get_current_device()) + + # Bcast rank0 data to all processes + if criterion: + output = model(data) + loss = criterion(output, label) + else: + output = model(data, label) + loss = output + + loss.backward() + colo_optimizer.step() + + if i > 2: + break + + if not os.path.isdir('./checkpoint') and rank == 0: + os.mkdir('./checkpoint') + save_checkpoint('./checkpoint', 0, model, None, None) dist.barrier() - load_checkpoint('./checkpoint', 0, model_reload, optimizer_reload, lr_scheduler_reload) + load_checkpoint('./checkpoint', 0, model_reload, None, None) # Since model is sharded, we merge them before param checking. for p in model.parameters(): @@ -163,26 +156,29 @@ def run_checkpoint(init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg): check_param_equal(model, model_reload) + if rank == 0: + remove('./checkpoint') + def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler): - if use_ddp and world_size == 1: - return - tp_world_size = world_size // 2 if use_ddp else world_size - config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),)) - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') pg = ProcessGroup(tp_degree=world_size) - run_checkpoint(init_1d_row_for_linear_weight_spec, use_ddp, use_mp_reload, test_scheduler=test_scheduler, pg=pg) + for model_name in ['bert', 'simple_net']: + _run_checkpoint(model_name, + init_1d_row_for_linear_weight_spec, + use_ddp, + use_mp_reload, + test_scheduler=test_scheduler, + pg=pg) @pytest.mark.dist @pytest.mark.parametrize('world_size', [1, 2]) -@pytest.mark.parametrize('use_ddp', [True, False]) +@pytest.mark.parametrize('use_ddp', [False]) @pytest.mark.parametrize('use_mp_reload', [True, False]) -@pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda']) +# @pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda']) @rerun_if_address_is_in_use() -def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler): - if not os.path.isdir('./checkpoint'): - os.mkdir('./checkpoint') +def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler=None): run_func = partial(run_dist, world_size=world_size, port=free_port(), @@ -190,8 +186,7 @@ def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler): use_mp_reload=use_mp_reload, test_scheduler=test_scheduler) mp.spawn(run_func, nprocs=world_size) - remove('./checkpoint') if __name__ == '__main__': - test_checkpoint(2, True, False, "torch_cosine") + test_checkpoint(2, use_ddp=False, use_mp_reload=True, test_scheduler="torch_cosine")