mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 12:43:02 +00:00
[zero] fix gradient clipping in hybrid parallelism (#2521)
* [zero] fix gradient clipping in hybrid parallelism * [testing] change model name to avoid pytest warning * [hotfix] fix unit testing
This commit is contained in:
@@ -58,10 +58,12 @@ class DynamicGradScaler(BaseGradScaler):
|
||||
|
||||
if self._min_scale:
|
||||
assert self._min_scale > 0, 'The minimum gradient scale cannot be zero or negative'
|
||||
assert self._min_scale <= self._scale, 'The minimum gradient scale cannot be greater than the current scale'
|
||||
if self._max_scale:
|
||||
assert self._min_scale > 0, 'The maximum gradient scale cannot be zero or negative'
|
||||
assert self._max_scale > 0, 'The maximum gradient scale cannot be zero or negative'
|
||||
assert self._max_scale >= self._scale, 'The maximum gradient scale cannot be smaller than the current scale'
|
||||
assert self._growth_factor > 1, 'The growth factor cannot be equal or smaller than 1'
|
||||
assert self._backoff_factor < 1 and self._backoff_factor > 0, 'The backoff factor must be between 0 and 1'
|
||||
assert 0 < self._backoff_factor < 1, 'The backoff factor must be between 0 and 1'
|
||||
assert self._hysteresis >= 0, 'The hysteresis cannot be negative'
|
||||
|
||||
def update(self, overflow: bool) -> None:
|
||||
@@ -103,3 +105,17 @@ class DynamicGradScaler(BaseGradScaler):
|
||||
self._scale = self._scale * self._growth_factor
|
||||
if self._max_scale:
|
||||
self._scale = torch.min(self._scale, self._max_scale)
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = dict()
|
||||
state_dict['scale'] = self._scale
|
||||
state_dict['growth_factor'] = self._growth_factor
|
||||
state_dict['backoff_factor'] = self._backoff_factor
|
||||
state_dict['hysteresis'] = self._hysteresis
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self._scale = state_dict['scale'].cuda(torch.cuda.current_device())
|
||||
self._growth_factor = state_dict['growth_factor']
|
||||
self._backoff_factor = state_dict['backoff_factor']
|
||||
self._hysteresis = state_dict['hysteresis']
|
||||
|
Reference in New Issue
Block a user