mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-09 12:03:34 +00:00
[refactor] remove gpc dependency in colotensor's _ops (#1189)
This commit is contained in:
@@ -11,7 +11,6 @@ import torch.multiprocessing as mp
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import distspec, TensorSpec, ColoTensor, ProcessGroup
|
||||
from colossalai.context import ParallelMode
|
||||
from functools import partial
|
||||
|
||||
|
||||
@@ -55,11 +54,9 @@ def test_operand():
|
||||
def _run_view(world_size):
|
||||
t_ref = torch.randn(4, 5)
|
||||
rank = gpc.get_global_rank()
|
||||
pg = ProcessGroup(rank, list(range(world_size)))
|
||||
assert pg.dp_world_size() == world_size, f"{pg.dp_world_size()} vs {world_size}"
|
||||
pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
|
||||
t = ColoTensor.from_torch_tensor(
|
||||
t_ref,
|
||||
TensorSpec(distspec.shard(process_group=pg.dp_process_group(), dims=[0], num_partitions=[pg.dp_world_size()])))
|
||||
t_ref, TensorSpec(distspec.shard(process_group=pg, dims=[0], num_partitions=[pg.tp_world_size()])))
|
||||
|
||||
assert t.size_global()[0] == 4 * world_size
|
||||
assert t.size_global(1) == 5
|
||||
@@ -77,12 +74,12 @@ def _run_tensor_shard_init(world_size):
|
||||
t_ref = torch.randn(4, 5)
|
||||
|
||||
rank = gpc.get_global_rank()
|
||||
pg = ProcessGroup(rank, list(range(world_size)))
|
||||
shard_spec = distspec.shard(process_group=pg.dp_process_group(), dims=[0], num_partitions=[pg.dp_world_size()])
|
||||
pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
|
||||
shard_spec = distspec.shard(process_group=pg, dims=[0], num_partitions=[pg.tp_world_size()])
|
||||
tensor_spec = TensorSpec(shard_spec)
|
||||
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
||||
t.set_tensor_spec(TensorSpec(dist_spec=distspec.replicate()))
|
||||
assert t.shape == torch.Size((4 * world_size, 5))
|
||||
assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape} vs ({4 * world_size, 5})"
|
||||
|
||||
|
||||
def _run_tensor_replicated_init(world_size):
|
||||
@@ -92,11 +89,19 @@ def _run_tensor_replicated_init(world_size):
|
||||
assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape}"
|
||||
|
||||
|
||||
def _run_process_group(world_size):
|
||||
pg1 = ProcessGroup()
|
||||
pg2 = ProcessGroup()
|
||||
|
||||
assert pg1 == pg2
|
||||
|
||||
|
||||
def run_dist_tests(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
_run_tensor_shard_init(world_size)
|
||||
_run_tensor_replicated_init(world_size)
|
||||
_run_view(world_size)
|
||||
_run_process_group(world_size)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
Reference in New Issue
Block a user