mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[gemini] improve compatibility and add static placement policy (#4479)
* [gemini] remove distributed-related part from colotensor (#4379) * [gemini] remove process group dependency * [gemini] remove tp part from colo tensor * [gemini] patch inplace op * [gemini] fix param op hook and update tests * [test] remove useless tests * [test] remove useless tests * [misc] fix requirements * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [misc] update requirements * [gemini] refactor gemini optimizer and gemini ddp (#4398) * [gemini] update optimizer interface * [gemini] renaming gemini optimizer * [gemini] refactor gemini ddp class * [example] update gemini related example * [example] update gemini related example * [plugin] fix gemini plugin args * [test] update gemini ckpt tests * [gemini] fix checkpoint io * [example] fix opt example requirements * [example] fix opt example * [example] fix opt example * [example] fix opt example * [gemini] add static placement policy (#4443) * [gemini] add static placement policy * [gemini] fix param offload * [test] update gemini tests * [plugin] update gemini plugin * [plugin] update gemini plugin docstr * [misc] fix flash attn requirement * [test] fix gemini checkpoint io test * [example] update resnet example result (#4457) * [example] update bert example result (#4458) * [doc] update gemini doc (#4468) * [example] update gemini related examples (#4473) * [example] update gpt example * [example] update dreambooth example * [example] update vit * [example] update opt * [example] update palm * [example] update vit and opt benchmark * [hotfix] fix bert in model zoo (#4480) * [hotfix] fix bert in model zoo * [test] remove chatglm gemini test * [test] remove sam gemini test * [test] remove vit gemini test * [hotfix] fix opt tutorial example (#4497) * [hotfix] fix opt tutorial example * [hotfix] fix opt tutorial example
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import gzip
|
||||
import random
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from time import time
|
||||
|
||||
@@ -8,20 +8,17 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import tqdm
|
||||
from packaging import version
|
||||
|
||||
from colossalai.nn import HybridAdam
|
||||
from palm_pytorch import PaLM
|
||||
from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
||||
from colossalai.utils import MultiTimer, get_current_device
|
||||
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, ZeroDDP
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
# constants
|
||||
|
||||
@@ -44,23 +41,10 @@ def parse_args():
|
||||
help="The distributed plan [colossalai, pytorch].",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp_degree",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--placement",
|
||||
type=str,
|
||||
default='cpu',
|
||||
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shardinit",
|
||||
type=bool,
|
||||
default=False,
|
||||
help=
|
||||
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
|
||||
"--offload_optim_frac",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Fraction of optimizer states to be offloaded. This is only used for gemini.",
|
||||
)
|
||||
parser.add_argument('-p',
|
||||
'--plugin',
|
||||
@@ -111,51 +95,6 @@ def get_model_size(model: nn.Module):
|
||||
return total_numel
|
||||
|
||||
|
||||
|
||||
|
||||
# Parameter Sharding Strategies for Tensor Parallelism
|
||||
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
|
||||
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
param.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||
split_param_single_dim_tp1d(0, param, pg)
|
||||
|
||||
|
||||
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||
split_param_single_dim_tp1d(-1, param, pg)
|
||||
|
||||
|
||||
# Tensor Parallel
|
||||
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
||||
"""tensor_parallelize
|
||||
Sharding the Model Parameters.
|
||||
Args:
|
||||
model (torch.nn.Module): a torch module to be sharded
|
||||
"""
|
||||
for mn, module in model.named_modules():
|
||||
for pn, param in module.named_parameters(recurse=False):
|
||||
if hasattr(param, 'visited'):
|
||||
continue
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
if 'net.0' in mn:
|
||||
split_param_col_tp1d(param, pg) # column slice
|
||||
elif 'to_q' in mn:
|
||||
split_param_col_tp1d(param, pg) # column slice
|
||||
elif 'to_kv' in mn:
|
||||
split_param_row_tp1d(param, pg) # row slice
|
||||
elif 'to_out' in mn:
|
||||
split_param_row_tp1d(param, pg) # row slice
|
||||
elif '1.1' in mn:
|
||||
split_param_col_tp1d(param, pg) # column slice
|
||||
elif '1.2' in mn:
|
||||
split_param_row_tp1d(param, pg) # row slice
|
||||
else:
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
param.visited = True
|
||||
|
||||
|
||||
args = parse_args()
|
||||
if args.distplan not in ["colossalai", "pytorch"]:
|
||||
raise TypeError(f"{args.distplan} is error")
|
||||
@@ -212,23 +151,18 @@ if args.distplan == "colossalai":
|
||||
if args.plugin.startswith('torch_ddp'):
|
||||
plugin = TorchDDPPlugin()
|
||||
elif args.plugin == 'gemini':
|
||||
plugin = GeminiPlugin(placement_policy=args.placement, strict_ddp_mode=True, initial_scale=2 ** 5)
|
||||
plugin = GeminiPlugin(offload_optim_frac=args.offload_optim_frac, initial_scale=2**5)
|
||||
elif args.plugin == 'low_level_zero':
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2 ** 5)
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||
logger.info(f"plugin: {plugin}")
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
|
||||
default_pg = ProcessGroup(tp_degree=args.tp_degree)
|
||||
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
|
||||
ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg)
|
||||
ctx = LazyInitContext(default_device=get_current_device()) if args.plugin == 'gemini' else nullcontext()
|
||||
|
||||
with ctx:
|
||||
model = PaLM(num_tokens=50304, dim=4096, depth=64)
|
||||
model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN)
|
||||
|
||||
pg = default_pg
|
||||
tensor_parallelize(model, pg)
|
||||
|
||||
# optimizer
|
||||
|
||||
optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE, initial_scale=2**5)
|
||||
|
Reference in New Issue
Block a user