mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
fixed bug in activation checkpointing test (#387)
This commit is contained in:
@@ -7,7 +7,7 @@ import torch.nn.functional as F
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.context.random import add_seed, seed, set_mode
|
||||
from colossalai.context.random import add_seed, seed, set_mode, reset_seeds
|
||||
from colossalai.utils import checkpoint
|
||||
|
||||
|
||||
@@ -17,12 +17,12 @@ def forward(x, weight):
|
||||
out_ = F.dropout(out, p=0.4, training=True)
|
||||
return out_
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
@pytest.mark.parametrize("cpu_offload", [True, False])
|
||||
def test_activation_checkpointing(cpu_offload):
|
||||
if cpu_offload:
|
||||
add_seed(ParallelMode.GLOBAL, 1024)
|
||||
add_seed(ParallelMode.DATA, 1026)
|
||||
add_seed(ParallelMode.GLOBAL, 1024)
|
||||
add_seed(ParallelMode.DATA, 1026)
|
||||
set_mode(ParallelMode.GLOBAL)
|
||||
global_cuda_rng_state = torch.cuda.get_rng_state()
|
||||
set_mode(ParallelMode.DATA)
|
||||
@@ -56,4 +56,8 @@ def test_activation_checkpointing(cpu_offload):
|
||||
|
||||
assert torch.all(data.grad == data_.grad), 'Gradient of the input does not match'
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# as seed manager is singleton
|
||||
# if we don't reset seeds here,
|
||||
# other tests will fail if running together with this test
|
||||
# as other tests can't overwrite the seed set by this test
|
||||
reset_seeds()
|
||||
|
Reference in New Issue
Block a user