mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -28,7 +28,6 @@ def copy_to_device(obj, device):
|
||||
|
||||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, activation_offload=False, *args):
|
||||
check_backward_validity(args)
|
||||
@@ -42,7 +41,7 @@ class CheckpointFunction(torch.autograd.Function):
|
||||
ctx.fwd_seed_states = get_states(copy=True)
|
||||
ctx.fwd_current_mode = get_current_mode()
|
||||
|
||||
if hasattr(torch, 'is_autocast_enabled'):
|
||||
if hasattr(torch, "is_autocast_enabled"):
|
||||
ctx.had_autocast_in_fwd = torch.is_autocast_enabled()
|
||||
else:
|
||||
ctx.had_autocast_in_fwd = False
|
||||
@@ -62,7 +61,7 @@ class CheckpointFunction(torch.autograd.Function):
|
||||
for i, arg in enumerate(args):
|
||||
if torch.is_tensor(arg):
|
||||
if activation_offload:
|
||||
tensor_inputs.append(copy_to_device(arg, 'cpu'))
|
||||
tensor_inputs.append(copy_to_device(arg, "cpu"))
|
||||
else:
|
||||
tensor_inputs.append(arg)
|
||||
ctx.tensor_indices.append(i)
|
||||
@@ -79,8 +78,10 @@ class CheckpointFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def backward(ctx, *args):
|
||||
if not torch.autograd._is_checkpoint_valid():
|
||||
raise RuntimeError("Checkpointing is not compatible with .grad() or when an `inputs` parameter is "
|
||||
"passed to .backward(). Please use .backward() and do not pass its `inputs` argument.")
|
||||
raise RuntimeError(
|
||||
"Checkpointing is not compatible with .grad() or when an `inputs` parameter is "
|
||||
"passed to .backward(). Please use .backward() and do not pass its `inputs` argument."
|
||||
)
|
||||
# Copy the list to avoid modifying original list.
|
||||
inputs = list(ctx.inputs)
|
||||
tensor_indices = ctx.tensor_indices
|
||||
@@ -131,8 +132,7 @@ class CheckpointFunction(torch.autograd.Function):
|
||||
outputs_with_grad.append(outputs[i])
|
||||
args_with_grad.append(args[i])
|
||||
if len(outputs_with_grad) == 0:
|
||||
raise RuntimeError("none of output has requires_grad=True,"
|
||||
" this checkpoint() is not necessary")
|
||||
raise RuntimeError("none of output has requires_grad=True," " this checkpoint() is not necessary")
|
||||
torch.autograd.backward(outputs_with_grad, args_with_grad)
|
||||
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs)
|
||||
return (None, None) + grads
|
||||
@@ -169,7 +169,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args):
|
||||
fwd_current_mode = get_current_mode()
|
||||
|
||||
# check if use autocast
|
||||
if hasattr(torch, 'is_autocast_enabled'):
|
||||
if hasattr(torch, "is_autocast_enabled"):
|
||||
has_autocast_in_fwd = torch.is_autocast_enabled()
|
||||
else:
|
||||
has_autocast_in_fwd = False
|
||||
@@ -179,7 +179,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args):
|
||||
weak_holder_list = []
|
||||
|
||||
# class for weakref.ref
|
||||
class Holder():
|
||||
class Holder:
|
||||
pass
|
||||
|
||||
# return a Holder object for later unpack process
|
||||
@@ -226,19 +226,20 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args):
|
||||
|
||||
# rerun forward, the inner_pack will store all the activations in storage
|
||||
if has_autocast_in_fwd:
|
||||
with torch.enable_grad(), \
|
||||
torch.cuda.amp.autocast(), \
|
||||
torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
|
||||
with torch.enable_grad(), torch.cuda.amp.autocast(), torch.autograd.graph.saved_tensors_hooks(
|
||||
inner_pack, inner_unpack
|
||||
):
|
||||
_unused = function(*args)
|
||||
else:
|
||||
with torch.enable_grad(), \
|
||||
torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
|
||||
with torch.enable_grad(), torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
|
||||
_unused = function(*args)
|
||||
|
||||
if x not in storage:
|
||||
raise RuntimeError("Attempt to retrieve a tensor saved by autograd multiple times without checkpoint"
|
||||
" recomputation being triggered in between, this is not currently supported. Please"
|
||||
" open an issue with details on your use case so that we can prioritize adding this.")
|
||||
raise RuntimeError(
|
||||
"Attempt to retrieve a tensor saved by autograd multiple times without checkpoint"
|
||||
" recomputation being triggered in between, this is not currently supported. Please"
|
||||
" open an issue with details on your use case so that we can prioritize adding this."
|
||||
)
|
||||
|
||||
return storage[x]
|
||||
|
||||
|
Reference in New Issue
Block a user