[gemini] zero supports gemini (#1093)

* add placement policy

* add gemini mgr

* update mem stats collector

* update zero

* update zero optim

* fix bugs

* zero optim monitor os

* polish unit test

* polish unit test

* add assert
This commit is contained in:
ver217
2022-06-10 14:48:28 +08:00
committed by GitHub
parent 2b2dc1c86b
commit 1f894e033f
9 changed files with 366 additions and 12 deletions

View File

@@ -5,6 +5,7 @@ from enum import Enum
from typing import List
from contextlib import contextmanager
from functools import partial
from colossalai.gemini.gemini_mgr import GeminiManager
class TrainingPhase(Enum):
@@ -14,9 +15,10 @@ class TrainingPhase(Enum):
class ZeROHookV2(ParamOpHook):
def __init__(self, chunk_manager: ChunkManager) -> None:
def __init__(self, gemini_manager: GeminiManager) -> None:
super().__init__()
self._chunk_manager = chunk_manager
self._gemini_manager = gemini_manager
self._chunk_manager = gemini_manager.chunk_manager
self._training_phase = TrainingPhase.FORWARD
def pre_op(self, params):
@@ -24,9 +26,11 @@ class ZeROHookV2(ParamOpHook):
for p in params:
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
self._chunk_manager.exec_lazy_release()
# TODO: evict chunks
self._gemini_manager.sample_overall_data()
self._gemini_manager.adjust_layout(chunks, 'fp16_param')
for chunk in chunks:
self._chunk_manager.access_chunk(chunk)
self._gemini_manager.sample_model_data()
def post_op(self, params):
for p in params:

View File

@@ -7,6 +7,7 @@ from typing import Dict
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import get_current_device, disposable
class OptimState(Enum):
@@ -19,6 +20,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
def __init__(self,
optim: Optimizer,
module: ColoDDPV2,
gpu_margin_mem_ratio: float = 0.0,
initial_scale: float = 2**32,
min_scale: float = 1,
growth_factor: float = 2,
@@ -29,6 +31,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
super().__init__(optim)
assert isinstance(module, ColoDDPV2)
self.module = module
self.gemini_manager = module.gemini_manager
self.chunk_manager = self.gemini_manager.chunk_manager
self.optim_state = OptimState.UNSCALED
self.fp16_param_to_fp32_param: Dict[torch.Tensor, torch.Tensor] = {}
for p, fp32_p in zip(module.parameters(), module.fp32_params):
@@ -45,6 +49,18 @@ class ZeroOptimizer(ColossalaiOptimizer):
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=torch.cuda.current_device())
self._logger = get_dist_logger()
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0'
# Only move fp32 shards from CPU to GPU when user allows and inner optimizer is valid
# Inner optimizer must support optimizing hybrid (CPU and CUDA) tensors,
# and it must set `num_fp32_shards_per_param` correctly
self._should_move_fp32_params_h2d: bool = self.gemini_manager.is_cuda_margin_mem_avail and self.gpu_margin_mem_ratio > 0.0 and getattr(
optim, 'num_fp32_shards_per_param', 0) >= 2
if self.gpu_margin_mem_ratio > 0.0 and not self.gemini_manager.is_cuda_margin_mem_avail:
self._logger.warning(f'gpu_margin_mem_ratio is meaningless when placement_policy is not "auto"', ranks=[0])
self._register_states = disposable(self._register_states_)
def _update_params_ptr(self):
for group in self.optim.param_groups:
for p in group['params']:
@@ -82,6 +98,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
return self.optim.zero_grad(set_to_none=True)
def step(self, *args, **kwargs):
self._maybe_move_fp32_params()
# unscale grads if scaled
if self.optim_state == OptimState.SCALED:
self._unscale_grads()
@@ -94,6 +111,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
return
self._update_params_ptr()
ret = self.optim.step(*args, **kwargs)
self._register_states()
self._update_fp16_params()
return ret
@@ -109,3 +127,29 @@ class ZeroOptimizer(ColossalaiOptimizer):
def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor):
self.module.backward_by_grad(tensor, grad)
def _maybe_move_fp32_params(self):
if self._should_move_fp32_params_h2d:
self._should_move_fp32_params_h2d = False
available_cuda_margin_mem = self.gemini_manager.cuda_margin_mem * self.gpu_margin_mem_ratio
fp32_params_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param
fp32_params_used_cuda_margin_mem = 0
for fp16_param_chunk, fp32_param_chunk in zip(self.chunk_manager.chunk_groups['fp16_param'],
self.chunk_manager.chunk_groups['fp32_param']):
if fp32_param_chunk.is_free:
continue
if fp32_params_used_cuda_margin_mem + fp32_param_chunk.mem < fp32_params_available_cuda_margin_mem:
self.chunk_manager.move_chunk(fp32_param_chunk, get_current_device())
# stores grad now
self.chunk_manager.move_chunk(fp16_param_chunk, get_current_device())
self.module._set_chunk_grad_device(fp16_param_chunk, get_current_device())
fp32_params_used_cuda_margin_mem += fp32_param_chunk.mem
self.module._setup_grads_ptr()
def _register_states_(self):
for group in self.optim.param_groups:
for p in group['params']:
state = self.optim.state[p]
for val in state.values():
if isinstance(val, torch.Tensor):
self.chunk_manager.add_extern_static_tensor(val)