[shardformer] refactored layernorm (#4086)

This commit is contained in:
Frank Lee
2023-06-26 18:05:00 +08:00
parent c4b1b65931
commit d33a44e8c3
4 changed files with 51 additions and 77 deletions

View File

@@ -1,16 +1,15 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
import colossalai
from colossalai.shardformer.layer import LayerNorm1D
from colossalai.shardformer.layer import FusedLayerNorm
from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_layernorm_1d():
def check_layernorm():
norm = nn.LayerNorm(128, 0.00001).cuda()
norm1d = LayerNorm1D.from_native_module(norm, process_group=None)
norm1d = FusedLayerNorm.from_native_module(norm, process_group=None)
assert norm1d.weight.shape == torch.Size([128])
@@ -33,11 +32,11 @@ def check_layernorm_1d():
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_layernorm_1d()
check_layernorm()
@rerun_if_address_is_in_use()
def test_layernorm_1d():
def test_layernorm():
spawn(run_dist, nprocs=2)