mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[gemini] gemini supports lazy init (#3379)
* [gemini] fix nvme optimizer init * [gemini] gemini supports lazy init * [gemini] add init example * [gemini] add fool model * [zero] update gemini ddp * [zero] update init example * add chunk method * add chunk method * [lazyinit] fix lazy tensor tolist * [gemini] fix buffer materialization * [misc] remove useless file * [booster] update gemini plugin * [test] update gemini plugin test * [test] fix gemini plugin test * [gemini] fix import * [gemini] fix import * [lazyinit] use new metatensor * [lazyinit] use new metatensor * [lazyinit] fix __set__ method
This commit is contained in:
@@ -16,10 +16,8 @@ from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
|
||||
from colossalai.checkpoint_io.utils import save_state_dict
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.zero.gemini.colo_init_context import _convert_to_coloparam
|
||||
from colossalai.zero.gemini.memory_tracer import MemStats
|
||||
|
||||
from .plugin_base import Plugin
|
||||
@@ -27,50 +25,6 @@ from .plugin_base import Plugin
|
||||
__all__ = ['GeminiPlugin']
|
||||
|
||||
|
||||
def convert_to_colo_param(module: nn.Module) -> None:
|
||||
"""Convert module's paramters to ColoParameter. This is a workaround and will be deprecated when lazy init is compatible with Gemini.
|
||||
|
||||
Args:
|
||||
module (nn.Module): Module to be converted.
|
||||
"""
|
||||
converted_modules = set() # handle shared modules
|
||||
converted_params = dict() # record mapping between (torch.Tensor, ColoTensor) to distinguish the same reference
|
||||
|
||||
def convert_recursively(m: nn.Module):
|
||||
for child in m.children():
|
||||
if child not in converted_modules:
|
||||
converted_modules.add(child)
|
||||
convert_recursively(child)
|
||||
|
||||
for name, p in m.named_parameters(recurse=False):
|
||||
assert not isinstance(p, ColoParameter)
|
||||
if p in converted_params:
|
||||
target = converted_params[p]
|
||||
else:
|
||||
target = _convert_to_coloparam(p, p.device, p.dtype)
|
||||
converted_params[p] = target
|
||||
setattr(m, name, target)
|
||||
target.shared_param_modules.append(m)
|
||||
|
||||
convert_recursively(module)
|
||||
|
||||
# optimizer should replace params in group as well. This attr should be deleted after replacing to avoid memory leak
|
||||
module._converted_params = converted_params
|
||||
|
||||
|
||||
def replace_param_in_group(optimizer: Optimizer, converted_params: dict) -> None:
|
||||
"""Replace param in optimizer's group with converted ColoParameter.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): Optimizer to be replaced.
|
||||
converted_params (dict): Mapping between (torch.Tensor, ColoTensor).
|
||||
"""
|
||||
for group in optimizer.param_groups:
|
||||
for i, p in enumerate(group['params']):
|
||||
if p in converted_params:
|
||||
group['params'][i] = converted_params[p]
|
||||
|
||||
|
||||
class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
def __init__(self) -> None:
|
||||
@@ -113,8 +67,6 @@ class GeminiModel(ModelWrapper):
|
||||
|
||||
def __init__(self, module: nn.Module, gemini_config: dict) -> None:
|
||||
super().__init__(module)
|
||||
# TODO(ver217): only support Gemini now
|
||||
convert_to_colo_param(module)
|
||||
self.module = zero_model_wrapper(module, zero_stage=3, gemini_config=gemini_config)
|
||||
|
||||
def unwrap(self):
|
||||
@@ -125,8 +77,6 @@ class GeminiModel(ModelWrapper):
|
||||
class GeminiOptimizer(OptimizerWrapper):
|
||||
|
||||
def __init__(self, module: GeminiDDP, optimizer: Optimizer, zero_optim_config: dict, optim_kwargs: dict) -> None:
|
||||
replace_param_in_group(optimizer, module.module._converted_params)
|
||||
del module.module._converted_params
|
||||
optimizer = zero_optim_wrapper(module, optimizer, optim_config=zero_optim_config, **optim_kwargs)
|
||||
super().__init__(optimizer)
|
||||
|
||||
|
Reference in New Issue
Block a user