[NFC] polish colossalai/engine/schedule/_pipeline_schedule_v2.py code style (#3275)

This commit is contained in:
Zirui Zhu 2023-03-28 14:31:38 +08:00 committed by binmakeswell
parent 196d4696d0
commit 1168b50e33

View File

@ -1,11 +1,12 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
from typing import Tuple, Iterable from typing import Iterable, Tuple
from colossalai import engine
import colossalai.communication.p2p_v2 as comm
import torch.cuda import torch.cuda
import colossalai.communication.p2p_v2 as comm
from colossalai import engine
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
@ -35,7 +36,7 @@ def pack_return_tensors(return_tensors):
class PipelineScheduleV2(PipelineSchedule): class PipelineScheduleV2(PipelineSchedule):
"""Derived class of PipelineSchedule, the only difference is that """Derived class of PipelineSchedule, the only difference is that
forward_backward_step is reconstructed with p2p_v2 forward_backward_step is reconstructed with p2p_v2
Args: Args:
num_microbatches (int): The number of microbatches. num_microbatches (int): The number of microbatches.
data_process_func (Callable, optional): data_process_func (Callable, optional):
@ -43,9 +44,9 @@ class PipelineScheduleV2(PipelineSchedule):
tensor_shape (torch.Size, optional): Specified shape in pipeline communication. tensor_shape (torch.Size, optional): Specified shape in pipeline communication.
scatter_gather_tensors (bool, optional): scatter_gather_tensors (bool, optional):
If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization. If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization.
Example: Example:
# this shows an example of customized data_process_func # this shows an example of customized data_process_func
def data_process_func(stage_output, dataloader_output): def data_process_func(stage_output, dataloader_output):
output1, output2 = stage_output output1, output2 = stage_output