[shardformer] support sharded optimizer checkpointIO of HybridParallelPlugin (#4540)

* implement sharded optimizer saving

* add more param info

* finish implementation of sharded optimizer saving

* fix bugs in optimizer sharded saving

* add pp+zero test

* param group loading

* greedy loading of optimizer

* fix bug when loading

* implement optimizer sharded saving

* add optimizer test & arrange checkpointIO utils

* fix gemini sharding state_dict

* add verbose option

* add loading of master params

* fix typehint

* fix master/working mapping in fp16 amp
This commit is contained in:
Baizhou Zhang
2023-08-31 14:50:47 +08:00
committed by GitHub
parent 2c787d7f47
commit c9625dbb63
6 changed files with 812 additions and 369 deletions

View File

@@ -1,7 +1,7 @@
import random
from contextlib import nullcontext
from functools import partial
from typing import Any, Callable, Iterator, List, Optional, Tuple, Union
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
import numpy as np
import torch
@@ -110,6 +110,36 @@ class HybridParallelModule(ModelWrapper):
return module
def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes:
# 1. A complete param_group, with params in the form of param_id
# 2. A mapping from param address (obtained using id(param)) to integer param_id
# 3. A mapping from integer param_id to param address.
# 4. A mapping from param_address (obtained using id(param)) to the original shape of parameter before sharding.
# When Zero is used, the params here are fp16/bf16 model params rather than fp32 master params in optimizer.
if optim is None:
return {}
param_info = {'param_groups': [], 'param2id': {}, 'id2param': {}, 'param2shape': {}}
start_index = 0
for group in optim.param_groups:
packed_group = {k: v for k, v in group.items() if k != 'params'}
packed_group['params'] = []
for param_id, param in enumerate(group['params'], start_index):
original_shape = param.shape if isinstance(param, torch.Tensor) else None
packed_group['params'].append(param_id)
param_info['param2id'][id(param)] = param_id
param_info['id2param'][param_id] = id(param)
param_info['param2shape'][id(param)] = original_shape
param_info['param_groups'].append(packed_group)
start_index += len(group['params'])
return param_info
def init_pipeline_optimizer(optim: Optimizer, model: Module):
params = set(model.parameters())
new_param_groups = []
@@ -121,7 +151,8 @@ def init_pipeline_optimizer(optim: Optimizer, model: Module):
class HybridParallelNaiveOptimizer(OptimizerWrapper):
def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool):
def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict):
self.param_info = param_info
if use_pipeline:
init_pipeline_optimizer(optim, model)
super().__init__(optim)
@@ -133,6 +164,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
optim: Optimizer,
model: Module,
use_pipeline: bool,
param_info: OrderedDict,
precision: str = 'fp16',
initial_scale: float = 2**16,
min_scale: float = 1,
@@ -142,6 +174,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0):
self.param_info = param_info
if use_pipeline:
init_pipeline_optimizer(optim, model)
super().__init__(optim, precision, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval,
@@ -155,6 +188,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
optimizer: Optimizer,
model: Module,
use_pipeline: bool,
param_info: OrderedDict,
initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1,
growth_factor: float = 2.,
@@ -172,6 +206,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
tp_process_group: Optional[ProcessGroup] = None, # if using tp
forced_dtype: Optional[torch.dtype] = None):
self.param_info = param_info
if use_pipeline:
init_pipeline_optimizer(optimizer, model)
super().__init__(optimizer, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval,
@@ -356,6 +391,7 @@ class HybridParallelPlugin(PipelinePluginBase):
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
param_info = get_param_info(optimizer)
if not isinstance(model, ModelWrapper):
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp,
@@ -366,25 +402,33 @@ class HybridParallelPlugin(PipelinePluginBase):
optimizer = HybridParallelAMPOptimizer(optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
precision=self.precision,
max_norm=self.max_norm,
**self.amp_config)
self.checkpoint_io.link_master_and_working_param(optimizer.working_to_master_map,
optimizer.master_to_working_map)
else:
optimizer = HybridParallelNaiveOptimizer(optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism)
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info)
else:
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO."
optimizer = HybridParallelZeroOptimizer(optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
dp_process_group=self.dp_group,
tp_process_group=self.tp_group,
verbose=True,
clip_grad_norm=self.max_norm,
**self.zero_config,
**self.amp_config)
self.checkpoint_io.link_master_and_working_param(optimizer._param_store.working_to_master_param,
optimizer._param_store.master_to_working_param)
return model, optimizer, criterion, dataloader, lr_scheduler
def execute_pipeline(self,
@@ -461,7 +505,8 @@ class HybridParallelPlugin(PipelinePluginBase):
**_kwargs)
def get_checkpoint_io(self) -> CheckpointIO:
return HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group)
self.checkpoint_io = HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
return self.checkpoint_io
def no_sync(self, model: Module) -> Iterator[None]:
raise NotImplementedError