mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[hotfix] fix unsafe async comm in zero (#4404)
* improve stablility of zero * fix wrong index * add record stream
This commit is contained in:
@@ -13,15 +13,20 @@ class BucketStore(BaseStore):
|
||||
def __init__(self, torch_pg: ProcessGroup):
|
||||
super().__init__(torch_pg)
|
||||
|
||||
# init and reset
|
||||
# init
|
||||
self.current_group_id = 0
|
||||
self._num_elements_in_bucket = 0
|
||||
# mapping gardient slices and parameter
|
||||
self.grad_to_param_mapping = dict()
|
||||
|
||||
self._grad_in_bucket = dict()
|
||||
self._param_list = []
|
||||
self._padding_size = []
|
||||
for rank in range(self._world_size):
|
||||
self._grad_in_bucket[rank] = []
|
||||
|
||||
self.reset()
|
||||
# offset_list records number of tensors in the bucket before each reduction
|
||||
self.offset_list = [0]
|
||||
|
||||
def num_elements_in_bucket(self) -> int:
|
||||
"""Return the total number of elements in bucket
|
||||
@@ -32,6 +37,12 @@ class BucketStore(BaseStore):
|
||||
|
||||
return self._num_elements_in_bucket
|
||||
|
||||
def reset_num_elements_in_bucket(self):
|
||||
"""Set the number of elements in bucket to zero.
|
||||
"""
|
||||
|
||||
self._num_elements_in_bucket = 0
|
||||
|
||||
def add_param_grad(self, group_id: int, param: Tensor, padding_size: int):
|
||||
"""Add a param to bucket and record the padding size of a param for gradient padding
|
||||
|
||||
@@ -46,28 +57,32 @@ class BucketStore(BaseStore):
|
||||
self._num_elements_in_bucket += (param.numel() + padding_size)
|
||||
self.current_group_id = group_id
|
||||
|
||||
# number of tensors in current bucket
|
||||
self.offset_list[-1] += 1
|
||||
|
||||
def build_grad_in_bucket(self):
|
||||
"""Orgnize parameters' gradient(padding and split), follows the paramters' splitting method
|
||||
|
||||
Data structure of self._grad_in_bucket:
|
||||
{
|
||||
rank0: [grad0_rank0, grad1_rank0, ...]
|
||||
rank1: [grad1_rank1, grad1_rank1, ...]
|
||||
rank1: [grad0_rank1, grad1_rank1, ...]
|
||||
}
|
||||
"""
|
||||
|
||||
for param, padding_size in zip(self._param_list, self._padding_size):
|
||||
with torch.no_grad():
|
||||
grad = param.grad.detach().flatten()
|
||||
if padding_size > 0:
|
||||
grad = torch.nn.functional.pad(grad, [0, padding_size])
|
||||
grad_list = grad.split(grad.numel() // self._world_size)
|
||||
for rank in range(self._world_size):
|
||||
grad_current_rank = grad_list[rank].detach()
|
||||
self.grad_to_param_mapping[id(grad_current_rank)] = id(param)
|
||||
self._grad_in_bucket[rank].append(grad_current_rank)
|
||||
grad = param.grad.clone().detach().flatten()
|
||||
if padding_size > 0:
|
||||
with torch.no_grad():
|
||||
grad = torch.nn.functional.pad(grad.view(-1), [0, padding_size])
|
||||
grad_list = grad.split(grad.numel() // self._world_size)
|
||||
for rank in range(self._world_size):
|
||||
grad_current_rank = grad_list[rank].clone().detach()
|
||||
self.grad_to_param_mapping[id(grad_current_rank)] = id(param)
|
||||
self._grad_in_bucket[rank].append(grad_current_rank)
|
||||
param.grad = None
|
||||
|
||||
self.offset_list.append(0)
|
||||
|
||||
def get_grad(self) -> Dict:
|
||||
"""Return the dictionary of gradients slices, of which the keys are ranks
|
||||
|
||||
@@ -104,10 +119,12 @@ class BucketStore(BaseStore):
|
||||
return self.grad_to_param_mapping[id(grad)]
|
||||
|
||||
def reset(self):
|
||||
self.grad_to_param_mapping = dict()
|
||||
self._num_elements_in_bucket = 0
|
||||
self._param_list = []
|
||||
self._padding_size = []
|
||||
self._grad_in_bucket = dict()
|
||||
"""Reset the bucket storage after reduction, only release the tensors have been reduced
|
||||
"""
|
||||
cur_offset = self.offset_list.pop(0)
|
||||
self._param_list = self._param_list[cur_offset:]
|
||||
self._padding_size = self._padding_size[cur_offset:]
|
||||
for _ in range(cur_offset):
|
||||
del self.grad_to_param_mapping[next(iter(self.grad_to_param_mapping))]
|
||||
for rank in range(self._world_size):
|
||||
self._grad_in_bucket[rank] = []
|
||||
self._grad_in_bucket[rank] = self._grad_in_bucket[rank][cur_offset:]
|
||||
|
Reference in New Issue
Block a user