mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[legacy] clean up legacy code (#4743)
* [legacy] remove outdated codes of pipeline (#4692) * [legacy] remove cli of benchmark and update optim (#4690) * [legacy] remove cli of benchmark and update optim * [doc] fix cli doc test * [legacy] fix engine clip grad norm * [legacy] remove outdated colo tensor (#4694) * [legacy] remove outdated colo tensor * [test] fix test import * [legacy] move outdated zero to legacy (#4696) * [legacy] clean up utils (#4700) * [legacy] clean up utils * [example] update examples * [legacy] clean up amp * [legacy] fix amp module * [legacy] clean up gpc (#4742) * [legacy] clean up context * [legacy] clean core, constants and global vars * [legacy] refactor initialize * [example] fix examples ci * [example] fix examples ci * [legacy] fix tests * [example] fix gpt example * [example] fix examples ci * [devops] fix ci installation * [example] fix examples ci
This commit is contained in:
140
colossalai/legacy/utils/checkpoint/module_checkpoint.py
Normal file
140
colossalai/legacy/utils/checkpoint/module_checkpoint.py
Normal file
@@ -0,0 +1,140 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.tensor import ColoTensor
|
||||
|
||||
from .utils import gather_tensor, scatter_tensor
|
||||
|
||||
|
||||
def save_checkpoint(path: str,
|
||||
epoch: int,
|
||||
model: torch.nn.Module,
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
*args,
|
||||
**kwargs):
|
||||
"""save_checkpoint
|
||||
save a model, whose parameters are `ColoTensor`s.
|
||||
Args:
|
||||
path (str): directory to save the checkpoint files.
|
||||
epoch (int): the number of epoch
|
||||
model (torch.nn.Module): a torch module initialized by ColoInitContext
|
||||
optimizer (OptimizerWrapper, optional): optimizers. Defaults to None.
|
||||
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
|
||||
"""
|
||||
rank = dist.get_rank()
|
||||
model_state = model.state_dict()
|
||||
# save the dist context about the tensors in a new dict, while still maintain the original dict.
|
||||
for k, v in model_state.items():
|
||||
if isinstance(v, ColoTensor):
|
||||
gather_tensor(v) # gather shared tensors to rank0
|
||||
# don't recover tensors in rank0, since the dict is only a copy of model
|
||||
|
||||
if rank == 0:
|
||||
# sanity check
|
||||
for k, v in model_state.items():
|
||||
if isinstance(v, ColoTensor):
|
||||
assert v.save_ready
|
||||
assert v.is_replicate()
|
||||
delattr(v, 'save_ready')
|
||||
# model saving
|
||||
save_state = {'epoch': epoch, 'model': model_state}
|
||||
torch.save(save_state, path + '/epoch_{}_model.pth'.format(epoch), *args, **kwargs)
|
||||
|
||||
# delete old dicts
|
||||
del model_state
|
||||
# synchronize all the processes
|
||||
dist.barrier()
|
||||
|
||||
if optimizer is not None:
|
||||
mapping = dict()
|
||||
optim_state = optimizer.state_dict()
|
||||
for k, v in optim_state['state'].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
mapping[(k, n)] = t.dist_spec
|
||||
gather_tensor(t)
|
||||
|
||||
if rank == 0:
|
||||
save_state = {'epoch': epoch, 'optim': optim_state}
|
||||
torch.save(save_state, path + '/epoch_{}_optim.pth'.format(epoch), *args, **kwargs)
|
||||
# recover colo tensors in rank0
|
||||
for k, v in optimizer.state_dict()['state'].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
assert hasattr(t, 'save_ready')
|
||||
t.set_dist_spec(mapping[(k, n)])
|
||||
delattr(t, 'save_ready')
|
||||
|
||||
del optim_state
|
||||
del mapping
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def load_checkpoint(path: str,
|
||||
epoch: int,
|
||||
model: torch.nn.Module,
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
torch_load_kwargs: Optional[Dict] = None,
|
||||
load_state_dict_kwargs: Optional[Dict] = None):
|
||||
"""load_checkpoint
|
||||
load a model, whose parameters are `ColoTensor`s.
|
||||
Args:
|
||||
path (str): directory to save the checkpoint files.
|
||||
epoch (int): the number of epoch
|
||||
model (torch.nn.Module): a torch module initialized by ColoInitContext
|
||||
optimizer (OptimizerWrapper, optional): optimizers. Defaults to None.
|
||||
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
|
||||
torch_load_kwargs: (dict, optional): The kwargs of torch.load inside the function
|
||||
load_state_dict_kwargs (dict, optional): The kwargs of load_state_dict inside the function
|
||||
"""
|
||||
# initialize the default parameters
|
||||
if not torch_load_kwargs:
|
||||
torch_load_kwargs = dict()
|
||||
if not load_state_dict_kwargs:
|
||||
load_state_dict_kwargs = dict()
|
||||
|
||||
rank = dist.get_rank()
|
||||
mapping = dict()
|
||||
for n, p in model.named_parameters():
|
||||
if isinstance(p, ColoTensor):
|
||||
mapping[n] = p.dist_spec
|
||||
gather_tensor(p)
|
||||
|
||||
if rank == 0:
|
||||
load_state = torch.load(path + '/epoch_{}_model.pth'.format(epoch), **torch_load_kwargs)
|
||||
model.load_state_dict(load_state['model'], **load_state_dict_kwargs)
|
||||
dist.barrier()
|
||||
|
||||
# scatter loaded parameters
|
||||
for n, p in model.named_parameters():
|
||||
if isinstance(p, ColoTensor):
|
||||
scatter_tensor(p, mapping[n])
|
||||
if rank == 0:
|
||||
assert hasattr(p, 'save_ready')
|
||||
delattr(p, 'save_ready')
|
||||
del mapping
|
||||
|
||||
if optimizer is not None:
|
||||
mapping = dict()
|
||||
for k, v in optimizer.state_dict()['state'].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
mapping[(k, n)] = t.dist_spec
|
||||
gather_tensor(t)
|
||||
|
||||
if rank == 0:
|
||||
colo_checkpoint = torch.load(path + '/epoch_{}_optim.pth'.format(epoch), **torch_load_kwargs)
|
||||
optimizer.load_state_dict(colo_checkpoint['optim'], **load_state_dict_kwargs)
|
||||
dist.barrier()
|
||||
|
||||
for k, v in optimizer.state_dict()['state'].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
scatter_tensor(t, mapping[(k, n)])
|
||||
|
||||
del mapping
|
Reference in New Issue
Block a user