[legacy] move engine to legacy (#4560)

* [legacy] move engine to legacy

* [example] fix seq parallel example

* [example] fix seq parallel example

* [test] test gemini pluging hang

* [test] test gemini pluging hang

* [test] test gemini pluging hang

* [test] test gemini pluging hang

* [test] test gemini pluging hang

* [example] update seq parallel requirements
This commit is contained in:
Hongxin Liu
2023-09-04 11:33:40 +08:00
parent 89fe027787
commit 8accecd55b
39 changed files with 93 additions and 105 deletions

View File

@@ -0,0 +1,25 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod
class BaseGradientHandler(ABC):
"""A basic helper class to handle all-reduce operations of gradients across different parallel groups
before optimization.
Args:
model (Module): Model where the gradients accumulate.
optimizer (Optimizer): Optimizer for updating the parameters.
"""
def __init__(self, model, optimizer):
self._model = model
self._optimizer = optimizer
@abstractmethod
def handle_gradient(self):
"""A method to accumulate gradients across different parallel groups. Users should
write their own functions or just use the functions in pre-defined subclasses.
"""
pass