mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 13:59:08 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -23,7 +23,7 @@ class _MultiDeviceReplicator(object):
|
||||
"""
|
||||
|
||||
def __init__(self, master_tensor: torch.Tensor) -> None:
|
||||
assert master_tensor.is_cuda or master_tensor.device.type == 'xla'
|
||||
assert master_tensor.is_cuda or master_tensor.device.type == "xla"
|
||||
self.master = master_tensor
|
||||
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
|
||||
|
||||
@@ -118,7 +118,7 @@ class GradScaler(object):
|
||||
invokes the underlying ``optimizer.step()``, and other methods become no-ops.
|
||||
"""
|
||||
|
||||
def __init__(self, init_scale=2.**16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True):
|
||||
def __init__(self, init_scale=2.0**16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True):
|
||||
if enabled and not torch.cuda.is_available():
|
||||
warnings.warn("torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.")
|
||||
self._enabled = False
|
||||
@@ -174,7 +174,7 @@ class GradScaler(object):
|
||||
|
||||
# Short-circuit for the common case.
|
||||
if isinstance(outputs, torch.Tensor):
|
||||
assert outputs.is_cuda or outputs.device.type == 'xla'
|
||||
assert outputs.is_cuda or outputs.device.type == "xla"
|
||||
if self._scale is None:
|
||||
self._lazy_init_scale_growth_tracker(outputs.device)
|
||||
assert self._scale is not None
|
||||
@@ -186,7 +186,7 @@ class GradScaler(object):
|
||||
|
||||
def apply_scale(val):
|
||||
if isinstance(val, torch.Tensor):
|
||||
assert val.is_cuda or val.device.type == 'xla'
|
||||
assert val.is_cuda or val.device.type == "xla"
|
||||
if len(stash) == 0:
|
||||
if self._scale is None:
|
||||
self._lazy_init_scale_growth_tracker(val.device)
|
||||
@@ -214,7 +214,7 @@ class GradScaler(object):
|
||||
|
||||
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
|
||||
# Google says mypy struggles with defaultdicts type annotations.
|
||||
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
|
||||
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
|
||||
with torch.no_grad():
|
||||
for group in optimizer.param_groups:
|
||||
for param in group["params"]:
|
||||
@@ -238,8 +238,9 @@ class GradScaler(object):
|
||||
|
||||
for device, per_dtype_grads in per_device_and_dtype_grads.items():
|
||||
for grads in per_dtype_grads.values():
|
||||
torch._amp_foreach_non_finite_check_and_unscale_(grads, per_device_found_inf.get(device),
|
||||
per_device_inv_scale.get(device))
|
||||
torch._amp_foreach_non_finite_check_and_unscale_(
|
||||
grads, per_device_found_inf.get(device), per_device_inv_scale.get(device)
|
||||
)
|
||||
# For tensor parallel parameters it should be all-reduced over tensor parallel process group
|
||||
if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1:
|
||||
vals = [val for val in per_device_found_inf._per_device_tensors.values()]
|
||||
@@ -328,7 +329,7 @@ class GradScaler(object):
|
||||
.. warning::
|
||||
Closure use is not currently supported.
|
||||
"""
|
||||
if (not self._enabled):
|
||||
if not self._enabled:
|
||||
return optimizer.step(*args, **kwargs)
|
||||
|
||||
if "closure" in kwargs:
|
||||
@@ -343,7 +344,7 @@ class GradScaler(object):
|
||||
|
||||
retval = None
|
||||
|
||||
if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling):
|
||||
if hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling:
|
||||
# This optimizer has customized scale-handling logic, so we can call optimizer.step() directly.
|
||||
# The contract with custom optimizers is that their step() should accept an additional,
|
||||
# optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information:
|
||||
@@ -391,14 +392,14 @@ class GradScaler(object):
|
||||
if new_scale is not None:
|
||||
# Accept a new user-defined scale.
|
||||
if isinstance(new_scale, float):
|
||||
self._scale.fill_(new_scale) # type: ignore[union-attr]
|
||||
self._scale.fill_(new_scale) # type: ignore[union-attr]
|
||||
else:
|
||||
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
|
||||
# type: ignore[attr-defined]
|
||||
assert isinstance(new_scale, torch.cuda.FloatTensor), reason
|
||||
assert new_scale.numel() == 1, reason
|
||||
assert new_scale.requires_grad is False, reason
|
||||
self._scale.copy_(new_scale) # type: ignore[union-attr]
|
||||
self._scale.copy_(new_scale) # type: ignore[union-attr]
|
||||
else:
|
||||
# Consume shared inf/nan data collected from optimizers to update the scale.
|
||||
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
|
||||
@@ -416,11 +417,23 @@ class GradScaler(object):
|
||||
found_inf_combined += found_infs[i]
|
||||
|
||||
if self._higher_than_torch18:
|
||||
torch._amp_update_scale_(_scale, _growth_tracker, found_inf_combined, self._growth_factor,
|
||||
self._backoff_factor, self._growth_interval)
|
||||
torch._amp_update_scale_(
|
||||
_scale,
|
||||
_growth_tracker,
|
||||
found_inf_combined,
|
||||
self._growth_factor,
|
||||
self._backoff_factor,
|
||||
self._growth_interval,
|
||||
)
|
||||
else:
|
||||
self._scale = torch._amp_update_scale(_growth_tracker, _scale, found_inf_combined, self._growth_factor,
|
||||
self._backoff_factor, self._growth_interval)
|
||||
self._scale = torch._amp_update_scale(
|
||||
_growth_tracker,
|
||||
_scale,
|
||||
found_inf_combined,
|
||||
self._growth_factor,
|
||||
self._backoff_factor,
|
||||
self._growth_interval,
|
||||
)
|
||||
|
||||
# To prepare for next iteration, clear the data collected from optimizers this iteration.
|
||||
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
||||
@@ -507,13 +520,17 @@ class GradScaler(object):
|
||||
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
|
||||
should be called after :meth:`update`.
|
||||
"""
|
||||
return {
|
||||
"scale": self.get_scale(),
|
||||
"growth_factor": self._growth_factor,
|
||||
"backoff_factor": self._backoff_factor,
|
||||
"growth_interval": self._growth_interval,
|
||||
"_growth_tracker": self._get_growth_tracker()
|
||||
} if self._enabled else {}
|
||||
return (
|
||||
{
|
||||
"scale": self.get_scale(),
|
||||
"growth_factor": self._growth_factor,
|
||||
"backoff_factor": self._backoff_factor,
|
||||
"growth_interval": self._growth_interval,
|
||||
"_growth_tracker": self._get_growth_tracker(),
|
||||
}
|
||||
if self._enabled
|
||||
else {}
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
r"""
|
||||
@@ -526,8 +543,10 @@ class GradScaler(object):
|
||||
return
|
||||
|
||||
if len(state_dict) == 0:
|
||||
raise RuntimeError("The source state dict is empty, possibly because it was saved "
|
||||
"from a disabled instance of GradScaler.")
|
||||
raise RuntimeError(
|
||||
"The source state dict is empty, possibly because it was saved "
|
||||
"from a disabled instance of GradScaler."
|
||||
)
|
||||
|
||||
self._init_scale = state_dict["scale"]
|
||||
if self._scale is not None:
|
||||
@@ -542,15 +561,17 @@ class GradScaler(object):
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
if self._enabled:
|
||||
assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\
|
||||
"of an iteration, or at the end after scaler.update()."
|
||||
assert len(self._per_optimizer_states) == 0, (
|
||||
"A GradScaler instance may only be pickled at the beginning "
|
||||
"of an iteration, or at the end after scaler.update()."
|
||||
)
|
||||
# Pickling _scale and _growth_tracker Tensors directly triggers
|
||||
# "warnings.warn("pickle support for Storage will be removed in 1.5..."
|
||||
# so instead, we set the unpickled instance up to reinitialize them lazily.
|
||||
state['_init_scale'] = self.get_scale()
|
||||
state['_init_growth_tracker'] = self._get_growth_tracker()
|
||||
state['_scale'] = None
|
||||
state['_growth_tracker'] = None
|
||||
state["_init_scale"] = self.get_scale()
|
||||
state["_init_growth_tracker"] = self._get_growth_tracker()
|
||||
state["_scale"] = None
|
||||
state["_growth_tracker"] = None
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
@@ -562,8 +583,9 @@ class GradScaler(object):
|
||||
dummy_inv_scale = torch.full((1,), 1.0, dtype=torch.float32, device=_scale.device)
|
||||
found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=_scale.device)
|
||||
|
||||
self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \
|
||||
self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)
|
||||
self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = self._unscale_grads_(
|
||||
optimizer, dummy_inv_scale, found_inf, True
|
||||
)
|
||||
|
||||
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
|
||||
|
||||
|
Reference in New Issue
Block a user