[zero] solve hang

This commit is contained in:
botbw
2024-07-09 08:14:00 +00:00
committed by Hongxin Liu
parent b5bfeb2efd
commit 13b48ac0aa
8 changed files with 218 additions and 335 deletions

View File

@@ -110,12 +110,8 @@ class BucketStore(BaseStore):
flat_grad = []
for grad_list in self._grad_in_bucket.values():
if len(grad_list) > 0:
flat_grad.append(_flatten_dense_tensors(grad_list))
if len(flat_grad) > 0:
flat_grad = _flatten_dense_tensors(flat_grad)
else:
flat_grad = torch.tensor([], device=self.comm_stream.device, dtype=dtype)
flat_grad.append(_flatten_dense_tensors(grad_list))
flat_grad = _flatten_dense_tensors(flat_grad)
return flat_grad
def get_param_id_of_grad(self, grad: Tensor) -> int: