mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[tensor] a shorter shard and replicate spec (#1245)
This commit is contained in:
@@ -4,7 +4,7 @@ import pytest
|
||||
import torch.nn as nn
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.tensor import ColoTensor, ProcessGroup
|
||||
from colossalai.tensor import distspec
|
||||
from colossalai.tensor import ShardSpec
|
||||
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
@@ -37,13 +37,13 @@ class Conv1D(nn.Module):
|
||||
|
||||
|
||||
def init_1d_row(weight, bias, pg: ProcessGroup):
|
||||
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def init_1d_col(weight, bias, pg: ProcessGroup):
|
||||
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(*spec)
|
||||
bias.set_tensor_spec(*spec)
|
||||
|
@@ -4,10 +4,9 @@ import torch.distributed as dist
|
||||
import pytest
|
||||
import colossalai
|
||||
import torch.multiprocessing as mp
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import DistSpecManager, distspec, ProcessGroup
|
||||
from colossalai.tensor import DistSpecManager, ProcessGroup, ShardSpec, ReplicaSpec
|
||||
from functools import partial
|
||||
|
||||
|
||||
@@ -18,10 +17,10 @@ def run():
|
||||
depth = int(math.sqrt(size))
|
||||
assert depth == math.sqrt(size)
|
||||
x = torch.rand(8, 8).cuda()
|
||||
old_dist_spec = distspec.replicate()
|
||||
row_spec = distspec.shard([0], [size])
|
||||
col_spec = distspec.shard([-1], [size])
|
||||
mat_spec = distspec.shard([0, 1], [depth, depth])
|
||||
old_dist_spec = ReplicaSpec()
|
||||
row_spec = ShardSpec([0], [size])
|
||||
col_spec = ShardSpec([-1], [size])
|
||||
mat_spec = ShardSpec([0, 1], [depth, depth])
|
||||
row_shard = DistSpecManager._shard_as(x, old_dist_spec, row_spec, group)
|
||||
assert torch.equal(x.chunk(size, 0)[rank], row_shard)
|
||||
assert torch.equal(x, DistSpecManager._gather(row_shard, row_spec, group))
|
||||
@@ -40,8 +39,8 @@ def check_mem():
|
||||
x = torch.rand(32, 32).cuda()
|
||||
orig_mem = x.numel() * x.element_size()
|
||||
assert torch.cuda.memory_allocated() == orig_mem
|
||||
old_dist_spec = distspec.replicate()
|
||||
row_spec = distspec.shard([0], [size])
|
||||
old_dist_spec = ReplicaSpec()
|
||||
row_spec = ShardSpec([0], [size])
|
||||
x.data = DistSpecManager._shard_as(x, old_dist_spec, row_spec, pg)
|
||||
assert x.size(0) == 32 // size and x.size(1) == 32
|
||||
assert torch.cuda.memory_allocated() == orig_mem // size
|
||||
|
@@ -1,5 +1,5 @@
|
||||
import torch
|
||||
from colossalai.tensor import distspec, ColoParameter
|
||||
from colossalai.tensor import ShardSpec, ColoParameter
|
||||
from torch.nn import functional as F
|
||||
from functools import partial
|
||||
|
||||
@@ -14,7 +14,7 @@ from _utils import tensor_equal, tensor_shard_equal
|
||||
|
||||
|
||||
def init_1d_col(weight, pg: ProcessGroup):
|
||||
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(*spec)
|
||||
|
||||
|
@@ -1,5 +1,5 @@
|
||||
import torch
|
||||
from colossalai.tensor import ColoTensor, distspec
|
||||
from colossalai.tensor import ColoTensor, ShardSpec
|
||||
from torch.nn import functional as F
|
||||
from functools import partial
|
||||
|
||||
@@ -14,13 +14,13 @@ from _utils import tensor_equal, tensor_shard_equal
|
||||
|
||||
|
||||
def init_1d_row(weight, pg: ProcessGroup):
|
||||
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def init_1d_col(weight, pg: ProcessGroup):
|
||||
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(*spec)
|
||||
|
||||
|
@@ -12,7 +12,7 @@ from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup
|
||||
from colossalai.tensor import ShardSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
|
||||
from colossalai.nn.parallel.data_parallel import ColoDDP
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
@@ -20,7 +20,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
def init_1d_row_spec(model, pg: ProcessGroup):
|
||||
tensor_spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
tensor_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
for n, p in model.named_parameters():
|
||||
if 'weight' in n and 'ln' not in n:
|
||||
@@ -28,7 +28,7 @@ def init_1d_row_spec(model, pg: ProcessGroup):
|
||||
|
||||
|
||||
def init_1d_col_spec(model, pg: ProcessGroup):
|
||||
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
for n, p in model.named_parameters():
|
||||
if 'ln' not in n and ('weight' in n or 'bias' in n):
|
||||
|
@@ -1,5 +1,5 @@
|
||||
import torch
|
||||
from colossalai.tensor import ColoTensor, distspec
|
||||
from colossalai.tensor import ColoTensor, ShardSpec
|
||||
|
||||
from functools import partial
|
||||
|
||||
@@ -15,13 +15,13 @@ from _utils import tensor_equal, tensor_shard_equal
|
||||
|
||||
|
||||
def init_1d_row(weight, bias, pg: ProcessGroup):
|
||||
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def init_1d_col(weight, bias, pg: ProcessGroup):
|
||||
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(*spec)
|
||||
bias.set_tensor_spec(*spec)
|
||||
|
@@ -8,7 +8,7 @@ from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import distspec, ComputeSpec, ComputePattern
|
||||
from colossalai.tensor import ShardSpec, ComputeSpec, ComputePattern
|
||||
|
||||
|
||||
def check_cross_entropy():
|
||||
@@ -22,7 +22,7 @@ def check_cross_entropy():
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg))
|
||||
input_shard = input_t_colo.redistribute(distspec.shard([-1], [pg.tp_world_size()]))
|
||||
input_shard = input_t_colo.redistribute(ShardSpec([-1], [pg.tp_world_size()]))
|
||||
input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D))
|
||||
|
||||
output = F.cross_entropy(input_t, target)
|
||||
|
@@ -11,7 +11,7 @@ from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.tensor import distspec, ColoTensorSpec, ComputePattern, \
|
||||
from colossalai.tensor import ShardSpec, ColoTensorSpec, ComputePattern, \
|
||||
ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup
|
||||
from colossalai.nn.optimizer import ColoOptimizer
|
||||
|
||||
@@ -19,28 +19,28 @@ from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
def init_1d_row_linear(weight: ColoTensor, pg: ProcessGroup):
|
||||
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_process_group(pg)
|
||||
weight.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def init_1d_col_linear(weight, pg):
|
||||
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_process_group(pg)
|
||||
weight.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def init_1d_row_embedding(weight, pg):
|
||||
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_process_group(pg)
|
||||
weight.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def init_1d_col_embedding(weight, pg):
|
||||
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_process_group(pg)
|
||||
weight.set_tensor_spec(*spec)
|
||||
|
@@ -5,7 +5,7 @@ from functools import partial
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec
|
||||
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, ShardSpec, ReplicaSpec
|
||||
from colossalai.nn.parallel.layers import init_colo_module, check_colo_module
|
||||
from _utils import tensor_equal, tensor_shard_equal, set_seed
|
||||
|
||||
@@ -13,7 +13,7 @@ import colossalai
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
|
||||
from colossalai.tensor import distspec, ProcessGroup
|
||||
from colossalai.tensor import distspec, ProcessGroup, ReplicaSpec
|
||||
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
@@ -159,7 +159,7 @@ def run_check_shared_param():
|
||||
# They are all Linear, so both row is allowed. This should pass check.
|
||||
init_colo_module(model, compute_spec, pg=pg, recursive=True, mode='row')
|
||||
# This should be detected by check because you can not set weight as row while set bias as col.
|
||||
col_spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
col_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
|
||||
# TODO(jiaruifang) optimize this line
|
||||
if not model.cls.predictions.bias.has_initialized:
|
||||
|
@@ -4,7 +4,7 @@ import colossalai
|
||||
import torch.nn.functional as F
|
||||
import torch.multiprocessing as mp
|
||||
from functools import partial
|
||||
from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec
|
||||
from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from torch.nn import Parameter
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
@@ -47,7 +47,7 @@ def check_element_wise_ops():
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
t = torch.rand(2, 2)
|
||||
x = ColoTensor(t, spec=ColoTensorSpec(pg, distspec.shard([0], [pg.tp_world_size()])))
|
||||
x = ColoTensor(t, spec=ColoTensorSpec(pg, ShardSpec([0], [pg.tp_world_size()])))
|
||||
|
||||
check_spec_eq(x, x.cuda())
|
||||
assert torch.equal(x.cuda(), t.cuda())
|
||||
|
@@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import distspec, ColoTensor, ProcessGroup
|
||||
from colossalai.tensor import distspec, ColoTensor, ProcessGroup, ShardSpec, ReplicaSpec
|
||||
from functools import partial
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ def _run_operand(world_size):
|
||||
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg))
|
||||
t.set_dist_spec(distspec.shard([0], [world_size]))
|
||||
t.set_dist_spec(ShardSpec([0], [world_size]))
|
||||
t_new = torch.zeros_like(t)
|
||||
assert isinstance(t_new, ColoTensor)
|
||||
assert t_new.is_sharded()
|
||||
@@ -69,7 +69,7 @@ def _run_view(world_size):
|
||||
rank = gpc.get_global_rank()
|
||||
pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
|
||||
t = ColoTensor.from_torch_tensor(
|
||||
t_ref, ColoTensorSpec(pg, dist_attr=distspec.shard(dims=[0], num_partitions=[pg.tp_world_size()])))
|
||||
t_ref, ColoTensorSpec(pg, dist_attr=ShardSpec(dims=[0], num_partitions=[pg.tp_world_size()])))
|
||||
|
||||
assert t.size_global()[0] == 4 * world_size
|
||||
assert t.size_global(1) == 5
|
||||
@@ -82,7 +82,7 @@ def _run_view(world_size):
|
||||
def _run_tensor_shard_init(world_size):
|
||||
t_ref = torch.randn(4, 5)
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
shard_attr = distspec.shard(dims=[0], num_partitions=[pg.tp_world_size()])
|
||||
shard_attr = ShardSpec(dims=[0], num_partitions=[pg.tp_world_size()])
|
||||
tensor_spec = ColoTensorSpec(pg, dist_attr=shard_attr)
|
||||
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
||||
t.set_dist_spec(distspec.replicate())
|
||||
|
@@ -17,7 +17,7 @@ from colossalai.zero import ZeroOptimizer
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.amp import convert_to_apex_amp
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup
|
||||
from colossalai.tensor import ColoTensorSpec, ShardSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
|
||||
|
||||
|
||||
def check_param_equal(model, torch_model, pg: ProcessGroup):
|
||||
@@ -45,7 +45,7 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
|
||||
|
||||
|
||||
def init_1d_row_spec(model, pg: ProcessGroup):
|
||||
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
for n, p in model.named_parameters():
|
||||
if 'weight' in n and 'ln' not in n:
|
||||
@@ -53,7 +53,7 @@ def init_1d_row_spec(model, pg: ProcessGroup):
|
||||
|
||||
|
||||
def init_1d_col_spec(model, pg: ProcessGroup):
|
||||
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
for n, p in model.named_parameters():
|
||||
if 'ln' not in n and ('weight' in n or 'bias' in n):
|
||||
|
Reference in New Issue
Block a user