[hotfix] fixed memory usage of shardformer module replacement (#5122)

This commit is contained in:
アマデウス
2023-11-28 15:38:26 +08:00
committed by GitHub
parent 7b789f4dd2
commit 126cf180bc
2 changed files with 6 additions and 6 deletions

View File

@@ -112,7 +112,7 @@ def _split(tensor: torch.Tensor, comm_spec: CommSpec):
dim = comm_spec.shard_dim
length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group)
start = length * dist.get_rank(process_group)
output = torch.narrow(tensor, dim, start, length).contiguous()
output = torch.narrow(tensor, dim, start, length).clone().contiguous()
return output