mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 15:01:43 +00:00
* [legacy] remove outdated codes of pipeline (#4692) * [legacy] remove cli of benchmark and update optim (#4690) * [legacy] remove cli of benchmark and update optim * [doc] fix cli doc test * [legacy] fix engine clip grad norm * [legacy] remove outdated colo tensor (#4694) * [legacy] remove outdated colo tensor * [test] fix test import * [legacy] move outdated zero to legacy (#4696) * [legacy] clean up utils (#4700) * [legacy] clean up utils * [example] update examples * [legacy] clean up amp * [legacy] fix amp module * [legacy] clean up gpc (#4742) * [legacy] clean up context * [legacy] clean core, constants and global vars * [legacy] refactor initialize * [example] fix examples ci * [example] fix examples ci * [legacy] fix tests * [example] fix gpt example * [example] fix examples ci * [devops] fix ci installation * [example] fix examples ci
23 lines
706 B
Python
23 lines
706 B
Python
from typing import Tuple
|
|
|
|
import torch
|
|
|
|
|
|
def get_shard(tensor: torch.Tensor, rank: int, world_size: int) -> Tuple[torch.Tensor, int]:
|
|
"""Return the local shard of a full tensor."""
|
|
# Shard using torch.chunk to match all-gather/reduce-scatter.
|
|
chunks = list(torch.flatten(tensor).chunk(world_size))
|
|
while len(chunks) < world_size:
|
|
chunks.append(chunks[0].new_empty(0))
|
|
|
|
# Determine number of padding elements.
|
|
num_to_pad = chunks[0].numel() - chunks[rank].numel()
|
|
assert num_to_pad >= 0, num_to_pad
|
|
|
|
shard = torch.zeros_like(chunks[0])
|
|
length = chunks[rank].size(0)
|
|
shard_temp = shard[:length]
|
|
shard_temp.copy_(chunks[rank])
|
|
|
|
return shard, num_to_pad
|