mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[zero] add zero context manager to change config during initialization (#546)
This commit is contained in:
@@ -5,6 +5,7 @@ import torch.nn as nn
|
||||
from colossalai.context import ParallelMode, seed
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from colossalai.zero.init_ctx import no_shard_zero_decrator
|
||||
from typing import Type
|
||||
|
||||
|
||||
@@ -34,6 +35,7 @@ class Experts(MoeExperts):
|
||||
expert_args: Args used to initialize experts, the args could be found in corresponding expert class
|
||||
"""
|
||||
|
||||
@no_shard_zero_decrator
|
||||
def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args):
|
||||
super().__init__("all_to_all", num_experts)
|
||||
|
||||
|
@@ -1,3 +1,4 @@
|
||||
import functools
|
||||
import math
|
||||
|
||||
import torch
|
||||
@@ -9,6 +10,7 @@ from colossalai.utils import get_current_device
|
||||
from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum
|
||||
from .experts import MoeExperts, Experts
|
||||
from .utils import ForceFP32Parameter, UniformNoiseGenerator, NormalNoiseGenerator
|
||||
from colossalai.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator
|
||||
from typing import Callable, Optional, Type
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
@@ -205,7 +207,7 @@ class Top2Router(nn.Module):
|
||||
return cb_weight, sec_mask
|
||||
|
||||
|
||||
class FP32LinearGate(nn.Linear):
|
||||
class FP32LinearGate(nn.Module):
|
||||
"""Gate module used in MOE layer. Just a linear function without bias.
|
||||
But it should be kept as fp32 forever.
|
||||
|
||||
@@ -217,9 +219,13 @@ class FP32LinearGate(nn.Linear):
|
||||
weight (ForceFP32Parameter): The weight of linear gate
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, num_experts: int):
|
||||
super().__init__(d_model, num_experts, bias=False, device=get_current_device())
|
||||
self.weight = ForceFP32Parameter(self.weight)
|
||||
def __init__(self, d_model: int, num_experts: int, scale: float = 0.1):
|
||||
super().__init__()
|
||||
self.weight = ForceFP32Parameter(torch.empty(num_experts, d_model, device=get_current_device()))
|
||||
nn.init.trunc_normal_(self.weight, std=math.sqrt(scale / d_model))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return F.linear(x, self.weight)
|
||||
|
||||
|
||||
class MoeLayer(nn.Module):
|
||||
@@ -235,6 +241,7 @@ class MoeLayer(nn.Module):
|
||||
experts (:class:`torch.nn.Module`): Instance of experts generated by Expert.
|
||||
"""
|
||||
|
||||
@no_shard_zero_decrator
|
||||
def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: MoeExperts):
|
||||
super().__init__()
|
||||
self.d_model = dim_model
|
||||
@@ -361,7 +368,6 @@ class MoeModule(nn.Module):
|
||||
min_capacity=min_capacity,
|
||||
noisy_func=noisy_func,
|
||||
drop_tks=drop_tks)
|
||||
|
||||
self.use_residual = use_residual
|
||||
if use_residual:
|
||||
if residual_instance is not None:
|
||||
@@ -371,7 +377,8 @@ class MoeModule(nn.Module):
|
||||
"Expert class can't be None when residual instance is not given"
|
||||
self.residual_module = expert_cls(**expert_args)
|
||||
|
||||
self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device())
|
||||
with no_shard_zero_context():
|
||||
self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device())
|
||||
|
||||
if expert_instance is not None:
|
||||
self.experts = expert_instance
|
||||
|
@@ -1,3 +1,3 @@
|
||||
from .init_context import ZeroInitContext
|
||||
from .init_context import ZeroInitContext, no_shard_zero_context, no_shard_zero_decrator
|
||||
|
||||
__all__ = ['ZeroInitContext']
|
||||
__all__ = ['ZeroInitContext', 'no_shard_zero_context', 'no_shard_zero_decrator']
|
||||
|
@@ -1,9 +1,11 @@
|
||||
import contextlib
|
||||
import functools
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
|
||||
@@ -82,6 +84,25 @@ class InsertPostInitMethodToModuleSubClasses(object):
|
||||
pass
|
||||
|
||||
|
||||
class ZeroContextConfig(object):
|
||||
"""The configuration used to control zero context initialization.
|
||||
|
||||
Args:
|
||||
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
||||
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
|
||||
This will reduce memory usage when initializing model.
|
||||
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
|
||||
If set to `False`, remove tensor payload on param.data afther the context exist.
|
||||
This is used when you add some logic to operate tensors in __init__ of module.
|
||||
See torchvision resnet18. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self, shard_param: bool = False, rm_torch_payload_on_the_fly: bool = False):
|
||||
super().__init__()
|
||||
self.shard_param: bool = shard_param
|
||||
self.rm_torch_payload_on_the_fly: bool = rm_torch_payload_on_the_fly
|
||||
|
||||
|
||||
class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
"""A context to initialize model.
|
||||
|
||||
@@ -90,11 +111,9 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
3. Shard the param and grad according to flags.
|
||||
|
||||
Args:
|
||||
convert_fp16 (bool): Whether to convert params to fp16.
|
||||
target_device (torch.device): The device where param data after exiting the context.
|
||||
shard_strategy (BaseShardStrategy): Shard strategy instance.
|
||||
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
||||
shard_grad (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
||||
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
|
||||
This will reduce memory usage when initializing model.
|
||||
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
|
||||
@@ -115,13 +134,23 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
|
||||
super().__init__()
|
||||
self.target_device = target_device
|
||||
self.shard_param = shard_param
|
||||
self.shard_strategy = shard_strategy
|
||||
self.rm_torch_payload_on_the_fly = rm_torch_payload_on_the_fly
|
||||
self.initialized_param_list = []
|
||||
self.model_numel_tensor = model_numel_tensor
|
||||
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
|
||||
|
||||
self.config = ZeroContextConfig(shard_param=shard_param,
|
||||
rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly)
|
||||
ZeroContextMgr().current_context = self
|
||||
|
||||
@property
|
||||
def shard_param(self):
|
||||
return self.config.shard_param
|
||||
|
||||
@property
|
||||
def rm_torch_payload_on_the_fly(self):
|
||||
return self.config.rm_torch_payload_on_the_fly
|
||||
|
||||
def _pre_context_exec(self):
|
||||
"""
|
||||
The Callback function when entering the context
|
||||
@@ -143,6 +172,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
The function to call at the end of the constructor of each module.
|
||||
NOTE() The module may be passed to this function multiple times.
|
||||
"""
|
||||
|
||||
def half_fn(t: torch.Tensor):
|
||||
return t.half() if t.is_floating_point() else t
|
||||
|
||||
for param in module.parameters(recurse=False):
|
||||
# avoid adapting a param to ShardedParam twice
|
||||
if hasattr(param, 'col_attr'):
|
||||
@@ -150,23 +183,24 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
|
||||
self.model_numel_tensor += param.numel()
|
||||
|
||||
target_device = self.target_device
|
||||
|
||||
# convert to fp16
|
||||
param.data = param.data.to(torch.half)
|
||||
# convert parameters to half
|
||||
param_half = half_fn(param)
|
||||
param.data = param_half
|
||||
if param.grad is not None:
|
||||
param.grad = param.grad.to(torch.half)
|
||||
grad_half = half_fn(param.grad)
|
||||
param.grad.data = grad_half
|
||||
|
||||
# move torch parameters to the target device
|
||||
target_device = self.target_device
|
||||
param.data = param.data.to(target_device)
|
||||
if param.grad is not None:
|
||||
param.grad = param.grad.to(target_device)
|
||||
|
||||
param.col_attr = ShardedParamV2(param, rm_torch_payload=self.rm_torch_payload_on_the_fly)
|
||||
|
||||
self.initialized_param_list.append(param)
|
||||
if self.shard_param:
|
||||
self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group)
|
||||
self.initialized_param_list.append(param)
|
||||
|
||||
# We must cast buffers
|
||||
# If we use BN, buffers may be on CPU and Float
|
||||
@@ -174,3 +208,30 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
for buffer in module.buffers(recurse=False):
|
||||
buffer.data = buffer.data.to(device=torch.cuda.current_device())
|
||||
buffer.data = cast_tensor_to_fp16(buffer.data)
|
||||
|
||||
|
||||
class ZeroContextMgr(metaclass=SingletonMeta):
|
||||
current_context: Optional[ZeroInitContext] = None
|
||||
|
||||
@contextlib.contextmanager
|
||||
def hijack_context_config(self, **kwargs):
|
||||
if self.current_context is None:
|
||||
yield
|
||||
else:
|
||||
old_config = self.current_context.config
|
||||
self.current_context.config = ZeroContextConfig(**kwargs)
|
||||
yield
|
||||
self.current_context.config = old_config
|
||||
|
||||
|
||||
def no_shard_zero_context():
|
||||
return ZeroContextMgr().hijack_context_config(shard_param=False, rm_torch_payload_on_the_fly=False)
|
||||
|
||||
|
||||
def no_shard_zero_decrator(init_func):
|
||||
|
||||
def _no_shard(*args, **kwargs):
|
||||
with no_shard_zero_context():
|
||||
init_func(*args, **kwargs)
|
||||
|
||||
return _no_shard
|
||||
|
Reference in New Issue
Block a user