[utils] Add use_reetrant=False in utils.activation_checkpoint (#1460)

* [utils] Add use_reetrant=False into colossalai checkpoint

* [utils] add some annotation in utils.activaion_checkpoint

* [test] add reset_seed at the beginning of tests in test_actiavion_checkpointing.py

* [test] modify test_activation_checkpoint.py

* [test] modify test for reentrant=False
This commit is contained in:
Boyuan Yao
2022-08-16 15:39:20 +08:00
committed by GitHub
parent 36824a304c
commit 47fd8e4a02
2 changed files with 174 additions and 6 deletions

View File

@@ -7,6 +7,8 @@ from torch.utils.checkpoint import check_backward_validity, detach_variable
from colossalai.context.random import get_states, get_current_mode, set_seed_states, set_mode, sync_states
from .cuda import get_current_device
import weakref
def copy_to_device(obj, device):
if torch.is_tensor(obj):
@@ -136,14 +138,122 @@ class CheckpointFunction(torch.autograd.Function):
return (None, None) + grads
def checkpoint(function, activation_offload, *args):
def checkpoint(function, activation_offload, *args, use_reentrant: bool = True):
"""Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint.
Args:
function: Describe the forward pass function. It should know how to handle the input tuples.
activation_offload: The variable to check whether we should offload activation to cpu
args (list): Tuple containing the parameters of the function
use_reentrant: Bool type to check if we need to use_reentrant, if use_reentrant=False, there
might be more flexibility for user to define there checkpoint function
Returns:
Output of running function with provided args.
"""
return CheckpointFunction.apply(function, activation_offload, *args)
if use_reentrant:
return CheckpointFunction.apply(function, activation_offload, *args)
else:
return _checkpoint_without_reentrant(
function,
activation_offload,
*args,
)
def _checkpoint_without_reentrant(function, activation_offload=False, *args):
# store rng_state
fwd_cpu_state = torch.get_rng_state()
sync_states()
fwd_seed_states = get_states(copy=True)
fwd_current_mode = get_current_mode()
# check if use autocast
if hasattr(torch, 'is_autocast_enabled'):
has_autocast_in_fwd = torch.is_autocast_enabled()
else:
has_autocast_in_fwd = False
# using WeakKeyDictionary to store all the activation the first time we call unpack
storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
weak_holder_list = []
# class for weakref.ref
class Holder():
pass
# return a Holder object for later unpack process
def pack(x):
res = Holder()
weak_holder_list.append(weakref.ref(res))
return res
# unpack hook
def unpack(x):
unpack_counter = 0
# re-compute all the activation inside the function when we first call unpack
if len(storage) == 0:
def inner_pack(inner):
nonlocal unpack_counter
unpack_counter += 1
# If the holder went out of scope, the SavedVariable is dead and so
# the value will never be read from the storage. Skip filling it.
if weak_holder_list[unpack_counter - 1]() is None:
return
# Use detach here to ensure we don't keep the temporary autograd
# graph created during the second forward
storage[weak_holder_list[unpack_counter - 1]()] = inner.detach()
return
def inner_unpack(packed):
raise RuntimeError("You are calling backwards on a tensor that is never exposed. Please open an issue.")
# restore rng state
torch.set_rng_state(fwd_cpu_state)
for parallel_mode, state in fwd_seed_states.items():
set_seed_states(parallel_mode, state)
set_mode(fwd_current_mode)
# reload arg into device if needed
if activation_offload:
for arg in args:
if torch.is_tensor(arg):
arg = arg.to(device=device)
# rerun forward, the inner_pack will store all the activations in storage
if has_autocast_in_fwd:
with torch.enable_grad(), \
torch.cuda.amp.autocast(), \
torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
_unused = function(*args)
else:
with torch.enable_grad(), \
torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
_unused = function(*args)
if x not in storage:
raise RuntimeError("Attempt to retrieve a tensor saved by autograd multiple times without checkpoint"
" recomputation being triggered in between, this is not currently supported. Please"
" open an issue with details on your use case so that we can prioritize adding this.")
return storage[x]
# get device if we need to offload the activation
if activation_offload:
device = get_current_device()
# run function with pack and unpack as saved_tensors_hooks
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
output = function(*args)
# offload activation if needed
if activation_offload:
for arg in args:
if torch.is_tensor(arg):
arg = arg.to(device="cpu")
return output