mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[legacy] move engine to legacy (#4560)
* [legacy] move engine to legacy * [example] fix seq parallel example * [example] fix seq parallel example * [test] test gemini pluging hang * [test] test gemini pluging hang * [test] test gemini pluging hang * [test] test gemini pluging hang * [test] test gemini pluging hang * [example] update seq parallel requirements
This commit is contained in:
11
colossalai/legacy/engine/gradient_handler/__init__.py
Normal file
11
colossalai/legacy/engine/gradient_handler/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
from ._data_parallel_gradient_handler import DataParallelGradientHandler
|
||||
from ._moe_gradient_handler import MoeGradientHandler
|
||||
from ._pipeline_parallel_gradient_handler import PipelineSharedModuleGradientHandler
|
||||
from ._sequence_parallel_gradient_handler import SequenceParallelGradientHandler
|
||||
from ._zero_gradient_handler import ZeROGradientHandler
|
||||
|
||||
__all__ = [
|
||||
'BaseGradientHandler', 'DataParallelGradientHandler', 'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler',
|
||||
'MoeGradientHandler', 'SequenceParallelGradientHandler'
|
||||
]
|
@@ -0,0 +1,25 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseGradientHandler(ABC):
|
||||
"""A basic helper class to handle all-reduce operations of gradients across different parallel groups
|
||||
before optimization.
|
||||
|
||||
Args:
|
||||
model (Module): Model where the gradients accumulate.
|
||||
optimizer (Optimizer): Optimizer for updating the parameters.
|
||||
"""
|
||||
|
||||
def __init__(self, model, optimizer):
|
||||
self._model = model
|
||||
self._optimizer = optimizer
|
||||
|
||||
@abstractmethod
|
||||
def handle_gradient(self):
|
||||
"""A method to accumulate gradients across different parallel groups. Users should
|
||||
write their own functions or just use the functions in pre-defined subclasses.
|
||||
"""
|
||||
pass
|
@@ -0,0 +1,27 @@
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import GRADIENT_HANDLER
|
||||
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
from .utils import bucket_allreduce
|
||||
|
||||
|
||||
@GRADIENT_HANDLER.register_module
|
||||
class DataParallelGradientHandler(BaseGradientHandler):
|
||||
"""A helper class to handle all-reduce operations in a data parallel group.
|
||||
A all-reduce collective communication will be operated in
|
||||
:func:`handle_gradient` among a data parallel group.
|
||||
For better performance, it bucketizes the gradients of all parameters that are
|
||||
the same type to improve the efficiency of communication.
|
||||
|
||||
Args:
|
||||
model (Module): Model where the gradients accumulate.
|
||||
optimizer (Optimizer): Optimizer for updating the parameters.
|
||||
"""
|
||||
|
||||
def handle_gradient(self):
|
||||
"""A method running a all-reduce operation in a data parallel group.
|
||||
"""
|
||||
# TODO: add memory buffer
|
||||
if gpc.data_parallel_size > 1:
|
||||
bucket_allreduce(param_list=self._model.parameters(), group=gpc.get_group(ParallelMode.DATA))
|
@@ -0,0 +1,46 @@
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import GRADIENT_HANDLER
|
||||
from colossalai.utils.moe import get_moe_epsize_param_dict
|
||||
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
from .utils import bucket_allreduce
|
||||
|
||||
|
||||
@GRADIENT_HANDLER.register_module
|
||||
class MoeGradientHandler(BaseGradientHandler):
|
||||
"""A helper class to handle all-reduce operations in a data parallel group and
|
||||
moe model parallel. A all-reduce collective communication will be operated in
|
||||
:func:`handle_gradient` among a data parallel group.
|
||||
For better performance, it bucketizes the gradients of all parameters that are
|
||||
the same type to improve the efficiency of communication.
|
||||
|
||||
Args:
|
||||
model (Module): Model where the gradients accumulate.
|
||||
optimizer (Optimizer): Optimizer for updating the parameters.
|
||||
"""
|
||||
|
||||
def __init__(self, model, optimizer=None):
|
||||
super().__init__(model, optimizer)
|
||||
|
||||
def handle_gradient(self):
|
||||
"""A method running an all-reduce operation in a data parallel group.
|
||||
Then running an all-reduce operation for all parameters in experts
|
||||
across moe model parallel group
|
||||
"""
|
||||
global_data = gpc.data_parallel_size
|
||||
|
||||
if global_data > 1:
|
||||
epsize_param_dict = get_moe_epsize_param_dict(self._model)
|
||||
|
||||
# epsize is 1, indicating the params are replicated among processes in data parallelism
|
||||
# use the ParallelMode.DATA to get data parallel group
|
||||
# reduce gradients for all parameters in data parallelism
|
||||
if 1 in epsize_param_dict:
|
||||
bucket_allreduce(param_list=epsize_param_dict[1], group=gpc.get_group(ParallelMode.DATA))
|
||||
|
||||
for ep_size in epsize_param_dict:
|
||||
if ep_size != 1 and ep_size != MOE_CONTEXT.world_size:
|
||||
bucket_allreduce(param_list=epsize_param_dict[ep_size],
|
||||
group=MOE_CONTEXT.parallel_info_dict[ep_size].dp_group)
|
@@ -0,0 +1,53 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import GRADIENT_HANDLER
|
||||
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
|
||||
|
||||
@GRADIENT_HANDLER.register_module
|
||||
class PipelineSharedModuleGradientHandler(BaseGradientHandler):
|
||||
"""A helper class to handle all-reduce operations in sub parallel groups.
|
||||
A all-reduce collective communication will be operated in
|
||||
:func:`handle_gradient` among all sub pipeline parallel groups.
|
||||
For better performance, it bucketizes the gradients of all parameters that are
|
||||
the same type to improve the efficiency of communication.
|
||||
|
||||
Args:
|
||||
model (Module): Model where the gradients accumulate.
|
||||
optimizer (Optimizer): Optimizer for updating the parameters.
|
||||
"""
|
||||
|
||||
def handle_gradient(self):
|
||||
"""A method running a all-reduce operation in sub pipeline parallel groups.
|
||||
"""
|
||||
if gpc.pipeline_parallel_size > 1:
|
||||
# bucketize and all-reduce
|
||||
buckets = defaultdict(lambda: defaultdict(list))
|
||||
# Pack the buckets.
|
||||
for param in self._model.parameters():
|
||||
group = getattr(param, 'pipeline_shared_module_pg', None)
|
||||
if param.requires_grad and group is not None and (
|
||||
(hasattr(param, 'colo_attr') and not param.colo_attr.saved_grad.is_null())
|
||||
or param.grad is not None):
|
||||
tp = param.data.type()
|
||||
buckets[group][tp].append(param)
|
||||
|
||||
# For each bucket, all-reduce and copy all-reduced grads.
|
||||
for group, group_buckets in buckets.items():
|
||||
for tp, bucket in group_buckets.items():
|
||||
grads = [
|
||||
param.colo_attr.grad_payload if hasattr(param, 'colo_attr') else param.grad.data
|
||||
for param in bucket
|
||||
]
|
||||
coalesced = _flatten_dense_tensors(grads).to(torch.cuda.current_device())
|
||||
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=group)
|
||||
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
||||
buf.copy_(synced)
|
@@ -0,0 +1,26 @@
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import GRADIENT_HANDLER
|
||||
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
from .utils import bucket_allreduce
|
||||
|
||||
|
||||
@GRADIENT_HANDLER.register_module
|
||||
class SequenceParallelGradientHandler(BaseGradientHandler):
|
||||
"""A helper class to handle all-reduce operations in a data parallel group.
|
||||
A all-reduce collective communication will be operated in
|
||||
:func:`handle_gradient` among a data parallel group.
|
||||
For better performance, it bucketizes the gradients of all parameters that are
|
||||
the same type to improve the efficiency of communication.
|
||||
|
||||
Args:
|
||||
model (Module): Model where the gradients accumulate.
|
||||
optimizer (Optimizer): Optimizer for updating the parameters.
|
||||
"""
|
||||
|
||||
def handle_gradient(self):
|
||||
"""A method running a all-reduce operation in a data parallel group.
|
||||
"""
|
||||
if gpc.get_world_size(ParallelMode.SEQUENCE_DP) > 1:
|
||||
bucket_allreduce(param_list=self._model.parameters(), group=gpc.get_group(ParallelMode.SEQUENCE_DP))
|
@@ -0,0 +1,21 @@
|
||||
from colossalai.registry import GRADIENT_HANDLER
|
||||
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
|
||||
|
||||
@GRADIENT_HANDLER.register_module
|
||||
class ZeROGradientHandler(BaseGradientHandler):
|
||||
"""A helper class to handle all-reduce operations in a data parallel group.
|
||||
A all-reduce collective communication will be operated in
|
||||
:func:`handle_gradient` among a data parallel group.
|
||||
This class is specialized with ZeRO optimization.
|
||||
|
||||
Args:
|
||||
model (Module): Model where the gradients accumulate.
|
||||
optimizer (Optimizer): Optimizer for updating the parameters.
|
||||
"""
|
||||
|
||||
def handle_gradient(self):
|
||||
"""A method running a all-reduce operation in a data parallel group.
|
||||
"""
|
||||
self._optimizer.sync_grad()
|
30
colossalai/legacy/engine/gradient_handler/utils.py
Normal file
30
colossalai/legacy/engine/gradient_handler/utils.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from typing import Iterable
|
||||
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
|
||||
def bucket_allreduce(param_list: Iterable[nn.Parameter], group=None):
|
||||
# get communication world size
|
||||
comm_size = dist.get_world_size(group)
|
||||
# bucketize and all-reduce
|
||||
buckets = {}
|
||||
# Pack the buckets.
|
||||
for param in param_list:
|
||||
if param.requires_grad and param.grad is not None:
|
||||
tp = param.data.type()
|
||||
if tp not in buckets:
|
||||
buckets[tp] = []
|
||||
buckets[tp].append(param)
|
||||
|
||||
# For each bucket, all-reduce and copy all-reduced grads.
|
||||
for tp in buckets:
|
||||
bucket = buckets[tp]
|
||||
grads = [param.grad.data for param in bucket]
|
||||
coalesced = _flatten_dense_tensors(grads)
|
||||
coalesced /= comm_size
|
||||
|
||||
dist.all_reduce(coalesced, group=group)
|
||||
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
||||
buf.copy_(synced)
|
Reference in New Issue
Block a user