diff --git a/examples/language/palm/train.py b/examples/language/palm/train.py index 6725c07df..b17496954 100644 --- a/examples/language/palm/train.py +++ b/examples/language/palm/train.py @@ -1,22 +1,22 @@ import gzip import random -from time import time from functools import partial +from time import time + import numpy as np import torch -import torch.optim as optim import torch.nn as nn +import torch.optim as optim import tqdm from packaging import version from palm_pytorch import PaLM from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper -from torch.nn import functional as F from torch.utils.data import DataLoader, Dataset import colossalai from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer -from colossalai.nn.parallel import GeminiDDP, ZeroDDP +from colossalai.nn.parallel import ZeroDDP from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.utils import MultiTimer, get_current_device from colossalai.utils.model.colo_init_context import ColoInitContext @@ -69,6 +69,7 @@ def parse_args(): args = parser.parse_args() return args + # helpers def cycle(loader): while True: @@ -79,12 +80,15 @@ def cycle(loader): def decode_token(token): return str(chr(max(32, token))) + def get_tflops(model_numel, batch_size, seq_len, step_time): return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) + def decode_tokens(tokens): return "".join(list(map(decode_token, tokens))) + def get_model_size(model: nn.Module): total_numel = 0 for module in model.modules(): @@ -92,6 +96,7 @@ def get_model_size(model: nn.Module): total_numel += p.numel() return total_numel + # Gemini + ZeRO DDP def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"): cai_version = colossalai.__version__ @@ -115,6 +120,7 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: raise NotImplemented(f"CAI version {cai_version} is not supported") return model + ## Parameter Sharding Strategies for Tensor Parallelism def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) @@ -128,6 +134,7 @@ def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup): def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup): split_param_single_dim_tp1d(-1, param, pg) + # Tensor Parallel def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup): """tensor_parallelize @@ -159,7 +166,7 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup): args = parse_args() if args.distplan not in ["colossalai", "pytorch"]: - raise TypeError(f"{args.distplan} is error") + raise TypeError(f"{args.distplan} is error") disable_existing_loggers() colossalai.launch_from_torch(config={}) logger = get_dist_logger() @@ -216,7 +223,7 @@ else: model.cuda() optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) - # model is shared after TP +# model is shared after TP numel = get_model_size(model) get_tflops_func = partial(get_tflops, numel, args.batch_size, SEQ_LEN) @@ -251,7 +258,7 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"): ) if i >= WARMUP_BATCHES: tflops_list.append(step_tflops) - + else: for __ in range(GRADIENT_ACCUMULATE_EVERY): loss = model(next(train_loader)) @@ -261,18 +268,17 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"): torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) optim.step() optim.zero_grad() - + tflops_list.sort() median_index = ((NUM_BATCHES - WARMUP_BATCHES) >> 1) + WARMUP_BATCHES logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}") - - # TODO - # if i % VALIDATE_EVERY == 0: - # model.eval() - # with torch.no_grad(): - # loss = model(next(val_loader)) - # print(f"validation loss: {loss.item()}") +# TODO +# if i % VALIDATE_EVERY == 0: +# model.eval() +# with torch.no_grad(): +# loss = model(next(val_loader)) +# print(f"validation loss: {loss.item()}") # if i % GENERATE_EVERY == 0: # model.eval() @@ -282,4 +288,4 @@ logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}") # sample = model.generate(inp[None, ...], GENERATE_LENGTH) # output_str = decode_tokens(sample[0]) - # print(output_str) \ No newline at end of file + # print(output_str)