[zero] a shard strategy in granularity of tensor (#307)

This commit is contained in:
Jiarui Fang
2022-03-04 11:59:35 +08:00
committed by Frank Lee
parent 80364c7686
commit 74f77e314b
4 changed files with 56 additions and 10 deletions

View File

@@ -7,6 +7,8 @@ import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.zero.sharded_param import ShardedTensor, ShardedParam
from colossalai.utils import free_port
from colossalai.logging import get_dist_logger, disable_existing_loggers
@@ -18,19 +20,14 @@ def run_shard_tensor(rank, world_size, port):
t = ShardedTensor(tensor=torch.randn(world_size * 2, 3))
assert list(t.shape) == [world_size * 2, 3]
t.shard()
# The shape is flattened
assert list(t.shape) == [6]
# Do nothing
t.shard()
assert list(t.shape) == [6]
shard_strategy = TensorShardStrategy(process_group=None)
t.gather()
# test shard strategy
shard_strategy.shard([t])
assert list(t.shape) == [6]
shard_strategy.gather([t])
assert list(t.shape) == [world_size * 2, 3]
t.payload = torch.zeros(world_size * 2, 3)
assert torch.sum(t.payload).cpu() == 0
@pytest.mark.dist
def test_shard_tensor():