[fp8] Disable all_gather intranode. Disable Redundant all_gather fp8 (#6059)

* all_gather only internode, fix pytest

* fix cuda arch <89 compile pytest error

* fix pytest failure

* disable all_gather_into_tensor_flat_fp8

* fix fp8 format

* fix pytest

* fix conversations

* fix chunk tuple to list
This commit is contained in:
Guangyao Zhang
2024-09-14 10:40:01 +08:00
committed by GitHub
parent 696fced0d7
commit f20b066c59
8 changed files with 43 additions and 147 deletions

View File

@@ -4,7 +4,7 @@ import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8
from colossalai.quantization.fp8 import all_gather_fp8
class TensorBucket:
@@ -67,7 +67,7 @@ class TensorBucket:
flat = self.flatten()
buffer = torch.empty(flat.numel() * dist.get_world_size(group), device=flat.device, dtype=flat.dtype)
if fp8_communication:
all_gather_into_tensor_flat_fp8(buffer, flat, output_shape=buffer.shape, group=group)
all_gather_fp8(list(buffer.chunk(dist.get_world_size(group))), flat, group=group, fp8_format="e4m3")
else:
dist.all_gather_into_tensor(buffer, flat, group=group)
unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(dist.get_world_size(group))]