[shardformer] refactored the shardformer layer structure (#4053)

This commit is contained in:
Frank Lee
2023-06-21 14:30:06 +08:00
parent 58df720570
commit f22ddacef0
24 changed files with 196 additions and 471 deletions

View File

@@ -4,9 +4,11 @@ import torch
import torch.nn as nn
from torch.distributed import ProcessGroup
from .parallelmodule import ParallelModule
from .parallel_module import ParallelModule
from .utils import create_randomizer_with_offset
__all__ = ['Dropout1D']
class Dropout1D(ParallelModule, nn.Dropout):
"""