mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[zero] support shard optimizer state dict of zero (#4194)
* support shard optimizer of zero * polish code * support sync grad manually
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
import copy
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
from typing import Dict, Iterator, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -447,18 +447,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
# Gradient Synchronization #
|
||||
############################
|
||||
|
||||
# this method is used to sync gradient manually
|
||||
def sync_grad(self):
|
||||
for group_id in range(self.num_param_groups):
|
||||
param_group = self._working_param_groups[group_id]
|
||||
for param in param_group:
|
||||
if param.requires_grad and param.grad is not None:
|
||||
self._add_to_bucket(param, group_id)
|
||||
|
||||
self._run_reduction()
|
||||
|
||||
def _reduce_grad(self, partition_grad):
|
||||
# if not overlapping communication (no reduction hook is attached) when zero1
|
||||
# we need to manually reduce these gradients
|
||||
if not partition_grad and not self._overlap_communication:
|
||||
for group_id in range(len(self._working_param_groups)):
|
||||
param_group = self._working_param_groups[group_id]
|
||||
for param in param_group:
|
||||
if param.grad is not None:
|
||||
self._add_to_bucket(param, group_id)
|
||||
|
||||
# run reduction
|
||||
self._run_reduction()
|
||||
self.sync_grad()
|
||||
else:
|
||||
self._run_reduction()
|
||||
|
||||
# this context comes from pytorch DDP
|
||||
@contextmanager
|
||||
@@ -473,7 +478,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
##############
|
||||
# State Dict #
|
||||
##############
|
||||
def _pack_state(self, state: dict) -> dict:
|
||||
|
||||
def _pack_state(self, state: Dict) -> Dict:
|
||||
# comes from pytorch optimizer.state_dict()
|
||||
param_mappings = {}
|
||||
start_index = 0
|
||||
@@ -487,17 +493,17 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
start_index += len(packed['params'])
|
||||
return packed
|
||||
|
||||
param_groups = [pack_group(g) for g in self.param_groups]
|
||||
param_groups = [pack_group(g) for g in self.optim.param_groups]
|
||||
# Remap state to use order indices as keys
|
||||
packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v for k, v in state.items()}
|
||||
|
||||
return {'state': packed_state, 'param_groups': param_groups}
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
def state_dict(self) -> Dict:
|
||||
"""Return a state_dict same with DDP
|
||||
|
||||
Returns:
|
||||
dict: the pytorch form state_dict
|
||||
Dict: the pytorch form state_dict
|
||||
"""
|
||||
zero_state = dict()
|
||||
for param, state in self.optim.state.items():
|
||||
@@ -514,7 +520,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
|
||||
return states_dict
|
||||
|
||||
def load_state_dict(self, state_dict: dict):
|
||||
def load_state_dict(self, state_dict: Dict):
|
||||
"""Load state dict, requires the state_dict be the pytorch form
|
||||
|
||||
Args:
|
||||
@@ -534,3 +540,46 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
|
||||
self.optim.load_state_dict(zero_state_dict)
|
||||
zero_state_dict = dict()
|
||||
|
||||
def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, int]]:
|
||||
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
|
||||
Only include the 'state' in state_dict.
|
||||
|
||||
Args:
|
||||
max_shard_size (int, optional): max size of state shard (in MB). Defaults to 1024.
|
||||
|
||||
Yields:
|
||||
Iterator[OrderedDict]: A generator of state dict shard
|
||||
"""
|
||||
ret_block = dict()
|
||||
ret_block_size = 0
|
||||
|
||||
local_states = self.optim.state_dict()['state']
|
||||
for param_idx, states in local_states.items():
|
||||
current_block_size = 0
|
||||
current_block = copy.deepcopy(states)
|
||||
|
||||
# find the working param of current param_id
|
||||
for group_id, pg in self._master_param_groups_of_current_rank.items():
|
||||
if (group_id + 1) * len(pg) < param_idx:
|
||||
continue
|
||||
master_param = pg[param_idx - (group_id) * len(pg)]
|
||||
working_param = self._param_store.master_to_working_param[id(master_param)]
|
||||
|
||||
for k, v in states.items():
|
||||
if isinstance(v, torch.Tensor) and k != 'step':
|
||||
state_tensor = [torch.zeros_like(v) for _ in range(self._world_size)]
|
||||
dist.all_gather(state_tensor, v, group=self.dp_pg)
|
||||
state_tensor = torch.stack(state_tensor).view(-1)[:working_param.numel()].reshape_as(working_param)
|
||||
current_block_size += state_tensor.numel()
|
||||
current_block[k] = state_tensor
|
||||
|
||||
if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0:
|
||||
yield ret_block, ret_block_size
|
||||
ret_block = dict()
|
||||
ret_block_size = 0
|
||||
|
||||
ret_block[param_idx] = current_block
|
||||
ret_block_size += current_block_size
|
||||
|
||||
yield ret_block, ret_block_size
|
||||
|
54
colossalai/zero/low_level/readme.md
Normal file
54
colossalai/zero/low_level/readme.md
Normal file
@@ -0,0 +1,54 @@
|
||||
# Low Level ZeRO
|
||||
>Low Level ZeRO == ZeRO-DP stage 1 and 2, we would denote it as ZeRO.
|
||||
|
||||
## Design:
|
||||
### Notion
|
||||
`p32` denotes the param copy in the optimizer
|
||||
`p` denotes the model param
|
||||
`g` denotes the gradient
|
||||
|
||||
### INIT
|
||||
In low level zero(1, 2), `p32` is split. Different from the previous implement, we split each `p32` evenly by world_size. Thus, rank0 got a param list as `[p00, p10]`, rank1 got a param list as `[p-01, p-11]`, etc.
|
||||
<img width="840" alt="image" src="https://github.com/hpcaitech/ColossalAI/assets/74758262/f7758d7d-c5e5-44a4-a121-3aba8b05c904">
|
||||
|
||||
For the detailed implementation, we first pad `p` for it can be split by world_size if needed. Then, we would view it to the shape `[world_size, -1]`, and each rank got its own part `p32` by cloning.
|
||||
|
||||
### BWD
|
||||
To leverage the communication, a gradient would be added to a bucket first. When the bucket is full, each `g` in it would be reshaped as `[world_size, -1]`. And the `[local_rank]` parts would be united.
|
||||
The data structure looks like this:
|
||||
```
|
||||
{
|
||||
0: [g-00, g-10],
|
||||
1: [g-01, g-11],
|
||||
2: [g-02, g-12]
|
||||
}
|
||||
```
|
||||
After that, the gradients would be flattened by rank, and the data structure looks like this:
|
||||
```
|
||||
# g-0 means flatten([g-00, g-10])
|
||||
{
|
||||
0: [g-0],
|
||||
1: [g-1],
|
||||
2: [g-2]
|
||||
}
|
||||
```
|
||||
For zero1, we iterate the dictionary and do `all_reduce`. For zero2, we can just do `reduce-scatter`.
|
||||
|
||||
### Optim
|
||||
For each rank gets its own `p32` and the counterpart `g`, it is quite easy to do `optim.step()`.
|
||||
|
||||
However, we have to consider a situation of layer drop, for instance:
|
||||
```
|
||||
class MlpModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MlpModel, self).__init__()
|
||||
self.linear1 = nn.Linear(128, 256)
|
||||
self.drop_linear = nn.Linear(256, 256)
|
||||
self.linear2 = nn.Linear(256, 512)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
```
|
||||
And the solution is to build a mapping of `p32`, `p`, and `g`. Before `optim.step()`, we collect `p` which `requires_grad=True` and `p.grad != None` as a real working param. And select the counterpart `p32` and `g`.
|
Reference in New Issue
Block a user