[shardformer] integrated linear 1D with dtensor (#3996)

* [shardformer] integrated linear 1D with dtensor

* polish code
This commit is contained in:
Frank Lee
2023-06-15 18:03:38 +08:00
parent d3bc530849
commit 015af592f8
9 changed files with 707 additions and 408 deletions

View File

@@ -10,6 +10,7 @@ from colossalai.core import global_context as gpc
class ParallelLayer(nn.Module):
global_state_dict: bool = True
def __init__(self):