diff --git a/colossalai/engine/schedule/_pipeline_schedule.py b/colossalai/engine/schedule/_pipeline_schedule.py index 73a39e833..5bab0d524 100644 --- a/colossalai/engine/schedule/_pipeline_schedule.py +++ b/colossalai/engine/schedule/_pipeline_schedule.py @@ -1,19 +1,20 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import List, Tuple, Union, Callable import inspect -import torch.cuda +from typing import Callable, List, Tuple, Union import colossalai.communication as comm +import torch.cuda +from colossalai.amp.naive_amp import NaiveAMPModel from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from colossalai.amp.naive_amp import NaiveAMPModel +from colossalai.logging import get_dist_logger +from colossalai.utils import switch_virtual_pipeline_parallel_rank from colossalai.utils.cuda import get_current_device from colossalai.zero import (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3) -from colossalai.utils import switch_virtual_pipeline_parallel_rank -from colossalai.logging import get_dist_logger + from ._base_schedule import BaseSchedule @@ -151,7 +152,7 @@ class PipelineSchedule(BaseSchedule): if gpc.is_last_rank(ParallelMode.PIPELINE): if return_output_label: - return_tensors.append(tuple((output_tensor, label))) + return_tensors.append((output_tensor, label)) if accum_loss is not None: loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches accum_loss.add_(loss_reduced.detach()) @@ -414,7 +415,7 @@ class InterleavedPipelineSchedule(PipelineSchedule): if gpc.is_pipeline_last_stage(): if return_output_label: - return_tensors.append(tuple(output_tensor, label)) + return_tensors.append((output_tensor, label)) if accum_loss is not None: loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches accum_loss.add_(loss_reduced.detach())