[zero] support extra dp (#6123)

* [zero] support extra dp

* [zero] update checkpoint

* fix bugs

* fix bugs
This commit is contained in:
Hongxin Liu
2024-11-12 11:20:46 +08:00
committed by GitHub
parent 30a9443132
commit a2596519fd
8 changed files with 238 additions and 57 deletions

View File

@@ -1,10 +1,12 @@
from typing import Optional
import numpy as np
import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from colossalai.quantization.fp8 import all_gather_fp8
from colossalai.zero.low_level._utils import all_gather_into_flat_tensor_nd
class TensorBucket:
@@ -65,12 +67,18 @@ class TensorBucket:
def all_gather(self, group=None, fp8_communication: bool = False):
flat = self.flatten()
buffer = torch.empty(flat.numel() * dist.get_world_size(group), device=flat.device, dtype=flat.dtype)
if isinstance(group, tuple):
world_size = np.prod([dist.get_world_size(pg) for pg in group])
else:
world_size = dist.get_world_size(group)
buffer = torch.empty(flat.numel() * world_size, device=flat.device, dtype=flat.dtype)
if fp8_communication:
# TODO: fit fp8
all_gather_fp8(list(buffer.chunk(dist.get_world_size(group))), flat, group=group, fp8_format="e4m3")
else:
dist.all_gather_into_tensor(buffer, flat, group=group)
unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(dist.get_world_size(group))]
# dist.all_gather_into_tensor(buffer, flat, group=group)
all_gather_into_flat_tensor_nd(buffer, flat, group=group)
unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(world_size)]
# transpose the list of list
unflat_buffers = list(map(list, zip(*unflat_buffers)))
for unflat_shards, tensor in zip(unflat_buffers, self._bucket):