[gemini] zero supports gemini (#1093)

* add placement policy

* add gemini mgr

* update mem stats collector

* update zero

* update zero optim

* fix bugs

* zero optim monitor os

* polish unit test

* polish unit test

* add assert
This commit is contained in:
ver217
2022-06-10 14:48:28 +08:00
committed by GitHub
parent 2b2dc1c86b
commit 1f894e033f
9 changed files with 366 additions and 12 deletions

View File

@@ -5,6 +5,7 @@ from enum import Enum
from typing import List
from contextlib import contextmanager
from functools import partial
from colossalai.gemini.gemini_mgr import GeminiManager
class TrainingPhase(Enum):
@@ -14,9 +15,10 @@ class TrainingPhase(Enum):
class ZeROHookV2(ParamOpHook):
def __init__(self, chunk_manager: ChunkManager) -> None:
def __init__(self, gemini_manager: GeminiManager) -> None:
super().__init__()
self._chunk_manager = chunk_manager
self._gemini_manager = gemini_manager
self._chunk_manager = gemini_manager.chunk_manager
self._training_phase = TrainingPhase.FORWARD
def pre_op(self, params):
@@ -24,9 +26,11 @@ class ZeROHookV2(ParamOpHook):
for p in params:
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
self._chunk_manager.exec_lazy_release()
# TODO: evict chunks
self._gemini_manager.sample_overall_data()
self._gemini_manager.adjust_layout(chunks, 'fp16_param')
for chunk in chunks:
self._chunk_manager.access_chunk(chunk)
self._gemini_manager.sample_model_data()
def post_op(self, params):
for p in params: