[fix] fix bwd step if condition; remove useless comments and format info;

This commit is contained in:
duanjunwen
2024-09-03 08:56:08 +00:00
parent ab643c9af7
commit 4c1f81c683
4 changed files with 54 additions and 1188 deletions

View File

@@ -33,14 +33,11 @@ def deallocate_output_tensor(out, deallocate_pipeline_outputs=False):
only useful for its '.grad_fn' field, and not its '.data'.
"""
if (out is None) or (not deallocate_pipeline_outputs):
print(
f"(out is None) or (not deallocate_pipeline_outputs): {(out is None) or (not deallocate_pipeline_outputs)}"
)
return
assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__
assert out._base is None, "counter-productive to free a view of another tensor."
# out.data = torch.empty((1,), device=out.device, dtype=out.dtype,)
out.data.storage().resize_(0)
out.data.untyped_storage().resize_(0)
class ZeroBubbleVPipeScheduler(PipelineSchedule):
@@ -457,33 +454,15 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# Retain the grad on the input_obj.
tree_map(retain_grad, input_obj)
if model_chunk_id == 0:
# bwd step
optimizer.backward_by_grad(
tensor=output_obj,
grad=output_obj_grad,
inputs=input_obj,
retain_graph=True,
)
else:
if self.stage_manager.is_first_stage(ignore_chunk=True):
# loss backward; output_obj is loss
optimizer.backward_by_grad(
tensor=output_obj,
grad=None,
inputs=input_obj,
retain_graph=True,
)
else:
# commom bwd step
optimizer.backward_by_grad(
tensor=output_obj,
grad=output_obj_grad,
inputs=input_obj,
retain_graph=True,
)
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# loss backward; output_obj is loss
output_obj_grad = None
optimizer.backward_by_grad(
tensor=output_obj,
grad=output_obj_grad,
inputs=input_obj,
retain_graph=True,
)
return input_obj.grad
def backward_w_step(
@@ -507,29 +486,39 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
Nothing need to return; we only calculate dw then update w;
"""
# calculate bwd w step ; only dw = x*dy;
if model_chunk_id == 0:
optimizer.backward_by_grad(
tensor=output_obj,
grad=output_obj_grad,
inputs=list(model_chunk[model_chunk_id].parameters()),
retain_graph=False,
)
else:
if self.stage_manager.is_first_stage(ignore_chunk=True):
optimizer.backward_by_grad(
tensor=output_obj,
grad=None,
inputs=list(model_chunk[model_chunk_id].parameters()),
retain_graph=False,
)
else:
optimizer.backward_by_grad(
tensor=output_obj,
grad=output_obj_grad,
inputs=list(model_chunk[model_chunk_id].parameters()),
retain_graph=False,
)
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# loss backward; output_obj is loss
output_obj_grad = None
optimizer.backward_by_grad(
tensor=output_obj,
grad=output_obj_grad,
inputs=list(model_chunk[model_chunk_id].parameters()),
retain_graph=False,
)
# if model_chunk_id == 0:
# optimizer.backward_by_grad(
# tensor=output_obj,
# grad=output_obj_grad,
# inputs=list(model_chunk[model_chunk_id].parameters()),
# retain_graph=False,
# )
# else:
# if self.stage_manager.is_first_stage(ignore_chunk=True):
# optimizer.backward_by_grad(
# tensor=output_obj,
# grad=None,
# inputs=list(model_chunk[model_chunk_id].parameters()),
# retain_graph=False,
# )
# else:
# optimizer.backward_by_grad(
# tensor=output_obj,
# grad=output_obj_grad,
# inputs=list(model_chunk[model_chunk_id].parameters()),
# retain_graph=False,
# )
def schedule_f(
self,
@@ -578,15 +567,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
accum_loss=accum_loss,
outputs=outputs,
)
# add input and output object for backward b
self.input_tensors[model_chunk_id].append(input_obj)
# detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj
detached_output_obj = output_obj.clone()
deallocate_output_tensor(detached_output_obj, deallocate_pipeline_outputs=True)
self.output_tensors[model_chunk_id].append(detached_output_obj)
# add output object for backward w
self.output_tensors_dw[model_chunk_id].append(detached_output_obj)
# Step3: send fwd
# add output to send_fwd_buffer
@@ -603,6 +583,15 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
else:
self.send_forward_buffer[model_chunk_id].append(output_obj)
# add input and output object for backward b
self.input_tensors[model_chunk_id].append(input_obj)
# detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj
detached_output_obj = output_obj.clone()
deallocate_output_tensor(detached_output_obj, deallocate_pipeline_outputs=True)
self.output_tensors[model_chunk_id].append(detached_output_obj)
# add output object for backward w
self.output_tensors_dw[model_chunk_id].append(detached_output_obj)
def schedule_b(
self,
scheduled_node,