[refactory] add nn.parallel module (#1068)

This commit is contained in:
Jiarui Fang
2022-06-06 15:34:41 +08:00
committed by GitHub
parent 6754f1b77f
commit 49832b2344
22 changed files with 44 additions and 46 deletions

View File

@@ -2,7 +2,7 @@ from .utils import InsertPostInitMethodToModuleSubClasses
import torch
from colossalai.tensor import ColoTensor, ColoParameter
from colossalai.nn import register_colo_module, init_colo_module, \
from colossalai.nn.parallel.layers import register_colo_module, \
ColoLinear, ColoEmbedding
from torch import nn

View File

@@ -1,10 +1,7 @@
import torch
import functools
import inspect
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.utils.model.utils import _substitute_init_recursively, InsertPostInitMethodToModuleSubClasses, call_to_str
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses, call_to_str
from colossalai.builder.pipeline import partition_uniform, partition_balanced
from colossalai.core import global_context as gpc
from colossalai.nn.layer.utils import CheckpointModule
from colossalai.tensor import ColoTensor