mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 13:11:05 +00:00
[zero] add zero context manager to change config during initialization (#546)
This commit is contained in:
@@ -5,6 +5,7 @@ import torch.nn as nn
|
||||
from colossalai.context import ParallelMode, seed
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from colossalai.zero.init_ctx import no_shard_zero_decrator
|
||||
from typing import Type
|
||||
|
||||
|
||||
@@ -34,6 +35,7 @@ class Experts(MoeExperts):
|
||||
expert_args: Args used to initialize experts, the args could be found in corresponding expert class
|
||||
"""
|
||||
|
||||
@no_shard_zero_decrator
|
||||
def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args):
|
||||
super().__init__("all_to_all", num_experts)
|
||||
|
||||
|
@@ -1,3 +1,4 @@
|
||||
import functools
|
||||
import math
|
||||
|
||||
import torch
|
||||
@@ -9,6 +10,7 @@ from colossalai.utils import get_current_device
|
||||
from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum
|
||||
from .experts import MoeExperts, Experts
|
||||
from .utils import ForceFP32Parameter, UniformNoiseGenerator, NormalNoiseGenerator
|
||||
from colossalai.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator
|
||||
from typing import Callable, Optional, Type
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
@@ -205,7 +207,7 @@ class Top2Router(nn.Module):
|
||||
return cb_weight, sec_mask
|
||||
|
||||
|
||||
class FP32LinearGate(nn.Linear):
|
||||
class FP32LinearGate(nn.Module):
|
||||
"""Gate module used in MOE layer. Just a linear function without bias.
|
||||
But it should be kept as fp32 forever.
|
||||
|
||||
@@ -217,9 +219,13 @@ class FP32LinearGate(nn.Linear):
|
||||
weight (ForceFP32Parameter): The weight of linear gate
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, num_experts: int):
|
||||
super().__init__(d_model, num_experts, bias=False, device=get_current_device())
|
||||
self.weight = ForceFP32Parameter(self.weight)
|
||||
def __init__(self, d_model: int, num_experts: int, scale: float = 0.1):
|
||||
super().__init__()
|
||||
self.weight = ForceFP32Parameter(torch.empty(num_experts, d_model, device=get_current_device()))
|
||||
nn.init.trunc_normal_(self.weight, std=math.sqrt(scale / d_model))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return F.linear(x, self.weight)
|
||||
|
||||
|
||||
class MoeLayer(nn.Module):
|
||||
@@ -235,6 +241,7 @@ class MoeLayer(nn.Module):
|
||||
experts (:class:`torch.nn.Module`): Instance of experts generated by Expert.
|
||||
"""
|
||||
|
||||
@no_shard_zero_decrator
|
||||
def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: MoeExperts):
|
||||
super().__init__()
|
||||
self.d_model = dim_model
|
||||
@@ -361,7 +368,6 @@ class MoeModule(nn.Module):
|
||||
min_capacity=min_capacity,
|
||||
noisy_func=noisy_func,
|
||||
drop_tks=drop_tks)
|
||||
|
||||
self.use_residual = use_residual
|
||||
if use_residual:
|
||||
if residual_instance is not None:
|
||||
@@ -371,7 +377,8 @@ class MoeModule(nn.Module):
|
||||
"Expert class can't be None when residual instance is not given"
|
||||
self.residual_module = expert_cls(**expert_args)
|
||||
|
||||
self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device())
|
||||
with no_shard_zero_context():
|
||||
self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device())
|
||||
|
||||
if expert_instance is not None:
|
||||
self.experts = expert_instance
|
||||
|
Reference in New Issue
Block a user