mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[gemini] improve compatibility and add static placement policy (#4479)
* [gemini] remove distributed-related part from colotensor (#4379) * [gemini] remove process group dependency * [gemini] remove tp part from colo tensor * [gemini] patch inplace op * [gemini] fix param op hook and update tests * [test] remove useless tests * [test] remove useless tests * [misc] fix requirements * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [misc] update requirements * [gemini] refactor gemini optimizer and gemini ddp (#4398) * [gemini] update optimizer interface * [gemini] renaming gemini optimizer * [gemini] refactor gemini ddp class * [example] update gemini related example * [example] update gemini related example * [plugin] fix gemini plugin args * [test] update gemini ckpt tests * [gemini] fix checkpoint io * [example] fix opt example requirements * [example] fix opt example * [example] fix opt example * [example] fix opt example * [gemini] add static placement policy (#4443) * [gemini] add static placement policy * [gemini] fix param offload * [test] update gemini tests * [plugin] update gemini plugin * [plugin] update gemini plugin docstr * [misc] fix flash attn requirement * [test] fix gemini checkpoint io test * [example] update resnet example result (#4457) * [example] update bert example result (#4458) * [doc] update gemini doc (#4468) * [example] update gemini related examples (#4473) * [example] update gpt example * [example] update dreambooth example * [example] update vit * [example] update opt * [example] update palm * [example] update vit and opt benchmark * [hotfix] fix bert in model zoo (#4480) * [hotfix] fix bert in model zoo * [test] remove chatglm gemini test * [test] remove sam gemini test * [test] remove vit gemini test * [hotfix] fix opt tutorial example (#4497) * [hotfix] fix opt tutorial example * [hotfix] fix opt tutorial example
This commit is contained in:
@@ -4,9 +4,6 @@ export DISTPLAN=${DISTPLAN:-"CAI_Gemini"}
|
||||
|
||||
# The following options only valid when DISTPLAN="colossalai"
|
||||
export GPUNUM=${GPUNUM:-1}
|
||||
export TPDEGREE=${TPDEGREE:-1}
|
||||
export PLACEMENT=${PLACEMENT:-"cpu"}
|
||||
export USE_SHARD_INIT=${USE_SHARD_INIT:-False}
|
||||
export BATCH_SIZE=${BATCH_SIZE:-16}
|
||||
export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"}
|
||||
export TRAIN_STEP=${TRAIN_STEP:-10}
|
||||
@@ -21,11 +18,8 @@ fi
|
||||
mkdir -p gemini_logs
|
||||
|
||||
torchrun --standalone --nproc_per_node=${GPUNUM} ./train_gpt_demo.py \
|
||||
--tp_degree=${TPDEGREE} \
|
||||
--model_type=${MODEL_TYPE} \
|
||||
--batch_size=${BATCH_SIZE} \
|
||||
--placement=${PLACEMENT} \
|
||||
${USE_SHARD_INIT} \
|
||||
--distplan=${DISTPLAN} \
|
||||
--train_step=${TRAIN_STEP} \
|
||||
2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPLAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}_${PLACEMENT}.log
|
||||
|
@@ -6,29 +6,17 @@ for MODEL_TYPE in "gpt2_medium"; do
|
||||
for DISTPLAN in "CAI_Gemini"; do
|
||||
for BATCH_SIZE in 2; do
|
||||
for GPUNUM in 1 4; do
|
||||
for TPDEGREE in 1 2; do
|
||||
if [ ${TPDEGREE} -gt ${GPUNUM} ]; then
|
||||
continue
|
||||
fi
|
||||
for PLACEMENT in "cpu" "auto"; do
|
||||
MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE} PLACEMENT=${PLACEMENT} \
|
||||
bash ./run_gemini.sh
|
||||
done
|
||||
done
|
||||
MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} \
|
||||
bash ./run_gemini.sh
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
for DISTPLAN in "zero1" "zero2"; do
|
||||
for DISTPLAN in "CAI_ZeRO2" "CAI_ZeRO1"; do
|
||||
for BATCH_SIZE in 2; do
|
||||
for GPUNUM in 1 4; do
|
||||
for TPDEGREE in 1; do
|
||||
if [ ${TPDEGREE} -gt ${GPUNUM} ]; then
|
||||
continue
|
||||
fi
|
||||
MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE}\
|
||||
bash ./run_gemini.sh
|
||||
done
|
||||
MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} \
|
||||
bash ./run_gemini.sh
|
||||
done
|
||||
done
|
||||
done
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from time import time
|
||||
|
||||
@@ -13,11 +14,10 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
|
||||
CAI_VERSION = colossalai.__version__
|
||||
|
||||
@@ -30,24 +30,6 @@ def parse_args():
|
||||
default='CAI_Gemini',
|
||||
help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp_degree",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--placement",
|
||||
type=str,
|
||||
default='cpu',
|
||||
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shardinit",
|
||||
action='store_true',
|
||||
help=
|
||||
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
@@ -71,20 +53,6 @@ def parse_args():
|
||||
return args
|
||||
|
||||
|
||||
# Parameter Sharding Strategies for Tensor Parallelism
|
||||
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
|
||||
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
param.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||
split_param_single_dim_tp1d(0, param, pg)
|
||||
|
||||
|
||||
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||
split_param_single_dim_tp1d(-1, param, pg)
|
||||
|
||||
|
||||
class GPTLMLoss(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
@@ -140,47 +108,6 @@ def set_cpu_maximum_parallelism():
|
||||
print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.")
|
||||
|
||||
|
||||
# Tensor Parallel
|
||||
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
||||
"""tensor_parallelize
|
||||
Sharding the Model Parameters.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): a torch module to be sharded
|
||||
"""
|
||||
for mn, module in model.named_modules():
|
||||
for pn, param in module.named_parameters(recurse=False):
|
||||
# NOTE() a param maybe shared by two modules
|
||||
if hasattr(param, 'visited'):
|
||||
continue
|
||||
|
||||
# if shard init, then convert param to replica and use the dp-only ProcessGroup
|
||||
param: ColoParameter = param
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
param.set_process_group(pg)
|
||||
|
||||
# shard it w.r.t tp pattern
|
||||
if 'mlp.c_fc' in mn:
|
||||
if 'weight' in pn or 'bias' in pn:
|
||||
split_param_col_tp1d(param, pg) # column slice
|
||||
# keep the shape of the output from c_fc
|
||||
param.compute_spec.set_output_replicate(False)
|
||||
else:
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
elif 'mlp.c_proj' in mn:
|
||||
if 'weight' in pn:
|
||||
split_param_row_tp1d(param, pg) # row slice
|
||||
else:
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
elif 'wte' in mn or 'wpe' in mn:
|
||||
split_param_col_tp1d(param, pg) # column slice
|
||||
elif 'c_attn' in mn or 'c_proj' in mn:
|
||||
split_param_col_tp1d(param, pg) # column slice
|
||||
else:
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
param.visited = True
|
||||
|
||||
|
||||
def main():
|
||||
# version check
|
||||
# this example is supposed to work for versions greater than 0.2.0
|
||||
@@ -213,30 +140,13 @@ def main():
|
||||
|
||||
# build criterion
|
||||
criterion = GPTLMLoss()
|
||||
|
||||
torch.manual_seed(123)
|
||||
if args.distplan.startswith("CAI"):
|
||||
# all param must use the same process group.
|
||||
world_size = torch.distributed.get_world_size()
|
||||
shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None
|
||||
default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None
|
||||
|
||||
if args.shardinit and args.distplan != "CAI_Gemini":
|
||||
raise RuntimeError("You can only use shardinit with CAI_Gemini")
|
||||
|
||||
ctx = LazyInitContext(default_device=get_current_device()) if args.distplan == "CAI_Gemini" else nullcontext()
|
||||
# build GPT model
|
||||
with ColoInitContext(device=get_current_device(),
|
||||
dtype=torch.half,
|
||||
default_dist_spec=default_dist_spec,
|
||||
default_pg=shard_pg):
|
||||
with ctx:
|
||||
model = model_builder(args.model_type)(checkpoint=True)
|
||||
|
||||
tp_pg = ProcessGroup(tp_degree=args.tp_degree)
|
||||
# Tensor Parallelism (TP)
|
||||
# You should notice that v0.1.10 is not compatible with TP degree > 1
|
||||
if args.tp_degree > 1:
|
||||
tensor_parallelize(model, tp_pg)
|
||||
|
||||
# assign running configurations
|
||||
if args.distplan == "CAI_ZeRO1":
|
||||
zero_stage = 1
|
||||
@@ -254,13 +164,7 @@ def main():
|
||||
overlap_communication=True,
|
||||
verbose=True)
|
||||
elif args.distplan == "CAI_Gemini":
|
||||
plugin = GeminiPlugin(device=get_current_device(),
|
||||
placement_policy=args.placement,
|
||||
pin_memory=True,
|
||||
strict_ddp_mode=args.tp_degree == 1,
|
||||
search_range_m=128,
|
||||
hidden_dim=model.config.n_embd,
|
||||
gpu_margin_mem_ratio=0.)
|
||||
plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd)
|
||||
else:
|
||||
raise RuntimeError
|
||||
|
||||
|
Reference in New Issue
Block a user