Merge branch 'main' into feature/shardformer

This commit is contained in:
Hongxin Liu
2023-09-04 23:43:13 +08:00
committed by GitHub
138 changed files with 4664 additions and 4219 deletions

View File

@@ -57,8 +57,8 @@ class GradientStore(BaseStore):
self._grads_of_params[group_id][param_id].append(grad)
def add_gradients_by_param_id(self, grad: Tensor, grad_idx: int, group_id: int, param_id: int):
"""For old gradient accumulation, not in use now.
Add a gradient slice on an existing slice of the parameter's gradient
"""Add a gradient slice on an existing slice of the parameter's gradient
Used when no_sync is not activated.
Args:
grad (Tensor): The split gradient to append to list

View File

@@ -80,9 +80,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
tp_process_group: Optional[ProcessGroup] = None, # if using tp
forced_dtype: Optional[torch.dtype] = None):
# TODO:
# 1. state_dict for checkpoint IO
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
self._dtype = self.optim.param_groups[0]['params'][0].dtype
self._logger = get_dist_logger()
@@ -277,7 +274,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
sync_tensor(flat_grads_per_rank[rank], grad_list)
for grad in grad_list:
param_id = self._bucket_store.get_param_id_of_grad(grad)
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id,
param_id)) < self._world_size:
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
else:
self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id)
else:
flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size))
@@ -291,7 +292,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
sync_tensor(recieved_grad, grad_in_bucket_current_rank)
for grad in grad_in_bucket_current_rank:
param_id = self._bucket_store.get_param_id_of_grad(grad)
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < 1:
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
else:
self._grad_store.add_gradients_by_param_id(grad, 0, group_id, param_id)
self._bucket_store.reset()
@@ -303,7 +307,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# or got a grad of param from another group
# after reduction, the bucket will be empty
if self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size or \
group_id != self._bucket_store.current_group_id:
group_id != self._bucket_store.current_group_id:
self._run_reduction()
padding_size = self._param_store.get_param_padding_size(param)
@@ -315,7 +319,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def backward(self, loss, retain_graph=False):
assert not(self._partition_grads and not self.require_grad_sync), \
"ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
"ZeRO2(partition_grads) and no_sync are not compatible"
if self.mixed_precision_mixin is not None:
loss = self.mixed_precision_mixin.pre_backward(loss)
@@ -537,9 +542,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != 'step':
working_param = self._param_store.master_to_working_param[id(param)]
gather_tensor = [torch.zeros_like(v) for _ in range(self._world_size)]
dist.all_gather(gather_tensor, v, group=self.dp_pg)
param_state = torch.stack(gather_tensor).view(-1)[:working_param.numel()].reshape_as(working_param)
gather_tensor = [
torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size)
]
dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg)
param_state = torch.stack(gather_tensor).view(-1)[:working_param.numel()].reshape_as(
working_param).cpu()
zero_state[param][k] = param_state
states_dict = self._pack_state(zero_state)
@@ -562,10 +570,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
v_list = v.split(v.numel() // self._world_size)
zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].detach()
zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].detach().clone()
self.optim.load_state_dict(zero_state_dict)
zero_state_dict = dict()
def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, int]]:
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
@@ -594,9 +601,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for k, v in states.items():
if isinstance(v, torch.Tensor) and k != 'step':
state_tensor = [torch.zeros_like(v) for _ in range(self._world_size)]
dist.all_gather(state_tensor, v, group=self.dp_pg)
state_tensor = torch.stack(state_tensor).view(-1)[:working_param.numel()].reshape_as(working_param)
state_tensor = [torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size)]
dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg)
state_tensor = torch.stack(state_tensor).view(-1)[:working_param.numel()].reshape_as(
working_param).cpu()
current_block_size += state_tensor.numel()
current_block[k] = state_tensor

View File

@@ -1,5 +1,41 @@
# Low Level ZeRO
>Low Level ZeRO == ZeRO-DP stage 1 and 2, we would denote it as ZeRO.
## Examples of ZeRO and gradient accumulation
The code below only shows a typical gradient accumulation process, and it drops a lot of details, such as the processing of loss.
```python
# examples of ZeRO1 with gradient accumulation
...
outputs = model(input)
loss = SomeLoss(outputs)
if (idx + 1) % ACCUMULATE_STEP != 0:
with booster.no_sync(model, optimizer):
# under this context, the gradient would not sync when backward,
# left each rank having different gradient.
# It saves the backward time
booster.backward(loss, optimizer)
continue
else:
# need to sync all the accumulated gradient
booster.backward(loss, optimizer):
optimizer.step()
...
```
```python
# example of ZeRO2 with gradient accumulation
...
outputs = model(input)
loss = SomeLoss(outputs)
# ZeRO2 split the gradients and can NOT accumulate gradient with syncing.
booster.backward(loss, optimizer)
if (idx + 1) % ACCUMULATE_STEP == 0:
optimizer.step()
...
```
## Design:
### Notion
@@ -25,11 +61,11 @@ The data structure looks like this:
```
After that, the gradients would be flattened by rank, and the data structure looks like this:
```
# g-0 means flatten([g-00, g-10])
# g-X0 means flatten([g-00, g-10])
{
0: [g-0],
1: [g-1],
2: [g-2]
0: [g-X0],
1: [g-X1],
2: [g-X2]
}
```
For zero1, we iterate the dictionary and do `all_reduce`. For zero2, we can just do `reduce-scatter`.