polish code

This commit is contained in:
ver217
2022-03-18 13:57:20 +08:00
parent e99af94ab8
commit 8cf7ff08cf
4 changed files with 24 additions and 28 deletions

View File

@@ -1,10 +1,8 @@
import imp
from functools import partial
import torch
import torch.distributed as dist
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import CPUAdam
from colossalai.utils import checkpoint
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2
@@ -20,23 +18,22 @@ _ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25,
use_memory_tracer=False,
shard_strategy=TensorShardStrategy)
_ZERO_OPTIMIZER_CONFIG = dict(
cpu_offload=False,
initial_scale=2**5,
min_scale=1,
growth_factor=2,
backoff_factor=0.5,
growth_interval=1000,
hysteresis=2,
max_scale=2**32,
lr=1e-3)
_ZERO_OPTIMIZER_CONFIG = dict(cpu_offload=False,
initial_scale=2**5,
min_scale=1,
growth_factor=2,
backoff_factor=0.5,
growth_interval=1000,
hysteresis=2,
max_scale=2**32,
lr=1e-3)
ZERO_PARALLEL_CONFIG = dict(fp16=dict(mode=None,),
zero=dict(
model_config=_ZERO_MODEL_CONFIG,
optimizer_config=_ZERO_OPTIMIZER_CONFIG,
),
parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)))
),
parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)))
CONFIG = dict(fp16=dict(mode=None,),
zero=dict(level=3,