update sharded optim and fix zero init ctx

This commit is contained in:
ver217
2022-03-18 13:17:53 +08:00
parent f27d801a13
commit 57567ee768
11 changed files with 147 additions and 142 deletions

View File

@@ -5,7 +5,7 @@ import argparse
import os
import pprint
from pathlib import Path
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union, Type
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -21,13 +21,13 @@ from colossalai.builder.builder import build_gradient_handler
from colossalai.context import Config, ConfigException, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.engine.ophooks import BaseOpHook
from colossalai.global_variables import moe_env
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
from colossalai.utils import (accumulate_gradient, get_current_device, is_using_ddp, is_using_pp, is_using_sequence,
sync_model_param)
from colossalai.zero import convert_to_zero_v2
from colossalai.engine.ophooks import BaseOpHook
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
@@ -217,8 +217,8 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
verbose=verbose)
def initialize(model: Union[Callable, nn.Module],
optimizer: Union[Type[Optimizer], Optimizer],
def initialize(model: nn.Module,
optimizer: Optimizer,
criterion: Optional[_Loss] = None,
train_dataloader: Optional[Iterable] = None,
test_dataloader: Optional[Iterable] = None,
@@ -278,12 +278,10 @@ def initialize(model: Union[Callable, nn.Module],
cfg_ = {}
optimizer_config = zero_cfg.get('optimizer_config', None)
model_config = zero_cfg.get('model_config', None)
model, optimizer = convert_to_zero_v2(model_builder=model,
model_config=model_config,
optimizer_config=optimizer_config)
model, optimizer = convert_to_zero_v2(model, model_config=model_config, optimizer_config=optimizer_config)
logger.info("Initializing ZeRO model and optimizer finished!", ranks=[0])
#FIXME() throw a warning if using zero with MP
# FIXME() throw a warning if using zero with MP
if gpc.get_world_size(ParallelMode.MODEL) > 1:
logger.warning("ZeRO currently has not been tested with model parallelism.", ranks=[0])
else: