mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-19 00:16:51 +00:00
[util] fixed activation checkpointing on torch 1.9 (#719)
This commit is contained in:
@@ -68,7 +68,10 @@ class CheckpointFunction(torch.autograd.Function):
|
||||
else:
|
||||
ctx.inputs.append(arg)
|
||||
|
||||
ctx.save_for_backward(*tensor_inputs)
|
||||
if activation_offload:
|
||||
ctx.tensor_inputs = tensor_inputs
|
||||
else:
|
||||
ctx.save_for_backward(*tensor_inputs)
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
@@ -79,7 +82,11 @@ class CheckpointFunction(torch.autograd.Function):
|
||||
# Copy the list to avoid modifying original list.
|
||||
inputs = list(ctx.inputs)
|
||||
tensor_indices = ctx.tensor_indices
|
||||
tensors = ctx.saved_tensors
|
||||
|
||||
if ctx.activation_offload:
|
||||
tensors = ctx.tensor_inputs
|
||||
else:
|
||||
tensors = ctx.saved_tensors
|
||||
|
||||
# store the current states
|
||||
bwd_cpu_rng_state = torch.get_rng_state()
|
||||
|
Reference in New Issue
Block a user