[checkpoint] add test for bert and hotfix save bugs (#1297)

This commit is contained in:
Jiarui Fang 2022-07-14 15:38:18 +08:00 committed by GitHub
parent bd71e2a88b
commit 3ef3791a3b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 109 additions and 113 deletions

View File

@ -28,7 +28,8 @@ def save_checkpoint(dire: str,
if isinstance(v, ColoTensor): if isinstance(v, ColoTensor):
mapping[k] = (v.dist_spec, v.compute_spec) mapping[k] = (v.dist_spec, v.compute_spec)
new_dict[k] = v.to_replicate().detach() new_dict[k] = v.to_replicate().detach()
else:
new_dict[k] = v
if dist.get_rank() == 0: if dist.get_rank() == 0:
for k, v in new_dict.items(): for k, v in new_dict.items():
if isinstance(v, ColoTensor): if isinstance(v, ColoTensor):
@ -60,7 +61,7 @@ def load_checkpoint(dire,
""" """
mapping = dict() mapping = dict()
for k, v in model.named_parameters(): for k, v in model.state_dict().items():
if isinstance(v, ColoTensor): if isinstance(v, ColoTensor):
mapping[k] = (v.dist_spec, v.compute_spec) mapping[k] = (v.dist_spec, v.compute_spec)
v.to_replicate_() v.to_replicate_()
@ -70,6 +71,6 @@ def load_checkpoint(dire,
# reset tensors to original dist spec. # reset tensors to original dist spec.
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
for k, v in model.named_parameters(): for k, v in model.state_dict().items():
if isinstance(v, ColoTensor): if isinstance(v, ColoTensor):
v.set_tensor_spec(*mapping[k]) v.set_tensor_spec(*mapping[k])

View File

@ -1,91 +1,65 @@
from abc import ABC, abstractmethod
import os, shutil import os, shutil
import torch import torch
import torch.nn as nn
import pytest import pytest
from functools import partial from functools import partial
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.distributed as dist import torch.distributed as dist
from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim.lr_scheduler import MultiplicativeLR from torch.optim.lr_scheduler import MultiplicativeLR
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
import colossalai import colossalai
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ShardSpec, ProcessGroup from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup, DistSpecManager, ReplicaSpec
from colossalai.nn.parallel.data_parallel import ColoDDP from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import ColoOptimizer
from tests.components_to_test.registry import non_distributed_component_funcs
class DummyDataGenerator(ABC): def init_1d_row_linear(weight: ColoTensor, pg: ProcessGroup):
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
def __init__(self, length=10): weight.set_process_group(pg)
self.length = length weight.set_tensor_spec(*spec)
@abstractmethod
def generate(self):
pass
def __iter__(self):
self.step = 0
return self
def __next__(self):
if self.step < self.length:
self.step += 1
return self.generate()
else:
raise StopIteration
def __len__(self):
return self.length
class DummyDataLoader(DummyDataGenerator): def init_1d_col_linear(weight, pg):
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
def __init__(self, batch_size, category, feature_size, length=10): weight.set_process_group(pg)
super().__init__(length) weight.set_tensor_spec(*spec)
self.batch_size = batch_size
self.category = category
self.feature_size = feature_size
def generate(self):
image_dict = {}
image_dict['pixel_values'] = torch.rand(self.batch_size, self.feature_size, device=get_current_device()) * 2 - 1
image_dict['label'] = torch.randint(self.category, (self.batch_size,),
dtype=torch.int64,
device=get_current_device())
return image_dict
class MLP(nn.Module): def init_1d_row_embedding(weight, pg):
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
weight.set_process_group(pg)
weight.set_tensor_spec(*spec)
def __init__(self, in_features, out_features, hidden_features=None):
super().__init__()
if hidden_features is None:
hidden_features = out_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.fc2 = nn.Linear(hidden_features, out_features)
self.activation = nn.ReLU()
def forward(self, x): def init_1d_col_embedding(weight, pg):
x = self.fc1(x) spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
x = self.activation(x) weight.set_process_group(pg)
x = self.fc2(x) weight.set_tensor_spec(*spec)
return x
def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup): def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup):
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): for name, p in model.named_parameters():
for n, p in model.named_parameters(): if not isinstance(p, ColoTensor):
if 'weight' in n: continue
p.set_process_group(pg) if 'embed' in name and 'weight' in name:
p.set_tensor_spec(*spec) init_1d_col_embedding(p, pg)
if 'proj1' in name and ('weight' in name or 'bias' in name):
init_1d_col_linear(p, pg)
if 'proj2' in name and 'weight' in name:
init_1d_row_linear(p, pg)
if 'classifier' in name and ('weight' in name or 'bias' in name):
init_1d_col_linear(p, pg)
def check_param_equal(model, torch_model): def check_param_equal(model, torch_model):
@ -103,56 +77,75 @@ def remove(path):
raise ValueError("file {} is not a file or dir.".format(path)) raise ValueError("file {} is not a file or dir.".format(path))
def run_checkpoint(init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg): def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
num_epoch = 5 get_components_func = non_distributed_component_funcs.get_callable(model_name)
warmup_epoch = 2 model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
batch = 3 rank = torch.distributed.get_rank()
feature = 32 world_size = torch.distributed.get_world_size()
category = 16
# set_seed(1)
with ColoInitContext(device=get_current_device()): with ColoInitContext(device=get_current_device()):
model = MLP(feature, category) model = model_builder(checkpoint=True)
model_reload = model_builder(checkpoint=True)
with ColoInitContext(device=get_current_device()): if use_mp_reload:
model_reload = MLP(feature, category) if 'bert' == model_name:
for name, p in model.named_parameters():
if not isinstance(p, ColoTensor):
continue
# num_class = type_vocab_size = 2 | (8, 2)
if 'classifier' in name and 'weight' in name:
init_1d_row_linear(p, pg)
# num_class = vocab_size = 30524 | (30524, 8)
elif 'word_embeddings' in name and 'weight' in name:
init_1d_row_embedding(p, pg)
# num_class = seq_len = 512 | (512, 8)
elif 'position_embeddings' in name and 'weight' in name:
init_1d_row_embedding(p, pg)
# num_class = type_vocab_size = 2 | (2, 8)
elif 'token_type_embeddings' in name and 'weight' in name:
init_1d_col_embedding(p, pg)
elif p.process_group.tp_world_size() == 1:
p.redistribute(ReplicaSpec(), pg)
elif "simple_net" == model_name:
init_spec_func(model, pg)
model = model.cuda() model = model.cuda()
model.train()
model_reload = model_reload.cuda() model_reload = model_reload.cuda()
if use_ddp: model_reload.train()
model = ColoDDP(model, pg)
model_reload = ColoDDP(model_reload, pg)
init_spec_func(model, pg) colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)
if use_mp_reload:
init_spec_func(model_reload, pg)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) for i, (data, label) in enumerate(train_dataloader):
optimizer_reload = torch.optim.Adam(model_reload.parameters(),
lr=0.001,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=0)
lr_scheduler = None # Zero grad
if test_scheduler == 'colossalai_cosine_warmup': colo_optimizer.zero_grad()
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=num_epoch, warmup_steps=warmup_epoch)
lr_scheduler_reload = CosineAnnealingWarmupLR(optimizer=optimizer_reload,
total_steps=num_epoch,
warmup_steps=warmup_epoch)
elif test_scheduler == 'torch_cosine':
lr_scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=num_epoch)
lr_scheduler_reload = CosineAnnealingLR(optimizer=optimizer_reload, T_max=num_epoch)
elif test_scheduler == 'torch_lambda':
lr_lambda = lambda epoch: 0.95
lr_scheduler = MultiplicativeLR(optimizer=optimizer, lr_lambda=lr_lambda)
lr_scheduler_reload = MultiplicativeLR(optimizer=optimizer_reload, lr_lambda=lr_lambda)
else:
raise TypeError(f"{test_scheduler} is invalid")
save_checkpoint('./checkpoint', 0, model, optimizer, lr_scheduler) data = data.to(get_current_device())
label = label.to(get_current_device())
# Bcast rank0 data to all processes
if criterion:
output = model(data)
loss = criterion(output, label)
else:
output = model(data, label)
loss = output
loss.backward()
colo_optimizer.step()
if i > 2:
break
if not os.path.isdir('./checkpoint') and rank == 0:
os.mkdir('./checkpoint')
save_checkpoint('./checkpoint', 0, model, None, None)
dist.barrier() dist.barrier()
load_checkpoint('./checkpoint', 0, model_reload, optimizer_reload, lr_scheduler_reload) load_checkpoint('./checkpoint', 0, model_reload, None, None)
# Since model is sharded, we merge them before param checking. # Since model is sharded, we merge them before param checking.
for p in model.parameters(): for p in model.parameters():
@ -163,26 +156,29 @@ def run_checkpoint(init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
check_param_equal(model, model_reload) check_param_equal(model, model_reload)
if rank == 0:
remove('./checkpoint')
def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler): def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
if use_ddp and world_size == 1: colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
return
tp_world_size = world_size // 2 if use_ddp else world_size
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
pg = ProcessGroup(tp_degree=world_size) pg = ProcessGroup(tp_degree=world_size)
run_checkpoint(init_1d_row_for_linear_weight_spec, use_ddp, use_mp_reload, test_scheduler=test_scheduler, pg=pg) for model_name in ['bert', 'simple_net']:
_run_checkpoint(model_name,
init_1d_row_for_linear_weight_spec,
use_ddp,
use_mp_reload,
test_scheduler=test_scheduler,
pg=pg)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2]) @pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('use_ddp', [True, False]) @pytest.mark.parametrize('use_ddp', [False])
@pytest.mark.parametrize('use_mp_reload', [True, False]) @pytest.mark.parametrize('use_mp_reload', [True, False])
@pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda']) # @pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda'])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler): def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler=None):
if not os.path.isdir('./checkpoint'):
os.mkdir('./checkpoint')
run_func = partial(run_dist, run_func = partial(run_dist,
world_size=world_size, world_size=world_size,
port=free_port(), port=free_port(),
@ -190,8 +186,7 @@ def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler):
use_mp_reload=use_mp_reload, use_mp_reload=use_mp_reload,
test_scheduler=test_scheduler) test_scheduler=test_scheduler)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
remove('./checkpoint')
if __name__ == '__main__': if __name__ == '__main__':
test_checkpoint(2, True, False, "torch_cosine") test_checkpoint(2, use_ddp=False, use_mp_reload=True, test_scheduler="torch_cosine")