[zero] hybrid cpu adam (#445)

This commit is contained in:
Jiarui Fang
2022-03-17 15:05:41 +08:00
committed by GitHub
parent b72b8445c6
commit 237d08e7ee
2 changed files with 81 additions and 55 deletions

View File

@@ -1,20 +1,19 @@
from asyncio.log import logger
from distutils.command.config import config
from typing import Callable
import torch
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
from colossalai.zero.shard_utils import TensorShardStrategy
import torch
import torch.nn as nn
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from torch.optim import Optimizer
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.logging import get_dist_logger
from .sharded_model import ShardedModel
from .sharded_optim import ShardedOptimizer
from colossalai.zero.init_ctx import ZeroInitContext
from typing import Callable, Type
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
def convert_to_zero_v2(model_builder: Callable, model_config, optimizer_config) -> (ShardedModelV2, ShardedOptimizerV2):