mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[hotfix] fix bugs in testing (#659)
* remove hybrid adam in test_moe_zero_optim * fix activation checkpointing and its unitest
This commit is contained in:
@@ -10,7 +10,11 @@ from .cuda import get_current_device
|
||||
|
||||
def copy_to_device(obj, device):
|
||||
if torch.is_tensor(obj):
|
||||
return obj.to(device)
|
||||
# Notice:
|
||||
# When in no_grad context, requires_gard is False after movement
|
||||
ret = obj.to(device)
|
||||
ret.requires_grad = obj.requires_grad
|
||||
return ret
|
||||
elif isinstance(obj, list):
|
||||
return [copy_to_device(i, device) for i in obj]
|
||||
elif isinstance(obj, tuple):
|
||||
@@ -20,6 +24,7 @@ def copy_to_device(obj, device):
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@@ -39,7 +44,7 @@ class CheckpointFunction(torch.autograd.Function):
|
||||
ctx.had_autocast_in_fwd = torch.is_autocast_enabled()
|
||||
else:
|
||||
ctx.had_autocast_in_fwd = False
|
||||
|
||||
|
||||
if activation_offload:
|
||||
inputs_cuda = copy_to_device(args, ctx.device)
|
||||
else:
|
||||
@@ -69,10 +74,8 @@ class CheckpointFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def backward(ctx, *args):
|
||||
if not torch.autograd._is_checkpoint_valid():
|
||||
raise RuntimeError(
|
||||
"Checkpointing is not compatible with .grad() or when an `inputs` parameter is "
|
||||
"passed to .backward(). Please use .backward() and do not pass its `inputs` argument."
|
||||
)
|
||||
raise RuntimeError("Checkpointing is not compatible with .grad() or when an `inputs` parameter is "
|
||||
"passed to .backward(). Please use .backward() and do not pass its `inputs` argument.")
|
||||
# Copy the list to avoid modifying original list.
|
||||
inputs = list(ctx.inputs)
|
||||
tensor_indices = ctx.tensor_indices
|
||||
@@ -119,16 +122,14 @@ class CheckpointFunction(torch.autograd.Function):
|
||||
outputs_with_grad.append(outputs[i])
|
||||
args_with_grad.append(args[i])
|
||||
if len(outputs_with_grad) == 0:
|
||||
raise RuntimeError(
|
||||
"none of output has requires_grad=True,"
|
||||
" this checkpoint() is not necessary")
|
||||
raise RuntimeError("none of output has requires_grad=True,"
|
||||
" this checkpoint() is not necessary")
|
||||
torch.autograd.backward(outputs_with_grad, args_with_grad)
|
||||
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
|
||||
for inp in detached_inputs)
|
||||
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs)
|
||||
return (None, None) + grads
|
||||
|
||||
|
||||
def checkpoint(function, activation_offload ,*args):
|
||||
def checkpoint(function, activation_offload, *args):
|
||||
"""Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint.
|
||||
|
||||
Args:
|
||||
|
Reference in New Issue
Block a user