[fix] fix zerobubble; support shardformer model type;

This commit is contained in:
duanjunwen
2024-09-26 06:11:56 +00:00
parent 83163fa70c
commit a92e16719b
3 changed files with 109 additions and 129 deletions

View File

@@ -431,7 +431,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# fwd calculate
internal_inputs = {} if input_obj is None else input_obj
internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs)
output_obj = model_forward(model_chunk, micro_batch, internal_inputs)
# last layer in model
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
@@ -562,7 +562,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
optimizer.backward_by_grad(
tensor=output_obj_,
grad=output_obj_grad_,
inputs=list(model_chunk[model_chunk_id].parameters()),
inputs=list(model_chunk.parameters()),
retain_graph=False,
)