mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[gemini] gemini supports lazy init (#3379)
* [gemini] fix nvme optimizer init * [gemini] gemini supports lazy init * [gemini] add init example * [gemini] add fool model * [zero] update gemini ddp * [zero] update init example * add chunk method * add chunk method * [lazyinit] fix lazy tensor tolist * [gemini] fix buffer materialization * [misc] remove useless file * [booster] update gemini plugin * [test] update gemini plugin test * [test] fix gemini plugin test * [gemini] fix import * [gemini] fix import * [lazyinit] use new metatensor * [lazyinit] use new metatensor * [lazyinit] fix __set__ method
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import itertools
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -14,6 +14,7 @@ from colossalai.tensor import ReplicaSpec
|
||||
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.utils import get_current_device, is_ddp_ignored
|
||||
from colossalai.utils.model.experimental import LazyTensor
|
||||
|
||||
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
|
||||
from .gemini_hook import GeminiZeROHook
|
||||
@@ -55,7 +56,6 @@ class ZeroDDP(ColoDDP):
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False,
|
||||
strict_ddp_mode: bool = False) -> None:
|
||||
super().__init__(module, process_group=ColoProcessGroup())
|
||||
self.gemini_manager = gemini_manager
|
||||
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
|
||||
self.force_outputs_fp32 = force_outputs_fp32
|
||||
@@ -67,7 +67,6 @@ class ZeroDDP(ColoDDP):
|
||||
self.param2name: Dict[nn.Parameter, str] = dict()
|
||||
self.name2param: Dict[str, nn.Parameter] = dict()
|
||||
|
||||
self._cast_buffers()
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
if self.gemini_manager._premade_memstats_:
|
||||
@@ -91,6 +90,8 @@ class ZeroDDP(ColoDDP):
|
||||
for p_name, p_var in m_var.named_parameters(recurse=False):
|
||||
param_name = m_name + '.' + p_name if m_name else p_name
|
||||
self.name2param[param_name] = p_var
|
||||
super().__init__(module, process_group=ColoProcessGroup())
|
||||
self._cast_buffers()
|
||||
|
||||
def _post_forward(self):
|
||||
"""This function is only triggered for inference.
|
||||
@@ -478,7 +479,8 @@ class ZeroDDP(ColoDDP):
|
||||
def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool):
|
||||
ddp_pg = ColoProcessGroup()
|
||||
for p in param_order.generate():
|
||||
assert isinstance(p, ColoParameter)
|
||||
self._preprocess_param(p)
|
||||
assert type(p) is ColoParameter
|
||||
|
||||
# gather sharded parameters in the strict ddp mode
|
||||
if strict_ddp_mode:
|
||||
@@ -531,10 +533,27 @@ class ZeroDDP(ColoDDP):
|
||||
|
||||
def _cast_buffers(self):
|
||||
for buffer in self.module.buffers():
|
||||
if isinstance(buffer, LazyTensor):
|
||||
buffer.materialize()
|
||||
buffer.data = buffer.cuda()
|
||||
if torch.is_floating_point(buffer):
|
||||
buffer.data = buffer.half()
|
||||
|
||||
def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, 'LazyTensor']) -> None:
|
||||
"""Convert parameter to ColoParameter in-place.
|
||||
Args:
|
||||
p (Union[nn.Parameter, ColoParameter, LazyTensor]): parameter to be converted
|
||||
"""
|
||||
if type(p) is ColoParameter:
|
||||
# model is initialized with ColoInitContext
|
||||
return
|
||||
requires_grad = p.requires_grad
|
||||
if isinstance(p, LazyTensor):
|
||||
# model is initialized with LazyInitContext
|
||||
p.materialize()
|
||||
p.__class__ = ColoParameter
|
||||
p.__init__(p, requires_grad=requires_grad)
|
||||
|
||||
|
||||
class GeminiDDP(ZeroDDP):
|
||||
|
||||
|
Reference in New Issue
Block a user