mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-09 03:47:57 +00:00
[NFC] polish colossalai/engine/schedule/_pipeline_schedule_v2.py code style (#3275)
This commit is contained in:
parent
196d4696d0
commit
1168b50e33
@ -1,11 +1,12 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
from typing import Tuple, Iterable
|
from typing import Iterable, Tuple
|
||||||
|
|
||||||
from colossalai import engine
|
|
||||||
import colossalai.communication.p2p_v2 as comm
|
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
|
|
||||||
|
import colossalai.communication.p2p_v2 as comm
|
||||||
|
from colossalai import engine
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
@ -35,7 +36,7 @@ def pack_return_tensors(return_tensors):
|
|||||||
class PipelineScheduleV2(PipelineSchedule):
|
class PipelineScheduleV2(PipelineSchedule):
|
||||||
"""Derived class of PipelineSchedule, the only difference is that
|
"""Derived class of PipelineSchedule, the only difference is that
|
||||||
forward_backward_step is reconstructed with p2p_v2
|
forward_backward_step is reconstructed with p2p_v2
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_microbatches (int): The number of microbatches.
|
num_microbatches (int): The number of microbatches.
|
||||||
data_process_func (Callable, optional):
|
data_process_func (Callable, optional):
|
||||||
@ -43,9 +44,9 @@ class PipelineScheduleV2(PipelineSchedule):
|
|||||||
tensor_shape (torch.Size, optional): Specified shape in pipeline communication.
|
tensor_shape (torch.Size, optional): Specified shape in pipeline communication.
|
||||||
scatter_gather_tensors (bool, optional):
|
scatter_gather_tensors (bool, optional):
|
||||||
If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization.
|
If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
# this shows an example of customized data_process_func
|
# this shows an example of customized data_process_func
|
||||||
def data_process_func(stage_output, dataloader_output):
|
def data_process_func(stage_output, dataloader_output):
|
||||||
output1, output2 = stage_output
|
output1, output2 = stage_output
|
||||||
|
Loading…
Reference in New Issue
Block a user