mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[shardformer] support lazy init (#4202)
* [shardformer] support lazy init * [shardformer] linear support lazy init * [shardformer] embedding support lazy init * [shardformer] norm support lazy init * [shardformer] fused linear support lazy init * [test] update shardformer test layer * [test] shardformer with lazy init fit ddp * [lazy] hotfix deepcopy of param * [shardformer] fix bert policy and update test * [shardformer] fix bloom policy and update test * [shardformer] fix opt policy and update test * [shardformer] fix t5 policy and update test * [shardformer] fix gpt2 policy and update test * [shardformer] fix llama policy and update test
This commit is contained in:
@@ -1,15 +1,22 @@
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.shardformer.layer import Embedding1D
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def check_embedding_1d():
|
||||
embedding = nn.Embedding(32, 128).cuda()
|
||||
@parameterize('lazy_init', [False, True])
|
||||
def check_embedding_1d(lazy_init: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
with ctx:
|
||||
embedding = nn.Embedding(32, 128).cuda()
|
||||
embedding_1d = Embedding1D.from_native_module(embedding, process_group=None)
|
||||
|
||||
assert embedding_1d.weight.shape == torch.Size([32, 64])
|
||||
|
@@ -1,14 +1,21 @@
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.shardformer.layer import FusedLayerNorm
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def check_layernorm():
|
||||
norm = nn.LayerNorm(128, 0.00001).cuda()
|
||||
@parameterize('lazy_init', [False, True])
|
||||
def check_layernorm(lazy_init: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
with ctx:
|
||||
norm = nn.LayerNorm(128, 0.00001).cuda()
|
||||
norm1d = FusedLayerNorm.from_native_module(norm, process_group=None)
|
||||
|
||||
assert norm1d.weight.shape == torch.Size([128])
|
||||
|
@@ -1,16 +1,23 @@
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
|
||||
from colossalai.tensor.d_tensor import is_distributed_tensor
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def check_linear_1d_col():
|
||||
linear = nn.Linear(32, 128).cuda()
|
||||
@parameterize('lazy_init', [False, True])
|
||||
def check_linear_1d_col(lazy_init: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
with ctx:
|
||||
linear = nn.Linear(32, 128).cuda()
|
||||
linear_col = Linear1D_Col.from_native_module(linear, process_group=None, gather_output=True)
|
||||
|
||||
# ensure that the parameters are distributed
|
||||
@@ -50,8 +57,12 @@ def check_linear_1d_col():
|
||||
assert_close(x_for_unshard.grad, x_for_shard.grad)
|
||||
|
||||
|
||||
def check_linear_1d_row():
|
||||
linear = nn.Linear(32, 128).cuda()
|
||||
@parameterize('lazy_init', [False, True])
|
||||
def check_linear_1d_row(lazy_init: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
with ctx:
|
||||
linear = nn.Linear(32, 128).cuda()
|
||||
linear_row = Linear1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
|
||||
|
||||
assert linear_row.weight.shape == torch.Size([128, 16])
|
||||
@@ -83,9 +94,13 @@ def check_linear_1d_row():
|
||||
assert_close(x_for_unshard.grad, x_for_shard.grad)
|
||||
|
||||
|
||||
def check_linear_col_plus_row():
|
||||
linear_1 = nn.Linear(32, 128).cuda()
|
||||
linear_2 = nn.Linear(128, 32).cuda()
|
||||
@parameterize('lazy_init', [False, True])
|
||||
def check_linear_col_plus_row(lazy_init: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
with ctx:
|
||||
linear_1 = nn.Linear(32, 128).cuda()
|
||||
linear_2 = nn.Linear(128, 32).cuda()
|
||||
linear_col = Linear1D_Col.from_native_module(linear_1, process_group=None, gather_output=False)
|
||||
linear_row = Linear1D_Row.from_native_module(linear_2, process_group=None, parallel_input=True)
|
||||
|
||||
|
@@ -1,12 +1,15 @@
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
||||
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
# This code is copied from https://github.com/huggingface/transformers
|
||||
@@ -50,8 +53,12 @@ def rearrange(tensor: torch.Tensor, dim: int):
|
||||
return rearanged_tensor
|
||||
|
||||
|
||||
def check_linear_conv_1d_col():
|
||||
linear = Conv1D(192, 48).cuda()
|
||||
@parameterize('lazy_init', [False, True])
|
||||
def check_linear_conv_1d_col(lazy_init: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
with ctx:
|
||||
linear = Conv1D(192, 48).cuda()
|
||||
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear,
|
||||
process_group=None,
|
||||
gather_output=True,
|
||||
@@ -80,8 +87,12 @@ def check_linear_conv_1d_col():
|
||||
assert_close(target_grad, linear_conv_col.weight.grad)
|
||||
|
||||
|
||||
def check_linear_conv_1d_row():
|
||||
linear = Conv1D(192, 48).cuda()
|
||||
@parameterize('lazy_init', [False, True])
|
||||
def check_linear_conv_1d_row(lazy_init: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
with ctx:
|
||||
linear = Conv1D(192, 48).cuda()
|
||||
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
|
||||
|
||||
assert linear.weight.shape == torch.Size([48, 192])
|
||||
|
@@ -1,15 +1,23 @@
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.shardformer.layer import VocabParallelEmbedding1D
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row, VocabParallelEmbedding1D
|
||||
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def check_vocab_embedding_1d():
|
||||
embedding = nn.Embedding(128, 32).to('cuda')
|
||||
@parameterize('lazy_init', [False, True])
|
||||
def check_vocab_embedding_1d(lazy_init: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
with ctx:
|
||||
embedding = nn.Embedding(128, 32).to('cuda')
|
||||
dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding, process_group=None)
|
||||
|
||||
assert dist_embedding_1d.weight.shape == torch.Size([64, 32])
|
||||
|
@@ -1,19 +1,24 @@
|
||||
import copy
|
||||
from contextlib import nullcontext
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
|
||||
|
||||
def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True):
|
||||
# create new model
|
||||
org_model = model_fn().cuda()
|
||||
|
||||
def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True, use_lazy_init: bool = False):
|
||||
ctx = LazyInitContext() if use_lazy_init else nullcontext()
|
||||
with ctx:
|
||||
# create new model
|
||||
org_model = model_fn()
|
||||
model_copy = copy.deepcopy(org_model)
|
||||
if use_lazy_init:
|
||||
ctx.materialize(org_model)
|
||||
# shard model
|
||||
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
|
||||
enable_tensor_parallelism=enable_tensor_parallelism)
|
||||
model_copy = copy.deepcopy(org_model)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
sharded_model, shared_params = shard_former.optimize(model_copy)
|
||||
return org_model, sharded_model.cuda()
|
||||
return org_model.cuda(), sharded_model.cuda()
|
||||
|
||||
|
||||
def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
|
@@ -67,12 +67,14 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
def run_bert_test(enable_fused_normalization, enable_tensor_parallelism):
|
||||
@parameterize('enable_fused_normalization', [False, True])
|
||||
@parameterize('enable_tensor_parallelism', [False, True])
|
||||
@parameterize('use_lazy_init', [False, True])
|
||||
def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
use_lazy_init)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
@@ -69,10 +69,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism):
|
||||
@parameterize('use_lazy_init', [False, True])
|
||||
def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
use_lazy_init)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
@@ -69,10 +69,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism):
|
||||
@parameterize('use_lazy_init', [False, True])
|
||||
def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
use_lazy_init)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
@@ -72,10 +72,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism):
|
||||
@parameterize('use_lazy_init', [False, True])
|
||||
def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
use_lazy_init)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
@@ -71,10 +71,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism):
|
||||
@parameterize('use_lazy_init', [False, True])
|
||||
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
use_lazy_init)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
@@ -82,10 +82,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism):
|
||||
@parameterize('use_lazy_init', [False, True])
|
||||
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
use_lazy_init)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from contextlib import nullcontext
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -5,15 +7,15 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
def check_shardformer_with_ddp(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
@parameterize('lazy_init', [True, False])
|
||||
def check_shardformer_with_ddp(lazy_init: bool):
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
||||
|
||||
@@ -41,9 +43,12 @@ def check_shardformer_with_ddp(rank, world_size, port):
|
||||
shard_config = ShardConfig(tensor_parallel_process_group=tp_process_group, enable_fused_normalization=True)
|
||||
shardformer = ShardFormer(shard_config=shard_config)
|
||||
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
# create and shard model
|
||||
model = model_fn().cuda()
|
||||
with ctx:
|
||||
model = model_fn().cuda()
|
||||
sharded_model, _ = shardformer.optimize(model)
|
||||
|
||||
# add ddp
|
||||
@@ -65,13 +70,18 @@ def check_shardformer_with_ddp(rank, world_size, port):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
check_shardformer_with_ddp()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_gpt2():
|
||||
spawn(check_shardformer_with_ddp, 4)
|
||||
spawn(run_dist, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_gpt2()
|
||||
test_gpt2()
|
||||
|
Reference in New Issue
Block a user