mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[Gemini] ZeROHookV2 -> GeminiZeROHook (#1972)
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
import torch
|
||||
from colossalai.tensor.param_op_hook import ParamOpHook
|
||||
from colossalai.gemini import TensorState
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.gemini import TensorState
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.tensor.param_op_hook import ParamOpHook
|
||||
|
||||
|
||||
class TrainingPhase(Enum):
|
||||
@@ -13,7 +15,7 @@ class TrainingPhase(Enum):
|
||||
BACKWARD = 1
|
||||
|
||||
|
||||
class ZeROHookV2(ParamOpHook):
|
||||
class GeminiZeROHook(ParamOpHook):
|
||||
|
||||
def __init__(self, gemini_manager: GeminiManager) -> None:
|
||||
super().__init__()
|
Reference in New Issue
Block a user