add fp32 master params in sharded adam

This commit is contained in:
ver217 2022-03-03 15:42:53 +08:00
parent 6185b9772d
commit 6c290dbb08

View File

@ -1,5 +1,5 @@
from enum import Enum from enum import Enum
from typing import Optional, Union from typing import Dict, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -11,6 +11,7 @@ from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from torch import Tensor from torch import Tensor
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from torch.optim import Optimizer from torch.optim import Optimizer
from ._utils import has_inf_or_nan from ._utils import has_inf_or_nan
@ -39,7 +40,7 @@ class ShardedAdam(ColossalaiOptimizer):
super().__init__(adam_optim) super().__init__(adam_optim)
self.model: Union[nn.Module, ShardedModelV2] = sharded_model self.model: Union[nn.Module, ShardedModelV2] = sharded_model
self.model_is_sharded = isinstance(sharded_model, ShardedModelV2) self.model_is_sharded = isinstance(sharded_model, ShardedModelV2)
self.state_device = torch.cuda.current_device() if not cpu_offload else torch.device('cpu') self.device = torch.cuda.current_device() if not cpu_offload else torch.device('cpu')
self.optim_state: OptimState = OptimState.UNSCALED self.optim_state: OptimState = OptimState.UNSCALED
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA) self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
self.mp_process_group = mp_process_group or gpc.get_group(ParallelMode.MODEL) self.mp_process_group = mp_process_group or gpc.get_group(ParallelMode.MODEL)
@ -51,35 +52,18 @@ class ShardedAdam(ColossalaiOptimizer):
growth_interval=growth_interval, growth_interval=growth_interval,
hysteresis=hysteresis, hysteresis=hysteresis,
max_scale=max_scale) max_scale=max_scale)
self._found_overflow: Tensor = torch.FloatTensor([0]).to(self.state_device) self._found_overflow: Tensor = torch.FloatTensor([0]).to(self.device)
# Store fp32 params
self.master_params: Dict[Parameter, Tensor] = {}
# Early state initialization
for group in adam_optim.param_groups: for group in adam_optim.param_groups:
for p in group['params']: for p in group['params']:
state_shape = p.shape
if hasattr(p, 'ca_attr'): if hasattr(p, 'ca_attr'):
assert p.ca_attr.is_sharded, 'ShardedAdam can be only used with sharded model' assert p.ca_attr.is_sharded, 'ShardedAdam can be only used with sharded model'
# TODO: use payload shape self.master_params[p] = p.ca_attr.payload(self.device).to(torch.float)
state_shape = p.ca_attr.payload(self.state_device) else:
state = adam_optim.state[p] self.master_params[p] = p.data.to(torch.float)
assert len(state) == 0, 'adam optimizer initialized'
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros(state_shape,
memory_format=torch.preserve_format,
dtype=torch.float,
device=self.state_device)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros(state_shape,
memory_format=torch.preserve_format,
dtype=torch.float,
device=self.state_device)
if group['amsgrad']:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros(state_shape,
memory_format=torch.preserve_format,
dtype=torch.float,
device=self.state_device)
def step(self, *args, **kwargs): def step(self, *args, **kwargs):
# unscale grads if scaled # unscale grads if scaled
@ -93,19 +77,15 @@ class ShardedAdam(ColossalaiOptimizer):
self.zero_grad() self.zero_grad()
return return
# Write payload back to p.data # Write master param to p.data
for group in self.optim.param_groups: for group in self.optim.param_groups:
for p in group['params']: for p in group['params']:
data = p.data p.data = self.master_params[p]
if hasattr(p, 'ca_attr'):
data = p.ca_attr.payload(self.state_device)
if torch.is_floating_point(data) and data.dtype != torch.float:
data = data.to(torch.float)
p.data = data
ret = self.optim.step(*args, **kwargs) ret = self.optim.step(*args, **kwargs)
# Set p.data to None # Write master param to payload and set p.data to None
for group in self.optim.param_groups: for group in self.optim.param_groups:
for p in group['params']: for p in group['params']:
# TODO: update payload
p.data = None p.data = None
return ret return ret