diff --git a/tests/test_fp8/test_all_to_all_single.py b/tests/test_fp8/test_all_to_all_single.py index 722cbce9a..0de5e836a 100644 --- a/tests/test_fp8/test_all_to_all_single.py +++ b/tests/test_fp8/test_all_to_all_single.py @@ -6,9 +6,10 @@ from torch.testing import assert_close from colossalai import launch from colossalai.accelerator import get_accelerator from colossalai.quantization.fp8 import all_to_all_single_fp8 -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn, clear_cache_before_run +@clear_cache_before_run() @parameterize("shape", [(4,), (1, 8, 16), (4, 8, 16)]) @parameterize("dtype", [torch.bfloat16, torch.float16]) @parameterize("async_op", [True, False]) @@ -24,6 +25,7 @@ def check_all2all(shape, dtype, async_op): assert_close(output, output_fp8, rtol=0.1, atol=0.1) +@clear_cache_before_run() @parameterize("shape", [(8, 8, 16)]) @parameterize("dtype", [torch.bfloat16, torch.float16]) @parameterize("async_op", [True, False]) diff --git a/tests/test_fp8/test_fp8_all_to_all.py b/tests/test_fp8/test_fp8_all_to_all.py index 98bbbad85..236ac2af8 100644 --- a/tests/test_fp8/test_fp8_all_to_all.py +++ b/tests/test_fp8/test_fp8_all_to_all.py @@ -6,9 +6,10 @@ from torch.testing import assert_close from colossalai import launch from colossalai.accelerator import get_accelerator from colossalai.quantization.fp8 import _all_to_all_fp8 -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn, clear_cache_before_run +@clear_cache_before_run() @parameterize("shape", [(16, 8, 4)]) @parameterize("scatter_dim", [0, 1, 2]) @parameterize("dtype", [torch.bfloat16, torch.float16]) diff --git a/tests/test_fp8/test_fp8_all_to_all_single.py b/tests/test_fp8/test_fp8_all_to_all_single.py index 70765f2d4..b5229d097 100644 --- a/tests/test_fp8/test_fp8_all_to_all_single.py +++ b/tests/test_fp8/test_fp8_all_to_all_single.py @@ -6,11 +6,12 @@ from torch.testing import assert_close from colossalai import launch from colossalai.accelerator import get_accelerator from colossalai.quantization.fp8 import all_to_all_single_fp8 -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn, clear_cache_before_run dist.all_to_all_single +@clear_cache_before_run() @parameterize("shape", [(4), (8, 7), (4, 8, 16)]) @parameterize("dtype", [torch.bfloat16, torch.float16]) @parameterize("fp8_format", ["e4m3", "e5m2"]) diff --git a/tests/test_fp8/test_fp8_allgather.py b/tests/test_fp8/test_fp8_allgather.py index ebbe2476a..79b55395d 100644 --- a/tests/test_fp8/test_fp8_allgather.py +++ b/tests/test_fp8/test_fp8_allgather.py @@ -9,11 +9,11 @@ from colossalai.quantization.fp8 import _all_gather_fp8 from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +@clear_cache_before_run() @parameterize( "shape", [(3, 7, 16)], ) -@clear_cache_before_run() @parameterize("dtype", [torch.bfloat16, torch.float16]) @parameterize("fp8_format", ["e4m3", "e5m2"]) @parameterize("async_op", [True, False]) diff --git a/tests/test_fp8/test_fp8_cast.py b/tests/test_fp8/test_fp8_cast.py index db9a909e6..88bdc094f 100644 --- a/tests/test_fp8/test_fp8_cast.py +++ b/tests/test_fp8/test_fp8_cast.py @@ -3,9 +3,11 @@ from torch.testing import assert_close from colossalai.accelerator import get_accelerator from colossalai.quantization.fp8 import cast_from_fp8, cast_from_fp8_pipeline, cast_to_fp8, cast_to_fp8_pipeline -from colossalai.testing import parameterize +from colossalai.testing import parameterize, clear_cache_before_run + +@clear_cache_before_run() @parameterize("shape", [(100, 10), (10, 100), (3, 7), (2, 1), (1, 2), (2, 2), (4, 2), (5,), (4,), (2,)]) @parameterize("dtype", [torch.bfloat16, torch.float16, torch.float32]) @parameterize("fp8_format", ["e4m3", "e5m2"]) diff --git a/tests/test_fp8/test_fp8_fsdp_comm_hook.py b/tests/test_fp8/test_fp8_fsdp_comm_hook.py index 3d0660961..97ba0ff36 100644 --- a/tests/test_fp8/test_fp8_fsdp_comm_hook.py +++ b/tests/test_fp8/test_fp8_fsdp_comm_hook.py @@ -8,7 +8,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.testing import assert_close from colossalai import launch -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn, clear_cache_before_run # example modified from https://pytorch.org/tutorials/intermediate/ddp_tutorial.html @@ -27,7 +27,7 @@ class ToyModel(nn.Module): def forward(self, x): return self.net2(self.relu(self.net1(x))) - +@clear_cache_before_run() @parameterize("mode", ["grad", "params"]) def run_model(mode): rank = dist.get_rank() diff --git a/tests/test_fp8/test_fp8_reduce_scatter.py b/tests/test_fp8/test_fp8_reduce_scatter.py index e0b558a25..7a2dc3188 100644 --- a/tests/test_fp8/test_fp8_reduce_scatter.py +++ b/tests/test_fp8/test_fp8_reduce_scatter.py @@ -6,9 +6,10 @@ from torch.testing import assert_close from colossalai import launch from colossalai.accelerator import get_accelerator from colossalai.quantization.fp8 import reduce_scatter_fp8 -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn, clear_cache_before_run +@clear_cache_before_run() @parameterize("shape", [(16, 8, 4)]) @parameterize("scatter_dim", [0, 1, 2]) @parameterize("dtype", [torch.bfloat16, torch.float16])