[CI] add test_ci.sh for palm, opt and gpt (#2475)

This commit is contained in:
Jiarui Fang
2023-01-16 14:44:29 +08:00
committed by GitHub
parent e4c38ba367
commit 7c31706227
8 changed files with 107 additions and 38 deletions

View File

@@ -8,4 +8,4 @@ export PLACEMENT='cpu'
export USE_SHARD_INIT=False
export BATCH_SIZE=4
env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train_new.py --tp_degree=${TPDEGREE} --batch_size=${BATCH_SIZE} --placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee run.log
env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train.py --tp_degree=${TPDEGREE} --batch_size=${BATCH_SIZE} --placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee run.log

View File

@@ -0,0 +1,9 @@
$(cd `dirname $0`;pwd)
for BATCH_SIZE in 2
do
for GPUNUM in 1 4
do
env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train.py --dummy_data=True --batch_size=${BATCH_SIZE} 2>&1 | tee run.log
done
done

View File

@@ -1,11 +1,12 @@
import gzip
import random
from time import time
from functools import partial
from time import time
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import torch.optim as optim
import tqdm
from packaging import version
from palm_pytorch import PaLM
@@ -23,7 +24,7 @@ from colossalai.utils.model.colo_init_context import ColoInitContext
# constants
NUM_BATCHES = int(100)
NUM_BATCHES = int(10)
WARMUP_BATCHES = 1
GRADIENT_ACCUMULATE_EVERY = 1
LEARNING_RATE = 2e-4
@@ -66,9 +67,16 @@ def parse_args():
default=8,
help="batch size per DP group of training.",
)
parser.add_argument(
"--dummy_data",
type=bool,
default=False,
help="use dummy dataset.",
)
args = parser.parse_args()
return args
# helpers
def cycle(loader):
while True:
@@ -79,12 +87,15 @@ def cycle(loader):
def decode_token(token):
return str(chr(max(32, token)))
def get_tflops(model_numel, batch_size, seq_len, step_time):
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
def decode_tokens(tokens):
return "".join(list(map(decode_token, tokens)))
def get_model_size(model: nn.Module):
total_numel = 0
for module in model.modules():
@@ -92,6 +103,7 @@ def get_model_size(model: nn.Module):
total_numel += p.numel()
return total_numel
# Gemini + ZeRO DDP
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
cai_version = colossalai.__version__
@@ -115,6 +127,7 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
raise NotImplemented(f"CAI version {cai_version} is not supported")
return model
## Parameter Sharding Strategies for Tensor Parallelism
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
@@ -128,6 +141,7 @@ def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(-1, param, pg)
# Tensor Parallel
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
"""tensor_parallelize
@@ -159,15 +173,28 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
args = parse_args()
if args.distplan not in ["colossalai", "pytorch"]:
raise TypeError(f"{args.distplan} is error")
raise TypeError(f"{args.distplan} is error")
disable_existing_loggers()
colossalai.launch_from_torch(config={})
logger = get_dist_logger()
with gzip.open("./data/enwik8.gz") as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
trX, vaX = np.split(X, [int(90e6)])
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
def generate_dataset(dummy_data: bool = False):
if not dummy_data:
with gzip.open("./data/enwik8.gz") as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
trX, vaX = np.split(X, [int(90e6)])
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
# print(f"data_train {data_train.shape} {data_train.dtype} {max(data_train)} {min(data_train)}")
# print(f"data_val {data_val.shape} {data_val.dtype} {max(data_val)} {min(data_val)}")
return data_train, data_val
else:
return torch.randint(0, 100, (90000000,)), torch.randint(0, 100, (5000000,))
data_train, data_val = generate_dataset(args.dummy_data)
print("generate dataset ready!")
class TextSamplerDataset(Dataset):
@@ -216,7 +243,7 @@ else:
model.cuda()
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# model is shared after TP
# model is shared after TP
numel = get_model_size(model)
get_tflops_func = partial(get_tflops, numel, args.batch_size, SEQ_LEN)
@@ -251,7 +278,7 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
)
if i >= WARMUP_BATCHES:
tflops_list.append(step_tflops)
else:
for __ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader))
@@ -261,18 +288,17 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()
tflops_list.sort()
median_index = ((NUM_BATCHES - WARMUP_BATCHES) >> 1) + WARMUP_BATCHES
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
# TODO
# if i % VALIDATE_EVERY == 0:
# model.eval()
# with torch.no_grad():
# loss = model(next(val_loader))
# print(f"validation loss: {loss.item()}")
# TODO
# if i % VALIDATE_EVERY == 0:
# model.eval()
# with torch.no_grad():
# loss = model(next(val_loader))
# print(f"validation loss: {loss.item()}")
# if i % GENERATE_EVERY == 0:
# model.eval()
@@ -282,4 +308,4 @@ logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
# sample = model.generate(inp[None, ...], GENERATE_LENGTH)
# output_str = decode_tokens(sample[0])
# print(output_str)
# print(output_str)