mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[zero] polish shard strategy (#310)
* init shard param from shape tuple * add more unitest for shard param * add set_payload method for ShardedParam * [zero] add shareded tensor class * polish code * add shard stratgy * move shard and gather logic to shard strategy from shard tensor. * polish code
This commit is contained in:
@@ -7,7 +7,6 @@ import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from colossalai.zero.shard_utils import TensorShardStrategy
|
||||
from colossalai.zero.sharded_param import ShardedTensor, ShardedParam
|
||||
from colossalai.utils import free_port
|
||||
@@ -18,15 +17,16 @@ from tests.test_zero_data_parallel.common import Net, CONFIG
|
||||
def run_shard_tensor(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
t = ShardedTensor(tensor=torch.randn(world_size * 2, 3))
|
||||
|
||||
assert list(t.origin_shape) == [world_size * 2, 3]
|
||||
assert list(t.shape) == [world_size * 2, 3]
|
||||
|
||||
shard_strategy = TensorShardStrategy(process_group=None)
|
||||
|
||||
# test shard strategy
|
||||
shard_strategy.shard([t])
|
||||
assert list(t.shape) == [6]
|
||||
assert list(t.shape) == [6], f"{list(t.shape)} vs 6"
|
||||
shard_strategy.gather([t])
|
||||
assert list(t.shape) == [world_size * 2, 3]
|
||||
assert list(t.shape) == [world_size * 2, 3], f"{list(t.shape)} vs {[world_size * 2, 3]}"
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
Reference in New Issue
Block a user