mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[zero] a shard strategy in granularity of tensor (#307)
This commit is contained in:
@@ -7,6 +7,8 @@ import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from colossalai.zero.shard_utils import TensorShardStrategy
|
||||
from colossalai.zero.sharded_param import ShardedTensor, ShardedParam
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.logging import get_dist_logger, disable_existing_loggers
|
||||
@@ -18,19 +20,14 @@ def run_shard_tensor(rank, world_size, port):
|
||||
t = ShardedTensor(tensor=torch.randn(world_size * 2, 3))
|
||||
|
||||
assert list(t.shape) == [world_size * 2, 3]
|
||||
t.shard()
|
||||
# The shape is flattened
|
||||
assert list(t.shape) == [6]
|
||||
# Do nothing
|
||||
t.shard()
|
||||
assert list(t.shape) == [6]
|
||||
shard_strategy = TensorShardStrategy(process_group=None)
|
||||
|
||||
t.gather()
|
||||
# test shard strategy
|
||||
shard_strategy.shard([t])
|
||||
assert list(t.shape) == [6]
|
||||
shard_strategy.gather([t])
|
||||
assert list(t.shape) == [world_size * 2, 3]
|
||||
|
||||
t.payload = torch.zeros(world_size * 2, 3)
|
||||
assert torch.sum(t.payload).cpu() == 0
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_shard_tensor():
|
||||
|
Reference in New Issue
Block a user