[fix] fix fwd branch, fwd pass both micro_batch & internal_inputs'

This commit is contained in:
duanjunwen
2024-09-20 07:34:43 +00:00
parent b6616f544e
commit 1739df423c
2 changed files with 6 additions and 15 deletions

View File

@@ -429,18 +429,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# Only attention_mask from micro_batch is used
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
# fwd calculate
if isinstance(model_chunk, ModuleList):
# fwd for ModuleList model
if input_obj is None:
output_obj = model_chunk[model_chunk_id](**micro_batch)
else:
output_obj = model_chunk[model_chunk_id](**input_obj)
else:
# fwd for shardformer
# NOTE: in shardformer, each device still has the entire model, so we need to use relevant stage layers
internal_inputs = {} if input_obj is None else input_obj
# internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs)
internal_inputs = {} if input_obj is None else input_obj
# internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs)
# last layer in model
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):