[zero] adapt zero for unsharded parameters (#561)

* support existing sharded and unsharded parameters in zero

* add unitest for moe-zero model init

* polish moe gradient handler
This commit is contained in:
HELSON
2022-03-31 18:34:11 +08:00
committed by GitHub
parent 13ed4b6441
commit e6d50ec107
11 changed files with 211 additions and 70 deletions

View File

@@ -16,6 +16,9 @@ class MoeGradientHandler(BaseGradientHandler):
the same type to improve the efficiency of communication.
"""
def __init__(self, model, optimizer=None):
super().__init__(model, optimizer)
def handle_gradient(self):
"""A method running an all-reduce operation in a data parallel group.
Then running an all-reduce operation for all parameters in experts
@@ -24,13 +27,15 @@ class MoeGradientHandler(BaseGradientHandler):
global_data = gpc.data_parallel_size
if global_data > 1:
param_dict = get_moe_epsize_param_dict(self._model)
epsize_param_dict = get_moe_epsize_param_dict(self._model)
# epsize is 1, indicating the params are replicated among processes in data parallelism
# use the ParallelMode.DATA to get data parallel group
# reduce gradients for all parameters in data parallelism
if 1 in param_dict:
bucket_allreduce(param_list=param_dict[1], group=gpc.get_group(ParallelMode.DATA))
if 1 in epsize_param_dict:
bucket_allreduce(param_list=epsize_param_dict[1], group=gpc.get_group(ParallelMode.DATA))
for ep_size in param_dict:
for ep_size in epsize_param_dict:
if ep_size != 1 and ep_size != MOE_CONTEXT.world_size:
bucket_allreduce(param_list=param_dict[ep_size],
bucket_allreduce(param_list=epsize_param_dict[ep_size],
group=MOE_CONTEXT.parallel_info_dict[ep_size].dp_group)

View File

@@ -1,4 +1,4 @@
from typing import List
from typing import List, Callable, Optional
import torch
@@ -61,18 +61,25 @@ class PostBackwardFunction(torch.autograd.Function):
return (None, None) + args
def register_ophooks_recursively(module: torch.nn.Module, ophook_list: List[BaseOpHook] = None, name: str = ""):
def register_ophooks_recursively(module: torch.nn.Module,
ophook_list: List[BaseOpHook] = None,
name: str = "",
filter_fn: Optional[Callable] = None):
r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
assert isinstance(module, torch.nn.Module)
# Add hooks for submodules
for child_name, child in module.named_children():
register_ophooks_recursively(child, ophook_list, name + child_name)
register_ophooks_recursively(child, ophook_list, name + child_name, filter_fn)
# Early return on modules with no parameters.
if len(list(module.parameters(recurse=False))) == 0:
return
# return from flitered module
if filter_fn is not None and filter_fn(module):
return
if ophook_list is not None:
for hook in ophook_list:
assert (isinstance(hook, BaseOpHook))