mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[zero] reorganize zero/gemini folder structure (#3424)
* [zero] refactor low-level zero folder structure * [zero] fix legacy zero import path * [zero] fix legacy zero import path * [zero] remove useless import * [zero] refactor gemini folder structure * [zero] refactor gemini folder structure * [zero] refactor legacy zero import path * [zero] refactor gemini folder structure * [zero] refactor gemini folder structure * [zero] refactor gemini folder structure * [zero] refactor legacy zero import path * [zero] fix test import path * [zero] fix test * [zero] fix circular import * [zero] update import
This commit is contained in:
@@ -15,11 +15,9 @@ from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
||||
from colossalai.utils import MultiTimer, get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, ZeroDDP
|
||||
|
||||
# constants
|
||||
|
||||
@@ -127,7 +125,7 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
|
||||
return model
|
||||
|
||||
|
||||
## Parameter Sharding Strategies for Tensor Parallelism
|
||||
# Parameter Sharding Strategies for Tensor Parallelism
|
||||
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
|
||||
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
param.set_tensor_spec(*spec)
|
||||
@@ -232,7 +230,7 @@ if args.distplan == "colossalai":
|
||||
tensor_parallelize(model, pg)
|
||||
model = gemini_zero_dpp(model, pg, args.placement)
|
||||
|
||||
#optimizer
|
||||
# optimizer
|
||||
|
||||
#optimizer = GeminiAdamOptimizer(model, lr=1e-7, initial_scale=2**5)
|
||||
optimizer = GeminiAdamOptimizer(model, lr=LEARNING_RATE, initial_scale=2**5)
|
||||
|
Reference in New Issue
Block a user