mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 21:09:18 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -2,7 +2,6 @@ import threading
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._C._distributed_rpc import PyRRef
|
||||
from torch.futures import Future
|
||||
|
||||
@@ -15,7 +14,6 @@ from colossalai.legacy.pipeline.rpc._pipeline_base import Phase, PipelineEngineB
|
||||
|
||||
|
||||
class FillDrainWorker(WorkerBase):
|
||||
|
||||
def _get_work_item_key(self) -> UniqueKey:
|
||||
# execute backward first (if backward phase in work_list)
|
||||
num_microbatches = self.num_microbatches
|
||||
@@ -33,29 +31,40 @@ class FillDrainWorker(WorkerBase):
|
||||
|
||||
|
||||
class FillDrainPipelineEngine(PipelineEngineBase):
|
||||
|
||||
def __init__(self,
|
||||
partition_fn: Callable,
|
||||
stage_num: int,
|
||||
num_microbatches: int,
|
||||
device: str,
|
||||
chunk: int = 1,
|
||||
criterion: Callable = None,
|
||||
metric: Callable = None,
|
||||
checkpoint: bool = False,
|
||||
data_process_func: Callable = None) -> None:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
partition_fn: Callable,
|
||||
stage_num: int,
|
||||
num_microbatches: int,
|
||||
device: str,
|
||||
chunk: int = 1,
|
||||
criterion: Callable = None,
|
||||
metric: Callable = None,
|
||||
checkpoint: bool = False,
|
||||
data_process_func: Callable = None,
|
||||
) -> None:
|
||||
if chunk > 1:
|
||||
assert num_microbatches % stage_num == 0, \
|
||||
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
|
||||
assert (
|
||||
num_microbatches % stage_num == 0
|
||||
), "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
|
||||
use_1F1B = False
|
||||
|
||||
super().__init__(FillDrainWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
|
||||
metric, checkpoint, data_process_func)
|
||||
super().__init__(
|
||||
FillDrainWorker,
|
||||
partition_fn,
|
||||
stage_num,
|
||||
num_microbatches,
|
||||
device,
|
||||
use_1F1B,
|
||||
chunk,
|
||||
criterion,
|
||||
metric,
|
||||
checkpoint,
|
||||
data_process_func,
|
||||
)
|
||||
|
||||
|
||||
class OneFOneBWorker(WorkerBase):
|
||||
|
||||
def _get_work_item_key(self) -> UniqueKey:
|
||||
# execute backward first (if backward phase in work_list)
|
||||
pp_rank = self.pp_rank
|
||||
@@ -77,8 +86,7 @@ class OneFOneBWorker(WorkerBase):
|
||||
# change outstanding_range at:
|
||||
# 1. forward times reach actual_stage_num, this is the end of continuous forward
|
||||
# 2. forward times reach num_microbatches, this is the end of 1F1B mode
|
||||
if not is_last_stage and \
|
||||
target_key.phase == Phase.FORWARD:
|
||||
if not is_last_stage and target_key.phase == Phase.FORWARD:
|
||||
if target_key.microbatch_id == actual_stage_num - 1 and num_microbatches > 2:
|
||||
# Why need num_microbatches > 2 ? Because there is no steady stage when num_microbatches <= 2
|
||||
outstanding_min = actual_stage_num - pp_rank - 1
|
||||
@@ -91,30 +99,41 @@ class OneFOneBWorker(WorkerBase):
|
||||
|
||||
|
||||
class OneFOneBPipelineEngine(PipelineEngineBase):
|
||||
|
||||
def __init__(self,
|
||||
partition_fn: Callable,
|
||||
stage_num: int,
|
||||
num_microbatches: int,
|
||||
device: str,
|
||||
chunk: int = 1,
|
||||
criterion: Callable = None,
|
||||
metric: Callable = None,
|
||||
checkpoint: bool = False,
|
||||
data_process_func: Callable = None) -> None:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
partition_fn: Callable,
|
||||
stage_num: int,
|
||||
num_microbatches: int,
|
||||
device: str,
|
||||
chunk: int = 1,
|
||||
criterion: Callable = None,
|
||||
metric: Callable = None,
|
||||
checkpoint: bool = False,
|
||||
data_process_func: Callable = None,
|
||||
) -> None:
|
||||
if chunk > 1:
|
||||
assert num_microbatches % stage_num == 0, \
|
||||
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
|
||||
assert (
|
||||
num_microbatches % stage_num == 0
|
||||
), "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
|
||||
# assert num_microbatches > stage_num * chunk, "num_microbatches must be greater than stage_num * chunk"
|
||||
use_1F1B = True
|
||||
|
||||
super().__init__(OneFOneBWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
|
||||
metric, checkpoint, data_process_func)
|
||||
super().__init__(
|
||||
OneFOneBWorker,
|
||||
partition_fn,
|
||||
stage_num,
|
||||
num_microbatches,
|
||||
device,
|
||||
use_1F1B,
|
||||
chunk,
|
||||
criterion,
|
||||
metric,
|
||||
checkpoint,
|
||||
data_process_func,
|
||||
)
|
||||
|
||||
|
||||
class ChimeraWorker(WorkerBase):
|
||||
|
||||
def _get_producer_consumer(self) -> None:
|
||||
rank = self.pp_rank
|
||||
min_pp_rank = (rank // self.actual_stage_num) * self.actual_stage_num
|
||||
@@ -143,11 +162,12 @@ class ChimeraWorker(WorkerBase):
|
||||
forward_block_size = 1 if self.num_microbatches < stage_num else self.num_microbatches // stage_num
|
||||
forward_block_num = self.forward_times // forward_block_size
|
||||
|
||||
if self.forward_times >= real_microbatch_num or \
|
||||
((pp_rank + 1) % stage_num == 0 and forward_block_num > self.backward_times):
|
||||
if self.forward_times >= real_microbatch_num or (
|
||||
(pp_rank + 1) % stage_num == 0 and forward_block_num > self.backward_times
|
||||
):
|
||||
target_phase = Phase.BACKWARD
|
||||
target_microbatch_id = self.backward_times
|
||||
else: # others
|
||||
else: # others
|
||||
target_phase = Phase.FORWARD
|
||||
target_microbatch_id = self.forward_times
|
||||
|
||||
@@ -168,7 +188,7 @@ class ChimeraWorker(WorkerBase):
|
||||
# from corresponding up stage
|
||||
pp_rank = self.pp_rank
|
||||
stage_num = self.actual_stage_num
|
||||
device = self.device
|
||||
self.device
|
||||
if pp_rank < stage_num:
|
||||
super()._initialize_partition()
|
||||
else:
|
||||
@@ -242,27 +262,38 @@ class ChimeraWorker(WorkerBase):
|
||||
|
||||
|
||||
class ChimeraPipelineEngine(PipelineEngineBase):
|
||||
|
||||
def __init__(self,
|
||||
partition_fn: Callable,
|
||||
stage_num: int,
|
||||
num_microbatches: int,
|
||||
device: str,
|
||||
criterion: Callable = None,
|
||||
metric: Callable = None,
|
||||
checkpoint: bool = False,
|
||||
data_process_func: Callable = None) -> None:
|
||||
|
||||
assert num_microbatches % stage_num == 0, \
|
||||
"In Chimera, num_microbatches must be the multiply of stage_num!"
|
||||
def __init__(
|
||||
self,
|
||||
partition_fn: Callable,
|
||||
stage_num: int,
|
||||
num_microbatches: int,
|
||||
device: str,
|
||||
criterion: Callable = None,
|
||||
metric: Callable = None,
|
||||
checkpoint: bool = False,
|
||||
data_process_func: Callable = None,
|
||||
) -> None:
|
||||
assert num_microbatches % stage_num == 0, "In Chimera, num_microbatches must be the multiply of stage_num!"
|
||||
use_1F1B = False
|
||||
chunk = 1
|
||||
|
||||
super().__init__(ChimeraWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
|
||||
metric, checkpoint, data_process_func)
|
||||
super().__init__(
|
||||
ChimeraWorker,
|
||||
partition_fn,
|
||||
stage_num,
|
||||
num_microbatches,
|
||||
device,
|
||||
use_1F1B,
|
||||
chunk,
|
||||
criterion,
|
||||
metric,
|
||||
checkpoint,
|
||||
data_process_func,
|
||||
)
|
||||
|
||||
def _consume_constraint(self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int],
|
||||
output_pp_ranks: List[int], ret_future):
|
||||
def _consume_constraint(
|
||||
self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int], output_pp_ranks: List[int], ret_future
|
||||
):
|
||||
pass
|
||||
|
||||
def _create_pp_rank_to_rpc_worker_id(self) -> None:
|
||||
|
Reference in New Issue
Block a user