mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[zero] fix init bugs in zero context (#686)
* adapt model weight initialization for methods in Pytorch nn.init
This commit is contained in:
@@ -3,6 +3,8 @@ import functools
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
@@ -10,7 +12,6 @@ from colossalai.logging import get_dist_logger
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
|
||||
from colossalai.zero.sharded_param import ShardedParamV2
|
||||
from torch.distributed import ProcessGroup
|
||||
from contextlib import AbstractContextManager
|
||||
|
||||
|
||||
@@ -93,24 +94,21 @@ class ZeroContextConfig(object):
|
||||
replicated (bool, optional): Whether the param is replicated across data parallel group.
|
||||
Some parameters are not replicated, e.g. parameters in MOE experts.
|
||||
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
||||
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
|
||||
This will reduce memory usage when initializing model.
|
||||
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
|
||||
If set to `False`, remove tensor payload on param.data afther the context exist.
|
||||
This is used when you add some logic to operate tensors in __init__ of module.
|
||||
See torchvision resnet18. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
target_device: torch.device,
|
||||
replicated: bool = True,
|
||||
shard_param: bool = False,
|
||||
rm_torch_payload_on_the_fly: bool = False):
|
||||
def __init__(self, target_device: torch.device, replicated: bool = True, shard_param: bool = False):
|
||||
super().__init__()
|
||||
|
||||
if shard_param:
|
||||
assert replicated, "Non-replicated parameters can't be sharded."
|
||||
|
||||
# replicated no-shard parameters should locate in cuda, since we will broadcast them soon
|
||||
if replicated and not shard_param:
|
||||
assert target_device.type == 'cuda', "Replicated no-shard paramters should locate in cuda."
|
||||
|
||||
self.target_device = target_device
|
||||
self.is_replicated: bool = replicated
|
||||
self.shard_param: bool = shard_param
|
||||
self.rm_torch_payload_on_the_fly: bool = rm_torch_payload_on_the_fly
|
||||
|
||||
|
||||
class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
@@ -123,35 +121,27 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
Args:
|
||||
target_device (torch.device): The device where param data are after exiting the context.
|
||||
shard_strategy (BaseShardStrategy): Shard strategy instance.
|
||||
seed (int, optional): Random seed for weight initialization
|
||||
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
||||
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
|
||||
This will reduce memory usage when initializing model.
|
||||
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
|
||||
If set to `False`, remove tensor payload on param.data afther the context exist.
|
||||
This is used when you add some logic to operate tensors in __init__ of module.
|
||||
See torchvision resnet18. Defaults to False.
|
||||
model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int).
|
||||
dp_process_group (Optional[ProcessGroup], optional): Data parallel process group. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
target_device: torch.device,
|
||||
shard_strategy: BaseShardStrategy,
|
||||
seed: int = 2**10 - 1,
|
||||
shard_param: bool = False,
|
||||
rm_torch_payload_on_the_fly: bool = False,
|
||||
model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.long),
|
||||
dp_process_group: Optional[ProcessGroup] = None):
|
||||
model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.long)):
|
||||
|
||||
super().__init__()
|
||||
self.shard_strategy = shard_strategy
|
||||
self.initialized_param_list = []
|
||||
self.sharded_param_list = []
|
||||
self.unshard_param_list = []
|
||||
self.model_numel_tensor = model_numel_tensor
|
||||
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
|
||||
self.seed = seed
|
||||
self.dp_process_group = gpc.get_group(ParallelMode.DATA)
|
||||
|
||||
self.config = ZeroContextConfig(target_device=target_device,
|
||||
replicated=True,
|
||||
shard_param=shard_param,
|
||||
rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly)
|
||||
self.config = ZeroContextConfig(target_device=target_device, replicated=True, shard_param=shard_param)
|
||||
|
||||
ZeroContextMgr().current_context = self
|
||||
|
||||
@@ -167,9 +157,35 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
def shard_param(self):
|
||||
return self.config.shard_param
|
||||
|
||||
@property
|
||||
def rm_torch_payload_on_the_fly(self):
|
||||
return self.config.rm_torch_payload_on_the_fly
|
||||
@staticmethod
|
||||
def calc_fanin_fanout(tensor: torch.Tensor):
|
||||
"""We use this function to substitute fan-in and fan-out calculation in torch.nn.init.
|
||||
This can help us get correct fan-in and fan-out for sharded tensor.
|
||||
"""
|
||||
assert isinstance(tensor, nn.Parameter), "Sharded tensor initilization is only allowed for paramters"
|
||||
|
||||
# get correct shape of input tensor
|
||||
if not hasattr(tensor, 'colo_attr') or not tensor.colo_attr.param_is_sharded:
|
||||
tensor_shape = tensor.shape
|
||||
else:
|
||||
tensor_shape = tensor.colo_attr.sharded_data_tensor.origin_shape
|
||||
|
||||
dimensions = len(tensor_shape)
|
||||
if dimensions < 2:
|
||||
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
|
||||
|
||||
num_input_fmaps = tensor_shape[1]
|
||||
num_output_fmaps = tensor_shape[0]
|
||||
receptive_field_size = 1
|
||||
if dimensions > 2:
|
||||
# math.prod is not always available, accumulate the product manually
|
||||
# we could use functools.reduce but that is not supported by TorchScript
|
||||
for s in tensor_shape[2:]:
|
||||
receptive_field_size *= s
|
||||
fan_in = num_input_fmaps * receptive_field_size
|
||||
fan_out = num_output_fmaps * receptive_field_size
|
||||
|
||||
return fan_in, fan_out
|
||||
|
||||
def _pre_context_exec(self):
|
||||
"""
|
||||
@@ -177,15 +193,40 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
"""
|
||||
self.logger = get_dist_logger("ZeroInitContext")
|
||||
|
||||
# substitute fan-in and fan-out calculation
|
||||
self.nn_fanin_fanout = nn.init._calculate_fan_in_and_fan_out
|
||||
nn.init._calculate_fan_in_and_fan_out = self.calc_fanin_fanout
|
||||
|
||||
# reserve rng states
|
||||
self.cpu_rng_state = torch.get_rng_state()
|
||||
self.cuda_rng_state = torch.cuda.get_rng_state()
|
||||
|
||||
# set new seed for initialization, since we initialize sharded tensor separately
|
||||
# we don't want all processes have the same seed
|
||||
# otherwise all sharded tensors are same after init
|
||||
offset = self.seed + 1 # we want to have more 1 in binary format seed
|
||||
torch.manual_seed(self.seed + offset * dist.get_rank())
|
||||
|
||||
def _post_context_exec(self):
|
||||
"""The callback function when exiting context.
|
||||
"""
|
||||
if not self.rm_torch_payload_on_the_fly:
|
||||
for param in self.initialized_param_list:
|
||||
assert hasattr(param, 'colo_attr')
|
||||
param.colo_attr.remove_torch_payload()
|
||||
for param in self.sharded_param_list:
|
||||
assert hasattr(param, 'colo_attr')
|
||||
param.colo_attr.remove_torch_payload()
|
||||
|
||||
del self.initialized_param_list
|
||||
del self.sharded_param_list
|
||||
|
||||
src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0]
|
||||
for param in self.unshard_param_list:
|
||||
assert hasattr(param, 'colo_attr')
|
||||
if param.is_replicated:
|
||||
dist.broadcast(tensor=param.data, src=src_rank, group=self.dp_process_group)
|
||||
|
||||
del self.unshard_param_list
|
||||
|
||||
nn.init._calculate_fan_in_and_fan_out = self.nn_fanin_fanout
|
||||
torch.set_rng_state(self.cpu_rng_state)
|
||||
torch.cuda.set_rng_state(self.cuda_rng_state)
|
||||
|
||||
def _post_init_method(self, module: torch.nn.Module):
|
||||
"""
|
||||
@@ -219,11 +260,14 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
if param.grad is not None:
|
||||
param.grad = param.grad.to(target_device)
|
||||
|
||||
param.colo_attr = ShardedParamV2(param, rm_torch_payload=self.rm_torch_payload_on_the_fly)
|
||||
param.colo_attr = ShardedParamV2(param, rm_torch_payload=False)
|
||||
|
||||
if self.shard_param:
|
||||
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
self.initialized_param_list.append(param)
|
||||
param.data = param.colo_attr.sharded_data_tensor.payload
|
||||
self.sharded_param_list.append(param)
|
||||
else:
|
||||
self.unshard_param_list.append(param)
|
||||
|
||||
# We must cast buffers
|
||||
# If we use BN, buffers may be on CPU and Float
|
||||
@@ -250,8 +294,7 @@ class ZeroContextMgr(metaclass=SingletonMeta):
|
||||
def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager:
|
||||
return ZeroContextMgr().hijack_context_config(target_device=torch.device('cuda', torch.cuda.current_device()),
|
||||
replicated=is_replicated,
|
||||
shard_param=False,
|
||||
rm_torch_payload_on_the_fly=False)
|
||||
shard_param=False)
|
||||
|
||||
|
||||
def no_shard_zero_decrator(is_replicated: bool = True):
|
||||
|
Reference in New Issue
Block a user