mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[hotfix] fixed memory usage of shardformer module replacement (#5122)
This commit is contained in:
@@ -112,7 +112,7 @@ def _split(tensor: torch.Tensor, comm_spec: CommSpec):
|
||||
dim = comm_spec.shard_dim
|
||||
length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group)
|
||||
start = length * dist.get_rank(process_group)
|
||||
output = torch.narrow(tensor, dim, start, length).contiguous()
|
||||
output = torch.narrow(tensor, dim, start, length).clone().contiguous()
|
||||
return output
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user