mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[legacy] clean up legacy code (#4743)
* [legacy] remove outdated codes of pipeline (#4692) * [legacy] remove cli of benchmark and update optim (#4690) * [legacy] remove cli of benchmark and update optim * [doc] fix cli doc test * [legacy] fix engine clip grad norm * [legacy] remove outdated colo tensor (#4694) * [legacy] remove outdated colo tensor * [test] fix test import * [legacy] move outdated zero to legacy (#4696) * [legacy] clean up utils (#4700) * [legacy] clean up utils * [example] update examples * [legacy] clean up amp * [legacy] fix amp module * [legacy] clean up gpc (#4742) * [legacy] clean up context * [legacy] clean core, constants and global vars * [legacy] refactor initialize * [example] fix examples ci * [example] fix examples ci * [legacy] fix tests * [example] fix gpt example * [example] fix examples ci * [devops] fix ci installation * [example] fix examples ci
This commit is contained in:
65
colossalai/legacy/utils/checkpoint/utils.py
Normal file
65
colossalai/legacy/utils/checkpoint/utils.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.legacy.tensor import ColoTensorSpec
|
||||
from colossalai.legacy.tensor.distspec import DistPlacementPattern, _DistSpec
|
||||
from colossalai.tensor import ColoTensor
|
||||
|
||||
|
||||
def robust_broadcast(tensor):
|
||||
with torch.no_grad():
|
||||
is_cpu_ten = tensor.device.type == 'cpu'
|
||||
if is_cpu_ten:
|
||||
b_data = tensor.cuda()
|
||||
else:
|
||||
b_data = tensor
|
||||
|
||||
dist.broadcast(b_data, 0)
|
||||
|
||||
if is_cpu_ten:
|
||||
tensor.copy_(b_data)
|
||||
|
||||
|
||||
def gather_tensor(colo_tensor: ColoTensor) -> None:
|
||||
"""Make colo_tensor replicated when the rank is 0
|
||||
"""
|
||||
if not colo_tensor.is_replicate():
|
||||
pg = colo_tensor.get_process_group()
|
||||
# for the group which contains rank 0
|
||||
if pg.dp_local_rank() == 0:
|
||||
old_dist_spec = colo_tensor.dist_spec
|
||||
colo_tensor.to_replicate_()
|
||||
if dist.get_rank() != 0:
|
||||
colo_tensor.set_dist_spec(old_dist_spec)
|
||||
|
||||
# synchronize all processes for unexpected problems
|
||||
dist.barrier()
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
setattr(colo_tensor, 'save_ready', True) # set saving signature
|
||||
|
||||
|
||||
def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None:
|
||||
"""Reversal operation of `gather_tensor`.
|
||||
"""
|
||||
if dist_spec.placement == DistPlacementPattern.REPLICATE:
|
||||
robust_broadcast(colo_tensor.data)
|
||||
else:
|
||||
global_size = colo_tensor.size_global()
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
entire_data = colo_tensor.data
|
||||
else:
|
||||
entire_data = torch.empty(global_size, device=colo_tensor.device)
|
||||
robust_broadcast(entire_data)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
colo_tensor.set_dist_spec(dist_spec)
|
||||
else:
|
||||
rep_tensor = ColoTensor(
|
||||
entire_data, ColoTensorSpec(pg=colo_tensor.get_process_group(), compute_attr=colo_tensor.compute_spec))
|
||||
rep_tensor.set_dist_spec(dist_spec)
|
||||
with torch.no_grad():
|
||||
colo_tensor.data.copy_(rep_tensor.data)
|
||||
# synchronize all processes for unexpected problems
|
||||
dist.barrier()
|
Reference in New Issue
Block a user