diff --git a/colossalai/utils/activation_checkpoint.py b/colossalai/utils/activation_checkpoint.py index 2edd6b1a5..fa9ed827a 100644 --- a/colossalai/utils/activation_checkpoint.py +++ b/colossalai/utils/activation_checkpoint.py @@ -7,6 +7,8 @@ from torch.utils.checkpoint import check_backward_validity, detach_variable from colossalai.context.random import get_states, get_current_mode, set_seed_states, set_mode, sync_states from .cuda import get_current_device +import weakref + def copy_to_device(obj, device): if torch.is_tensor(obj): @@ -136,14 +138,122 @@ class CheckpointFunction(torch.autograd.Function): return (None, None) + grads -def checkpoint(function, activation_offload, *args): +def checkpoint(function, activation_offload, *args, use_reentrant: bool = True): """Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint. Args: function: Describe the forward pass function. It should know how to handle the input tuples. + activation_offload: The variable to check whether we should offload activation to cpu args (list): Tuple containing the parameters of the function + use_reentrant: Bool type to check if we need to use_reentrant, if use_reentrant=False, there + might be more flexibility for user to define there checkpoint function Returns: Output of running function with provided args. """ - return CheckpointFunction.apply(function, activation_offload, *args) + if use_reentrant: + return CheckpointFunction.apply(function, activation_offload, *args) + else: + return _checkpoint_without_reentrant( + function, + activation_offload, + *args, + ) + + +def _checkpoint_without_reentrant(function, activation_offload=False, *args): + # store rng_state + fwd_cpu_state = torch.get_rng_state() + sync_states() + fwd_seed_states = get_states(copy=True) + fwd_current_mode = get_current_mode() + + # check if use autocast + if hasattr(torch, 'is_autocast_enabled'): + has_autocast_in_fwd = torch.is_autocast_enabled() + else: + has_autocast_in_fwd = False + + # using WeakKeyDictionary to store all the activation the first time we call unpack + storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() + weak_holder_list = [] + + # class for weakref.ref + class Holder(): + pass + + # return a Holder object for later unpack process + def pack(x): + res = Holder() + weak_holder_list.append(weakref.ref(res)) + return res + + # unpack hook + def unpack(x): + unpack_counter = 0 + + # re-compute all the activation inside the function when we first call unpack + if len(storage) == 0: + + def inner_pack(inner): + nonlocal unpack_counter + unpack_counter += 1 + + # If the holder went out of scope, the SavedVariable is dead and so + # the value will never be read from the storage. Skip filling it. + if weak_holder_list[unpack_counter - 1]() is None: + return + + # Use detach here to ensure we don't keep the temporary autograd + # graph created during the second forward + storage[weak_holder_list[unpack_counter - 1]()] = inner.detach() + return + + def inner_unpack(packed): + raise RuntimeError("You are calling backwards on a tensor that is never exposed. Please open an issue.") + + # restore rng state + torch.set_rng_state(fwd_cpu_state) + for parallel_mode, state in fwd_seed_states.items(): + set_seed_states(parallel_mode, state) + set_mode(fwd_current_mode) + + # reload arg into device if needed + if activation_offload: + for arg in args: + if torch.is_tensor(arg): + arg = arg.to(device=device) + + # rerun forward, the inner_pack will store all the activations in storage + if has_autocast_in_fwd: + with torch.enable_grad(), \ + torch.cuda.amp.autocast(), \ + torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack): + _unused = function(*args) + else: + with torch.enable_grad(), \ + torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack): + _unused = function(*args) + + if x not in storage: + raise RuntimeError("Attempt to retrieve a tensor saved by autograd multiple times without checkpoint" + " recomputation being triggered in between, this is not currently supported. Please" + " open an issue with details on your use case so that we can prioritize adding this.") + + return storage[x] + + # get device if we need to offload the activation + if activation_offload: + device = get_current_device() + + # run function with pack and unpack as saved_tensors_hooks + with torch.autograd.graph.saved_tensors_hooks(pack, unpack): + output = function(*args) + + # offload activation if needed + if activation_offload: + for arg in args: + if torch.is_tensor(arg): + arg = arg.to(device="cpu") + + return output diff --git a/tests/test_utils/test_activation_checkpointing.py b/tests/test_utils/test_activation_checkpointing.py index a68644254..3ac75fb00 100644 --- a/tests/test_utils/test_activation_checkpointing.py +++ b/tests/test_utils/test_activation_checkpointing.py @@ -6,7 +6,7 @@ import torch import torch.nn.functional as F from colossalai.context.parallel_mode import ParallelMode from colossalai.context.random import add_seed, seed, set_mode, reset_seeds -from colossalai.utils import checkpoint +from colossalai.utils.activation_checkpoint import checkpoint def forward(x, weight): @@ -16,10 +16,37 @@ def forward(x, weight): return out_ +def forward_inplace_ckpt(x, weight, cpu_offload=False): + out = torch.matmul(x, weight) + bn = torch.nn.BatchNorm1d(4, affine=False) + bn = bn.to(device="cuda") + out = bn(out) + + def ckpt0(x): + return F.relu(x, inplace=True) + + out = checkpoint(ckpt0, cpu_offload, out, use_reentrant=False) + return out + + +def forward_inplace(x, weight): + out = torch.matmul(x, weight) + bn = torch.nn.BatchNorm1d(4, affine=False) + bn = bn.to(device="cuda") + out = bn(out) + out = F.relu(out, inplace=True) + return out + + @pytest.mark.gpu -@pytest.mark.skip("set seed error") +@pytest.mark.parametrize("use_reentrant", [True, False]) @pytest.mark.parametrize("cpu_offload", [True, False]) -def test_activation_checkpointing(cpu_offload): +def test_activation_checkpointing(cpu_offload, use_reentrant): + + # as seed manager is singleton + # if we don't reset seeds here, + # other tests might affect this test + reset_seeds() # We put initilization here to avoid change cuda rng state below inputs = torch.rand(2, 2, requires_grad=True, device='cuda') @@ -50,15 +77,46 @@ def test_activation_checkpointing(cpu_offload): torch.cuda.set_rng_state(data_parallel_cuda_rng_state) set_mode(ParallelMode.GLOBAL) - out = checkpoint(forward, cpu_offload, inputs_, weight_) + out = checkpoint(forward, cpu_offload, inputs_, weight_, use_reentrant=use_reentrant) loss = out.sum() loss.backward() assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match' torch.cuda.empty_cache() + # Extra test for use_reentrant=False + if use_reentrant == False: + # Recover cuda rng states + set_mode(ParallelMode.GLOBAL) + torch.cuda.set_rng_state(global_cuda_rng_state) + set_mode(ParallelMode.DATA) + torch.cuda.set_rng_state(data_parallel_cuda_rng_state) + set_mode(ParallelMode.GLOBAL) + + out = forward_inplace(inputs, weight) + loss = out.sum() + loss.backward() + + # Recover cuda rng states + set_mode(ParallelMode.GLOBAL) + torch.cuda.set_rng_state(global_cuda_rng_state) + set_mode(ParallelMode.DATA) + torch.cuda.set_rng_state(data_parallel_cuda_rng_state) + set_mode(ParallelMode.GLOBAL) + + out = forward_inplace_ckpt(inputs_, weight_, cpu_offload=cpu_offload) + loss = out.sum() + loss.backward() + + assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match' + torch.cuda.empty_cache() + # as seed manager is singleton # if we don't reset seeds here, # other tests will fail if running together with this test # as other tests can't overwrite the seed set by this test reset_seeds() + + +if __name__ == "__main__": + test_activation_checkpointing(False, False)