mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-21 21:22:04 +00:00
fix
This commit is contained in:
parent
c0811d7342
commit
a4e5ed9990
@ -6,13 +6,14 @@ from torch.testing import assert_close
|
|||||||
from colossalai import launch
|
from colossalai import launch
|
||||||
from colossalai.accelerator import get_accelerator
|
from colossalai.accelerator import get_accelerator
|
||||||
from colossalai.quantization.fp8 import _all_gather_fp8
|
from colossalai.quantization.fp8 import _all_gather_fp8
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn, clear_cache_before_run
|
||||||
|
|
||||||
|
|
||||||
@parameterize(
|
@parameterize(
|
||||||
"shape",
|
"shape",
|
||||||
[(3, 7, 16)],
|
[(3, 7, 16)],
|
||||||
)
|
)
|
||||||
|
|
||||||
@parameterize("dtype", [torch.bfloat16, torch.float16])
|
@parameterize("dtype", [torch.bfloat16, torch.float16])
|
||||||
@parameterize("fp8_format", ["e4m3", "e5m2"])
|
@parameterize("fp8_format", ["e4m3", "e5m2"])
|
||||||
@parameterize("async_op", [True, False])
|
@parameterize("async_op", [True, False])
|
||||||
|
@ -5,7 +5,7 @@ from torch.testing import assert_close
|
|||||||
from colossalai import launch
|
from colossalai import launch
|
||||||
from colossalai.accelerator import get_accelerator
|
from colossalai.accelerator import get_accelerator
|
||||||
from colossalai.quantization.fp8 import all_reduce_fp8
|
from colossalai.quantization.fp8 import all_reduce_fp8
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn, clear_cache_before_run
|
||||||
|
|
||||||
|
|
||||||
@parameterize(
|
@parameterize(
|
||||||
@ -20,6 +20,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
|||||||
(8,),
|
(8,),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@clear_cache_before_run()
|
||||||
@parameterize("dtype", [torch.float16, torch.bfloat16])
|
@parameterize("dtype", [torch.float16, torch.bfloat16])
|
||||||
@parameterize("fp8_format", ["e4m3", "e5m2"])
|
@parameterize("fp8_format", ["e4m3", "e5m2"])
|
||||||
@parameterize("async_op", [True, False])
|
@parameterize("async_op", [True, False])
|
||||||
|
Loading…
Reference in New Issue
Block a user