[fp8] Disable all_gather intranode. Disable Redundant all_gather fp8 (#6059)

* all_gather only internode, fix pytest

* fix cuda arch <89 compile pytest error

* fix pytest failure

* disable all_gather_into_tensor_flat_fp8

* fix fp8 format

* fix pytest

* fix conversations

* fix chunk tuple to list
This commit is contained in:
Guangyao Zhang
2024-09-14 10:40:01 +08:00
committed by GitHub
parent 696fced0d7
commit f20b066c59
8 changed files with 43 additions and 147 deletions

View File

@@ -7,6 +7,7 @@ import torch.distributed as dist
from torch.distributed import ProcessGroup
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import all_gather_fp8
class TensorState(Enum):
@@ -523,11 +524,12 @@ class Chunk:
alloc_storage(self.cuda_global_chunk)
assert self.cuda_global_chunk.is_contiguous()
if self.fp8_communication:
assert async_op == False, "fp8 all-gather does not support async_op!"
from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8
work = all_gather_into_tensor_flat_fp8(
self.cuda_global_chunk, self.cuda_shard, self.cuda_global_chunk.shape, self.torch_pg
work = all_gather_fp8(
list(self.cuda_global_chunk.chunk(self.pg_size)),
self.cuda_shard,
self.torch_pg,
fp8_format="e4m3",
async_op=async_op,
)
else:
work = dist.all_gather_into_tensor(