mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-11-03 23:48:41 +00:00 
			
		
		
		
	[zero] sharded model support the reuse of fp16 shard (#495)
* sharded model supports reuse fp16 shard * rename variable * polish code * polish code * polish code
This commit is contained in:
		@@ -56,6 +56,8 @@ class CPUAdam(torch.optim.Optimizer):
 | 
			
		||||
                          bias_correction2,
 | 
			
		||||
                          loss_scale,
 | 
			
		||||
                          use_adamw=False):
 | 
			
		||||
        # FIXME(ver217): remove the below line when replace torch adam with fused adam
 | 
			
		||||
        grad = grad.float()
 | 
			
		||||
        if loss_scale is not None:
 | 
			
		||||
            grad.div_(loss_scale)
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -29,24 +29,22 @@ class ShardedModelV2(nn.Module):
 | 
			
		||||
    compared to classic data parallelism while the computational granularity and communication efficiency are retained.
 | 
			
		||||
    Note that you must use `ShardedModelV2` with `ShardedOptimizerV2`.
 | 
			
		||||
 | 
			
		||||
    :param module: A sharded module, which must be initialized by `ZeroInitContext`.
 | 
			
		||||
    :type module: nn.Module
 | 
			
		||||
    :param shard_strategy: A shard strategy to manage shard behavior.
 | 
			
		||||
    :type shard_strategy: BaseShardStrategy
 | 
			
		||||
    :param process_group: Data parallel process group, defaults to None
 | 
			
		||||
    :type process_group: Optional[ProcessGroup], optional
 | 
			
		||||
    :param reduce_scatter_process_group: Reduce-scatter process group, defaults to None. Generally, it should be `None`.
 | 
			
		||||
    :type reduce_scatter_process_group: Optional[ProcessGroup], optional
 | 
			
		||||
    :param reduce_scatter_bucket_size_mb: Reduce-scatter bucket size in *MB*, defaults to 25
 | 
			
		||||
    :type reduce_scatter_bucket_size_mb: int, optional
 | 
			
		||||
    :param fp32_reduce_scatter: If set to `True`, gradients are forced to FP32 before reduce-scatter, defaults to False
 | 
			
		||||
    :type fp32_reduce_scatter: bool, optional
 | 
			
		||||
    :param offload_config: We currently only support CPU offload. Set to `{"device": "cpu"}` to enable CPU offload, defaults to None
 | 
			
		||||
    :type offload_config: Optional[dict], optional
 | 
			
		||||
    :param gradient_predivide_factor: Gradient is divived by this value before reduce-scatter, defaults to 1.0
 | 
			
		||||
    :type gradient_predivide_factor: Optional[float], optional
 | 
			
		||||
    :param use_memory_tracer: Whether to use memoty tracer, defaults to False
 | 
			
		||||
    :type use_memory_tracer: bool, optional
 | 
			
		||||
    Args:
 | 
			
		||||
        module (nn.Module): A sharded module, which must be initialized by `ZeroInitContext`.
 | 
			
		||||
        shard_strategy (BaseShardStrategy): A shard strategy to manage shard behavior.
 | 
			
		||||
        process_group (Optional[ProcessGroup], optional): Data parallel process group. Defaults to None.
 | 
			
		||||
        reduce_scatter_process_group (Optional[ProcessGroup], optional): Reduce-scatter process group. 
 | 
			
		||||
            Generally, it should be `None`, and it's the same as `process_group`. Defaults to None.
 | 
			
		||||
        reduce_scatter_bucket_size_mb (int, optional): Reduce-scatter bucket size in *MB*. Defaults to 25.
 | 
			
		||||
        fp32_reduce_scatter (bool, optional): If set to `True`, gradients are forced to FP32 before reduce-scatter. Defaults to False.
 | 
			
		||||
        offload_config (Optional[dict], optional): We currently only support CPU offload. Set to `{"device": "cpu"}` to enable CPU offload. Defaults to None.
 | 
			
		||||
        gradient_predivide_factor (Optional[float], optional): Gradient is divived by this value before reduce-scatter. Defaults to 1.0.
 | 
			
		||||
        use_memory_tracer (bool, optional): Whether to use memoty tracer. Defaults to False.
 | 
			
		||||
        reuse_fp16_shard (bool, optional): Whether to reuse fp16 shard for param and grad. 
 | 
			
		||||
            Enabling this can reduce GPU memory usage, but you have to make sure you disable it when using gradient accumulation. 
 | 
			
		||||
            In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad). 
 | 
			
		||||
            We find that PyTorch's optimizers don't support mixed precision, 
 | 
			
		||||
            so we recommend you enable this only when using our CPUAdam with CPU offload. Defaults to False.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
@@ -58,7 +56,8 @@ class ShardedModelV2(nn.Module):
 | 
			
		||||
                 fp32_reduce_scatter: bool = False,
 | 
			
		||||
                 offload_config: Optional[dict] = None,
 | 
			
		||||
                 gradient_predivide_factor: Optional[float] = 1.0,
 | 
			
		||||
                 use_memory_tracer: bool = False):
 | 
			
		||||
                 use_memory_tracer: bool = False,
 | 
			
		||||
                 reuse_fp16_shard: bool = False):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.logger = get_dist_logger()
 | 
			
		||||
 | 
			
		||||
@@ -97,8 +96,8 @@ class ShardedModelV2(nn.Module):
 | 
			
		||||
        self.fp32_reduce_scatter = fp32_reduce_scatter
 | 
			
		||||
        self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False
 | 
			
		||||
        for param in module.parameters():
 | 
			
		||||
            # Init `offload_fp32_grad`
 | 
			
		||||
            param.col_attr.offload_fp32_grad = self._cpu_offload
 | 
			
		||||
            # Init `offload_grad`
 | 
			
		||||
            param.col_attr.offload_grad = self._cpu_offload
 | 
			
		||||
 | 
			
		||||
        # We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
 | 
			
		||||
        # So we use 1.0 as the default gradient_predivide_factor
 | 
			
		||||
@@ -114,6 +113,7 @@ class ShardedModelV2(nn.Module):
 | 
			
		||||
        self._require_backward_grad_sync: bool = True
 | 
			
		||||
 | 
			
		||||
        self._cuda_margin_space = 0
 | 
			
		||||
        self.reuse_fp16_shard = reuse_fp16_shard
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def cuda_margin_space(self):
 | 
			
		||||
@@ -143,11 +143,7 @@ class ShardedModelV2(nn.Module):
 | 
			
		||||
        for ophook in self._ophook_list:
 | 
			
		||||
            ophook.post_iter()
 | 
			
		||||
 | 
			
		||||
    @torch.no_grad()
 | 
			
		||||
    def _post_backward_operations(self) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        The method includes operations required to be processed after backward
 | 
			
		||||
        """
 | 
			
		||||
    def _update_memstats(self):
 | 
			
		||||
        if self._iter_cnter == 0 and self._memstats_collector:
 | 
			
		||||
            self._memstats_collector.finish_collection()
 | 
			
		||||
        if self._memstats_collector:
 | 
			
		||||
@@ -160,6 +156,13 @@ class ShardedModelV2(nn.Module):
 | 
			
		||||
 | 
			
		||||
        self._iter_cnter += 1
 | 
			
		||||
 | 
			
		||||
    @torch.no_grad()
 | 
			
		||||
    def _post_backward_operations(self) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        The method includes operations required to be processed after backward
 | 
			
		||||
        """
 | 
			
		||||
        self._update_memstats()
 | 
			
		||||
 | 
			
		||||
        if self._require_backward_grad_sync:
 | 
			
		||||
            # Flush any unreduced buckets in the post_backward stream.
 | 
			
		||||
            with torch.cuda.stream(self.comm_stream):
 | 
			
		||||
@@ -171,9 +174,11 @@ class ShardedModelV2(nn.Module):
 | 
			
		||||
        self.reducer.free()
 | 
			
		||||
        # In case some post bwd hook is not fired
 | 
			
		||||
        if self.shard_param:
 | 
			
		||||
            tensor_list = []
 | 
			
		||||
            for p in self.module.parameters():
 | 
			
		||||
                if not p.col_attr.param_is_sharded:
 | 
			
		||||
                    self.shard_strategy.shard([p.col_attr.sharded_data_tensor], self.process_group)
 | 
			
		||||
                    tensor_list.append(p.col_attr.sharded_data_tensor)
 | 
			
		||||
            self.shard_strategy.shard(tensor_list, self.process_group)
 | 
			
		||||
        for p in self.module.parameters():
 | 
			
		||||
            p.col_attr.bwd_count = 0
 | 
			
		||||
            if not p.requires_grad:
 | 
			
		||||
@@ -191,13 +196,17 @@ class ShardedModelV2(nn.Module):
 | 
			
		||||
            # If world size == 1 and sharded param,
 | 
			
		||||
            # the shape `grad` is the same as unsharded param
 | 
			
		||||
            # So we can just use `view(-1)` to ensure grad is a flat tensor shard
 | 
			
		||||
            if self.reuse_fp16_shard:
 | 
			
		||||
                grad = p.col_attr.sharded_data_tensor.payload
 | 
			
		||||
            else:
 | 
			
		||||
                grad = cast_tensor_to_fp32(p.col_attr.fp16_grad)
 | 
			
		||||
            if p.col_attr.offload_fp32_grad:
 | 
			
		||||
            if p.col_attr.offload_grad:
 | 
			
		||||
                col_move_to_cpu(grad)
 | 
			
		||||
            if p.col_attr.fp32_grad is not None:
 | 
			
		||||
                assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True'
 | 
			
		||||
                p.col_attr.fp32_grad.add_(grad.view_as(p.col_attr.fp32_grad))
 | 
			
		||||
                grad = p.col_attr.fp32_grad
 | 
			
		||||
            p.grad.data = grad.view(-1)
 | 
			
		||||
            p.grad.data = grad
 | 
			
		||||
            p.col_attr.fp16_grad = None
 | 
			
		||||
            p.col_attr.fp32_grad = None
 | 
			
		||||
 | 
			
		||||
@@ -250,10 +259,14 @@ class ShardedModelV2(nn.Module):
 | 
			
		||||
        return empty_grad
 | 
			
		||||
 | 
			
		||||
    def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
 | 
			
		||||
        reduced_grad = reduced_grad.view(-1)
 | 
			
		||||
        if self.gradient_postdivide_factor > 1:
 | 
			
		||||
            # Average grad by world_size for consistency with PyTorch DDP.
 | 
			
		||||
            reduced_grad.data.div_(self.gradient_postdivide_factor)
 | 
			
		||||
 | 
			
		||||
        if self.reuse_fp16_shard:
 | 
			
		||||
            param.col_attr.sharded_data_tensor.reset_payload(reduced_grad.data)
 | 
			
		||||
            param.col_attr.sharded_data_tensor.is_sharded = True
 | 
			
		||||
        else:
 | 
			
		||||
            param.col_attr.fp16_grad = reduced_grad.data
 | 
			
		||||
 | 
			
		||||
    def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
 | 
			
		||||
 
 | 
			
		||||
@@ -224,5 +224,5 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
 | 
			
		||||
                    if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem:
 | 
			
		||||
                        self.master_params[p] = self.master_params[p].to(torch.cuda.current_device())
 | 
			
		||||
                        p.grad.data = p.grad.data.to(torch.cuda.current_device())
 | 
			
		||||
                        p.col_attr.offload_fp32_grad = False
 | 
			
		||||
                        p.col_attr.offload_grad = False
 | 
			
		||||
                        fp32_shards_used_cuda_margin_mem += shard_mem
 | 
			
		||||
 
 | 
			
		||||
@@ -14,7 +14,7 @@ class ShardedParamV2(object):
 | 
			
		||||
        self.fp16_grad: Optional[torch.Tensor] = None
 | 
			
		||||
        self.fp32_grad: Optional[torch.Tensor] = None
 | 
			
		||||
        # This attribute must be initialized in ShardedModel
 | 
			
		||||
        self.offload_fp32_grad: bool = False
 | 
			
		||||
        self.offload_grad: bool = False
 | 
			
		||||
 | 
			
		||||
        # make sure the shared param is the only owner of payload
 | 
			
		||||
        # The param.data maybe used to init the other part of the model.
 | 
			
		||||
 
 | 
			
		||||
@@ -16,7 +16,8 @@ _ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25,
 | 
			
		||||
                          offload_config=None,
 | 
			
		||||
                          gradient_predivide_factor=1.0,
 | 
			
		||||
                          use_memory_tracer=False,
 | 
			
		||||
                          shard_strategy=TensorShardStrategy())
 | 
			
		||||
                          shard_strategy=TensorShardStrategy(),
 | 
			
		||||
                          reuse_fp16_shard=False)
 | 
			
		||||
 | 
			
		||||
_ZERO_OPTIMIZER_CONFIG = dict(cpu_offload=False,
 | 
			
		||||
                              initial_scale=2**5,
 | 
			
		||||
@@ -116,9 +117,12 @@ def check_params_padding(model, zero_model, loose=False):
 | 
			
		||||
        assert allclose(p, zero_p, loose=loose)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_sharded_params_padding(model, zero_model, loose=False):
 | 
			
		||||
def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=False):
 | 
			
		||||
    rank = dist.get_rank()
 | 
			
		||||
    for p, zero_p in zip(model.parameters(), zero_model.parameters()):
 | 
			
		||||
        if reuse_fp16_shard:
 | 
			
		||||
            zero_p = zero_p.data.to(p.device).float()
 | 
			
		||||
        else:
 | 
			
		||||
            zero_p = zero_p.col_attr.sharded_data_tensor.payload.to(p.device).float()
 | 
			
		||||
        chunks = torch.flatten(p).chunk(dist.get_world_size())
 | 
			
		||||
        if rank >= len(chunks):
 | 
			
		||||
 
 | 
			
		||||
@@ -18,7 +18,7 @@ from colossalai.zero.sharded_optim._utils import has_inf_or_nan
 | 
			
		||||
from tests.components_to_test.registry import non_distributed_component_funcs
 | 
			
		||||
from torch.nn.parallel import DistributedDataParallel as DDP
 | 
			
		||||
 | 
			
		||||
from common import CONFIG, check_sharded_params_padding
 | 
			
		||||
from common import CONFIG, check_sharded_model_params
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _run_step(model, optimizer, data, label, criterion, enable_autocast=False):
 | 
			
		||||
@@ -65,7 +65,8 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
 | 
			
		||||
        zero_model = ShardedModelV2(zero_model,
 | 
			
		||||
                                    shard_strategy,
 | 
			
		||||
                                    offload_config=dict(device='cpu') if cpu_offload else None,
 | 
			
		||||
                                    use_memory_tracer=gpu_margin_mem_ratio > 0.0)
 | 
			
		||||
                                    use_memory_tracer=gpu_margin_mem_ratio > 0.0,
 | 
			
		||||
                                    reuse_fp16_shard=use_cpuadam)
 | 
			
		||||
 | 
			
		||||
        model = model_builder(checkpoint=True).half()
 | 
			
		||||
        col_model_deepcopy(zero_model, model)
 | 
			
		||||
@@ -92,7 +93,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
 | 
			
		||||
            data, label = data.cuda(), label.cuda()
 | 
			
		||||
            _run_step(apex_model, apex_optimizer, data, label, criterion, False)
 | 
			
		||||
            _run_step(zero_model, sharded_optim, data, label, criterion, False)
 | 
			
		||||
            check_sharded_params_padding(model, zero_model, loose=True)
 | 
			
		||||
            check_sharded_model_params(model, zero_model, loose=True, reuse_fp16_shard=use_cpuadam)
 | 
			
		||||
            for param in model.parameters():
 | 
			
		||||
                assert not has_inf_or_nan(param)
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -16,7 +16,7 @@ from colossalai.zero.sharded_optim._utils import has_inf_or_nan
 | 
			
		||||
from tests.components_to_test.registry import non_distributed_component_funcs
 | 
			
		||||
from torch.nn.parallel import DistributedDataParallel as DDP
 | 
			
		||||
 | 
			
		||||
from common import (MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_params_padding)
 | 
			
		||||
from common import (MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_model_params)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_dist(rank, world_size, port, parallel_config):
 | 
			
		||||
@@ -87,7 +87,7 @@ def run_dist(rank, world_size, port, parallel_config):
 | 
			
		||||
        if parallel_config == MP_PARALLEL_CONFIG:
 | 
			
		||||
            check_params(torch_model, colo_model, loose=True)
 | 
			
		||||
        elif parallel_config == ZERO_PARALLEL_CONFIG:
 | 
			
		||||
            check_sharded_params_padding(torch_model, colo_model, loose=True)
 | 
			
		||||
            check_sharded_model_params(torch_model, colo_model, loose=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# FIXME: enable this test in next PR
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user