[zero] polish shard strategy (#310)

* init shard param from shape tuple

* add more unitest for shard param

* add set_payload method for ShardedParam

* [zero] add shareded tensor class

* polish code

* add shard stratgy

* move shard and gather logic to shard strategy from shard tensor.

* polish code
This commit is contained in:
Jiarui Fang
2022-03-04 15:35:07 +08:00
committed by Frank Lee
parent 3092317b80
commit c9e7d9582d
3 changed files with 56 additions and 35 deletions

View File

@@ -7,7 +7,6 @@ import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.zero.sharded_param import ShardedTensor, ShardedParam
from colossalai.utils import free_port
@@ -18,15 +17,16 @@ from tests.test_zero_data_parallel.common import Net, CONFIG
def run_shard_tensor(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
t = ShardedTensor(tensor=torch.randn(world_size * 2, 3))
assert list(t.origin_shape) == [world_size * 2, 3]
assert list(t.shape) == [world_size * 2, 3]
shard_strategy = TensorShardStrategy(process_group=None)
# test shard strategy
shard_strategy.shard([t])
assert list(t.shape) == [6]
assert list(t.shape) == [6], f"{list(t.shape)} vs 6"
shard_strategy.gather([t])
assert list(t.shape) == [world_size * 2, 3]
assert list(t.shape) == [world_size * 2, 3], f"{list(t.shape)} vs {[world_size * 2, 3]}"
@pytest.mark.dist