mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 23:18:36 +00:00
reorgnize colotensor directory (#1062)
* reorgnize colotensor directory * polish code
This commit is contained in:
@@ -1,9 +1,12 @@
|
||||
from colossalai.utils import free_port, ColoInitContext, get_current_device
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, init_colo_module
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
|
||||
|
||||
from functools import partial
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
|
||||
from colossalai.nn import init_colo_module
|
||||
from colossalai.nn.parallel import ColoDDP
|
||||
|
||||
import colossalai
|
||||
@@ -11,12 +14,14 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import pytest
|
||||
|
||||
|
||||
class Net(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.embed = torch.nn.Embedding(20, 4)
|
||||
self.proj = torch.nn.Linear(4, 8)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
# move input to cpu and restore output
|
||||
current_dev = x.device
|
||||
@@ -27,6 +32,7 @@ class Net(torch.nn.Module):
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
def run_hybrid_device(use_ddp):
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = Net()
|
||||
@@ -36,7 +42,6 @@ def run_hybrid_device(use_ddp):
|
||||
model = ColoDDP(model)
|
||||
real_model = model.module
|
||||
|
||||
|
||||
print(f'embedding weight size: {real_model.embed.weight.size()} | device: {real_model.embed.weight.device}')
|
||||
#print(f'linear weight size: {real_model.proj.weight.size()} | device: {real_model.proj.weight.device}')
|
||||
parallel_action = ParallelAction(ComputePattern.TP1D)
|
||||
@@ -49,11 +54,12 @@ def run_hybrid_device(use_ddp):
|
||||
|
||||
print(f'embedding weight size: {real_model.embed.weight.size()} | new device: {real_model.embed.weight.device}')
|
||||
#print(f'linear weight size: {real_model.proj.weight.size()} | new device: {real_model.proj.weight.device}')
|
||||
|
||||
|
||||
data = torch.randint(low=0, high=20, size=(16,), device=get_current_device())
|
||||
out = model(data)
|
||||
out.sum().backward()
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, use_ddp):
|
||||
if use_ddp and world_size == 1:
|
||||
return
|
||||
@@ -62,6 +68,7 @@ def run_dist(rank, world_size, port, use_ddp):
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_hybrid_device(use_ddp)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@pytest.mark.parametrize('use_ddp', [False, True])
|
||||
@@ -71,5 +78,6 @@ def _test_hybrid_device(world_size, use_ddp):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp)
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_test_hybrid_device(1, False)
|
||||
_test_hybrid_device(1, False)
|
||||
|
@@ -10,9 +10,10 @@ from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils import ColoInitContext
|
||||
from colossalai.tensor import distspec, named_params_with_colotensor, TensorSpec, ComputePattern, \
|
||||
ParallelAction, ColoTensor, ColoOptimizer, DistSpecManager
|
||||
ParallelAction, ColoTensor, DistSpecManager
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn.optimizer import ColoOptimizer
|
||||
from functools import partial
|
||||
from _utils import set_seed
|
||||
|
||||
|
@@ -1,24 +1,28 @@
|
||||
from copy import copy
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
import torch
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.tensor import ColoTensor, distspec
|
||||
|
||||
import pytest
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
|
||||
from colossalai.nn import init_colo_module, check_colo_module
|
||||
from _utils import tensor_equal, tensor_shard_equal, set_seed
|
||||
|
||||
import colossalai
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.tensor import distspec
|
||||
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, register_colo_module, init_colo_module, check_colo_module
|
||||
from _utils import tensor_equal, tensor_shard_equal, set_seed
|
||||
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
def run_model_with_spec(mode, model_name):
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
@@ -27,7 +31,7 @@ def run_model_with_spec(mode, model_name):
|
||||
set_seed(1)
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder(checkpoint=False)
|
||||
|
||||
|
||||
if rank == 0:
|
||||
model_seq = model_builder(checkpoint=False)
|
||||
model_seq = model_seq.cuda()
|
||||
@@ -103,15 +107,16 @@ def run_model_with_spec(mode, model_name):
|
||||
if i > 3:
|
||||
break
|
||||
|
||||
|
||||
def run_linear_with_spec(mode):
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = torch.nn.Linear(4, 8)
|
||||
|
||||
model_handy = copy(model)
|
||||
|
||||
|
||||
parallel_action = ParallelAction(ComputePattern.TP1D)
|
||||
init_colo_module(model, parallel_action, recursive=True, mode=mode)
|
||||
|
||||
|
||||
x = torch.rand(2, 4).cuda()
|
||||
out = model(x)
|
||||
colo_out = model_handy(x)
|
||||
@@ -122,6 +127,7 @@ def run_linear_with_spec(mode):
|
||||
assert tensor_shard_equal(model.weight.grad, model_handy.weight.grad)
|
||||
assert tensor_shard_equal(model.bias.grad, model_handy.bias.grad)
|
||||
|
||||
|
||||
def run_check_shared_param():
|
||||
from transformers import BertForMaskedLM, BertConfig
|
||||
hidden_dim = 8
|
||||
@@ -157,12 +163,14 @@ def run_check_shared_param():
|
||||
except Exception as e:
|
||||
assert 'incorrectly sharded' in str(e)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_linear_with_spec('col')
|
||||
run_linear_with_spec('row')
|
||||
|
||||
|
||||
def run_dist_model(rank, world_size, port):
|
||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
@@ -170,11 +178,13 @@ def run_dist_model(rank, world_size, port):
|
||||
run_model_with_spec('col', model_name)
|
||||
run_model_with_spec('row', model_name)
|
||||
|
||||
|
||||
def run_dist_check(rank, world_size, port):
|
||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_check_shared_param()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
@@ -182,6 +192,7 @@ def test_module_linear_1d(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
@@ -189,6 +200,7 @@ def test_module_model(world_size):
|
||||
run_func = partial(run_dist_model, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
@@ -196,5 +208,6 @@ def test_module_check(world_size):
|
||||
run_func = partial(run_dist_check, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_module_check(2)
|
||||
test_module_check(2)
|
||||
|
Reference in New Issue
Block a user