[gemini] gemini supports lazy init (#3379)

* [gemini] fix nvme optimizer init

* [gemini] gemini supports lazy init

* [gemini] add init example

* [gemini] add fool model

* [zero] update gemini ddp

* [zero] update init example

* add chunk method

* add chunk method

* [lazyinit] fix lazy tensor tolist

* [gemini] fix buffer materialization

* [misc] remove useless file

* [booster] update gemini plugin

* [test] update gemini plugin test

* [test] fix gemini plugin test

* [gemini] fix import

* [gemini] fix import

* [lazyinit] use new metatensor

* [lazyinit] use new metatensor

* [lazyinit] fix __set__ method
This commit is contained in:
Hongxin Liu
2023-04-12 16:03:25 +08:00
committed by GitHub
parent 366a035552
commit 152239bbfa
7 changed files with 80 additions and 72 deletions

View File

@@ -16,10 +16,8 @@ from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io.utils import save_state_dict
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.utils import get_current_device
from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.zero.gemini.colo_init_context import _convert_to_coloparam
from colossalai.zero.gemini.memory_tracer import MemStats
from .plugin_base import Plugin
@@ -27,50 +25,6 @@ from .plugin_base import Plugin
__all__ = ['GeminiPlugin']
def convert_to_colo_param(module: nn.Module) -> None:
"""Convert module's paramters to ColoParameter. This is a workaround and will be deprecated when lazy init is compatible with Gemini.
Args:
module (nn.Module): Module to be converted.
"""
converted_modules = set() # handle shared modules
converted_params = dict() # record mapping between (torch.Tensor, ColoTensor) to distinguish the same reference
def convert_recursively(m: nn.Module):
for child in m.children():
if child not in converted_modules:
converted_modules.add(child)
convert_recursively(child)
for name, p in m.named_parameters(recurse=False):
assert not isinstance(p, ColoParameter)
if p in converted_params:
target = converted_params[p]
else:
target = _convert_to_coloparam(p, p.device, p.dtype)
converted_params[p] = target
setattr(m, name, target)
target.shared_param_modules.append(m)
convert_recursively(module)
# optimizer should replace params in group as well. This attr should be deleted after replacing to avoid memory leak
module._converted_params = converted_params
def replace_param_in_group(optimizer: Optimizer, converted_params: dict) -> None:
"""Replace param in optimizer's group with converted ColoParameter.
Args:
optimizer (Optimizer): Optimizer to be replaced.
converted_params (dict): Mapping between (torch.Tensor, ColoTensor).
"""
for group in optimizer.param_groups:
for i, p in enumerate(group['params']):
if p in converted_params:
group['params'][i] = converted_params[p]
class GeminiCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None:
@@ -113,8 +67,6 @@ class GeminiModel(ModelWrapper):
def __init__(self, module: nn.Module, gemini_config: dict) -> None:
super().__init__(module)
# TODO(ver217): only support Gemini now
convert_to_colo_param(module)
self.module = zero_model_wrapper(module, zero_stage=3, gemini_config=gemini_config)
def unwrap(self):
@@ -125,8 +77,6 @@ class GeminiModel(ModelWrapper):
class GeminiOptimizer(OptimizerWrapper):
def __init__(self, module: GeminiDDP, optimizer: Optimizer, zero_optim_config: dict, optim_kwargs: dict) -> None:
replace_param_in_group(optimizer, module.module._converted_params)
del module.module._converted_params
optimizer = zero_optim_wrapper(module, optimizer, optim_config=zero_optim_config, **optim_kwargs)
super().__init__(optimizer)