From a4e91bc87f3f6a9c2e9e0c2cef9062d624572220 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Tue, 12 Apr 2022 16:04:21 +0800 Subject: [PATCH] [bug] fixed grad scaler compatibility with torch 1.8 (#735) --- colossalai/amp/torch_amp/_grad_scaler.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/colossalai/amp/torch_amp/_grad_scaler.py b/colossalai/amp/torch_amp/_grad_scaler.py index 48c7eb949..d5e526d0c 100644 --- a/colossalai/amp/torch_amp/_grad_scaler.py +++ b/colossalai/amp/torch_amp/_grad_scaler.py @@ -12,6 +12,7 @@ from colossalai.context import ParallelMode import torch.distributed as dist from colossalai.core import global_context as gpc from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from packaging import version class _MultiDeviceReplicator(object): @@ -122,6 +123,14 @@ class GradScaler(object): else: self._enabled = enabled + # check version + torch_version = version.parse(torch.__version__) + assert torch_version.major == 1 + if torch_version.minor > 8: + self._higher_than_torch18 = True + else: + self._higher_than_torch18 = False + if self._enabled: assert growth_factor > 1.0, "The growth factor must be > 1.0." assert backoff_factor < 1.0, "The backoff factor must be < 1.0." @@ -404,8 +413,12 @@ class GradScaler(object): for i in range(1, len(found_infs)): found_inf_combined += found_infs[i] - torch._amp_update_scale_(_scale, _growth_tracker, found_inf_combined, self._growth_factor, - self._backoff_factor, self._growth_interval) + if self._higher_than_torch18: + torch._amp_update_scale_(_scale, _growth_tracker, found_inf_combined, self._growth_factor, + self._backoff_factor, self._growth_interval) + else: + self._scale = torch._amp_update_scale(_growth_tracker, _scale, found_inf_combined, self._growth_factor, + self._backoff_factor, self._growth_interval) # To prepare for next iteration, clear the data collected from optimizers this iteration. self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)