[zero] support extra dp (#6123)

* [zero] support extra dp

* [zero] update checkpoint

* fix bugs

* fix bugs
This commit is contained in:
Hongxin Liu
2024-11-12 11:20:46 +08:00
committed by GitHub
parent 30a9443132
commit a2596519fd
8 changed files with 238 additions and 57 deletions

View File

@@ -1,6 +1,7 @@
import math
from typing import Optional
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
@@ -209,3 +210,42 @@ def sync_tensor(flat_tensor, tensor_list):
# update the tensor data
for p, q in zip(tensor_list, updated_params):
p.data = q.data
def all_gather_into_flat_tensor_nd(
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
group: Union[dist.ProcessGroup, Tuple[dist.ProcessGroup, ...]],
async_op: bool = False,
):
if isinstance(group, dist.ProcessGroup):
group = (group,)
sizes = [dist.get_world_size(pg) for pg in group]
ranks = [dist.get_rank(pg) for pg in group]
for i, pg in list(enumerate(group))[::-1]:
if i == 0:
out = output_tensor
else:
prev_sizes = sizes[:i]
prev_ranks = ranks[:i]
chunks = output_tensor.chunk(np.prod(prev_sizes))
out = chunks[np.ravel_multi_index(prev_ranks, prev_sizes)]
handle = dist.all_gather_into_tensor(out, input_tensor, group=pg, async_op=async_op)
input_tensor = out
return handle
def get_nd_world_size(group) -> int:
if isinstance(group, tuple):
return int(np.prod([dist.get_world_size(pg) for pg in group]))
else:
return dist.get_world_size(group)
def get_nd_rank(group) -> int:
if isinstance(group, tuple):
return np.ravel_multi_index(
tuple(dist.get_rank(group=pg) for pg in group), [dist.get_world_size(pg) for pg in group]
)
else:
return dist.get_rank(group)