[fp8] hotfix backward hook (#6053)

* [fp8] hotfix backward hook

* [fp8] hotfix pipeline loss accumulation
This commit is contained in:
Hongxin Liu
2024-09-11 16:11:25 +08:00
committed by GitHub
parent c54c4fcd15
commit 13946c4448
6 changed files with 31 additions and 17 deletions

View File

@@ -1,6 +1,6 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import copy
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from functools import partial
from typing import Dict, Iterator, List, Optional, Tuple
from weakref import proxy
@@ -88,6 +88,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
master_weights: bool = True, # master weights
overlap_allgather: bool = False,
fp8_communication: bool = False,
backward_context=None,
):
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
@@ -130,6 +131,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self._reduce_bucket_size = reduce_bucket_size
self._communication_dtype = communication_dtype
self._fp8_communication = fp8_communication
self._backward_context = backward_context
# gradient clipping
self._clip_grad_norm = clip_grad_norm
@@ -429,7 +431,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if self.mixed_precision_mixin is not None:
loss = self.mixed_precision_mixin.pre_backward(loss)
loss.backward(retain_graph=retain_graph)
ctx = nullcontext() if self._backward_context is None else self._backward_context()
with ctx:
loss.backward(retain_graph=retain_graph)
if not self.require_grad_sync:
return