diff --git a/tests/test_fp8/test_all_to_all_single.py b/tests/test_fp8/test_all_to_all_single.py index 0de5e836a..448a3f031 100644 --- a/tests/test_fp8/test_all_to_all_single.py +++ b/tests/test_fp8/test_all_to_all_single.py @@ -6,7 +6,7 @@ from torch.testing import assert_close from colossalai import launch from colossalai.accelerator import get_accelerator from colossalai.quantization.fp8 import all_to_all_single_fp8 -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn, clear_cache_before_run +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn @clear_cache_before_run() diff --git a/tests/test_fp8/test_fp8_all_to_all.py b/tests/test_fp8/test_fp8_all_to_all.py index 236ac2af8..a86741b4c 100644 --- a/tests/test_fp8/test_fp8_all_to_all.py +++ b/tests/test_fp8/test_fp8_all_to_all.py @@ -6,7 +6,7 @@ from torch.testing import assert_close from colossalai import launch from colossalai.accelerator import get_accelerator from colossalai.quantization.fp8 import _all_to_all_fp8 -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn, clear_cache_before_run +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn @clear_cache_before_run() diff --git a/tests/test_fp8/test_fp8_all_to_all_single.py b/tests/test_fp8/test_fp8_all_to_all_single.py index b5229d097..a301301b3 100644 --- a/tests/test_fp8/test_fp8_all_to_all_single.py +++ b/tests/test_fp8/test_fp8_all_to_all_single.py @@ -6,7 +6,7 @@ from torch.testing import assert_close from colossalai import launch from colossalai.accelerator import get_accelerator from colossalai.quantization.fp8 import all_to_all_single_fp8 -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn, clear_cache_before_run +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn dist.all_to_all_single diff --git a/tests/test_fp8/test_fp8_cast.py b/tests/test_fp8/test_fp8_cast.py index 88bdc094f..479cb3770 100644 --- a/tests/test_fp8/test_fp8_cast.py +++ b/tests/test_fp8/test_fp8_cast.py @@ -3,8 +3,7 @@ from torch.testing import assert_close from colossalai.accelerator import get_accelerator from colossalai.quantization.fp8 import cast_from_fp8, cast_from_fp8_pipeline, cast_to_fp8, cast_to_fp8_pipeline -from colossalai.testing import parameterize, clear_cache_before_run - +from colossalai.testing import clear_cache_before_run, parameterize @clear_cache_before_run() diff --git a/tests/test_fp8/test_fp8_fsdp_comm_hook.py b/tests/test_fp8/test_fp8_fsdp_comm_hook.py index 97ba0ff36..a95fbdf01 100644 --- a/tests/test_fp8/test_fp8_fsdp_comm_hook.py +++ b/tests/test_fp8/test_fp8_fsdp_comm_hook.py @@ -8,7 +8,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.testing import assert_close from colossalai import launch -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn, clear_cache_before_run +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn # example modified from https://pytorch.org/tutorials/intermediate/ddp_tutorial.html @@ -27,6 +27,7 @@ class ToyModel(nn.Module): def forward(self, x): return self.net2(self.relu(self.net1(x))) + @clear_cache_before_run() @parameterize("mode", ["grad", "params"]) def run_model(mode): diff --git a/tests/test_fp8/test_fp8_reduce_scatter.py b/tests/test_fp8/test_fp8_reduce_scatter.py index 7a2dc3188..a2eac1c7e 100644 --- a/tests/test_fp8/test_fp8_reduce_scatter.py +++ b/tests/test_fp8/test_fp8_reduce_scatter.py @@ -6,7 +6,7 @@ from torch.testing import assert_close from colossalai import launch from colossalai.accelerator import get_accelerator from colossalai.quantization.fp8 import reduce_scatter_fp8 -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn, clear_cache_before_run +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn @clear_cache_before_run()