[refactory] add nn.parallel module (#1068)

This commit is contained in:
Jiarui Fang
2022-06-06 15:34:41 +08:00
committed by GitHub
parent 6754f1b77f
commit 49832b2344
22 changed files with 44 additions and 46 deletions

View File

@@ -1,9 +1,12 @@
import contextlib
import functools
from typing import Optional
from contextlib import AbstractContextManager
import torch
import torch.nn as nn
import torch.distributed as dist
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.context.singleton_meta import SingletonMeta
@@ -12,8 +15,7 @@ from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.sharded_param import ShardedParamV2
from contextlib import AbstractContextManager
from colossalai.utils import InsertPostInitMethodToModuleSubClasses
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
class ZeroContextConfig(object):

View File

@@ -2,7 +2,7 @@ import torch
import torch.distributed as dist
from enum import Enum
from torch.optim import Optimizer
from colossalai.nn.parallel import ColoDDPV2
from colossalai.nn.parallel.data_parallel import ColoDDPV2
from typing import Dict
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.core import global_context as gpc