mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[example] make gpt example directory more clear (#2353)
This commit is contained in:
20
examples/language/gpt/gemini/benchmark_gemini.sh
Normal file
20
examples/language/gpt/gemini/benchmark_gemini.sh
Normal file
@@ -0,0 +1,20 @@
|
||||
for MODEL_TYPE in "gpt2_medium"; do
|
||||
for BATCH_SIZE in 16; do
|
||||
for GPUNUM in 1 2 4 8; do
|
||||
for TPDEGREE in 1 2 4 8; do
|
||||
if [ ${TPDEGREE} -gt ${GPUNUM} ]; then
|
||||
continue
|
||||
fi
|
||||
for PLACEMENT in "cpu" "auto"; do
|
||||
echo "****************** Begin ***************************"
|
||||
echo "* benchmrking MODEL_TYPE ${MODEL_TYPE} BS ${BATCH_SIZE} BS ${BS} GPUNUM ${GPUNUM} TPDEGREE ${TPDEGREE} PLACEMENT ${PLACEMENT}"
|
||||
MODEL_TYPE=${MODEL_TYPE} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE} PLACEMENT=${PLACEMENT} \
|
||||
bash ./gemini/run_gemini.sh
|
||||
echo "****************** Finished ***************************"
|
||||
echo ""
|
||||
echo ""
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
73
examples/language/gpt/gemini/commons/model_zoo.py
Normal file
73
examples/language/gpt/gemini/commons/model_zoo.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from torch import nn
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
|
||||
|
||||
## Define the Model and Loss Based on Huggingface transformers GPT2LMHeadModel
|
||||
class GPTLMModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size=768,
|
||||
num_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_seq_len=1024,
|
||||
vocab_size=50257,
|
||||
checkpoint=False):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
self.config = GPT2Config(n_embd=hidden_size,
|
||||
n_layer=num_layers,
|
||||
n_head=num_attention_heads,
|
||||
n_positions=max_seq_len,
|
||||
n_ctx=max_seq_len,
|
||||
vocab_size=vocab_size)
|
||||
self.model = GPT2LMHeadModel(self.config)
|
||||
if checkpoint:
|
||||
self.model.gradient_checkpointing_enable()
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
# Only return lm_logits
|
||||
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
|
||||
|
||||
|
||||
def gpt2_medium(checkpoint=False):
|
||||
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def gpt2_xl(checkpoint=True):
|
||||
return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def gpt2_10b(checkpoint=True):
|
||||
return GPTLMModel(hidden_size=4096, num_layers=50, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def gpt2_14b(checkpoint=True):
|
||||
return GPTLMModel(hidden_size=4096, num_layers=70, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def gpt2_20b(checkpoint=True):
|
||||
return GPTLMModel(hidden_size=8192, num_layers=25, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def gpt2_24b(checkpoint=True):
|
||||
return GPTLMModel(hidden_size=8192, num_layers=30, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def model_builder(model_size: str) -> callable:
|
||||
if model_size == "gpt2_medium":
|
||||
return gpt2_medium
|
||||
elif model_size == "gpt2_xl":
|
||||
return gpt2_xl
|
||||
elif model_size == "gpt2_10b":
|
||||
return gpt2_10b
|
||||
elif model_size == "gpt2_14b":
|
||||
return gpt2_14b
|
||||
elif model_size == "gpt2_20b":
|
||||
return gpt2_20b
|
||||
elif model_size == "gpt2_24b":
|
||||
return gpt2_24b
|
||||
else:
|
||||
raise TypeError(f"model_builder {model_size}")
|
||||
|
||||
|
||||
__all__ = ['model_builder']
|
12
examples/language/gpt/gemini/commons/utils.py
Normal file
12
examples/language/gpt/gemini/commons/utils.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import torch
|
||||
|
||||
|
||||
# Randomly Generated Data
|
||||
def get_data(batch_size, seq_len, vocab_size):
|
||||
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
return input_ids, attention_mask
|
||||
|
||||
|
||||
def get_tflops(model_numel, batch_size, seq_len, step_time):
|
||||
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
|
24
examples/language/gpt/gemini/run_gemini.sh
Normal file
24
examples/language/gpt/gemini/run_gemini.sh
Normal file
@@ -0,0 +1,24 @@
|
||||
set -x
|
||||
# distplan in ["colossalai", "zero1", "zero2", "torch_ddp", "torch_zero"]
|
||||
export DISTPAN=${DISTPAN:-"colossalai"}
|
||||
|
||||
# The following options only valid when DISTPAN="colossalai"
|
||||
export GPUNUM=${GPUNUM:-1}
|
||||
export TPDEGREE=${TPDEGREE:-1}
|
||||
export PLACEMENT=${PLACEMENT:-"cpu"}
|
||||
export USE_SHARD_INIT=${USE_SHARD_INIT:-False}
|
||||
export BATCH_SIZE=${BATCH_SIZE:-16}
|
||||
export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"}
|
||||
|
||||
# export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
|
||||
mkdir -p gemini_logs
|
||||
|
||||
torchrun --standalone --nproc_per_node=${GPUNUM} ./train_gpt_demo.py \
|
||||
--tp_degree=${TPDEGREE} \
|
||||
--model_type=${MODEL_TYPE} \
|
||||
--batch_size=${BATCH_SIZE} \
|
||||
--placement=${PLACEMENT} \
|
||||
--shardinit=${USE_SHARD_INIT} \
|
||||
--distplan=${DISTPAN} \
|
||||
2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}_${PLACEMENT}.log
|
350
examples/language/gpt/gemini/train_gpt_demo.py
Normal file
350
examples/language/gpt/gemini/train_gpt_demo.py
Normal file
@@ -0,0 +1,350 @@
|
||||
import os
|
||||
from functools import partial
|
||||
from time import time
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from commons.model_zoo import model_builder
|
||||
from commons.utils import get_data, get_tflops
|
||||
from packaging import version
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
|
||||
CAI_VERSION = colossalai.__version__
|
||||
|
||||
if version.parse(CAI_VERSION) > version.parse("0.1.10"):
|
||||
# These are added after 0.1.10
|
||||
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
|
||||
from colossalai.nn.parallel import GeminiDDP
|
||||
from colossalai.zero.sharded_optim import LowLevelZeroOptimizer
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = colossalai.get_default_parser()
|
||||
parser.add_argument(
|
||||
"--distplan",
|
||||
type=str,
|
||||
default='colossalai',
|
||||
help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp_degree",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--placement",
|
||||
type=str,
|
||||
default='cpu',
|
||||
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shardinit",
|
||||
type=bool,
|
||||
default=False,
|
||||
help=
|
||||
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=8,
|
||||
help="batch size per DP group of training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
type=str,
|
||||
default="gpt2_medium",
|
||||
help="model model scale",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
# Parameter Sharding Strategies for Tensor Parallelism
|
||||
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
|
||||
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
param.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||
split_param_single_dim_tp1d(0, param, pg)
|
||||
|
||||
|
||||
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||
split_param_single_dim_tp1d(-1, param, pg)
|
||||
|
||||
|
||||
class GPTLMLoss(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.loss_fn = nn.CrossEntropyLoss()
|
||||
|
||||
def forward(self, logits, labels):
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
|
||||
def get_cpu_mem():
|
||||
return psutil.Process().memory_info().rss / 1024**2
|
||||
|
||||
|
||||
def get_gpu_mem():
|
||||
return torch.cuda.memory_allocated() / 1024**2
|
||||
|
||||
|
||||
def get_mem_info(prefix=''):
|
||||
return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'
|
||||
|
||||
|
||||
def get_model_size(model: nn.Module):
|
||||
total_numel = 0
|
||||
for module in model.modules():
|
||||
for p in module.parameters(recurse=False):
|
||||
total_numel += p.numel()
|
||||
return total_numel
|
||||
|
||||
|
||||
def model_size_formatter(numel: int) -> str:
|
||||
GB_SIZE = 10**9
|
||||
MB_SIZE = 10**6
|
||||
KB_SIZE = 10**3
|
||||
if numel >= GB_SIZE:
|
||||
return f'{numel / GB_SIZE:.1f}B'
|
||||
elif numel >= MB_SIZE:
|
||||
return f'{numel / MB_SIZE:.1f}M'
|
||||
elif numel >= KB_SIZE:
|
||||
return f'{numel / KB_SIZE:.1f}K'
|
||||
else:
|
||||
return str(numel)
|
||||
|
||||
|
||||
def set_cpu_maximum_parallelism():
|
||||
conf_str = torch.__config__.parallel_info()
|
||||
inter_str = conf_str.split("hardware_concurrency() : ")[1]
|
||||
max_concurrency = inter_str.split('\n')[0]
|
||||
os.environ["OMP_NUM_THREADS"] = max_concurrency
|
||||
print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.")
|
||||
|
||||
|
||||
# Tensor Parallel
|
||||
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
||||
"""tensor_parallelize
|
||||
Sharding the Model Parameters.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): a torch module to be sharded
|
||||
"""
|
||||
for mn, module in model.named_modules():
|
||||
for pn, param in module.named_parameters(recurse=False):
|
||||
# NOTE() a param maybe shared by tow modules
|
||||
if hasattr(param, 'visited'):
|
||||
continue
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
if 'mlp.c_fc' in mn:
|
||||
if 'weight' in pn or 'bias' in pn:
|
||||
split_param_col_tp1d(param, pg) # colmn slice
|
||||
# keep the shape of the output from c_fc
|
||||
param.compute_spec.set_output_replicate(False)
|
||||
else:
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
elif 'mlp.c_proj' in mn:
|
||||
if 'weight' in pn:
|
||||
split_param_row_tp1d(param, pg) # row slice
|
||||
else:
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
elif 'wte' in mn or 'wpe' in mn:
|
||||
split_param_col_tp1d(param, pg) # colmn slice
|
||||
elif 'c_attn' in mn or 'c_proj' in mn:
|
||||
split_param_col_tp1d(param, pg) # colmn slice
|
||||
else:
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
|
||||
param.visited = True
|
||||
|
||||
|
||||
# Gemini + ZeRO DDP
|
||||
def build_gemini(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"):
|
||||
fp16_init_scale = 2**5
|
||||
gpu_margin_mem_ratio_for_auto = 0
|
||||
|
||||
if version.parse(CAI_VERSION) > version.parse("0.1.10"):
|
||||
model = GeminiDDP(model,
|
||||
device=get_current_device(),
|
||||
placement_policy=placement_policy,
|
||||
pin_memory=True,
|
||||
hidden_dim=model.config.n_embd,
|
||||
search_range_mb=64)
|
||||
# configure the const policy
|
||||
if placement_policy == 'const':
|
||||
model.gemini_manager._placement_policy.set_const_memory_boundary(2 * 1024)
|
||||
# build a highly optimized cpu optimizer
|
||||
optimizer = GeminiAdamOptimizer(model,
|
||||
lr=1e-3,
|
||||
initial_scale=fp16_init_scale,
|
||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio_for_auto)
|
||||
elif version.parse("0.1.9") <= version.parse(CAI_VERSION) <= version.parse("0.1.10"):
|
||||
from colossalai.gemini import ChunkManager, GeminiManager
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.zero import ZeroOptimizer
|
||||
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 1024, filter_exlarge_params=True)
|
||||
chunk_manager = ChunkManager(chunk_size,
|
||||
pg,
|
||||
enable_distributed_storage=True,
|
||||
init_device=GeminiManager.get_default_device(placement_policy))
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager)
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
optimizer = ZeroOptimizer(optimizer,
|
||||
model,
|
||||
initial_scale=fp16_init_scale,
|
||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio_for_auto)
|
||||
else:
|
||||
raise NotImplemented(f"CAI version {CAI_VERSION} is not supported")
|
||||
return model, optimizer
|
||||
|
||||
|
||||
def main():
|
||||
# version check
|
||||
# this example is supposed to work for versions greater than 0.1.9
|
||||
assert version.parse(CAI_VERSION) >= version.parse("0.1.9")
|
||||
|
||||
set_cpu_maximum_parallelism()
|
||||
args = parse_args()
|
||||
|
||||
if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]:
|
||||
raise TypeError(f"{args.distplan} is error")
|
||||
|
||||
# batch size per DP degree
|
||||
BATCH_SIZE = args.batch_size
|
||||
SEQ_LEN = 1024
|
||||
VOCAB_SIZE = 50257
|
||||
|
||||
NUM_STEPS = 10
|
||||
WARMUP_STEPS = 1
|
||||
assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps"
|
||||
assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median "
|
||||
|
||||
disable_existing_loggers()
|
||||
colossalai.launch_from_torch(config={})
|
||||
|
||||
logger = get_dist_logger()
|
||||
logger.info(f"{args.model_type}, {args.distplan}, batch size {BATCH_SIZE}", ranks=[0])
|
||||
|
||||
# build criterion
|
||||
criterion = GPTLMLoss()
|
||||
|
||||
torch.manual_seed(123)
|
||||
if args.distplan == "colossalai":
|
||||
# all param must use the same process group.
|
||||
default_pg = ProcessGroup(tp_degree=args.tp_degree)
|
||||
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
|
||||
|
||||
# build GPT model
|
||||
if version.parse(CAI_VERSION) > version.parse("0.1.10"):
|
||||
with ColoInitContext(device=get_current_device(),
|
||||
dtype=torch.half,
|
||||
default_dist_spec=default_dist_spec,
|
||||
default_pg=default_pg):
|
||||
model = model_builder(args.model_type)(checkpoint=True)
|
||||
else:
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder(args.model_type)(checkpoint=True)
|
||||
|
||||
pg = default_pg
|
||||
# Tensor Parallelism (TP)
|
||||
tensor_parallelize(model, pg)
|
||||
|
||||
# build a Gemini model and a highly optimized cpu optimizer
|
||||
# Gemini + ZeRO DP, Note it must be used after TP
|
||||
model, optimizer = build_gemini(model, pg, args.placement)
|
||||
|
||||
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
|
||||
else:
|
||||
model = model_builder(args.model_type)(checkpoint=True).cuda()
|
||||
|
||||
if args.distplan.startswith("torch"):
|
||||
model = DDP(model)
|
||||
if args.distplan.endswith("ddp"):
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
|
||||
elif args.distplan.endswith("zero"):
|
||||
from torch.distributed.optim import ZeroRedundancyOptimizer
|
||||
optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=0.01)
|
||||
elif args.distplan.startswith("zero"):
|
||||
partition_flag = args.distplan == "zero2"
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
|
||||
optimizer = LowLevelZeroOptimizer(optimizer,
|
||||
overlap_communication=True,
|
||||
partition_grad=partition_flag,
|
||||
verbose=True)
|
||||
|
||||
# model is shared after TP
|
||||
numel = get_model_size(model)
|
||||
logger.info(f"the size of testing model size is {model_size_formatter(numel)}.")
|
||||
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
|
||||
|
||||
# Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
|
||||
# = (batch_per_DP_group * dp_degree) * (numel * tp_degree) * seq_len * 8 / (tp_degree * dp_degree)
|
||||
# = batch_per_DP_group * numel * seq_len * 8
|
||||
get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
model.train()
|
||||
tflops_list = []
|
||||
for n in range(NUM_STEPS):
|
||||
# we just use randomly generated data here
|
||||
input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
|
||||
optimizer.zero_grad()
|
||||
|
||||
start = time()
|
||||
outputs = model(input_ids, attn_mask)
|
||||
loss = criterion(outputs, input_ids)
|
||||
torch.cuda.synchronize()
|
||||
fwd_end = time()
|
||||
fwd_time = fwd_end - start
|
||||
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0])
|
||||
|
||||
if args.distplan in ["colossalai", "zero1", "zero2"]:
|
||||
optimizer.backward(loss)
|
||||
elif args.distplan in ["torch_ddp", "torch_zero"]:
|
||||
loss.backward()
|
||||
torch.cuda.synchronize()
|
||||
bwd_end = time()
|
||||
bwd_time = bwd_end - fwd_end
|
||||
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Backward '), ranks=[0])
|
||||
|
||||
if args.distplan in ["zero1", "zero2"]:
|
||||
optimizer.sync_grad()
|
||||
optimizer.step()
|
||||
torch.cuda.synchronize()
|
||||
optim_time = time() - bwd_end
|
||||
step_time = time() - start
|
||||
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Optimizer step '), ranks=[0])
|
||||
|
||||
step_tflops = get_tflops_func(step_time)
|
||||
logger.info(
|
||||
f"[{n + 1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}, FWD time: {fwd_time:.3f}s, BWD time: {bwd_time:.3f}s, OPTIM time: {optim_time:.3f}s",
|
||||
ranks=[0],
|
||||
)
|
||||
if n >= WARMUP_STEPS:
|
||||
tflops_list.append(step_tflops)
|
||||
|
||||
tflops_list.sort()
|
||||
median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS
|
||||
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Reference in New Issue
Block a user