mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
Merge branch 'main' into dev/zero_bubble
This commit is contained in:
@@ -286,7 +286,6 @@ class InterleavedSchedule(PipelineSchedule):
|
||||
# for the first stage, input_obj is None
|
||||
# for other stages, input_obj is the output of the previous stage containing hidden_states etc.
|
||||
# Only attention_mask from micro_batch is used
|
||||
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if isinstance(model_chunk, ModuleList):
|
||||
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)
|
||||
|
@@ -244,6 +244,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
output_obj = model_forward(model, micro_batch, input_obj)
|
||||
if self.stage_manager.is_last_stage():
|
||||
loss = criterion(output_obj, micro_batch) / self.num_microbatches
|
||||
|
||||
if accum_loss is not None:
|
||||
accum_loss.add_(loss.detach())
|
||||
if outputs is not None:
|
||||
|
Reference in New Issue
Block a user