[zero] add zero context manager to change config during initialization (#546)

This commit is contained in:
HELSON
2022-03-29 17:57:59 +08:00
committed by GitHub
parent ec5086c49c
commit 8c90d4df54
5 changed files with 185 additions and 18 deletions

View File

@@ -5,6 +5,7 @@ import torch.nn as nn
from colossalai.context import ParallelMode, seed
from colossalai.utils import get_current_device
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.zero.init_ctx import no_shard_zero_decrator
from typing import Type
@@ -34,6 +35,7 @@ class Experts(MoeExperts):
expert_args: Args used to initialize experts, the args could be found in corresponding expert class
"""
@no_shard_zero_decrator
def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args):
super().__init__("all_to_all", num_experts)

View File

@@ -1,3 +1,4 @@
import functools
import math
import torch
@@ -9,6 +10,7 @@ from colossalai.utils import get_current_device
from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum
from .experts import MoeExperts, Experts
from .utils import ForceFP32Parameter, UniformNoiseGenerator, NormalNoiseGenerator
from colossalai.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator
from typing import Callable, Optional, Type
from torch.distributed import ProcessGroup
@@ -205,7 +207,7 @@ class Top2Router(nn.Module):
return cb_weight, sec_mask
class FP32LinearGate(nn.Linear):
class FP32LinearGate(nn.Module):
"""Gate module used in MOE layer. Just a linear function without bias.
But it should be kept as fp32 forever.
@@ -217,9 +219,13 @@ class FP32LinearGate(nn.Linear):
weight (ForceFP32Parameter): The weight of linear gate
"""
def __init__(self, d_model: int, num_experts: int):
super().__init__(d_model, num_experts, bias=False, device=get_current_device())
self.weight = ForceFP32Parameter(self.weight)
def __init__(self, d_model: int, num_experts: int, scale: float = 0.1):
super().__init__()
self.weight = ForceFP32Parameter(torch.empty(num_experts, d_model, device=get_current_device()))
nn.init.trunc_normal_(self.weight, std=math.sqrt(scale / d_model))
def forward(self, x: torch.Tensor):
return F.linear(x, self.weight)
class MoeLayer(nn.Module):
@@ -235,6 +241,7 @@ class MoeLayer(nn.Module):
experts (:class:`torch.nn.Module`): Instance of experts generated by Expert.
"""
@no_shard_zero_decrator
def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: MoeExperts):
super().__init__()
self.d_model = dim_model
@@ -361,7 +368,6 @@ class MoeModule(nn.Module):
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks)
self.use_residual = use_residual
if use_residual:
if residual_instance is not None:
@@ -371,7 +377,8 @@ class MoeModule(nn.Module):
"Expert class can't be None when residual instance is not given"
self.residual_module = expert_cls(**expert_args)
self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device())
with no_shard_zero_context():
self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device())
if expert_instance is not None:
self.experts = expert_instance