mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-20 20:54:55 +00:00
* init
* rename and remove useless func
* basic chunk
* add evoformer
* align evoformer
* add meta
* basic chunk
* basic memory
* finish basic inference memory estimation
* finish memory estimation
* fix bug
* finish memory estimation
* add part of index tracer
* finish basic index tracer
* add doc string
* add doc str
* polish code
* polish code
* update active log
* polish code
* add possible region search
* finish region search loop
* finish chunk define
* support new op
* rename index tracer
* finishi codegen on msa
* redesign index tracer, add source and change compute
* pass outproduct mean
* code format
* code format
* work with outerproductmean and msa
* code style
* code style
* code style
* code style
* change threshold
* support check_index_duplicate
* support index dupilictae and update loop
* support output
* update memory estimate
* optimise search
* fix layernorm
* move flow tracer
* refactor flow tracer
* format code
* refactor flow search
* code style
* adapt codegen to prepose node
* code style
* remove abandoned function
* remove flow tracer
* code style
* code style
* reorder nodes
* finish node reorder
* update run
* code style
* add chunk select class
* add chunk select
* code style
* add chunksize in emit, fix bug in reassgin shape
* code style
* turn off print mem
* add evoformer openfold init
* init openfold
* add benchmark
* add print
* code style
* code style
* init openfold
* update openfold
* align openfold
* use max_mem to control stratge
* update source add
* add reorder in mem estimator
* improve reorder efficeincy
* support ones_like, add prompt if fit mode search fail
* fix a bug in ones like, dont gen chunk if dim size is 1
* fix bug again
* update min memory stratege, reduce mem usage by 30%
* last version of benchmark
* refactor structure
* restruct dir
* update test
* rename
* take apart chunk code gen
* close mem and code print
* code format
* rename ambiguous variable
* seperate flow tracer
* seperate input node dim search
* seperate prepose_nodes
* seperate non chunk input
* seperate reorder
* rename
* ad reorder graph
* seperate trace flow
* code style
* code style
* fix typo
* set benchmark
* rename test
* update codegen test
* Fix state_dict key missing issue of the ZeroDDP (#2363)
* Fix state_dict output for ZeroDDP duplicated parameters
* Rewrite state_dict based on get_static_torch_model
* Modify get_static_torch_model to be compatible with the lower version (ZeroDDP)
* update codegen test
* update codegen test
* add chunk search test
* code style
* add available
* [hotfix] fix gpt gemini example (#2404)
* [hotfix] fix gpt gemini example
* [example] add new assertions
* remove autochunk_available
* [workflow] added nightly release to pypi (#2403)
* add comments
* code style
* add doc for search chunk
* [doc] updated readme regarding pypi installation (#2406)
* add doc for search
* [doc] updated kernel-related optimisers' docstring (#2385)
* [doc] updated kernel-related optimisers' docstring
* polish doc
* rename trace_index to trace_indice
* rename function from index to indice
* rename
* rename in doc
* [polish] polish code for get_static_torch_model (#2405)
* [gemini] polish code
* [testing] remove code
* [gemini] make more robust
* rename
* rename
* remove useless function
* [worfklow] added coverage test (#2399)
* [worfklow] added coverage test
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* add doc for trace indice
* [docker] updated Dockerfile and release workflow (#2410)
* add doc
* update doc
* add available
* change imports
* add test in import
* [workflow] refactored the example check workflow (#2411)
* [workflow] refactored the example check workflow
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* Update parallel_context.py (#2408)
* [hotfix] add DISTPAN argument for benchmark (#2412)
* change the benchmark config file
* change config
* revert config file
* rename distpan to distplan
* [workflow] added precommit check for code consistency (#2401)
* [workflow] added precommit check for code consistency
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* adapt new fx
* [workflow] added translation for non-english comments (#2414)
* [setup] refactored setup.py for dependency graph (#2413)
* change import
* update doc
* [workflow] auto comment if precommit check fails (#2417)
* [hotfix] add norm clearing for the overflow step (#2416)
* [examples] adding tflops to PaLM (#2365)
* [workflow]auto comment with test coverage report (#2419)
* [workflow]auto comment with test coverage report
* polish code
* polish yaml
* [doc] added documentation for CI/CD (#2420)
* [doc] added documentation for CI/CD
* polish markdown
* polish markdown
* polish markdown
* [example] removed duplicated stable diffusion example (#2424)
* [zero] add inference mode and its unit test (#2418)
* [workflow] report test coverage even if below threshold (#2431)
* [example] improved the clarity yof the example readme (#2427)
* [example] improved the clarity yof the example readme
* polish workflow
* polish workflow
* polish workflow
* polish workflow
* polish workflow
* polish workflow
* [ddp] add is_ddp_ignored (#2434)
[ddp] rename to is_ddp_ignored
* [workflow] make test coverage report collapsable (#2436)
* [autoparallel] add shard option (#2423)
* [fx] allow native ckpt trace and codegen. (#2438)
* [cli] provided more details if colossalai run fail (#2442)
* [autoparallel] integrate device mesh initialization into autoparallelize (#2393)
* [autoparallel] integrate device mesh initialization into autoparallelize
* add megatron solution
* update gpt autoparallel examples with latest api
* adapt beta value to fit the current computation cost
* [zero] fix state_dict and load_state_dict for ddp ignored parameters (#2443)
* [ddp] add is_ddp_ignored
[ddp] rename to is_ddp_ignored
* [zero] fix state_dict and load_state_dict
* fix bugs
* [zero] update unit test for ZeroDDP
* [example] updated the hybrid parallel tutorial (#2444)
* [example] updated the hybrid parallel tutorial
* polish code
* [zero] add warning for ignored parameters (#2446)
* [example] updated large-batch optimizer tutorial (#2448)
* [example] updated large-batch optimizer tutorial
* polish code
* polish code
* [example] fixed seed error in train_dreambooth_colossalai.py (#2445)
* [workflow] fixed the on-merge condition check (#2452)
* [workflow] automated the compatiblity test (#2453)
* [workflow] automated the compatiblity test
* polish code
* [autoparallel] update binary elementwise handler (#2451)
* [autoparallel] update binary elementwise handler
* polish
* [workflow] automated bdist wheel build (#2459)
* [workflow] automated bdist wheel build
* polish workflow
* polish readme
* polish readme
* Fix False warning in initialize.py (#2456)
* Update initialize.py
* pre-commit run check
* [examples] update autoparallel tutorial demo (#2449)
* [examples] update autoparallel tutorial demo
* add test_ci.sh
* polish
* add conda yaml
* [cli] fixed hostname mismatch error (#2465)
* [example] integrate autoparallel demo with CI (#2466)
* [example] integrate autoparallel demo with CI
* polish code
* polish code
* polish code
* polish code
* [zero] low level optim supports ProcessGroup (#2464)
* [example] update vit ci script (#2469)
* [example] update vit ci script
* [example] update requirements
* [example] update requirements
* [example] integrate seq-parallel tutorial with CI (#2463)
* [zero] polish low level optimizer (#2473)
* polish pp middleware (#2476)
Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
* [example] update gpt gemini example ci test (#2477)
* [zero] add unit test for low-level zero init (#2474)
* [workflow] fixed the skip condition of example weekly check workflow (#2481)
* [example] stable diffusion add roadmap
* add dummy test_ci.sh
* [example] stable diffusion add roadmap (#2482)
* [CI] add test_ci.sh for palm, opt and gpt (#2475)
* polish code
* [example] titans for gpt
* polish readme
* remove license
* polish code
* update readme
* [example] titans for gpt (#2484)
* [autoparallel] support origin activation ckpt on autoprallel system (#2468)
* [autochunk] support evoformer tracer (#2485)
support full evoformer tracer, which is a main module of alphafold. previously we just support a simplifed version of it.
1. support some evoformer's op in fx
2. support evoformer test
3. add repos for test code
* [example] fix requirements (#2488)
* [zero] add unit testings for hybrid parallelism (#2486)
* [hotfix] gpt example titans bug #2493
* polish code and fix dataloader bugs
* [hotfix] gpt example titans bug #2493 (#2494)
* [fx] allow control of ckpt_codegen init (#2498)
* [fx] allow control of ckpt_codegen init
Currently in ColoGraphModule, ActivationCheckpointCodeGen will be set automatically in __init__. But other codegen can't be set if so.
So I add an arg to control whether to set ActivationCheckpointCodeGen in __init__.
* code style
* [example] dreambooth example
* add test_ci.sh to dreambooth
* [autochunk] support autochunk on evoformer (#2497)
* Revert "Update parallel_context.py (#2408)"
This reverts commit 7d5640b9db
.
* add avg partition (#2483)
Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
* [auto-chunk] support extramsa (#3) (#2504)
* [utils] lazy init. (#2148)
* [utils] lazy init.
* [utils] remove description.
* [utils] complete.
* [utils] finalize.
* [utils] fix names.
* [autochunk] support parsing blocks (#2506)
* [zero] add strict ddp mode (#2508)
* [zero] add strict ddp mode
* [polish] add comments for strict ddp mode
* [zero] fix test error
* [doc] update opt and tutorial links (#2509)
* [workflow] fixed changed file detection (#2515)
Co-authored-by: oahzxl <xuanlei.zhao@gmail.com>
Co-authored-by: eric8607242 <e0928021388@gmail.com>
Co-authored-by: HELSON <c2h214748@gmail.com>
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: Haofan Wang <haofanwang.ai@gmail.com>
Co-authored-by: Jiarui Fang <fangjiarui123@gmail.com>
Co-authored-by: ZijianYY <119492445+ZijianYY@users.noreply.github.com>
Co-authored-by: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com>
Co-authored-by: Super Daniel <78588128+super-dainiu@users.noreply.github.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: Ziyue Jiang <ziyue.jiang97@gmail.com>
Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
Co-authored-by: oahzxl <43881818+oahzxl@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Fazzie-Maqianli <55798671+Fazziekey@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
219 lines
9.1 KiB
Python
219 lines
9.1 KiB
Python
import argparse
|
|
|
|
import torch
|
|
from data.bert_helper import SequenceParallelDataIterator, get_batch_for_sequence_parallel
|
|
from data.dummy_dataloader import DummyDataloader
|
|
from loss_func.bert_loss import BertLoss
|
|
from lr_scheduler import AnnealingLR
|
|
from model.bert import BertForPretrain, build_pipeline_bert
|
|
|
|
import colossalai
|
|
from colossalai.amp import AMP_TYPE
|
|
from colossalai.context.parallel_mode import ParallelMode
|
|
from colossalai.core import global_context as gpc
|
|
from colossalai.engine.schedule import PipelineSchedule
|
|
from colossalai.kernel import LayerNorm
|
|
from colossalai.logging import get_dist_logger
|
|
from colossalai.nn.optimizer import FusedAdam
|
|
from colossalai.utils import MultiTimer, is_using_pp
|
|
|
|
|
|
def process_batch_data(batch_data):
|
|
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = batch_data
|
|
if gpc.is_first_rank(ParallelMode.PIPELINE):
|
|
data = dict(input_ids=tokens, attention_masks=padding_mask, tokentype_ids=types, lm_labels=lm_labels)
|
|
else:
|
|
data = dict(attention_masks=padding_mask, tokentype_ids=types, lm_labels=lm_labels)
|
|
label = dict(loss_mask=loss_mask, sentence_order=sentence_order)
|
|
return data, label
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('-s', '--synthetic', action="store_true", help="whether use synthetic data")
|
|
return parser.parse_args()
|
|
|
|
|
|
def pipeline_data_process_func(stage_output, micro_batch_data):
|
|
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = micro_batch_data
|
|
if gpc.is_first_rank(ParallelMode.PIPELINE):
|
|
data = (tokens, padding_mask, types, lm_labels)
|
|
label = (loss_mask, sentence_order)
|
|
else:
|
|
data = (stage_output, padding_mask, types, lm_labels)
|
|
label = (loss_mask, sentence_order)
|
|
return data, label
|
|
|
|
|
|
def main():
|
|
# initialize
|
|
args = parse_args()
|
|
colossalai.launch_from_torch(config='./config.py', seed=1234, backend='nccl')
|
|
|
|
logger = get_dist_logger()
|
|
|
|
# build synthetic dataloader
|
|
BATCH_SIZE_PER_GPUS = gpc.config.GLOBAL_BATCH_SIZE // gpc.get_world_size(ParallelMode.DATA)
|
|
VOCAB_SIZE = 30528
|
|
trainloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
|
|
vocab_size=VOCAB_SIZE,
|
|
seq_length=gpc.config.SEQ_LENGTH)
|
|
validloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
|
|
vocab_size=VOCAB_SIZE,
|
|
seq_length=gpc.config.SEQ_LENGTH)
|
|
|
|
logger.info("Dataloaders are built", ranks=[0])
|
|
|
|
# build model
|
|
if hasattr(gpc.config, 'fp16') and gpc.config.fp16.get('mode') == AMP_TYPE.NAIVE:
|
|
is_naive_fp16 = True
|
|
else:
|
|
is_naive_fp16 = False
|
|
|
|
use_pipeline = is_using_pp()
|
|
kwargs = dict(vocab_size=VOCAB_SIZE,
|
|
hidden_size=gpc.config.HIDDEN_SIZE,
|
|
max_sequence_length=gpc.config.SEQ_LENGTH,
|
|
num_attention_heads=gpc.config.NUM_ATTENTION_HEADS,
|
|
convert_fp16_to_fp32_in_softmax=True,
|
|
is_naive_fp16=is_naive_fp16,
|
|
add_binary_head=gpc.config.ADD_BINARY_HEAD)
|
|
|
|
if use_pipeline:
|
|
model = build_pipeline_bert(num_layers=gpc.config.DEPTH, num_chunks=1, **kwargs)
|
|
else:
|
|
model = BertForPretrain(num_layers=gpc.config.DEPTH, **kwargs)
|
|
|
|
model = model.half()
|
|
model.reset_parameters()
|
|
logger.info(f"Model is built with softmax in fp32 = {is_naive_fp16}", ranks=[0])
|
|
|
|
total_numel = 0
|
|
for p in model.parameters():
|
|
total_numel += p.numel()
|
|
logger.info(f"This model has {total_numel} parameters")
|
|
|
|
# build criterion
|
|
criterion = BertLoss()
|
|
logger.info("Criterion is built", ranks=[0])
|
|
|
|
# layernorm and bias has no weight decay
|
|
weight_decay_params = {'params': []}
|
|
no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
|
|
for module_ in model.modules():
|
|
if isinstance(module_, LayerNorm):
|
|
no_weight_decay_params['params'].extend([p for p in list(module_._parameters.values()) if p is not None])
|
|
else:
|
|
weight_decay_params['params'].extend(
|
|
[p for n, p in list(module_._parameters.items()) if p is not None and n != 'bias'])
|
|
no_weight_decay_params['params'].extend(
|
|
[p for n, p in list(module_._parameters.items()) if p is not None and n == 'bias'])
|
|
|
|
logger.info(
|
|
f"without weight decay param: {len(no_weight_decay_params['params'])}, with weight decay param: {len(weight_decay_params['params'])}"
|
|
)
|
|
# optimizer
|
|
optimizer = FusedAdam((weight_decay_params, no_weight_decay_params),
|
|
lr=gpc.config.LR,
|
|
weight_decay=gpc.config.WEIGHT_DECAY)
|
|
logger.info("Optimizer is built", ranks=[0])
|
|
|
|
# lr scheduler
|
|
# follow Megatron-LM setting
|
|
warmup_steps = int(gpc.config.DECAY_ITERS * gpc.config.WARMUP_FRACTION)
|
|
lr_scheduler = AnnealingLR(optimizer=optimizer,
|
|
max_lr=gpc.config.LR,
|
|
min_lr=gpc.config.MIN_LR,
|
|
warmup_steps=warmup_steps,
|
|
decay_steps=gpc.config.DECAY_ITERS,
|
|
decay_style='linear')
|
|
logger.info(f"LR Scheduler is built with {warmup_steps} warmup steps and {gpc.config.DECAY_ITERS} decay steps")
|
|
|
|
# # init
|
|
engine, *dummy = colossalai.initialize(model, optimizer, criterion, verbose=True)
|
|
|
|
# build timer
|
|
timer = MultiTimer()
|
|
skip_iters = 0
|
|
|
|
# build loss tracker
|
|
accumulated_train_loss = torch.zeros(1, dtype=torch.float32).cuda()
|
|
accumulated_eval_loss = torch.zeros(1, dtype=torch.float32).cuda()
|
|
|
|
# build data iters for pipeline parallel
|
|
if use_pipeline:
|
|
train_data_iter = SequenceParallelDataIterator(trainloader)
|
|
valid_data_iter = SequenceParallelDataIterator(validloader)
|
|
engine.schedule.data_process_func = pipeline_data_process_func
|
|
|
|
logger.info("start training")
|
|
|
|
for step in range(1, gpc.config.TRAIN_ITERS + 1):
|
|
timer.start('train-iterations')
|
|
engine.train()
|
|
if use_pipeline:
|
|
engine.zero_grad()
|
|
_, _, train_loss = engine.execute_schedule(train_data_iter, return_output_label=False)
|
|
engine.step()
|
|
else:
|
|
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch_for_sequence_parallel(
|
|
trainloader)
|
|
engine.zero_grad()
|
|
lm_loss, sop_output = engine(tokens, padding_mask, types, lm_labels)
|
|
train_loss = engine.criterion(lm_loss, sop_output, loss_mask, sentence_order)
|
|
engine.backward(train_loss)
|
|
engine.step()
|
|
timer.stop('train-iterations', keep_in_history=True)
|
|
|
|
if not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE):
|
|
accumulated_train_loss += train_loss
|
|
|
|
lr_scheduler.step()
|
|
|
|
if step % gpc.config.EVAL_INTERVAL == 0:
|
|
engine.eval()
|
|
|
|
for j in range(gpc.config.EVAL_ITERS):
|
|
with torch.no_grad():
|
|
if use_pipeline:
|
|
_, _, eval_loss = engine.execute_schedule(valid_data_iter,
|
|
forward_only=True,
|
|
return_output_label=False)
|
|
else:
|
|
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch_for_sequence_parallel(
|
|
validloader)
|
|
lm_loss, sop_output = engine(tokens, padding_mask, types, lm_labels)
|
|
eval_loss = engine.criterion(lm_loss, sop_output, loss_mask, sentence_order)
|
|
|
|
if not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE):
|
|
accumulated_eval_loss += eval_loss
|
|
|
|
if not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE):
|
|
accumulated_eval_loss /= gpc.config.EVAL_ITERS
|
|
accumulated_train_loss /= gpc.config.EVAL_INTERVAL
|
|
|
|
timer_string = []
|
|
for n, t in timer:
|
|
timer_string.append(f"{n}: {t.get_history_mean()*1000:.5f}")
|
|
timer_string = ' | '.join(timer_string)
|
|
lr = list(engine.optimizer.param_groups)[0]['lr']
|
|
loss_scale = engine.optimizer.optim.loss_scale.item()
|
|
|
|
if gpc.is_initialized(ParallelMode.PIPELINE):
|
|
ranks = [gpc.get_ranks_in_group(ParallelMode.PIPELINE)[-1]]
|
|
else:
|
|
ranks = [0]
|
|
logger.info(f'Step {step} / {gpc.config.TRAIN_ITERS} | Train Loss: {accumulated_train_loss.item():.5g} ' +
|
|
f'| Eval Loss: {accumulated_eval_loss.item():.5g} ' + f'| Loss Scale: {loss_scale}' +
|
|
f"| Learning rate: {lr} | " + timer_string,
|
|
ranks=ranks)
|
|
|
|
for n, t in timer:
|
|
t.reset()
|
|
accumulated_eval_loss.zero_()
|
|
accumulated_train_loss.zero_()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|