Revert "[zero] update sharded optim and fix zero init ctx" (#456)

* Revert "polish code"

This reverts commit 8cf7ff08cf.

* Revert "rename variables"

This reverts commit e99af94ab8.

* Revert "remove surplus imports"

This reverts commit 46add4a5c5.

* Revert "update sharded optim and fix zero init ctx"

This reverts commit 57567ee768.
This commit is contained in:
Jiarui Fang
2022-03-18 15:22:43 +08:00
committed by GitHub
parent 8cf7ff08cf
commit e2e9f82588
11 changed files with 161 additions and 161 deletions

View File

@@ -2,10 +2,11 @@ from functools import partial
import torch
import torch.distributed as dist
from colossalai.logging import get_dist_logger
from colossalai.utils import checkpoint
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.nn.optimizer import CPUAdam
LOGGER = get_dist_logger('zero_test')
@@ -15,18 +16,20 @@ _ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25,
fp32_reduce_scatter=False,
offload_config=None,
gradient_predivide_factor=1.0,
use_memory_tracer=False,
shard_strategy=TensorShardStrategy)
shard_param=True,
use_memory_tracer=False)
_ZERO_OPTIMIZER_CONFIG = dict(cpu_offload=False,
initial_scale=2**5,
min_scale=1,
growth_factor=2,
backoff_factor=0.5,
growth_interval=1000,
hysteresis=2,
max_scale=2**32,
lr=1e-3)
_ZERO_OPTIMIZER_CONFIG = dict(
optimizer_class=torch.optim.Adam, #CPUAdam
cpu_offload=False,
initial_scale=2**5,
min_scale=1,
growth_factor=2,
backoff_factor=0.5,
growth_interval=1000,
hysteresis=2,
max_scale=2**32,
lr=1e-3)
ZERO_PARALLEL_CONFIG = dict(fp16=dict(mode=None,),
zero=dict(