[fp8] Disable all_gather intranode. Disable Redundant all_gather fp8 (#6059)

* all_gather only internode, fix pytest

* fix cuda arch <89 compile pytest error

* fix pytest failure

* disable all_gather_into_tensor_flat_fp8

* fix fp8 format

* fix pytest

* fix conversations

* fix chunk tuple to list
This commit is contained in:
Guangyao Zhang
2024-09-14 10:40:01 +08:00
committed by GitHub
parent 696fced0d7
commit f20b066c59
8 changed files with 43 additions and 147 deletions

View File

@@ -17,10 +17,10 @@ except ImportError:
_grad_accum_fusion_available = False
from colossalai.quantization.fp8 import (
all_gather_fp8,
all_reduce_fp8,
all_to_all_fp8,
all_to_all_single_fp8,
gather_fp8,
reduce_scatter_fp8,
)
@@ -961,7 +961,7 @@ def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_for
input_ = input_.contiguous()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
if fp8_communication:
gather_fp8(tensor_list, input_, fp8_format=fp8_format, group=process_group)
all_gather_fp8(tensor_list, input_, fp8_format=fp8_format, group=process_group)
else:
dist.all_gather(tensor_list, input_, group=process_group)