[test] make zero engine test really work (#447)

This commit is contained in:
Jiarui Fang 2022-03-17 17:24:25 +08:00 committed by GitHub
parent bb2790cf0b
commit 0fcfb1e00d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 39 additions and 28 deletions

View File

@ -20,6 +20,7 @@ class CPUAdam(torch.optim.Optimizer):
The difference is that model_params are sharded parameters belonging to a ShardedModelV2 instance. The difference is that model_params are sharded parameters belonging to a ShardedModelV2 instance.
The sharded param of model_params can resident on both CPU and CUDA. The sharded param of model_params can resident on both CPU and CUDA.
""" """
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super(CPUAdam, self).__init__(model_params, default_args) super(CPUAdam, self).__init__(model_params, default_args)
self.opt_id = CPUAdam.optimizer_id self.opt_id = CPUAdam.optimizer_id
@ -34,7 +35,8 @@ class CPUAdam(torch.optim.Optimizer):
self.cpu_adam_op.create_adam(self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode, simd_log) self.cpu_adam_op.create_adam(self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode, simd_log)
def __del__(self): def __del__(self):
self.cpu_adam_op.destroy_adam(self.opt_id) if self.cpu_adam_op:
self.cpu_adam_op.destroy_adam(self.opt_id)
def torch_adam_update(self, def torch_adam_update(self,
data, data,
@ -72,7 +74,6 @@ class CPUAdam(torch.optim.Optimizer):
@torch.no_grad() @torch.no_grad()
def step(self, closure=None): def step(self, closure=None):
loss = None loss = None
if closure is not None: if closure is not None:
with torch.enable_grad(): with torch.enable_grad():

View File

@ -2,9 +2,10 @@ from typing import List
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch._utils import _flatten_dense_tensors as flatten
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from torch._utils import _flatten_dense_tensors as flatten
from .tensor_shard_strategy import TensorShardStrategy from .tensor_shard_strategy import TensorShardStrategy

View File

@ -2,6 +2,7 @@ from typing import List, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._zero3_utils import get_shard from colossalai.zero.sharded_model._zero3_utils import get_shard
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor

View File

@ -1,9 +1,14 @@
from enum import Enum from enum import Enum
from typing import Callable, Dict, Optional, Union from typing import Dict, Optional, Type, Any
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
@ -11,11 +16,8 @@ from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp32 from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp32
from torch import Tensor from colossalai.logging import get_dist_logger
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from torch.optim import Optimizer
from typing import Type, Any
from ._utils import has_inf_or_nan from ._utils import has_inf_or_nan
@ -82,7 +84,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
:type defaults: dict() :type defaults: dict()
""" """
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel' assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
self._logger = get_dist_logger('ShardedOptimV2 logger')
self._optim_defaults = defaults self._optim_defaults = defaults
# initialize the M, V as zeros tensors and initialize param fp32 from sharded_model.parameters() # initialize the M, V as zeros tensors and initialize param fp32 from sharded_model.parameters()
@ -136,23 +138,24 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self.grad_scaler.update(found_inf) self.grad_scaler.update(found_inf)
if found_inf: if found_inf:
self._logger.info('found inf during ShardedOptimV2 step')
self.zero_grad() self.zero_grad()
return return
# assign master param pointers to p.data. # assign master param pointers to p.data.
# We will not trigger data copy here. # We will not trigger data copy here.
for group in self.optim.param_groups: for group in self.optimizer.param_groups:
for p in group['params']: for p in group['params']:
p.data = self.master_params[p] p.data = self.master_params[p]
# Now p.data is sharded # Now p.data is sharded
# So optimizer states are sharded naturally # So optimizer states are sharded naturally
ret = self.optim.step(*args, **kwargs) ret = self.optimizer.step(*args, **kwargs)
# Copy master param data (fp32) to payload of col_attr (fp16) # Copy master param data (fp32) to payload of col_attr (fp16)
# TODO() improve efficiency by gathering tensors into a chunk and transfering # TODO() improve efficiency by gathering tensors into a chunk and transfering
# a chunk. # a chunk.
for group in self.optim.param_groups: for group in self.optimizer.param_groups:
for p in group['params']: for p in group['params']:
is_param_sharded = p.col_attr.data.is_sharded is_param_sharded = p.col_attr.data.is_sharded
if not is_param_sharded: if not is_param_sharded:
@ -196,7 +199,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self._found_overflow.fill_(0.0) self._found_overflow.fill_(0.0)
# check for overflow # check for overflow
for group in self.optim.param_groups: for group in self.optimizer.param_groups:
for p in group['params']: for p in group['params']:
if has_inf_or_nan(p.grad): if has_inf_or_nan(p.grad):
self._found_overflow.fill_(1.0) self._found_overflow.fill_(1.0)
@ -212,7 +215,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
def _unscale_grads(self): def _unscale_grads(self):
assert self.optim_state == OptimState.SCALED assert self.optim_state == OptimState.SCALED
for group in self.optim.param_groups: for group in self.optimizer.param_groups:
for p in group['params']: for p in group['params']:
if p.grad is not None: if p.grad is not None:
p.grad.data.div_(self.loss_scale) p.grad.data.div_(self.loss_scale)
@ -222,7 +225,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# We must set grad to None # We must set grad to None
# Because we will judge whether local grad accumulation # Because we will judge whether local grad accumulation
# is enabled by wheter grad is None # is enabled by wheter grad is None
self.optim.zero_grad(set_to_none=True) self.optimizer.zero_grad(set_to_none=True)
def sync_grad(self): def sync_grad(self):
pass pass

View File

@ -6,6 +6,7 @@ import torch.distributed as dist
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import checkpoint from colossalai.utils import checkpoint
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.nn.optimizer import CPUAdam
LOGGER = get_dist_logger('zero_test') LOGGER = get_dist_logger('zero_test')
@ -19,16 +20,16 @@ _ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25,
use_memory_tracer=False) use_memory_tracer=False)
_ZERO_OPTIMIZER_CONFIG = dict( _ZERO_OPTIMIZER_CONFIG = dict(
optimizer_class=torch.optim.Adam, optimizer_class=torch.optim.Adam, #CPUAdam
cpu_offload=False, cpu_offload=False,
initial_scale=2**32, initial_scale=2**5,
min_scale=1, min_scale=1,
growth_factor=2, growth_factor=2,
backoff_factor=0.5, backoff_factor=0.5,
growth_interval=1000, growth_interval=1000,
hysteresis=2, hysteresis=2,
max_scale=2**32, max_scale=2**32,
) lr=1e-3)
ZERO_PARALLEL_CONFIG = dict(fp16=dict(mode=None,), ZERO_PARALLEL_CONFIG = dict(fp16=dict(mode=None,),
zero=dict( zero=dict(

View File

@ -13,6 +13,7 @@ from colossalai.zero.sharded_optim import ShardedOptimizerV2
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.optimizer import CPUAdam from colossalai.nn.optimizer import CPUAdam
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from common import CONFIG, check_sharded_params_padding from common import CONFIG, check_sharded_params_padding
@ -71,6 +72,8 @@ def _run_dist(rank, world_size, port, cpu_offload, shard_strategy, use_cpuadam):
_run_step(model, optim, data, label, criterion, False) _run_step(model, optim, data, label, criterion, False)
_run_step(zero_model, sharded_optim, data, label, criterion, False) _run_step(zero_model, sharded_optim, data, label, criterion, False)
check_sharded_params_padding(model, zero_model, loose=True) check_sharded_params_padding(model, zero_model, loose=True)
for param in model.parameters():
assert not has_inf_or_nan(param)
# use_cpuadam = True can be used with cpu_offload = False # use_cpuadam = True can be used with cpu_offload = False
@ -105,7 +108,4 @@ def test_sharded_optim_v2_cpu_adam(world_size, cpu_offload, shard_strategy, use_
if __name__ == '__main__': if __name__ == '__main__':
test_sharded_optim_v2_cpu_adam(world_size=2, test_sharded_optim_v2_cpu_adam(world_size=2, cpu_offload=True, shard_strategy=TensorShardStrategy, use_cpuadam=True)
cpu_offload=False,
shard_strategy=TensorShardStrategy,
use_cpuadam=True)

View File

@ -8,6 +8,7 @@ import pytest
import colossalai import colossalai
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.distributed as dist import torch.distributed as dist
@ -32,12 +33,13 @@ def run_dist(rank, world_size, port, parallel_config):
colo_model = model_builder(checkpoint=True) colo_model = model_builder(checkpoint=True)
torch_model = copy.deepcopy(colo_model).cuda() torch_model = copy.deepcopy(colo_model).cuda()
torch_model.train()
engine, train_dataloader, _, _ = colossalai.initialize(colo_model, engine, train_dataloader, _, _ = colossalai.initialize(colo_model,
optimizer=optimizer_class, optimizer=optimizer_class,
criterion=criterion, criterion=criterion,
train_dataloader=train_dataloader) train_dataloader=train_dataloader)
engine.train() engine.train()
torch_optimizer = optimizer_class(torch_model.parameters()) torch_optimizer = optimizer_class(torch_model.parameters(), lr=1e-3)
if dist.get_world_size() > 1: if dist.get_world_size() > 1:
torch_model = DDP(torch_model) torch_model = DDP(torch_model)
@ -66,15 +68,17 @@ def run_dist(rank, world_size, port, parallel_config):
engine.step() engine.step()
torch_loss.backward() torch_loss.backward()
for param in torch_model.parameters():
if param.grad is not None:
assert not has_inf_or_nan(param.grad)
torch_optimizer.step() torch_optimizer.step()
i += 1 i += 1
# for torch_param, zero_param in zip(torch_model.parameters(), colo_model.parameters()):
# assert torch.allclose(torch_param, zero_param), f"diff {torch_param - zero_param}"
if parallel_config == MP_PARALLEL_CONFIG: if parallel_config == MP_PARALLEL_CONFIG:
check_params(torch_model, colo_model, loose=True) check_params(torch_model, colo_model, loose=True)
elif isinstance(colo_model, ShardedModelV2): elif parallel_config == ZERO_PARALLEL_CONFIG:
check_sharded_params_padding(torch_model, colo_model, loose=True) check_sharded_params_padding(torch_model, colo_model, loose=True)