From a4e5ed9990dbeaaeca4ea3355a9129fbb40e3d37 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 9 Apr 2025 16:32:10 +0800 Subject: [PATCH] fix --- tests/test_fp8/test_fp8_allgather.py | 3 ++- tests/test_fp8/test_fp8_allreduce.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_fp8/test_fp8_allgather.py b/tests/test_fp8/test_fp8_allgather.py index 91e66e83c..e6b618560 100644 --- a/tests/test_fp8/test_fp8_allgather.py +++ b/tests/test_fp8/test_fp8_allgather.py @@ -6,13 +6,14 @@ from torch.testing import assert_close from colossalai import launch from colossalai.accelerator import get_accelerator from colossalai.quantization.fp8 import _all_gather_fp8 -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn, clear_cache_before_run @parameterize( "shape", [(3, 7, 16)], ) + @parameterize("dtype", [torch.bfloat16, torch.float16]) @parameterize("fp8_format", ["e4m3", "e5m2"]) @parameterize("async_op", [True, False]) diff --git a/tests/test_fp8/test_fp8_allreduce.py b/tests/test_fp8/test_fp8_allreduce.py index ccc43ed29..d7e706ffd 100644 --- a/tests/test_fp8/test_fp8_allreduce.py +++ b/tests/test_fp8/test_fp8_allreduce.py @@ -5,7 +5,7 @@ from torch.testing import assert_close from colossalai import launch from colossalai.accelerator import get_accelerator from colossalai.quantization.fp8 import all_reduce_fp8 -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn, clear_cache_before_run @parameterize( @@ -20,6 +20,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn (8,), ], ) +@clear_cache_before_run() @parameterize("dtype", [torch.float16, torch.bfloat16]) @parameterize("fp8_format", ["e4m3", "e5m2"]) @parameterize("async_op", [True, False])