ColossalAI/examples/tutorial/sequence_parallel/train.py
Boyuan Yao 7a58dc5ad2
Update metainfo patch branch (#2517)
* init

* rename and remove useless func

* basic chunk

* add evoformer

* align evoformer

* add meta

* basic chunk

* basic memory

* finish basic inference memory estimation

* finish memory estimation

* fix bug

* finish memory estimation

* add part of index tracer

* finish basic index tracer

* add doc string

* add doc str

* polish code

* polish code

* update active log

* polish code

* add possible region search

* finish region search loop

* finish chunk define

* support new op

* rename index tracer

* finishi codegen on msa

* redesign index tracer, add source and change compute

* pass outproduct mean

* code format

* code format

* work with outerproductmean and msa

* code style

* code style

* code style

* code style

* change threshold

* support check_index_duplicate

* support index dupilictae and update loop

* support output

* update memory estimate

* optimise search

* fix layernorm

* move flow tracer

* refactor flow tracer

* format code

* refactor flow search

* code style

* adapt codegen to prepose node

* code style

* remove abandoned function

* remove flow tracer

* code style

* code style

* reorder nodes

* finish node reorder

* update run

* code style

* add chunk select class

* add chunk select

* code style

* add chunksize in emit, fix bug in reassgin shape

* code style

* turn off print mem

* add evoformer openfold init

* init openfold

* add benchmark

* add print

* code style

* code style

* init openfold

* update openfold

* align openfold

* use max_mem to control stratge

* update source add

* add reorder in mem estimator

* improve reorder efficeincy

* support ones_like, add prompt if fit mode search fail

* fix a bug in ones like, dont gen chunk if dim size is 1

* fix bug again

* update min memory stratege, reduce mem usage by 30%

* last version of benchmark

* refactor structure

* restruct dir

* update test

* rename

* take apart chunk code gen

* close mem and code print

* code format

* rename ambiguous variable

* seperate flow tracer

* seperate input node dim search

* seperate prepose_nodes

* seperate non chunk input

* seperate reorder

* rename

* ad reorder graph

* seperate trace flow

* code style

* code style

* fix typo

* set benchmark

* rename test

* update codegen test

* Fix state_dict key missing issue of the ZeroDDP (#2363)

* Fix state_dict output for ZeroDDP duplicated parameters

* Rewrite state_dict based on get_static_torch_model

* Modify get_static_torch_model to be compatible with the lower version (ZeroDDP)

* update codegen test

* update codegen test

* add chunk search test

* code style

* add available

* [hotfix] fix gpt gemini example (#2404)

* [hotfix] fix gpt gemini example

* [example] add new assertions

* remove autochunk_available

* [workflow] added nightly release to pypi (#2403)

* add comments

* code style

* add doc for search chunk

* [doc] updated readme regarding pypi installation (#2406)

* add doc for search

* [doc] updated kernel-related optimisers' docstring (#2385)

* [doc] updated kernel-related optimisers' docstring

* polish doc

* rename trace_index to trace_indice

* rename function from index to indice

* rename

* rename in doc

* [polish] polish code for get_static_torch_model (#2405)

* [gemini] polish code

* [testing] remove code

* [gemini] make more robust

* rename

* rename

* remove useless function

* [worfklow] added coverage test (#2399)

* [worfklow] added coverage test

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* add doc for trace indice

* [docker] updated Dockerfile and release workflow (#2410)

* add doc

* update doc

* add available

* change imports

* add test in import

* [workflow] refactored the example check workflow (#2411)

* [workflow] refactored the example check workflow

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* Update parallel_context.py (#2408)

* [hotfix] add DISTPAN argument for benchmark (#2412)

* change the benchmark config file

* change config

* revert config file

* rename distpan to distplan

* [workflow] added precommit check for code consistency (#2401)

* [workflow] added precommit check for code consistency

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* adapt new fx

* [workflow] added translation for non-english comments (#2414)

* [setup] refactored setup.py for dependency graph (#2413)

* change import

* update doc

* [workflow] auto comment if precommit check fails (#2417)

* [hotfix] add norm clearing for the overflow step (#2416)

* [examples] adding tflops to PaLM (#2365)

* [workflow]auto comment with test coverage report (#2419)

* [workflow]auto comment with test coverage report

* polish code

* polish yaml

* [doc] added documentation for CI/CD (#2420)

* [doc] added documentation for CI/CD

* polish markdown

* polish markdown

* polish markdown

* [example] removed duplicated stable diffusion example (#2424)

* [zero] add inference mode and its unit test (#2418)

* [workflow] report test coverage even if below threshold (#2431)

* [example] improved the clarity yof the example readme (#2427)

* [example] improved the clarity yof the example readme

* polish workflow

* polish workflow

* polish workflow

* polish workflow

* polish workflow

* polish workflow

* [ddp] add is_ddp_ignored (#2434)

[ddp] rename to is_ddp_ignored

* [workflow] make test coverage report collapsable (#2436)

* [autoparallel] add shard option (#2423)

* [fx] allow native ckpt trace and codegen. (#2438)

* [cli] provided more details if colossalai run fail (#2442)

* [autoparallel] integrate device mesh initialization into autoparallelize (#2393)

* [autoparallel] integrate device mesh initialization into autoparallelize

* add megatron solution

* update gpt autoparallel examples with latest api

* adapt beta value to fit the current computation cost

* [zero] fix state_dict and load_state_dict for ddp ignored parameters (#2443)

* [ddp] add is_ddp_ignored

[ddp] rename to is_ddp_ignored

* [zero] fix state_dict and load_state_dict

* fix bugs

* [zero] update unit test for ZeroDDP

* [example] updated the hybrid parallel tutorial (#2444)

* [example] updated the hybrid parallel tutorial

* polish code

* [zero] add warning for ignored parameters (#2446)

* [example] updated large-batch optimizer tutorial (#2448)

* [example] updated large-batch optimizer tutorial

* polish code

* polish code

* [example] fixed seed error in train_dreambooth_colossalai.py (#2445)

* [workflow] fixed the on-merge condition check (#2452)

* [workflow] automated the compatiblity test (#2453)

* [workflow] automated the compatiblity test

* polish code

* [autoparallel] update binary elementwise handler (#2451)

* [autoparallel] update binary elementwise handler

* polish

* [workflow] automated bdist wheel build (#2459)

* [workflow] automated bdist wheel build

* polish workflow

* polish readme

* polish readme

* Fix False warning in initialize.py (#2456)

* Update initialize.py

* pre-commit run check

* [examples] update autoparallel tutorial demo (#2449)

* [examples] update autoparallel tutorial demo

* add test_ci.sh

* polish

* add conda yaml

* [cli] fixed hostname mismatch error (#2465)

* [example] integrate autoparallel demo with CI (#2466)

* [example] integrate autoparallel demo with CI

* polish code

* polish code

* polish code

* polish code

* [zero] low level optim supports ProcessGroup (#2464)

* [example] update vit ci script (#2469)

* [example] update vit ci script

* [example] update requirements

* [example] update requirements

* [example] integrate seq-parallel tutorial with CI (#2463)

* [zero] polish low level optimizer (#2473)

* polish pp middleware (#2476)

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>

* [example] update gpt gemini example ci test (#2477)

* [zero] add unit test for low-level zero init (#2474)

* [workflow] fixed the skip condition of  example weekly check workflow (#2481)

* [example] stable diffusion add roadmap

* add dummy test_ci.sh

* [example] stable diffusion add roadmap (#2482)

* [CI] add test_ci.sh for palm, opt and gpt (#2475)

* polish code

* [example] titans for gpt

* polish readme

* remove license

* polish code

* update readme

* [example] titans for gpt (#2484)

* [autoparallel] support origin activation ckpt on autoprallel system (#2468)

* [autochunk] support evoformer tracer (#2485)

support full evoformer tracer, which is a main module of alphafold. previously we just support a simplifed version of it.
1. support some evoformer's op in fx
2. support evoformer test
3. add repos for test code

* [example] fix requirements (#2488)

* [zero] add unit testings for hybrid parallelism  (#2486)

* [hotfix] gpt example titans bug #2493

* polish code and fix dataloader bugs

* [hotfix] gpt example titans bug #2493 (#2494)

* [fx] allow control of ckpt_codegen init (#2498)

* [fx] allow control of ckpt_codegen init

Currently in ColoGraphModule, ActivationCheckpointCodeGen will be set automatically in __init__. But other codegen can't be set if so. 
So I add an arg to control whether to set ActivationCheckpointCodeGen in __init__.

* code style

* [example] dreambooth example

* add test_ci.sh to dreambooth

* [autochunk] support autochunk on evoformer (#2497)

* Revert "Update parallel_context.py (#2408)"

This reverts commit 7d5640b9db.

* add avg partition (#2483)

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>

* [auto-chunk] support extramsa (#3) (#2504)

* [utils] lazy init. (#2148)

* [utils] lazy init.

* [utils] remove description.

* [utils] complete.

* [utils] finalize.

* [utils] fix names.

* [autochunk] support parsing blocks (#2506)

* [zero] add strict ddp mode (#2508)

* [zero] add strict ddp mode

* [polish] add comments for strict ddp mode

* [zero] fix test error

* [doc] update opt and tutorial links (#2509)

* [workflow] fixed changed file detection (#2515)

Co-authored-by: oahzxl <xuanlei.zhao@gmail.com>
Co-authored-by: eric8607242 <e0928021388@gmail.com>
Co-authored-by: HELSON <c2h214748@gmail.com>
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: Haofan Wang <haofanwang.ai@gmail.com>
Co-authored-by: Jiarui Fang <fangjiarui123@gmail.com>
Co-authored-by: ZijianYY <119492445+ZijianYY@users.noreply.github.com>
Co-authored-by: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com>
Co-authored-by: Super Daniel <78588128+super-dainiu@users.noreply.github.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: Ziyue Jiang <ziyue.jiang97@gmail.com>
Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
Co-authored-by: oahzxl <43881818+oahzxl@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Fazzie-Maqianli <55798671+Fazziekey@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
2023-01-27 09:52:21 +08:00

219 lines
9.1 KiB
Python

import argparse
import torch
from data.bert_helper import SequenceParallelDataIterator, get_batch_for_sequence_parallel
from data.dummy_dataloader import DummyDataloader
from loss_func.bert_loss import BertLoss
from lr_scheduler import AnnealingLR
from model.bert import BertForPretrain, build_pipeline_bert
import colossalai
from colossalai.amp import AMP_TYPE
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine.schedule import PipelineSchedule
from colossalai.kernel import LayerNorm
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import FusedAdam
from colossalai.utils import MultiTimer, is_using_pp
def process_batch_data(batch_data):
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = batch_data
if gpc.is_first_rank(ParallelMode.PIPELINE):
data = dict(input_ids=tokens, attention_masks=padding_mask, tokentype_ids=types, lm_labels=lm_labels)
else:
data = dict(attention_masks=padding_mask, tokentype_ids=types, lm_labels=lm_labels)
label = dict(loss_mask=loss_mask, sentence_order=sentence_order)
return data, label
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('-s', '--synthetic', action="store_true", help="whether use synthetic data")
return parser.parse_args()
def pipeline_data_process_func(stage_output, micro_batch_data):
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = micro_batch_data
if gpc.is_first_rank(ParallelMode.PIPELINE):
data = (tokens, padding_mask, types, lm_labels)
label = (loss_mask, sentence_order)
else:
data = (stage_output, padding_mask, types, lm_labels)
label = (loss_mask, sentence_order)
return data, label
def main():
# initialize
args = parse_args()
colossalai.launch_from_torch(config='./config.py', seed=1234, backend='nccl')
logger = get_dist_logger()
# build synthetic dataloader
BATCH_SIZE_PER_GPUS = gpc.config.GLOBAL_BATCH_SIZE // gpc.get_world_size(ParallelMode.DATA)
VOCAB_SIZE = 30528
trainloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
vocab_size=VOCAB_SIZE,
seq_length=gpc.config.SEQ_LENGTH)
validloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
vocab_size=VOCAB_SIZE,
seq_length=gpc.config.SEQ_LENGTH)
logger.info("Dataloaders are built", ranks=[0])
# build model
if hasattr(gpc.config, 'fp16') and gpc.config.fp16.get('mode') == AMP_TYPE.NAIVE:
is_naive_fp16 = True
else:
is_naive_fp16 = False
use_pipeline = is_using_pp()
kwargs = dict(vocab_size=VOCAB_SIZE,
hidden_size=gpc.config.HIDDEN_SIZE,
max_sequence_length=gpc.config.SEQ_LENGTH,
num_attention_heads=gpc.config.NUM_ATTENTION_HEADS,
convert_fp16_to_fp32_in_softmax=True,
is_naive_fp16=is_naive_fp16,
add_binary_head=gpc.config.ADD_BINARY_HEAD)
if use_pipeline:
model = build_pipeline_bert(num_layers=gpc.config.DEPTH, num_chunks=1, **kwargs)
else:
model = BertForPretrain(num_layers=gpc.config.DEPTH, **kwargs)
model = model.half()
model.reset_parameters()
logger.info(f"Model is built with softmax in fp32 = {is_naive_fp16}", ranks=[0])
total_numel = 0
for p in model.parameters():
total_numel += p.numel()
logger.info(f"This model has {total_numel} parameters")
# build criterion
criterion = BertLoss()
logger.info("Criterion is built", ranks=[0])
# layernorm and bias has no weight decay
weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
for module_ in model.modules():
if isinstance(module_, LayerNorm):
no_weight_decay_params['params'].extend([p for p in list(module_._parameters.values()) if p is not None])
else:
weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items()) if p is not None and n != 'bias'])
no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items()) if p is not None and n == 'bias'])
logger.info(
f"without weight decay param: {len(no_weight_decay_params['params'])}, with weight decay param: {len(weight_decay_params['params'])}"
)
# optimizer
optimizer = FusedAdam((weight_decay_params, no_weight_decay_params),
lr=gpc.config.LR,
weight_decay=gpc.config.WEIGHT_DECAY)
logger.info("Optimizer is built", ranks=[0])
# lr scheduler
# follow Megatron-LM setting
warmup_steps = int(gpc.config.DECAY_ITERS * gpc.config.WARMUP_FRACTION)
lr_scheduler = AnnealingLR(optimizer=optimizer,
max_lr=gpc.config.LR,
min_lr=gpc.config.MIN_LR,
warmup_steps=warmup_steps,
decay_steps=gpc.config.DECAY_ITERS,
decay_style='linear')
logger.info(f"LR Scheduler is built with {warmup_steps} warmup steps and {gpc.config.DECAY_ITERS} decay steps")
# # init
engine, *dummy = colossalai.initialize(model, optimizer, criterion, verbose=True)
# build timer
timer = MultiTimer()
skip_iters = 0
# build loss tracker
accumulated_train_loss = torch.zeros(1, dtype=torch.float32).cuda()
accumulated_eval_loss = torch.zeros(1, dtype=torch.float32).cuda()
# build data iters for pipeline parallel
if use_pipeline:
train_data_iter = SequenceParallelDataIterator(trainloader)
valid_data_iter = SequenceParallelDataIterator(validloader)
engine.schedule.data_process_func = pipeline_data_process_func
logger.info("start training")
for step in range(1, gpc.config.TRAIN_ITERS + 1):
timer.start('train-iterations')
engine.train()
if use_pipeline:
engine.zero_grad()
_, _, train_loss = engine.execute_schedule(train_data_iter, return_output_label=False)
engine.step()
else:
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch_for_sequence_parallel(
trainloader)
engine.zero_grad()
lm_loss, sop_output = engine(tokens, padding_mask, types, lm_labels)
train_loss = engine.criterion(lm_loss, sop_output, loss_mask, sentence_order)
engine.backward(train_loss)
engine.step()
timer.stop('train-iterations', keep_in_history=True)
if not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE):
accumulated_train_loss += train_loss
lr_scheduler.step()
if step % gpc.config.EVAL_INTERVAL == 0:
engine.eval()
for j in range(gpc.config.EVAL_ITERS):
with torch.no_grad():
if use_pipeline:
_, _, eval_loss = engine.execute_schedule(valid_data_iter,
forward_only=True,
return_output_label=False)
else:
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch_for_sequence_parallel(
validloader)
lm_loss, sop_output = engine(tokens, padding_mask, types, lm_labels)
eval_loss = engine.criterion(lm_loss, sop_output, loss_mask, sentence_order)
if not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE):
accumulated_eval_loss += eval_loss
if not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE):
accumulated_eval_loss /= gpc.config.EVAL_ITERS
accumulated_train_loss /= gpc.config.EVAL_INTERVAL
timer_string = []
for n, t in timer:
timer_string.append(f"{n}: {t.get_history_mean()*1000:.5f}")
timer_string = ' | '.join(timer_string)
lr = list(engine.optimizer.param_groups)[0]['lr']
loss_scale = engine.optimizer.optim.loss_scale.item()
if gpc.is_initialized(ParallelMode.PIPELINE):
ranks = [gpc.get_ranks_in_group(ParallelMode.PIPELINE)[-1]]
else:
ranks = [0]
logger.info(f'Step {step} / {gpc.config.TRAIN_ITERS} | Train Loss: {accumulated_train_loss.item():.5g} ' +
f'| Eval Loss: {accumulated_eval_loss.item():.5g} ' + f'| Loss Scale: {loss_scale}' +
f"| Learning rate: {lr} | " + timer_string,
ranks=ranks)
for n, t in timer:
t.reset()
accumulated_eval_loss.zero_()
accumulated_train_loss.zero_()
if __name__ == '__main__':
main()