From c6ea65011f1cf48f97d76209925f63a5d496ea79 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Mon, 14 Nov 2022 18:06:57 +0800 Subject: [PATCH] [tutorial] fixed pipeline bug for sequence parallel (#1943) --- examples/tutorial/sequence_parallel/train.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/examples/tutorial/sequence_parallel/train.py b/examples/tutorial/sequence_parallel/train.py index 2ca84e2bc..b92061000 100644 --- a/examples/tutorial/sequence_parallel/train.py +++ b/examples/tutorial/sequence_parallel/train.py @@ -35,6 +35,17 @@ def parse_args(): return parser.parse_args() +def pipeline_data_process_func(stage_output, micro_batch_data): + tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = micro_batch_data + if gpc.is_first_rank(ParallelMode.PIPELINE): + data = (tokens, padding_mask, types, lm_labels) + label = (loss_mask, sentence_order) + else: + data = (stage_output, padding_mask, types, lm_labels) + label = (loss_mask, sentence_order) + return data, label + + def main(): # initialize args = parse_args() @@ -155,6 +166,7 @@ def main(): if use_pipeline: train_data_iter = SequenceParallelDataIterator(trainloader) valid_data_iter = SequenceParallelDataIterator(validloader) + engine.schedule.data_process_func = pipeline_data_process_func logger.info("start training")