add moe context, moe utilities and refactor gradient handler (#455)

This commit is contained in:
HELSON
2022-03-18 16:38:32 +08:00
committed by GitHub
parent af185b5519
commit 84fd7c1d4d
11 changed files with 255 additions and 125 deletions

View File

@@ -1,12 +1,8 @@
#!/usr/bin/env python
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler
from ...context.parallel_mode import ParallelMode
from .utils import bucket_allreduce
@GRADIENT_HANDLER.register_module
@@ -23,26 +19,4 @@ class DataParallelGradientHandler(BaseGradientHandler):
"""
# TODO: add memory buffer
if gpc.data_parallel_size > 1:
# bucketize and all-reduce
buckets = {}
# Pack the buckets.
for param in self._model.parameters():
if param.requires_grad and param.grad is not None:
tp = param.data.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(param)
# param.main_grad = param.grad
# For each bucket, all-reduce and copy all-reduced grads.
for tp in buckets:
bucket = buckets[tp]
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
coalesced /= gpc.get_world_size(ParallelMode.DATA)
dist.all_reduce(
coalesced, group=gpc.get_group(ParallelMode.DATA))
for buf, synced in zip(grads, _unflatten_dense_tensors(
coalesced, grads)):
buf.copy_(synced)
bucket_allreduce(param_list=self._model.parameters(), group=gpc.get_group(ParallelMode.DATA))

View File

@@ -1,10 +1,9 @@
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from colossalai.core import global_context as gpc
from colossalai.core import global_context as gpc, moe_context as moe_env
from colossalai.registry import GRADIENT_HANDLER
from colossalai.global_variables import moe_env
from colossalai.utils.moe import get_moe_epsize_param_dict
from ._base_gradient_handler import BaseGradientHandler
from ...context.parallel_mode import ParallelMode
from .utils import bucket_allreduce
@GRADIENT_HANDLER.register_module
@@ -21,41 +20,15 @@ class MoeGradientHandler(BaseGradientHandler):
Then running an all-reduce operation for all parameters in experts
across moe model parallel group
"""
moe_data = moe_env.data_parallel_size
global_data = gpc.data_parallel_size
if global_data > 1:
# bucketize and all-reduce
buckets = {}
# Pack the buckets.
for param in self._model.parameters():
if param.requires_grad and \
param.grad is not None and \
not hasattr(param, 'moe_param'):
tp = param.data.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(param)
# param.main_grad = param.grad
param_dict = get_moe_epsize_param_dict(self._model)
# For each bucket, all-reduce and copy all-reduced grads.
for tp in buckets:
bucket = buckets[tp]
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
coalesced /= gpc.get_world_size(ParallelMode.DATA)
# reduce gradients for all parameters in data parallelism
if 1 in param_dict:
bucket_allreduce(param_list=param_dict[1], group=gpc.get_group(ParallelMode.DATA))
dist.all_reduce(
coalesced, group=gpc.get_group(ParallelMode.DATA))
for buf, synced in zip(grads, _unflatten_dense_tensors(
coalesced, grads)):
buf.copy_(synced)
if global_data > 1:
for param in self._model.parameters():
if not param.requires_grad or param.grad is None:
continue
if moe_data > 1 and hasattr(param, 'moe_param'):
param.grad.data /= moe_data
dist.all_reduce(param.grad.data,
group=gpc.get_group(ParallelMode.MOE_DATA))
for ep_size in param_dict:
if ep_size != 1 and ep_size != moe_env.world_size:
bucket_allreduce(param_list=param_dict[ep_size], group=moe_env.information[ep_size].dp_group)

View File

@@ -1,14 +1,8 @@
#!/usr/bin/env python
from functools import total_ordering
import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler
from ...context.parallel_mode import ParallelMode
import colossalai
from .utils import bucket_allreduce
@GRADIENT_HANDLER.register_module
@@ -23,29 +17,5 @@ class SequenceParallelGradientHandler(BaseGradientHandler):
def handle_gradient(self):
"""A method running a all-reduce operation in a data parallel group.
"""
# bucketize and all-reduce
buckets = {}
# Pack the buckets.
for param in self._model.parameters():
if param.requires_grad and param.grad is not None:
tp = param.data.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(param)
# For each bucket, all-reduce and copy all-reduced grads.
for tp in buckets:
bucket = buckets[tp]
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
coalesced /= gpc.get_world_size(ParallelMode.SEQUENCE_DP)
dist.all_reduce(
coalesced, group=gpc.get_group(ParallelMode.SEQUENCE_DP))
for buf, synced in zip(grads, _unflatten_dense_tensors(
coalesced, grads)):
buf.copy_(synced)
if gpc.get_world_size(ParallelMode.SEQUENCE_DP) > 1:
bucket_allreduce(param_list=self._model.parameters(), group=gpc.get_group(ParallelMode.SEQUENCE_DP))

View File

@@ -0,0 +1,29 @@
import torch.distributed as dist
import torch.nn as nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from typing import Iterable
def bucket_allreduce(param_list: Iterable[nn.Parameter], group=None):
# get communication world size
comm_size = dist.get_world_size(group)
# bucketize and all-reduce
buckets = {}
# Pack the buckets.
for param in param_list:
if param.requires_grad and param.grad is not None:
tp = param.data.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(param)
# For each bucket, all-reduce and copy all-reduced grads.
for tp in buckets:
bucket = buckets[tp]
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
coalesced /= comm_size
dist.all_reduce(coalesced, group=group)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)