mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
Revert "[zero] update sharded optim and fix zero init ctx" (#456)
* Revert "polish code" This reverts commit8cf7ff08cf
. * Revert "rename variables" This reverts commite99af94ab8
. * Revert "remove surplus imports" This reverts commit46add4a5c5
. * Revert "update sharded optim and fix zero init ctx" This reverts commit57567ee768
.
This commit is contained in:
@@ -5,7 +5,7 @@ import argparse
|
||||
import os
|
||||
import pprint
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union, Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -21,13 +21,13 @@ from colossalai.builder.builder import build_gradient_handler
|
||||
from colossalai.context import Config, ConfigException, ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.engine import Engine
|
||||
from colossalai.engine.ophooks import BaseOpHook
|
||||
from colossalai.global_variables import moe_env
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
|
||||
from colossalai.utils import (accumulate_gradient, get_current_device, is_using_ddp, is_using_pp, is_using_sequence,
|
||||
sync_model_param)
|
||||
from colossalai.zero import convert_to_zero_v2
|
||||
from colossalai.engine.ophooks import BaseOpHook
|
||||
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
|
||||
|
||||
|
||||
@@ -217,8 +217,8 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
|
||||
verbose=verbose)
|
||||
|
||||
|
||||
def initialize(model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
def initialize(model: Union[Callable, nn.Module],
|
||||
optimizer: Union[Type[Optimizer], Optimizer],
|
||||
criterion: Optional[_Loss] = None,
|
||||
train_dataloader: Optional[Iterable] = None,
|
||||
test_dataloader: Optional[Iterable] = None,
|
||||
@@ -278,10 +278,12 @@ def initialize(model: nn.Module,
|
||||
cfg_ = {}
|
||||
optimizer_config = zero_cfg.get('optimizer_config', None)
|
||||
model_config = zero_cfg.get('model_config', None)
|
||||
model, optimizer = convert_to_zero_v2(model, model_config=model_config, optimizer_config=optimizer_config)
|
||||
model, optimizer = convert_to_zero_v2(model_builder=model,
|
||||
model_config=model_config,
|
||||
optimizer_config=optimizer_config)
|
||||
|
||||
logger.info("Initializing ZeRO model and optimizer finished!", ranks=[0])
|
||||
# FIXME() throw a warning if using zero with MP
|
||||
#FIXME() throw a warning if using zero with MP
|
||||
if gpc.get_world_size(ParallelMode.MODEL) > 1:
|
||||
logger.warning("ZeRO currently has not been tested with model parallelism.", ranks=[0])
|
||||
else:
|
||||
|
Reference in New Issue
Block a user