[zero] support shard optimizer state dict of zero (#4194)

* support shard optimizer of zero

* polish code

* support sync grad manually
This commit is contained in:
LuGY 2023-07-11 18:03:13 +08:00 committed by Hongxin Liu
parent dd7cc58299
commit 1a49a5ea00
4 changed files with 239 additions and 68 deletions

View File

@ -1,5 +1,8 @@
import logging
import os
import warnings import warnings
from functools import partial from functools import partial
from pathlib import Path
from typing import Callable, Iterator, List, Optional, Tuple, Union from typing import Callable, Iterator, List, Optional, Tuple, Union
import torch import torch
@ -10,10 +13,16 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO
from colossalai.checkpoint_io.utils import (
get_optimizer_base_filenames,
get_shard_filename,
save_param_groups,
save_state_dict,
)
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero import zero_model_wrapper, zero_optim_wrapper from colossalai.zero import LowLevelZeroOptimizer, zero_model_wrapper, zero_optim_wrapper
from .dp_plugin_base import DPPluginBase from .dp_plugin_base import DPPluginBase
from .torch_ddp_plugin import TorchDDPCheckpointIO from .torch_ddp_plugin import TorchDDPCheckpointIO
@ -32,21 +41,104 @@ SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32']
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
""" """Save optimizer to checkpoint but only on master process.
Save optimizer to checkpoint but only on master process.
"""
# TODO(ver217): optimizer state dict is sharded, and cannot get full state dict now
warnings.warn(
'LowLevelZeroPlugin does not support save full optimizer checkpoint now. Save it on every process.')
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
GeneralCheckpointIO.save_unsharded_optimizer(self, optimizer, checkpoint, gather_dtensor)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str): Args:
warnings.warn( optimizer (OptimizerWrapper): Optimizer to save state_dict
'LowLevelZeroPlugin can only load optimizer checkpoint saved by itself with the same number of processes.') checkpoint (str): Path to save checkpoint
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}' gather_dtensor (bool): Whether to gather_dtensor, not used
super().load_optimizer(optimizer, checkpoint) """
# the `state_dict` in LowLevelZeroOptimizer has communication
# if only the master rank collect state_dict and save,
# the communication on each rank would not match
state_dict = optimizer.state_dict()
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False)
def save_sharded_optimizer(self,
optimizer: OptimizerWrapper,
checkpoint: str,
gather_dtensor: bool = False,
prefix: str = None,
size_per_shard: int = 1024):
"""
Save sharded Zero-optimizer checkpoint under the given checkpointing path.
The following files will be created under the path:
- An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
- A group file (pytorch_optim_group.bin) recording information of param_groups
- Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way
Args:
optimizer (OptimizerWrapper): Optimizer to save sharded state_dict
checkpoint (str): Path to save optimizer state_dict
gather_dtensor (bool): Whether to gather_dtensor, not used
prefix (str): Perfix of file to save
size_per_shard (int): Max file size of each file that store state tensors
"""
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
Path(checkpoint).mkdir(parents=True, exist_ok=True)
# state_dict only provide only 'param_groups'
state_dict = optimizer.optim.state_dict()
# state shard would be handled by the low-level zero optimizer
sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard)
# Preparing file paths and index file.
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
index_file = CheckpointIndexFile(checkpoint)
# Store the information of param groups to param_group_file.
index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(state_dict, group_file_path)
# Save shards of optimizer states.
total_size = 0
for idx, shard_pair in enumerate(sharded_state):
shard, current_size = shard_pair
shard_file = get_shard_filename(states_name, idx)
total_size = total_size + current_size
for param_id in shard.keys():
index_file.append_weight_map(str(param_id), shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file)
if self.coordinator.is_master():
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
# Wrap up index file.
index_file.append_meta_data("total_size", total_size)
if self.coordinator.is_master():
index_file.write_index_file(save_index_file)
logging.info(f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str):
"""Load sharded optimizer with the given path to index file.
Args:
optimizer (OptimizerWrapper): Optimizer to load state_dict
index_file_path (str): Path to the index file
prefix (str): Not used.
"""
super().load_sharded_optimizer(optimizer, index_file_path, prefix)
current_rank_state_dict = optimizer.optim.state_dict()['state']
for param_idx, state in current_rank_state_dict.items():
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != 'step':
padding_size = (self.coordinator.world_size -
v.numel() % self.coordinator.world_size) % self.coordinator.world_size
with torch.no_grad():
v = v.flatten()
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
v_list = v.split(v.numel() // self.coordinator.world_size)
current_rank_state_dict[param_idx][k] = v_list[self.coordinator.rank].detach()
class LowLevelZeroModel(ModelWrapper): class LowLevelZeroModel(ModelWrapper):
@ -74,36 +166,6 @@ class LowLevelZeroModel(ModelWrapper):
return super().forward(*args, **kwargs) return super().forward(*args, **kwargs)
class LowLevelZeroOptimizer(OptimizerWrapper):
def __init__(self,
module: nn.Module,
optimizer: Optimizer,
zero_optim_config: dict,
optim_kwargs: dict,
verbose: bool = False) -> None:
optimizer = zero_optim_wrapper(module,
optimizer,
optim_config=zero_optim_config,
**optim_kwargs,
verbose=verbose)
super().__init__(optimizer)
def backward(self, loss: Tensor, *args, **kwargs):
self.optim.backward(loss)
def clip_grad_by_norm(self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2,
error_if_nonfinite: bool = False,
*args,
**kwargs) -> Tensor:
warnings.warn(f'LowLevelZero controls grad clipping by itself, so you should not use clip_grad_by_norm')
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
raise NotImplementedError('LowLevelZero does not support clip_grad_by_value')
class LowLevelZeroPlugin(DPPluginBase): class LowLevelZeroPlugin(DPPluginBase):
""" """
Plugin for low level zero. Plugin for low level zero.
@ -211,8 +273,11 @@ class LowLevelZeroPlugin(DPPluginBase):
if optimizer is not None and \ if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper): not isinstance(optimizer, OptimizerWrapper):
optimizer = LowLevelZeroOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs, optimizer = zero_optim_wrapper(model.unwrap(),
self.verbose) optimizer,
optim_config=self.zero_optim_config,
**self.optim_kwargs,
verbose=self.verbose)
return model, optimizer, criterion, dataloader, lr_scheduler return model, optimizer, criterion, dataloader, lr_scheduler

View File

@ -2,7 +2,7 @@
import copy import copy
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
from typing import Optional from typing import Dict, Iterator, Optional, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -447,18 +447,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# Gradient Synchronization # # Gradient Synchronization #
############################ ############################
# this method is used to sync gradient manually
def sync_grad(self):
for group_id in range(self.num_param_groups):
param_group = self._working_param_groups[group_id]
for param in param_group:
if param.requires_grad and param.grad is not None:
self._add_to_bucket(param, group_id)
self._run_reduction()
def _reduce_grad(self, partition_grad): def _reduce_grad(self, partition_grad):
# if not overlapping communication (no reduction hook is attached) when zero1 # if not overlapping communication (no reduction hook is attached) when zero1
# we need to manually reduce these gradients # we need to manually reduce these gradients
if not partition_grad and not self._overlap_communication: if not partition_grad and not self._overlap_communication:
for group_id in range(len(self._working_param_groups)): self.sync_grad()
param_group = self._working_param_groups[group_id] else:
for param in param_group: self._run_reduction()
if param.grad is not None:
self._add_to_bucket(param, group_id)
# run reduction
self._run_reduction()
# this context comes from pytorch DDP # this context comes from pytorch DDP
@contextmanager @contextmanager
@ -473,7 +478,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
############## ##############
# State Dict # # State Dict #
############## ##############
def _pack_state(self, state: dict) -> dict:
def _pack_state(self, state: Dict) -> Dict:
# comes from pytorch optimizer.state_dict() # comes from pytorch optimizer.state_dict()
param_mappings = {} param_mappings = {}
start_index = 0 start_index = 0
@ -487,17 +493,17 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
start_index += len(packed['params']) start_index += len(packed['params'])
return packed return packed
param_groups = [pack_group(g) for g in self.param_groups] param_groups = [pack_group(g) for g in self.optim.param_groups]
# Remap state to use order indices as keys # Remap state to use order indices as keys
packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v for k, v in state.items()} packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v for k, v in state.items()}
return {'state': packed_state, 'param_groups': param_groups} return {'state': packed_state, 'param_groups': param_groups}
def state_dict(self) -> dict: def state_dict(self) -> Dict:
"""Return a state_dict same with DDP """Return a state_dict same with DDP
Returns: Returns:
dict: the pytorch form state_dict Dict: the pytorch form state_dict
""" """
zero_state = dict() zero_state = dict()
for param, state in self.optim.state.items(): for param, state in self.optim.state.items():
@ -514,7 +520,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
return states_dict return states_dict
def load_state_dict(self, state_dict: dict): def load_state_dict(self, state_dict: Dict):
"""Load state dict, requires the state_dict be the pytorch form """Load state dict, requires the state_dict be the pytorch form
Args: Args:
@ -534,3 +540,46 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self.optim.load_state_dict(zero_state_dict) self.optim.load_state_dict(zero_state_dict)
zero_state_dict = dict() zero_state_dict = dict()
def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, int]]:
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
Only include the 'state' in state_dict.
Args:
max_shard_size (int, optional): max size of state shard (in MB). Defaults to 1024.
Yields:
Iterator[OrderedDict]: A generator of state dict shard
"""
ret_block = dict()
ret_block_size = 0
local_states = self.optim.state_dict()['state']
for param_idx, states in local_states.items():
current_block_size = 0
current_block = copy.deepcopy(states)
# find the working param of current param_id
for group_id, pg in self._master_param_groups_of_current_rank.items():
if (group_id + 1) * len(pg) < param_idx:
continue
master_param = pg[param_idx - (group_id) * len(pg)]
working_param = self._param_store.master_to_working_param[id(master_param)]
for k, v in states.items():
if isinstance(v, torch.Tensor) and k != 'step':
state_tensor = [torch.zeros_like(v) for _ in range(self._world_size)]
dist.all_gather(state_tensor, v, group=self.dp_pg)
state_tensor = torch.stack(state_tensor).view(-1)[:working_param.numel()].reshape_as(working_param)
current_block_size += state_tensor.numel()
current_block[k] = state_tensor
if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0:
yield ret_block, ret_block_size
ret_block = dict()
ret_block_size = 0
ret_block[param_idx] = current_block
ret_block_size += current_block_size
yield ret_block, ret_block_size

View File

@ -0,0 +1,54 @@
# Low Level ZeRO
>Low Level ZeRO == ZeRO-DP stage 1 and 2, we would denote it as ZeRO.
## Design:
### Notion
`p32` denotes the param copy in the optimizer
`p` denotes the model param
`g` denotes the gradient
### INIT
In low level zero(1, 2), `p32` is split. Different from the previous implement, we split each `p32` evenly by world_size. Thus, rank0 got a param list as `[p00, p10]`, rank1 got a param list as `[p-01, p-11]`, etc.
<img width="840" alt="image" src="https://github.com/hpcaitech/ColossalAI/assets/74758262/f7758d7d-c5e5-44a4-a121-3aba8b05c904">
For the detailed implementation, we first pad `p` for it can be split by world_size if needed. Then, we would view it to the shape `[world_size, -1]`, and each rank got its own part `p32` by cloning.
### BWD
To leverage the communication, a gradient would be added to a bucket first. When the bucket is full, each `g` in it would be reshaped as `[world_size, -1]`. And the `[local_rank]` parts would be united.
The data structure looks like this:
```
{
0: [g-00, g-10],
1: [g-01, g-11],
2: [g-02, g-12]
}
```
After that, the gradients would be flattened by rank, and the data structure looks like this:
```
# g-0 means flatten([g-00, g-10])
{
0: [g-0],
1: [g-1],
2: [g-2]
}
```
For zero1, we iterate the dictionary and do `all_reduce`. For zero2, we can just do `reduce-scatter`.
### Optim
For each rank gets its own `p32` and the counterpart `g`, it is quite easy to do `optim.step()`.
However, we have to consider a situation of layer drop, for instance:
```
class MlpModel(nn.Module):
def __init__(self):
super(MlpModel, self).__init__()
self.linear1 = nn.Linear(128, 256)
self.drop_linear = nn.Linear(256, 256)
self.linear2 = nn.Linear(256, 512)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
```
And the solution is to build a mapping of `p32`, `p`, and `g`. Before `optim.step()`, we collect `p` which `requires_grad=True` and `p.grad != None` as a real working param. And select the counterpart `p32` and `g`.

View File

@ -38,9 +38,8 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool):
optimizer_ckpt_path = f"{tempdir}/optimizer" optimizer_ckpt_path = f"{tempdir}/optimizer"
# lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here # lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here
booster.save_model(model, model_ckpt_path, shard=shard) booster.save_model(model, model_ckpt_path, shard=shard)
if not shard: booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard)
# TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint
booster.save_optimizer(optimizer, optimizer_ckpt_path)
dist.barrier() dist.barrier()
new_model = resnet18() new_model = resnet18()
@ -49,9 +48,9 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool):
booster.load_model(new_model, model_ckpt_path) booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False) check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
if not shard:
booster.load_optimizer(new_optimizer, optimizer_ckpt_path) booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False) check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
@ -62,3 +61,7 @@ def run_dist(rank, world_size, port):
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_low_level_zero_checkpointIO(): def test_low_level_zero_checkpointIO():
spawn(run_dist, 2) spawn(run_dist, 2)
if __name__ == "__main__":
test_low_level_zero_checkpointIO()