This commit is contained in:
flybird11111 2025-04-10 12:55:02 +08:00
parent eaef783ec3
commit e8a3d52381
7 changed files with 15 additions and 8 deletions

View File

@ -6,9 +6,10 @@ from torch.testing import assert_close
from colossalai import launch
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import all_to_all_single_fp8
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn, clear_cache_before_run
@clear_cache_before_run()
@parameterize("shape", [(4,), (1, 8, 16), (4, 8, 16)])
@parameterize("dtype", [torch.bfloat16, torch.float16])
@parameterize("async_op", [True, False])
@ -24,6 +25,7 @@ def check_all2all(shape, dtype, async_op):
assert_close(output, output_fp8, rtol=0.1, atol=0.1)
@clear_cache_before_run()
@parameterize("shape", [(8, 8, 16)])
@parameterize("dtype", [torch.bfloat16, torch.float16])
@parameterize("async_op", [True, False])

View File

@ -6,9 +6,10 @@ from torch.testing import assert_close
from colossalai import launch
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import _all_to_all_fp8
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn, clear_cache_before_run
@clear_cache_before_run()
@parameterize("shape", [(16, 8, 4)])
@parameterize("scatter_dim", [0, 1, 2])
@parameterize("dtype", [torch.bfloat16, torch.float16])

View File

@ -6,11 +6,12 @@ from torch.testing import assert_close
from colossalai import launch
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import all_to_all_single_fp8
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn, clear_cache_before_run
dist.all_to_all_single
@clear_cache_before_run()
@parameterize("shape", [(4), (8, 7), (4, 8, 16)])
@parameterize("dtype", [torch.bfloat16, torch.float16])
@parameterize("fp8_format", ["e4m3", "e5m2"])

View File

@ -9,11 +9,11 @@ from colossalai.quantization.fp8 import _all_gather_fp8
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
@clear_cache_before_run()
@parameterize(
"shape",
[(3, 7, 16)],
)
@clear_cache_before_run()
@parameterize("dtype", [torch.bfloat16, torch.float16])
@parameterize("fp8_format", ["e4m3", "e5m2"])
@parameterize("async_op", [True, False])

View File

@ -3,9 +3,11 @@ from torch.testing import assert_close
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import cast_from_fp8, cast_from_fp8_pipeline, cast_to_fp8, cast_to_fp8_pipeline
from colossalai.testing import parameterize
from colossalai.testing import parameterize, clear_cache_before_run
@clear_cache_before_run()
@parameterize("shape", [(100, 10), (10, 100), (3, 7), (2, 1), (1, 2), (2, 2), (4, 2), (5,), (4,), (2,)])
@parameterize("dtype", [torch.bfloat16, torch.float16, torch.float32])
@parameterize("fp8_format", ["e4m3", "e5m2"])

View File

@ -8,7 +8,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing import assert_close
from colossalai import launch
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn, clear_cache_before_run
# example modified from https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
@ -27,7 +27,7 @@ class ToyModel(nn.Module):
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
@clear_cache_before_run()
@parameterize("mode", ["grad", "params"])
def run_model(mode):
rank = dist.get_rank()

View File

@ -6,9 +6,10 @@ from torch.testing import assert_close
from colossalai import launch
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import reduce_scatter_fp8
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn, clear_cache_before_run
@clear_cache_before_run()
@parameterize("shape", [(16, 8, 4)])
@parameterize("scatter_dim", [0, 1, 2])
@parameterize("dtype", [torch.bfloat16, torch.float16])