mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[zero] add zero optimizer for ColoTensor (#1046)
* add zero optimizer * torch ok * unit test ok * polish code * fix bugs * polish unit test * polish zero optim * polish colo ddp v2 * refactor folder structure * add comment * polish unit test * polish zero optim * polish unit test
This commit is contained in:
@@ -5,6 +5,7 @@ import torch.nn as nn
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
|
||||
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
|
||||
from .zero_optimizer import ZeroOptimizer
|
||||
|
||||
|
||||
def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config,
|
||||
@@ -35,4 +36,4 @@ def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model
|
||||
return zero_model, zero_optimizer
|
||||
|
||||
|
||||
__all__ = ['convert_to_zero_v2', 'ShardedModelV2', 'ShardedOptimizerV2']
|
||||
__all__ = ['convert_to_zero_v2', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroOptimizer']
|
||||
|
Reference in New Issue
Block a user