[tensor] redistribute among different process groups (#1247)

* make it faster

* [tensor] rename convert_to_dist -> redistribute

* [tensor] ShardSpec and ReplicaSpec

* [tensor] redistribute among diff pgs

* polish code
This commit is contained in:
Jiarui Fang
2022-07-12 10:24:05 +08:00
committed by GitHub
parent 9bcd2fd4af
commit 1aad903c15
8 changed files with 48 additions and 17 deletions

View File

@@ -164,7 +164,7 @@ def run_check_shared_param():
# TODO(jiaruifang) optimize this line
if not model.cls.predictions.bias.has_initialized:
model.cls.predictions.bias.pg = pg
model.cls.predictions.bias.dist_spec = distspec.replicate()
model.cls.predictions.bias.dist_spec = ReplicaSpec()
model.cls.predictions.bias.has_initialized = True
model.cls.predictions.bias.set_tensor_spec(*col_spec)
try:

View File

@@ -9,7 +9,6 @@ from colossalai.utils import get_current_device
from torch.nn import Parameter
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import distspec
def _run_layer_norm():

View File

@@ -5,7 +5,7 @@ from numpy import allclose
import colossalai
from colossalai.utils import free_port
from colossalai.tensor import distspec, ColoTensorSpec
from colossalai.tensor import ColoTensorSpec
from colossalai.core import global_context as gpc
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
@@ -85,7 +85,7 @@ def _run_tensor_shard_init(world_size):
shard_attr = ShardSpec(dims=[0], num_partitions=[pg.tp_world_size()])
tensor_spec = ColoTensorSpec(pg, dist_attr=shard_attr)
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
t.set_dist_spec(distspec.replicate())
t.set_dist_spec(ReplicaSpec())
assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape} vs ({4 * world_size, 5})"
@@ -102,10 +102,26 @@ def _run_tensor_replicated_init(world_size):
def _run_process_group(world_size):
pg1 = ProcessGroup()
pg2 = ProcessGroup()
assert pg1 == pg2
def _run_redistributed(world_size):
if world_size != 4:
return
pg1 = ProcessGroup(tp_degree=2, dp_degree=2)
pg2 = ProcessGroup(tp_degree=4, dp_degree=1)
spec1 = ColoTensorSpec(pg1)
t1 = ColoTensor.from_torch_tensor(torch.randn(2, 3, 4), spec1)
t1 = t1.redistribute(ShardSpec([0], [pg1.tp_world_size()]))
assert t1.is_sharded()
t1 = t1.redistribute(ShardSpec([-1], [pg2.tp_world_size()]), pg2)
assert t1.is_sharded()
pg3 = ProcessGroup(tp_degree=1, dp_degree=4)
t1 = t1.redistribute(ReplicaSpec(), pg3)
assert t1.is_replicate()
def run_dist_tests(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
_run_tensor_shard_init(world_size)
@@ -115,6 +131,7 @@ def run_dist_tests(rank, world_size, port):
_run_tensor_indexing()
_run_operand(world_size)
_run_wrapped_tensor_func()
_run_redistributed(world_size)
@pytest.mark.dist
@@ -126,4 +143,4 @@ def test_dist_cases(world_size):
if __name__ == '__main__':
test_dist_cases(1)
test_dist_cases(4)