mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[legacy] clean up legacy code (#4743)
* [legacy] remove outdated codes of pipeline (#4692) * [legacy] remove cli of benchmark and update optim (#4690) * [legacy] remove cli of benchmark and update optim * [doc] fix cli doc test * [legacy] fix engine clip grad norm * [legacy] remove outdated colo tensor (#4694) * [legacy] remove outdated colo tensor * [test] fix test import * [legacy] move outdated zero to legacy (#4696) * [legacy] clean up utils (#4700) * [legacy] clean up utils * [example] update examples * [legacy] clean up amp * [legacy] fix amp module * [legacy] clean up gpc (#4742) * [legacy] clean up context * [legacy] clean core, constants and global vars * [legacy] refactor initialize * [example] fix examples ci * [example] fix examples ci * [legacy] fix tests * [example] fix gpt example * [example] fix examples ci * [devops] fix ci installation * [example] fix examples ci
This commit is contained in:
3
colossalai/legacy/utils/checkpoint/__init__.py
Normal file
3
colossalai/legacy/utils/checkpoint/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .module_checkpoint import load_checkpoint, save_checkpoint
|
||||
|
||||
__all__ = ['save_checkpoint', 'load_checkpoint']
|
140
colossalai/legacy/utils/checkpoint/module_checkpoint.py
Normal file
140
colossalai/legacy/utils/checkpoint/module_checkpoint.py
Normal file
@@ -0,0 +1,140 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.tensor import ColoTensor
|
||||
|
||||
from .utils import gather_tensor, scatter_tensor
|
||||
|
||||
|
||||
def save_checkpoint(path: str,
|
||||
epoch: int,
|
||||
model: torch.nn.Module,
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
*args,
|
||||
**kwargs):
|
||||
"""save_checkpoint
|
||||
save a model, whose parameters are `ColoTensor`s.
|
||||
Args:
|
||||
path (str): directory to save the checkpoint files.
|
||||
epoch (int): the number of epoch
|
||||
model (torch.nn.Module): a torch module initialized by ColoInitContext
|
||||
optimizer (OptimizerWrapper, optional): optimizers. Defaults to None.
|
||||
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
|
||||
"""
|
||||
rank = dist.get_rank()
|
||||
model_state = model.state_dict()
|
||||
# save the dist context about the tensors in a new dict, while still maintain the original dict.
|
||||
for k, v in model_state.items():
|
||||
if isinstance(v, ColoTensor):
|
||||
gather_tensor(v) # gather shared tensors to rank0
|
||||
# don't recover tensors in rank0, since the dict is only a copy of model
|
||||
|
||||
if rank == 0:
|
||||
# sanity check
|
||||
for k, v in model_state.items():
|
||||
if isinstance(v, ColoTensor):
|
||||
assert v.save_ready
|
||||
assert v.is_replicate()
|
||||
delattr(v, 'save_ready')
|
||||
# model saving
|
||||
save_state = {'epoch': epoch, 'model': model_state}
|
||||
torch.save(save_state, path + '/epoch_{}_model.pth'.format(epoch), *args, **kwargs)
|
||||
|
||||
# delete old dicts
|
||||
del model_state
|
||||
# synchronize all the processes
|
||||
dist.barrier()
|
||||
|
||||
if optimizer is not None:
|
||||
mapping = dict()
|
||||
optim_state = optimizer.state_dict()
|
||||
for k, v in optim_state['state'].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
mapping[(k, n)] = t.dist_spec
|
||||
gather_tensor(t)
|
||||
|
||||
if rank == 0:
|
||||
save_state = {'epoch': epoch, 'optim': optim_state}
|
||||
torch.save(save_state, path + '/epoch_{}_optim.pth'.format(epoch), *args, **kwargs)
|
||||
# recover colo tensors in rank0
|
||||
for k, v in optimizer.state_dict()['state'].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
assert hasattr(t, 'save_ready')
|
||||
t.set_dist_spec(mapping[(k, n)])
|
||||
delattr(t, 'save_ready')
|
||||
|
||||
del optim_state
|
||||
del mapping
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def load_checkpoint(path: str,
|
||||
epoch: int,
|
||||
model: torch.nn.Module,
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
torch_load_kwargs: Optional[Dict] = None,
|
||||
load_state_dict_kwargs: Optional[Dict] = None):
|
||||
"""load_checkpoint
|
||||
load a model, whose parameters are `ColoTensor`s.
|
||||
Args:
|
||||
path (str): directory to save the checkpoint files.
|
||||
epoch (int): the number of epoch
|
||||
model (torch.nn.Module): a torch module initialized by ColoInitContext
|
||||
optimizer (OptimizerWrapper, optional): optimizers. Defaults to None.
|
||||
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
|
||||
torch_load_kwargs: (dict, optional): The kwargs of torch.load inside the function
|
||||
load_state_dict_kwargs (dict, optional): The kwargs of load_state_dict inside the function
|
||||
"""
|
||||
# initialize the default parameters
|
||||
if not torch_load_kwargs:
|
||||
torch_load_kwargs = dict()
|
||||
if not load_state_dict_kwargs:
|
||||
load_state_dict_kwargs = dict()
|
||||
|
||||
rank = dist.get_rank()
|
||||
mapping = dict()
|
||||
for n, p in model.named_parameters():
|
||||
if isinstance(p, ColoTensor):
|
||||
mapping[n] = p.dist_spec
|
||||
gather_tensor(p)
|
||||
|
||||
if rank == 0:
|
||||
load_state = torch.load(path + '/epoch_{}_model.pth'.format(epoch), **torch_load_kwargs)
|
||||
model.load_state_dict(load_state['model'], **load_state_dict_kwargs)
|
||||
dist.barrier()
|
||||
|
||||
# scatter loaded parameters
|
||||
for n, p in model.named_parameters():
|
||||
if isinstance(p, ColoTensor):
|
||||
scatter_tensor(p, mapping[n])
|
||||
if rank == 0:
|
||||
assert hasattr(p, 'save_ready')
|
||||
delattr(p, 'save_ready')
|
||||
del mapping
|
||||
|
||||
if optimizer is not None:
|
||||
mapping = dict()
|
||||
for k, v in optimizer.state_dict()['state'].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
mapping[(k, n)] = t.dist_spec
|
||||
gather_tensor(t)
|
||||
|
||||
if rank == 0:
|
||||
colo_checkpoint = torch.load(path + '/epoch_{}_optim.pth'.format(epoch), **torch_load_kwargs)
|
||||
optimizer.load_state_dict(colo_checkpoint['optim'], **load_state_dict_kwargs)
|
||||
dist.barrier()
|
||||
|
||||
for k, v in optimizer.state_dict()['state'].items():
|
||||
for n, t in v.items():
|
||||
if isinstance(t, ColoTensor):
|
||||
scatter_tensor(t, mapping[(k, n)])
|
||||
|
||||
del mapping
|
65
colossalai/legacy/utils/checkpoint/utils.py
Normal file
65
colossalai/legacy/utils/checkpoint/utils.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.legacy.tensor import ColoTensorSpec
|
||||
from colossalai.legacy.tensor.distspec import DistPlacementPattern, _DistSpec
|
||||
from colossalai.tensor import ColoTensor
|
||||
|
||||
|
||||
def robust_broadcast(tensor):
|
||||
with torch.no_grad():
|
||||
is_cpu_ten = tensor.device.type == 'cpu'
|
||||
if is_cpu_ten:
|
||||
b_data = tensor.cuda()
|
||||
else:
|
||||
b_data = tensor
|
||||
|
||||
dist.broadcast(b_data, 0)
|
||||
|
||||
if is_cpu_ten:
|
||||
tensor.copy_(b_data)
|
||||
|
||||
|
||||
def gather_tensor(colo_tensor: ColoTensor) -> None:
|
||||
"""Make colo_tensor replicated when the rank is 0
|
||||
"""
|
||||
if not colo_tensor.is_replicate():
|
||||
pg = colo_tensor.get_process_group()
|
||||
# for the group which contains rank 0
|
||||
if pg.dp_local_rank() == 0:
|
||||
old_dist_spec = colo_tensor.dist_spec
|
||||
colo_tensor.to_replicate_()
|
||||
if dist.get_rank() != 0:
|
||||
colo_tensor.set_dist_spec(old_dist_spec)
|
||||
|
||||
# synchronize all processes for unexpected problems
|
||||
dist.barrier()
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
setattr(colo_tensor, 'save_ready', True) # set saving signature
|
||||
|
||||
|
||||
def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None:
|
||||
"""Reversal operation of `gather_tensor`.
|
||||
"""
|
||||
if dist_spec.placement == DistPlacementPattern.REPLICATE:
|
||||
robust_broadcast(colo_tensor.data)
|
||||
else:
|
||||
global_size = colo_tensor.size_global()
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
entire_data = colo_tensor.data
|
||||
else:
|
||||
entire_data = torch.empty(global_size, device=colo_tensor.device)
|
||||
robust_broadcast(entire_data)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
colo_tensor.set_dist_spec(dist_spec)
|
||||
else:
|
||||
rep_tensor = ColoTensor(
|
||||
entire_data, ColoTensorSpec(pg=colo_tensor.get_process_group(), compute_attr=colo_tensor.compute_spec))
|
||||
rep_tensor.set_dist_spec(dist_spec)
|
||||
with torch.no_grad():
|
||||
colo_tensor.data.copy_(rep_tensor.data)
|
||||
# synchronize all processes for unexpected problems
|
||||
dist.barrier()
|
Reference in New Issue
Block a user