mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[fix] fix fwd branch, fwd pass both micro_batch & internal_inputs'
This commit is contained in:
@@ -429,18 +429,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
# Only attention_mask from micro_batch is used
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
# fwd calculate
|
||||
if isinstance(model_chunk, ModuleList):
|
||||
# fwd for ModuleList model
|
||||
if input_obj is None:
|
||||
output_obj = model_chunk[model_chunk_id](**micro_batch)
|
||||
else:
|
||||
output_obj = model_chunk[model_chunk_id](**input_obj)
|
||||
else:
|
||||
# fwd for shardformer
|
||||
# NOTE: in shardformer, each device still has the entire model, so we need to use relevant stage layers
|
||||
internal_inputs = {} if input_obj is None else input_obj
|
||||
# internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
|
||||
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs)
|
||||
internal_inputs = {} if input_obj is None else input_obj
|
||||
# internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
|
||||
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs)
|
||||
|
||||
# last layer in model
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
|
Reference in New Issue
Block a user