fix dist spec mgr (#1045)

This commit is contained in:
ver217
2022-05-31 12:14:39 +08:00
committed by GitHub
parent 9492a561c3
commit 7faef93326
2 changed files with 18 additions and 1 deletions

View File

@@ -34,7 +34,7 @@ class DistSpecManager:
chunk_size = divide(tensor.size(dim), dist_spec.num_partitions[i])
chunk = chunk.narrow(dim, idx // num_parts * chunk_size, chunk_size)
idx %= num_parts
return chunk.detach().contiguous()
return chunk.clone().detach().contiguous()
@staticmethod
def _gather(tensor: torch.Tensor, old_dist_spec: _DistSpec) -> torch.Tensor: