Revert "[zero] update sharded optim and fix zero init ctx" (#456)

* Revert "polish code"

This reverts commit 8cf7ff08cf.

* Revert "rename variables"

This reverts commit e99af94ab8.

* Revert "remove surplus imports"

This reverts commit 46add4a5c5.

* Revert "update sharded optim and fix zero init ctx"

This reverts commit 57567ee768.
This commit is contained in:
Jiarui Fang
2022-03-18 15:22:43 +08:00
committed by GitHub
parent 8cf7ff08cf
commit e2e9f82588
11 changed files with 161 additions and 161 deletions

View File

@@ -5,7 +5,7 @@ import argparse
import os
import pprint
from pathlib import Path
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union, Type
import torch
import torch.nn as nn
@@ -21,13 +21,13 @@ from colossalai.builder.builder import build_gradient_handler
from colossalai.context import Config, ConfigException, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.engine.ophooks import BaseOpHook
from colossalai.global_variables import moe_env
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
from colossalai.utils import (accumulate_gradient, get_current_device, is_using_ddp, is_using_pp, is_using_sequence,
sync_model_param)
from colossalai.zero import convert_to_zero_v2
from colossalai.engine.ophooks import BaseOpHook
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
@@ -217,8 +217,8 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
verbose=verbose)
def initialize(model: nn.Module,
optimizer: Optimizer,
def initialize(model: Union[Callable, nn.Module],
optimizer: Union[Type[Optimizer], Optimizer],
criterion: Optional[_Loss] = None,
train_dataloader: Optional[Iterable] = None,
test_dataloader: Optional[Iterable] = None,
@@ -278,10 +278,12 @@ def initialize(model: nn.Module,
cfg_ = {}
optimizer_config = zero_cfg.get('optimizer_config', None)
model_config = zero_cfg.get('model_config', None)
model, optimizer = convert_to_zero_v2(model, model_config=model_config, optimizer_config=optimizer_config)
model, optimizer = convert_to_zero_v2(model_builder=model,
model_config=model_config,
optimizer_config=optimizer_config)
logger.info("Initializing ZeRO model and optimizer finished!", ranks=[0])
# FIXME() throw a warning if using zero with MP
#FIXME() throw a warning if using zero with MP
if gpc.get_world_size(ParallelMode.MODEL) > 1:
logger.warning("ZeRO currently has not been tested with model parallelism.", ranks=[0])
else: