mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 06:30:41 +00:00
[gemini] zero supports gemini (#1093)
* add placement policy * add gemini mgr * update mem stats collector * update zero * update zero optim * fix bugs * zero optim monitor os * polish unit test * polish unit test * add assert
This commit is contained in:
@@ -4,8 +4,11 @@ from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
from functools import partial
|
||||
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
|
||||
from colossalai.tensor.chunk import ChunkManager, TensorState
|
||||
from colossalai.tensor.chunk import ChunkManager, TensorState, Chunk
|
||||
from colossalai.tensor.param_op_hook import use_param_op_hooks
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from typing import Dict
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
|
||||
def free_storage(data: torch.Tensor) -> None:
|
||||
@@ -89,12 +92,14 @@ class ColoDDP(torch.nn.Module):
|
||||
|
||||
class ColoDDPV2(ColoDDP):
|
||||
|
||||
def __init__(self, module: torch.nn.Module, chunk_manager: ChunkManager) -> None:
|
||||
def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> None:
|
||||
super().__init__(module)
|
||||
self.chunk_manager = chunk_manager
|
||||
self.param_op_hook = ZeROHookV2(chunk_manager)
|
||||
self.gemini_manager = gemini_manager
|
||||
self.chunk_manager = gemini_manager.chunk_manager
|
||||
self.param_op_hook = ZeROHookV2(gemini_manager)
|
||||
self.fp32_params = []
|
||||
self.overflow_counter = 0
|
||||
self.grads_device: Dict[torch.Tensor, torch.device] = {}
|
||||
# TODO: get param order and filter unused params
|
||||
for p in module.parameters():
|
||||
assert p.dtype == torch.half
|
||||
@@ -102,22 +107,32 @@ class ColoDDPV2(ColoDDP):
|
||||
self.chunk_manager.append_tensor(p, 'fp16_param')
|
||||
self.chunk_manager.append_tensor(fp32_p, 'fp32_param')
|
||||
self.fp32_params.append(fp32_p)
|
||||
self.grads_device[p] = self.gemini_manager.default_device
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
self.module.zero_grad(set_to_none=True)
|
||||
self.gemini_manager.pre_iter()
|
||||
with use_param_op_hooks(self.param_op_hook):
|
||||
outputs = self.module(*args, **kwargs)
|
||||
self.chunk_manager.exec_lazy_release()
|
||||
return outputs
|
||||
|
||||
def _post_backward(self):
|
||||
self.chunk_manager.exec_lazy_release()
|
||||
def _setup_grads_ptr(self):
|
||||
for p in self.module.parameters():
|
||||
if self.chunk_manager.get_chunk(p).is_free or not p.requires_grad:
|
||||
p.grad = None
|
||||
else:
|
||||
p.grad = p.data
|
||||
|
||||
def _post_backward(self):
|
||||
self.chunk_manager.exec_lazy_release()
|
||||
self._setup_grads_ptr()
|
||||
self._logger.info(
|
||||
f'layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, PCIE move vol: {self.gemini_manager._cpu_gpu_move_volume}B'
|
||||
)
|
||||
self.gemini_manager.post_iter()
|
||||
|
||||
def backward(self, loss: torch.Tensor):
|
||||
with self.param_op_hook.switch_to_backward(), use_param_op_hooks(self.param_op_hook):
|
||||
loss.backward()
|
||||
@@ -141,7 +156,12 @@ class ColoDDPV2(ColoDDP):
|
||||
self.chunk_manager.release_chunk(chunk)
|
||||
if reduced and not chunk.is_free:
|
||||
self.overflow_counter += chunk.has_inf_or_nan
|
||||
self.chunk_manager.move_chunk(chunk, self.grads_device[p])
|
||||
return empty_grad
|
||||
|
||||
def zero_grad(self, set_to_none: bool = False) -> None:
|
||||
self.module.zero_grad(set_to_none=True)
|
||||
|
||||
def _set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None:
|
||||
for tensor in chunk.get_tensors():
|
||||
self.grads_device[tensor] = device
|
||||
|
Reference in New Issue
Block a user