[shardformer] support lazy init (#4202)

* [shardformer] support lazy init

* [shardformer] linear support lazy init

* [shardformer] embedding support lazy init

* [shardformer] norm support lazy init

* [shardformer] fused linear support lazy init

* [test] update shardformer test layer

* [test] shardformer with lazy init fit ddp

* [lazy] hotfix deepcopy of param

* [shardformer] fix bert policy and update test

* [shardformer] fix bloom policy and update test

* [shardformer] fix opt policy and update test

* [shardformer] fix t5 policy and update test

* [shardformer] fix gpt2 policy and update test

* [shardformer] fix llama policy and update test
This commit is contained in:
Hongxin Liu
2023-07-10 10:48:53 +08:00
parent f3bcc292c8
commit 890774b2fb
25 changed files with 263 additions and 157 deletions

View File

@@ -1,16 +1,23 @@
from contextlib import nullcontext
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
import colossalai
from colossalai.lazy import LazyInitContext
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
from colossalai.tensor.d_tensor import is_distributed_tensor
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
def check_linear_1d_col():
linear = nn.Linear(32, 128).cuda()
@parameterize('lazy_init', [False, True])
def check_linear_1d_col(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
with ctx:
linear = nn.Linear(32, 128).cuda()
linear_col = Linear1D_Col.from_native_module(linear, process_group=None, gather_output=True)
# ensure that the parameters are distributed
@@ -50,8 +57,12 @@ def check_linear_1d_col():
assert_close(x_for_unshard.grad, x_for_shard.grad)
def check_linear_1d_row():
linear = nn.Linear(32, 128).cuda()
@parameterize('lazy_init', [False, True])
def check_linear_1d_row(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
with ctx:
linear = nn.Linear(32, 128).cuda()
linear_row = Linear1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
assert linear_row.weight.shape == torch.Size([128, 16])
@@ -83,9 +94,13 @@ def check_linear_1d_row():
assert_close(x_for_unshard.grad, x_for_shard.grad)
def check_linear_col_plus_row():
linear_1 = nn.Linear(32, 128).cuda()
linear_2 = nn.Linear(128, 32).cuda()
@parameterize('lazy_init', [False, True])
def check_linear_col_plus_row(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
with ctx:
linear_1 = nn.Linear(32, 128).cuda()
linear_2 = nn.Linear(128, 32).cuda()
linear_col = Linear1D_Col.from_native_module(linear_1, process_group=None, gather_output=False)
linear_row = Linear1D_Row.from_native_module(linear_2, process_group=None, parallel_input=True)