[gemini] zero supports gemini (#1093)

* add placement policy

* add gemini mgr

* update mem stats collector

* update zero

* update zero optim

* fix bugs

* zero optim monitor os

* polish unit test

* polish unit test

* add assert
This commit is contained in:
ver217
2022-06-10 14:48:28 +08:00
committed by GitHub
parent 2b2dc1c86b
commit 1f894e033f
9 changed files with 366 additions and 12 deletions

View File

@@ -4,8 +4,11 @@ from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from functools import partial
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
from colossalai.tensor.chunk import ChunkManager, TensorState
from colossalai.tensor.chunk import ChunkManager, TensorState, Chunk
from colossalai.tensor.param_op_hook import use_param_op_hooks
from colossalai.gemini.gemini_mgr import GeminiManager
from typing import Dict
from colossalai.logging import get_dist_logger
def free_storage(data: torch.Tensor) -> None:
@@ -89,12 +92,14 @@ class ColoDDP(torch.nn.Module):
class ColoDDPV2(ColoDDP):
def __init__(self, module: torch.nn.Module, chunk_manager: ChunkManager) -> None:
def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> None:
super().__init__(module)
self.chunk_manager = chunk_manager
self.param_op_hook = ZeROHookV2(chunk_manager)
self.gemini_manager = gemini_manager
self.chunk_manager = gemini_manager.chunk_manager
self.param_op_hook = ZeROHookV2(gemini_manager)
self.fp32_params = []
self.overflow_counter = 0
self.grads_device: Dict[torch.Tensor, torch.device] = {}
# TODO: get param order and filter unused params
for p in module.parameters():
assert p.dtype == torch.half
@@ -102,22 +107,32 @@ class ColoDDPV2(ColoDDP):
self.chunk_manager.append_tensor(p, 'fp16_param')
self.chunk_manager.append_tensor(fp32_p, 'fp32_param')
self.fp32_params.append(fp32_p)
self.grads_device[p] = self.gemini_manager.default_device
self._logger = get_dist_logger()
def forward(self, *args, **kwargs):
self.module.zero_grad(set_to_none=True)
self.gemini_manager.pre_iter()
with use_param_op_hooks(self.param_op_hook):
outputs = self.module(*args, **kwargs)
self.chunk_manager.exec_lazy_release()
return outputs
def _post_backward(self):
self.chunk_manager.exec_lazy_release()
def _setup_grads_ptr(self):
for p in self.module.parameters():
if self.chunk_manager.get_chunk(p).is_free or not p.requires_grad:
p.grad = None
else:
p.grad = p.data
def _post_backward(self):
self.chunk_manager.exec_lazy_release()
self._setup_grads_ptr()
self._logger.info(
f'layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, PCIE move vol: {self.gemini_manager._cpu_gpu_move_volume}B'
)
self.gemini_manager.post_iter()
def backward(self, loss: torch.Tensor):
with self.param_op_hook.switch_to_backward(), use_param_op_hooks(self.param_op_hook):
loss.backward()
@@ -141,7 +156,12 @@ class ColoDDPV2(ColoDDP):
self.chunk_manager.release_chunk(chunk)
if reduced and not chunk.is_free:
self.overflow_counter += chunk.has_inf_or_nan
self.chunk_manager.move_chunk(chunk, self.grads_device[p])
return empty_grad
def zero_grad(self, set_to_none: bool = False) -> None:
self.module.zero_grad(set_to_none=True)
def _set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None:
for tensor in chunk.get_tensors():
self.grads_device[tensor] = device