mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 09:59:38 +00:00
[zero] adapt zero for unsharded parameters (#561)
* support existing sharded and unsharded parameters in zero * add unitest for moe-zero model init * polish moe gradient handler
This commit is contained in:
@@ -16,6 +16,9 @@ class MoeGradientHandler(BaseGradientHandler):
|
||||
the same type to improve the efficiency of communication.
|
||||
"""
|
||||
|
||||
def __init__(self, model, optimizer=None):
|
||||
super().__init__(model, optimizer)
|
||||
|
||||
def handle_gradient(self):
|
||||
"""A method running an all-reduce operation in a data parallel group.
|
||||
Then running an all-reduce operation for all parameters in experts
|
||||
@@ -24,13 +27,15 @@ class MoeGradientHandler(BaseGradientHandler):
|
||||
global_data = gpc.data_parallel_size
|
||||
|
||||
if global_data > 1:
|
||||
param_dict = get_moe_epsize_param_dict(self._model)
|
||||
epsize_param_dict = get_moe_epsize_param_dict(self._model)
|
||||
|
||||
# epsize is 1, indicating the params are replicated among processes in data parallelism
|
||||
# use the ParallelMode.DATA to get data parallel group
|
||||
# reduce gradients for all parameters in data parallelism
|
||||
if 1 in param_dict:
|
||||
bucket_allreduce(param_list=param_dict[1], group=gpc.get_group(ParallelMode.DATA))
|
||||
if 1 in epsize_param_dict:
|
||||
bucket_allreduce(param_list=epsize_param_dict[1], group=gpc.get_group(ParallelMode.DATA))
|
||||
|
||||
for ep_size in param_dict:
|
||||
for ep_size in epsize_param_dict:
|
||||
if ep_size != 1 and ep_size != MOE_CONTEXT.world_size:
|
||||
bucket_allreduce(param_list=param_dict[ep_size],
|
||||
bucket_allreduce(param_list=epsize_param_dict[ep_size],
|
||||
group=MOE_CONTEXT.parallel_info_dict[ep_size].dp_group)
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from typing import List
|
||||
from typing import List, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -61,18 +61,25 @@ class PostBackwardFunction(torch.autograd.Function):
|
||||
return (None, None) + args
|
||||
|
||||
|
||||
def register_ophooks_recursively(module: torch.nn.Module, ophook_list: List[BaseOpHook] = None, name: str = ""):
|
||||
def register_ophooks_recursively(module: torch.nn.Module,
|
||||
ophook_list: List[BaseOpHook] = None,
|
||||
name: str = "",
|
||||
filter_fn: Optional[Callable] = None):
|
||||
r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
|
||||
# Add hooks for submodules
|
||||
for child_name, child in module.named_children():
|
||||
register_ophooks_recursively(child, ophook_list, name + child_name)
|
||||
register_ophooks_recursively(child, ophook_list, name + child_name, filter_fn)
|
||||
|
||||
# Early return on modules with no parameters.
|
||||
if len(list(module.parameters(recurse=False))) == 0:
|
||||
return
|
||||
|
||||
# return from flitered module
|
||||
if filter_fn is not None and filter_fn(module):
|
||||
return
|
||||
|
||||
if ophook_list is not None:
|
||||
for hook in ophook_list:
|
||||
assert (isinstance(hook, BaseOpHook))
|
||||
|
Reference in New Issue
Block a user