mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 17:46:42 +00:00
[fp8] hotfix backward hook (#6053)
* [fp8] hotfix backward hook * [fp8] hotfix pipeline loss accumulation
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
|
||||
import copy
|
||||
from contextlib import contextmanager
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from functools import partial
|
||||
from typing import Dict, Iterator, List, Optional, Tuple
|
||||
from weakref import proxy
|
||||
@@ -88,6 +88,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
master_weights: bool = True, # master weights
|
||||
overlap_allgather: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
backward_context=None,
|
||||
):
|
||||
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
|
||||
|
||||
@@ -130,6 +131,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
self._reduce_bucket_size = reduce_bucket_size
|
||||
self._communication_dtype = communication_dtype
|
||||
self._fp8_communication = fp8_communication
|
||||
self._backward_context = backward_context
|
||||
|
||||
# gradient clipping
|
||||
self._clip_grad_norm = clip_grad_norm
|
||||
@@ -429,7 +431,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
if self.mixed_precision_mixin is not None:
|
||||
loss = self.mixed_precision_mixin.pre_backward(loss)
|
||||
|
||||
loss.backward(retain_graph=retain_graph)
|
||||
ctx = nullcontext() if self._backward_context is None else self._backward_context()
|
||||
with ctx:
|
||||
loss.backward(retain_graph=retain_graph)
|
||||
|
||||
if not self.require_grad_sync:
|
||||
return
|
||||
|
Reference in New Issue
Block a user