[gemini] gemini supports lazy init (#3379)

* [gemini] fix nvme optimizer init

* [gemini] gemini supports lazy init

* [gemini] add init example

* [gemini] add fool model

* [zero] update gemini ddp

* [zero] update init example

* add chunk method

* add chunk method

* [lazyinit] fix lazy tensor tolist

* [gemini] fix buffer materialization

* [misc] remove useless file

* [booster] update gemini plugin

* [test] update gemini plugin test

* [test] fix gemini plugin test

* [gemini] fix import

* [gemini] fix import

* [lazyinit] use new metatensor

* [lazyinit] use new metatensor

* [lazyinit] fix __set__ method
This commit is contained in:
Hongxin Liu
2023-04-12 16:03:25 +08:00
committed by GitHub
parent 366a035552
commit 152239bbfa
7 changed files with 80 additions and 72 deletions

View File

@@ -1,7 +1,7 @@
import itertools
from collections import OrderedDict
from functools import partial
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union
import torch
import torch.distributed as dist
@@ -14,6 +14,7 @@ from colossalai.tensor import ReplicaSpec
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import get_current_device, is_ddp_ignored
from colossalai.utils.model.experimental import LazyTensor
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
from .gemini_hook import GeminiZeROHook
@@ -55,7 +56,6 @@ class ZeroDDP(ColoDDP):
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False) -> None:
super().__init__(module, process_group=ColoProcessGroup())
self.gemini_manager = gemini_manager
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
self.force_outputs_fp32 = force_outputs_fp32
@@ -67,7 +67,6 @@ class ZeroDDP(ColoDDP):
self.param2name: Dict[nn.Parameter, str] = dict()
self.name2param: Dict[str, nn.Parameter] = dict()
self._cast_buffers()
self._logger = get_dist_logger()
if self.gemini_manager._premade_memstats_:
@@ -91,6 +90,8 @@ class ZeroDDP(ColoDDP):
for p_name, p_var in m_var.named_parameters(recurse=False):
param_name = m_name + '.' + p_name if m_name else p_name
self.name2param[param_name] = p_var
super().__init__(module, process_group=ColoProcessGroup())
self._cast_buffers()
def _post_forward(self):
"""This function is only triggered for inference.
@@ -478,7 +479,8 @@ class ZeroDDP(ColoDDP):
def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool):
ddp_pg = ColoProcessGroup()
for p in param_order.generate():
assert isinstance(p, ColoParameter)
self._preprocess_param(p)
assert type(p) is ColoParameter
# gather sharded parameters in the strict ddp mode
if strict_ddp_mode:
@@ -531,10 +533,27 @@ class ZeroDDP(ColoDDP):
def _cast_buffers(self):
for buffer in self.module.buffers():
if isinstance(buffer, LazyTensor):
buffer.materialize()
buffer.data = buffer.cuda()
if torch.is_floating_point(buffer):
buffer.data = buffer.half()
def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, 'LazyTensor']) -> None:
"""Convert parameter to ColoParameter in-place.
Args:
p (Union[nn.Parameter, ColoParameter, LazyTensor]): parameter to be converted
"""
if type(p) is ColoParameter:
# model is initialized with ColoInitContext
return
requires_grad = p.requires_grad
if isinstance(p, LazyTensor):
# model is initialized with LazyInitContext
p.materialize()
p.__class__ = ColoParameter
p.__init__(p, requires_grad=requires_grad)
class GeminiDDP(ZeroDDP):