mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 17:46:42 +00:00
[utils] Add use_reetrant=False in utils.activation_checkpoint (#1460)
* [utils] Add use_reetrant=False into colossalai checkpoint * [utils] add some annotation in utils.activaion_checkpoint * [test] add reset_seed at the beginning of tests in test_actiavion_checkpointing.py * [test] modify test_activation_checkpoint.py * [test] modify test for reentrant=False
This commit is contained in:
@@ -6,7 +6,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.context.random import add_seed, seed, set_mode, reset_seeds
|
||||
from colossalai.utils import checkpoint
|
||||
from colossalai.utils.activation_checkpoint import checkpoint
|
||||
|
||||
|
||||
def forward(x, weight):
|
||||
@@ -16,10 +16,37 @@ def forward(x, weight):
|
||||
return out_
|
||||
|
||||
|
||||
def forward_inplace_ckpt(x, weight, cpu_offload=False):
|
||||
out = torch.matmul(x, weight)
|
||||
bn = torch.nn.BatchNorm1d(4, affine=False)
|
||||
bn = bn.to(device="cuda")
|
||||
out = bn(out)
|
||||
|
||||
def ckpt0(x):
|
||||
return F.relu(x, inplace=True)
|
||||
|
||||
out = checkpoint(ckpt0, cpu_offload, out, use_reentrant=False)
|
||||
return out
|
||||
|
||||
|
||||
def forward_inplace(x, weight):
|
||||
out = torch.matmul(x, weight)
|
||||
bn = torch.nn.BatchNorm1d(4, affine=False)
|
||||
bn = bn.to(device="cuda")
|
||||
out = bn(out)
|
||||
out = F.relu(out, inplace=True)
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
@pytest.mark.skip("set seed error")
|
||||
@pytest.mark.parametrize("use_reentrant", [True, False])
|
||||
@pytest.mark.parametrize("cpu_offload", [True, False])
|
||||
def test_activation_checkpointing(cpu_offload):
|
||||
def test_activation_checkpointing(cpu_offload, use_reentrant):
|
||||
|
||||
# as seed manager is singleton
|
||||
# if we don't reset seeds here,
|
||||
# other tests might affect this test
|
||||
reset_seeds()
|
||||
|
||||
# We put initilization here to avoid change cuda rng state below
|
||||
inputs = torch.rand(2, 2, requires_grad=True, device='cuda')
|
||||
@@ -50,15 +77,46 @@ def test_activation_checkpointing(cpu_offload):
|
||||
torch.cuda.set_rng_state(data_parallel_cuda_rng_state)
|
||||
set_mode(ParallelMode.GLOBAL)
|
||||
|
||||
out = checkpoint(forward, cpu_offload, inputs_, weight_)
|
||||
out = checkpoint(forward, cpu_offload, inputs_, weight_, use_reentrant=use_reentrant)
|
||||
loss = out.sum()
|
||||
loss.backward()
|
||||
|
||||
assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match'
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Extra test for use_reentrant=False
|
||||
if use_reentrant == False:
|
||||
# Recover cuda rng states
|
||||
set_mode(ParallelMode.GLOBAL)
|
||||
torch.cuda.set_rng_state(global_cuda_rng_state)
|
||||
set_mode(ParallelMode.DATA)
|
||||
torch.cuda.set_rng_state(data_parallel_cuda_rng_state)
|
||||
set_mode(ParallelMode.GLOBAL)
|
||||
|
||||
out = forward_inplace(inputs, weight)
|
||||
loss = out.sum()
|
||||
loss.backward()
|
||||
|
||||
# Recover cuda rng states
|
||||
set_mode(ParallelMode.GLOBAL)
|
||||
torch.cuda.set_rng_state(global_cuda_rng_state)
|
||||
set_mode(ParallelMode.DATA)
|
||||
torch.cuda.set_rng_state(data_parallel_cuda_rng_state)
|
||||
set_mode(ParallelMode.GLOBAL)
|
||||
|
||||
out = forward_inplace_ckpt(inputs_, weight_, cpu_offload=cpu_offload)
|
||||
loss = out.sum()
|
||||
loss.backward()
|
||||
|
||||
assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match'
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# as seed manager is singleton
|
||||
# if we don't reset seeds here,
|
||||
# other tests will fail if running together with this test
|
||||
# as other tests can't overwrite the seed set by this test
|
||||
reset_seeds()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_activation_checkpointing(False, False)
|
||||
|
Reference in New Issue
Block a user