[shardformer] refactored embedding and dropout to parallel module (#4013)

* [shardformer] refactored embedding and dropout to parallel module

* polish code
This commit is contained in:
Frank Lee
2023-06-16 15:00:26 +08:00
parent dfca9678fa
commit 3893fa1a8d
6 changed files with 198 additions and 423 deletions

View File

@@ -5,7 +5,7 @@ from torch.testing import assert_close
import colossalai
from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_linear_1d_col():