sync before creating empty grad

This commit is contained in:
ver217
2022-03-16 13:40:19 +08:00
parent ea6905a898
commit fce9432f08
2 changed files with 7 additions and 4 deletions

View File

@@ -218,6 +218,7 @@ class ShardedModelV2(nn.Module):
else:
self._reduce_scatter_callback(param, new_grad)
orig_grad_data.record_stream(self.comm_stream)
torch.cuda.current_stream().wait_stream(self.comm_stream)
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
return empty_grad