mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[shardformer] support sharded optimizer checkpointIO of HybridParallelPlugin (#4540)
* implement sharded optimizer saving * add more param info * finish implementation of sharded optimizer saving * fix bugs in optimizer sharded saving * add pp+zero test * param group loading * greedy loading of optimizer * fix bug when loading * implement optimizer sharded saving * add optimizer test & arrange checkpointIO utils * fix gemini sharding state_dict * add verbose option * add loading of master params * fix typehint * fix master/working mapping in fp16 amp
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import random
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Iterator, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -110,6 +110,36 @@ class HybridParallelModule(ModelWrapper):
|
||||
return module
|
||||
|
||||
|
||||
def get_param_info(optim: Optimizer):
|
||||
# Get a backup of necessary information of parameters for future use, which includes:
|
||||
# 1. A complete param_group, with params in the form of param_id
|
||||
# 2. A mapping from param address (obtained using id(param)) to integer param_id
|
||||
# 3. A mapping from integer param_id to param address.
|
||||
# 4. A mapping from param_address (obtained using id(param)) to the original shape of parameter before sharding.
|
||||
# When Zero is used, the params here are fp16/bf16 model params rather than fp32 master params in optimizer.
|
||||
|
||||
if optim is None:
|
||||
return {}
|
||||
param_info = {'param_groups': [], 'param2id': {}, 'id2param': {}, 'param2shape': {}}
|
||||
start_index = 0
|
||||
for group in optim.param_groups:
|
||||
|
||||
packed_group = {k: v for k, v in group.items() if k != 'params'}
|
||||
packed_group['params'] = []
|
||||
|
||||
for param_id, param in enumerate(group['params'], start_index):
|
||||
original_shape = param.shape if isinstance(param, torch.Tensor) else None
|
||||
packed_group['params'].append(param_id)
|
||||
param_info['param2id'][id(param)] = param_id
|
||||
param_info['id2param'][param_id] = id(param)
|
||||
param_info['param2shape'][id(param)] = original_shape
|
||||
|
||||
param_info['param_groups'].append(packed_group)
|
||||
start_index += len(group['params'])
|
||||
|
||||
return param_info
|
||||
|
||||
|
||||
def init_pipeline_optimizer(optim: Optimizer, model: Module):
|
||||
params = set(model.parameters())
|
||||
new_param_groups = []
|
||||
@@ -121,7 +151,8 @@ def init_pipeline_optimizer(optim: Optimizer, model: Module):
|
||||
|
||||
class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||
|
||||
def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool):
|
||||
def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict):
|
||||
self.param_info = param_info
|
||||
if use_pipeline:
|
||||
init_pipeline_optimizer(optim, model)
|
||||
super().__init__(optim)
|
||||
@@ -133,6 +164,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||
optim: Optimizer,
|
||||
model: Module,
|
||||
use_pipeline: bool,
|
||||
param_info: OrderedDict,
|
||||
precision: str = 'fp16',
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
@@ -142,6 +174,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0):
|
||||
self.param_info = param_info
|
||||
if use_pipeline:
|
||||
init_pipeline_optimizer(optim, model)
|
||||
super().__init__(optim, precision, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval,
|
||||
@@ -155,6 +188,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
optimizer: Optimizer,
|
||||
model: Module,
|
||||
use_pipeline: bool,
|
||||
param_info: OrderedDict,
|
||||
initial_scale: int = 2**16, # grad scaler config
|
||||
min_scale: int = 1,
|
||||
growth_factor: float = 2.,
|
||||
@@ -172,6 +206,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
|
||||
tp_process_group: Optional[ProcessGroup] = None, # if using tp
|
||||
forced_dtype: Optional[torch.dtype] = None):
|
||||
self.param_info = param_info
|
||||
if use_pipeline:
|
||||
init_pipeline_optimizer(optimizer, model)
|
||||
super().__init__(optimizer, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval,
|
||||
@@ -356,6 +391,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
dataloader: Optional[DataLoader] = None,
|
||||
lr_scheduler: Optional[LRScheduler] = None,
|
||||
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||
param_info = get_param_info(optimizer)
|
||||
if not isinstance(model, ModelWrapper):
|
||||
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
|
||||
model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp,
|
||||
@@ -366,25 +402,33 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
optimizer = HybridParallelAMPOptimizer(optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
param_info=param_info,
|
||||
precision=self.precision,
|
||||
max_norm=self.max_norm,
|
||||
**self.amp_config)
|
||||
self.checkpoint_io.link_master_and_working_param(optimizer.working_to_master_map,
|
||||
optimizer.master_to_working_map)
|
||||
else:
|
||||
optimizer = HybridParallelNaiveOptimizer(optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism)
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
param_info=param_info)
|
||||
else:
|
||||
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
|
||||
assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO."
|
||||
optimizer = HybridParallelZeroOptimizer(optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
param_info=param_info,
|
||||
dp_process_group=self.dp_group,
|
||||
tp_process_group=self.tp_group,
|
||||
verbose=True,
|
||||
clip_grad_norm=self.max_norm,
|
||||
**self.zero_config,
|
||||
**self.amp_config)
|
||||
self.checkpoint_io.link_master_and_working_param(optimizer._param_store.working_to_master_param,
|
||||
optimizer._param_store.master_to_working_param)
|
||||
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
||||
def execute_pipeline(self,
|
||||
@@ -461,7 +505,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
**_kwargs)
|
||||
|
||||
def get_checkpoint_io(self) -> CheckpointIO:
|
||||
return HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group)
|
||||
self.checkpoint_io = HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
||||
return self.checkpoint_io
|
||||
|
||||
def no_sync(self, model: Module) -> Iterator[None]:
|
||||
raise NotImplementedError
|
||||
|
Reference in New Issue
Block a user