mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[CI] add test_ci.sh for palm, opt and gpt (#2475)
This commit is contained in:
@@ -8,4 +8,4 @@ export PLACEMENT='cpu'
|
||||
export USE_SHARD_INIT=False
|
||||
export BATCH_SIZE=4
|
||||
|
||||
env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train_new.py --tp_degree=${TPDEGREE} --batch_size=${BATCH_SIZE} --placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee run.log
|
||||
env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train.py --tp_degree=${TPDEGREE} --batch_size=${BATCH_SIZE} --placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee run.log
|
||||
|
9
examples/language/palm/test_ci.sh
Normal file
9
examples/language/palm/test_ci.sh
Normal file
@@ -0,0 +1,9 @@
|
||||
$(cd `dirname $0`;pwd)
|
||||
|
||||
for BATCH_SIZE in 2
|
||||
do
|
||||
for GPUNUM in 1 4
|
||||
do
|
||||
env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train.py --dummy_data=True --batch_size=${BATCH_SIZE} 2>&1 | tee run.log
|
||||
done
|
||||
done
|
@@ -1,11 +1,12 @@
|
||||
import gzip
|
||||
import random
|
||||
from time import time
|
||||
from functools import partial
|
||||
from time import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import tqdm
|
||||
from packaging import version
|
||||
from palm_pytorch import PaLM
|
||||
@@ -23,7 +24,7 @@ from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
|
||||
# constants
|
||||
|
||||
NUM_BATCHES = int(100)
|
||||
NUM_BATCHES = int(10)
|
||||
WARMUP_BATCHES = 1
|
||||
GRADIENT_ACCUMULATE_EVERY = 1
|
||||
LEARNING_RATE = 2e-4
|
||||
@@ -66,9 +67,16 @@ def parse_args():
|
||||
default=8,
|
||||
help="batch size per DP group of training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dummy_data",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="use dummy dataset.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
# helpers
|
||||
def cycle(loader):
|
||||
while True:
|
||||
@@ -79,12 +87,15 @@ def cycle(loader):
|
||||
def decode_token(token):
|
||||
return str(chr(max(32, token)))
|
||||
|
||||
|
||||
def get_tflops(model_numel, batch_size, seq_len, step_time):
|
||||
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
|
||||
|
||||
|
||||
def decode_tokens(tokens):
|
||||
return "".join(list(map(decode_token, tokens)))
|
||||
|
||||
|
||||
def get_model_size(model: nn.Module):
|
||||
total_numel = 0
|
||||
for module in model.modules():
|
||||
@@ -92,6 +103,7 @@ def get_model_size(model: nn.Module):
|
||||
total_numel += p.numel()
|
||||
return total_numel
|
||||
|
||||
|
||||
# Gemini + ZeRO DDP
|
||||
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
|
||||
cai_version = colossalai.__version__
|
||||
@@ -115,6 +127,7 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
|
||||
raise NotImplemented(f"CAI version {cai_version} is not supported")
|
||||
return model
|
||||
|
||||
|
||||
## Parameter Sharding Strategies for Tensor Parallelism
|
||||
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
|
||||
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
@@ -128,6 +141,7 @@ def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||
split_param_single_dim_tp1d(-1, param, pg)
|
||||
|
||||
|
||||
# Tensor Parallel
|
||||
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
||||
"""tensor_parallelize
|
||||
@@ -159,15 +173,28 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
||||
|
||||
args = parse_args()
|
||||
if args.distplan not in ["colossalai", "pytorch"]:
|
||||
raise TypeError(f"{args.distplan} is error")
|
||||
raise TypeError(f"{args.distplan} is error")
|
||||
disable_existing_loggers()
|
||||
colossalai.launch_from_torch(config={})
|
||||
logger = get_dist_logger()
|
||||
|
||||
with gzip.open("./data/enwik8.gz") as file:
|
||||
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
|
||||
trX, vaX = np.split(X, [int(90e6)])
|
||||
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
|
||||
|
||||
def generate_dataset(dummy_data: bool = False):
|
||||
if not dummy_data:
|
||||
with gzip.open("./data/enwik8.gz") as file:
|
||||
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
|
||||
trX, vaX = np.split(X, [int(90e6)])
|
||||
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
|
||||
# print(f"data_train {data_train.shape} {data_train.dtype} {max(data_train)} {min(data_train)}")
|
||||
# print(f"data_val {data_val.shape} {data_val.dtype} {max(data_val)} {min(data_val)}")
|
||||
return data_train, data_val
|
||||
else:
|
||||
return torch.randint(0, 100, (90000000,)), torch.randint(0, 100, (5000000,))
|
||||
|
||||
|
||||
data_train, data_val = generate_dataset(args.dummy_data)
|
||||
|
||||
print("generate dataset ready!")
|
||||
|
||||
|
||||
class TextSamplerDataset(Dataset):
|
||||
@@ -216,7 +243,7 @@ else:
|
||||
model.cuda()
|
||||
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
||||
|
||||
# model is shared after TP
|
||||
# model is shared after TP
|
||||
numel = get_model_size(model)
|
||||
get_tflops_func = partial(get_tflops, numel, args.batch_size, SEQ_LEN)
|
||||
|
||||
@@ -251,7 +278,7 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
|
||||
)
|
||||
if i >= WARMUP_BATCHES:
|
||||
tflops_list.append(step_tflops)
|
||||
|
||||
|
||||
else:
|
||||
for __ in range(GRADIENT_ACCUMULATE_EVERY):
|
||||
loss = model(next(train_loader))
|
||||
@@ -261,18 +288,17 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
|
||||
optim.step()
|
||||
optim.zero_grad()
|
||||
|
||||
|
||||
tflops_list.sort()
|
||||
median_index = ((NUM_BATCHES - WARMUP_BATCHES) >> 1) + WARMUP_BATCHES
|
||||
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
|
||||
|
||||
|
||||
# TODO
|
||||
# if i % VALIDATE_EVERY == 0:
|
||||
# model.eval()
|
||||
# with torch.no_grad():
|
||||
# loss = model(next(val_loader))
|
||||
# print(f"validation loss: {loss.item()}")
|
||||
# TODO
|
||||
# if i % VALIDATE_EVERY == 0:
|
||||
# model.eval()
|
||||
# with torch.no_grad():
|
||||
# loss = model(next(val_loader))
|
||||
# print(f"validation loss: {loss.item()}")
|
||||
|
||||
# if i % GENERATE_EVERY == 0:
|
||||
# model.eval()
|
||||
@@ -282,4 +308,4 @@ logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
|
||||
|
||||
# sample = model.generate(inp[None, ...], GENERATE_LENGTH)
|
||||
# output_str = decode_tokens(sample[0])
|
||||
# print(output_str)
|
||||
# print(output_str)
|
||||
|
Reference in New Issue
Block a user