mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[legacy] move engine to legacy (#4560)
* [legacy] move engine to legacy * [example] fix seq parallel example * [example] fix seq parallel example * [test] test gemini pluging hang * [test] test gemini pluging hang * [test] test gemini pluging hang * [test] test gemini pluging hang * [test] test gemini pluging hang * [example] update seq parallel requirements
This commit is contained in:
@@ -0,0 +1,25 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseGradientHandler(ABC):
|
||||
"""A basic helper class to handle all-reduce operations of gradients across different parallel groups
|
||||
before optimization.
|
||||
|
||||
Args:
|
||||
model (Module): Model where the gradients accumulate.
|
||||
optimizer (Optimizer): Optimizer for updating the parameters.
|
||||
"""
|
||||
|
||||
def __init__(self, model, optimizer):
|
||||
self._model = model
|
||||
self._optimizer = optimizer
|
||||
|
||||
@abstractmethod
|
||||
def handle_gradient(self):
|
||||
"""A method to accumulate gradients across different parallel groups. Users should
|
||||
write their own functions or just use the functions in pre-defined subclasses.
|
||||
"""
|
||||
pass
|
Reference in New Issue
Block a user