mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 09:59:38 +00:00
Added MoE parallel (#127)
This commit is contained in:
@@ -2,6 +2,8 @@ from ._base_gradient_handler import BaseGradientHandler
|
||||
from ._data_parallel_gradient_handler import DataParallelGradientHandler
|
||||
from ._zero_gradient_handler import ZeROGradientHandler
|
||||
from ._pipeline_parallel_gradient_handler import PipelineSharedModuleGradientHandler
|
||||
from ._moe_gradient_handler import MoeGradientHandler
|
||||
|
||||
__all__ = ['BaseGradientHandler', 'DataParallelGradientHandler',
|
||||
'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler']
|
||||
'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler',
|
||||
'MoeGradientHandler']
|
||||
|
61
colossalai/engine/gradient_handler/_moe_gradient_handler.py
Normal file
61
colossalai/engine/gradient_handler/_moe_gradient_handler.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import torch.distributed as dist
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import GRADIENT_HANDLER
|
||||
from colossalai.global_variables import moe_env
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
from ...context.parallel_mode import ParallelMode
|
||||
|
||||
|
||||
@GRADIENT_HANDLER.register_module
|
||||
class MoeGradientHandler(BaseGradientHandler):
|
||||
"""A helper class to handle all-reduce operations in a data parallel group and
|
||||
moe tensor parallel. A all-reduce collective communication will be operated in
|
||||
:func:`handle_gradient` among a data parallel group.
|
||||
For better performance, it bucketizes the gradients of all parameters that are
|
||||
the same type to improve the efficiency of communication.
|
||||
"""
|
||||
|
||||
def handle_gradient(self):
|
||||
"""A method running an all-reduce operation in a data parallel group.
|
||||
Then running an all-reduce operation for all parameters in experts
|
||||
across moe tensor parallel group
|
||||
"""
|
||||
moe_data = moe_env.data_parallel_size
|
||||
global_data = gpc.data_parallel_size
|
||||
|
||||
if global_data > 1:
|
||||
# bucketize and all-reduce
|
||||
buckets = {}
|
||||
# Pack the buckets.
|
||||
for param in self._model.parameters():
|
||||
if param.requires_grad and \
|
||||
param.grad is not None and \
|
||||
not hasattr(param, 'moe_param'):
|
||||
tp = param.data.type()
|
||||
if tp not in buckets:
|
||||
buckets[tp] = []
|
||||
buckets[tp].append(param)
|
||||
# param.main_grad = param.grad
|
||||
|
||||
# For each bucket, all-reduce and copy all-reduced grads.
|
||||
for tp in buckets:
|
||||
bucket = buckets[tp]
|
||||
grads = [param.grad.data for param in bucket]
|
||||
coalesced = _flatten_dense_tensors(grads)
|
||||
coalesced /= gpc.get_world_size(ParallelMode.DATA)
|
||||
|
||||
dist.all_reduce(
|
||||
coalesced, group=gpc.get_group(ParallelMode.DATA))
|
||||
for buf, synced in zip(grads, _unflatten_dense_tensors(
|
||||
coalesced, grads)):
|
||||
buf.copy_(synced)
|
||||
|
||||
if global_data > 1:
|
||||
for param in self._model.parameters():
|
||||
if not param.requires_grad or param.grad is None:
|
||||
continue
|
||||
if moe_data > 1 and hasattr(param, 'moe_param'):
|
||||
param.grad.data /= moe_data
|
||||
dist.all_reduce(param.grad.data,
|
||||
group=gpc.get_group(ParallelMode.MOE_DATA))
|
@@ -38,8 +38,9 @@ class BaseSchedule(ABC):
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def _check_sanity(data, tag):
|
||||
assert isinstance(data, (torch.Tensor, dict)), f'{tag} must be torch.Tensor or dict'
|
||||
def _check_sanity(data, tag: str):
|
||||
assert isinstance(data, (torch.Tensor, dict)), \
|
||||
f'{tag} must be torch.Tensor or dict'
|
||||
|
||||
def load_batch(self, data_iter, to_gpu=True):
|
||||
"""Loads a batch from data iterator. It returns the data and labels which are
|
||||
|
Reference in New Issue
Block a user