mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-01 09:42:35 +00:00
[checkpoint] checkpoint for ColoTensor Model (#1196)
This commit is contained in:
parent
291e22aac6
commit
f38006ea83
3
colossalai/utils/checkpoint/__init__.py
Normal file
3
colossalai/utils/checkpoint/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .module_checkpoint import save_checkpoint, load_checkpoint
|
||||||
|
|
||||||
|
__all__ = ['save_checkpoint', 'load_checkpoint']
|
73
colossalai/utils/checkpoint/module_checkpoint.py
Normal file
73
colossalai/utils/checkpoint/module_checkpoint.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.distributed as dist
|
||||||
|
import collections
|
||||||
|
from torch.optim.lr_scheduler import CosineAnnealingLR as _CosineAnnealingLR
|
||||||
|
from colossalai.utils.model.colo_init_context import colo_state_dict
|
||||||
|
|
||||||
|
def save_checkpoint(dire,
|
||||||
|
epoch: int,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
optimizer: torch.optim.Optimizer = None,
|
||||||
|
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||||
|
*args,
|
||||||
|
**kwargs):
|
||||||
|
"""save_checkpoint
|
||||||
|
save a model, whose parameters are `ColoTensor`s.
|
||||||
|
Args:
|
||||||
|
dire (_type_): _description_
|
||||||
|
epoch (int): _description_
|
||||||
|
model (torch.nn.Module): _description_
|
||||||
|
optimizer (torch.optim.Optimizer, optional): _description_. Defaults to None.
|
||||||
|
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None.
|
||||||
|
"""
|
||||||
|
model_state = {
|
||||||
|
'epoch': epoch,
|
||||||
|
'model': colo_state_dict(model, state_dict_func=nn.Module.state_dict)
|
||||||
|
}
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
torch.save(model_state, dire + '/epoch_{}_model.pth'.format(epoch))
|
||||||
|
lr_scheduler_dict = lr_scheduler.state_dict()
|
||||||
|
lr_scheduler_dict['after_scheduler'] = lr_scheduler_dict['after_scheduler'].state_dict()
|
||||||
|
optim_state = {
|
||||||
|
'epoch': epoch,
|
||||||
|
'optimizer': optimizer.state_dict(),
|
||||||
|
'lr_scheduler': lr_scheduler_dict
|
||||||
|
}
|
||||||
|
torch.save(optim_state, dire + '/epoch_{}_optim_rank_{}.pth'.format(epoch, dist.get_rank()))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint(dire,
|
||||||
|
epoch: int,
|
||||||
|
rank: int,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
optimizer: torch.optim.Optimizer = None,
|
||||||
|
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||||
|
*args,
|
||||||
|
**kwargs):
|
||||||
|
"""load_checkpoint
|
||||||
|
load a model, whose parameters are `ColoTensor`s.
|
||||||
|
Args:
|
||||||
|
dire (_type_): _description_
|
||||||
|
epoch (int): _description_
|
||||||
|
rank (int): _description_
|
||||||
|
model (torch.nn.Module): _description_
|
||||||
|
optimizer (torch.optim.Optimizer, optional): _description_. Defaults to None.
|
||||||
|
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None.
|
||||||
|
"""
|
||||||
|
model_state = torch.load(dire + '/epoch_{}_model.pth'.format(epoch))
|
||||||
|
model_state['model'] = collections.OrderedDict([(k.split('.', 1)[1], v) for k, v in model_state['model'].items()])
|
||||||
|
model.load_state_dict(model_state['model'])
|
||||||
|
optim_state = torch.load(dire + '/epoch_{}_optim_rank_{}.pth'.format(epoch, rank))
|
||||||
|
optimizer.load_state_dict(optim_state['optimizer'])
|
||||||
|
lr_scheduler_dict = optim_state['lr_scheduler']
|
||||||
|
after_scheduler_dict = lr_scheduler_dict['after_scheduler']
|
||||||
|
lr_scheduler_dict['after_scheduler'] = _CosineAnnealingLR(
|
||||||
|
optimizer,
|
||||||
|
after_scheduler_dict['T_max'],
|
||||||
|
after_scheduler_dict['eta_min'],
|
||||||
|
after_scheduler_dict['last_epoch']
|
||||||
|
)
|
||||||
|
lr_scheduler.load_state_dict(lr_scheduler_dict)
|
@ -38,15 +38,18 @@ def colo_state_dict(self, destination=None, prefix='', keep_vars=False, state_di
|
|||||||
# build param to spec mapping
|
# build param to spec mapping
|
||||||
mapping1 = dict()
|
mapping1 = dict()
|
||||||
mapping2 = dict()
|
mapping2 = dict()
|
||||||
|
mapping3 = dict()
|
||||||
# gather all params
|
# gather all params
|
||||||
has_dist_parameter = False
|
has_dist_parameter = False
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for param in self.parameters():
|
for param in self.parameters():
|
||||||
if isinstance(param, ColoParameter) and param.has_compute_spec():
|
if isinstance(param, ColoParameter):
|
||||||
has_dist_parameter = True
|
has_dist_parameter = True
|
||||||
mapping1[id(param)] = copy(param.dist_spec)
|
mapping1[id(param)] = copy(param.dist_spec)
|
||||||
mapping2[id(param)] = copy(param.compute_spec)
|
mapping2[id(param)] = copy(param.compute_spec)
|
||||||
|
mapping3[id(param)] = param.get_process_group()
|
||||||
param.set_dist_spec(distspec.replicate())
|
param.set_dist_spec(distspec.replicate())
|
||||||
|
param.process_group = None
|
||||||
|
|
||||||
# TODO: fix when keep_vars = True
|
# TODO: fix when keep_vars = True
|
||||||
# when keep_vars = False, the state_dict_func will call detach to create
|
# when keep_vars = False, the state_dict_func will call detach to create
|
||||||
@ -64,6 +67,7 @@ def colo_state_dict(self, destination=None, prefix='', keep_vars=False, state_di
|
|||||||
if param_id in mapping1:
|
if param_id in mapping1:
|
||||||
dist_spec = mapping1[id(param)]
|
dist_spec = mapping1[id(param)]
|
||||||
compute_spec = mapping2[id(param)]
|
compute_spec = mapping2[id(param)]
|
||||||
|
param.process_group = mapping3[id(param)]
|
||||||
param.set_tensor_spec(dist_spec, compute_spec)
|
param.set_tensor_spec(dist_spec, compute_spec)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
211
tests/test_utils/test_colo_checkpoint.py
Normal file
211
tests/test_utils/test_colo_checkpoint.py
Normal file
@ -0,0 +1,211 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
import os, sys, shutil
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import pytest
|
||||||
|
import copy
|
||||||
|
import operator
|
||||||
|
import colossalai
|
||||||
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import torch.distributed as dist
|
||||||
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
|
from colossalai.utils.cuda import get_current_device
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
|
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup, ColoTensor
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
from functools import partial
|
||||||
|
from colossalai.nn.parallel.data_parallel import ColoDDP
|
||||||
|
from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint
|
||||||
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||||
|
|
||||||
|
|
||||||
|
class DummyDataGenerator(ABC):
|
||||||
|
|
||||||
|
def __init__(self, length=10):
|
||||||
|
self.length = length
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def generate(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
self.step = 0
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self.step < self.length:
|
||||||
|
self.step += 1
|
||||||
|
return self.generate()
|
||||||
|
else:
|
||||||
|
raise StopIteration
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
|
||||||
|
class DummyDataLoader(DummyDataGenerator):
|
||||||
|
batch_size = 128
|
||||||
|
category = 16
|
||||||
|
feature_size = 256
|
||||||
|
|
||||||
|
def generate(self):
|
||||||
|
image_dict = {}
|
||||||
|
image_dict['pixel_values'] = torch.rand(
|
||||||
|
DummyDataLoader.batch_size, DummyDataLoader.feature_size, device=get_current_device()) * 2 - 1
|
||||||
|
image_dict['label'] = torch.randint(DummyDataLoader.category, (DummyDataLoader.batch_size,),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=get_current_device())
|
||||||
|
return image_dict
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_features, out_features, hidden_features=None):
|
||||||
|
super().__init__()
|
||||||
|
if hidden_features is None:
|
||||||
|
hidden_features = out_features
|
||||||
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||||
|
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||||
|
self.activation = nn.ReLU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.activation(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup):
|
||||||
|
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||||
|
with DistSpecManager.no_grad():
|
||||||
|
for n, p in model.named_parameters():
|
||||||
|
if 'weight' in n:
|
||||||
|
p.set_process_group(pg)
|
||||||
|
p.set_tensor_spec(*spec)
|
||||||
|
|
||||||
|
|
||||||
|
def check_param_equal(model, torch_model):
|
||||||
|
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
||||||
|
assert torch.allclose(torch_p, p, rtol=1e-3, atol=1e-1)
|
||||||
|
|
||||||
|
|
||||||
|
def remove(path):
|
||||||
|
""" param <path> could either be relative or absolute. """
|
||||||
|
if os.path.isfile(path) or os.path.islink(path):
|
||||||
|
os.remove(path)
|
||||||
|
elif os.path.isdir(path):
|
||||||
|
shutil.rmtree(path)
|
||||||
|
else:
|
||||||
|
raise ValueError("file {} is not a file or dir.".format(path))
|
||||||
|
|
||||||
|
|
||||||
|
def run_checkpoint(init_spec_func, use_ddp, test_epoch, pg):
|
||||||
|
train_dataloader = DummyDataLoader(length=16)
|
||||||
|
with ColoInitContext(device=get_current_device()):
|
||||||
|
model = MLP(256, 16, 64)
|
||||||
|
model_reload = MLP(256, 16, 64)
|
||||||
|
model_ref = MLP(256, 16, 64)
|
||||||
|
model = model.cuda()
|
||||||
|
model_reload = model_reload.cuda()
|
||||||
|
model_ref = model_ref.cuda()
|
||||||
|
if use_ddp:
|
||||||
|
model = ColoDDP(model, pg)
|
||||||
|
model_reload = ColoDDP(model_reload, pg)
|
||||||
|
model_ref = ColoDDP(model_ref, pg)
|
||||||
|
|
||||||
|
criterion = torch.nn.CrossEntropyLoss()
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
||||||
|
optimizer_reload = torch.optim.Adam(model_reload.parameters(),
|
||||||
|
lr=0.001,
|
||||||
|
betas=(0.9, 0.999),
|
||||||
|
eps=1e-08,
|
||||||
|
weight_decay=0)
|
||||||
|
optimizer_ref = torch.optim.Adam(model_ref.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
||||||
|
|
||||||
|
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=20, warmup_steps=5)
|
||||||
|
lr_scheduler_reload = CosineAnnealingWarmupLR(optimizer=optimizer_reload, total_steps=20, warmup_steps=5)
|
||||||
|
lr_scheduler_ref = CosineAnnealingWarmupLR(optimizer=optimizer_ref, total_steps=20, warmup_steps=5)
|
||||||
|
|
||||||
|
init_spec_func(model, pg)
|
||||||
|
init_spec_func(model_ref, pg)
|
||||||
|
|
||||||
|
for epoch in range(0, 20):
|
||||||
|
if epoch <= test_epoch:
|
||||||
|
for i, image_dict in enumerate(train_dataloader):
|
||||||
|
if use_ddp:
|
||||||
|
model.zero_grad()
|
||||||
|
else:
|
||||||
|
optimizer.zero_grad()
|
||||||
|
logits = model(image_dict['pixel_values'])
|
||||||
|
loss = criterion(logits, image_dict['label'])
|
||||||
|
if use_ddp:
|
||||||
|
model.backward(loss)
|
||||||
|
else:
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
if epoch == test_epoch:
|
||||||
|
for ref_p, p in zip(model_ref.parameters(), model.parameters()):
|
||||||
|
ref_p.data.copy_(p)
|
||||||
|
optimizer_ref = copy.deepcopy(optimizer)
|
||||||
|
lr_scheduler_ref = copy.deepcopy(lr_scheduler)
|
||||||
|
|
||||||
|
check_param_equal(model, model_ref)
|
||||||
|
save_checkpoint('./checkpoint', epoch, model, optimizer, lr_scheduler)
|
||||||
|
dist.barrier()
|
||||||
|
else:
|
||||||
|
if epoch == test_epoch + 1:
|
||||||
|
load_checkpoint('./checkpoint', test_epoch, dist.get_rank(), model_reload, optimizer_reload,
|
||||||
|
lr_scheduler_reload)
|
||||||
|
init_spec_func(model_reload, pg)
|
||||||
|
for i, image_dict in enumerate(train_dataloader):
|
||||||
|
if use_ddp:
|
||||||
|
model_ref.zero_grad()
|
||||||
|
model_reload.zero_grad()
|
||||||
|
else:
|
||||||
|
optimizer_ref.zero_grad()
|
||||||
|
optimizer_reload.zero_grad()
|
||||||
|
logits_ref = model_ref(image_dict['pixel_values'])
|
||||||
|
logits_reload = model_reload(image_dict['pixel_values'])
|
||||||
|
loss_ref = criterion(logits_ref, image_dict['label'])
|
||||||
|
loss_reload = criterion(logits_reload, image_dict['label'])
|
||||||
|
if use_ddp:
|
||||||
|
model_ref.backward(loss_ref)
|
||||||
|
model_reload.backward(loss_reload)
|
||||||
|
else:
|
||||||
|
loss_ref.backward()
|
||||||
|
loss_reload.backward()
|
||||||
|
optimizer_ref.step()
|
||||||
|
optimizer_reload.step()
|
||||||
|
lr_scheduler.step()
|
||||||
|
|
||||||
|
check_param_equal(model_ref, model_reload)
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port, use_ddp, test_epoch):
|
||||||
|
if use_ddp and world_size == 1:
|
||||||
|
return
|
||||||
|
tp_world_size = world_size // 2 if use_ddp else world_size
|
||||||
|
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
|
||||||
|
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
pg = ProcessGroup(tp_degree=world_size)
|
||||||
|
run_checkpoint(init_1d_row_for_linear_weight_spec, use_ddp, test_epoch, pg)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@pytest.mark.parametrize('world_size', [4])
|
||||||
|
@pytest.mark.parametrize('use_ddp', [True])
|
||||||
|
@pytest.mark.parametrize('test_epoch', [1, 2, 3])
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_checkpoint(world_size, use_ddp, test_epoch):
|
||||||
|
if not os.path.isdir('./checkpoint'):
|
||||||
|
os.mkdir('./checkpoint')
|
||||||
|
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp, test_epoch=test_epoch)
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
remove('./checkpoint')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_checkpoint(4, True, 1)
|
Loading…
Reference in New Issue
Block a user