mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[fp8] Disable all_gather intranode. Disable Redundant all_gather fp8 (#6059)
* all_gather only internode, fix pytest * fix cuda arch <89 compile pytest error * fix pytest failure * disable all_gather_into_tensor_flat_fp8 * fix fp8 format * fix pytest * fix conversations * fix chunk tuple to list
This commit is contained in:
@@ -4,7 +4,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8
|
||||
from colossalai.quantization.fp8 import all_gather_fp8
|
||||
|
||||
|
||||
class TensorBucket:
|
||||
@@ -67,7 +67,7 @@ class TensorBucket:
|
||||
flat = self.flatten()
|
||||
buffer = torch.empty(flat.numel() * dist.get_world_size(group), device=flat.device, dtype=flat.dtype)
|
||||
if fp8_communication:
|
||||
all_gather_into_tensor_flat_fp8(buffer, flat, output_shape=buffer.shape, group=group)
|
||||
all_gather_fp8(list(buffer.chunk(dist.get_world_size(group))), flat, group=group, fp8_format="e4m3")
|
||||
else:
|
||||
dist.all_gather_into_tensor(buffer, flat, group=group)
|
||||
unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(dist.get_world_size(group))]
|
||||
|
@@ -20,7 +20,7 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import (
|
||||
)
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8, all_reduce_fp8, reduce_scatter_fp8
|
||||
from colossalai.quantization.fp8 import all_gather_fp8, all_reduce_fp8, reduce_scatter_fp8
|
||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
|
||||
from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor
|
||||
@@ -580,8 +580,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
else:
|
||||
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size:
|
||||
if self._fp8_communication:
|
||||
all_gather_into_tensor_flat_fp8(
|
||||
padded_working_param, param_to_gather, pg, fp8_format="e4m3"
|
||||
all_gather_fp8(
|
||||
list(padded_working_param.chunk(dist.get_world_size(pg))),
|
||||
param_to_gather,
|
||||
pg,
|
||||
fp8_format="e4m3",
|
||||
)
|
||||
else:
|
||||
dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg)
|
||||
|
Reference in New Issue
Block a user