mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[shardformer] support inplace sharding (#4251)
* [shardformer] embedding support inplace sharding * [shardformer] linear support inplace sharding * [shardformer] layernorm support inplace sharding * [shardformer] qkv support inplace sharding * [test] update shardformer layer test * [shardformer] fix shared param sharding * [shardformer] fix bert policy * [shardformer] fix bloom policy * [shardformer] fix llama policy * [shardformer] fix opt policy * [shardformer] fix t5 policy * [shardformer] fix fused qkv linear * [shardformer] fix bugs * force sync * [test] fix bugs * [test] fix transformer version
This commit is contained in:
@@ -15,11 +15,13 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
def check_embedding_1d(lazy_init: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
embedding = nn.Embedding(32, 128).cuda()
|
||||
with ctx:
|
||||
embedding = nn.Embedding(32, 128).cuda()
|
||||
embedding_1d = Embedding1D.from_native_module(embedding, process_group=None)
|
||||
embedding_copy = nn.Embedding(32, 128).cuda()
|
||||
embedding_1d = Embedding1D.from_native_module(embedding_copy, process_group=None)
|
||||
|
||||
assert embedding_1d.weight.shape == torch.Size([32, 64])
|
||||
assert embedding_1d.weight is embedding_copy.weight
|
||||
|
||||
# ensure state dict is reversibly loadable
|
||||
embedding.load_state_dict(embedding_1d.state_dict())
|
||||
|
@@ -14,11 +14,14 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
def check_layernorm(lazy_init: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
norm = nn.LayerNorm(128, 0.00001).cuda()
|
||||
with ctx:
|
||||
norm = nn.LayerNorm(128, 0.00001).cuda()
|
||||
norm1d = FusedLayerNorm.from_native_module(norm, process_group=None)
|
||||
norm_copy = nn.LayerNorm(128, 0.00001).cuda()
|
||||
norm1d = FusedLayerNorm.from_native_module(norm_copy, process_group=None)
|
||||
|
||||
assert norm1d.weight.shape == torch.Size([128])
|
||||
assert norm_copy.weight is norm1d.weight
|
||||
assert norm_copy.bias is norm1d.bias
|
||||
|
||||
# ensure state dict is reversibly loadable
|
||||
norm.load_state_dict(norm1d.state_dict())
|
||||
|
@@ -15,14 +15,16 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
@parameterize('lazy_init', [False, True])
|
||||
def check_linear_1d_col(lazy_init: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
linear = nn.Linear(32, 128).cuda()
|
||||
with ctx:
|
||||
linear = nn.Linear(32, 128).cuda()
|
||||
linear_col = Linear1D_Col.from_native_module(linear, process_group=None, gather_output=True)
|
||||
linear_copy = nn.Linear(32, 128).cuda()
|
||||
linear_col = Linear1D_Col.from_native_module(linear_copy, process_group=None, gather_output=True)
|
||||
|
||||
# ensure that the parameters are distributed
|
||||
assert is_distributed_tensor(linear_col.weight)
|
||||
assert is_distributed_tensor(linear_col.bias)
|
||||
assert linear_copy.weight is linear_col.weight
|
||||
assert linear_copy.bias is linear_col.bias
|
||||
|
||||
# ensure the shape is correct
|
||||
assert linear_col.weight.shape == torch.Size([64, 32])
|
||||
@@ -61,12 +63,18 @@ def check_linear_1d_col(lazy_init: bool):
|
||||
def check_linear_1d_row(lazy_init: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
linear = nn.Linear(32, 128).cuda()
|
||||
with ctx:
|
||||
linear = nn.Linear(32, 128).cuda()
|
||||
linear_row = Linear1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
|
||||
linear_copy = nn.Linear(32, 128).cuda()
|
||||
linear_row = Linear1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False)
|
||||
|
||||
assert linear_row.weight.shape == torch.Size([128, 16])
|
||||
assert linear_row.bias.shape == torch.Size([128])
|
||||
assert linear_copy.weight is linear_row.weight
|
||||
assert linear_copy.bias is linear_row.bias
|
||||
|
||||
linear.load_state_dict(linear_row.state_dict())
|
||||
linear_row.load_state_dict(linear.state_dict())
|
||||
|
||||
# check computation correctness
|
||||
x = torch.rand(4, 32).cuda()
|
||||
@@ -98,11 +106,19 @@ def check_linear_1d_row(lazy_init: bool):
|
||||
def check_linear_col_plus_row(lazy_init: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
linear_1 = nn.Linear(32, 128).cuda()
|
||||
linear_2 = nn.Linear(128, 32).cuda()
|
||||
|
||||
with ctx:
|
||||
linear_1 = nn.Linear(32, 128).cuda()
|
||||
linear_2 = nn.Linear(128, 32).cuda()
|
||||
linear_col = Linear1D_Col.from_native_module(linear_1, process_group=None, gather_output=False)
|
||||
linear_row = Linear1D_Row.from_native_module(linear_2, process_group=None, parallel_input=True)
|
||||
linear_1_copy = nn.Linear(32, 128).cuda()
|
||||
linear_2_copy = nn.Linear(128, 32).cuda()
|
||||
linear_col = Linear1D_Col.from_native_module(linear_1_copy, process_group=None, gather_output=False)
|
||||
linear_row = Linear1D_Row.from_native_module(linear_2_copy, process_group=None, parallel_input=True)
|
||||
|
||||
linear_1.load_state_dict(linear_col.state_dict())
|
||||
linear_col.load_state_dict(linear_1.state_dict())
|
||||
linear_2.load_state_dict(linear_row.state_dict())
|
||||
linear_row.load_state_dict(linear_2.state_dict())
|
||||
|
||||
# check computation correctness
|
||||
x = torch.rand(4, 32).cuda()
|
||||
|
@@ -56,10 +56,10 @@ def rearrange(tensor: torch.Tensor, dim: int):
|
||||
@parameterize('lazy_init', [False, True])
|
||||
def check_linear_conv_1d_col(lazy_init: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
linear = Conv1D(192, 48).cuda()
|
||||
with ctx:
|
||||
linear = Conv1D(192, 48).cuda()
|
||||
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear,
|
||||
linear_copy = Conv1D(192, 48).cuda()
|
||||
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear_copy,
|
||||
process_group=None,
|
||||
gather_output=True,
|
||||
n_fused=3)
|
||||
@@ -68,6 +68,8 @@ def check_linear_conv_1d_col(lazy_init: bool):
|
||||
assert linear.bias.shape == torch.Size([192])
|
||||
assert linear_conv_col.weight.shape == torch.Size([48, 96])
|
||||
assert linear_conv_col.bias.shape == torch.Size([96])
|
||||
assert linear_copy.weight is linear_conv_col.weight
|
||||
assert linear_copy.bias is linear_conv_col.bias
|
||||
|
||||
# ensure weights are reversibly loadable
|
||||
linear_conv_col.load_state_dict(linear.state_dict())
|
||||
@@ -91,13 +93,20 @@ def check_linear_conv_1d_col(lazy_init: bool):
|
||||
def check_linear_conv_1d_row(lazy_init: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
linear = Conv1D(192, 48).cuda()
|
||||
with ctx:
|
||||
linear = Conv1D(192, 48).cuda()
|
||||
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
|
||||
linear_copy = Conv1D(192, 48).cuda()
|
||||
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False)
|
||||
|
||||
assert linear.weight.shape == torch.Size([48, 192])
|
||||
assert linear_row.weight.shape == torch.Size([24, 192])
|
||||
assert linear_row.bias.shape == torch.Size([192])
|
||||
assert linear_copy.weight is linear_row.weight
|
||||
assert linear_copy.bias is linear_row.bias
|
||||
|
||||
# ensure weights are reversibly loadable
|
||||
linear_row.load_state_dict(linear.state_dict())
|
||||
linear.load_state_dict(linear_row.state_dict())
|
||||
|
||||
# check computation correctness
|
||||
x = torch.rand(4, 48).cuda()
|
||||
|
@@ -7,8 +7,7 @@ from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row, VocabParallelEmbedding1D
|
||||
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
|
||||
from colossalai.shardformer.layer import VocabParallelEmbedding1D
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
@@ -16,13 +15,15 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
def check_vocab_embedding_1d(lazy_init: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
embedding = nn.Embedding(128, 32).to('cuda')
|
||||
with ctx:
|
||||
embedding = nn.Embedding(128, 32).to('cuda')
|
||||
dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding, process_group=None)
|
||||
embedding_copy = nn.Embedding(128, 32).to('cuda')
|
||||
dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding_copy, process_group=None)
|
||||
|
||||
assert dist_embedding_1d.weight.shape == torch.Size([64, 32])
|
||||
assert dist_embedding_1d.num_embeddings == 64
|
||||
assert dist_embedding_1d.embedding_dim == 32
|
||||
assert embedding_copy.weight is dist_embedding_1d.weight
|
||||
|
||||
# ensure state dict is reversibly loadable
|
||||
embedding.load_state_dict(dist_embedding_1d.state_dict())
|
||||
|
@@ -1,8 +1,10 @@
|
||||
import copy
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
|
||||
|
||||
@@ -61,3 +63,14 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn,
|
||||
shard_output = output_transform_fn(shard_output)
|
||||
shard_loss = loss_fn(shard_output)
|
||||
return org_output, org_loss, shard_output, shard_loss
|
||||
|
||||
|
||||
def check_state_dict(org_model: Module, sharded_model: Module, name: str = ''):
|
||||
org_sd = org_model.state_dict()
|
||||
shard_sd = sharded_model.state_dict()
|
||||
for k, v in org_sd.items():
|
||||
assert k in shard_sd, f'{name} {k} not in sharded model'
|
||||
shard_v = shard_sd[k]
|
||||
assert v.shape == shard_v.shape, f'{name} {k} shape mismatch, {v.shape} vs {shard_v.shape}'
|
||||
assert v.dtype == shard_v.dtype, f'{name} {k} dtype mismatch, {v.dtype} vs {shard_v.dtype}'
|
||||
assert torch.equal(v, shard_v), f'{name} {k} value mismatch'
|
||||
|
@@ -12,7 +12,7 @@ from colossalai.testing import (
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||
from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
@@ -75,6 +75,7 @@ def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_laz
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
use_lazy_init)
|
||||
check_state_dict(org_model, sharded_model, name=name)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
@@ -12,7 +12,7 @@ from colossalai.testing import (
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||
from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
@@ -75,6 +75,7 @@ def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_la
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
use_lazy_init)
|
||||
check_state_dict(org_model, sharded_model, name=name)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
@@ -12,7 +12,7 @@ from colossalai.testing import (
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||
from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
@@ -77,6 +77,7 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
use_lazy_init)
|
||||
check_state_dict(org_model, sharded_model, name=name)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
@@ -14,7 +14,7 @@ from colossalai.testing import (
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||
from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward
|
||||
|
||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||
|
||||
@@ -78,6 +78,7 @@ def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism, use_la
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
use_lazy_init)
|
||||
check_state_dict(org_model, sharded_model, name=name)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
@@ -15,7 +15,7 @@ from colossalai.testing import (
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||
from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward
|
||||
|
||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||
|
||||
@@ -77,6 +77,7 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
use_lazy_init)
|
||||
check_state_dict(org_model, sharded_model, name=name)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
@@ -14,7 +14,7 @@ from colossalai.testing import (
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||
from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
@@ -88,6 +88,7 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
use_lazy_init)
|
||||
check_state_dict(org_model, sharded_model, name=name)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
Reference in New Issue
Block a user