mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-14 05:33:23 +00:00
[zero] solve hang
This commit is contained in:
@@ -110,12 +110,8 @@ class BucketStore(BaseStore):
|
||||
|
||||
flat_grad = []
|
||||
for grad_list in self._grad_in_bucket.values():
|
||||
if len(grad_list) > 0:
|
||||
flat_grad.append(_flatten_dense_tensors(grad_list))
|
||||
if len(flat_grad) > 0:
|
||||
flat_grad = _flatten_dense_tensors(flat_grad)
|
||||
else:
|
||||
flat_grad = torch.tensor([], device=self.comm_stream.device, dtype=dtype)
|
||||
flat_grad.append(_flatten_dense_tensors(grad_list))
|
||||
flat_grad = _flatten_dense_tensors(flat_grad)
|
||||
return flat_grad
|
||||
|
||||
def get_param_id_of_grad(self, grad: Tensor) -> int:
|
||||
|
Reference in New Issue
Block a user