mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[fp8] Disable all_gather intranode. Disable Redundant all_gather fp8 (#6059)
* all_gather only internode, fix pytest * fix cuda arch <89 compile pytest error * fix pytest failure * disable all_gather_into_tensor_flat_fp8 * fix fp8 format * fix pytest * fix conversations * fix chunk tuple to list
This commit is contained in:
@@ -5,7 +5,7 @@ from torch.testing import assert_close
|
||||
|
||||
from colossalai import launch
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.quantization.fp8 import all_to_all_fp8
|
||||
from colossalai.quantization.fp8 import _all_to_all_fp8
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ def check_4gpu(shape, scatter_dim, dtype, fp8_format):
|
||||
input_tensor_list = [x.contiguous() for x in input_tensor_list]
|
||||
output_tensor_list_fp8 = [torch.empty_like(x) for x in input_tensor_list]
|
||||
output_tensor_list = [torch.empty_like(x) for x in input_tensor_list]
|
||||
all_to_all_fp8(output_tensor_list_fp8, input_tensor_list, group=_get_default_group(), fp8_format=fp8_format)
|
||||
_all_to_all_fp8(output_tensor_list_fp8, input_tensor_list, group=_get_default_group(), fp8_format=fp8_format)
|
||||
dist.all_to_all(output_tensor_list, input_tensor_list, group=_get_default_group())
|
||||
assert_close(output_tensor_list_fp8, output_tensor_list, rtol=0.1, atol=0.1)
|
||||
|
||||
|
Reference in New Issue
Block a user