[Gemini] ZeROHookV2 -> GeminiZeROHook (#1972)

This commit is contained in:
Jiarui Fang
2022-11-17 14:43:49 +08:00
committed by GitHub
parent f8a7148dec
commit cc0ed7cf33
4 changed files with 12 additions and 10 deletions

View File

@@ -1,11 +1,13 @@
import torch
from colossalai.tensor.param_op_hook import ParamOpHook
from colossalai.gemini import TensorState
from enum import Enum
from typing import List
from contextlib import contextmanager
from enum import Enum
from functools import partial
from typing import List
import torch
from colossalai.gemini import TensorState
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.tensor.param_op_hook import ParamOpHook
class TrainingPhase(Enum):
@@ -13,7 +15,7 @@ class TrainingPhase(Enum):
BACKWARD = 1
class ZeROHookV2(ParamOpHook):
class GeminiZeROHook(ParamOpHook):
def __init__(self, gemini_manager: GeminiManager) -> None:
super().__init__()