[fp8] add fp8 comm for low level zero

This commit is contained in:
ver217
2024-08-02 11:12:12 +08:00
parent 5fd0592767
commit ae486ce005
3 changed files with 32 additions and 9 deletions

View File

@@ -4,6 +4,8 @@ import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8
class TensorBucket:
def __init__(self, size):
@@ -61,11 +63,14 @@ class TensorBucket:
for old, new in zip(self._bucket, unflattened_tensor_list):
old.copy_(new)
def all_gather(self, group=None):
def all_gather(self, group=None, fp8_communication: bool = False):
flat = self.flatten()
buffers = [torch.empty_like(flat) for _ in range(dist.get_world_size(group))]
dist.all_gather(buffers, flat, group=group)
unflat_buffers = [self.unflatten(buffer) for buffer in buffers]
buffer = torch.empty(flat.numel() * dist.get_world_size(group), device=flat.device, dtype=flat.dtype)
if fp8_communication:
all_gather_into_tensor_flat_fp8(buffer, flat, output_shape=buffer.shape, group=group)
else:
dist.all_gather_into_tensor(buffer, flat, group=group)
unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(dist.get_world_size(group))]
# transpose the list of list
unflat_buffers = list(map(list, zip(*unflat_buffers)))
for unflat_shards, tensor in zip(unflat_buffers, self._bucket):