Merge branch 'main' into dev/zero_bubble

This commit is contained in:
duanjunwen
2024-08-26 04:03:20 +00:00
96 changed files with 2430 additions and 690 deletions

View File

@@ -286,7 +286,6 @@ class InterleavedSchedule(PipelineSchedule):
# for the first stage, input_obj is None
# for other stages, input_obj is the output of the previous stage containing hidden_states etc.
# Only attention_mask from micro_batch is used
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if isinstance(model_chunk, ModuleList):
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)

View File

@@ -244,6 +244,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
output_obj = model_forward(model, micro_batch, input_obj)
if self.stage_manager.is_last_stage():
loss = criterion(output_obj, micro_batch) / self.num_microbatches
if accum_loss is not None:
accum_loss.add_(loss.detach())
if outputs is not None: