mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[tensor] redistribute among different process groups (#1247)
* make it faster * [tensor] rename convert_to_dist -> redistribute * [tensor] ShardSpec and ReplicaSpec * [tensor] redistribute among diff pgs * polish code
This commit is contained in:
@@ -164,7 +164,7 @@ def run_check_shared_param():
|
||||
# TODO(jiaruifang) optimize this line
|
||||
if not model.cls.predictions.bias.has_initialized:
|
||||
model.cls.predictions.bias.pg = pg
|
||||
model.cls.predictions.bias.dist_spec = distspec.replicate()
|
||||
model.cls.predictions.bias.dist_spec = ReplicaSpec()
|
||||
model.cls.predictions.bias.has_initialized = True
|
||||
model.cls.predictions.bias.set_tensor_spec(*col_spec)
|
||||
try:
|
||||
|
@@ -9,7 +9,6 @@ from colossalai.utils import get_current_device
|
||||
from torch.nn import Parameter
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import distspec
|
||||
|
||||
|
||||
def _run_layer_norm():
|
||||
|
@@ -5,7 +5,7 @@ from numpy import allclose
|
||||
|
||||
import colossalai
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import distspec, ColoTensorSpec
|
||||
from colossalai.tensor import ColoTensorSpec
|
||||
from colossalai.core import global_context as gpc
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
@@ -85,7 +85,7 @@ def _run_tensor_shard_init(world_size):
|
||||
shard_attr = ShardSpec(dims=[0], num_partitions=[pg.tp_world_size()])
|
||||
tensor_spec = ColoTensorSpec(pg, dist_attr=shard_attr)
|
||||
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
||||
t.set_dist_spec(distspec.replicate())
|
||||
t.set_dist_spec(ReplicaSpec())
|
||||
|
||||
assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape} vs ({4 * world_size, 5})"
|
||||
|
||||
@@ -102,10 +102,26 @@ def _run_tensor_replicated_init(world_size):
|
||||
def _run_process_group(world_size):
|
||||
pg1 = ProcessGroup()
|
||||
pg2 = ProcessGroup()
|
||||
|
||||
assert pg1 == pg2
|
||||
|
||||
|
||||
def _run_redistributed(world_size):
|
||||
if world_size != 4:
|
||||
return
|
||||
pg1 = ProcessGroup(tp_degree=2, dp_degree=2)
|
||||
pg2 = ProcessGroup(tp_degree=4, dp_degree=1)
|
||||
|
||||
spec1 = ColoTensorSpec(pg1)
|
||||
t1 = ColoTensor.from_torch_tensor(torch.randn(2, 3, 4), spec1)
|
||||
t1 = t1.redistribute(ShardSpec([0], [pg1.tp_world_size()]))
|
||||
assert t1.is_sharded()
|
||||
t1 = t1.redistribute(ShardSpec([-1], [pg2.tp_world_size()]), pg2)
|
||||
assert t1.is_sharded()
|
||||
pg3 = ProcessGroup(tp_degree=1, dp_degree=4)
|
||||
t1 = t1.redistribute(ReplicaSpec(), pg3)
|
||||
assert t1.is_replicate()
|
||||
|
||||
|
||||
def run_dist_tests(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
_run_tensor_shard_init(world_size)
|
||||
@@ -115,6 +131,7 @@ def run_dist_tests(rank, world_size, port):
|
||||
_run_tensor_indexing()
|
||||
_run_operand(world_size)
|
||||
_run_wrapped_tensor_func()
|
||||
_run_redistributed(world_size)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@@ -126,4 +143,4 @@ def test_dist_cases(world_size):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_dist_cases(1)
|
||||
test_dist_cases(4)
|
||||
|
Reference in New Issue
Block a user