mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
[zero] hybrid cpu adam (#445)
This commit is contained in:
@@ -1,20 +1,19 @@
|
||||
from asyncio.log import logger
|
||||
from distutils.command.config import config
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
|
||||
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
|
||||
from colossalai.zero.shard_utils import TensorShardStrategy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.amp.naive_amp import NaiveAMPModel
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from torch.optim import Optimizer
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from .sharded_model import ShardedModel
|
||||
from .sharded_optim import ShardedOptimizer
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from typing import Callable, Type
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
|
||||
def convert_to_zero_v2(model_builder: Callable, model_config, optimizer_config) -> (ShardedModelV2, ShardedOptimizerV2):
|
||||
|
Reference in New Issue
Block a user