[hotfix] fix bugs in testing (#659)

* remove hybrid adam in test_moe_zero_optim

* fix activation checkpointing and its unitest
This commit is contained in:
HELSON
2022-04-02 21:58:47 +08:00
committed by GitHub
parent 036404ca8a
commit e5d615aeee
3 changed files with 31 additions and 31 deletions

View File

@@ -10,7 +10,11 @@ from .cuda import get_current_device
def copy_to_device(obj, device):
if torch.is_tensor(obj):
return obj.to(device)
# Notice:
# When in no_grad context, requires_gard is False after movement
ret = obj.to(device)
ret.requires_grad = obj.requires_grad
return ret
elif isinstance(obj, list):
return [copy_to_device(i, device) for i in obj]
elif isinstance(obj, tuple):
@@ -20,6 +24,7 @@ def copy_to_device(obj, device):
else:
return obj
class CheckpointFunction(torch.autograd.Function):
@staticmethod
@@ -39,7 +44,7 @@ class CheckpointFunction(torch.autograd.Function):
ctx.had_autocast_in_fwd = torch.is_autocast_enabled()
else:
ctx.had_autocast_in_fwd = False
if activation_offload:
inputs_cuda = copy_to_device(args, ctx.device)
else:
@@ -69,10 +74,8 @@ class CheckpointFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, *args):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError(
"Checkpointing is not compatible with .grad() or when an `inputs` parameter is "
"passed to .backward(). Please use .backward() and do not pass its `inputs` argument."
)
raise RuntimeError("Checkpointing is not compatible with .grad() or when an `inputs` parameter is "
"passed to .backward(). Please use .backward() and do not pass its `inputs` argument.")
# Copy the list to avoid modifying original list.
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
@@ -119,16 +122,14 @@ class CheckpointFunction(torch.autograd.Function):
outputs_with_grad.append(outputs[i])
args_with_grad.append(args[i])
if len(outputs_with_grad) == 0:
raise RuntimeError(
"none of output has requires_grad=True,"
" this checkpoint() is not necessary")
raise RuntimeError("none of output has requires_grad=True,"
" this checkpoint() is not necessary")
torch.autograd.backward(outputs_with_grad, args_with_grad)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
for inp in detached_inputs)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs)
return (None, None) + grads
def checkpoint(function, activation_offload ,*args):
def checkpoint(function, activation_offload, *args):
"""Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint.
Args: