mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[shardformer] support inplace sharding (#4251)
* [shardformer] embedding support inplace sharding * [shardformer] linear support inplace sharding * [shardformer] layernorm support inplace sharding * [shardformer] qkv support inplace sharding * [test] update shardformer layer test * [shardformer] fix shared param sharding * [shardformer] fix bert policy * [shardformer] fix bloom policy * [shardformer] fix llama policy * [shardformer] fix opt policy * [shardformer] fix t5 policy * [shardformer] fix fused qkv linear * [shardformer] fix bugs * force sync * [test] fix bugs * [test] fix transformer version
This commit is contained in:
@@ -15,14 +15,16 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
@parameterize('lazy_init', [False, True])
|
||||
def check_linear_1d_col(lazy_init: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
linear = nn.Linear(32, 128).cuda()
|
||||
with ctx:
|
||||
linear = nn.Linear(32, 128).cuda()
|
||||
linear_col = Linear1D_Col.from_native_module(linear, process_group=None, gather_output=True)
|
||||
linear_copy = nn.Linear(32, 128).cuda()
|
||||
linear_col = Linear1D_Col.from_native_module(linear_copy, process_group=None, gather_output=True)
|
||||
|
||||
# ensure that the parameters are distributed
|
||||
assert is_distributed_tensor(linear_col.weight)
|
||||
assert is_distributed_tensor(linear_col.bias)
|
||||
assert linear_copy.weight is linear_col.weight
|
||||
assert linear_copy.bias is linear_col.bias
|
||||
|
||||
# ensure the shape is correct
|
||||
assert linear_col.weight.shape == torch.Size([64, 32])
|
||||
@@ -61,12 +63,18 @@ def check_linear_1d_col(lazy_init: bool):
|
||||
def check_linear_1d_row(lazy_init: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
linear = nn.Linear(32, 128).cuda()
|
||||
with ctx:
|
||||
linear = nn.Linear(32, 128).cuda()
|
||||
linear_row = Linear1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
|
||||
linear_copy = nn.Linear(32, 128).cuda()
|
||||
linear_row = Linear1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False)
|
||||
|
||||
assert linear_row.weight.shape == torch.Size([128, 16])
|
||||
assert linear_row.bias.shape == torch.Size([128])
|
||||
assert linear_copy.weight is linear_row.weight
|
||||
assert linear_copy.bias is linear_row.bias
|
||||
|
||||
linear.load_state_dict(linear_row.state_dict())
|
||||
linear_row.load_state_dict(linear.state_dict())
|
||||
|
||||
# check computation correctness
|
||||
x = torch.rand(4, 32).cuda()
|
||||
@@ -98,11 +106,19 @@ def check_linear_1d_row(lazy_init: bool):
|
||||
def check_linear_col_plus_row(lazy_init: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
linear_1 = nn.Linear(32, 128).cuda()
|
||||
linear_2 = nn.Linear(128, 32).cuda()
|
||||
|
||||
with ctx:
|
||||
linear_1 = nn.Linear(32, 128).cuda()
|
||||
linear_2 = nn.Linear(128, 32).cuda()
|
||||
linear_col = Linear1D_Col.from_native_module(linear_1, process_group=None, gather_output=False)
|
||||
linear_row = Linear1D_Row.from_native_module(linear_2, process_group=None, parallel_input=True)
|
||||
linear_1_copy = nn.Linear(32, 128).cuda()
|
||||
linear_2_copy = nn.Linear(128, 32).cuda()
|
||||
linear_col = Linear1D_Col.from_native_module(linear_1_copy, process_group=None, gather_output=False)
|
||||
linear_row = Linear1D_Row.from_native_module(linear_2_copy, process_group=None, parallel_input=True)
|
||||
|
||||
linear_1.load_state_dict(linear_col.state_dict())
|
||||
linear_col.load_state_dict(linear_1.state_dict())
|
||||
linear_2.load_state_dict(linear_row.state_dict())
|
||||
linear_row.load_state_dict(linear_2.state_dict())
|
||||
|
||||
# check computation correctness
|
||||
x = torch.rand(4, 32).cuda()
|
||||
|
Reference in New Issue
Block a user