[shardformer] supported bloom model (#4098)

This commit is contained in:
Frank Lee
2023-06-28 15:04:35 +08:00
parent 8af29ee47a
commit b1c2901530
20 changed files with 724 additions and 154 deletions

View File

@@ -3,13 +3,13 @@ import torch.distributed as dist
import torch.nn as nn
import colossalai
from colossalai.shardformer.layer import Dropout1D
from colossalai.shardformer.layer import DropoutForParallelInput, DropoutForReplicatedInput
from colossalai.testing import assert_equal, assert_not_equal, rerun_if_address_is_in_use, spawn
def check_dropout():
def check_dropout_parallel_input():
dropout = nn.Dropout().cuda()
dropout_1d = Dropout1D.from_native_module(dropout, process_group=None)
dropout_1d = DropoutForParallelInput.from_native_module(dropout, process_group=None)
# check computation correctness
x = torch.rand(4, 128).cuda()
@@ -39,9 +39,26 @@ def check_dropout():
assert_not_equal(out_1d_all[i], out_1d_all[0])
def check_dropout_replicated_input():
dropout = nn.Dropout().cuda()
dropout_replica = DropoutForReplicatedInput.from_native_module(dropout, process_group=None)
# check computation correctness
x = torch.rand(4, 128).cuda()
out_1d = dropout_replica(x)
# ensure out_1d is different across ranks
world_size = dist.get_world_size()
out_1d_all = [torch.zeros_like(out_1d) for _ in range(world_size)]
dist.all_gather(out_1d_all, out_1d)
for i in range(1, world_size):
assert_equal(out_1d_all[i], out_1d_all[0])
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_dropout()
check_dropout_parallel_input()
check_dropout_replicated_input()
@rerun_if_address_is_in_use()