ColossalAI/colossalai/legacy/zero/shard_utils/commons.py
Hongxin Liu b5f9e37c70
[legacy] clean up legacy code (#4743)
* [legacy] remove outdated codes of pipeline (#4692)

* [legacy] remove cli of benchmark and update optim (#4690)

* [legacy] remove cli of benchmark and update optim

* [doc] fix cli doc test

* [legacy] fix engine clip grad norm

* [legacy] remove outdated colo tensor (#4694)

* [legacy] remove outdated colo tensor

* [test] fix test import

* [legacy] move outdated zero to legacy (#4696)

* [legacy] clean up utils (#4700)

* [legacy] clean up utils

* [example] update examples

* [legacy] clean up amp

* [legacy] fix amp module

* [legacy] clean up gpc (#4742)

* [legacy] clean up context

* [legacy] clean core, constants and global vars

* [legacy] refactor initialize

* [example] fix examples ci

* [example] fix examples ci

* [legacy] fix tests

* [example] fix gpt example

* [example] fix examples ci

* [devops] fix ci installation

* [example] fix examples ci
2023-09-18 16:31:06 +08:00

23 lines
706 B
Python

from typing import Tuple
import torch
def get_shard(tensor: torch.Tensor, rank: int, world_size: int) -> Tuple[torch.Tensor, int]:
"""Return the local shard of a full tensor."""
# Shard using torch.chunk to match all-gather/reduce-scatter.
chunks = list(torch.flatten(tensor).chunk(world_size))
while len(chunks) < world_size:
chunks.append(chunks[0].new_empty(0))
# Determine number of padding elements.
num_to_pad = chunks[0].numel() - chunks[rank].numel()
assert num_to_pad >= 0, num_to_pad
shard = torch.zeros_like(chunks[0])
length = chunks[rank].size(0)
shard_temp = shard[:length]
shard_temp.copy_(chunks[rank])
return shard, num_to_pad