[shardformer] support inplace sharding (#4251)

* [shardformer] embedding support inplace sharding

* [shardformer] linear support inplace sharding

* [shardformer] layernorm support inplace sharding

* [shardformer] qkv support inplace sharding

* [test] update shardformer layer test

* [shardformer] fix shared param sharding

* [shardformer] fix bert policy

* [shardformer] fix bloom policy

* [shardformer] fix llama policy

* [shardformer] fix opt policy

* [shardformer] fix t5 policy

* [shardformer] fix fused qkv linear

* [shardformer] fix bugs

* force sync

* [test] fix bugs

* [test] fix transformer version
This commit is contained in:
Hongxin Liu
2023-07-20 10:39:06 +08:00
parent 2a2eacfaf1
commit d921ce8391
26 changed files with 371 additions and 340 deletions

View File

@@ -15,14 +15,16 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@parameterize('lazy_init', [False, True])
def check_linear_1d_col(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = nn.Linear(32, 128).cuda()
with ctx:
linear = nn.Linear(32, 128).cuda()
linear_col = Linear1D_Col.from_native_module(linear, process_group=None, gather_output=True)
linear_copy = nn.Linear(32, 128).cuda()
linear_col = Linear1D_Col.from_native_module(linear_copy, process_group=None, gather_output=True)
# ensure that the parameters are distributed
assert is_distributed_tensor(linear_col.weight)
assert is_distributed_tensor(linear_col.bias)
assert linear_copy.weight is linear_col.weight
assert linear_copy.bias is linear_col.bias
# ensure the shape is correct
assert linear_col.weight.shape == torch.Size([64, 32])
@@ -61,12 +63,18 @@ def check_linear_1d_col(lazy_init: bool):
def check_linear_1d_row(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = nn.Linear(32, 128).cuda()
with ctx:
linear = nn.Linear(32, 128).cuda()
linear_row = Linear1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
linear_copy = nn.Linear(32, 128).cuda()
linear_row = Linear1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False)
assert linear_row.weight.shape == torch.Size([128, 16])
assert linear_row.bias.shape == torch.Size([128])
assert linear_copy.weight is linear_row.weight
assert linear_copy.bias is linear_row.bias
linear.load_state_dict(linear_row.state_dict())
linear_row.load_state_dict(linear.state_dict())
# check computation correctness
x = torch.rand(4, 32).cuda()
@@ -98,11 +106,19 @@ def check_linear_1d_row(lazy_init: bool):
def check_linear_col_plus_row(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear_1 = nn.Linear(32, 128).cuda()
linear_2 = nn.Linear(128, 32).cuda()
with ctx:
linear_1 = nn.Linear(32, 128).cuda()
linear_2 = nn.Linear(128, 32).cuda()
linear_col = Linear1D_Col.from_native_module(linear_1, process_group=None, gather_output=False)
linear_row = Linear1D_Row.from_native_module(linear_2, process_group=None, parallel_input=True)
linear_1_copy = nn.Linear(32, 128).cuda()
linear_2_copy = nn.Linear(128, 32).cuda()
linear_col = Linear1D_Col.from_native_module(linear_1_copy, process_group=None, gather_output=False)
linear_row = Linear1D_Row.from_native_module(linear_2_copy, process_group=None, parallel_input=True)
linear_1.load_state_dict(linear_col.state_dict())
linear_col.load_state_dict(linear_1.state_dict())
linear_2.load_state_dict(linear_row.state_dict())
linear_row.load_state_dict(linear_2.state_dict())
# check computation correctness
x = torch.rand(4, 32).cuda()