[fp8] hotfix backward hook (#6053)

* [fp8] hotfix backward hook

* [fp8] hotfix pipeline loss accumulation
This commit is contained in:
Hongxin Liu
2024-09-11 16:11:25 +08:00
committed by GitHub
parent c54c4fcd15
commit 13946c4448
6 changed files with 31 additions and 17 deletions

View File

@@ -318,7 +318,7 @@ class InterleavedSchedule(PipelineSchedule):
if self.stage_manager.is_last_stage():
loss = criterion(output_obj, micro_batch) / self.num_microbatch
if accum_loss is not None:
accum_loss.add_(loss.detach())
accum_loss.add_(loss.data)
if outputs is not None:
outputs.append(tree_map(detach, output_obj))
return loss