mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[refactor] remove gpc dependency in colotensor's _ops (#1189)
This commit is contained in:
@@ -41,7 +41,7 @@ def tensor_equal(A, B):
|
||||
return torch.allclose(A, B, rtol=1e-3, atol=1e-1)
|
||||
|
||||
|
||||
def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor):
|
||||
def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor, rank, world_size):
|
||||
assert tensor.ndim == shard.ndim
|
||||
if tensor.shape == shard.shape:
|
||||
return tensor_equal(tensor, shard)
|
||||
@@ -50,8 +50,10 @@ def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor):
|
||||
if dims_not_eq.numel() == 1:
|
||||
# 1D shard
|
||||
dim = dims_not_eq.item()
|
||||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
if world_size is None:
|
||||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||
if rank is None:
|
||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
return tensor_equal(tensor.chunk(world_size, dim)[rank], shard)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
@@ -3,14 +3,12 @@ import torch
|
||||
import pytest
|
||||
import torch.nn as nn
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.tensor import ColoTensor
|
||||
from colossalai.tensor import ColoTensor, ProcessGroup
|
||||
from colossalai.tensor import distspec
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from functools import partial
|
||||
from colossalai.core import global_context as gpc
|
||||
from _utils import tensor_shard_equal, tensor_equal
|
||||
|
||||
|
||||
@@ -38,18 +36,14 @@ class Conv1D(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def init_1d_row(weight, bias):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
def init_1d_row(weight, bias, pg: ProcessGroup):
|
||||
spec = TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_col(weight, bias):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
def init_1d_col(weight, bias, pg: ProcessGroup):
|
||||
spec = TensorSpec(distspec.shard(pg, [-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(spec)
|
||||
bias.set_tensor_spec(spec)
|
||||
@@ -59,7 +53,9 @@ def run_with_spec(spec_init_func):
|
||||
model = Conv1D(4, 16).cuda()
|
||||
weight = ColoTensor(torch.nn.Parameter(model.weight.detach()))
|
||||
bias = ColoTensor(torch.nn.Parameter(model.bias.detach()))
|
||||
spec_init_func(weight, bias)
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
spec_init_func(weight, bias, pg)
|
||||
x = torch.rand(2, 16).cuda()
|
||||
out = model(x)
|
||||
colo_out = torch.addmm(bias, x, weight)
|
||||
@@ -68,13 +64,12 @@ def run_with_spec(spec_init_func):
|
||||
grad = torch.rand_like(out)
|
||||
out.backward(grad)
|
||||
colo_out.backward(grad)
|
||||
tensor_shard_equal(model.weight.grad, weight.grad)
|
||||
tensor_shard_equal(model.bias.grad, bias.grad)
|
||||
tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
tensor_shard_equal(model.bias.grad, bias.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_with_spec(init_1d_row)
|
||||
run_with_spec(init_1d_col)
|
||||
|
||||
|
@@ -7,12 +7,12 @@ import torch.multiprocessing as mp
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import DistSpecManager, distspec
|
||||
from colossalai.tensor import DistSpecManager, distspec, ProcessGroup
|
||||
from functools import partial
|
||||
|
||||
|
||||
def run():
|
||||
group = _get_default_group()
|
||||
group = ProcessGroup(tp_degree=dist.get_world_size())
|
||||
rank = dist.get_rank()
|
||||
size = dist.get_world_size()
|
||||
depth = int(math.sqrt(size))
|
||||
@@ -34,7 +34,7 @@ def run():
|
||||
|
||||
|
||||
def check_mem():
|
||||
group = _get_default_group()
|
||||
group = ProcessGroup(tp_degree=dist.get_world_size())
|
||||
size = dist.get_world_size()
|
||||
assert torch.cuda.memory_allocated() == 0
|
||||
x = torch.rand(32, 32).cuda()
|
||||
|
@@ -1,6 +1,5 @@
|
||||
import torch
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.tensor import ColoTensor, distspec, ColoParameter
|
||||
from colossalai.tensor import distspec, ColoParameter
|
||||
from torch.nn import functional as F
|
||||
from functools import partial
|
||||
|
||||
@@ -10,23 +9,21 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
|
||||
from _utils import tensor_equal, tensor_shard_equal
|
||||
|
||||
|
||||
def init_1d_col(weight):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
def init_1d_col(weight, pg: ProcessGroup):
|
||||
spec = TensorSpec(distspec.shard(pg, [-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def run_with_spec(spec_init_func):
|
||||
pg = ProcessGroup(tp_degree=torch.distributed.get_world_size())
|
||||
model = torch.nn.EmbeddingBag(10, 4).cuda()
|
||||
weight = ColoParameter(model.weight.clone())
|
||||
spec_init_func(weight)
|
||||
spec_init_func(weight, pg)
|
||||
inputs = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]).cuda()
|
||||
offsets = torch.tensor([0, 4]).cuda()
|
||||
out = model(inputs, offsets=offsets)
|
||||
@@ -35,7 +32,7 @@ def run_with_spec(spec_init_func):
|
||||
grad = torch.rand_like(out)
|
||||
out.backward(grad)
|
||||
colo_out.backward(grad)
|
||||
assert tensor_shard_equal(model.weight.grad, weight.grad)
|
||||
assert tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
@@ -1,5 +1,4 @@
|
||||
import torch
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.tensor import ColoTensor, distspec
|
||||
from torch.nn import functional as F
|
||||
from functools import partial
|
||||
@@ -11,30 +10,26 @@ import torch.multiprocessing as mp
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
|
||||
from _utils import tensor_equal, tensor_shard_equal
|
||||
|
||||
|
||||
def init_1d_row(weight):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
def init_1d_row(weight, pg: ProcessGroup):
|
||||
spec = TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_col(weight):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
def init_1d_col(weight, pg: ProcessGroup):
|
||||
spec = TensorSpec(distspec.shard(pg, [-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def run_with_spec(spec_init_func):
|
||||
def run_with_spec(spec_init_func, pg: ProcessGroup):
|
||||
model = torch.nn.Embedding(12, 32).cuda()
|
||||
weight = ColoTensor(torch.nn.Parameter(model.weight.detach()))
|
||||
spec_init_func(weight)
|
||||
spec_init_func(weight, pg)
|
||||
x = torch.tensor((0, 3, 6, 9)).cuda()
|
||||
out = model(x)
|
||||
colo_out = F.embedding(x, weight)
|
||||
@@ -42,14 +37,16 @@ def run_with_spec(spec_init_func):
|
||||
grad = torch.rand_like(out)
|
||||
out.backward(grad)
|
||||
colo_out.backward(grad)
|
||||
assert tensor_shard_equal(model.weight.grad, weight.grad)
|
||||
# compare grad inside a TP group
|
||||
assert tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_with_spec(init_1d_row)
|
||||
run_with_spec(init_1d_col)
|
||||
# config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
run_with_spec(init_1d_row, pg)
|
||||
run_with_spec(init_1d_col, pg)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
@@ -1,51 +1,54 @@
|
||||
import pytest
|
||||
|
||||
import colossalai
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup
|
||||
|
||||
from functools import partial
|
||||
from _utils import tensor_equal, tensor_shard_equal, set_seed
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
import torch
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from colossalai.nn.parallel.data_parallel import ColoDDP
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
|
||||
|
||||
def init_1d_row_spec(model):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
def init_1d_row_spec(model, pg: ProcessGroup):
|
||||
tensor_spec = TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
for n, p in model.named_parameters():
|
||||
if 'weight' in n and 'ln' not in n:
|
||||
p.set_tensor_spec(spec)
|
||||
p.set_tensor_spec(tensor_spec)
|
||||
|
||||
|
||||
def init_1d_col_spec(model):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
def init_1d_col_spec(model, pg: ProcessGroup):
|
||||
spec = TensorSpec(distspec.shard(pg, [-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
for n, p in model.named_parameters():
|
||||
if 'ln' not in n and ('weight' in n or 'bias' in n):
|
||||
p.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def check_param_equal(model, torch_model):
|
||||
def check_param_equal(model, torch_model, pg: ProcessGroup):
|
||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
||||
assert tensor_shard_equal(torch_p, p)
|
||||
assert pg.tp_local_rank() is not None, f"{pg.rank()} {pg.tp_world_size()} {pg._tp_degree} {pg.tp_local_rank()}1"
|
||||
assert pg.tp_world_size() is not None
|
||||
assert tensor_shard_equal(torch_p, p, pg.tp_local_rank(), pg.tp_world_size())
|
||||
|
||||
|
||||
def check_grad_equal(model, torch_model):
|
||||
def check_grad_equal(model, torch_model, pg: ProcessGroup):
|
||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
||||
assert tensor_shard_equal(torch_p.grad, p.grad)
|
||||
assert tensor_shard_equal(torch_p.grad, p.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
|
||||
|
||||
def run_gpt(init_spec_func, use_ddp):
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(dp_degree=(2 if (use_ddp and world_size >= 2) else 1))
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
@@ -54,21 +57,25 @@ def run_gpt(init_spec_func, use_ddp):
|
||||
model = model.cuda()
|
||||
torch_model = model_builder().cuda()
|
||||
if use_ddp:
|
||||
model = ColoDDP(model)
|
||||
# torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg)
|
||||
# torch.distributed.barrier()
|
||||
torch_model = DDP(torch_model,
|
||||
device_ids=[gpc.get_global_rank()],
|
||||
process_group=gpc.get_group(ParallelMode.DATA))
|
||||
|
||||
model = ColoDDP(model, process_group=pg)
|
||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||
torch_p.data.copy_(p)
|
||||
init_spec_func(model)
|
||||
check_param_equal(model, torch_model)
|
||||
init_spec_func(model, pg)
|
||||
check_param_equal(model, torch_model, pg)
|
||||
model.train()
|
||||
torch_model.train()
|
||||
set_seed(gpc.get_local_rank(ParallelMode.DATA))
|
||||
set_seed(pg.tp_local_rank())
|
||||
|
||||
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
|
||||
logits = model(input_ids, attn_mask)
|
||||
torch_logits = torch_model(input_ids, attn_mask)
|
||||
assert tensor_equal(torch_logits, logits)
|
||||
assert tensor_equal(torch_logits, logits), f"{torch_logits - logits}"
|
||||
loss = criterion(logits, input_ids)
|
||||
torch_loss = criterion(torch_logits, input_ids)
|
||||
if use_ddp:
|
||||
@@ -76,7 +83,7 @@ def run_gpt(init_spec_func, use_ddp):
|
||||
else:
|
||||
loss.backward()
|
||||
torch_loss.backward()
|
||||
check_grad_equal(model, torch_model)
|
||||
check_grad_equal(model, torch_model, pg)
|
||||
if i > 0:
|
||||
break
|
||||
|
||||
@@ -87,11 +94,12 @@ def run_dist(rank, world_size, port, use_ddp):
|
||||
tp_world_size = world_size // 2 if use_ddp else world_size
|
||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_gpt(init_1d_row_spec, use_ddp)
|
||||
# run_gpt(init_1d_row_spec, use_ddp)
|
||||
run_gpt(init_1d_col_spec, use_ddp)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.skip("under development")
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@pytest.mark.parametrize('use_ddp', [False, True])
|
||||
@rerun_if_address_is_in_use()
|
||||
|
@@ -1,88 +0,0 @@
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.tensor import ComputePattern, ComputeSpec
|
||||
|
||||
from functools import partial
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
|
||||
from colossalai.nn.parallel.layers import init_colo_module
|
||||
from colossalai.nn.parallel.data_parallel import ColoDDP
|
||||
from colossalai.nn.optimizer import ColoOptimizer
|
||||
|
||||
import colossalai
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import pytest
|
||||
|
||||
|
||||
class Net(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.embed = torch.nn.Embedding(20, 4)
|
||||
self.proj = torch.nn.Linear(4, 8)
|
||||
|
||||
def forward(self, x):
|
||||
# move input to cpu and restore output
|
||||
current_dev = x.device
|
||||
x = x.to('cpu')
|
||||
x = self.embed(x)
|
||||
x = x.to(current_dev)
|
||||
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
def run_hybrid_device(use_ddp, mode):
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = Net()
|
||||
|
||||
real_model = model
|
||||
if use_ddp:
|
||||
model = ColoDDP(model)
|
||||
real_model = model.module
|
||||
|
||||
print(f'embedding weight size: {real_model.embed.weight.size()} | device: {real_model.embed.weight.device}')
|
||||
#print(f'linear weight size: {real_model.proj.weight.size()} | device: {real_model.proj.weight.device}')
|
||||
parallel_action = ComputeSpec(ComputePattern.TP1D)
|
||||
init_colo_module(model, parallel_action, recursive=True, mode=mode)
|
||||
|
||||
# use cpu gloo to handle embedding
|
||||
real_model.embed.to('cpu')
|
||||
gloo_group_tp = gpc.get_cpu_group(ParallelMode.PARALLEL_1D)
|
||||
real_model.embed.weight.spec.dist_spec.process_group = gloo_group_tp
|
||||
|
||||
print(f'embedding weight size: {real_model.embed.weight.size()} | new device: {real_model.embed.weight.device}')
|
||||
#print(f'linear weight size: {real_model.proj.weight.size()} | new device: {real_model.proj.weight.device}')
|
||||
|
||||
optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)
|
||||
data = torch.randint(low=0, high=20, size=(16,), device=get_current_device())
|
||||
out = model(data)
|
||||
out.sum().backward()
|
||||
optimizer.step()
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, use_ddp, mode):
|
||||
if use_ddp and world_size == 1:
|
||||
return
|
||||
tp_world_size = world_size // 2 if use_ddp else world_size
|
||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_hybrid_device(use_ddp, mode)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@pytest.mark.parametrize('use_ddp', [False, True])
|
||||
@pytest.mark.parametrize('mode', ['col', 'row'])
|
||||
@rerun_if_address_is_in_use()
|
||||
# Working for simulate the embedding(CPU DP+TP) -> nn(GPU DP+TP)
|
||||
def _test_hybrid_device(world_size, use_ddp, mode):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp, mode=mode)
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_test_hybrid_device(4, True, 'row')
|
@@ -12,32 +12,29 @@ import torch.nn.functional as F
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
|
||||
from _utils import tensor_equal, tensor_shard_equal
|
||||
|
||||
|
||||
def init_1d_row(weight, bias):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
def init_1d_row(weight, bias, pg: ProcessGroup):
|
||||
spec = TensorSpec(distspec.shard(pg, [-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_col(weight, bias):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
def init_1d_col(weight, bias, pg: ProcessGroup):
|
||||
spec = TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(spec)
|
||||
bias.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def run_with_spec(spec_init_func):
|
||||
pg = ProcessGroup(tp_degree=torch.distributed.get_world_size())
|
||||
model = torch.nn.Linear(4, 8).cuda()
|
||||
weight = ColoTensor(torch.nn.Parameter(model.weight.detach()))
|
||||
bias = ColoTensor(torch.nn.Parameter(model.bias.detach()))
|
||||
spec_init_func(weight, bias)
|
||||
spec_init_func(weight, bias, pg)
|
||||
x = torch.rand(2, 4).cuda()
|
||||
out = model(x)
|
||||
colo_out = F.linear(x, weight, bias)
|
||||
@@ -46,8 +43,8 @@ def run_with_spec(spec_init_func):
|
||||
grad = torch.rand_like(out)
|
||||
out.backward(grad)
|
||||
colo_out.backward(grad)
|
||||
assert tensor_shard_equal(model.weight.grad, weight.grad)
|
||||
assert tensor_shard_equal(model.bias.grad, bias.grad)
|
||||
assert tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
assert tensor_shard_equal(model.bias.grad, bias.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
@@ -1,10 +1,12 @@
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
from functools import partial
|
||||
from _utils import tensor_shard_equal, set_seed
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
import colossalai
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
@@ -12,34 +14,30 @@ from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.tensor import distspec, TensorSpec, ComputePattern, \
|
||||
ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup
|
||||
from colossalai.nn.optimizer import ColoOptimizer
|
||||
from functools import partial
|
||||
from _utils import tensor_shard_equal, set_seed
|
||||
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
def init_1d_row_linear(weight, pg: ProcessGroup):
|
||||
spec = TensorSpec(distspec.shard(pg.tp_process_group(), [-1], [pg.tp_world_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
spec = TensorSpec(distspec.shard(pg, [-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_col_linear(weight, pg):
|
||||
spec = TensorSpec(distspec.shard(pg.tp_process_group(), [0], [pg.tp_world_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
spec = TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_row_embedding(weight, pg):
|
||||
spec = TensorSpec(distspec.shard(pg.tp_process_group(), [0], [pg.tp_world_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
spec = TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_col_embedding(weight, pg):
|
||||
spec = TensorSpec(distspec.shard(pg.tp_process_group(), [-1], [pg.tp_world_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
spec = TensorSpec(distspec.shard(pg, [-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
@@ -142,7 +140,7 @@ def run_1d_hybrid_tp(model_name):
|
||||
with torch.no_grad():
|
||||
# check param
|
||||
for p, torch_p in zip(model.parameters(), model_torch.parameters()):
|
||||
assert tensor_shard_equal(torch_p, p)
|
||||
assert tensor_shard_equal(torch_p, p, pg.tp_local_rank(), pg.tp_world_size())
|
||||
|
||||
if i > 5:
|
||||
break
|
||||
|
@@ -13,12 +13,10 @@ import colossalai
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.tensor import distspec
|
||||
from colossalai.tensor import distspec, ProcessGroup
|
||||
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
@@ -26,7 +24,9 @@ from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
def run_model_with_spec(mode, model_name):
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
rank = pg.rank()
|
||||
|
||||
set_seed(1)
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
@@ -40,28 +40,28 @@ def run_model_with_spec(mode, model_name):
|
||||
for p1, p2 in zip(model.parameters(), model_seq.parameters()):
|
||||
p2.data.copy_(p1.data)
|
||||
|
||||
parallel_action = ComputeSpec(ComputePattern.TP1D)
|
||||
compute_spec = ComputeSpec(ComputePattern.TP1D)
|
||||
# Not all layers in Bert can be mod by 4.
|
||||
# e.g. row shard for all layers is invalid because the first dim of some layer is the classification type size 2.
|
||||
if 'bert' == model_name:
|
||||
if 'col' == mode:
|
||||
init_colo_module(model.bert.embeddings, parallel_action, recursive=True, mode=mode)
|
||||
init_colo_module(model.bert.encoder, parallel_action, recursive=True, mode=mode)
|
||||
init_colo_module(model.classifier, parallel_action, recursive=True, mode='row')
|
||||
init_colo_module(model.bert.embeddings, compute_spec, pg=pg, recursive=True, mode=mode)
|
||||
init_colo_module(model.bert.encoder, compute_spec, pg=pg, recursive=True, mode=mode)
|
||||
init_colo_module(model.classifier, compute_spec, pg=pg, recursive=True, mode='row')
|
||||
elif 'row' == mode:
|
||||
init_colo_module(model.bert.embeddings, parallel_action, recursive=True, mode='col')
|
||||
init_colo_module(model.bert.encoder, parallel_action, recursive=True, mode=mode)
|
||||
init_colo_module(model.classifier, parallel_action, recursive=True, mode=mode)
|
||||
init_colo_module(model.bert.embeddings, compute_spec, pg=pg, recursive=True, mode='col')
|
||||
init_colo_module(model.bert.encoder, compute_spec, pg=pg, recursive=True, mode=mode)
|
||||
init_colo_module(model.classifier, compute_spec, pg=pg, recursive=True, mode=mode)
|
||||
elif 'simple_net' == model_name:
|
||||
init_colo_module(model, parallel_action, recursive=True, mode=mode)
|
||||
init_colo_module(model, compute_spec, pg=pg, recursive=True, mode=mode)
|
||||
|
||||
model = model.cuda()
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
data = data.to(get_current_device())
|
||||
label = label.to(get_current_device())
|
||||
|
||||
torch.distributed.broadcast(data, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
torch.distributed.broadcast(label, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
|
||||
torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
|
||||
|
||||
if criterion:
|
||||
output = model(data)
|
||||
@@ -113,9 +113,10 @@ def run_linear_with_spec(mode):
|
||||
model = torch.nn.Linear(4, 8)
|
||||
|
||||
model_handy = copy(model)
|
||||
|
||||
parallel_action = ComputeSpec(ComputePattern.TP1D)
|
||||
init_colo_module(model, parallel_action, recursive=True, mode=mode)
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
compute_spec = ComputeSpec(ComputePattern.TP1D)
|
||||
init_colo_module(model, compute_spec, pg=pg, recursive=True, mode=mode)
|
||||
|
||||
x = torch.rand(2, 4).cuda()
|
||||
out = model(x)
|
||||
@@ -124,8 +125,8 @@ def run_linear_with_spec(mode):
|
||||
grad = torch.rand_like(out)
|
||||
out.backward(grad)
|
||||
colo_out.backward(grad)
|
||||
assert tensor_shard_equal(model.weight.grad, model_handy.weight.grad)
|
||||
assert tensor_shard_equal(model.bias.grad, model_handy.bias.grad)
|
||||
assert tensor_shard_equal(model.weight.grad, model_handy.weight.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
assert tensor_shard_equal(model.bias.grad, model_handy.bias.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
|
||||
|
||||
def run_check_shared_param():
|
||||
@@ -136,6 +137,10 @@ def run_check_shared_param():
|
||||
num_layer = 2
|
||||
vocab_size = 24
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
rank = pg.rank()
|
||||
|
||||
config = BertConfig(vocab_size=vocab_size,
|
||||
hidden_size=hidden_dim,
|
||||
intermediate_size=hidden_dim * 4,
|
||||
@@ -148,18 +153,16 @@ def run_check_shared_param():
|
||||
model = BertForMaskedLM(config)
|
||||
|
||||
model = model.cuda()
|
||||
parallel_action = ComputeSpec(ComputePattern.TP1D)
|
||||
compute_spec = ComputeSpec(ComputePattern.TP1D)
|
||||
# model.cls.predictions.decoder and model.cls.predictions share the bias, so they should have the same spec
|
||||
assert len(model.cls.predictions.decoder.bias.shared_param_modules) == 2
|
||||
# They are all Linear, so both row is allowed. This should pass check.
|
||||
init_colo_module(model, parallel_action, recursive=True, mode='row')
|
||||
init_colo_module(model, compute_spec, pg=pg, recursive=True, mode='row')
|
||||
# This should be detected by check because you can not set weight as row while set bias as col.
|
||||
col_spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
col_spec = TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
model.cls.predictions.bias.set_tensor_spec(col_spec)
|
||||
try:
|
||||
check_colo_module(model.cls.predictions.decoder, recursive=False)
|
||||
check_colo_module(model.cls.predictions.decoder, pg=pg, recursive=False)
|
||||
except Exception as e:
|
||||
assert 'incorrectly sharded' in str(e)
|
||||
|
||||
|
@@ -4,10 +4,9 @@ import colossalai
|
||||
import torch.nn.functional as F
|
||||
import torch.multiprocessing as mp
|
||||
from functools import partial
|
||||
from colossalai.tensor import ColoTensor, ColoParameter
|
||||
from colossalai.tensor import ColoTensor, ProcessGroup
|
||||
from colossalai.utils import get_current_device
|
||||
from torch.nn import Parameter
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import distspec, TensorSpec
|
||||
@@ -43,9 +42,10 @@ def check_spec_eq(tensor, other):
|
||||
|
||||
|
||||
def check_element_wise_ops():
|
||||
pg = _get_default_group()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
t = torch.rand(2, 2)
|
||||
x = ColoTensor(t, spec=TensorSpec(distspec.shard(pg, [0], [pg.size()])))
|
||||
x = ColoTensor(t, spec=TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()])))
|
||||
check_spec_eq(x, x.cuda())
|
||||
assert torch.equal(x.cuda(), t.cuda())
|
||||
check_spec_eq(x, torch.abs(x))
|
||||
|
@@ -11,7 +11,6 @@ import torch.multiprocessing as mp
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import distspec, TensorSpec, ColoTensor, ProcessGroup
|
||||
from colossalai.context import ParallelMode
|
||||
from functools import partial
|
||||
|
||||
|
||||
@@ -55,11 +54,9 @@ def test_operand():
|
||||
def _run_view(world_size):
|
||||
t_ref = torch.randn(4, 5)
|
||||
rank = gpc.get_global_rank()
|
||||
pg = ProcessGroup(rank, list(range(world_size)))
|
||||
assert pg.dp_world_size() == world_size, f"{pg.dp_world_size()} vs {world_size}"
|
||||
pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
|
||||
t = ColoTensor.from_torch_tensor(
|
||||
t_ref,
|
||||
TensorSpec(distspec.shard(process_group=pg.dp_process_group(), dims=[0], num_partitions=[pg.dp_world_size()])))
|
||||
t_ref, TensorSpec(distspec.shard(process_group=pg, dims=[0], num_partitions=[pg.tp_world_size()])))
|
||||
|
||||
assert t.size_global()[0] == 4 * world_size
|
||||
assert t.size_global(1) == 5
|
||||
@@ -77,12 +74,12 @@ def _run_tensor_shard_init(world_size):
|
||||
t_ref = torch.randn(4, 5)
|
||||
|
||||
rank = gpc.get_global_rank()
|
||||
pg = ProcessGroup(rank, list(range(world_size)))
|
||||
shard_spec = distspec.shard(process_group=pg.dp_process_group(), dims=[0], num_partitions=[pg.dp_world_size()])
|
||||
pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
|
||||
shard_spec = distspec.shard(process_group=pg, dims=[0], num_partitions=[pg.tp_world_size()])
|
||||
tensor_spec = TensorSpec(shard_spec)
|
||||
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
||||
t.set_tensor_spec(TensorSpec(dist_spec=distspec.replicate()))
|
||||
assert t.shape == torch.Size((4 * world_size, 5))
|
||||
assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape} vs ({4 * world_size, 5})"
|
||||
|
||||
|
||||
def _run_tensor_replicated_init(world_size):
|
||||
@@ -92,11 +89,19 @@ def _run_tensor_replicated_init(world_size):
|
||||
assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape}"
|
||||
|
||||
|
||||
def _run_process_group(world_size):
|
||||
pg1 = ProcessGroup()
|
||||
pg2 = ProcessGroup()
|
||||
|
||||
assert pg1 == pg2
|
||||
|
||||
|
||||
def run_dist_tests(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
_run_tensor_shard_init(world_size)
|
||||
_run_tensor_replicated_init(world_size)
|
||||
_run_view(world_size)
|
||||
_run_process_group(world_size)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
@@ -2,13 +2,11 @@ import pytest
|
||||
import colossalai
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.gemini import ChunkManager
|
||||
from colossalai.core import global_context as gpc
|
||||
from functools import partial
|
||||
from _utils import tensor_equal, set_seed, tensor_shard_equal
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
@@ -19,20 +17,22 @@ from colossalai.zero import ZeroOptimizer
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.amp import convert_to_apex_amp
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup
|
||||
|
||||
|
||||
def check_param_equal(model, torch_model):
|
||||
def check_param_equal(model, torch_model, pg: ProcessGroup):
|
||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
||||
if p.storage().size() > 0:
|
||||
assert p.dtype == torch.half
|
||||
assert tensor_shard_equal(torch_p.to(dtype=p.dtype, device=p.device), p), f'{torch_p} vs {p}'
|
||||
assert tensor_shard_equal(torch_p.to(dtype=p.dtype, device=p.device), p, pg.tp_local_rank(),
|
||||
pg.tp_world_size()), f'{torch_p} vs {p}'
|
||||
|
||||
|
||||
def check_grad_equal(model, torch_model):
|
||||
def check_grad_equal(model, torch_model, pg: ProcessGroup):
|
||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
||||
if p.grad is not None:
|
||||
assert tensor_shard_equal(torch_p.grad.to(dtype=p.grad.dtype, device=p.grad.device), p.grad)
|
||||
assert tensor_shard_equal(torch_p.grad.to(dtype=p.grad.dtype, device=p.grad.device), p.grad,
|
||||
pg.tp_local_rank(), pg.tp_world_size())
|
||||
|
||||
|
||||
def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
|
||||
@@ -44,20 +44,16 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
|
||||
return logits
|
||||
|
||||
|
||||
def init_1d_row_spec(model):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
def init_1d_row_spec(model, pg: ProcessGroup):
|
||||
spec = TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
for n, p in model.named_parameters():
|
||||
if 'weight' in n and 'ln' not in n:
|
||||
p.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_col_spec(model):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
def init_1d_col_spec(model, pg: ProcessGroup):
|
||||
spec = TensorSpec(distspec.shard(pg, [-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
for n, p in model.named_parameters():
|
||||
if 'ln' not in n and ('weight' in n or 'bias' in n):
|
||||
@@ -79,44 +75,51 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
|
||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||
torch_p.data.copy_(p)
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
|
||||
# world size, dp = 2, tp =2, construct a hybrid parallelism.
|
||||
if world_size == 4:
|
||||
pg = ProcessGroup(tp_degree=2)
|
||||
else:
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
|
||||
if tp_init_spec_func:
|
||||
tp_init_spec_func(model)
|
||||
tp_init_spec_func(model, pg)
|
||||
|
||||
chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None
|
||||
chunk_manager = ChunkManager(chunk_size,
|
||||
enable_distributed_storage=use_zero,
|
||||
init_device=GeminiManager.get_default_device(placement_policy))
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pg)
|
||||
optim = HybridAdam(model.parameters(), lr=1e-3)
|
||||
optim = ZeroOptimizer(optim, model, initial_scale=32)
|
||||
|
||||
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=32)
|
||||
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
|
||||
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
||||
torch_model = DDP(torch_model, device_ids=[gpc.get_global_rank()], process_group=gpc.get_group(ParallelMode.DATA))
|
||||
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
|
||||
|
||||
print(chunk_manager)
|
||||
check_param_equal(model, torch_model)
|
||||
# print(chunk_manager)
|
||||
check_param_equal(model, torch_model, pg)
|
||||
model.train()
|
||||
torch_model.train()
|
||||
set_seed(gpc.get_local_rank(ParallelMode.DATA))
|
||||
set_seed(pg.dp_local_rank())
|
||||
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
|
||||
if i > 2:
|
||||
break
|
||||
|
||||
logits = run_fwd_bwd(model, criterion, optim, input_ids, attn_mask)
|
||||
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
|
||||
assert tensor_equal(logits, torch_logits)
|
||||
check_grad_equal(model, torch_model)
|
||||
check_grad_equal(model, torch_model, pg)
|
||||
optim.step()
|
||||
torch_optim.step()
|
||||
check_param_equal(model, torch_model)
|
||||
check_param_equal(model, torch_model, pg)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
if world_size == 4:
|
||||
config['parallel'] = {'tensor': {'mode': '1d', 'size': 2}}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
if world_size == 4:
|
||||
run_gpt(tp_init_spec_func=init_1d_col_spec)
|
||||
@@ -126,6 +129,7 @@ def run_dist(rank, world_size, port):
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.skip("under development")
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gpt(world_size):
|
||||
|
Reference in New Issue
Block a user