mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[refactory] add nn.parallel module (#1068)
This commit is contained in:
@@ -1,9 +1,12 @@
|
||||
import contextlib
|
||||
import functools
|
||||
from typing import Optional
|
||||
from contextlib import AbstractContextManager
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
@@ -12,8 +15,7 @@ from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
|
||||
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
|
||||
from colossalai.zero.sharded_param import ShardedParamV2
|
||||
from contextlib import AbstractContextManager
|
||||
from colossalai.utils import InsertPostInitMethodToModuleSubClasses
|
||||
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
|
||||
|
||||
|
||||
class ZeroContextConfig(object):
|
||||
|
||||
@@ -2,7 +2,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from enum import Enum
|
||||
from torch.optim import Optimizer
|
||||
from colossalai.nn.parallel import ColoDDPV2
|
||||
from colossalai.nn.parallel.data_parallel import ColoDDPV2
|
||||
from typing import Dict
|
||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
Reference in New Issue
Block a user