mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[fix] fix zerobubble; support shardformer model type;
This commit is contained in:
@@ -431,7 +431,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
# fwd calculate
|
||||
internal_inputs = {} if input_obj is None else input_obj
|
||||
internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
|
||||
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs)
|
||||
output_obj = model_forward(model_chunk, micro_batch, internal_inputs)
|
||||
|
||||
# last layer in model
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
@@ -562,7 +562,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
optimizer.backward_by_grad(
|
||||
tensor=output_obj_,
|
||||
grad=output_obj_grad_,
|
||||
inputs=list(model_chunk[model_chunk_id].parameters()),
|
||||
inputs=list(model_chunk.parameters()),
|
||||
retain_graph=False,
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user