mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[gemini] zero supports gemini (#1093)
* add placement policy * add gemini mgr * update mem stats collector * update zero * update zero optim * fix bugs * zero optim monitor os * polish unit test * polish unit test * add assert
This commit is contained in:
@@ -5,6 +5,7 @@ from enum import Enum
|
||||
from typing import List
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
|
||||
|
||||
class TrainingPhase(Enum):
|
||||
@@ -14,9 +15,10 @@ class TrainingPhase(Enum):
|
||||
|
||||
class ZeROHookV2(ParamOpHook):
|
||||
|
||||
def __init__(self, chunk_manager: ChunkManager) -> None:
|
||||
def __init__(self, gemini_manager: GeminiManager) -> None:
|
||||
super().__init__()
|
||||
self._chunk_manager = chunk_manager
|
||||
self._gemini_manager = gemini_manager
|
||||
self._chunk_manager = gemini_manager.chunk_manager
|
||||
self._training_phase = TrainingPhase.FORWARD
|
||||
|
||||
def pre_op(self, params):
|
||||
@@ -24,9 +26,11 @@ class ZeROHookV2(ParamOpHook):
|
||||
for p in params:
|
||||
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
|
||||
self._chunk_manager.exec_lazy_release()
|
||||
# TODO: evict chunks
|
||||
self._gemini_manager.sample_overall_data()
|
||||
self._gemini_manager.adjust_layout(chunks, 'fp16_param')
|
||||
for chunk in chunks:
|
||||
self._chunk_manager.access_chunk(chunk)
|
||||
self._gemini_manager.sample_model_data()
|
||||
|
||||
def post_op(self, params):
|
||||
for p in params:
|
||||
|
@@ -7,6 +7,7 @@ from typing import Dict
|
||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.utils import get_current_device, disposable
|
||||
|
||||
|
||||
class OptimState(Enum):
|
||||
@@ -19,6 +20,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
def __init__(self,
|
||||
optim: Optimizer,
|
||||
module: ColoDDPV2,
|
||||
gpu_margin_mem_ratio: float = 0.0,
|
||||
initial_scale: float = 2**32,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
@@ -29,6 +31,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
super().__init__(optim)
|
||||
assert isinstance(module, ColoDDPV2)
|
||||
self.module = module
|
||||
self.gemini_manager = module.gemini_manager
|
||||
self.chunk_manager = self.gemini_manager.chunk_manager
|
||||
self.optim_state = OptimState.UNSCALED
|
||||
self.fp16_param_to_fp32_param: Dict[torch.Tensor, torch.Tensor] = {}
|
||||
for p, fp32_p in zip(module.parameters(), module.fp32_params):
|
||||
@@ -45,6 +49,18 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=torch.cuda.current_device())
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
|
||||
assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0'
|
||||
# Only move fp32 shards from CPU to GPU when user allows and inner optimizer is valid
|
||||
# Inner optimizer must support optimizing hybrid (CPU and CUDA) tensors,
|
||||
# and it must set `num_fp32_shards_per_param` correctly
|
||||
self._should_move_fp32_params_h2d: bool = self.gemini_manager.is_cuda_margin_mem_avail and self.gpu_margin_mem_ratio > 0.0 and getattr(
|
||||
optim, 'num_fp32_shards_per_param', 0) >= 2
|
||||
if self.gpu_margin_mem_ratio > 0.0 and not self.gemini_manager.is_cuda_margin_mem_avail:
|
||||
self._logger.warning(f'gpu_margin_mem_ratio is meaningless when placement_policy is not "auto"', ranks=[0])
|
||||
|
||||
self._register_states = disposable(self._register_states_)
|
||||
|
||||
def _update_params_ptr(self):
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
@@ -82,6 +98,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
return self.optim.zero_grad(set_to_none=True)
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
self._maybe_move_fp32_params()
|
||||
# unscale grads if scaled
|
||||
if self.optim_state == OptimState.SCALED:
|
||||
self._unscale_grads()
|
||||
@@ -94,6 +111,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
return
|
||||
self._update_params_ptr()
|
||||
ret = self.optim.step(*args, **kwargs)
|
||||
self._register_states()
|
||||
self._update_fp16_params()
|
||||
return ret
|
||||
|
||||
@@ -109,3 +127,29 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
|
||||
def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor):
|
||||
self.module.backward_by_grad(tensor, grad)
|
||||
|
||||
def _maybe_move_fp32_params(self):
|
||||
if self._should_move_fp32_params_h2d:
|
||||
self._should_move_fp32_params_h2d = False
|
||||
available_cuda_margin_mem = self.gemini_manager.cuda_margin_mem * self.gpu_margin_mem_ratio
|
||||
fp32_params_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param
|
||||
fp32_params_used_cuda_margin_mem = 0
|
||||
for fp16_param_chunk, fp32_param_chunk in zip(self.chunk_manager.chunk_groups['fp16_param'],
|
||||
self.chunk_manager.chunk_groups['fp32_param']):
|
||||
if fp32_param_chunk.is_free:
|
||||
continue
|
||||
if fp32_params_used_cuda_margin_mem + fp32_param_chunk.mem < fp32_params_available_cuda_margin_mem:
|
||||
self.chunk_manager.move_chunk(fp32_param_chunk, get_current_device())
|
||||
# stores grad now
|
||||
self.chunk_manager.move_chunk(fp16_param_chunk, get_current_device())
|
||||
self.module._set_chunk_grad_device(fp16_param_chunk, get_current_device())
|
||||
fp32_params_used_cuda_margin_mem += fp32_param_chunk.mem
|
||||
self.module._setup_grads_ptr()
|
||||
|
||||
def _register_states_(self):
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
state = self.optim.state[p]
|
||||
for val in state.values():
|
||||
if isinstance(val, torch.Tensor):
|
||||
self.chunk_manager.add_extern_static_tensor(val)
|
||||
|
Reference in New Issue
Block a user