mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[zero] migrate zero1&2 (#1878)
* add zero1&2 optimizer * rename test ditectory * rename test files * change tolerance in test
This commit is contained in:
@@ -2,9 +2,11 @@ from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
|
||||
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
|
||||
from colossalai.zero.sharded_optim import LowLevelZeroOptimizer, ShardedOptimizerV2
|
||||
|
||||
from .zero_optimizer import ZeroOptimizer
|
||||
|
||||
|
||||
@@ -36,4 +38,4 @@ def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model
|
||||
return zero_model, zero_optimizer
|
||||
|
||||
|
||||
__all__ = ['convert_to_zero_v2', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroOptimizer']
|
||||
__all__ = ['convert_to_zero_v2', 'LowLevelZeroOptimizer', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroOptimizer']
|
||||
|
Reference in New Issue
Block a user