mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[zero] support extra dp (#6123)
* [zero] support extra dp * [zero] update checkpoint * fix bugs * fix bugs
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
@@ -209,3 +210,42 @@ def sync_tensor(flat_tensor, tensor_list):
|
||||
# update the tensor data
|
||||
for p, q in zip(tensor_list, updated_params):
|
||||
p.data = q.data
|
||||
|
||||
|
||||
def all_gather_into_flat_tensor_nd(
|
||||
output_tensor: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
group: Union[dist.ProcessGroup, Tuple[dist.ProcessGroup, ...]],
|
||||
async_op: bool = False,
|
||||
):
|
||||
if isinstance(group, dist.ProcessGroup):
|
||||
group = (group,)
|
||||
sizes = [dist.get_world_size(pg) for pg in group]
|
||||
ranks = [dist.get_rank(pg) for pg in group]
|
||||
for i, pg in list(enumerate(group))[::-1]:
|
||||
if i == 0:
|
||||
out = output_tensor
|
||||
else:
|
||||
prev_sizes = sizes[:i]
|
||||
prev_ranks = ranks[:i]
|
||||
chunks = output_tensor.chunk(np.prod(prev_sizes))
|
||||
out = chunks[np.ravel_multi_index(prev_ranks, prev_sizes)]
|
||||
handle = dist.all_gather_into_tensor(out, input_tensor, group=pg, async_op=async_op)
|
||||
input_tensor = out
|
||||
return handle
|
||||
|
||||
|
||||
def get_nd_world_size(group) -> int:
|
||||
if isinstance(group, tuple):
|
||||
return int(np.prod([dist.get_world_size(pg) for pg in group]))
|
||||
else:
|
||||
return dist.get_world_size(group)
|
||||
|
||||
|
||||
def get_nd_rank(group) -> int:
|
||||
if isinstance(group, tuple):
|
||||
return np.ravel_multi_index(
|
||||
tuple(dist.get_rank(group=pg) for pg in group), [dist.get_world_size(pg) for pg in group]
|
||||
)
|
||||
else:
|
||||
return dist.get_rank(group)
|
||||
|
Reference in New Issue
Block a user