mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 21:17:08 +00:00
[tutorial] fixed pipeline bug for sequence parallel (#1943)
This commit is contained in:
@@ -35,6 +35,17 @@ def parse_args():
|
|||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def pipeline_data_process_func(stage_output, micro_batch_data):
|
||||||
|
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = micro_batch_data
|
||||||
|
if gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||||
|
data = (tokens, padding_mask, types, lm_labels)
|
||||||
|
label = (loss_mask, sentence_order)
|
||||||
|
else:
|
||||||
|
data = (stage_output, padding_mask, types, lm_labels)
|
||||||
|
label = (loss_mask, sentence_order)
|
||||||
|
return data, label
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# initialize
|
# initialize
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
@@ -155,6 +166,7 @@ def main():
|
|||||||
if use_pipeline:
|
if use_pipeline:
|
||||||
train_data_iter = SequenceParallelDataIterator(trainloader)
|
train_data_iter = SequenceParallelDataIterator(trainloader)
|
||||||
valid_data_iter = SequenceParallelDataIterator(validloader)
|
valid_data_iter = SequenceParallelDataIterator(validloader)
|
||||||
|
engine.schedule.data_process_func = pipeline_data_process_func
|
||||||
|
|
||||||
logger.info("start training")
|
logger.info("start training")
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user