mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[zero] support extra dp (#6123)
* [zero] support extra dp * [zero] update checkpoint * fix bugs * fix bugs
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
from colossalai.quantization.fp8 import all_gather_fp8
|
||||
from colossalai.zero.low_level._utils import all_gather_into_flat_tensor_nd
|
||||
|
||||
|
||||
class TensorBucket:
|
||||
@@ -65,12 +67,18 @@ class TensorBucket:
|
||||
|
||||
def all_gather(self, group=None, fp8_communication: bool = False):
|
||||
flat = self.flatten()
|
||||
buffer = torch.empty(flat.numel() * dist.get_world_size(group), device=flat.device, dtype=flat.dtype)
|
||||
if isinstance(group, tuple):
|
||||
world_size = np.prod([dist.get_world_size(pg) for pg in group])
|
||||
else:
|
||||
world_size = dist.get_world_size(group)
|
||||
buffer = torch.empty(flat.numel() * world_size, device=flat.device, dtype=flat.dtype)
|
||||
if fp8_communication:
|
||||
# TODO: fit fp8
|
||||
all_gather_fp8(list(buffer.chunk(dist.get_world_size(group))), flat, group=group, fp8_format="e4m3")
|
||||
else:
|
||||
dist.all_gather_into_tensor(buffer, flat, group=group)
|
||||
unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(dist.get_world_size(group))]
|
||||
# dist.all_gather_into_tensor(buffer, flat, group=group)
|
||||
all_gather_into_flat_tensor_nd(buffer, flat, group=group)
|
||||
unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(world_size)]
|
||||
# transpose the list of list
|
||||
unflat_buffers = list(map(list, zip(*unflat_buffers)))
|
||||
for unflat_shards, tensor in zip(unflat_buffers, self._bucket):
|
||||
|
Reference in New Issue
Block a user