mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 17:40:33 +00:00
[MOE] add unitest for MOE experts layout, gradient handler and kernel (#469)
This commit is contained in:
@@ -2,7 +2,6 @@ from typing import Optional
|
||||
|
||||
|
||||
class TensorParallelEnv(object):
|
||||
|
||||
_instance = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
@@ -33,7 +32,7 @@ class TensorParallelEnv(object):
|
||||
self.depth_3d = depth_3d
|
||||
self.input_group_3d = input_group_3d
|
||||
self.weight_group_3d = weight_group_3d
|
||||
self.output_group_3d = output_group_3d
|
||||
self.output_group_3d = output_group_3d
|
||||
|
||||
def save(self):
|
||||
return dict(mode=self.mode,
|
||||
@@ -48,43 +47,4 @@ class TensorParallelEnv(object):
|
||||
output_group_3d=self.output_group_3d)
|
||||
|
||||
|
||||
class MoeEnv:
|
||||
"""Moe enviroment variables.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.data_parallel_size = None
|
||||
self.model_parallel_size = None
|
||||
self.aux_loss = None
|
||||
self.enable_cuda = True
|
||||
|
||||
def setup(self, moe_model_size):
|
||||
from .core import global_context as gpc
|
||||
if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1:
|
||||
raise NotImplementedError("Moe is not compatible with tensor or pipeline parallel")
|
||||
|
||||
assert gpc.data_parallel_size % moe_model_size == 0, \
|
||||
"The size of data parallel needs to be divided by moe model parallel size"
|
||||
|
||||
self.data_parallel_size = gpc.data_parallel_size // moe_model_size
|
||||
self.model_parallel_size = moe_model_size
|
||||
|
||||
def is_initialized(self):
|
||||
return self.model_parallel_size is not None
|
||||
|
||||
def set_cuda_false(self):
|
||||
self.enable_cuda = False
|
||||
|
||||
def reset_loss(self):
|
||||
self.aux_loss = 0
|
||||
|
||||
def add_loss(self, loss):
|
||||
self.aux_loss += loss
|
||||
|
||||
def get_loss(self):
|
||||
return self.aux_loss
|
||||
|
||||
|
||||
tensor_parallel_env = TensorParallelEnv()
|
||||
|
||||
moe_env = MoeEnv()
|
||||
|
Reference in New Issue
Block a user