mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 09:59:38 +00:00
add moe context, moe utilities and refactor gradient handler (#455)
This commit is contained in:
@@ -1,12 +1,8 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import GRADIENT_HANDLER
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
from ...context.parallel_mode import ParallelMode
|
||||
from .utils import bucket_allreduce
|
||||
|
||||
|
||||
@GRADIENT_HANDLER.register_module
|
||||
@@ -23,26 +19,4 @@ class DataParallelGradientHandler(BaseGradientHandler):
|
||||
"""
|
||||
# TODO: add memory buffer
|
||||
if gpc.data_parallel_size > 1:
|
||||
# bucketize and all-reduce
|
||||
buckets = {}
|
||||
# Pack the buckets.
|
||||
for param in self._model.parameters():
|
||||
if param.requires_grad and param.grad is not None:
|
||||
tp = param.data.type()
|
||||
if tp not in buckets:
|
||||
buckets[tp] = []
|
||||
buckets[tp].append(param)
|
||||
# param.main_grad = param.grad
|
||||
|
||||
# For each bucket, all-reduce and copy all-reduced grads.
|
||||
for tp in buckets:
|
||||
bucket = buckets[tp]
|
||||
grads = [param.grad.data for param in bucket]
|
||||
coalesced = _flatten_dense_tensors(grads)
|
||||
coalesced /= gpc.get_world_size(ParallelMode.DATA)
|
||||
|
||||
dist.all_reduce(
|
||||
coalesced, group=gpc.get_group(ParallelMode.DATA))
|
||||
for buf, synced in zip(grads, _unflatten_dense_tensors(
|
||||
coalesced, grads)):
|
||||
buf.copy_(synced)
|
||||
bucket_allreduce(param_list=self._model.parameters(), group=gpc.get_group(ParallelMode.DATA))
|
||||
|
@@ -1,10 +1,9 @@
|
||||
import torch.distributed as dist
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.core import global_context as gpc, moe_context as moe_env
|
||||
from colossalai.registry import GRADIENT_HANDLER
|
||||
from colossalai.global_variables import moe_env
|
||||
from colossalai.utils.moe import get_moe_epsize_param_dict
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
from ...context.parallel_mode import ParallelMode
|
||||
from .utils import bucket_allreduce
|
||||
|
||||
|
||||
@GRADIENT_HANDLER.register_module
|
||||
@@ -21,41 +20,15 @@ class MoeGradientHandler(BaseGradientHandler):
|
||||
Then running an all-reduce operation for all parameters in experts
|
||||
across moe model parallel group
|
||||
"""
|
||||
moe_data = moe_env.data_parallel_size
|
||||
global_data = gpc.data_parallel_size
|
||||
|
||||
if global_data > 1:
|
||||
# bucketize and all-reduce
|
||||
buckets = {}
|
||||
# Pack the buckets.
|
||||
for param in self._model.parameters():
|
||||
if param.requires_grad and \
|
||||
param.grad is not None and \
|
||||
not hasattr(param, 'moe_param'):
|
||||
tp = param.data.type()
|
||||
if tp not in buckets:
|
||||
buckets[tp] = []
|
||||
buckets[tp].append(param)
|
||||
# param.main_grad = param.grad
|
||||
param_dict = get_moe_epsize_param_dict(self._model)
|
||||
|
||||
# For each bucket, all-reduce and copy all-reduced grads.
|
||||
for tp in buckets:
|
||||
bucket = buckets[tp]
|
||||
grads = [param.grad.data for param in bucket]
|
||||
coalesced = _flatten_dense_tensors(grads)
|
||||
coalesced /= gpc.get_world_size(ParallelMode.DATA)
|
||||
# reduce gradients for all parameters in data parallelism
|
||||
if 1 in param_dict:
|
||||
bucket_allreduce(param_list=param_dict[1], group=gpc.get_group(ParallelMode.DATA))
|
||||
|
||||
dist.all_reduce(
|
||||
coalesced, group=gpc.get_group(ParallelMode.DATA))
|
||||
for buf, synced in zip(grads, _unflatten_dense_tensors(
|
||||
coalesced, grads)):
|
||||
buf.copy_(synced)
|
||||
|
||||
if global_data > 1:
|
||||
for param in self._model.parameters():
|
||||
if not param.requires_grad or param.grad is None:
|
||||
continue
|
||||
if moe_data > 1 and hasattr(param, 'moe_param'):
|
||||
param.grad.data /= moe_data
|
||||
dist.all_reduce(param.grad.data,
|
||||
group=gpc.get_group(ParallelMode.MOE_DATA))
|
||||
for ep_size in param_dict:
|
||||
if ep_size != 1 and ep_size != moe_env.world_size:
|
||||
bucket_allreduce(param_list=param_dict[ep_size], group=moe_env.information[ep_size].dp_group)
|
||||
|
@@ -1,14 +1,8 @@
|
||||
#!/usr/bin/env python
|
||||
from functools import total_ordering
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import GRADIENT_HANDLER
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
from ...context.parallel_mode import ParallelMode
|
||||
import colossalai
|
||||
from .utils import bucket_allreduce
|
||||
|
||||
|
||||
@GRADIENT_HANDLER.register_module
|
||||
@@ -23,29 +17,5 @@ class SequenceParallelGradientHandler(BaseGradientHandler):
|
||||
def handle_gradient(self):
|
||||
"""A method running a all-reduce operation in a data parallel group.
|
||||
"""
|
||||
|
||||
# bucketize and all-reduce
|
||||
buckets = {}
|
||||
|
||||
# Pack the buckets.
|
||||
for param in self._model.parameters():
|
||||
if param.requires_grad and param.grad is not None:
|
||||
tp = param.data.type()
|
||||
if tp not in buckets:
|
||||
buckets[tp] = []
|
||||
buckets[tp].append(param)
|
||||
|
||||
# For each bucket, all-reduce and copy all-reduced grads.
|
||||
for tp in buckets:
|
||||
bucket = buckets[tp]
|
||||
grads = [param.grad.data for param in bucket]
|
||||
coalesced = _flatten_dense_tensors(grads)
|
||||
|
||||
coalesced /= gpc.get_world_size(ParallelMode.SEQUENCE_DP)
|
||||
|
||||
dist.all_reduce(
|
||||
coalesced, group=gpc.get_group(ParallelMode.SEQUENCE_DP))
|
||||
|
||||
for buf, synced in zip(grads, _unflatten_dense_tensors(
|
||||
coalesced, grads)):
|
||||
buf.copy_(synced)
|
||||
if gpc.get_world_size(ParallelMode.SEQUENCE_DP) > 1:
|
||||
bucket_allreduce(param_list=self._model.parameters(), group=gpc.get_group(ParallelMode.SEQUENCE_DP))
|
||||
|
29
colossalai/engine/gradient_handler/utils.py
Normal file
29
colossalai/engine/gradient_handler/utils.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from typing import Iterable
|
||||
|
||||
|
||||
def bucket_allreduce(param_list: Iterable[nn.Parameter], group=None):
|
||||
# get communication world size
|
||||
comm_size = dist.get_world_size(group)
|
||||
# bucketize and all-reduce
|
||||
buckets = {}
|
||||
# Pack the buckets.
|
||||
for param in param_list:
|
||||
if param.requires_grad and param.grad is not None:
|
||||
tp = param.data.type()
|
||||
if tp not in buckets:
|
||||
buckets[tp] = []
|
||||
buckets[tp].append(param)
|
||||
|
||||
# For each bucket, all-reduce and copy all-reduced grads.
|
||||
for tp in buckets:
|
||||
bucket = buckets[tp]
|
||||
grads = [param.grad.data for param in bucket]
|
||||
coalesced = _flatten_dense_tensors(grads)
|
||||
coalesced /= comm_size
|
||||
|
||||
dist.all_reduce(coalesced, group=group)
|
||||
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
||||
buf.copy_(synced)
|
Reference in New Issue
Block a user