mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-05 03:26:48 +00:00
[checkpoint] add ColoOptimizer checkpointing (#1316)
This commit is contained in:
parent
7c2634f4b3
commit
9e4c6449b0
@ -1,6 +1,3 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- encoding: utf-8 -*-
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
@ -1,12 +1,15 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from colossalai.tensor import ColoTensor, DistSpecManager
|
from colossalai.tensor import ColoTensor, DistSpecManager
|
||||||
|
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||||
|
from copy import copy
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(dire: str,
|
def save_checkpoint(dire: str,
|
||||||
epoch: int,
|
epoch: int,
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
optimizer: torch.optim.Optimizer = None,
|
optimizer: Optional[ColossalaiOptimizer] = None,
|
||||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||||
*args,
|
*args,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
@ -16,7 +19,7 @@ def save_checkpoint(dire: str,
|
|||||||
dire (str): directory to save the checkpoint files.
|
dire (str): directory to save the checkpoint files.
|
||||||
epoch (int): the number of epoch
|
epoch (int): the number of epoch
|
||||||
model (torch.nn.Module): a torch module initialized by ColoInitContext
|
model (torch.nn.Module): a torch module initialized by ColoInitContext
|
||||||
optimizer (torch.optim.Optimizer, optional): optimizers. Defaults to None.
|
optimizer (ColossalaiOptimizer, optional): optimizers. Defaults to None.
|
||||||
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
|
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -41,11 +44,21 @@ def save_checkpoint(dire: str,
|
|||||||
# delete the new dict
|
# delete the new dict
|
||||||
del new_dict
|
del new_dict
|
||||||
|
|
||||||
|
optim_state_copy = copy(optimizer.state_dict())
|
||||||
|
for k, v in optim_state_copy['state'].items():
|
||||||
|
for n, t in v.items():
|
||||||
|
if isinstance(t, ColoTensor):
|
||||||
|
t.to_replicate_()
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
model_state = {'epoch': epoch, 'optim': optim_state_copy}
|
||||||
|
torch.save(model_state, dire + '/epoch_{}_optim.pth'.format(epoch))
|
||||||
|
del optim_state_copy
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(dire,
|
def load_checkpoint(dire,
|
||||||
epoch: int,
|
epoch: int,
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
optimizer: torch.optim.Optimizer = None,
|
optimizer: Optional[ColossalaiOptimizer] = None,
|
||||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||||
*args,
|
*args,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
@ -56,7 +69,7 @@ def load_checkpoint(dire,
|
|||||||
epoch (int): _description_
|
epoch (int): _description_
|
||||||
rank (int): _description_
|
rank (int): _description_
|
||||||
model (torch.nn.Module): _description_
|
model (torch.nn.Module): _description_
|
||||||
optimizer (torch.optim.Optimizer, optional): _description_. Defaults to None.
|
optimizer (ColossalaiOptimizer, optional): _description_. Defaults to None.
|
||||||
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None.
|
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -74,3 +87,24 @@ def load_checkpoint(dire,
|
|||||||
for k, v in model.state_dict().items():
|
for k, v in model.state_dict().items():
|
||||||
if isinstance(v, ColoTensor):
|
if isinstance(v, ColoTensor):
|
||||||
v.set_tensor_spec(*mapping[k])
|
v.set_tensor_spec(*mapping[k])
|
||||||
|
|
||||||
|
del mapping
|
||||||
|
mapping = dict()
|
||||||
|
|
||||||
|
for k, v in optimizer.state_dict()['state'].items():
|
||||||
|
for n, t in v.items():
|
||||||
|
if isinstance(t, ColoTensor):
|
||||||
|
mapping[(k, n)] = (t.dist_spec, t.compute_spec)
|
||||||
|
t.to_replicate_()
|
||||||
|
|
||||||
|
colo_checkpoint = torch.load(dire + '/epoch_{}_optim.pth'.format(epoch))
|
||||||
|
optimizer.load_state_dict(colo_checkpoint['optim'])
|
||||||
|
|
||||||
|
for k, v in optimizer.state_dict()['state'].items():
|
||||||
|
for n, t in v.items():
|
||||||
|
if isinstance(t, ColoTensor):
|
||||||
|
# skip key not in mapping.
|
||||||
|
# For Adam, if it dose not execute step() once, there will be not exp_avg and exp_avg_sq in optimizer
|
||||||
|
if (k, n) not in mapping:
|
||||||
|
continue
|
||||||
|
t.set_tensor_spec(*mapping[(k, n)])
|
||||||
|
@ -77,6 +77,18 @@ def remove(path):
|
|||||||
raise ValueError("file {} is not a file or dir.".format(path))
|
raise ValueError("file {} is not a file or dir.".format(path))
|
||||||
|
|
||||||
|
|
||||||
|
def compare_optims(optim1, optim2):
|
||||||
|
state1 = optim1.state_dict()['state']
|
||||||
|
state2 = optim2.state_dict()['state']
|
||||||
|
for k, p1 in state1.items():
|
||||||
|
if k not in state2:
|
||||||
|
continue
|
||||||
|
p2 = state2[k]
|
||||||
|
if isinstance(p1, ColoTensor):
|
||||||
|
assert isinstance(p2, ColoTensor)
|
||||||
|
assert torch.allclose(p1.to_replicate_(), p2.to_replicate_(), rtol=1e-3, atol=1e-1)
|
||||||
|
|
||||||
|
|
||||||
def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
|
def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
|
||||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||||
@ -117,7 +129,10 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
|
|||||||
model_reload = model_reload.cuda()
|
model_reload = model_reload.cuda()
|
||||||
model_reload.train()
|
model_reload.train()
|
||||||
|
|
||||||
colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.named_parameters(), r=0.1))
|
opt_class = torch.optim.Adam
|
||||||
|
colo_optimizer = ColossalaiOptimizer(opt_class(model.parameters(), lr=0.1))
|
||||||
|
colo_optimizer_reload = ColossalaiOptimizer(opt_class(model_reload.parameters(), lr=0.1))
|
||||||
|
run_reload = False
|
||||||
|
|
||||||
for i, (data, label) in enumerate(train_dataloader):
|
for i, (data, label) in enumerate(train_dataloader):
|
||||||
|
|
||||||
@ -130,22 +145,35 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
|
|||||||
# Bcast rank0 data to all processes
|
# Bcast rank0 data to all processes
|
||||||
if criterion:
|
if criterion:
|
||||||
output = model(data)
|
output = model(data)
|
||||||
|
output_reload = model_reload(data)
|
||||||
loss = criterion(output, label)
|
loss = criterion(output, label)
|
||||||
|
loss_reload = criterion(output_reload, label)
|
||||||
else:
|
else:
|
||||||
output = model(data, label)
|
loss = model(data, label)
|
||||||
loss = output
|
loss_reload = model_reload(data, label)
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
colo_optimizer.step()
|
loss_reload.backward()
|
||||||
|
|
||||||
|
if run_reload:
|
||||||
|
colo_optimizer_reload.zero_grad()
|
||||||
|
if criterion:
|
||||||
|
output_reload = model_reload(data)
|
||||||
|
loss_reload = criterion(output_reload, label)
|
||||||
|
else:
|
||||||
|
loss_reload = model_reload(data, label)
|
||||||
|
loss_reload.backward()
|
||||||
|
colo_optimizer_reload.step()
|
||||||
|
|
||||||
if i > 2:
|
if i > 2:
|
||||||
break
|
break
|
||||||
|
|
||||||
if not os.path.isdir('./checkpoint') and rank == 0:
|
if not os.path.isdir('./checkpoint') and rank == 0:
|
||||||
os.mkdir('./checkpoint')
|
os.mkdir('./checkpoint')
|
||||||
save_checkpoint('./checkpoint', 0, model, None, None)
|
save_checkpoint('./checkpoint', 0, model, colo_optimizer, None)
|
||||||
|
dist.barrier()
|
||||||
|
load_checkpoint('./checkpoint', 0, model_reload, colo_optimizer_reload, None)
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
load_checkpoint('./checkpoint', 0, model_reload, None, None)
|
|
||||||
|
|
||||||
# Since model is sharded, we merge them before param checking.
|
# Since model is sharded, we merge them before param checking.
|
||||||
for p in model.parameters():
|
for p in model.parameters():
|
||||||
@ -155,7 +183,7 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
|
|||||||
p.to_replicate_()
|
p.to_replicate_()
|
||||||
|
|
||||||
check_param_equal(model, model_reload)
|
check_param_equal(model, model_reload)
|
||||||
|
compare_optims(colo_optimizer, colo_optimizer_reload)
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
remove('./checkpoint')
|
remove('./checkpoint')
|
||||||
|
|
||||||
@ -163,7 +191,7 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
|
|||||||
def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
|
def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
pg = ProcessGroup(tp_degree=world_size)
|
pg = ProcessGroup(tp_degree=world_size)
|
||||||
for model_name in ['bert', 'simple_net']:
|
for model_name in ['simple_net', 'bert']:
|
||||||
_run_checkpoint(model_name,
|
_run_checkpoint(model_name,
|
||||||
init_1d_row_for_linear_weight_spec,
|
init_1d_row_for_linear_weight_spec,
|
||||||
use_ddp,
|
use_ddp,
|
||||||
|
Loading…
Reference in New Issue
Block a user