mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 13:59:08 +00:00
[legacy] move engine to legacy (#4560)
* [legacy] move engine to legacy * [example] fix seq parallel example * [example] fix seq parallel example * [test] test gemini pluging hang * [test] test gemini pluging hang * [test] test gemini pluging hang * [test] test gemini pluging hang * [test] test gemini pluging hang * [example] update seq parallel requirements
This commit is contained in:
@@ -0,0 +1,21 @@
|
||||
from colossalai.registry import GRADIENT_HANDLER
|
||||
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
|
||||
|
||||
@GRADIENT_HANDLER.register_module
|
||||
class ZeROGradientHandler(BaseGradientHandler):
|
||||
"""A helper class to handle all-reduce operations in a data parallel group.
|
||||
A all-reduce collective communication will be operated in
|
||||
:func:`handle_gradient` among a data parallel group.
|
||||
This class is specialized with ZeRO optimization.
|
||||
|
||||
Args:
|
||||
model (Module): Model where the gradients accumulate.
|
||||
optimizer (Optimizer): Optimizer for updating the parameters.
|
||||
"""
|
||||
|
||||
def handle_gradient(self):
|
||||
"""A method running a all-reduce operation in a data parallel group.
|
||||
"""
|
||||
self._optimizer.sync_grad()
|
Reference in New Issue
Block a user