mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-14 05:33:23 +00:00
Merge branch 'main' into feature/shardformer
This commit is contained in:
@@ -57,8 +57,8 @@ class GradientStore(BaseStore):
|
||||
self._grads_of_params[group_id][param_id].append(grad)
|
||||
|
||||
def add_gradients_by_param_id(self, grad: Tensor, grad_idx: int, group_id: int, param_id: int):
|
||||
"""For old gradient accumulation, not in use now.
|
||||
Add a gradient slice on an existing slice of the parameter's gradient
|
||||
"""Add a gradient slice on an existing slice of the parameter's gradient
|
||||
Used when no_sync is not activated.
|
||||
|
||||
Args:
|
||||
grad (Tensor): The split gradient to append to list
|
||||
|
@@ -80,9 +80,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
tp_process_group: Optional[ProcessGroup] = None, # if using tp
|
||||
forced_dtype: Optional[torch.dtype] = None):
|
||||
|
||||
# TODO:
|
||||
# 1. state_dict for checkpoint IO
|
||||
|
||||
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
|
||||
self._dtype = self.optim.param_groups[0]['params'][0].dtype
|
||||
self._logger = get_dist_logger()
|
||||
@@ -277,7 +274,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
sync_tensor(flat_grads_per_rank[rank], grad_list)
|
||||
for grad in grad_list:
|
||||
param_id = self._bucket_store.get_param_id_of_grad(grad)
|
||||
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
|
||||
if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id,
|
||||
param_id)) < self._world_size:
|
||||
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
|
||||
else:
|
||||
self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id)
|
||||
|
||||
else:
|
||||
flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size))
|
||||
@@ -291,7 +292,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
sync_tensor(recieved_grad, grad_in_bucket_current_rank)
|
||||
for grad in grad_in_bucket_current_rank:
|
||||
param_id = self._bucket_store.get_param_id_of_grad(grad)
|
||||
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
|
||||
if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < 1:
|
||||
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
|
||||
else:
|
||||
self._grad_store.add_gradients_by_param_id(grad, 0, group_id, param_id)
|
||||
|
||||
self._bucket_store.reset()
|
||||
|
||||
@@ -303,7 +307,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
# or got a grad of param from another group
|
||||
# after reduction, the bucket will be empty
|
||||
if self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size or \
|
||||
group_id != self._bucket_store.current_group_id:
|
||||
group_id != self._bucket_store.current_group_id:
|
||||
self._run_reduction()
|
||||
|
||||
padding_size = self._param_store.get_param_padding_size(param)
|
||||
@@ -315,7 +319,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
|
||||
def backward(self, loss, retain_graph=False):
|
||||
assert not(self._partition_grads and not self.require_grad_sync), \
|
||||
"ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
|
||||
"ZeRO2(partition_grads) and no_sync are not compatible"
|
||||
|
||||
if self.mixed_precision_mixin is not None:
|
||||
loss = self.mixed_precision_mixin.pre_backward(loss)
|
||||
|
||||
@@ -537,9 +542,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor) and k != 'step':
|
||||
working_param = self._param_store.master_to_working_param[id(param)]
|
||||
gather_tensor = [torch.zeros_like(v) for _ in range(self._world_size)]
|
||||
dist.all_gather(gather_tensor, v, group=self.dp_pg)
|
||||
param_state = torch.stack(gather_tensor).view(-1)[:working_param.numel()].reshape_as(working_param)
|
||||
gather_tensor = [
|
||||
torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size)
|
||||
]
|
||||
dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg)
|
||||
param_state = torch.stack(gather_tensor).view(-1)[:working_param.numel()].reshape_as(
|
||||
working_param).cpu()
|
||||
zero_state[param][k] = param_state
|
||||
|
||||
states_dict = self._pack_state(zero_state)
|
||||
@@ -562,10 +570,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
if padding_size > 0:
|
||||
v = torch.nn.functional.pad(v, [0, padding_size])
|
||||
v_list = v.split(v.numel() // self._world_size)
|
||||
zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].detach()
|
||||
zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].detach().clone()
|
||||
|
||||
self.optim.load_state_dict(zero_state_dict)
|
||||
zero_state_dict = dict()
|
||||
|
||||
def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, int]]:
|
||||
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
|
||||
@@ -594,9 +601,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
|
||||
for k, v in states.items():
|
||||
if isinstance(v, torch.Tensor) and k != 'step':
|
||||
state_tensor = [torch.zeros_like(v) for _ in range(self._world_size)]
|
||||
dist.all_gather(state_tensor, v, group=self.dp_pg)
|
||||
state_tensor = torch.stack(state_tensor).view(-1)[:working_param.numel()].reshape_as(working_param)
|
||||
state_tensor = [torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size)]
|
||||
dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg)
|
||||
state_tensor = torch.stack(state_tensor).view(-1)[:working_param.numel()].reshape_as(
|
||||
working_param).cpu()
|
||||
current_block_size += state_tensor.numel()
|
||||
current_block[k] = state_tensor
|
||||
|
||||
|
@@ -1,5 +1,41 @@
|
||||
# Low Level ZeRO
|
||||
>Low Level ZeRO == ZeRO-DP stage 1 and 2, we would denote it as ZeRO.
|
||||
## Examples of ZeRO and gradient accumulation
|
||||
|
||||
The code below only shows a typical gradient accumulation process, and it drops a lot of details, such as the processing of loss.
|
||||
|
||||
```python
|
||||
# examples of ZeRO1 with gradient accumulation
|
||||
...
|
||||
outputs = model(input)
|
||||
loss = SomeLoss(outputs)
|
||||
if (idx + 1) % ACCUMULATE_STEP != 0:
|
||||
with booster.no_sync(model, optimizer):
|
||||
# under this context, the gradient would not sync when backward,
|
||||
# left each rank having different gradient.
|
||||
# It saves the backward time
|
||||
booster.backward(loss, optimizer)
|
||||
continue
|
||||
else:
|
||||
# need to sync all the accumulated gradient
|
||||
booster.backward(loss, optimizer):
|
||||
optimizer.step()
|
||||
...
|
||||
```
|
||||
|
||||
```python
|
||||
# example of ZeRO2 with gradient accumulation
|
||||
|
||||
...
|
||||
outputs = model(input)
|
||||
loss = SomeLoss(outputs)
|
||||
# ZeRO2 split the gradients and can NOT accumulate gradient with syncing.
|
||||
booster.backward(loss, optimizer)
|
||||
if (idx + 1) % ACCUMULATE_STEP == 0:
|
||||
optimizer.step()
|
||||
...
|
||||
```
|
||||
|
||||
|
||||
## Design:
|
||||
### Notion
|
||||
@@ -25,11 +61,11 @@ The data structure looks like this:
|
||||
```
|
||||
After that, the gradients would be flattened by rank, and the data structure looks like this:
|
||||
```
|
||||
# g-0 means flatten([g-00, g-10])
|
||||
# g-X0 means flatten([g-00, g-10])
|
||||
{
|
||||
0: [g-0],
|
||||
1: [g-1],
|
||||
2: [g-2]
|
||||
0: [g-X0],
|
||||
1: [g-X1],
|
||||
2: [g-X2]
|
||||
}
|
||||
```
|
||||
For zero1, we iterate the dictionary and do `all_reduce`. For zero2, we can just do `reduce-scatter`.
|
||||
|
Reference in New Issue
Block a user