mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[bug] fixed grad scaler compatibility with torch 1.8 (#735)
This commit is contained in:
@@ -12,6 +12,7 @@ from colossalai.context import ParallelMode
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
|
||||||
class _MultiDeviceReplicator(object):
|
class _MultiDeviceReplicator(object):
|
||||||
@@ -122,6 +123,14 @@ class GradScaler(object):
|
|||||||
else:
|
else:
|
||||||
self._enabled = enabled
|
self._enabled = enabled
|
||||||
|
|
||||||
|
# check version
|
||||||
|
torch_version = version.parse(torch.__version__)
|
||||||
|
assert torch_version.major == 1
|
||||||
|
if torch_version.minor > 8:
|
||||||
|
self._higher_than_torch18 = True
|
||||||
|
else:
|
||||||
|
self._higher_than_torch18 = False
|
||||||
|
|
||||||
if self._enabled:
|
if self._enabled:
|
||||||
assert growth_factor > 1.0, "The growth factor must be > 1.0."
|
assert growth_factor > 1.0, "The growth factor must be > 1.0."
|
||||||
assert backoff_factor < 1.0, "The backoff factor must be < 1.0."
|
assert backoff_factor < 1.0, "The backoff factor must be < 1.0."
|
||||||
@@ -404,8 +413,12 @@ class GradScaler(object):
|
|||||||
for i in range(1, len(found_infs)):
|
for i in range(1, len(found_infs)):
|
||||||
found_inf_combined += found_infs[i]
|
found_inf_combined += found_infs[i]
|
||||||
|
|
||||||
torch._amp_update_scale_(_scale, _growth_tracker, found_inf_combined, self._growth_factor,
|
if self._higher_than_torch18:
|
||||||
self._backoff_factor, self._growth_interval)
|
torch._amp_update_scale_(_scale, _growth_tracker, found_inf_combined, self._growth_factor,
|
||||||
|
self._backoff_factor, self._growth_interval)
|
||||||
|
else:
|
||||||
|
self._scale = torch._amp_update_scale(_growth_tracker, _scale, found_inf_combined, self._growth_factor,
|
||||||
|
self._backoff_factor, self._growth_interval)
|
||||||
|
|
||||||
# To prepare for next iteration, clear the data collected from optimizers this iteration.
|
# To prepare for next iteration, clear the data collected from optimizers this iteration.
|
||||||
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
||||||
|
Reference in New Issue
Block a user